In [1]:
import os
import sys
import pickle
import numpy as np
import torch 
from tqdm import tqdm

os.environ["CONFIG_PATH"] = "../configs/large_model.toml"
sys.path.append("../src")

import neural_net
from configuration import config
from neural_net import NeuralNet, ResidualBlock
import player_pov_helpers

""

Loaded config from file: ../configs/large_model.toml
Loading file: piece_indices
Loading file: rotation_mapping
Loading file: new_occupieds
Loading file: moves_ruled_out_for_all
Loading file: scores
Loading file: moves_ruled_out_for_player
Loading file: moves_enabled_for_player
Loading file: new_adjacents
Loading file: new_corners


''

In [2]:
BOARD_SIZE = config()["game"]["board_size"]

In [3]:
def model_size(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    return sum(p.numel() for p in model_parameters)
    # return sum([np.prod(p.size()) for p in model_parameters])
    # return sum(p.numel() for p in model.parameters())

In [4]:
import time

def time_per_eval(num_evaluations, batch_size, model):
    num_batches = num_evaluations // batch_size
    num_evaluations = num_batches * batch_size

    random_arrays = np.random.random((num_batches, batch_size, 4, BOARD_SIZE, BOARD_SIZE))
    print("Starting...")

    start = time.perf_counter()
    for i in range(num_batches):
        occupancies = torch.from_numpy(random_arrays[i]).to(dtype=torch.float, device="mps")
        model(occupancies)
    elapsed = time.perf_counter() - start

    return elapsed / (num_batches * batch_size)

In [5]:
model = NeuralNet().to("mps")
model_size(model)
print(len(model.residual_blocks))

19


In [10]:
time_per_eval(
    num_evaluations=5000,
    batch_size=500,
    model=model,
)

Starting...


0.0003070897334022447

In [6]:
# Time per move:
times_per_eval = {}
for batch_size in tqdm(range(10, 600, 10)):
    times_per_eval[batch_size] = time_per_eval(
        num_evaluations=batch_size * 20,
        batch_size=batch_size,
        model=model,
    )

  0%|          | 0/59 [00:00<?, ?it/s]

Starting...


  2%|▏         | 1/59 [00:00<00:18,  3.11it/s]

Starting...


  3%|▎         | 2/59 [00:00<00:17,  3.32it/s]

Starting...


  5%|▌         | 3/59 [00:00<00:17,  3.14it/s]

Starting...


  7%|▋         | 4/59 [00:01<00:19,  2.88it/s]

Starting...


  8%|▊         | 5/59 [00:01<00:20,  2.60it/s]

Starting...


 10%|█         | 6/59 [00:02<00:22,  2.34it/s]

Starting...


 12%|█▏        | 7/59 [00:02<00:24,  2.16it/s]

Starting...


 14%|█▎        | 8/59 [00:03<00:25,  1.97it/s]

Starting...


 15%|█▌        | 9/59 [00:04<00:27,  1.80it/s]

Starting...


 17%|█▋        | 10/59 [00:04<00:29,  1.67it/s]

Starting...


 19%|█▊        | 11/59 [00:05<00:30,  1.55it/s]

Starting...


 20%|██        | 12/59 [00:06<00:32,  1.45it/s]

Starting...


 22%|██▏       | 13/59 [00:07<00:34,  1.34it/s]

Starting...


 24%|██▎       | 14/59 [00:08<00:36,  1.24it/s]

Starting...


 25%|██▌       | 15/59 [00:09<00:37,  1.16it/s]

Starting...


 27%|██▋       | 16/59 [00:10<00:39,  1.09it/s]

Starting...


 29%|██▉       | 17/59 [00:11<00:41,  1.02it/s]

Starting...


 31%|███       | 18/59 [00:12<00:42,  1.03s/it]

Starting...


 32%|███▏      | 19/59 [00:13<00:43,  1.09s/it]

Starting...


 34%|███▍      | 20/59 [00:15<00:45,  1.15s/it]

Starting...


 36%|███▌      | 21/59 [00:16<00:46,  1.23s/it]

Starting...


 37%|███▋      | 22/59 [00:17<00:48,  1.30s/it]

Starting...


 39%|███▉      | 23/59 [00:19<00:49,  1.36s/it]

Starting...


 41%|████      | 24/59 [00:21<00:50,  1.46s/it]

Starting...


 42%|████▏     | 25/59 [00:22<00:52,  1.56s/it]

Starting...


 44%|████▍     | 26/59 [00:24<00:56,  1.72s/it]

Starting...


 46%|████▌     | 27/59 [00:27<01:01,  1.93s/it]

Starting...


 47%|████▋     | 28/59 [00:29<01:01,  1.98s/it]

Starting...


 49%|████▉     | 29/59 [00:31<01:00,  2.00s/it]

Starting...


 51%|█████     | 30/59 [00:33<00:59,  2.07s/it]

Starting...


 53%|█████▎    | 31/59 [00:35<00:59,  2.11s/it]

Starting...


 54%|█████▍    | 32/59 [00:38<00:57,  2.12s/it]

Starting...


 56%|█████▌    | 33/59 [00:41<01:02,  2.39s/it]

Starting...


 58%|█████▊    | 34/59 [00:43<00:59,  2.40s/it]

Starting...


 59%|█████▉    | 35/59 [00:46<01:02,  2.60s/it]

Starting...


 61%|██████    | 36/59 [00:48<00:57,  2.52s/it]

Starting...


 63%|██████▎   | 37/59 [00:52<01:00,  2.77s/it]

Starting...


 64%|██████▍   | 38/59 [00:59<01:28,  4.22s/it]

Starting...


 66%|██████▌   | 39/59 [01:06<01:38,  4.91s/it]

Starting...


 68%|██████▊   | 40/59 [01:09<01:20,  4.24s/it]

Starting...


 69%|██████▉   | 41/59 [01:11<01:08,  3.78s/it]

Starting...


 71%|███████   | 42/59 [01:14<00:58,  3.47s/it]

Starting...


 73%|███████▎  | 43/59 [01:17<00:52,  3.26s/it]

Starting...


 75%|███████▍  | 44/59 [01:20<00:46,  3.13s/it]

Starting...


 76%|███████▋  | 45/59 [01:23<00:43,  3.08s/it]

Starting...


 78%|███████▊  | 46/59 [01:26<00:39,  3.06s/it]

Starting...


 80%|███████▉  | 47/59 [01:29<00:36,  3.07s/it]

Starting...


 81%|████████▏ | 48/59 [01:32<00:33,  3.08s/it]

Starting...


 83%|████████▎ | 49/59 [01:35<00:31,  3.11s/it]

Starting...


 85%|████████▍ | 50/59 [01:38<00:28,  3.15s/it]

Starting...


 86%|████████▋ | 51/59 [01:41<00:25,  3.19s/it]

Starting...


 88%|████████▊ | 52/59 [01:45<00:22,  3.25s/it]

Starting...


 90%|████████▉ | 53/59 [01:48<00:19,  3.31s/it]

Starting...


 92%|█████████▏| 54/59 [01:52<00:16,  3.37s/it]

Starting...


 93%|█████████▎| 55/59 [01:55<00:13,  3.42s/it]

Starting...


 95%|█████████▍| 56/59 [01:59<00:10,  3.47s/it]

Starting...


 97%|█████████▋| 57/59 [02:03<00:07,  3.53s/it]

Starting...


 98%|█████████▊| 58/59 [02:06<00:03,  3.61s/it]

Starting...


100%|██████████| 59/59 [02:10<00:00,  2.22s/it]


In [7]:
size_distribution = {0: 0, 10: 0, 20: 0, 30: 0, 40: 2, 50: 0, 60: 0, 70: 2, 80: 5, 90: 0, 100: 0, 110: 8, 120: 8, 130: 9, 140: 10, 150: 16, 160: 25, 170: 23, 180: 7, 190: 0, 200: 3, 210: 1, 220: 2, 230: 6, 240: 2, 250: 8, 260: 13, 270: 7, 280: 0, 290: 18, 300: 6, 310: 2, 320: 4, 330: 2, 340: 0, 350: 2, 360: 0, 370: 4, 380: 8, 390: 8, 400: 6, 410: 4, 420: 4, 430: 3, 440: 3, 450: 1, 460: 0, 470: 2, 480: 1, 490: 0, 500: 0, 510: 1, 520: 0, 530: 0, 540: 0, 550: 0, 560: 0, 570: 0, 580: 0, 590: 0, 600: 0, 610: 0, 620: 0, 630: 0, 640: 0, 650: 0, 660: 0, 670: 0, 680: 0, 690: 0, 700: 0, 710: 0, 720: 0, 730: 0, 740: 0, 750: 0, 760: 0, 770: 0, 780: 0, 790: 0, 800: 0, 810: 0, 820: 0, 830: 0, 840: 0, 850: 0, 860: 0, 870: 0, 880: 0, 890: 0, 900: 0, 910: 0, 920: 0, 930: 0, 940: 0, 950: 0, 960: 0, 970: 0, 980: 0, 990: 0}

total_time = 0
num_evaluations = 0
for size in size_distribution:
    if size_distribution[size] == 0:
        continue
    total_time += size_distribution[size] * times_per_eval[size]
    num_evaluations += size_distribution[size]
print(total_time / num_evaluations)

0.0003870881047853177


In [8]:
time_per_eval(
    num_evaluations=5000,
    batch_size=100,
    model=model,
)

Starting...


0.0003615001916012261

In [9]:
times_per_eval

{10: 0.001602485000039451,
 20: 0.0007163511449471117,
 30: 0.0005632886816844499,
 40: 0.00048660171876690583,
 50: 0.0004494004999869503,
 60: 0.0004214411108356823,
 70: 0.00038294005929076643,
 80: 0.0003738480731226446,
 90: 0.0003639423844449791,
 100: 0.0003480814164940966,
 110: 0.0003374426704513925,
 120: 0.00032602670124712557,
 130: 0.00034013450346090114,
 140: 0.0003350290178579079,
 150: 0.0003281990833347663,
 160: 0.0003244789715608931,
 170: 0.00032740600500427025,
 180: 0.00032284310194275655,
 190: 0.00031721354184188196,
 200: 0.0003257472394980141,
 210: 0.00033658351190665384,
 220: 0.00032906508522997186,
 230: 0.00032740791673701176,
 240: 0.00034571629333489303,
 250: 0.0003564988666039426,
 260: 0.00040498254019229743,
 270: 0.00044534993833318973,
 280: 0.00037290672607403916,
 290: 0.0003526340589651854,
 300: 0.00036806953483513403,
 310: 0.00035447608871054986,
 320: 0.0003311281185960979,
 330: 0.0004576303346998016,
 340: 0.0003531430085294072,
 350: 0.

In [None]:
import time

def time_per_eval(num_evaluations, batch_size, model):
    num_batches = num_evaluations // batch_size
    num_evaluations = num_batches * batch_size

    random_arrays = np.random.random((num_batches, batch_size, 4, BOARD_SIZE, BOARD_SIZE))
    print("Starting...")

    start = time.perf_counter()
    for i in range(num_batches):
        occupancies = torch.from_numpy(random_arrays[i]).to(dtype=torch.float, device="mps")
        model(occupancies)
    elapsed = time.perf_counter() - start

    return elapsed / (num_batches * batch_size)

In [21]:
player_pov_helpers.moves_indices_to_player_pov([625], 1)

array([820])

In [23]:
NUM_MOVES = config()["game"]["num_moves"]

universal_moves_array = np.zeros((NUM_MOVES,))
universal_moves_array[625] = 1

player_moves_array = player_pov_helpers.moves_array_to_player_pov(universal_moves_array, 1)
np.argmax(player_moves_array)

820

In [44]:
player_whose_turn_it_is = 1

# In my node, index 0 represents move 123, index 1 represents move 625, index 2 represents move 1052.
# This is purely compression, nothing to do with player rotations.
array_index_to_move_index = np.array([123, 625, 1052])

# This is the argument I pass into my NN
player_pov_valid_move_indices = player_pov_helpers.moves_indices_to_player_pov(array_index_to_move_index, player_whose_turn_it_is)
print("player_pov_valid_move_indices", player_pov_valid_move_indices)

# Use the old, trusted function to determine where universal move 625 maps to in this player's POV.
universal_moves_array = np.zeros((NUM_MOVES,))
universal_moves_array[625] = 1
nn_recommended_move_index = np.argmax(player_pov_helpers.moves_array_to_player_pov(universal_moves_array, player_whose_turn_it_is))
print("nn_recommended_move_index", nn_recommended_move_index)

# My NN says that in the player POV, move 820 is best. (That's the move which move 625 becomes)
policy_that_my_nn_returned = np.zeros((NUM_MOVES,))
policy_that_my_nn_returned[nn_recommended_move_index] = 1

# In my evaluate, I use the rotated move indices to filter the policy my NN returned.
policy_that_evaluate_returns = policy_that_my_nn_returned[player_pov_valid_move_indices]

universal_policy_that_i_save_to_my_node = policy_that_evaluate_returns
universal_policy_that_i_save_to_my_node

player_pov_valid_move_indices [ 560  820 1059]
nn_recommended_move_index 820


array([0., 1., 0.])

In [39]:
player_pov_valid_move_indices

array([ 560,  820, 1059])