# Training cell predictor

In this notebook, we train a model to predict a cell's next state based on its current state and that of its neighbours.

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [None]:
class GameOfLifeDataset(torch.utils.data.Dataset):
    _size = 512

    def __init__(self):
        pass

    def __len__(self):
        return self._size

    def __getitem__(self, idx):
        # 12 -> "0b1100" -> "1100" -> "000001100"
        idx_bin = bin(idx)[2:].rjust(9, "0")
        X = torch.tensor([float(ch) for ch in idx_bin], dtype=torch.float32).reshape(1, 3, 3) # (channels, width, height)
        alive = X[0, 1, 1] > 0.5
        alive_neighbours = torch.sum(X) - X[0, 1, 1]
        next_alive = (alive and alive_neighbours > 1.5 and alive_neighbours < 3.5) or (not alive and alive_neighbours > 2.5 and alive_neighbours < 3.5)
        y = torch.tensor([float(next_alive)], dtype=torch.float32).reshape(1, 1) # (width, height)
        return X, y

In [None]:
dataset = GameOfLifeDataset()
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
class CellPredictorNeuralNetwork(nn.Module):
    """Predicts the next state of the cells.

    Inputs:
        x: Tensor of shape (batch_size, channels, width+2, height+2), where channels=1. width and height are the dimensions of the entire game grid.
           We add one cell of padding on each side to ensure that predictions can be made for the boundary cells.
    
    Returns: Tensor of shape (batch_size, width, height), the logits of the predicted states.
    """

    def __init__(self):
        super().__init__()
        self.conv0 = nn.Conv2d(1, 85, 3)
        self.conv1 = nn.Conv2d(85, 10, 1)
        self.conv2 = nn.Conv2d(10, 1, 1)

    def forward(self, x):
        x = F.relu(self.conv0(x))
        x = F.relu(self.conv1(x))
        logits = self.conv2(x)
        logits = torch.squeeze(logits, 1) # Remove channels dimension
        return logits

In [None]:
model = CellPredictorNeuralNetwork().to(device)
print(model)

In [None]:
X, y = next(iter(dataloader))
logits = model(X)[0]
pred_probab = F.sigmoid(logits)
y_pred = int(pred_probab > 0.5)
print(f"Predicted state: {y_pred}")

In [None]:
print(f"Model structure: {model}\n\n")

for name, param in model.named_parameters():
    print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    avg_loss = 0.0

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):

        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_loss += loss.item()

        if batch % 20 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    
    avg_loss /= size
    return avg_loss


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            prob = F.sigmoid(pred)
            correct += ((prob > 0.5).type(torch.float) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return {
        "loss": test_loss,
        "acc": correct
    }

In [None]:
learning_rate = 1e-1
batch_size = 4
epochs = 20

loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

loss_history = []
test_loss_history = []
test_acc_history = []

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    # Since our dataset represents the full set of possible states, we can safely use it for both training and testing
    epoch_loss = train_loop(dataloader, model, loss_fn, optimizer)
    loss_history.append(epoch_loss)
    test_metrics = test_loop(dataloader, model, loss_fn)
    test_loss_history.append(test_metrics["loss"])
    test_acc_history.append(test_metrics["acc"])

print("Done!")

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(2, 2, figsize=(10, 8), constrained_layout=True)
axs = axs.flatten()
ax_loss, ax_test_loss, ax_test_acc, ax_blank = axs

ax_loss.plot(loss_history)
ax_loss.set_title("Training loss")
ax_loss.set_xlabel("epoch")
ax_loss.set_ylabel("loss")

ax_test_loss.plot(test_loss_history)
ax_test_loss.set_title("Test loss")
ax_test_loss.set_xlabel("epoch")
ax_test_loss.set_ylabel("loss")

ax_test_acc.plot(test_acc_history)
ax_test_acc.set_title("Test accuracy")
ax_test_acc.set_xlabel("epoch")
ax_test_acc.set_ylabel("accuracy")

ax_blank.axis("off")

plt.plot()

In [None]:
torch.save(model.state_dict(), "model_weights.pth")