In [1]:
import os
from torch import utils, Tensor, abs
import random
import lightning as L
import pandas as pd
from lightning.pytorch.callbacks import EarlyStopping

from alpha_connect import AlphaZeroModelConnect4, MyDataset

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
data_path = os.path.join(os.getcwd(), "..", "data")
print(data_path)

/Users/alberttroussard/Documents/alpha-connect/notebooks/../data


In [4]:
model = AlphaZeroModelConnect4()

In [5]:
# load data.csv


data = pd.read_csv(os.path.join(data_path, "../data/data2.csv"))
# shuffle the data
data = data.sample(frac=1).reset_index(drop=True)

x = data.iloc[:, :-8].values

policy = data.iloc[:, -8:-1].values

value = data.iloc[:, -1].values


# keep only the first digit of the labels
train_ratio = 0.85
validation_ratio = 0.008
x_train, x_validation, x_test = (
    x[: int(train_ratio * len(x))],
    x[int(train_ratio * len(x)) : int((train_ratio + validation_ratio) * len(x))],
    x[int((train_ratio + validation_ratio) * len(x)) :],
)
policy_train, policy_validation, policy_test = (
    policy[: int(train_ratio * len(policy))],
    policy[
        int(train_ratio * len(policy)) : int(
            (train_ratio + validation_ratio) * len(policy)
        )
    ],
    policy[int((train_ratio + validation_ratio) * len(policy)) :],
)
value_train, value_validation, value_test = (
    value[: int(train_ratio * len(value))],
    value[
        int(train_ratio * len(value)) : int(
            (train_ratio + validation_ratio) * len(value)
        )
    ],
    value[int((train_ratio + validation_ratio) * len(value)) :],
)

# create a dataset
train_dataset = MyDataset(Tensor(x_train), Tensor(policy_train), Tensor(value_train))
validation_dataset = MyDataset(
    Tensor(x_validation), Tensor(policy_validation), Tensor(value_validation)
)
test_dataset = MyDataset(Tensor(x_test), Tensor(policy_test), Tensor(value_test))

# create a dataloader
train_loader = utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=9
)
validation_loader = utils.data.DataLoader(
    validation_dataset, batch_size=1, shuffle=False, num_workers=9
)
test_loader = utils.data.DataLoader(
    test_dataset, batch_size=1, shuffle=False, num_workers=9
)

In [6]:
class ValidationLossCallback(L.Callback):
    def __init__(self, validation_dataloader):
        self.validation_dataloader = validation_dataloader

    def on_train_epoch_end(self, trainer, pl_module):
        # if trainer.current_epoch % 5 != 0 :
        #     return

        pl_module.eval()
        correct = 0
        correct_value = 0
        total = 0
        i = 0
        for x, policy_distribution, value in self.validation_dataloader:
            i += 1

            policy = [i for i in range(7) if policy_distribution[0][i] != 0]
            predicted_policy, predicted_value = pl_module(x.to("mps"))
            _, predicted = predicted_policy.max(1)
            total += 1
            correct += 1 if (predicted in policy) else 0
            correct_value += 1 if abs(predicted_value - value.to("mps")) < 0.5 else 0
            # if i ==1:
            #     print(f"Predicted Policy: {predicted_policy}")
            #     print(f"Policy: {policy_distribution}")
            #     print(f"Predicted Value: {float(predicted_value)}")
            #     print(f"Value: {float(value)}")
        print(f"Value Accuracy: {correct_value/total}")
        print(f"Accuracy: {correct/total}")

        pl_module.train()

In [7]:
import torch


torch.autograd.set_detect_anomaly(True)
es = EarlyStopping(
    monitor="total_loss_step",
    min_delta=0.01,
    verbose=True,
    mode="min",
    patience=10,
    check_on_train_epoch_end=True,
)
trainer = L.Trainer(
    min_epochs=15,
    max_epochs=400,
    limit_train_batches=512,
    default_root_dir=os.path.join(data_path, "supervised"),
    check_val_every_n_epoch=1,
    callbacks=[es, ValidationLossCallback(validation_loader)],
)
trainer.fit(model=model, train_dataloaders=train_loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/alberttroussard/Documents/alpha-connect/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name        | Type       | Params
-------------------------------------------
0 | conv_layers | Sequential | 22.5 M
1 | policy_head | PolicyHead | 1.1 K 
2 | value_head  | ValueHead  | 11.5 K
-------------------------------------------
22.5 M    Trainable params
0         Non-t

Training: |          | 0/? [00:00<?, ?it/s]

Metric total_loss_step improved. New best score: 0.545


Value Accuracy: 0.5482428115015975
Accuracy: 0.3878594249201278


Metric total_loss_step improved by 0.183 >= min_delta = 0.01. New best score: 0.362


Value Accuracy: 0.9060702875399361
Accuracy: 0.6568690095846645
Value Accuracy: 0.6856230031948882
Accuracy: 0.5047923322683706
Value Accuracy: 0.8792332268370607
Accuracy: 0.734185303514377
Value Accuracy: 0.8562300319488818
Accuracy: 0.721405750798722


Metric total_loss_step improved by 0.041 >= min_delta = 0.01. New best score: 0.321


Value Accuracy: 0.8575079872204473
Accuracy: 0.7348242811501597
Value Accuracy: 0.8523961661341853
Accuracy: 0.7156549520766773


Metric total_loss_step improved by 0.033 >= min_delta = 0.01. New best score: 0.288


Value Accuracy: 0.9073482428115016
Accuracy: 0.802555910543131
Value Accuracy: 0.5936102236421725
Accuracy: 0.5725239616613419
Value Accuracy: 0.8958466453674121
Accuracy: 0.8006389776357827
Value Accuracy: 0.8926517571884984
Accuracy: 0.8242811501597445


Metric total_loss_step improved by 0.025 >= min_delta = 0.01. New best score: 0.263


Value Accuracy: 0.8428115015974441
Accuracy: 0.8089456869009585


Metric total_loss_step improved by 0.045 >= min_delta = 0.01. New best score: 0.218


Value Accuracy: 0.8587859424920128
Accuracy: 0.7712460063897764
Value Accuracy: 0.9105431309904153
Accuracy: 0.8268370607028754
Value Accuracy: 0.8440894568690096
Accuracy: 0.810223642172524
Value Accuracy: 0.8990415335463259
Accuracy: 0.8236421725239617
Value Accuracy: 0.7968051118210863
Accuracy: 0.744408945686901
Value Accuracy: 0.886261980830671
Accuracy: 0.8249201277955271
Value Accuracy: 0.8753993610223643
Accuracy: 0.8198083067092652
Value Accuracy: 0.9156549520766774
Accuracy: 0.8485623003194889
Value Accuracy: 0.8517571884984025
Accuracy: 0.8383386581469648


Metric total_loss_step improved by 0.050 >= min_delta = 0.01. New best score: 0.168


Value Accuracy: 0.9130990415335464
Accuracy: 0.8562300319488818
Value Accuracy: 0.8932907348242811
Accuracy: 0.8594249201277955
Value Accuracy: 0.8370607028753994
Accuracy: 0.8108626198083068
Value Accuracy: 0.9086261980830671
Accuracy: 0.8517571884984025
Value Accuracy: 0.9105431309904153
Accuracy: 0.8402555910543131
Value Accuracy: 0.9099041533546326
Accuracy: 0.8517571884984025
Value Accuracy: 0.9169329073482428
Accuracy: 0.8536741214057508
Value Accuracy: 0.8932907348242811
Accuracy: 0.8479233226837061


Metric total_loss_step improved by 0.022 >= min_delta = 0.01. New best score: 0.146


Value Accuracy: 0.9124600638977636
Accuracy: 0.8575079872204473
Value Accuracy: 0.9303514376996805
Accuracy: 0.8645367412140575


Metric total_loss_step improved by 0.021 >= min_delta = 0.01. New best score: 0.125


Value Accuracy: 0.9277955271565496
Accuracy: 0.868370607028754
Value Accuracy: 0.9329073482428115
Accuracy: 0.8677316293929712
Value Accuracy: 0.931629392971246
Accuracy: 0.8741214057507988


Metric total_loss_step improved by 0.020 >= min_delta = 0.01. New best score: 0.105


Value Accuracy: 0.929073482428115
Accuracy: 0.8792332268370607
Value Accuracy: 0.9297124600638977
Accuracy: 0.8741214057507988


Metric total_loss_step improved by 0.059 >= min_delta = 0.01. New best score: 0.046


Value Accuracy: 0.9252396166134186
Accuracy: 0.8747603833865815
Value Accuracy: 0.9322683706070287
Accuracy: 0.8753993610223643
Value Accuracy: 0.929073482428115
Accuracy: 0.8747603833865815
Value Accuracy: 0.929073482428115
Accuracy: 0.8766773162939298
Value Accuracy: 0.9297124600638977
Accuracy: 0.8792332268370607
Value Accuracy: 0.929073482428115
Accuracy: 0.876038338658147
Value Accuracy: 0.9303514376996805
Accuracy: 0.876038338658147
Value Accuracy: 0.9322683706070287
Accuracy: 0.8792332268370607
Value Accuracy: 0.9322683706070287
Accuracy: 0.8779552715654952
Value Accuracy: 0.9303514376996805
Accuracy: 0.8811501597444089


Monitored metric total_loss_step did not improve in the last 10 records. Best score: 0.046. Signaling Trainer to stop.


Value Accuracy: 0.9284345047923322
Accuracy: 0.8798722044728434


In [8]:
# test the model
model.eval()

correct_value = 0
correct = 0
total = 0
moves = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0}
i = 0
display = 5
for x, policy_distribution, value in test_loader:
    i += 1
    policy = [i for i in range(7) if policy_distribution[0][i] != 0]
    predicted_policy, predicted_value = model(x)
    _, predicted = predicted_policy.max(1)
    moves[predicted.item()] += 1
    total += 1
    correct += 1 if (predicted in policy) else 0

    correct_value += 1 if torch.abs(predicted_value - value) < 0.5 else 0
    if display > 0:
        print(f"Policy: {list(policy_distribution)[0]}")
        print(f"Predicted policy: {list(predicted_policy)[0]}")
        print(f"Value: {int(value)}")
        print(f"Predicted value: {float(predicted_value)}")
        display -= 1

    if i % 1000 == 0:
        print(f"Value Accuracy: {100*correct_value/total}")
        print(f"Accuracy: {100*correct/total}")
        print(moves)

print(f"Value Accuracy: {100*correct_value/total}")
print(f"Accuracy: {100*correct/total}")
print(moves)
model.train()

Policy: tensor([0.3333, 0.0000, 0.3333, 0.0000, 0.0000, 0.0000, 0.3333])
Predicted policy: tensor([0.1373, 0.0704, 0.1269, 0.0034, 0.1309, 0.1194, 0.3984],
       grad_fn=<UnbindBackward0>)
Value: 1
Predicted value: 0.9905009269714355
Policy: tensor([0., 0., 0., 0., 0., 1., 0.])
Predicted policy: tensor([-0.0067,  0.0104, -0.0596, -0.0173,  0.0236,  1.0273, -0.0023],
       grad_fn=<UnbindBackward0>)
Value: -1
Predicted value: -0.8876975774765015
Policy: tensor([0.2000, 0.2000, 0.0000, 0.0000, 0.2000, 0.2000, 0.2000])
Predicted policy: tensor([0.3265, 0.0922, 0.0402, 0.0141, 0.1750, 0.0807, 0.3134],
       grad_fn=<UnbindBackward0>)
Value: 1
Predicted value: 0.9988082051277161
Policy: tensor([0.0000, 0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000])
Predicted policy: tensor([ 0.0038, -0.0194,  0.4885,  0.4538,  0.0181,  0.0210,  0.0037],
       grad_fn=<UnbindBackward0>)
Value: 1
Predicted value: 0.9964443445205688
Policy: tensor([0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000])

AlphaZeroModelConnect4(
  (conv_layers): Sequential(
    (0): ConvBlock(
      (n): Sequential(
        (0): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
    (1): ResidualBlock(
      (n): Sequential(
        (0): ConvBlock(
          (n): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
        )
        (1): ConvBlock(
          (n): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
        )
      )
      (r): ReLU()
    )
    (2): ResidualBlock(
      (n): Sequential(
        (0)

In [9]:
# test the model
model.eval()

train_loader_batch1 = utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=False, num_workers=9
)
correct_value = 0
correct = 0
total = 0
moves = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0}
i = 0
for x, policy, value in train_loader_batch1:
    i += 1
    if i > 1000:
        break
    predicted_policy, predicted_value = model(x)
    _, predicted = predicted_policy.max(1)
    moves[predicted.item()] += 1
    total += 1
    correct += 1 if (predicted in policy) else 0

    correct_value += 1 if torch.abs(predicted_value - value) < 0.5 else 0
    if i % 1000 == 0:
        print(f"Value Accuracy: {100*correct_value/total}")
        print(f"Accuracy: {100*correct/total}")

print(f"Value Accuracy: {100*correct_value/total}")
print(f"Accuracy: {100*correct/total}")
print(moves)
model.train()

Value Accuracy: 98.3
Accuracy: 19.8
Value Accuracy: 98.3
Accuracy: 19.8
{0: 127, 1: 123, 2: 162, 3: 194, 4: 144, 5: 135, 6: 115}


AlphaZeroModelConnect4(
  (conv_layers): Sequential(
    (0): ConvBlock(
      (n): Sequential(
        (0): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
    (1): ResidualBlock(
      (n): Sequential(
        (0): ConvBlock(
          (n): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
        )
        (1): ConvBlock(
          (n): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
        )
      )
      (r): ReLU()
    )
    (2): ResidualBlock(
      (n): Sequential(
        (0)