In [2]:
import os
import time
import random
import pickle
import numpy as np
from copy import deepcopy
from contextlib import nullcontext
from tqdm import tqdm, trange

import matplotlib.pyplot as plt
from cycler import cycler; plt.rcParams["axes.prop_cycle"] = cycler(color=["#000000", "#2180FE", "#EB4275"])
from IPython.display import clear_output

import torch
from torch import nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'device: {device}')
print(f'os.cpu_count(): {os.cpu_count()}')
!nvidia-smi -L

device: cuda
os.cpu_count(): 20
GPU 0: NVIDIA GeForce RTX 4070 SUPER (UUID: GPU-b10c6cac-b9b4-49fa-2e4d-8ba5c4c711ba)


In [3]:
from cube import Cube
env = Cube()

In [10]:
from model import Model
model = Model()
model.to(device)

Model(
  (embedding): LinearBlock(
    (fc): Linear(in_features=324, out_features=5000, bias=True)
    (relu): ReLU()
    (bn): BatchNorm1d(5000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (linear_layers): ModuleList(
    (0): LinearBlock(
      (fc): Linear(in_features=5000, out_features=1000, bias=True)
      (relu): ReLU()
      (bn): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (residual_blocks): ModuleList(
    (0-3): 4 x ResidualBlock(
      (layers): ModuleList(
        (0-1): 2 x LinearBlock(
          (fc): Linear(in_features=1000, out_features=1000, bias=True)
          (relu): ReLU()
          (bn): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
  )
  (output): Linear(in_features=1000, out_features=12, bias=True)
)

In [11]:
from config import TrainConfig


class ScrambleGenerator(torch.utils.data.Dataset):
    def __init__(
        self,
        num_workers=os.cpu_count(),
        max_depth=TrainConfig.max_depth,
        total_samples=TrainConfig.num_steps*TrainConfig.batch_size_per_depth
    ):
        self.num_workers = num_workers
        self.max_depth = max_depth
        # self.envs = [Cube() for _ in range(num_workers)]
        # self.generators = [env.scrambler(self.max_depth) for env in self.envs]

        self.total_samples = total_samples

    def __len__(self):
        return self.total_samples

    def __getitem__(self, i):
        ''' generate one scramble, consisting of `self.max_depth` data points '''
        # worker_idx = i % self.num_workers
        try:
            X = np.zeros((self.max_depth, 54), dtype=np.int64)
            y = np.zeros((self.max_depth,), dtype=np.int64)
            generator = Cube().scrambler(self.max_depth)
            for j in range(self.max_depth):
                state, last_move = next(generator)
                X[j, :] = state
                y[j] = last_move
            return X, y
        except Exception as e:
            print(f'error: {e}')
            return self.__getitem__(i)


dataloader = torch.utils.data.DataLoader(
    ScrambleGenerator(),
    num_workers=0,
    batch_size=TrainConfig.batch_size_per_depth
)

# Get one batch from the dataloader
sample_batch = next(iter(dataloader))

# Unpack X and y from the batch
X, y = sample_batch

# Check shapes and data types
print("X shape:", X.shape)  # Expected: (batch_size, max_depth, 54)
print("y shape:", y.shape)  # Expected: (batch_size, max_depth)
print("X dtype:", X.dtype)
print("y dtype:", y.dtype)

# Optional: Print the first few entries to inspect contents
print("Sample X:", X[0])  # Print the first sample in X
print("Sample y:", y[0])  # Print the first sample in y


X shape: torch.Size([1000, 26, 54])
y shape: torch.Size([1000, 26])
X dtype: torch.int64
y dtype: torch.int64
Sample X: tensor([[0, 0, 0,  ..., 1, 5, 5],
        [2, 2, 5,  ..., 1, 5, 0],
        [1, 2, 5,  ..., 2, 2, 2],
        ...,
        [0, 4, 5,  ..., 5, 4, 5],
        [0, 4, 5,  ..., 5, 4, 0],
        [0, 4, 5,  ..., 5, 4, 0]])
Sample y: tensor([ 7,  6, 10,  1,  3, 10,  0,  7, 10,  3,  7, 10,  2, 11,  2,  9, 11,  7,
         3,  5,  4,  0,  5,  6,  6,  5])


In [None]:
def plot_loss_curve(h):
    fig, ax = plt.subplots(1, 1)
    ax.plot(h)
    ax.set_xlabel("Steps")
    ax.set_ylabel("Cross-entropy loss")
    ax.set_xscale("log")
    plt.show()

def train(model, dataloader):
    model.train()
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=TrainConfig.learning_rate)
    g = iter(dataloader)
    h = []
    ctx = torch.cuda.amp.autocast(dtype=torch.float16) if TrainConfig.ENABLE_FP16 else nullcontext()

    for i in trange(1, TrainConfig.num_steps + 1):
        batch_x, batch_y = next(g)
        batch_x, batch_y = batch_x.reshape(-1, 54).to(device), batch_y.reshape(-1).to(device)

        with ctx:
            pred_y = model(batch_x)
            loss = loss_fn(pred_y, batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        h.append(loss.item())
        if TrainConfig.INTERVAL_PLOT and i % TrainConfig.INTERVAL_PLOT == 0:
            clear_output()
            plot_loss_curve(h)
        if TrainConfig.INTERVAL_SAVE and i % TrainConfig.INTERVAL_SAVE == 0:
            torch.save(model.state_dict(), f"{i}steps.pth")
            print("Model saved.")
    print(f"Trained on data equivalent to {TrainConfig.batch_size_per_depth * TrainConfig.num_steps} solves.")
    return model

model = train(model, dataloader)