# Training cell predictor

In this notebook, we train a model to predict a cell's next state based on its current state and that of its neighbours.

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
class GameOfLifeDataset(torch.utils.data.Dataset):
    _size = 512

    def __init__(self):
        pass

    def __len__(self):
        return self._size

    def __getitem__(self, idx):
        # 12 -> "0b1100" -> "1100" -> "000001100"
        idx_bin = bin(idx)[2:].rjust(9, "0")
        X = torch.tensor([float(ch) for ch in idx_bin], dtype=torch.float32).reshape(1, 3, 3) # (channels, width, height)
        alive = X[0, 1, 1] > 0.5
        alive_neighbours = torch.sum(X) - X[0, 1, 1]
        next_alive = (alive and alive_neighbours > 1.5 and alive_neighbours < 3.5) or (not alive and alive_neighbours > 2.5 and alive_neighbours < 3.5)
        y = torch.tensor([float(next_alive)], dtype=torch.float32)
        return X, y

In [None]:
from torch.utils.data import DataLoader

dataset = GameOfLifeDataset()
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
class CellPredictorNeuralNetwork(nn.Module):
    """Predicts the next state of the cells.

    Inputs:
        x: Tensor of shape (batch_size, channels, width, height), where channels=1, width=3 and height=3.
    
    Returns: Tensor of shape (batch_size,), the logits of the predicted states.
    """

    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 85, 3)
        self.linear1 = nn.Linear(85, 10)
        self.linear2 = nn.Linear(10, 1)

    def forward(self, x):
        x = F.relu(self.conv(x))
        x = torch.flatten(x, 1)
        x = F.relu(self.linear1(x))
        logits = self.linear2(x)
        return logits

In [None]:
model = CellPredictorNeuralNetwork().to(device)
print(model)

In [None]:
X, y = next(iter(dataloader))
logits = model(X)[0]
pred_probab = F.sigmoid(logits)
y_pred = int(pred_probab > 0.5)
print(f"Predicted state: {y_pred}")

In [None]:
print(f"Model structure: {model}\n\n")

for name, param in model.named_parameters():
    print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    avg_loss = 0.0

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):

        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_loss += loss.item()

        if batch % 30 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    
    avg_loss /= size
    return avg_loss


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            prob = F.sigmoid(pred)
            correct += ((prob > 0.5).type(torch.float) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss, correct

In [None]:
learning_rates = [1e-3, 1e-2, 1e-1, 2e-1]
repeats = len(learning_rates)
batch_size = 4
epochs = 10

fig, axs = plt.subplots(2, 2, figsize=(15, 10))
((ax_loss, ax_test_loss), (ax_test_acc, _)) = axs
ax_loss.set_title(f"Training loss")
ax_loss.set_xlabel("epoch")
ax_loss.set_ylabel("loss")
ax_test_loss.set_title(f"Test loss")
ax_test_loss.set_xlabel("epoch")
ax_test_loss.set_ylabel("test loss")
ax_test_acc.set_title(f"Test accuracy")
ax_test_acc.set_xlabel("epoch")
ax_test_acc.set_ylabel("test accuracy")

for r in range(repeats):
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rates[r])

    loss_history = []
    test_loss_history = []
    test_acc_history = []

    for t in range(epochs):
        print(f"Repeat {r+1}, Epoch {t+1}\n------------------------------------------")
        # Since our dataset represents the full set of possible states, we can safely use it for both training and testing
        epoch_loss = train_loop(dataloader, model, loss_fn, optimizer)
        loss_history.append(epoch_loss)
        test_loss, test_acc = test_loop(dataloader, model, loss_fn)
        test_loss_history.append(test_loss)
        test_acc_history.append(test_acc)

    ax_loss.plot(loss_history, label=f"Learning rate = {learning_rates[r]}")
    ax_test_loss.plot(test_loss_history, label=f"Learning rate = {learning_rates[r]}")
    ax_test_acc.plot(test_acc_history, label=f"Learning rate = {learning_rates[r]}")
    print("Done!")
    
for ax in axs.flatten():
    ax.legend()
plt.plot()

# Conclusion

We can clearly see that using a much more aggressive learning rate of 0.1 or 0.2 as opposed to 0.001 is beneficial. However, the key insight from this experiment was that the test accuracy calculation had a bug in it which meant the reported accuracy was much lower than the actual accuracy. One this bug was fixed, it became clear that even with a learning rate of 0.001, the model achieves 100% accuracy very quickly.