In [2]:
%load_ext autoreload
%autoreload 2
from alphatoe import evals, data, game, train
import torch
from torch.nn.functional import cross_entropy
from tqdm import tqdm
from transformer_lens import HookedTransformerConfig, HookedTransformer
import json
from transformer_lens import HookedTransformer, HookedTransformerConfig

In [3]:
_, game_list = data.gen_games("all minimax")

ValueError: gametype must be one of 'all', 'strat', or 'all minimax'. Not all minimax

In [None]:
game_list[-2]

tensor([10,  8,  4,  7,  6,  2,  5,  3,  0,  1,  9])

In [4]:
mm_first = game.get_all_minimax_games([game.Board()], True)
mm_second =  game.get_all_minimax_games([game.Board()], False)

KeyboardInterrupt: 

In [None]:
print(len(mm_first))
print(len(mm_second))

31040
9440


In [6]:
cfg = HookedTransformerConfig(
    n_layers = 2,
    n_heads = 4,
    d_model = 16,
    d_head = 4,
    d_mlp = 64,
    act_fn = "relu",
    #normalization_type=None,
    normalization_type='LN',
    d_vocab=11,
    d_vocab_out=10,
    n_ctx=10,
    init_weights=True,
    device="cuda",
    seed = 1337,
)

lr = 1e-5
weight_decay = 1e-4
test_train_split = 0.7
epochs = 10_000
batch_size = 4096 * 2 * 2
minimax_is_first = True

In [None]:
model = HookedTransformer(cfg).to(cfg.device)

Moving model to device:  cuda


In [7]:
minimax_first = data.gen_data("minimax first", test_train_split)
minimax_second= data.gen_data("minimax second", test_train_split)

Generating minimax vs all...
Generated 31040 games
Generated array of moves
torch.Size([31040, 10])
Generated data and labels
One hot encoded labels
Generating minimax vs all...
Generated 9440 games
Generated array of moves
torch.Size([9440, 10])
Generated data and labels
One hot encoded labels


In [10]:
minimax_first[1].shape

torch.Size([21728, 10, 10])

In [None]:
loss_fn = cross_entropy
optimizer =  torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)

In [None]:
def remove_alternating_indices(t: torch.Tensor, dim: int=0, odd:bool = True) -> torch.Tensor:
    indices = [index for index in range(t.shape[1]) if (index + odd) % 2 != 0 ]
    indices = torch.tensor(indices).to(t.device)
    return torch.index_select(t, 1, indices)

In [None]:
for epoch in range(epochs):
    if minimax_is_first: train_data, train_labels, test_data, test_labels = minimax_first
    else:
        train_data, train_labels, test_data, test_labels = minimax_second

    train_data.to(cfg.device)
    train_labels.to(cfg.device)
    test_data.to(cfg.device)
    test_labels.to(cfg.device)

    for batch in range(0, len(train_data), batch_size):
        input_batch = train_data[batch : batch + batch_size]
        label_batch = train_labels[batch : batch + batch_size]

        logits_batch = model(input_batch)

        logits_batch = remove_alternating_indices(logits_batch, odd = not minimax_is_first).to(cfg.device)
        label_batch = remove_alternating_indices(label_batch, odd = not minimax_is_first).to(cfg.device)
        train_loss = loss_fn(train.rearrange(logits_batch), train.rearrange(label_batch))

        train_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        with torch.inference_mode():
            # test inference runs for every update on the whole test set
            test_logits = model(test_data)
            logits_batch = remove_alternating_indices(test_logits, odd = not minimax_is_first).to('cuda')
            label_batch = remove_alternating_indices(test_labels, odd = not minimax_is_first).to('cuda')
            test_loss = loss_fn(train.rearrange(logits_batch), train.rearrange(label_batch))
        
        minimax_is_first = not minimax_is_first

    print(
        f"Epoch {epoch} | Train Loss: {train_loss.item()} | Test Loss: {test_loss.item()}"
    )

Epoch 0 | Train Loss: 2.4865224361419678 | Test Loss: 2.477419137954712
Epoch 1 | Train Loss: 2.5231146812438965 | Test Loss: 2.5100507736206055
Epoch 2 | Train Loss: 2.4725170135498047 | Test Loss: 2.464345932006836
Epoch 3 | Train Loss: 2.512850522994995 | Test Loss: 2.4998228549957275
Epoch 4 | Train Loss: 2.4594502449035645 | Test Loss: 2.4515063762664795
Epoch 5 | Train Loss: 2.502655506134033 | Test Loss: 2.489680051803589
Epoch 6 | Train Loss: 2.4466352462768555 | Test Loss: 2.4388341903686523
Epoch 7 | Train Loss: 2.492593288421631 | Test Loss: 2.4796814918518066
Epoch 8 | Train Loss: 2.433976888656616 | Test Loss: 2.426286458969116
Epoch 9 | Train Loss: 2.482698917388916 | Test Loss: 2.469857931137085
Epoch 10 | Train Loss: 2.421436071395874 | Test Loss: 2.4138646125793457
Epoch 11 | Train Loss: 2.4729504585266113 | Test Loss: 2.4601500034332275
Epoch 12 | Train Loss: 2.4090347290039062 | Test Loss: 2.4015798568725586
Epoch 13 | Train Loss: 2.463355302810669 | Test Loss: 2.450

(batch seq) token

In [None]:
def remove_alternating_indices(t: torch.Tensor, dim: int=0, odd:bool = True) -> torch.Tensor:
    indices = [index for index in range(t.shape[1]) if (index + odd) % 2 != 0 ]
    indices = torch.tensor(indices).to(t.device)
    print("indices device is: ", indices.device)
    print("t device is: ", t.device)
    return torch.index_select(t, 1, indices)

In [None]:
b = remove_alternating_indices(a, odd= True)
print(b)
print(b.shape)

tensor([[1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0.]])
torch.Size([10, 5])


In [None]:
board = game.Board()

In [None]:
a = evals._sample_game(model, 1)
print(a)
game.play_game(a)

[10, 3, 7, 6, 1, 4, 2, 5, 9, 9, 9]
Invalid game
|   | O | O |
| X | X | X |
| X | O |   |


In [None]:
torch.argmax(model(torch.tensor(seq))[0,-1])

tensor(9, device='cuda:0')

In [None]:
seq = [10,3,5,4,0,2,8,6]

In [None]:
board.make_move(seq[-1])
board.draw_board()

| O |   | X |
| X | X | O |
| X |   | O |


In [None]:
game.play_game(samples[123])

| X | O | O |
| X | X | X |
| X | O | O |


In [None]:
samples = evals.sample_games(model, 1, 1000)

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:11<00:00, 85.32it/s]


In [None]:
evals.eval_model(samples)

0it [00:00, ?it/s]

1000it [00:00, 22113.69it/s]


{'_check_played_repeat_moves': 0.043,
 '_check_played_after_player_victory': 0.08,
 '_check_played_after_draw_game': 0.0,
 'inappropriate_end_state': 0.392,
 '_check_if_illegal_moves': 0.509}