# Experiment 025

In this experiment, we'll try out different weight initializations and see if they improve the performance. We'll continue using a learning rate of 1e-4.

In [1]:
import os
from pathlib import Path
import shutil
import datetime

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

In [2]:
class RecordingDataset(Dataset):
    def __init__(self, path: str):
        self.path = path
        if not os.path.exists(path):
            raise FileNotFoundError()
        with os.scandir(self.path) as it:
            entry: os.DirEntry = next(iter(it))
            _, self.ext = os.path.splitext(entry.name)
            self.highest_index = max((int(Path(file.path).stem) for file in it), default=-1)

    def __len__(self):
        return self.highest_index + 1

    def __getitem__(self, idx):
        file = os.path.join(self.path, f"{idx}{self.ext}")
        if not os.path.exists(file):
            raise IndexError()
        boards = np.load(file)

        def transform(board):
            board = torch.tensor(board, dtype=torch.long)
            board = F.one_hot(board, 2) # One-hot encode the cell types
            board = board.type(torch.float) # Convert to floating-point
            board = board.permute((2, 0, 1)) # Move channels/classes to dimension 0
            return board

        x = transform(boards[-2]) # Ignore all boards except the last two
        y = transform(boards[-1])
        return x, y
        

In [3]:
train_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "train"))
test_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "test"))
batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

x, y = next(iter(train_dataloader))
print(x.shape, x.dtype)
print(y.shape, y.dtype)

torch.Size([4, 2, 22, 10]) torch.float32
torch.Size([4, 2, 22, 10]) torch.float32


In [4]:
device = (
    "cuda" if torch.cuda.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cpu device


In [5]:
class TetrisModel(nn.Module):
    """Predicts the next state of the cells.

    Inputs:
        x: Tensor of float32 of shape (batch_size, channels, height, width). channels = 2 is the one-hot encoding of cell types, with
           0 for empty cells and 1 for filled cells. height = 22 and width = 10 are the dimensions of the game board. The entries
           should be 0 for empty cells and 1 for filled cells.
        z: Tensor of float32 of shape (batch_size, 4). The entries should be random numbers sampled from a uniform distribution.
    
    Returns: Tensor of float32 of shape (batch_size, height, width), logits for the new cells. Probabilities close to 0 (negative logits)
             correspond to empty cells, and probabilities close to 1 (positive logits) correspond to filled cells.
    """

    def __init__(self):
        super().__init__()
        self.loc = nn.Sequential(
            nn.Conv2d(6, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
        )
        self.glob = nn.Sequential(
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(160, 10)
        )
        self.head = nn.Sequential(
            nn.Conv2d(26, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 2, 1),
            nn.Softmax(dim=1)
        )

    def forward(self, x, z):
        batch_size, channels, height, width = x.shape
        
        z = z[:, :, None, None] # Expand dims to match x
        z = z.repeat(1, 1, height, width) # Upscale to image size
        x = torch.cat((x, z), dim=1)

        x = self.loc(x)

        x_glob = self.glob(x)
        x_glob = x_glob[:, :, None, None] # Expand dims
        x_glob = x_glob.repeat(1, 1, height, width) # Upscale to image size
        x = torch.cat((x, x_glob), dim=1)

        y = self.head(x)
        return y

In [6]:
class TetrisDiscriminator(nn.Module):
    """A discriminator for the cell state predictions. Assesses the output of the generator.

    Inputs:
        x: Tensor of float32 of shape (batch_size, channels, height, width). channels = 2 is the one-hot encoding of cell types, with
           0 for empty cells and 1 for filled cells. height = 22 and width = 10 are the dimensions of the game board. The entries
           should be 0 for empty cells and 1 for filled cells.
        y: Tensor of float32 of shape (batch_size, channels, height, width), as with x. This should be either the output of the
           generator (with exp applied) or the one-hot encoding of the ground truth of the next cell states.
    
    Returns: Tensor of float32 of shape (batch_size, 1), decisions on whether the data are real or fake. Probabilities close to 0 (negative logits)
             correspond to fake data, and probabilities close to 1 (positive logits) correspond to real data.
    """

    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3),
                nn.ReLU(),
                nn.Flatten(),
                nn.Linear(112, 16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits

In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [8]:
gen = TetrisModel().to(device)
disc = TetrisDiscriminator().to(device)

with torch.no_grad():
    X, y = next(iter(train_dataloader))
    z = torch.rand(batch_size, 4)
    y_gen = gen(X, z)
    pred_on_real = F.sigmoid(disc(X, y)[0])
    pred_on_fake = F.sigmoid(disc(X, y_gen)[0])
    print(f"Number of generator parameters: {count_parameters(gen)}")
    print(f"Number of discriminator parameters: {count_parameters(disc)}")
    print(f"Predicted label for real data: {pred_on_real}")
    print(f"Predicted label for fake data: {pred_on_fake}")

Number of generator parameters: 17996
Number of discriminator parameters: 7057
Predicted label for real data: 0.4499087631702423
Predicted label for fake data: 0.44912049174308777


# Understanding default initialization

According to [this page](https://discuss.pytorch.org/t/how-are-layer-weights-and-biases-initialized-by-default/13073) and [this page](https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/conv.py#L146), Linear and Conv2D layer weights and biases are both are initialized by sampling from $ \text{Uniform} \left( -\frac{1}{\sqrt{k}}, \frac{1}{\sqrt{k}} \right) $, where $ k $ is the "fan in" for the node.

For Linear, the "fan in" is `weight.size(1)` (i.e. `layer.size(0)`) and for Conv2D, it is `weight.size(1) * prod(*kernel_size)`. We aren't using biases for Conv2D because we're using batch normalization.

In [164]:
def describe_parameters(model):
    with torch.no_grad():
        for name, param in model.named_parameters():
            param_var = torch.var(param.data)
            param_shape = tuple(param.data.size())
            print(f"Parameter {name}: shape={param_shape}, var={param_var:.6f}, min={torch.min(param.data):.6f}, max={torch.max(param.data):.6f}")

In [165]:
print("Describing generator parameters:")
gen = TetrisModel().to(device)
describe_parameters(gen)

Describing generator parameters:
Parameter loc.0.weight: shape=(16, 6, 3, 3), var=0.006143, min=-0.136041, max=0.135677
Parameter loc.1.weight: shape=(16,), var=0.000000, min=1.000000, max=1.000000
Parameter loc.1.bias: shape=(16,), var=0.000000, min=0.000000, max=0.000000
Parameter loc.3.weight: shape=(16, 16, 3, 3), var=0.002333, min=-0.083331, max=0.083284
Parameter loc.4.weight: shape=(16,), var=0.000000, min=1.000000, max=1.000000
Parameter loc.4.bias: shape=(16,), var=0.000000, min=0.000000, max=0.000000
Parameter glob.0.weight: shape=(16, 16, 3, 3), var=0.002368, min=-0.083251, max=0.083316
Parameter glob.1.weight: shape=(16,), var=0.000000, min=1.000000, max=1.000000
Parameter glob.1.bias: shape=(16,), var=0.000000, min=0.000000, max=0.000000
Parameter glob.4.weight: shape=(16, 16, 3, 3), var=0.002329, min=-0.083330, max=0.083303
Parameter glob.5.weight: shape=(16,), var=0.000000, min=1.000000, max=1.000000
Parameter glob.5.bias: shape=(16,), var=0.000000, min=0.000000, max=0.0

The Linear and Conv2D parameters are consistent with our understanding. The batch normalization layers are initialized to have all biases 0 and all weights 1.

In [166]:
print("Describing discriminator parameters:")
disc = TetrisDiscriminator().to(device)
describe_parameters(disc)

Describing discriminator parameters:
Parameter body.0.weight: shape=(16, 4, 3, 3), var=0.008941, min=-0.166594, max=0.166392
Parameter body.0.bias: shape=(16,), var=0.009371, min=-0.158820, max=0.165020
Parameter body.3.weight: shape=(16, 16, 3, 3), var=0.002247, min=-0.083271, max=0.083245
Parameter body.3.bias: shape=(16,), var=0.002015, min=-0.077155, max=0.069286
Parameter body.5.weight: shape=(16, 16, 3, 3), var=0.002291, min=-0.083271, max=0.083330
Parameter body.5.bias: shape=(16,), var=0.002536, min=-0.067519, max=0.072143
Parameter body.8.weight: shape=(16, 112), var=0.002949, min=-0.094480, max=0.094468
Parameter body.8.bias: shape=(16,), var=0.002630, min=-0.080585, max=0.087770
Parameter body.10.weight: shape=(1, 16), var=0.021753, min=-0.237778, max=0.216132
Parameter body.10.bias: shape=(1,), var=nan, min=0.135435, max=0.135435


The default initialization doesn't seem tailored to our use case, according to the PReLU / Kaiming He initialization paper (https://arxiv.org/abs/1502.01852).

# Custom initialization function

Let's write a custom initialization function. I'd like to use Kaiming He initialization, especially on the convolutional layers. Here are the options we could vary:
* Initialize the biases to small random values too, or set the biases to be zero.
* Use custom initialization in the generator, discriminator, or both.
* Use a uniform or normal distribution for the normalization. (The Kaiming He paper focuses on the variance of the weights, so we could realise this with either distribution).
* Use Kaiming He initialization in "fan in" mode or "fan out" mode.

In [253]:
def get_init_fn(zero_bias, fan_mode):
    
    def init_weights(module):
        if isinstance(module, nn.Linear):
            nn.init.kaiming_uniform_(module.weight, nonlinearity="relu", mode=fan_mode)
            if zero_bias and module.bias is not None:
                module.bias.data.fill_(0.0)
        if isinstance(module, nn.Conv2d):
            nn.init.kaiming_uniform_(module.weight, nonlinearity="relu", mode=fan_mode)
            if zero_bias and module.bias is not None:
                module.bias.data.fill_(0.0)
    
    return init_weights


init_fn = get_init_fn(zero_bias=True, fan_mode="fan_in")

conv = nn.Conv2d(16, 6, kernel_size=3)
describe_parameters(conv)
conv.apply(init_fn)
describe_parameters(conv)

Parameter weight: shape=(6, 16, 3, 3), var=0.002387, min=-0.083272, max=0.083010
Parameter bias: shape=(6,), var=0.003762, min=-0.075199, max=0.076754
Parameter weight: shape=(6, 16, 3, 3), var=0.014088, min=-0.202931, max=0.203769
Parameter bias: shape=(6,), var=0.000000, min=0.000000, max=0.000000


# Training the model

In [247]:
import itertools

def find_interesting_examples(dataset, num=3):
    num_spawns = num
    
    def inner():
        num_spawns_left = num_spawns

        for x, y in dataset:
            # Check for block spawn
            if (x.argmax(0)[0] == 0).all() & (y.argmax(0)[0] == 1).any():
                if num_spawns_left > 0:
                    num_spawns_left -= 1
                    yield x, y
                else:
                    continue
            
    return list(itertools.islice(inner(), num))

In [248]:
def render_prediction(x, pred, y):
    """Renders an example and prediction into a single-image array.
    
    Inputs:
        x: Tensor of shape (height, width), the model input.
        pred: Tensor of shape (height, width), the model prediction.
        y: Tensor of shape (height, width), the target.
    """
    assert len(x.shape) == 2, f"Expected tensors of shape (width, height) but got {x.shape}"
    assert x.shape == pred.shape, f"Shapes do not match: {x.shape} != {pred.shape}"
    assert x.shape == y.shape, f"Shapes do not match: {x.shape} != {y.shape}"
    height, width = x.shape
    with torch.no_grad():
        separator = torch.ones(height, 1, dtype=x.dtype)
        return torch.cat((x, separator, pred, separator, y), dim=-1)

In [249]:
blocks = [
    torch.tensor(
        [[0, 0, 0, 1, 1, 1, 1, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # I

    torch.tensor(
        [[0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # O

    torch.tensor(
        [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # J

    torch.tensor(
        [[0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # T

    torch.tensor(
        [[0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # S

    torch.tensor(
        [[0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # L

    torch.tensor(
        [[0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int) # Z
]

def get_valid_block_spawns(classes_X, classes_y_fake):
    """Determines whether predicted block spawns have a valid shape.
    
    Inputs:
        classes_X: Tensor of int32 of shape (batch_size, height, width), the first time step (with argmax applied on cell types).
        classes_y_fake: Tensor of int32 of shape (batch_size, height, width), the model's prediction (with argmax applied on cell types).

    Returns: Tensor of bool of shape (batch_size,), whether the items are predicted block spawns AND valid.
    """
    with torch.no_grad():
        batch_size = classes_X.size(0)
        ret = torch.full((batch_size,), False)

        # Take difference to see which cells are full but weren't before.
        diff = classes_y_fake - classes_X

        # It's only a valid block spawn if the change in the first 3 rows matches
        # one of the valid configurations.
        for block in blocks:
            ret |= (diff[:, :3, :] == block).all(-1).all(-1)
        
        return ret


In [250]:
real_label = 1.0
fake_label = 0.0


def train_loop(dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc):
    gen.train()
    disc.train()

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(train_dataloader):
        ##################################################################
        # (1) Update discriminator: minimize -log(D(x)) - log(1 - D(G(z)))
        ##################################################################
        disc.zero_grad()

        ## Train with all-real batch
        # Format batch
        X, y = X.to(device), y.to(device)
        batch_size = X.size(0)
        real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through discriminator
        output = torch.flatten(disc(X, y))
        # Calculate loss on all-real batch
        err_disc_real = loss_fn(output, real_labels)
        # Calculate gradients for discriminator in backward pass
        err_disc_real.backward()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        z = torch.rand(batch_size, 4)
        # Generate fake image batch with generator
        y_fake = gen(X, z)
        fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)
        # Classify all fake batch with discriminator
        output = torch.flatten(disc(X, y_fake.detach()))
        # Calculate discriminator's loss on the all-fake batch
        err_disc_fake = loss_fn(output, fake_labels)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        err_disc_fake.backward()

        ## Update discriminator weights
        # Compute error of discriminator as sum over the fake and the real batches
        err_disc = err_disc_real + err_disc_fake
        # Update discriminator
        optimizer_disc.step()

        ##############################################
        # (2) Update generator: minimize -log(D(G(z)))
        ##############################################
        gen.zero_grad()
        # Since we just updated the discriminator, perform another forward pass of the all-fake batch through it
        output = torch.flatten(disc(X, y_fake))
        # Calculate the generator's loss based on this output
        # We use real labels because the generator wants to fool the discriminator
        err_gen = loss_fn(output, real_labels)
        # Calculate gradients for generator
        err_gen.backward()
        # Update generator
        optimizer_gen.step()

        # Output training stats
        if batch % 20 == 0:
            current = batch * dataloader.batch_size + batch_size
            print(f"[{current}/{size}] D loss: {err_disc.item():.4f}, G loss: {err_gen.item():.4f}")


def test_loop(split_name, dataloader, gen, disc, loss_fn, tb_writer, epoch, examples):
    gen.eval()
    disc.eval()

    loss_disc = 0.0
    loss_gen = 0.0
    disc_accuracy = 0.0
    cell_accuracy = 0.0
    board_accuracy = 0.0
    spawn_recall = 0.0
    num_spawns = 0.0
    spawn_validity = 0.0
    num_predicted_spawns = 0.0
    spawn_precision = 0.0
    scores_real = np.zeros(len(dataloader.dataset))
    scores_fake = np.zeros(len(dataloader.dataset))

    num_batches = len(dataloader)
    with torch.no_grad():        
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            
            batch_size = X.size(0)
            real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)

            output_real = disc(X, y)
            loss_disc += loss_fn(output_real, real_labels).item()

            z = torch.rand(batch_size, 4)
            y_fake = gen(X, z)
            output_fake = disc(X, y_fake)
            
            loss_disc += loss_fn(output_fake, fake_labels).item()
            loss_gen += loss_fn(output_fake, real_labels).item()

            pred_real = (output_real > 0.0)
            pred_fake = (output_fake > 0.0)
            disc_accuracy += pred_real.type(torch.float).mean().item()
            disc_accuracy += (~pred_fake).type(torch.float).mean().item()

            classes_X = torch.argmax(X, dim=1)
            classes_y = torch.argmax(y, dim=1)
            classes_y_fake = torch.argmax(y_fake, dim=1)
            cell_accuracy += (classes_y_fake == classes_y).type(torch.float).mean().item()
            board_accuracy += (classes_y_fake == classes_y).all(-1).all(-1).type(torch.float).mean().item()

            actual_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y[:, 0, :] == 1).any(-1)
            predicted_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y_fake[:, 0, :] == 1).any(-1)
            num_true_positives = (actual_spawns & predicted_spawns).type(torch.float).sum().item()
            spawn_recall += num_true_positives
            spawn_precision += num_true_positives
            num_spawns += actual_spawns.type(torch.float).sum().item()
            valid_spawns = get_valid_block_spawns(classes_X, classes_y_fake)
            spawn_validity += valid_spawns.type(torch.float).sum().item()
            num_predicted_spawns += predicted_spawns.type(torch.float).sum().item()

            start_index = dataloader.batch_size * batch
            end_index = start_index + batch_size
            scores_real[start_index:end_index] = torch.sigmoid(output_real).numpy()
            scores_fake[start_index:end_index] = torch.sigmoid(output_fake).numpy()

    loss_disc /= num_batches
    loss_gen /= num_batches
    cell_accuracy /= num_batches
    board_accuracy /= num_batches
    spawn_recall /= num_spawns
    spawn_precision = np.nan if (num_predicted_spawns == 0.0) else spawn_precision / num_predicted_spawns
    disc_accuracy /= (2.0 * num_batches)
    spawn_validity = np.nan if (num_predicted_spawns == 0.0) else spawn_validity / num_predicted_spawns

    print(f"{split_name} error: \n D loss: {loss_disc:>8f}, G loss: {loss_gen:>8f}, D accuracy: {(100*disc_accuracy):>0.1f}%, cell accuracy: {(100*cell_accuracy):>0.1f}%, board accuracy: {(100*board_accuracy):>0.1f}% \n")

    tb_writer.add_scalar(f"Discriminator loss/{split_name}", loss_disc, epoch)
    tb_writer.add_scalar(f"Loss/{split_name}", loss_gen, epoch)
    tb_writer.add_scalar(f"Discriminator accuracy/{split_name}", disc_accuracy, epoch)
    tb_writer.add_scalar(f"Cell accuracy/{split_name}", cell_accuracy, epoch)
    tb_writer.add_scalar(f"Board accuracy/{split_name}", board_accuracy, epoch)
    tb_writer.add_scalar(f"Spawn recall/{split_name}", spawn_recall, epoch)
    tb_writer.add_scalar(f"Spawn precision/{split_name}", spawn_precision, epoch)
    tb_writer.add_scalar(f"Spawn validity/{split_name}", spawn_validity, epoch)

    with torch.no_grad():
        for i, (X, y) in enumerate(examples):
            X, y = X.unsqueeze(0), y.unsqueeze(0)
            z = torch.rand(1, 4)
            y_fake = gen(X, z)
            X, y, y_fake = X.squeeze(0), y.squeeze(0), y_fake.squeeze(0)
            X, y, y_fake = X.argmax(0), y.argmax(0), y_fake.argmax(0)
            img = render_prediction(X, y_fake, y)
            tb_writer.add_image(f"Predictions/{split_name}/{i}", img, epoch, dataformats="HW")
    
    tb_writer.add_histogram(f"Discriminator scores/{split_name}/real", scores_real, epoch)
    tb_writer.add_histogram(f"Discriminator scores/{split_name}/fake", scores_fake, epoch)


In [251]:
def train(run_name="", learning_rate=1e-4, epochs=100, zero_bias=False, fan_mode="fan_in"):
    gen = TetrisModel().to(device)
    disc = TetrisDiscriminator().to(device)

    init_fn = get_init_fn(zero_bias=zero_bias, fan_mode=fan_mode)
    gen.apply(init_fn)
    disc.apply(init_fn)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_025")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples)
        gen_zero_grads = 0
        for name, weight in gen.named_parameters():
            tb_writer.add_histogram(f"Weights/{name}", weight, epoch)
            if weight.grad is not None:
                tb_writer.add_histogram(f"Gradients/{name}", weight.grad, epoch)
                gen_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Zero gradients", gen_zero_grads, epoch)
        disc_zero_grads = 0
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)
            disc_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Discriminator zero gradients", disc_zero_grads, epoch)

    tb_writer.close()
    print("Done!")

In [252]:
for zero_bias in [False, True]:
    for fan_mode in ["fan_in", "fan_out"]:
        run_name = f"ZB_{zero_bias}_FM_{fan_mode}"
        train(run_name=run_name, zero_bias=zero_bias, fan_mode=fan_mode)

Epoch 0
-------------------------------
[4/1762] D loss: 1.5795, G loss: 0.6018
[84/1762] D loss: 1.4371, G loss: 0.7317
[164/1762] D loss: 1.2823, G loss: 0.8106
[244/1762] D loss: 1.1704, G loss: 0.9172
[324/1762] D loss: 1.1056, G loss: 0.8183
[404/1762] D loss: 0.9570, G loss: 0.9315
[484/1762] D loss: 0.9237, G loss: 0.9371
[564/1762] D loss: 0.8048, G loss: 1.0012
[644/1762] D loss: 0.9164, G loss: 0.9583
[724/1762] D loss: 0.6617, G loss: 1.2998
[804/1762] D loss: 0.4900, G loss: 1.4743
[884/1762] D loss: 0.3579, G loss: 1.7398
[964/1762] D loss: 0.4884, G loss: 1.6869
[1044/1762] D loss: 0.6451, G loss: 1.5514
[1124/1762] D loss: 0.5755, G loss: 1.5912
[1204/1762] D loss: 0.3772, G loss: 1.9778
[1284/1762] D loss: 0.3811, G loss: 2.0983
[1364/1762] D loss: 0.4416, G loss: 1.7795
[1444/1762] D loss: 0.4798, G loss: 1.9658
[1524/1762] D loss: 0.3614, G loss: 2.4524
[1604/1762] D loss: 0.4262, G loss: 2.3321
[1684/1762] D loss: 0.4936, G loss: 1.7772
[1762/1762] D loss: 0.3858, G 

AttributeError: 'NoneType' object has no attribute 'data'

In [254]:
for zero_bias in [True]:
    for fan_mode in ["fan_in", "fan_out"]:
        run_name = f"ZB_{zero_bias}_FM_{fan_mode}"
        train(run_name=run_name, zero_bias=zero_bias, fan_mode=fan_mode)

Epoch 0
-------------------------------
[4/1762] D loss: 1.6817, G loss: 0.3644
[84/1762] D loss: 1.3869, G loss: 0.6778
[164/1762] D loss: 1.4101, G loss: 0.5751
[244/1762] D loss: 1.2696, G loss: 0.6878
[324/1762] D loss: 1.0476, G loss: 0.7896
[404/1762] D loss: 0.7619, G loss: 1.0153
[484/1762] D loss: 0.7485, G loss: 1.3508
[564/1762] D loss: 0.7114, G loss: 1.3831
[644/1762] D loss: 0.7037, G loss: 1.4669
[724/1762] D loss: 0.4612, G loss: 1.9203
[804/1762] D loss: 0.4378, G loss: 2.2233
[884/1762] D loss: 0.5533, G loss: 2.2048
[964/1762] D loss: 0.4521, G loss: 2.2171
[1044/1762] D loss: 0.4330, G loss: 2.7242
[1124/1762] D loss: 0.3608, G loss: 2.8024
[1204/1762] D loss: 0.2840, G loss: 3.1952
[1284/1762] D loss: 0.2989, G loss: 3.0980
[1364/1762] D loss: 0.3148, G loss: 2.9682
[1444/1762] D loss: 0.4032, G loss: 2.9932
[1524/1762] D loss: 0.2260, G loss: 3.4064
[1604/1762] D loss: 0.2569, G loss: 3.9677
[1684/1762] D loss: 0.3499, G loss: 3.5932
[1762/1762] D loss: 0.1581, G 

In [255]:
for zero_bias in [False, True]:
    for fan_mode in ["fan_in", "fan_out"]:
        run_name = f"ZB_{zero_bias}_FM_{fan_mode}"
        train(run_name=run_name, zero_bias=zero_bias, fan_mode=fan_mode, epochs=300)

Epoch 0
-------------------------------
[4/1762] D loss: 1.5075, G loss: 0.4152
[84/1762] D loss: 1.3148, G loss: 0.5636
[164/1762] D loss: 1.0318, G loss: 0.6562
[244/1762] D loss: 0.9976, G loss: 0.7413
[324/1762] D loss: 0.9575, G loss: 0.7170
[404/1762] D loss: 0.7545, G loss: 0.8901
[484/1762] D loss: 0.6869, G loss: 1.1073
[564/1762] D loss: 0.6234, G loss: 1.2257
[644/1762] D loss: 0.5034, G loss: 1.4947
[724/1762] D loss: 0.5329, G loss: 1.5756
[804/1762] D loss: 0.3964, G loss: 1.9713
[884/1762] D loss: 0.4336, G loss: 2.4447
[964/1762] D loss: 0.5281, G loss: 2.1964
[1044/1762] D loss: 0.4551, G loss: 2.1459
[1124/1762] D loss: 0.4125, G loss: 2.4917
[1204/1762] D loss: 0.5087, G loss: 2.2703
[1284/1762] D loss: 0.6410, G loss: 2.6732
[1364/1762] D loss: 0.4885, G loss: 2.6320
[1444/1762] D loss: 0.3194, G loss: 3.2964
[1524/1762] D loss: 0.3718, G loss: 3.7887
[1604/1762] D loss: 0.4479, G loss: 2.5095
[1684/1762] D loss: 0.2654, G loss: 3.6770
[1762/1762] D loss: 0.4341, G 

The results here look quite promising. In terms of board accuracy, all the custom initializations improve faster than the default initialization. The initializations that don't zero out the bias are more unstable, with occasional large jumps down in board accuracy, though these instabilities decrease in magnitude after 100 epochs. Generally, "fan out" mode with zero bias looks like the best initialization method.

In terms of spawn recall, all the custom initializations made a much better attempt at getting good spawn recall alongside high accuracy, though the spawn recall curves are very noisy. The spawn precision steadily rose (with noise) for some curves in the first 100 epochs, so I reran the experiment for 300 epochs as well. The curves where the bias hadn't been zeroed out initially eventually hit zero spawn recall and do not recover. Of the curves which did have bias zeroed out, the "fan in" mode one does hit zero spawn recall during training a few times but always recovers. The "fan out" mode curve never hits spawn recall zero after the first 50 epochs, and it always recovers from its lows. Both these curves stay about 60% spawn precision after 200 epochs. They also have around 80% spawn validity.

Clearly when doing custom initialization, zeroing out the bias seems to work better than not. The spawn recall and board accuracy also seem better than when using the default initialization. However, let's rerun the custom initialization against the default a few more times to make sure.

In [258]:
def compute_final_test_metrics(dataloader, gen):
    gen.eval()

    board_accuracy = 0.0
    spawn_recall = 0.0
    num_spawns = 0.0
    spawn_validity = 0.0
    num_predicted_spawns = 0.0
    spawn_precision = 0.0

    num_batches = len(dataloader)
    with torch.no_grad():        
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            
            batch_size = X.size(0)

            z = torch.rand(batch_size, 4)
            y_fake = gen(X, z)
            
            classes_X = torch.argmax(X, dim=1)
            classes_y = torch.argmax(y, dim=1)
            classes_y_fake = torch.argmax(y_fake, dim=1)
            board_accuracy += (classes_y_fake == classes_y).all(-1).all(-1).type(torch.float).mean().item()

            actual_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y[:, 0, :] == 1).any(-1)
            predicted_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y_fake[:, 0, :] == 1).any(-1)
            num_true_positives = (actual_spawns & predicted_spawns).type(torch.float).sum().item()
            spawn_recall += num_true_positives
            spawn_precision += num_true_positives
            num_spawns += actual_spawns.type(torch.float).sum().item()
            valid_spawns = get_valid_block_spawns(classes_X, classes_y_fake)
            spawn_validity += valid_spawns.type(torch.float).sum().item()
            num_predicted_spawns += predicted_spawns.type(torch.float).sum().item()

    board_accuracy /= num_batches
    spawn_recall /= num_spawns
    spawn_precision = np.nan if (num_predicted_spawns == 0.0) else spawn_precision / num_predicted_spawns
    spawn_validity = np.nan if (num_predicted_spawns == 0.0) else spawn_validity / num_predicted_spawns

    return {
        "board_accuracy": board_accuracy,
        "spawn_recall": spawn_recall,
        "spawn_precision": spawn_precision,
        "spawn_validity": spawn_validity,
    }

In [259]:
def train(run_name: str, custom_init: bool, fan_mode: str, epochs=100, learning_rate=1e-4):
    gen = TetrisModel().to(device)
    disc = TetrisDiscriminator().to(device)

    if custom_init:
        init_fn = get_init_fn(zero_bias=True, fan_mode=fan_mode)
        gen.apply(init_fn)
        disc.apply(init_fn)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_025")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples)
        gen_zero_grads = 0
        for name, weight in gen.named_parameters():
            tb_writer.add_histogram(f"Weights/{name}", weight, epoch)
            if weight.grad is not None:
                tb_writer.add_histogram(f"Gradients/{name}", weight.grad, epoch)
                gen_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Zero gradients", gen_zero_grads, epoch)
        disc_zero_grads = 0
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)
            disc_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Discriminator zero gradients", disc_zero_grads, epoch)

    tb_writer.close()
    print("Done!")
    return compute_final_test_metrics(test_dataloader, gen)

In [261]:
def init_metrics():
    return {
        "board_accuracy": [],
        "spawn_recall": [],
        "spawn_precision": [],
        "spawn_validity": [],
    }

def copy_metrics(run_metrics, metrics):
    for name in ["board_accuracy", "spawn_recall", "spawn_precision", "spawn_validity"]:
        metrics[name].append(run_metrics[name])

baseline_metrics = init_metrics()
fan_in_metrics = init_metrics()
fan_out_metrics = init_metrics()

for i in range(3):
    run_metrics = train(run_name="baseline", custom_init=False, fan_mode=None, epochs=300)
    copy_metrics(run_metrics, baseline_metrics)
    run_metrics = train(run_name="ZB_True_FM_fan_in", custom_init=True, fan_mode="fan_in", epochs=300)
    copy_metrics(run_metrics, fan_in_metrics)
    run_metrics = train(run_name="ZB_True_FM_fan_out", custom_init=True, fan_mode="fan_out", epochs=300)
    copy_metrics(run_metrics, fan_out_metrics)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3902, G loss: 0.6080
[84/1762] D loss: 1.3832, G loss: 0.6123
[164/1762] D loss: 1.3725, G loss: 0.6178
[244/1762] D loss: 1.3437, G loss: 0.6305
[324/1762] D loss: 1.3164, G loss: 0.6475
[404/1762] D loss: 1.2642, G loss: 0.6783
[484/1762] D loss: 1.2000, G loss: 0.7289
[564/1762] D loss: 1.1141, G loss: 0.7745
[644/1762] D loss: 1.0164, G loss: 0.8650
[724/1762] D loss: 0.8450, G loss: 1.1013
[804/1762] D loss: 0.6391, G loss: 1.3917
[884/1762] D loss: 0.5564, G loss: 1.5377
[964/1762] D loss: 0.4168, G loss: 1.9368
[1044/1762] D loss: 0.4277, G loss: 2.1585
[1124/1762] D loss: 0.5023, G loss: 2.3791
[1204/1762] D loss: 0.3338, G loss: 2.4382
[1284/1762] D loss: 0.3095, G loss: 2.5720
[1364/1762] D loss: 0.3261, G loss: 2.9839
[1444/1762] D loss: 0.3123, G loss: 3.2847
[1524/1762] D loss: 0.2249, G loss: 3.1326
[1604/1762] D loss: 0.2717, G loss: 3.0876
[1684/1762] D loss: 0.2599, G loss: 3.4415
[1762/1762] D loss: 0.3794, G 

Visually, the graphs with custom initialization seem to outperform those without in terms of board accuracy, spawn recall, spawn precision and spawn validity. It is not just that the graphs for runs with custom initialization reach higher values - in fact most of the max values over the whole run are similar between the models - but the curves for models with custom initialization don't jump down as sharply, so the average values are higher and it gives a higher chance of the model having satisfactory performance at the epoch that the training is stopped.

To ensure that we make a reasoned choice, let's compare the final values of the key metrics between the different configurations.

In [280]:
from scipy.stats import ttest_ind

print(baseline_metrics)
print(fan_in_metrics)

print(ttest_ind(baseline_metrics["board_accuracy"], fan_in_metrics["board_accuracy"], equal_var=False, alternative="less"))
print(ttest_ind(baseline_metrics["spawn_recall"], fan_in_metrics["spawn_recall"], equal_var=False, alternative="less"))
print(ttest_ind(baseline_metrics["spawn_precision"], fan_in_metrics["spawn_precision"], equal_var=False, alternative="less"))
print(ttest_ind(baseline_metrics["spawn_validity"], fan_in_metrics["spawn_validity"], equal_var=False, alternative="less"))

{'board_accuracy': [0.875, 0.8795454545454545, 0.8795454545454545], 'spawn_recall': [0.175, 0.65, 0.5], 'spawn_precision': [0.5833333333333334, 1.0, 0.7692307692307693], 'spawn_validity': [0.0, 0.8461538461538461, 0.07692307692307693]}
{'board_accuracy': [0.9068181818181819, 0.8113636363636364, 0.884090909090909], 'spawn_recall': [0.0, 0.125, 0.0], 'spawn_precision': [nan, 0.1388888888888889, nan], 'spawn_validity': [nan, 0.8333333333333334, nan]}
Ttest_indResult(statistic=0.36791183185933857, pvalue=0.6259722179190691)
Ttest_indResult(statistic=2.7350538350313105, pvalue=0.9531246520195679)
Ttest_indResult(statistic=nan, pvalue=nan)
Ttest_indResult(statistic=nan, pvalue=nan)


Here, we cannot reject the null hypothesis according to the test, so there isn't clear evidence that we should prefer the fan_in variant over the baseline.

Now let's try the fan_out variant.

In [281]:
from scipy.stats import ttest_ind

print(baseline_metrics)
print(fan_out_metrics)

print(ttest_ind(baseline_metrics["board_accuracy"], fan_out_metrics["board_accuracy"], equal_var=False, alternative="less"))
print(ttest_ind(baseline_metrics["spawn_recall"], fan_out_metrics["spawn_recall"], equal_var=False, alternative="less"))
print(ttest_ind(baseline_metrics["spawn_precision"], fan_out_metrics["spawn_precision"], equal_var=False, alternative="less"))
print(ttest_ind(baseline_metrics["spawn_validity"], fan_out_metrics["spawn_validity"], equal_var=False, alternative="less"))

{'board_accuracy': [0.875, 0.8795454545454545, 0.8795454545454545], 'spawn_recall': [0.175, 0.65, 0.5], 'spawn_precision': [0.5833333333333334, 1.0, 0.7692307692307693], 'spawn_validity': [0.0, 0.8461538461538461, 0.07692307692307693]}
{'board_accuracy': [0.9090909090909091, 0.9022727272727272, 0.9090909090909091], 'spawn_recall': [0.0, 0.0, 0.0], 'spawn_precision': [nan, nan, nan], 'spawn_validity': [nan, nan, nan]}
Ttest_indResult(statistic=-10.539303728279362, pvalue=0.0004575975723908781)
Ttest_indResult(statistic=3.1505229808721493, pvalue=0.9561513718866513)
Ttest_indResult(statistic=nan, pvalue=nan)
Ttest_indResult(statistic=nan, pvalue=nan)


In terms of board accuracy, the fan_out variant clearly wins (though our sample size is small). However, there is something here I didn't notice before. Of the three fan_out runs, all of them have spawn recall zero at the end of the run, whereas at least some are nonzero with the baseline. The fan_in variant also looks bad here at first glance, but looking at the training curves, the fan_in variant still looks more stable (in spawn recall) than the baseline but we seem to have been unlucky with the stopping epoch. If we had stopped at, say, epoch 200, then it looks like we would have had a higher chance of getting a satisfactory fan_in model. Perhaps after this, the models encounter overfitting, mode collapse or another issue.

# Respecting the activation

As a next step, let's make the initialization depend on the activation function, since not all Conv2d / Linear layers are followed by ReLU - one layer in the generator is followed by Softmax, and the last layer of the discriminator is implicitly followed by Sigmoid. I've also noticed that the last layer of the generator's `glob` module doesn't have any activation function, so maybe that needs fixing. Finally, we can set biases to a small positive value instead of zero to avoid "dead" units at the start of the training (this is common advice with ReLU).

In [282]:
class GenWithSmartInit(nn.Module):
    def __init__(self):
        super().__init__()

        self.loc = nn.Sequential(
            nn.Conv2d(6, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
        )

        for i in [0, 3]:
            m = self.loc[i]
            nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")

        self.glob = nn.Sequential(
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(160, 10),
            nn.ReLU()
        )

        for i in [0, 4]:
            m = self.glob[i]
            nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
        
        m = self.glob[9]
        nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
        nn.init.constant_(m.bias, 0.01)

        self.head = nn.Sequential(
            nn.Conv2d(26, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 2, 1),
            nn.Softmax(dim=1)
        )

        for i in [0, 3, 6]:
            m = self.head[i]
            nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")

    def forward(self, x, z):
        batch_size, channels, height, width = x.shape
        
        z = z[:, :, None, None] # Expand dims to match x
        z = z.repeat(1, 1, height, width) # Upscale to image size
        x = torch.cat((x, z), dim=1)

        x = self.loc(x)

        x_glob = self.glob(x)
        x_glob = x_glob[:, :, None, None] # Expand dims
        x_glob = x_glob.repeat(1, 1, height, width) # Upscale to image size
        x = torch.cat((x, x_glob), dim=1)

        y = self.head(x)
        return y

In [283]:
class DiscWithSmartInit(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3),
                nn.ReLU(),
                nn.Flatten(),
                nn.Linear(112, 16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
        
        for i in [0, 3, 5, 8]:
            m = self.body[i]
            nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
            nn.init.constant_(m.bias, 0.01)
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits

In [284]:
def train(run_name: str, gen_cls: nn.Module, disc_cls: nn.Module, epochs=100, learning_rate=1e-4):
    gen = gen_cls().to(device)
    disc = disc_cls().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_025")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples)
        gen_zero_grads = 0
        for name, weight in gen.named_parameters():
            tb_writer.add_histogram(f"Weights/{name}", weight, epoch)
            if weight.grad is not None:
                tb_writer.add_histogram(f"Gradients/{name}", weight.grad, epoch)
                gen_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Zero gradients", gen_zero_grads, epoch)
        disc_zero_grads = 0
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)
            disc_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Discriminator zero gradients", disc_zero_grads, epoch)

    tb_writer.close()
    print("Done!")

In [285]:
train(run_name="smart_init", gen_cls=GenWithSmartInit, disc_cls=DiscWithSmartInit)

Epoch 0
-------------------------------
[4/1762] D loss: 1.5480, G loss: 0.4032
[84/1762] D loss: 1.4283, G loss: 0.5300
[164/1762] D loss: 1.3358, G loss: 0.5777
[244/1762] D loss: 1.2930, G loss: 0.5826
[324/1762] D loss: 1.2584, G loss: 0.6045
[404/1762] D loss: 1.2274, G loss: 0.5831
[484/1762] D loss: 1.1680, G loss: 0.6029
[564/1762] D loss: 1.0358, G loss: 0.6700
[644/1762] D loss: 0.9755, G loss: 0.6673
[724/1762] D loss: 0.9205, G loss: 0.6805
[804/1762] D loss: 0.8208, G loss: 0.7651
[884/1762] D loss: 0.7623, G loss: 0.7888
[964/1762] D loss: 0.8318, G loss: 0.8054
[1044/1762] D loss: 0.6268, G loss: 0.9621
[1124/1762] D loss: 0.6292, G loss: 1.1232
[1204/1762] D loss: 0.4983, G loss: 1.3276
[1284/1762] D loss: 0.4245, G loss: 1.5499
[1364/1762] D loss: 0.3337, G loss: 1.7937
[1444/1762] D loss: 0.2991, G loss: 2.0357
[1524/1762] D loss: 0.3915, G loss: 2.1592
[1604/1762] D loss: 0.3619, G loss: 2.5495
[1684/1762] D loss: 0.3323, G loss: 2.5705
[1762/1762] D loss: 0.2151, G 

In [288]:
train(run_name="smart_init", gen_cls=GenWithSmartInit, disc_cls=DiscWithSmartInit, epochs=300)

Epoch 0
-------------------------------
[4/1762] D loss: 1.4302, G loss: 0.4971
[84/1762] D loss: 1.3844, G loss: 0.5606
[164/1762] D loss: 1.3217, G loss: 0.5557
[244/1762] D loss: 1.2225, G loss: 0.6484
[324/1762] D loss: 1.1760, G loss: 0.6647
[404/1762] D loss: 1.1153, G loss: 0.6540
[484/1762] D loss: 0.9980, G loss: 0.7094
[564/1762] D loss: 0.8886, G loss: 0.7830
[644/1762] D loss: 0.7614, G loss: 0.8838
[724/1762] D loss: 0.6170, G loss: 1.0666
[804/1762] D loss: 0.5635, G loss: 1.2017
[884/1762] D loss: 0.5113, G loss: 1.3263
[964/1762] D loss: 0.3881, G loss: 1.6991
[1044/1762] D loss: 0.5150, G loss: 1.6651
[1124/1762] D loss: 0.4036, G loss: 1.8785
[1204/1762] D loss: 0.4050, G loss: 1.8413
[1284/1762] D loss: 0.4166, G loss: 2.2115
[1364/1762] D loss: 0.4231, G loss: 2.1704
[1444/1762] D loss: 0.4613, G loss: 2.1655
[1524/1762] D loss: 0.3109, G loss: 2.5156
[1604/1762] D loss: 0.2817, G loss: 2.7464
[1684/1762] D loss: 0.2633, G loss: 3.1566
[1762/1762] D loss: 0.4500, G 

In [289]:
train(run_name="smart_init", gen_cls=GenWithSmartInit, disc_cls=DiscWithSmartInit, epochs=300)
train(run_name="smart_init", gen_cls=GenWithSmartInit, disc_cls=DiscWithSmartInit, epochs=300)

Epoch 0
-------------------------------
[4/1762] D loss: 1.4017, G loss: 0.7401
[84/1762] D loss: 1.2772, G loss: 0.7452
[164/1762] D loss: 1.0795, G loss: 0.8004
[244/1762] D loss: 1.0160, G loss: 0.8316
[324/1762] D loss: 0.8129, G loss: 1.0013
[404/1762] D loss: 0.6928, G loss: 1.1339
[484/1762] D loss: 0.7068, G loss: 1.3336
[564/1762] D loss: 0.5667, G loss: 1.4359
[644/1762] D loss: 0.4554, G loss: 1.7969
[724/1762] D loss: 0.4525, G loss: 1.8288
[804/1762] D loss: 0.3008, G loss: 2.0978
[884/1762] D loss: 0.3396, G loss: 2.4096
[964/1762] D loss: 0.2397, G loss: 2.3941
[1044/1762] D loss: 0.2150, G loss: 2.5153
[1124/1762] D loss: 0.1906, G loss: 2.6784
[1204/1762] D loss: 0.1940, G loss: 2.9596
[1284/1762] D loss: 0.1498, G loss: 3.4117
[1364/1762] D loss: 0.1566, G loss: 2.4162
[1444/1762] D loss: 0.1206, G loss: 3.2591
[1524/1762] D loss: 0.1815, G loss: 3.3474
[1604/1762] D loss: 0.0906, G loss: 3.6867
[1684/1762] D loss: 0.0844, G loss: 3.2366
[1762/1762] D loss: 0.1197, G 

To do a fair comparison between the different weight initialization strategies, let's pick the optimal number of epochs to stop at. For each strategy, let's pick either 100, 200 or 300 epochs - whichever gives the best tradeoff between board accuracy, spawn recall and spawn precision - then write down the metrics of all these runs.

Here are the results:

| Configuration | Stopping epochs |
| ------------- | --------------- |
| Baseline | 300 |
| Kaiming fan_in | 200 |
| Kaiming fan_out | 200 |
| Smart init | 300 |

<p>(The metrics in the table below are averages of test metrics)</p>

| Configuration | Board accuracy | Spawn recall | Spawn precision | Spawn validity | Cell accuracy |
| ------------- | -------------- | ------------ | --------------- | -------------- | ------------- |
| Baseline | 0.8773 | 0.5188 | 0.7692 | 0.4223 | 0.9980 |
| Kaiming fan_in | 0.8620 | 0.5750 | 0.8248 | 0.7750 | 0.9977 |
| Kaiming fan_out | 0.8989 | 0.3250 | 0.9054 | 0.6490 | 0.9976 |
| Smart init | 0.8954 | 0.6083* | 0.9462 | 0.7811 | 0.9979 |

*One run of "smart init" stayed at 0 spawn recall for almost the entire run, so it seems like we got unlucky with the initialization. Still, this configuration seems to outperform the others regardless.

This table shows that custom initialization is not the silver bullet it looked like from earlier graphs. In particular, the "Kaiming fan_out" configuration looked good in the graphs, but this is because the poor spawn recall was masked by the fact that many graphs were overlaid on top of each other.

Overall, it looks like "smart init" is the best initialization scheme.

# Conclusion

Custom initialization is both theoretically sound (Kaiming He et al 2015) and empirically better. However, it is not a silver bullet and only gives incremental improvement. Of the initialization schemes tested, we prefer the one which uses Kaiming uniform initialization in `fan_in` mode before every ReLU layer and default initialization elsewhere. We also employ a couple of modifications: initialize biases before a ReLU to 0.01, and add ReLU to the end of the generator's `glob` module. We should incorporate these modifications into the main training notebook.