In [11]:
import os
import sys
import pickle
import numpy as np
import torch 

os.environ["CONFIG_PATH"] = "../configs/default.toml"
sys.path.append("../src")

import neural_net
from configuration import config
from neural_net import NeuralNet, ResidualBlock

""

''

In [12]:
BOARD_SIZE = config()["game"]["board_size"]

In [8]:
model = NeuralNet().to("mps")
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
sum([np.prod(p.size()) for p in model_parameters])

85842749

In [27]:
import time

def time_per_eval(num_evaluations, batch_size):
    num_batches = num_evaluations // batch_size
    num_evaluations = num_batches * batch_size

    random_arrays = np.random.random((num_batches, batch_size, 4, BOARD_SIZE, BOARD_SIZE))
    print("Starting...")

    start = time.perf_counter()
    for i in range(num_batches):
        occupancies = torch.from_numpy(random_arrays[i]).to(dtype=torch.float, device="mps")
        model(occupancies)
    elapsed = time.perf_counter() - start

    return elapsed / (num_batches * batch_size)

In [67]:
# Time per move:
500 * time_per_eval(
    num_evaluations=5000,
    batch_size=57,
)

Starting...


0.21635593113439305

In [4]:
import torch.nn as nn
import torch.nn.functional as F

In [5]:
nn.CrossEntropyLoss()(
    torch.tensor([[1.0, 5.0, 3.0]]),
    F.softmax(torch.tensor([[1.0, 5.0, 3.0]]), dim=1),
)

tensor(0.4411)

In [6]:
nn.CrossEntropyLoss(reduction="mean")(
    torch.tensor([
        [1.0, 5.0, 3.0],
        [1.0, 5.0, 3.0],
    ]),
    F.softmax(torch.tensor([
        [1.0, 5.0, 3.0],
        [1.0, 5.0, 3.0]
    ]), dim=1),
)

tensor(0.4411)

In [13]:
def myFunc():
    print("Starting tree search")
    results = []
    for i in range(4):
        result = yield f"board_{i}"
        results.append(result)
    print("Done with tree search. Got results:", results)

generators = {
    "generator_1": myFunc(),
    "generator_2": myFunc(),
}

# Have any new board evaluation results come in on the processing queue?
# If so, send them to the appropriate thread.
while True:
    socket_data = socket.get()
    if socket_data is None:
        # There's nothing to do -- add a new thread!
        generators["generator_3"] = myFunc()
        socket.send(next(generators["generator_3"]))

    thread_id, result = socket_data

    try:
        occupancies = generators[thread_id].send(result)
    except StopIteration as e:
        del generators[thread_id]

    socket.send(occupancies)

Starting tree search
Got board to evaluate: board_0
Got board to evaluate: board_1
Got board to evaluate: board_2
Got board to evaluate: board_3
Done with tree search. Got results: ['result_of_evaluating_board_0', 'result_of_evaluating_board_1', 'result_of_evaluating_board_2', 'result_of_evaluating_board_3']
