In [1]:
import os
from torch import utils, Tensor, abs
import random
import lightning as L
import pandas as pd
from lightning.pytorch.callbacks import EarlyStopping
import sys

sys.path.append("../")
from src.alpha_connect.supervised_model import AlphaZeroModel, MyDataset

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
data_path = os.path.join(os.getcwd(), "..", "data")
print(data_path)

/Users/alberttroussard/Documents/alpha-connect/notebooks/../data


In [4]:
model = AlphaZeroModel()

In [5]:
# load data.csv


data = pd.read_csv(os.path.join(data_path, "data.csv"))
# shuffle the data
data = data.sample(frac=1).reset_index(drop=True)

x = data.iloc[:, :-2].values

policy = data.iloc[:, -2].values
value = data.iloc[:, -1].values


# keep only the first digit of the labels
train_ratio = 0.85
validation_ratio = 0.008
x_train, x_validation, x_test = (
    x[: int(train_ratio * len(x))],
    x[int(train_ratio * len(x)) : int((train_ratio + validation_ratio) * len(x))],
    x[int((train_ratio + validation_ratio) * len(x)) :],
)
policy_train, policy_validation, policy_test = (
    [int(random.sample(str(i), 1)[0]) for i in policy][
        : int(train_ratio * len(policy))
    ],
    policy[
        int(train_ratio * len(policy)) : int(
            (train_ratio + validation_ratio) * len(policy)
        )
    ],
    policy[int((train_ratio + validation_ratio) * len(policy)) :],
)
value_train, value_validation, value_test = (
    value[: int(train_ratio * len(value))],
    value[
        int(train_ratio * len(value)) : int(
            (train_ratio + validation_ratio) * len(value)
        )
    ],
    value[int((train_ratio + validation_ratio) * len(value)) :],
)

# create a dataset
train_dataset = MyDataset(Tensor(x_train), Tensor(policy_train), Tensor(value_train))
validation_dataset = MyDataset(
    Tensor(x_validation), Tensor(policy_validation), Tensor(value_validation)
)
test_dataset = MyDataset(Tensor(x_test), Tensor(policy_test), Tensor(value_test))

# create a dataloader
train_loader = utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=9
)
validation_loader = utils.data.DataLoader(
    validation_dataset, batch_size=1, shuffle=False, num_workers=9
)
test_loader = utils.data.DataLoader(
    test_dataset, batch_size=1, shuffle=False, num_workers=9
)

In [6]:
class ValidationLossCallback(L.Callback):
    def __init__(self, validation_dataloader):
        self.validation_dataloader = validation_dataloader

    def on_train_epoch_end(self, trainer, pl_module):
        # if trainer.current_epoch % 5 != 0 :
        #     return

        pl_module.eval()
        correct = 0
        correct_value = 0
        total = 0
        i = 0
        for x, policy, value in self.validation_dataloader:
            i += 1
            policy = [int(a) for a in str(int(policy[0]))]
            predicted_policy, predicted_value = pl_module(x.to("mps"))
            _, predicted = predicted_policy.max(1)
            total += 1
            correct += 1 if (predicted in policy) else 0
            correct_value += 1 if abs(predicted_value - value.to("mps")) < 0.5 else 0
        print(f"Value Accuracy: {correct_value/total}")
        print(f"Accuracy: {correct/total}")

        pl_module.train()

In [7]:
import torch


torch.autograd.set_detect_anomaly(True)
es = EarlyStopping(
    monitor="total_loss_step",
    min_delta=0.01,
    verbose=True,
    mode="min",
    patience=10,
    check_on_train_epoch_end=True,
)
trainer = L.Trainer(
    min_epochs=15,
    max_epochs=400,
    limit_train_batches=256,
    default_root_dir=os.path.join(data_path, "supervised"),
    check_val_every_n_epoch=1,
    callbacks=[es, ValidationLossCallback(validation_loader)],
)
trainer.fit(model=model, train_dataloaders=train_loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/alberttroussard/Documents/alpha-connect/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name        | Type       | Params
-------------------------------------------
0 | conv_layers | Sequential | 22.5 M
1 | policy_head | PolicyHead | 1.1 K 
2 | value_head  | ValueHead  | 11.5 K
-------------------------------------------
22.5 M    Trainable params
0         Non-t

Training: |          | 0/? [00:00<?, ?it/s]

Metric total_loss_step improved. New best score: 2.359


Value Accuracy: 0.801150895140665
Accuracy: 0.4162404092071611
Value Accuracy: 0.9526854219948849
Accuracy: 0.5703324808184144


Metric total_loss_step improved by 0.565 >= min_delta = 0.01. New best score: 1.794


Value Accuracy: 0.9533248081841432
Accuracy: 0.5965473145780051
Value Accuracy: 0.9494884910485933
Accuracy: 0.618925831202046
Value Accuracy: 0.9322250639386189
Accuracy: 0.6118925831202046


Metric total_loss_step improved by 0.266 >= min_delta = 0.01. New best score: 1.528


Value Accuracy: 0.9763427109974424
Accuracy: 0.6911764705882353
Value Accuracy: 0.9699488491048593
Accuracy: 0.7186700767263428


Metric total_loss_step improved by 0.046 >= min_delta = 0.01. New best score: 1.481


Value Accuracy: 0.9501278772378516
Accuracy: 0.6726342710997443
Value Accuracy: 0.9533248081841432
Accuracy: 0.6777493606138107
Value Accuracy: 0.967391304347826
Accuracy: 0.7416879795396419
Value Accuracy: 0.9341432225063938
Accuracy: 0.7097186700767263
Value Accuracy: 0.9648337595907929
Accuracy: 0.7180306905370843
Value Accuracy: 0.8350383631713555
Accuracy: 0.6163682864450127
Value Accuracy: 0.9725063938618926
Accuracy: 0.7493606138107417
Value Accuracy: 0.7570332480818415
Accuracy: 0.5805626598465473


Metric total_loss_step improved by 0.058 >= min_delta = 0.01. New best score: 1.423


Value Accuracy: 0.94693094629156
Accuracy: 0.7551150895140665
Value Accuracy: 0.969309462915601
Accuracy: 0.7608695652173914
Value Accuracy: 0.9641943734015346
Accuracy: 0.7781329923273658
Value Accuracy: 0.9578005115089514
Accuracy: 0.760230179028133


Metric total_loss_step improved by 0.146 >= min_delta = 0.01. New best score: 1.277


Value Accuracy: 0.9763427109974424
Accuracy: 0.7659846547314578
Value Accuracy: 0.9386189258312021
Accuracy: 0.7512787723785166
Value Accuracy: 0.9731457800511509
Accuracy: 0.8056265984654731
Value Accuracy: 0.9718670076726342
Accuracy: 0.7647058823529411


In [None]:
# test the model
model.eval()

correct_value = 0
correct = 0
total = 0
moves = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0}
i = 0
for x, policy, value in test_loader:
    i += 1
    policy = [int(a) for a in str(int(policy[0]))]
    predicted_policy, predicted_value = model(x)
    _, predicted = predicted_policy.max(1)
    moves[predicted.item()] += 1
    total += 1
    correct += 1 if (predicted in policy) else 0

    correct_value += 1 if torch.abs(predicted_value - value) < 0.5 else 0
    if i % 1000 == 0:
        print(f"Value Accuracy: {100*correct_value/total}")
        print(f"Accuracy: {100*correct/total}")
        print(moves)

print(f"Value Accuracy: {100*correct_value/total}")
print(f"Accuracy: {100*correct/total}")
print(moves)
model.train()

Value Accuracy: 94.5
Accuracy: 73.7
{0: 52, 1: 76, 2: 165, 3: 198, 4: 160, 5: 76, 6: 273}
Value Accuracy: 93.5
Accuracy: 73.15
{0: 95, 1: 167, 2: 331, 3: 385, 4: 311, 5: 164, 6: 547}
Value Accuracy: 93.73333333333333
Accuracy: 73.2
{0: 135, 1: 262, 2: 500, 3: 567, 4: 470, 5: 240, 6: 826}
Value Accuracy: 93.55
Accuracy: 72.65
{0: 194, 1: 341, 2: 677, 3: 767, 4: 615, 5: 326, 6: 1080}
Value Accuracy: 93.6
Accuracy: 72.46
{0: 251, 1: 449, 2: 839, 3: 954, 4: 772, 5: 407, 6: 1328}
Value Accuracy: 93.76666666666667
Accuracy: 72.28333333333333
{0: 309, 1: 543, 2: 1013, 3: 1153, 4: 921, 5: 480, 6: 1581}
Value Accuracy: 93.72857142857143
Accuracy: 72.1
{0: 361, 1: 657, 2: 1170, 3: 1320, 4: 1074, 5: 550, 6: 1868}
Value Accuracy: 93.7875
Accuracy: 72.3875
{0: 403, 1: 732, 2: 1342, 3: 1539, 4: 1219, 5: 639, 6: 2126}
Value Accuracy: 93.68888888888888
Accuracy: 72.21111111111111
{0: 447, 1: 821, 2: 1497, 3: 1733, 4: 1375, 5: 710, 6: 2417}
Value Accuracy: 93.56
Accuracy: 72.09
{0: 501, 1: 899, 2: 1648

AlphaZeroModel(
  (conv_layers): Sequential(
    (0): ConvBlock(
      (n): Sequential(
        (0): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
    (1): ResidualBlock(
      (n): Sequential(
        (0): ConvBlock(
          (n): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
        )
        (1): ConvBlock(
          (n): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
        )
      )
      (r): ReLU()
    )
    (2): ResidualBlock(
      (n): Sequential(
        (0): ConvBl

In [None]:
# test the model
model.eval()

train_loader_batch1 = utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=False, num_workers=9
)
correct_value = 0
correct = 0
total = 0
moves = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0}
i = 0
for x, policy, value in train_loader_batch1:
    i += 1
    if i > 1000:
        break
    predicted_policy, predicted_value = model(x)
    _, predicted = predicted_policy.max(1)
    moves[predicted.item()] += 1
    total += 1
    correct += 1 if (predicted in policy) else 0

    correct_value += 1 if torch.abs(predicted_value - value) < 0.5 else 0
    if i % 1000 == 0:
        print(f"Value Accuracy: {100*correct_value/total}")
        print(f"Accuracy: {100*correct/total}")

print(f"Value Accuracy: {100*correct_value/total}")
print(f"Accuracy: {100*correct/total}")
print(moves)
model.train()

Value Accuracy: 93.5
Accuracy: 51.0
Value Accuracy: 93.5
Accuracy: 51.0
{0: 48, 1: 89, 2: 143, 3: 208, 4: 150, 5: 75, 6: 287}


AlphaZeroModel(
  (conv_layers): Sequential(
    (0): ConvBlock(
      (n): Sequential(
        (0): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
    (1): ResidualBlock(
      (n): Sequential(
        (0): ConvBlock(
          (n): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
        )
        (1): ConvBlock(
          (n): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
        )
      )
      (r): ReLU()
    )
    (2): ResidualBlock(
      (n): Sequential(
        (0): ConvBl