In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import wandb
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import WandbLogger

from training import SampleDataModule, TrainingGen
from nn import ConnectFourNet

In [3]:
base_dir = "../../training"
n_gens = 5
gens = TrainingGen.load_all(base_dir)[:n_gens]
samples = [
    sample
    for gen in gens
    for result in gen.get_games(base_dir).results
    for sample in result.samples
]

len(samples)

428653

In [4]:
sweep_config = {
    'method': 'bayes',  # can be 'random', 'grid', 'bayes'
    'metric': {
        'name': 'val_loss',
        'goal': 'minimize'
    },
    'parameters': {
        'learning_rate': {
            'min': 0.0001,
            'max': 0.1
        },
        'batch_size': {
            'values': [2048]
        },
        'n_conv_layers': {
            'values': [1, 2, 4, 8, 16]
        },
        'conv_filter_size': {
            'values': [4, 8, 16, 32, 64, 128, 256]
        },
        'n_policy_layers': {
            'values': [1, 2, 4, 8, 16]
        },
        'n_value_layers': {
            'values': [1, 2, 4, 8, 16]
        },
    }
}

In [5]:
def train():
    wandb.init()
    config = wandb.config
    wandb_logger = WandbLogger(project="c4a0")

    # Create your model with these hyperparameters
    model = ConnectFourNet(
        n_conv_layers=config.n_conv_layers,
        conv_filter_size=config.conv_filter_size,
        n_policy_layers=config.n_policy_layers,
        n_value_layers=config.n_value_layers,
        learning_rate=config.learning_rate,
    )

    split_idx = int(0.8 * len(samples))
    train, test = samples[:split_idx], samples[split_idx:]

    data_module = SampleDataModule(train, test, config.batch_size)

    trainer = pl.Trainer(
        max_epochs=30,
        accelerator="auto",
        devices="auto",
        callbacks=[
            EarlyStopping(monitor="val_loss", patience=4, mode="min"),
        ],
        logger=wandb_logger,
    )
    trainer.fit(model, data_module)

In [6]:
sweep_id = wandb.sweep(sweep_config, project="c4a0")
wandb.agent(sweep_id, train, count=100)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Create sweep with ID: n2g2mo4u
Sweep URL: https://wandb.ai/advait3000-advait/c4a0/sweeps/n2g2mo4u


[34m[1mwandb[0m: Agent Starting Run: h40qz4hc with config:
[34m[1mwandb[0m: 	batch_size: 2048
[34m[1mwandb[0m: 	conv_filter_size: 64
[34m[1mwandb[0m: 	learning_rate: 0.058514548646903766
[34m[1mwandb[0m: 	n_conv_layers: 2
[34m[1mwandb[0m: 	n_policy_layers: 1
[34m[1mwandb[0m: 	n_value_layers: 4
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33madvait3000[0m ([33madvait3000-advait[0m). Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/advait/c4a0/.venv/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.

  | Name          | Type             | Params
---------------------------------------------------
0 | conv          | Sequential       | 38.4 K
1 | fc_policy     | Sequential       | 18.8 K
2 | fc_value      | Sequential       | 21.7 M
3 | policy_kl_div | KLDivergence     | 0     
4 | value_mse     | MeanSquaredError | 0     
---------------------------------------------------
21.8 M    Trainable params
0         Non-trainable params
21.8 M    Total params
87.041    Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/Users/advait/c4a0/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


                                                                           

/Users/advait/c4a0/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 2:  68%|██████▊   | 227/335 [00:20<00:09, 10.92it/s, v_num=z4hc, train_loss=0.232, val_loss=0.263]