# Training cell predictor

In this experiment, we will try to improve the inference speed of the model.

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np

The code below predict the next state of a cell explicitly. The timing results show that each prediction takes about 7.5 microseconds. Performing this for each cell in a 50 x 50 grid would take 7.5 * 50**2 = 18750 microseconds ~= 19 milliseconds ~= 0.02 seconds. This is barely noticeable when the time between display.flip calls is 0.1 seconds and doesn't impact it too much when the time between display.flip calls is 0.01 seconds.

Let's compare this to model-based prediction.

In [None]:
%%timeit

cell_states = np.array([[1, 0, 0], [0, 0, 0], [1, 0, 1]])
_neighbour_mask = np.array([[1, 1, 1], [1, 0, 1], [1, 1, 1]])

alive = cell_states[1, 1] == 1
alive_neighbours = np.sum(cell_states * _neighbour_mask)
a = (
    int(alive_neighbours in [2, 3])
    if alive
    else int(alive_neighbours == 3)
)

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
class CellPredictorNeuralNetwork(nn.Module):
    """Predicts the next state of the cells.

    Inputs:
        x: Tensor of shape (batch_size, channels, width, height), where channels=1, width=3 and height=3.
    
    Returns: Tensor of shape (batch_size,), the logits of the predicted states.
    """

    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 85, 3)
        self.linear1 = nn.Linear(85, 10)
        self.linear2 = nn.Linear(10, 1)

    def forward(self, x):
        x = F.relu(self.conv(x))
        x = torch.flatten(x, 1)
        x = F.relu(self.linear1(x))
        logits = self.linear2(x)
        return logits

In [None]:
model = CellPredictorNeuralNetwork().to(device)
model.load_state_dict(torch.load("model_weights.pth"))
print(model)

Below, the loaded model takes about 130 microseconds to make each prediction - over 20 times longer than the rule-based method. Performing this for each cell in a 50 x 50 grid would take 130 * 50**2 = 325000 microseconds = 325 milliseconds ~= 0.3 seconds. This is larger than the time between display.flip calls, which is likely why the difference between model-based prediction and rule-based prediction is so noticeable.

In [None]:
%%timeit

cell_states = np.array([[1, 0, 0], [0, 0, 0], [1, 0, 1]])
# Convert to torch tensor, add batch and channel dimensions
X = torch.tensor(cell_states, dtype=torch.float).expand((1, 1, -1, -1))
with torch.no_grad():
    logit = model(X).item()
result = int(logit > 0.0)

## How to time models effectively

Does the model need to be trained to be timed accurately?

In [None]:
model = CellPredictorNeuralNetwork().to(device)

In [None]:
%%timeit

cell_states = np.array([[1, 0, 0], [0, 0, 0], [1, 0, 1]])
# Convert to torch tensor, add batch and channel dimensions
X = torch.tensor(cell_states, dtype=torch.float).expand((1, 1, -1, -1))
with torch.no_grad():
    logit = model(X).item()
result = int(logit > 0.0)

Looks like an untrained model takes about the same amount of time to run as a trained model. So, we can try different model architectures without training them to compare inference times.

Does the data need to be representative or do all examples take the same amount of time to process?

In [None]:
class GameOfLifeDataset(torch.utils.data.Dataset):
    _size = 512

    def __init__(self):
        pass

    def __len__(self):
        return self._size

    def __getitem__(self, idx):
        # 12 -> "0b1100" -> "1100" -> "000001100"
        idx_bin = bin(idx)[2:].rjust(9, "0")
        X = torch.tensor([float(ch) for ch in idx_bin], dtype=torch.float32).reshape(1, 3, 3) # (channels, width, height)
        alive = X[0, 1, 1] > 0.5
        alive_neighbours = torch.sum(X) - X[0, 1, 1]
        next_alive = (alive and alive_neighbours > 1.5 and alive_neighbours < 3.5) or (not alive and alive_neighbours > 2.5 and alive_neighbours < 3.5)
        y = torch.tensor([float(next_alive)], dtype=torch.float32)
        return X, y

In [None]:
dataset = GameOfLifeDataset()
dataloader = DataLoader(dataset, batch_size=1, shuffle=True) # We use batch size 1 to be representative of inference time

In [None]:
import time

def time_model(model, dataloader):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    time_taken = 0

    for X, y in dataloader:
        tic = time.time()
        with torch.no_grad():
            pred = model(X)
        toc = time.time()
        time_taken += toc - tic
    
    avg_time = time_taken / num_batches
    return avg_time

times = []
for i in range(100):
    times.append(1_000_000 * time_model(model, dataloader))
avg_time = np.mean(times)
std_time = np.std(times)
print(f"Average time: {avg_time:.2f} microseconds, std: {std_time:.2f} microseconds.")

Here we can see that inference time for a random example is roughly the same as the inference time for the specific example picked earlier. The times aren't exactly the same, but that's likely because of the different methodology between our custom timing code and the %%timeit command.

## Performance improvements

Now let's experiment with some performance improvements.

### Whole-grid prediction

Grid prediction could be done all at once by a fully-convolutional model, instead of one cell at a time.

In [None]:
class FullyConvolutionalCellPredictor(nn.Module):
    """Predicts the next state of the cells.

    Inputs:
        x: Tensor of shape (batch_size, channels, width+2, height+2), where channels=1. width and height are the dimensions of the entire game grid.
           We add one cell of padding on each side to ensure that predictions can be made for the boundary cells.
    
    Returns: Tensor of shape (batch_size, width, height), the logits of the predicted states.
    """

    def __init__(self):
        super().__init__()
        self.conv0 = nn.Conv2d(1, 85, 3)
        self.conv1 = nn.Conv2d(85, 10, 1)
        self.conv2 = nn.Conv2d(10, 1, 1)

    def forward(self, x):
        x = F.relu(self.conv0(x))
        x = F.relu(self.conv1(x))
        logits = self.conv2(x)
        logits = torch.squeeze(logits, 1) # Remove channels dimension
        return logits

fc_model = FullyConvolutionalCellPredictor()

How do we time this model? We need some data that looks like an entire grid, plus boundary cells

In [None]:
def make_fc_model_input(size=50):
    with torch.no_grad():
        X = torch.randint(0, 2, (size, size), dtype=torch.float)
        X = F.pad(X, (1, 1, 1, 1))
        X = X.unsqueeze(0).unsqueeze(0) # Insert channels and batch dimensions
    return X

X = make_fc_model_input()
print(X.shape)
print(X)

In [None]:
X = make_fc_model_input(50)
%timeit fc_model(X)

The fully-convolutional model takes about 590 µs to make predictions on the whole grid, which is much faster (over 500 times faster!) than the 130 * 50**2 = 325,000 µs taken by running the single-cell model on the whole grid one cell at a time.

Varying the size (= width = height) of the grid to different powers of two gives the following data:

| Grid size | Inference time (µs) |
| --------- | ------------------- |
| 1 | 135 |
| 2 | 141 |
| 4 | 148 |
| 8 | 150 |
| 16 | 291 |
| 32 | 375 |
| 64 | 1900 |
| 128 | 6770 |
| 256 | 32200 |
| 512 | 121000 |
| 1024 | 478000 |
| 2048 | 1890000 |

Plotting this data on a log-log plot shows a relationship that is roughly $ \text{Time} = O((\text{Grid size})^{1.12}) $. This is much better than single-cell prediction, which must be at least $ O((\text{Grid size})^{2}) $, because iit requires the model to be run separately on each cell.

### Batched prediction

Instead of predicting one cell at a time, we can feed a batch of cells into the model and get the predictions all at once.

To make use of this, the game code wrapping this model would have to be updated so it batches up the cells and then reads the results out from the batch. This extra preprocessing and post-processing may introduce additional overhead and would be fiddly to code. Therefore, it seems most sensible to focus on the fully-convolutional approach for now.

## Training a fully-convolutional model

In [None]:
# We need to use a slightly different dataset because the model outputs tensors with width and height dimensions. The target values from this dataset will have
# shape (channels, width, height), where channels = 1.

class FullyConvolutionalDataset(torch.utils.data.Dataset):
    _size = 512

    def __init__(self):
        pass

    def __len__(self):
        return self._size

    def __getitem__(self, idx):
        # 12 -> "0b1100" -> "1100" -> "000001100"
        idx_bin = bin(idx)[2:].rjust(9, "0")
        X = torch.tensor([float(ch) for ch in idx_bin], dtype=torch.float32).reshape(1, 3, 3) # (channels, width, height)
        alive = X[0, 1, 1] > 0.5
        alive_neighbours = torch.sum(X) - X[0, 1, 1]
        next_alive = (alive and alive_neighbours > 1.5 and alive_neighbours < 3.5) or (not alive and alive_neighbours > 2.5 and alive_neighbours < 3.5)
        y = torch.tensor([float(next_alive)], dtype=torch.float32).reshape(1, 1) # (width, height)
        return X, y

In [None]:
dataset = FullyConvolutionalDataset()
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
model = FullyConvolutionalCellPredictor().to(device)
print(model)

In [None]:
X, y = next(iter(dataloader))
logits = model(X)[0]
pred_probab = F.sigmoid(logits)
y_pred = int(pred_probab > 0.5)
print(f"Predicted state: {y_pred}")

In [None]:
print(f"Model structure: {model}\n\n")

for name, param in model.named_parameters():
    print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    avg_loss = 0.0

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):

        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        avg_loss += loss.item()

        if batch % 20 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    
    avg_loss /= size
    return avg_loss


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            prob = F.sigmoid(pred)
            correct += ((prob > 0.5).type(torch.float) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return {
        "loss": test_loss,
        "acc": correct
    }

In [None]:
learning_rate = 1e-1
batch_size = 4
epochs = 20

loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

loss_history = []
test_loss_history = []
test_acc_history = []

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    # Since our dataset represents the full set of possible states, we can safely use it for both training and testing
    epoch_loss = train_loop(dataloader, model, loss_fn, optimizer)
    loss_history.append(epoch_loss)
    test_metrics = test_loop(dataloader, model, loss_fn)
    test_loss_history.append(test_metrics["loss"])
    test_acc_history.append(test_metrics["acc"])

print("Done!")

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(2, 2, figsize=(10, 8), constrained_layout=True)
axs = axs.flatten()
ax_loss, ax_test_loss, ax_test_acc, ax_blank = axs

ax_loss.plot(loss_history)
ax_loss.set_title("Training loss")
ax_loss.set_xlabel("epoch")
ax_loss.set_ylabel("loss")

ax_test_loss.plot(test_loss_history)
ax_test_loss.set_title("Test loss")
ax_test_loss.set_xlabel("epoch")
ax_test_loss.set_ylabel("loss")

ax_test_acc.plot(test_acc_history)
ax_test_acc.set_title("Test accuracy")
ax_test_acc.set_xlabel("epoch")
ax_test_acc.set_ylabel("accuracy")

ax_blank.axis("off")

plt.plot()