In [6]:
%load_ext autoreload
%autoreload 2

import torch
from alpha_connect import AlphaZeroModelConnect4, RandomAgent, state_to_supervised_input
import time
from game import ConnectState

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
# generate 10000 random states
states = set()
random_agent = RandomAgent()
state = ConnectState.sample_initial_state()
while len(states) < 4096:
    if state.has_ended:
        state = ConnectState.sample_initial_state()
    states.add(state)

    action = random_agent.sample_move(state)
    state = action.sample_next_state()

In [9]:
model = AlphaZeroModelConnect4()
model.load_state_dict(torch.load("../data/latest.pth"))
model.to("mps")

batch_sizes = [4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1]
states = list(states)
for batch_size in batch_sizes:
    start_time = time.time()
    for i in range(0, len(states), batch_size):
        input_tensor = [
            state_to_supervised_input(state) for state in states[i : i + batch_size]
        ]
        input_tensor = (
            torch.stack(input_tensor)
            .type(torch.float32)
            .view(batch_size, 3, 6, 7)
            .to("mps")
        )
        a, b = model(input_tensor)
    print(
        f"Processed {len(states)} states in {time.time()-start_time} seconds with batch size {batch_size}"
    )

Processed 4096 states in 2.6952261924743652 seconds with batch size 4096
Processed 4096 states in 6.91292405128479 seconds with batch size 2048
Processed 4096 states in 2.618597984313965 seconds with batch size 1024
Processed 4096 states in 2.1275992393493652 seconds with batch size 512
Processed 4096 states in 2.1979100704193115 seconds with batch size 256
Processed 4096 states in 2.2819929122924805 seconds with batch size 128
Processed 4096 states in 2.568247079849243 seconds with batch size 64
Processed 4096 states in 3.096703052520752 seconds with batch size 32
Processed 4096 states in 5.507427930831909 seconds with batch size 16
Processed 4096 states in 10.644919157028198 seconds with batch size 8
Processed 4096 states in 21.00096321105957 seconds with batch size 4
Processed 4096 states in 41.22349309921265 seconds with batch size 2
Processed 4096 states in 83.01374006271362 seconds with batch size 1
