# Experiment 020

In this experiment, let's try and increase the generator performance, in particular the spawn recall.

Firstly, let's try smoothing on the second frame `y`. This will make the real data and the fake data a bit more similar, and hopefully means that the discriminator won't "waste any energy" trying to identify real/fake data based on whether cells are marked as approximately 0/1 vs exactly 0/1.

In [2]:
import os
from pathlib import Path
import shutil
import datetime

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

In [3]:
class RecordingDataset(Dataset):
    def __init__(self, path: str):
        self.path = path
        if not os.path.exists(path):
            raise FileNotFoundError()
        with os.scandir(self.path) as it:
            entry: os.DirEntry = next(iter(it))
            _, self.ext = os.path.splitext(entry.name)
            self.highest_index = max((int(Path(file.path).stem) for file in it), default=-1)

    def __len__(self):
        return self.highest_index + 1

    def __getitem__(self, idx):
        file = os.path.join(self.path, f"{idx}{self.ext}")
        if not os.path.exists(file):
            raise IndexError()
        boards = np.load(file)

        def transform(board):
            board = torch.tensor(board, dtype=torch.long)
            board = F.one_hot(board, 2) # One-hot encode the cell types
            board = board.type(torch.float) # Convert to floating-point
            board = board.permute((2, 0, 1)) # Move channels/classes to dimension 0
            return board

        x = transform(boards[-2]) # Ignore all boards except the last two
        y = transform(boards[-1])
        return x, y
        

In [4]:
train_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "train"))
test_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "test"))
batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

x, y = next(iter(train_dataloader))
print(x.shape, x.dtype)
print(y.shape, y.dtype)

torch.Size([4, 2, 22, 10]) torch.float32
torch.Size([4, 2, 22, 10]) torch.float32


In [5]:
device = (
    "cuda" if torch.cuda.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cpu device


In [6]:
class TetrisModel(nn.Module):
    """Predicts the next state of the cells.

    Inputs:
        x: Tensor of float32 of shape (batch_size, channels, height, width). channels = 2 is the one-hot encoding of cell types, with
           0 for empty cells and 1 for filled cells. height = 22 and width = 10 are the dimensions of the game board. The entries
           should be 0 for empty cells and 1 for filled cells.
        z: Tensor of float32 of shape (batch_size, 4). The entries should be random numbers sampled from a uniform distribution.
    
    Returns: Tensor of float32 of shape (batch_size, height, width), logits for the new cells. Probabilities close to 0 (negative logits)
             correspond to empty cells, and probabilities close to 1 (positive logits) correspond to filled cells.
    """

    def __init__(self):
        super().__init__()
        self.loc = nn.Sequential(
            nn.Conv2d(6, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
        )
        self.glob = nn.Sequential(
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(160, 10)
        )
        self.head = nn.Sequential(
            nn.Conv2d(26, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 2, 1),
            nn.Softmax(dim=1)
        )

    def forward(self, x, z):
        batch_size, channels, height, width = x.shape
        
        z = z[:, :, None, None] # Expand dims to match x
        z = z.repeat(1, 1, height, width) # Upscale to image size
        x = torch.cat((x, z), dim=1)

        x = self.loc(x)

        x_glob = self.glob(x)
        x_glob = x_glob[:, :, None, None] # Expand dims
        x_glob = x_glob.repeat(1, 1, height, width) # Upscale to image size
        x = torch.cat((x, x_glob), dim=1)

        y = self.head(x)
        return y

In [26]:
class TetrisDiscriminator(nn.Module):
    """A discriminator for the cell state predictions. Assesses the output of the generator.

    Inputs:
        x: Tensor of float32 of shape (batch_size, channels, height, width). channels = 2 is the one-hot encoding of cell types, with
           0 for empty cells and 1 for filled cells. height = 22 and width = 10 are the dimensions of the game board. The entries
           should be 0 for empty cells and 1 for filled cells.
        y: Tensor of float32 of shape (batch_size, channels, height, width), as with x. This should be either the output of the
           generator (with exp applied) or the one-hot encoding of the ground truth of the next cell states.
    
    Returns: Tensor of float32 of shape (batch_size, 1), decisions on whether the data are real or fake. Probabilities close to 0 (negative logits)
             correspond to fake data, and probabilities close to 1 (positive logits) correspond to real data.
    """

    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3),
                nn.ReLU(),
                nn.Flatten(),
                nn.Linear(112, 16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits

In [8]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [9]:
gen = TetrisModel().to(device)
disc = TetrisDiscriminator().to(device)

with torch.no_grad():
    X, y = next(iter(train_dataloader))
    z = torch.rand(batch_size, 4)
    y_gen = gen(X, z)
    pred_on_real = F.sigmoid(disc(X, y)[0])
    pred_on_fake = F.sigmoid(disc(X, y_gen)[0])
    print(f"Number of generator parameters: {count_parameters(gen)}")
    print(f"Number of discriminator parameters: {count_parameters(disc)}")
    print(f"Predicted label for real data: {pred_on_real}")
    print(f"Predicted label for fake data: {pred_on_fake}")

Number of generator parameters: 17996
Number of discriminator parameters: 7057
Predicted label for real data: 0.4737066626548767
Predicted label for fake data: 0.4734146296977997


In [10]:
import itertools

def find_interesting_examples(dataset, num=3):
    num_spawns = num
    
    def inner():
        num_spawns_left = num_spawns

        for x, y in dataset:
            # Check for block spawn
            if (x.argmax(0)[0] == 0).all() & (y.argmax(0)[0] == 1).any():
                if num_spawns_left > 0:
                    num_spawns_left -= 1
                    yield x, y
                else:
                    continue
            
    return list(itertools.islice(inner(), num))

In [11]:
def render_prediction(x, pred, y):
    """Renders an example and prediction into a single-image array.
    
    Inputs:
        x: Tensor of shape (height, width), the model input.
        pred: Tensor of shape (height, width), the model prediction.
        y: Tensor of shape (height, width), the target.
    """
    assert len(x.shape) == 2, f"Expected tensors of shape (width, height) but got {x.shape}"
    assert x.shape == pred.shape, f"Shapes do not match: {x.shape} != {pred.shape}"
    assert x.shape == y.shape, f"Shapes do not match: {x.shape} != {y.shape}"
    height, width = x.shape
    with torch.no_grad():
        separator = torch.ones(height, 1, dtype=x.dtype)
        return torch.cat((x, separator, pred, separator, y), dim=-1)

In [12]:
blocks = [
    torch.tensor(
        [[0, 0, 0, 1, 1, 1, 1, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # I

    torch.tensor(
        [[0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # O

    torch.tensor(
        [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # J

    torch.tensor(
        [[0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # T

    torch.tensor(
        [[0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # S

    torch.tensor(
        [[0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # L

    torch.tensor(
        [[0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int) # Z
]

def get_valid_block_spawns(classes_X, classes_y_fake):
    """Determines whether predicted block spawns have a valid shape.
    
    Inputs:
        classes_X: Tensor of int32 of shape (batch_size, height, width), the first time step (with argmax applied on cell types).
        classes_y_fake: Tensor of int32 of shape (batch_size, height, width), the model's prediction (with argmax applied on cell types).

    Returns: Tensor of bool of shape (batch_size,), whether the items are predicted block spawns AND valid.
    """
    with torch.no_grad():
        batch_size = classes_X.size(0)
        ret = torch.full((batch_size,), False)

        # Take difference to see which cells are full but weren't before.
        diff = classes_y_fake - classes_X

        # It's only a valid block spawn if the change in the first 3 rows matches
        # one of the valid configurations.
        for block in blocks:
            ret |= (diff[:, :3, :] == block).all(-1).all(-1)
        
        return ret


In [16]:
real_label = 1.0
fake_label = 0.0

def train_loop(dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc, y_smoothing=0.0):
    gen.train()
    disc.train()

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(train_dataloader):
        ##################################################################
        # (1) Update discriminator: minimize -log(D(x)) - log(1 - D(G(z)))
        ##################################################################
        disc.zero_grad()

        ## Train with all-real batch
        # Format batch
        X, y = X.to(device), y.to(device)
        if y_smoothing != 0.0:
            y = y * (1.0 - y_smoothing) + (1.0 - y) * y_smoothing
        batch_size = X.size(0)
        real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through discriminator
        output = torch.flatten(disc(X, y))
        # Calculate loss on all-real batch
        err_disc_real = loss_fn(output, real_labels)
        # Calculate gradients for discriminator in backward pass
        err_disc_real.backward()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        z = torch.rand(batch_size, 4, device=device)
        # Generate fake image batch with generator
        y_fake = gen(X, z)
        fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)
        # Classify all fake batch with discriminator
        output = torch.flatten(disc(X, y_fake.detach()))
        # Calculate discriminator's loss on the all-fake batch
        err_disc_fake = loss_fn(output, fake_labels)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        err_disc_fake.backward()

        ## Update discriminator weights
        # Compute error of discriminator as sum over the fake and the real batches
        err_disc = err_disc_real + err_disc_fake
        # Update discriminator
        optimizer_disc.step()

        ##############################################
        # (2) Update generator: minimize -log(D(G(z)))
        ##############################################
        gen.zero_grad()
        # Since we just updated the discriminator, perform another forward pass of the all-fake batch through it
        output = torch.flatten(disc(X, y_fake))
        # Calculate the generator's loss based on this output
        # We use real labels because the generator wants to fool the discriminator
        err_gen = loss_fn(output, real_labels)
        # Calculate gradients for generator
        err_gen.backward()
        # Update generator
        optimizer_gen.step()

        # Output training stats
        if batch % 20 == 0:
            current = batch * dataloader.batch_size + batch_size
            print(f"[{current}/{size}] D loss: {err_disc.item():.4f}, G loss: {err_gen.item():.4f}")


def test_loop(split_name, dataloader, gen, disc, loss_fn, tb_writer, epoch, examples, y_smoothing=0.0):
    gen.eval()
    disc.eval()

    loss_disc = 0.0
    loss_gen = 0.0
    disc_accuracy = 0.0
    cell_accuracy = 0.0
    board_accuracy = 0.0
    spawn_recall = 0.0
    num_spawns = 0.0
    spawn_validity = 0.0
    num_predicted_spawns = 0.0
    scores_real = np.zeros(len(dataloader.dataset))
    scores_fake = np.zeros(len(dataloader.dataset))

    num_batches = len(dataloader)
    with torch.no_grad():        
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            
            batch_size, channels, height, width = X.shape

            if y_smoothing != 0.0:
                smoothing_rand = y_smoothing * torch.rand(batch_size, 1, height, width)
                smoothing_rand = smoothing_rand.expand(-1, channels, -1, -1)
                y = y * (1.0 - y_smoothing) + (1.0 - y) * y_smoothing

            batch_size = X.size(0)
            real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)

            output_real = disc(X, y)
            loss_disc += loss_fn(output_real, real_labels).item()

            z = torch.rand(batch_size, 4, device=device)
            y_fake = gen(X, z)
            output_fake = disc(X, y_fake)
            
            loss_disc += loss_fn(output_fake, fake_labels).item()
            loss_gen += loss_fn(output_fake, real_labels).item()

            pred_real = (output_real > 0.0)
            pred_fake = (output_fake > 0.0)
            disc_accuracy += pred_real.type(torch.float).mean().item()
            disc_accuracy += (~pred_fake).type(torch.float).mean().item()

            classes_X = torch.argmax(X, dim=1)
            classes_y = torch.argmax(y, dim=1)
            classes_y_fake = torch.argmax(y_fake, dim=1)
            cell_accuracy += (classes_y_fake == classes_y).type(torch.float).mean().item()
            board_accuracy += (classes_y_fake == classes_y).all(-1).all(-1).type(torch.float).mean().item()

            actual_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y[:, 0, :] == 1).any(-1)
            predicted_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y_fake[:, 0, :] == 1).any(-1)
            num_true_positives = (actual_spawns & predicted_spawns).type(torch.float).sum().item()
            spawn_recall += num_true_positives
            num_spawns += actual_spawns.type(torch.float).sum().item()
            valid_spawns = get_valid_block_spawns(classes_X, classes_y_fake)
            spawn_validity += valid_spawns.type(torch.float).sum().item()
            num_predicted_spawns += predicted_spawns.type(torch.float).sum().item()

            start_index = dataloader.batch_size * batch
            end_index = start_index + batch_size
            scores_real[start_index:end_index] = torch.sigmoid(output_real).numpy()
            scores_fake[start_index:end_index] = torch.sigmoid(output_fake).numpy()

    loss_disc /= num_batches
    loss_gen /= num_batches
    cell_accuracy /= num_batches
    board_accuracy /= num_batches
    spawn_recall /= num_spawns
    disc_accuracy /= (2.0 * num_batches)
    spawn_validity = np.nan if (num_predicted_spawns == 0.0) else spawn_validity / num_predicted_spawns

    print(f"{split_name} error: \n D loss: {loss_disc:>8f}, G loss: {loss_gen:>8f}, D accuracy: {(100*disc_accuracy):>0.1f}%, cell accuracy: {(100*cell_accuracy):>0.1f}%, board accuracy: {(100*board_accuracy):>0.1f}% \n")

    tb_writer.add_scalar(f"Discriminator loss/{split_name}", loss_disc, epoch)
    tb_writer.add_scalar(f"Loss/{split_name}", loss_gen, epoch)
    tb_writer.add_scalar(f"Discriminator accuracy/{split_name}", disc_accuracy, epoch)
    tb_writer.add_scalar(f"Cell accuracy/{split_name}", cell_accuracy, epoch)
    tb_writer.add_scalar(f"Board accuracy/{split_name}", board_accuracy, epoch)
    tb_writer.add_scalar(f"Spawn recall/{split_name}", spawn_recall, epoch)
    tb_writer.add_scalar(f"Spawn validity/{split_name}", spawn_validity, epoch)

    with torch.no_grad():
        for i, (X, y) in enumerate(examples):
            X, y = X.unsqueeze(0), y.unsqueeze(0)
            z = torch.rand(1, 4, device=device)
            y_fake = gen(X, z)
            X, y, y_fake = X.squeeze(0), y.squeeze(0), y_fake.squeeze(0)
            X, y, y_fake = X.argmax(0), y.argmax(0), y_fake.argmax(0)
            img = render_prediction(X, y_fake, y)
            tb_writer.add_image(f"Predictions/{split_name}/{i}", img, epoch, dataformats="HW")
    
    tb_writer.add_histogram(f"Discriminator scores/{split_name}/real", scores_real, epoch)
    tb_writer.add_histogram(f"Discriminator scores/{split_name}/fake", scores_fake, epoch)


In [22]:
def train(run_name="", y_smoothing=0.0):
    learning_rate = 1e-3
    epochs = 100

    gen = TetrisModel().to(device)
    disc = TetrisDiscriminator().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_020")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc, y_smoothing=y_smoothing)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples, y_smoothing=y_smoothing)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples, y_smoothing=y_smoothing)
        gen_zero_grads = 0
        for name, weight in gen.named_parameters():
            tb_writer.add_histogram(f"Weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Gradients/{name}", weight.grad, epoch)
            gen_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Zero gradients", gen_zero_grads, epoch)
        disc_zero_grads = 0
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)
            disc_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Discriminator zero gradients", disc_zero_grads, epoch)

    tb_writer.close()
    print("Done!")

In [19]:
train(run_name="smooth_0p1", y_smoothing=0.1)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3899, G loss: 0.6232
[84/1762] D loss: 1.2050, G loss: 0.7130
[164/1762] D loss: 0.8062, G loss: 1.6977
[244/1762] D loss: 0.9580, G loss: 1.0284
[324/1762] D loss: 1.3387, G loss: 0.8685
[404/1762] D loss: 1.2858, G loss: 0.7602
[484/1762] D loss: 1.3336, G loss: 0.7300
[564/1762] D loss: 1.1233, G loss: 1.0096
[644/1762] D loss: 1.6743, G loss: 0.7018
[724/1762] D loss: 1.4721, G loss: 0.8222
[804/1762] D loss: 1.5792, G loss: 0.6202
[884/1762] D loss: 1.3602, G loss: 0.7179
[964/1762] D loss: 1.4653, G loss: 0.6715
[1044/1762] D loss: 1.4695, G loss: 0.6754
[1124/1762] D loss: 1.4198, G loss: 0.6800
[1204/1762] D loss: 1.3984, G loss: 0.7196
[1284/1762] D loss: 1.4177, G loss: 0.6873
[1364/1762] D loss: 1.4315, G loss: 0.6952
[1444/1762] D loss: 1.3708, G loss: 0.7339
[1524/1762] D loss: 1.3827, G loss: 0.6832
[1604/1762] D loss: 1.4035, G loss: 0.6766
[1684/1762] D loss: 1.4119, G loss: 0.6814
[1762/1762] D loss: 1.4345, G 

Smoothing seems to make discriminator unstable and its score histograms more spread out and overlapping. The generator performance is about the same as without the smoothing. The spawn recall situation does not improve.

# Monitor zero gradients

I've added a new metric to the training loop to count the number of zero gradients in both the generator and discriminator.

In [23]:
train(run_name="no_smoothing")

Epoch 0
-------------------------------
[4/1762] D loss: 1.3964, G loss: 0.6044
[84/1762] D loss: 1.1607, G loss: 0.7354
[164/1762] D loss: 0.3029, G loss: 2.8283
[244/1762] D loss: 0.0410, G loss: 5.4900
[324/1762] D loss: 0.3973, G loss: 6.0433
[404/1762] D loss: 0.0854, G loss: 4.8722
[484/1762] D loss: 0.6695, G loss: 4.4554
[564/1762] D loss: 0.6152, G loss: 6.2701
[644/1762] D loss: 0.4731, G loss: 4.1296
[724/1762] D loss: 0.5627, G loss: 4.1013
[804/1762] D loss: 0.5914, G loss: 2.3643
[884/1762] D loss: 0.8383, G loss: 2.9616
[964/1762] D loss: 0.4391, G loss: 2.0618
[1044/1762] D loss: 0.4119, G loss: 2.2814
[1124/1762] D loss: 2.2765, G loss: 0.3681
[1204/1762] D loss: 1.1637, G loss: 1.2333
[1284/1762] D loss: 1.3408, G loss: 1.1538
[1364/1762] D loss: 1.6352, G loss: 0.3607
[1444/1762] D loss: 0.7155, G loss: 1.0913
[1524/1762] D loss: 0.8522, G loss: 1.0842
[1604/1762] D loss: 0.9652, G loss: 0.9976
[1684/1762] D loss: 0.8149, G loss: 2.5588
[1762/1762] D loss: 0.5299, G 

The number of zero gradients in both models steadily climbs. After 100 epochs, the generator has 2880 / 17996 (16%) gradients zero, and the discriminator has 5357 / 7057 (76%) gradients zero.

# Leaky ReLU

Let's try using leaky ReLU in the discriminator and see if it reduces the number of zero gradients, and if that translates to an improvement in performance.

In [27]:
# Define discriminator that uses leaky ReLU

class DiscWithLeakyReLU(nn.Module):
    def __init__(self, leak):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1),
                nn.LeakyReLU(leak),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3),
                nn.LeakyReLU(leak),
                nn.Conv2d(16, 16, 3),
                nn.LeakyReLU(leak),
                nn.Flatten(),
                nn.Linear(112, 16),
                nn.LeakyReLU(leak),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits

In [28]:
# Redefine training loop to use leaky ReLU

def train(run_name, leak):
    learning_rate = 1e-3
    epochs = 100

    gen = TetrisModel().to(device)
    disc = DiscWithLeakyReLU(leak).to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_020")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc, y_smoothing=0.0)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples, y_smoothing=0.0)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples, y_smoothing=0.0)
        gen_zero_grads = 0
        for name, weight in gen.named_parameters():
            tb_writer.add_histogram(f"Weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Gradients/{name}", weight.grad, epoch)
            gen_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Zero gradients", gen_zero_grads, epoch)
        disc_zero_grads = 0
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)
            disc_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Discriminator zero gradients", disc_zero_grads, epoch)

    tb_writer.close()
    print("Done!")

In [29]:
train(run_name="LReLU_0p1", leak=0.1)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3850, G loss: 0.7165
[84/1762] D loss: 1.1689, G loss: 0.7529
[164/1762] D loss: 0.2074, G loss: 3.0603
[244/1762] D loss: 0.3808, G loss: 4.6779
[324/1762] D loss: 0.1221, G loss: 4.0611
[404/1762] D loss: 0.0846, G loss: 4.1780
[484/1762] D loss: 0.2582, G loss: 3.0850
[564/1762] D loss: 0.0694, G loss: 3.8776
[644/1762] D loss: 0.6929, G loss: 0.8410
[724/1762] D loss: 0.5694, G loss: 1.1098
[804/1762] D loss: 0.9283, G loss: 1.8484
[884/1762] D loss: 2.1984, G loss: 0.5692
[964/1762] D loss: 1.8836, G loss: 2.1332
[1044/1762] D loss: 0.9448, G loss: 1.0125
[1124/1762] D loss: 0.9535, G loss: 0.7981
[1204/1762] D loss: 0.9989, G loss: 1.9940
[1284/1762] D loss: 1.1616, G loss: 0.9524
[1364/1762] D loss: 1.3483, G loss: 0.9465
[1444/1762] D loss: 1.1970, G loss: 0.8780
[1524/1762] D loss: 0.9651, G loss: 1.4004
[1604/1762] D loss: 1.3323, G loss: 1.1833
[1684/1762] D loss: 1.4145, G loss: 1.2263
[1762/1762] D loss: 1.0553, G 

In [30]:
for leak in [0.01, 0.1, 0.2]:
    run_name = "LReLU_" + str(leak).replace(".", "p")
    train(run_name=run_name, leak=leak)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3909, G loss: 0.7230
[84/1762] D loss: 1.1247, G loss: 0.7592
[164/1762] D loss: 0.2634, G loss: 3.1230
[244/1762] D loss: 0.1594, G loss: 5.4866
[324/1762] D loss: 0.0802, G loss: 5.5689
[404/1762] D loss: 0.2223, G loss: 6.8260
[484/1762] D loss: 0.4439, G loss: 4.2998
[564/1762] D loss: 0.2115, G loss: 4.6934
[644/1762] D loss: 0.4044, G loss: 4.1957
[724/1762] D loss: 0.3183, G loss: 3.7658
[804/1762] D loss: 1.0453, G loss: 1.0787
[884/1762] D loss: 0.6515, G loss: 1.1139
[964/1762] D loss: 0.8461, G loss: 0.9658
[1044/1762] D loss: 0.3747, G loss: 2.1248
[1124/1762] D loss: 0.8628, G loss: 0.6585
[1204/1762] D loss: 0.9424, G loss: 1.3778
[1284/1762] D loss: 1.9373, G loss: 0.6898
[1364/1762] D loss: 0.8578, G loss: 1.0376
[1444/1762] D loss: 0.8388, G loss: 2.6960
[1524/1762] D loss: 0.8796, G loss: 1.1665
[1604/1762] D loss: 1.2876, G loss: 0.8187
[1684/1762] D loss: 0.7749, G loss: 1.8131
[1762/1762] D loss: 1.6671, G 

Regardless of the amount of "leak" in the discriminator, the generator gradients become zero at roughly the same rate. Of course, the discriminator gradients are barely ever zero when we use leaky ReLU.

In terms of generator board accuracy, the curves look basically the same for all values of leak in the discriminator that were tried (0.01, 0.1 and 0.2).

In terms of spawn recall, we see spawn recall jumping between ~0% and ~100% for discriminator leak 0.01, but this could just be a fluke of the run. Let's try 2 more runs and see how often it happens.

In [31]:
for i in range(2):
    train(run_name="LReLU_0p01", leak=0.01)

Epoch 0
-------------------------------
[4/1762] D loss: 1.4037, G loss: 0.5712
[84/1762] D loss: 1.2398, G loss: 0.5427
[164/1762] D loss: 0.8818, G loss: 0.6350
[244/1762] D loss: 0.2171, G loss: 1.9720
[324/1762] D loss: 0.6574, G loss: 4.8027
[404/1762] D loss: 0.6542, G loss: 3.9364
[484/1762] D loss: 0.2860, G loss: 4.0622
[564/1762] D loss: 0.4254, G loss: 2.7897
[644/1762] D loss: 0.6122, G loss: 3.4050
[724/1762] D loss: 0.4062, G loss: 3.3425
[804/1762] D loss: 0.9786, G loss: 1.9627
[884/1762] D loss: 0.7966, G loss: 3.6491
[964/1762] D loss: 0.8764, G loss: 2.1363
[1044/1762] D loss: 0.6998, G loss: 0.9334
[1124/1762] D loss: 0.7551, G loss: 1.3425
[1204/1762] D loss: 0.7082, G loss: 1.0690
[1284/1762] D loss: 0.2697, G loss: 2.0990
[1364/1762] D loss: 0.5889, G loss: 1.6427
[1444/1762] D loss: 1.2752, G loss: 1.8412
[1524/1762] D loss: 1.3477, G loss: 2.3651
[1604/1762] D loss: 0.4553, G loss: 2.6886
[1684/1762] D loss: 0.4956, G loss: 2.1004
[1762/1762] D loss: 1.2585, G 

2/3 runs with leak 0.01 have spawn recall that jumps between 0% and 100% instead of staying near 0%. In contrast, 0/3 runs with normal ReLU achieve this. This suggests that a small amount of ReLU may help, but the results are flaky.

# Learning rate

Let's try reinstating the original training loop, but reduce the learning rate. This is another way of addressing the dying ReLU problem.

In [34]:
def train(run_name="", learning_rate=1e-3, epochs=100):
    gen = TetrisModel().to(device)
    disc = TetrisDiscriminator().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_020")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc, y_smoothing=0.0)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples, y_smoothing=0.0)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples, y_smoothing=0.0)
        gen_zero_grads = 0
        for name, weight in gen.named_parameters():
            tb_writer.add_histogram(f"Weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Gradients/{name}", weight.grad, epoch)
            gen_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Zero gradients", gen_zero_grads, epoch)
        disc_zero_grads = 0
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)
            disc_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Discriminator zero gradients", disc_zero_grads, epoch)

    tb_writer.close()
    print("Done!")

In [33]:
train(run_name="lr_1em4", learning_rate=1e-4)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3876, G loss: 0.7431
[84/1762] D loss: 1.3834, G loss: 0.7447
[164/1762] D loss: 1.3701, G loss: 0.7485
[244/1762] D loss: 1.3522, G loss: 0.7564
[324/1762] D loss: 1.3286, G loss: 0.7625
[404/1762] D loss: 1.2793, G loss: 0.7695
[484/1762] D loss: 1.2041, G loss: 0.7816
[564/1762] D loss: 1.1095, G loss: 0.7881
[644/1762] D loss: 0.9665, G loss: 0.8697
[724/1762] D loss: 0.8562, G loss: 0.9590
[804/1762] D loss: 0.7806, G loss: 1.0783
[884/1762] D loss: 0.6841, G loss: 1.2420
[964/1762] D loss: 0.6466, G loss: 1.4002
[1044/1762] D loss: 0.4334, G loss: 1.7853
[1124/1762] D loss: 0.4379, G loss: 1.9718
[1204/1762] D loss: 0.3848, G loss: 2.2124
[1284/1762] D loss: 0.3274, G loss: 2.2416
[1364/1762] D loss: 0.3817, G loss: 2.4955
[1444/1762] D loss: 0.4208, G loss: 2.7263
[1524/1762] D loss: 0.2952, G loss: 3.0468
[1604/1762] D loss: 0.2257, G loss: 2.9054
[1684/1762] D loss: 0.2939, G loss: 2.7178
[1762/1762] D loss: 0.7945, G 

Interestingly, with this lower learning rate of 1e-4, the number of dead ReLU units actually goes down over time instead of up. The spawn recall still shows instability but gradually goes up over time. The board accuracy is quite unstable and doesn't reach as high a value as with learning rate 1e-3. The discriminator loss stabilises at a higher value than before. The generator loss stabilises at a lower value, even though the board accuracy is also lower.

Let's try re-running with learning rate 1e-4 but with more iterations.

In [35]:
train(run_name="lr_1em4", learning_rate=1e-4, epochs=500)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3891, G loss: 0.6476
[84/1762] D loss: 1.3858, G loss: 0.6555
[164/1762] D loss: 1.3820, G loss: 0.6599
[244/1762] D loss: 1.3762, G loss: 0.6645
[324/1762] D loss: 1.3653, G loss: 0.6673
[404/1762] D loss: 1.3514, G loss: 0.6692
[484/1762] D loss: 1.3233, G loss: 0.6712
[564/1762] D loss: 1.2812, G loss: 0.6756
[644/1762] D loss: 1.2118, G loss: 0.6915
[724/1762] D loss: 1.1258, G loss: 0.7277
[804/1762] D loss: 1.0624, G loss: 0.7538
[884/1762] D loss: 0.9151, G loss: 0.8687
[964/1762] D loss: 0.8049, G loss: 0.9847
[1044/1762] D loss: 0.6970, G loss: 1.1900
[1124/1762] D loss: 0.6766, G loss: 1.1881
[1204/1762] D loss: 0.5981, G loss: 1.6927
[1284/1762] D loss: 0.5332, G loss: 1.7024
[1364/1762] D loss: 0.5447, G loss: 1.8337
[1444/1762] D loss: 0.4508, G loss: 2.0551
[1524/1762] D loss: 0.6209, G loss: 2.0660
[1604/1762] D loss: 0.3405, G loss: 2.5549
[1684/1762] D loss: 0.4292, G loss: 2.4969
[1762/1762] D loss: 0.1877, G 

Training for this longer period of time doesn't guarantee good performance without an appropriate stopping criterion. The board accuracy generally stays high, over 90%, but there are regular epochs where it jumps down (sometimes to 60%) and back up again. Spawn recall is also pretty unstable, sometimes oscillating and sometimes getting "stuck" at 0%.

Compared to runs with a higher learning rate, the discriminator scores are slightly more dispersed in the cluster around 0.5, and the cluster around 1.0 (real data) is smaller. The cluster around 0.0 (fake data) moves close to the 0.5 cluster by the end of the training run.

The discriminator gradients are more varied than runs with a higher learning rate, where they are all clustered very close to zero. However, the generator gradients are still very clustered around zero.

While reducing the learning rate to 1e-4 doesn't fix the board accuracy and spawn recall issues, it does have the appealing property of removing the dying ReLU issue, so let's keep this lower learning rate.

# Handicap discriminator

With learning rate 1e-4, we saw that the generator gradients become very small while the discriminator gradients don't. Perhaps the discriminator is learning too quickly for the generator so it doesn't get useful feedback. Let's try handicapping the discriminator.

## Learning rate

Let's try giving the generator a higher learning rate than the discriminator.

In [36]:
def train(run_name="", lr_gen=1e-4, lr_disc=1e-4, epochs=100):
    gen = TetrisModel().to(device)
    disc = TetrisDiscriminator().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=lr_gen)
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=lr_disc)

    log_dir = os.path.join("runs", "experiment_020")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc, y_smoothing=0.0)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples, y_smoothing=0.0)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples, y_smoothing=0.0)
        gen_zero_grads = 0
        for name, weight in gen.named_parameters():
            tb_writer.add_histogram(f"Weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Gradients/{name}", weight.grad, epoch)
            gen_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Zero gradients", gen_zero_grads, epoch)
        disc_zero_grads = 0
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)
            disc_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Discriminator zero gradients", disc_zero_grads, epoch)

    tb_writer.close()
    print("Done!")

In [37]:
train(run_name="lr_gen_1em3_disc_1em4", lr_gen=1e-3, lr_disc=1e-4, epochs=100)
train(run_name="lr_gen_1em4_disc_1em5", lr_gen=1e-4, lr_disc=1e-5, epochs=100)
train(run_name="lr_gen_1em5_disc_1em5", lr_gen=1e-5, lr_disc=1e-5, epochs=100)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3866, G loss: 0.6801
[84/1762] D loss: 1.3863, G loss: 0.6834
[164/1762] D loss: 1.3862, G loss: 0.6867
[244/1762] D loss: 1.3841, G loss: 0.6919
[324/1762] D loss: 1.3836, G loss: 0.6965
[404/1762] D loss: 1.3832, G loss: 0.7002
[484/1762] D loss: 1.3797, G loss: 0.7054
[564/1762] D loss: 1.3778, G loss: 0.7084
[644/1762] D loss: 1.3813, G loss: 0.7076
[724/1762] D loss: 1.3834, G loss: 0.7017
[804/1762] D loss: 1.3844, G loss: 0.6969
[884/1762] D loss: 1.3855, G loss: 0.6970
[964/1762] D loss: 1.3850, G loss: 0.7051
[1044/1762] D loss: 1.3850, G loss: 0.6936
[1124/1762] D loss: 1.3846, G loss: 0.6945
[1204/1762] D loss: 1.3854, G loss: 0.7030
[1284/1762] D loss: 1.3862, G loss: 0.6966
[1364/1762] D loss: 1.3839, G loss: 0.6893
[1444/1762] D loss: 1.3854, G loss: 0.6884
[1524/1762] D loss: 1.3843, G loss: 0.6960
[1604/1762] D loss: 1.3849, G loss: 0.6989
[1684/1762] D loss: 1.3863, G loss: 0.6935
[1762/1762] D loss: 1.3862, G 

For gen LR 1e-3 and disc LR 1e-4, we see the generator zero gradients quickly climb, and the board accuracy is very unstable. So, the reduction of learning rate to 1e-4 is not only necessary in the discriminator, it needs to be done in the generator as well.

For gen LR 1e-4 and disc LR 1e-5, the board accuracy rises morre slowly and attains a lower peak. The discriminator accuracy stays around 50% for most of the training run, as opposed to 55% for other runs. The generator gradients are much smaller, on the order of 1e-4 or even as low as 1e-6. The spawn recall is very noisy.

It looks like in this training run, the discriminator did not learn fast enough to give useful feedback, and consequently the generator was effectively learning on random noise, which meant it didn't achieve as good performance.

For gen LR 1e-5 and disc LR 1e-5, training is *very* slow. By the end of 100 epochs, the board accuracy is only 7%.

Let's try another run where the discriminator learning rate is still lower than the generator one, but not by as much. Then, let's do a sanity check by making the generator's learning rate less than the discriminator's.

In [38]:
train(run_name="lr_gen_1em4_disc_3em5", lr_gen=1e-4, lr_disc=3e-5, epochs=100)
train(run_name="lr_gen_1em5_disc_1em4", lr_gen=1e-5, lr_disc=1e-4, epochs=100)
train(run_name="lr_gen_3em5_disc_1em4", lr_gen=3e-5, lr_disc=1e-4, epochs=100)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3897, G loss: 0.7660
[84/1762] D loss: 1.3893, G loss: 0.7628
[164/1762] D loss: 1.3886, G loss: 0.7615
[244/1762] D loss: 1.3872, G loss: 0.7577
[324/1762] D loss: 1.3855, G loss: 0.7557
[404/1762] D loss: 1.3866, G loss: 0.7533
[484/1762] D loss: 1.3851, G loss: 0.7523
[564/1762] D loss: 1.3846, G loss: 0.7454
[644/1762] D loss: 1.3803, G loss: 0.7436
[724/1762] D loss: 1.3804, G loss: 0.7399
[804/1762] D loss: 1.3775, G loss: 0.7406
[884/1762] D loss: 1.3738, G loss: 0.7412
[964/1762] D loss: 1.3725, G loss: 0.7421
[1044/1762] D loss: 1.3668, G loss: 0.7352
[1124/1762] D loss: 1.3631, G loss: 0.7380
[1204/1762] D loss: 1.3583, G loss: 0.7445
[1284/1762] D loss: 1.3503, G loss: 0.7474
[1364/1762] D loss: 1.3474, G loss: 0.7461
[1444/1762] D loss: 1.3448, G loss: 0.7443
[1524/1762] D loss: 1.3492, G loss: 0.7370
[1604/1762] D loss: 1.3524, G loss: 0.7444
[1684/1762] D loss: 1.3446, G loss: 0.7394
[1762/1762] D loss: 1.3369, G 

For gen LR 1e-4 and disc LR 3e-5, the board accuracy, discriminator accuracy and spawn recall are all very unstable. It seems worse than disc LRs 1e-4 and 1e-5 though, so maybe it is partly a fluke of the random initialization. Nonetheless, it doesn't seem like a good configuration.

For gen LR 3e-5 and disc LR 1e-4, the board accuracy curve looks a lot more stable, though it does seem to level off around 86% training board accuracy and meanwhile is 80% on the test set, so we see a bit of overfitting. The generator gradients are fairly close to zero after 100 epochs, similar to what we see with both learning rates at 1e-4. Some of the discriminator gradients are fairly large (values of around +/-10 or 20) while most are "normal" (values between -2 and 2). This combination of learning rates looks good, though to see its full potential we should probably train for more iterations. The spawn recall doesn't go above zero, but this is to be expected because of the relatively low board accuracy.

For gen LR 1e-5 and disc LR 1e-4, the board accuracy doesn't go above zero before epoch 63 (but we see the cell accuracy steadily improve). Things look fairly stable, but training is slow. It seems LR 1e-5 is too low.

Let's redo gen LR 3e-5, disc LR 1e-4, but train for more epochs.

In [40]:
train(run_name="lr_gen_3em5_disc_1em4", lr_gen=3e-5, lr_disc=1e-4, epochs=500)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3961, G loss: 0.5934
[84/1762] D loss: 1.3919, G loss: 0.5990
[164/1762] D loss: 1.3840, G loss: 0.6078
[244/1762] D loss: 1.3729, G loss: 0.6184
[324/1762] D loss: 1.3545, G loss: 0.6316
[404/1762] D loss: 1.3304, G loss: 0.6490
[484/1762] D loss: 1.2710, G loss: 0.6902
[564/1762] D loss: 1.1576, G loss: 0.7696
[644/1762] D loss: 1.0409, G loss: 0.8293
[724/1762] D loss: 0.8793, G loss: 0.9897
[804/1762] D loss: 0.7675, G loss: 1.1624
[884/1762] D loss: 0.5785, G loss: 1.6082
[964/1762] D loss: 0.3837, G loss: 1.9121
[1044/1762] D loss: 0.3845, G loss: 2.1101
[1124/1762] D loss: 0.2443, G loss: 2.5008
[1204/1762] D loss: 0.1973, G loss: 2.8896
[1284/1762] D loss: 0.1547, G loss: 3.4157
[1364/1762] D loss: 0.1658, G loss: 3.4066
[1444/1762] D loss: 0.1268, G loss: 3.4305
[1524/1762] D loss: 0.0926, G loss: 3.7162
[1604/1762] D loss: 0.0865, G loss: 4.0365
[1684/1762] D loss: 0.0545, G loss: 4.3674
[1762/1762] D loss: 0.0539, G 

For this repeat of gen LR 3e-5, disc LR 1e-4, we see quite high peaks of board accuracy (92.57% train, 90.68% test), but the spawn recall is zero. This configuration does not live up to its expectations. The generator gradients are all quite small. The discriminator gradients stay far enough from zero even until the end of the training run.

This suggests that the generator learns to fool an early-stage discriminator (which doesn't care about block spawns) but then fails to adapt once the discriminator starts taking block spawns into account.

## Optimizer

Some sources say to use SGD for the discriminator while using Adam for the generator. Let's try it (with learning rate 1e-4 for both) to see if it helps at all.

In [43]:
def train(run_name="", learning_rate=1e-4, disc_momentum=0.0, epochs=100):
    gen = TetrisModel().to(device)
    disc = TetrisDiscriminator().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.SGD(disc.parameters(), lr=learning_rate, momentum=disc_momentum)

    log_dir = os.path.join("runs", "experiment_020")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc, y_smoothing=0.0)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples, y_smoothing=0.0)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples, y_smoothing=0.0)
        gen_zero_grads = 0
        for name, weight in gen.named_parameters():
            tb_writer.add_histogram(f"Weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Gradients/{name}", weight.grad, epoch)
            gen_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Zero gradients", gen_zero_grads, epoch)
        disc_zero_grads = 0
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)
            disc_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Discriminator zero gradients", disc_zero_grads, epoch)

    tb_writer.close()
    print("Done!")

In [42]:
train(run_name="different_optimizers")

Epoch 0
-------------------------------
[4/1762] D loss: 1.3928, G loss: 0.7581
[84/1762] D loss: 1.3917, G loss: 0.7577
[164/1762] D loss: 1.3919, G loss: 0.7582
[244/1762] D loss: 1.3930, G loss: 0.7571
[324/1762] D loss: 1.3936, G loss: 0.7561
[404/1762] D loss: 1.3939, G loss: 0.7555
[484/1762] D loss: 1.3935, G loss: 0.7554
[564/1762] D loss: 1.3945, G loss: 0.7554
[644/1762] D loss: 1.3946, G loss: 0.7545
[724/1762] D loss: 1.3948, G loss: 0.7537
[804/1762] D loss: 1.3959, G loss: 0.7540
[884/1762] D loss: 1.3948, G loss: 0.7539
[964/1762] D loss: 1.3952, G loss: 0.7538
[1044/1762] D loss: 1.3956, G loss: 0.7532
[1124/1762] D loss: 1.3948, G loss: 0.7525
[1204/1762] D loss: 1.3957, G loss: 0.7527
[1284/1762] D loss: 1.3948, G loss: 0.7530
[1364/1762] D loss: 1.3947, G loss: 0.7525
[1444/1762] D loss: 1.3956, G loss: 0.7525
[1524/1762] D loss: 1.3952, G loss: 0.7522
[1604/1762] D loss: 1.3957, G loss: 0.7520
[1684/1762] D loss: 1.3951, G loss: 0.7524
[1762/1762] D loss: 1.3977, G 

In [44]:
train(run_name="different_optimizers_mom_0p01", disc_momentum=0.01)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3951, G loss: 0.6058
[84/1762] D loss: 1.3952, G loss: 0.6053
[164/1762] D loss: 1.3953, G loss: 0.6054
[244/1762] D loss: 1.3956, G loss: 0.6055
[324/1762] D loss: 1.3949, G loss: 0.6057
[404/1762] D loss: 1.3953, G loss: 0.6060
[484/1762] D loss: 1.3959, G loss: 0.6061
[564/1762] D loss: 1.3960, G loss: 0.6054
[644/1762] D loss: 1.3955, G loss: 0.6057
[724/1762] D loss: 1.3960, G loss: 0.6065
[804/1762] D loss: 1.3960, G loss: 0.6058
[884/1762] D loss: 1.3957, G loss: 0.6064
[964/1762] D loss: 1.3967, G loss: 0.6061
[1044/1762] D loss: 1.3969, G loss: 0.6058
[1124/1762] D loss: 1.3967, G loss: 0.6051
[1204/1762] D loss: 1.3961, G loss: 0.6062
[1284/1762] D loss: 1.3961, G loss: 0.6056
[1364/1762] D loss: 1.3967, G loss: 0.6060
[1444/1762] D loss: 1.3956, G loss: 0.6063
[1524/1762] D loss: 1.3971, G loss: 0.6058
[1604/1762] D loss: 1.3967, G loss: 0.6059
[1684/1762] D loss: 1.3969, G loss: 0.6061
[1762/1762] D loss: 1.3967, G 

In [45]:
train(run_name="different_optimizers_mom_0p1", disc_momentum=0.1)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3878, G loss: 0.7204
[84/1762] D loss: 1.3876, G loss: 0.7195
[164/1762] D loss: 1.3872, G loss: 0.7194
[244/1762] D loss: 1.3884, G loss: 0.7191
[324/1762] D loss: 1.3877, G loss: 0.7184
[404/1762] D loss: 1.3887, G loss: 0.7185
[484/1762] D loss: 1.3887, G loss: 0.7179
[564/1762] D loss: 1.3884, G loss: 0.7180
[644/1762] D loss: 1.3902, G loss: 0.7173
[724/1762] D loss: 1.3898, G loss: 0.7175
[804/1762] D loss: 1.3893, G loss: 0.7173
[884/1762] D loss: 1.3898, G loss: 0.7169
[964/1762] D loss: 1.3883, G loss: 0.7165
[1044/1762] D loss: 1.3907, G loss: 0.7160
[1124/1762] D loss: 1.3892, G loss: 0.7163
[1204/1762] D loss: 1.3893, G loss: 0.7171
[1284/1762] D loss: 1.3921, G loss: 0.7153
[1364/1762] D loss: 1.3894, G loss: 0.7163
[1444/1762] D loss: 1.3902, G loss: 0.7157
[1524/1762] D loss: 1.3916, G loss: 0.7149
[1604/1762] D loss: 1.3910, G loss: 0.7154
[1684/1762] D loss: 1.3913, G loss: 0.7148
[1762/1762] D loss: 1.3917, G 

When using Adam for the generator and SGD for the discriminator, we see that the cell accuracy decreases very slowly without any hyperparameter tuning, and we start encountering the dying ReLU problem again. If we add momentum, then the cell accuracy increases faster, but it still gets stuck at 80% at epoch 40 and doesn't recover by epoch 100.

While others may have had more success from using SGD for the discriminator, it looks like to use it to its full potential we'd have to tune the hyperparameters, and we've already done this exploration with Adam. Tuning hyperparameters again with SGD will take a significant amount of time and is not guaranteed to produce better results, so let's stick with Adam for both the generator and discriminator for now.

# Conclusion

We noted that both the generator and discriminator were suffering from a dying ReLU problem, introducing a new metric to measure this. Reducing the learning rate of both to 1e-4 fixed the dying ReLU problem, but made the final board accuracy lower. The spawn recall was no longer stuck at zero, but it oscillated wildly.

One hypothesis is that the generator doesn't fully learn when blocks should be spawned, so when it recalls a lot of block spawns then it also wrongly predicts others. We can test this by reinstating the spawn precision metric.

We know that the current generator architecture is able to learn block spawns for a stationary problem, i.e. when scored directly against the next frame, so perhaps the reason it struggles to learn as well in a GAN setup is because it "wastes" resources at the start fooling an incapable discriminator that doesn't yet care about block spawns, and then is unable to adapt when the discriminator criterion changes. We could test or overcome this in a few ways. We could freeze the global module for the first few epochs, then enable it once the discriminator starts to care about block spawns. We could pretrain the discriminator for a few epochs before starting to train the generator alongside it. Or, we could increase the capacity of the discriminator with the hope that it will learn faster.