# Experiment 019

In this experiment, we will try "one-sided" GAN training, where we freeze either the generator or discriminator after some number of epochs and train the other one, to assess the degree to which it's "playing catch-up" with the other model.

In [15]:
import os
from pathlib import Path
import shutil
import datetime

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

In [16]:
class RecordingDataset(Dataset):
    def __init__(self, path: str):
        self.path = path
        if not os.path.exists(path):
            raise FileNotFoundError()
        with os.scandir(self.path) as it:
            entry: os.DirEntry = next(iter(it))
            _, self.ext = os.path.splitext(entry.name)
            self.highest_index = max((int(Path(file.path).stem) for file in it), default=-1)

    def __len__(self):
        return self.highest_index + 1

    def __getitem__(self, idx):
        file = os.path.join(self.path, f"{idx}{self.ext}")
        if not os.path.exists(file):
            raise IndexError()
        boards = np.load(file)

        def transform(board):
            board = torch.tensor(board, dtype=torch.long)
            board = F.one_hot(board, 2) # One-hot encode the cell types
            board = board.type(torch.float) # Convert to floating-point
            board = board.permute((2, 0, 1)) # Move channels/classes to dimension 0
            return board

        x = transform(boards[-2]) # Ignore all boards except the last two
        y = transform(boards[-1])
        return x, y
        

In [17]:
train_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "train"))
test_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "test"))
batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

x, y = next(iter(train_dataloader))
print(x.shape, x.dtype)
print(y.shape, y.dtype)

torch.Size([4, 2, 22, 10]) torch.float32
torch.Size([4, 2, 22, 10]) torch.float32


In [18]:
device = (
    "cuda" if torch.cuda.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cpu device


In [19]:
class TetrisModel(nn.Module):
    """Predicts the next state of the cells.

    Inputs:
        x: Tensor of float32 of shape (batch_size, channels, height, width). channels = 2 is the one-hot encoding of cell types, with
           0 for empty cells and 1 for filled cells. height = 22 and width = 10 are the dimensions of the game board. The entries
           should be 0 for empty cells and 1 for filled cells.
        z: Tensor of float32 of shape (batch_size, 4). The entries should be random numbers sampled from a uniform distribution.
    
    Returns: Tensor of float32 of shape (batch_size, height, width), logits for the new cells. Probabilities close to 0 (negative logits)
             correspond to empty cells, and probabilities close to 1 (positive logits) correspond to filled cells.
    """

    def __init__(self):
        super().__init__()
        self.loc = nn.Sequential(
            nn.Conv2d(6, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
        )
        self.glob = nn.Sequential(
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(160, 10)
        )
        self.head = nn.Sequential(
            nn.Conv2d(26, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 2, 1),
            nn.Softmax(dim=1)
        )

    def forward(self, x, z):
        batch_size, channels, height, width = x.shape
        
        z = z[:, :, None, None] # Expand dims to match x
        z = z.repeat(1, 1, height, width) # Upscale to image size
        x = torch.cat((x, z), dim=1)

        x = self.loc(x)

        x_glob = self.glob(x)
        x_glob = x_glob[:, :, None, None] # Expand dims
        x_glob = x_glob.repeat(1, 1, height, width) # Upscale to image size
        x = torch.cat((x, x_glob), dim=1)

        y = self.head(x)
        return y

In [20]:
class TetrisDiscriminator(nn.Module):
    """A discriminator for the cell state predictions. Assesses the output of the generator.

    Inputs:
        x: Tensor of float32 of shape (batch_size, channels, height, width). channels = 2 is the one-hot encoding of cell types, with
           0 for empty cells and 1 for filled cells. height = 22 and width = 10 are the dimensions of the game board. The entries
           should be 0 for empty cells and 1 for filled cells.
        y: Tensor of float32 of shape (batch_size, channels, height, width), as with x. This should be either the output of the
           generator (with exp applied) or the one-hot encoding of the ground truth of the next cell states.
    
    Returns: Tensor of float32 of shape (batch_size, 1), decisions on whether the data are real or fake. Probabilities close to 0 (negative logits)
             correspond to fake data, and probabilities close to 1 (positive logits) correspond to real data.
    """

    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits

In [21]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [22]:
import itertools

def find_interesting_examples(dataset, num=4):
    num_spawns = num // 2
    num_normal = num - num_spawns
    
    def inner():
        num_spawns_left = num_spawns
        num_normal_left = num_normal

        for x, y in dataset:
            # Check for block spawn
            if (x.argmax(0)[0] == 0).all() & (y.argmax(0)[0] == 1).any():
                if num_spawns_left > 0:
                    num_spawns_left -= 1
                    yield x, y
                else:
                    continue
            # Yield general examples
            if num_normal_left > 0:
                num_normal_left -= 1
                yield x, y
            else:
                continue
            
    return list(itertools.islice(inner(), num))

In [23]:
def render_prediction(x, pred, y):
    """Renders an example and prediction into a single-image array.
    
    Inputs:
        x: Tensor of shape (height, width), the model input.
        pred: Tensor of shape (height, width), the model prediction.
        y: Tensor of shape (height, width), the target.
    """
    assert len(x.shape) == 2, f"Expected tensors of shape (width, height) but got {x.shape}"
    assert x.shape == pred.shape, f"Shapes do not match: {x.shape} != {pred.shape}"
    assert x.shape == y.shape, f"Shapes do not match: {x.shape} != {y.shape}"
    height, width = x.shape
    with torch.no_grad():
        separator = torch.ones(height, 1, dtype=x.dtype)
        return torch.cat((x, separator, pred, separator, y), dim=-1)

In [24]:
real_label = 1.0
fake_label = 0.0

def train_loop(dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc):
    gen.train()
    disc.train()

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(train_dataloader):
        ##################################################################
        # (1) Update discriminator: minimize -log(D(x)) - log(1 - D(G(z)))
        ##################################################################
        disc.zero_grad()

        ## Train with all-real batch
        # Format batch
        X, y = X.to(device), y.to(device)
        batch_size = X.size(0)
        real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through discriminator
        output = torch.flatten(disc(X, y))
        # Calculate loss on all-real batch
        err_disc_real = loss_fn(output, real_labels)
        # Calculate gradients for discriminator in backward pass
        err_disc_real.backward()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        z = torch.rand(batch_size, 4, device=device)
        # Generate fake image batch with generator
        y_fake = gen(X, z)
        fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)
        # Classify all fake batch with discriminator
        output = torch.flatten(disc(X, y_fake.detach()))
        # Calculate discriminator's loss on the all-fake batch
        err_disc_fake = loss_fn(output, fake_labels)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        err_disc_fake.backward()

        ## Update discriminator weights
        # Compute error of discriminator as sum over the fake and the real batches
        err_disc = err_disc_real + err_disc_fake
        # Update discriminator
        optimizer_disc.step()

        ##############################################
        # (2) Update generator: minimize -log(D(G(z)))
        ##############################################
        gen.zero_grad()
        # Since we just updated the discriminator, perform another forward pass of the all-fake batch through it
        output = torch.flatten(disc(X, y_fake))
        # Calculate the generator's loss based on this output
        # We use real labels because the generator wants to fool the discriminator
        err_gen = loss_fn(output, real_labels)
        # Calculate gradients for generator
        err_gen.backward()
        # Update generator
        optimizer_gen.step()

        # Output training stats
        if batch % 20 == 0:
            current = batch * dataloader.batch_size + batch_size
            print(f"[{current}/{size}] D loss: {err_disc.item():.4f}, G loss: {err_gen.item():.4f}")


def test_loop(split_name, dataloader, gen, disc, loss_fn, tb_writer, epoch, examples):
    gen.eval()
    disc.eval()

    loss_disc = 0.0
    loss_gen = 0.0
    disc_accuracy = 0.0
    cell_accuracy = 0.0
    board_accuracy = 0.0
    spawn_recall = 0.0
    num_spawns = 0.0
    spawn_precision = 0.0
    num_predicted_spawns = 0.0

    num_batches = len(dataloader)
    with torch.no_grad():        
        for X, y in dataloader:
            batch_size = X.size(0)
            real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)

            output_real = disc(X, y)
            loss_disc += loss_fn(output_real, real_labels).item()

            z = torch.rand(batch_size, 4, device=device)
            y_fake = gen(X, z)
            output_fake = disc(X, y_fake)
            
            loss_disc += loss_fn(output_fake, fake_labels).item()
            loss_gen += loss_fn(output_fake, real_labels).item()

            pred_real = (output_real > 0.0)
            pred_fake = (output_fake > 0.0)
            disc_accuracy += pred_real.type(torch.float).mean().item()
            disc_accuracy += (~pred_fake).type(torch.float).mean().item()

            classes_X = torch.argmax(X, dim=1)
            classes_y = torch.argmax(y, dim=1)
            classes_y_fake = torch.argmax(y_fake, dim=1)
            cell_accuracy += (classes_y_fake == classes_y).type(torch.float).mean().item()
            board_accuracy += (classes_y_fake == classes_y).all(-1).all(-1).type(torch.float).mean().item()

            actual_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y[:, 0, :] == 1).any(-1)
            predicted_spawns = (classes_y_fake[:, 0, :] == 1).any(-1)
            num_true_positives = (actual_spawns & predicted_spawns).type(torch.float).sum().item()
            spawn_recall += num_true_positives
            num_spawns += actual_spawns.type(torch.float).sum().item()
            spawn_precision += num_true_positives
            num_predicted_spawns += predicted_spawns.type(torch.float).sum().item()

    loss_disc /= num_batches
    loss_gen /= num_batches
    cell_accuracy /= num_batches
    board_accuracy /= num_batches
    spawn_recall /= num_spawns
    spawn_precision = np.nan if (num_predicted_spawns == 0.0) else (spawn_precision / num_predicted_spawns)
    disc_accuracy /= (2.0 * num_batches)

    print(f"{split_name} error: \n D loss: {loss_disc:>8f}, G loss: {loss_gen:>8f}, D accuracy: {(100*disc_accuracy):>0.1f}%, cell accuracy: {(100*cell_accuracy):>0.1f}%, board accuracy: {(100*board_accuracy):>0.1f}% \n")

    tb_writer.add_scalar(f"Discriminator loss/{split_name}", loss_disc, epoch)
    tb_writer.add_scalar(f"Loss/{split_name}", loss_gen, epoch)
    tb_writer.add_scalar(f"Discriminator accuracy/{split_name}", disc_accuracy, epoch)
    tb_writer.add_scalar(f"Cell accuracy/{split_name}", cell_accuracy, epoch)
    tb_writer.add_scalar(f"Board accuracy/{split_name}", board_accuracy, epoch)
    tb_writer.add_scalar(f"Spawn recall/{split_name}", spawn_recall, epoch)

    with torch.no_grad():
        for i, (X, y) in enumerate(examples):
            X, y = X.unsqueeze(0), y.unsqueeze(0)
            z = torch.rand(1, 4, device=device)
            y_fake = gen(X, z)
            X, y, y_fake = X.squeeze(0), y.squeeze(0), y_fake.squeeze(0)
            X, y, y_fake = X.argmax(0), y.argmax(0), y_fake.argmax(0)
            img = render_prediction(X, y_fake, y)
            tb_writer.add_image(f"Predictions/{split_name}/{i}", img, epoch, dataformats="HW")


# Freeze generator

In [27]:
def train_and_freeze_gen(run_name="", freeze_epoch=0):
    learning_rate = 1e-3
    epochs = 50

    gen = TetrisModel().to(device)
    disc = TetrisDiscriminator().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_019")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        if epoch == freeze_epoch:
            gen.requires_grad_(False)
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples)
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)

    tb_writer.close()
    print("Done!")

In [28]:
train_and_freeze_gen(run_name="freeze_gen_epoch_0", freeze_epoch=0)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3967, G loss: 0.7000
[84/1762] D loss: 0.6812, G loss: 1.2337
[164/1762] D loss: 0.1283, G loss: 2.6326
[244/1762] D loss: 0.0180, G loss: 4.3576
[324/1762] D loss: 0.0095, G loss: 5.2025
[404/1762] D loss: 0.0043, G loss: 5.7294
[484/1762] D loss: 0.0049, G loss: 5.9614
[564/1762] D loss: 0.0022, G loss: 6.5074
[644/1762] D loss: 0.0014, G loss: 6.8991
[724/1762] D loss: 0.0019, G loss: 6.7251
[804/1762] D loss: 0.0014, G loss: 7.0886
[884/1762] D loss: 0.0011, G loss: 7.2144
[964/1762] D loss: 0.0015, G loss: 7.2108
[1044/1762] D loss: 0.0013, G loss: 7.3081
[1124/1762] D loss: 0.0007, G loss: 7.7848
[1204/1762] D loss: 0.0004, G loss: 7.9899
[1284/1762] D loss: 0.0006, G loss: 8.0130
[1364/1762] D loss: 0.0003, G loss: 8.2685
[1444/1762] D loss: 0.0004, G loss: 8.1606
[1524/1762] D loss: 0.0005, G loss: 8.0983
[1604/1762] D loss: 0.0003, G loss: 8.4089
[1684/1762] D loss: 0.0003, G loss: 8.3041
[1762/1762] D loss: 0.0003, G 

In [30]:
train_and_freeze_gen(run_name="freeze_gen_epoch_5", freeze_epoch=5)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3628, G loss: 0.7337
[84/1762] D loss: 0.5267, G loss: 1.8589
[164/1762] D loss: 0.1896, G loss: 3.0379
[244/1762] D loss: 0.0452, G loss: 4.2842
[324/1762] D loss: 0.0156, G loss: 5.3431
[404/1762] D loss: 0.0315, G loss: 5.3833
[484/1762] D loss: 0.0203, G loss: 5.1763
[564/1762] D loss: 0.0087, G loss: 6.5806
[644/1762] D loss: 0.0066, G loss: 5.4703
[724/1762] D loss: 0.0158, G loss: 5.8748
[804/1762] D loss: 0.2884, G loss: 5.5163
[884/1762] D loss: 0.3284, G loss: 3.3779
[964/1762] D loss: 0.3438, G loss: 1.3096
[1044/1762] D loss: 0.0829, G loss: 2.8302
[1124/1762] D loss: 0.6703, G loss: 3.7446
[1204/1762] D loss: 0.3261, G loss: 2.5430
[1284/1762] D loss: 0.8957, G loss: 2.5023
[1364/1762] D loss: 0.7811, G loss: 2.0888
[1444/1762] D loss: 0.7707, G loss: 0.9743
[1524/1762] D loss: 0.7088, G loss: 0.7719
[1604/1762] D loss: 1.6000, G loss: 0.4953
[1684/1762] D loss: 1.0720, G loss: 0.8630
[1762/1762] D loss: 0.9721, G 

If the generator is frozen immediately, the discriminator easily learns to distinguish between fake and real data.

However, if the generator is frozen after 5 epochs, the discriminator loss becomes unstable. The frozen generator's highest board accuracy is 86.56%, which suggests that the generator could expect to get at most 56.72% predictions correct*. The discriminator's actual accuracy for the corresponding epochs was 55.07% and 52.75%, which suggests there is some room for improvement.

The instability can be seen most clearly in the discriminator loss, which jumps up and down and on average goes up.

*Let $ p_g $ be the probability that the generator produces passable output, as defines by the rules of Tetris. Let $ a_g $ be the probability that the generator's output matches the training data, and $ a_d $ the probability of the discriminator making a correct classification. Then we have $ a_d \le \frac{1}{2} + \frac{1}{2} (1 - p_g) = 1 - \frac{1}{2} p_g $, because the discriminator gets shown 50% real and 50% fake data, and since $ p_g \le 1 $, the best strategy for a perfect discriminator is to label data as real when it looks real and fake when it looks fake. If it does so, it marks all of the real data correctly and $ (1 - p_g) $ of the fake data correctly. It's difficult to measure $ p_g $ directly, but we know $ a_g \le p_g $. ("If it matches the training data, then it looks realistic.")

Therefore we get the equation $ a_d + \frac{1}{2} a_g \le 1 $.

There is a lower bound too. If the generator produces terrible output and the discriminator guesses randomly, we get $ a_d + \frac{1}{2} a_g \ge \frac{1}{2} $.

This suggests the metric $ 2 a_d + a_g - 1 $, which should be between 0 and 1.

In practice, this metric might be negative if the discriminator is doing worse than random chance, but in this case it could just flip the prediction and score higher. The metric could also be slightly above 1, because the discriminator might get lucky with its guesses.

When the metric is high, we can interpret this as "given the current performance of the generator, the discriminator is able to identify almost all of its mistakes".

Let's redefine the test loop to include this metric.

# Combined accuracy

In [33]:
def test_loop(split_name, dataloader, gen, disc, loss_fn, tb_writer, epoch, examples):
    gen.eval()
    disc.eval()

    loss_disc = 0.0
    loss_gen = 0.0
    disc_accuracy = 0.0
    cell_accuracy = 0.0
    board_accuracy = 0.0
    spawn_recall = 0.0
    num_spawns = 0.0
    spawn_precision = 0.0
    num_predicted_spawns = 0.0

    num_batches = len(dataloader)
    with torch.no_grad():        
        for X, y in dataloader:
            batch_size = X.size(0)
            real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)

            output_real = disc(X, y)
            loss_disc += loss_fn(output_real, real_labels).item()

            z = torch.rand(batch_size, 4, device=device)
            y_fake = gen(X, z)
            output_fake = disc(X, y_fake)
            
            loss_disc += loss_fn(output_fake, fake_labels).item()
            loss_gen += loss_fn(output_fake, real_labels).item()

            pred_real = (output_real > 0.0)
            pred_fake = (output_fake > 0.0)
            disc_accuracy += pred_real.type(torch.float).mean().item()
            disc_accuracy += (~pred_fake).type(torch.float).mean().item()

            classes_X = torch.argmax(X, dim=1)
            classes_y = torch.argmax(y, dim=1)
            classes_y_fake = torch.argmax(y_fake, dim=1)
            cell_accuracy += (classes_y_fake == classes_y).type(torch.float).mean().item()
            board_accuracy += (classes_y_fake == classes_y).all(-1).all(-1).type(torch.float).mean().item()

            actual_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y[:, 0, :] == 1).any(-1)
            predicted_spawns = (classes_y_fake[:, 0, :] == 1).any(-1)
            num_true_positives = (actual_spawns & predicted_spawns).type(torch.float).sum().item()
            spawn_recall += num_true_positives
            num_spawns += actual_spawns.type(torch.float).sum().item()
            spawn_precision += num_true_positives
            num_predicted_spawns += predicted_spawns.type(torch.float).sum().item()

    loss_disc /= num_batches
    loss_gen /= num_batches
    cell_accuracy /= num_batches
    board_accuracy /= num_batches
    spawn_recall /= num_spawns
    spawn_precision = np.nan if (num_predicted_spawns == 0.0) else (spawn_precision / num_predicted_spawns)
    disc_accuracy /= (2.0 * num_batches)
    combined_accuracy = board_accuracy + 2.0 * disc_accuracy - 1.0

    print(f"{split_name} error: \n D loss: {loss_disc:>8f}, G loss: {loss_gen:>8f}, D accuracy: {(100*disc_accuracy):>0.1f}%, cell accuracy: {(100*cell_accuracy):>0.1f}%, board accuracy: {(100*board_accuracy):>0.1f}% \n")

    tb_writer.add_scalar(f"Discriminator loss/{split_name}", loss_disc, epoch)
    tb_writer.add_scalar(f"Loss/{split_name}", loss_gen, epoch)
    tb_writer.add_scalar(f"Discriminator accuracy/{split_name}", disc_accuracy, epoch)
    tb_writer.add_scalar(f"Cell accuracy/{split_name}", cell_accuracy, epoch)
    tb_writer.add_scalar(f"Board accuracy/{split_name}", board_accuracy, epoch)
    tb_writer.add_scalar(f"Spawn recall/{split_name}", spawn_recall, epoch)
    tb_writer.add_scalar(f"Combined accuracy/{split_name}", combined_accuracy, epoch)

    with torch.no_grad():
        for i, (X, y) in enumerate(examples):
            X, y = X.unsqueeze(0), y.unsqueeze(0)
            z = torch.rand(1, 4, device=device)
            y_fake = gen(X, z)
            X, y, y_fake = X.squeeze(0), y.squeeze(0), y_fake.squeeze(0)
            X, y, y_fake = X.argmax(0), y.argmax(0), y_fake.argmax(0)
            img = render_prediction(X, y_fake, y)
            tb_writer.add_image(f"Predictions/{split_name}/{i}", img, epoch, dataformats="HW")


In [34]:
#train_and_freeze_gen(run_name="freeze_gen_epoch_0", freeze_epoch=0)
train_and_freeze_gen(run_name="freeze_gen_epoch_5", freeze_epoch=5)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3582, G loss: 0.7652
[84/1762] D loss: 0.5708, G loss: 1.3413
[164/1762] D loss: 0.2174, G loss: 2.0594
[244/1762] D loss: 0.0727, G loss: 3.8658
[324/1762] D loss: 0.0449, G loss: 3.9796
[404/1762] D loss: 0.0299, G loss: 4.6527
[484/1762] D loss: 0.0272, G loss: 4.5964
[564/1762] D loss: 0.0789, G loss: 4.6917
[644/1762] D loss: 0.1421, G loss: 4.2983
[724/1762] D loss: 0.1460, G loss: 2.6522
[804/1762] D loss: 0.4003, G loss: 3.4869
[884/1762] D loss: 0.0498, G loss: 4.6915
[964/1762] D loss: 1.2528, G loss: 2.1709
[1044/1762] D loss: 0.4615, G loss: 0.8731
[1124/1762] D loss: 2.1258, G loss: 0.4588
[1204/1762] D loss: 0.8899, G loss: 1.2532
[1284/1762] D loss: 0.6042, G loss: 0.8578
[1364/1762] D loss: 0.6118, G loss: 1.4607
[1444/1762] D loss: 1.0362, G loss: 1.2556
[1524/1762] D loss: 1.1309, G loss: 2.9071
[1604/1762] D loss: 0.8528, G loss: 1.1346
[1684/1762] D loss: 0.6562, G loss: 2.8039
[1762/1762] D loss: 1.6781, G 

The combined accuracy metric oscillates between 89% and 111%. This averages out at 100%, but the variance seems too high. Perhaps the discriminator is doing alright but the generator is not improving for some reason.

# Freeze discriminator

In [35]:
def train_and_freeze_disc(run_name="", freeze_epoch=0):
    learning_rate = 1e-3
    epochs = 50

    gen = TetrisModel().to(device)
    disc = TetrisDiscriminator().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_019")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        if epoch == freeze_epoch:
            disc.requires_grad_(False)
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples)
        for name, weight in gen.named_parameters():
            tb_writer.add_histogram(f"Generator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Generator gradients/{name}", weight.grad, epoch)

    tb_writer.close()
    print("Done!")

In [39]:
# Apparently we need to update the training loop to support tensors that don't require grad

def train_loop(dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc):
    gen.train()
    disc.train()

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(train_dataloader):
        ##################################################################
        # (1) Update discriminator: minimize -log(D(x)) - log(1 - D(G(z)))
        ##################################################################
        disc.zero_grad()

        ## Train with all-real batch
        # Format batch
        X, y = X.to(device), y.to(device)
        batch_size = X.size(0)
        real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through discriminator
        output = torch.flatten(disc(X, y))
        # Calculate loss on all-real batch
        err_disc_real = loss_fn(output, real_labels)
        # Calculate gradients for discriminator in backward pass
        if err_disc_real.requires_grad:
            err_disc_real.backward()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        z = torch.rand(batch_size, 4, device=device)
        # Generate fake image batch with generator
        y_fake = gen(X, z)
        fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)
        # Classify all fake batch with discriminator
        output = torch.flatten(disc(X, y_fake.detach()))
        # Calculate discriminator's loss on the all-fake batch
        err_disc_fake = loss_fn(output, fake_labels)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        if err_disc_fake.requires_grad:
            err_disc_fake.backward()

        ## Update discriminator weights
        # Compute error of discriminator as sum over the fake and the real batches
        err_disc = err_disc_real + err_disc_fake
        # Update discriminator
        optimizer_disc.step()

        ##############################################
        # (2) Update generator: minimize -log(D(G(z)))
        ##############################################
        gen.zero_grad()
        # Since we just updated the discriminator, perform another forward pass of the all-fake batch through it
        output = torch.flatten(disc(X, y_fake))
        # Calculate the generator's loss based on this output
        # We use real labels because the generator wants to fool the discriminator
        err_gen = loss_fn(output, real_labels)
        # Calculate gradients for generator
        err_gen.backward()
        # Update generator
        optimizer_gen.step()

        # Output training stats
        if batch % 20 == 0:
            current = batch * dataloader.batch_size + batch_size
            print(f"[{current}/{size}] D loss: {err_disc.item():.4f}, G loss: {err_gen.item():.4f}")

In [40]:
train_and_freeze_disc(run_name="freeze_disc_epoch_0", freeze_epoch=0)
train_and_freeze_disc(run_name="freeze_disc_epoch_5", freeze_epoch=5)

Epoch 0
-------------------------------
[4/1762] D loss: 1.4153, G loss: 0.5613
[84/1762] D loss: 1.4798, G loss: 0.5093
[164/1762] D loss: 1.5648, G loss: 0.4513
[244/1762] D loss: 1.5752, G loss: 0.4422
[324/1762] D loss: 1.6349, G loss: 0.4127
[404/1762] D loss: 1.6430, G loss: 0.4090
[484/1762] D loss: 1.6900, G loss: 0.3823
[564/1762] D loss: 1.7131, G loss: 0.3743
[644/1762] D loss: 1.7187, G loss: 0.3823
[724/1762] D loss: 1.6856, G loss: 0.3939
[804/1762] D loss: 1.7639, G loss: 0.3545
[884/1762] D loss: 1.7295, G loss: 0.3840
[964/1762] D loss: 1.7367, G loss: 0.3794
[1044/1762] D loss: 1.7267, G loss: 0.3667
[1124/1762] D loss: 1.7729, G loss: 0.3558
[1204/1762] D loss: 1.7590, G loss: 0.3697
[1284/1762] D loss: 1.7550, G loss: 0.3567
[1364/1762] D loss: 1.7479, G loss: 0.3659
[1444/1762] D loss: 1.7279, G loss: 0.3667
[1524/1762] D loss: 1.7807, G loss: 0.3520
[1604/1762] D loss: 1.7501, G loss: 0.3549
[1684/1762] D loss: 1.7437, G loss: 0.3769
[1762/1762] D loss: 1.7445, G 

In [41]:
train_and_freeze_disc(run_name="freeze_disc_epoch_20", freeze_epoch=20)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3696, G loss: 0.7796
[84/1762] D loss: 0.6314, G loss: 1.4119
[164/1762] D loss: 0.1031, G loss: 3.2123
[244/1762] D loss: 0.0252, G loss: 4.6887
[324/1762] D loss: 0.0307, G loss: 5.0415
[404/1762] D loss: 0.0320, G loss: 4.3509
[484/1762] D loss: 0.1698, G loss: 4.8203
[564/1762] D loss: 0.3615, G loss: 2.9130
[644/1762] D loss: 0.2233, G loss: 2.3559
[724/1762] D loss: 0.3234, G loss: 2.9364
[804/1762] D loss: 0.6554, G loss: 2.1477
[884/1762] D loss: 0.3603, G loss: 3.6595
[964/1762] D loss: 0.9268, G loss: 1.1082
[1044/1762] D loss: 0.9498, G loss: 1.6084
[1124/1762] D loss: 1.1009, G loss: 1.4884
[1204/1762] D loss: 0.7465, G loss: 1.3513
[1284/1762] D loss: 1.2277, G loss: 0.5546
[1364/1762] D loss: 1.6587, G loss: 0.3056
[1444/1762] D loss: 1.3217, G loss: 1.0860
[1524/1762] D loss: 1.2649, G loss: 0.6344
[1604/1762] D loss: 1.4319, G loss: 0.5391
[1684/1762] D loss: 1.3456, G loss: 0.7154
[1762/1762] D loss: 1.2866, G 

When the discriminator is frozen, the generator doesn't have the same problem. It easily learns to fool the discriminator and reduces its own loss to very low numbers around 1.0e-4 or even lower.

# Freeze gen and replace disc

There are at least two possible explanations for why freezing the generator doesn't have the desired effect:
* After a few epochs, the generator becomes so good that the discriminator architecture is incapable of learning the task of distinguishing between its output and the real data.
* Or, after a few epochs, the discriminator gets "locked into a bad state", where it cannot adapt to the generator's new behaviour.

We can distinguish between these two situations by first training the GAN as normal for a few epochs (say 5), then freezing the generator and replacing the discriminator with a newly initialized one. If the discriminator architecture is capable of learning to spot fakes from a partially trained generator, then we know the discriminator's architecture is valid. Otherwise, we can experiment with different architectures.

In [45]:
def train_and_freeze_gen_and_replace_disc(run_name="", freeze_epoch=5, new_learning_rate=1e-4):
    learning_rate = 1e-3
    epochs = 50

    gen = TetrisModel().to(device)
    disc = TetrisDiscriminator().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_019")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        if epoch == freeze_epoch:
            gen.requires_grad_(False)
            disc = TetrisDiscriminator().to(device)
            optimizer_disc = torch.optim.Adam(disc.parameters(), lr=new_learning_rate)
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples)
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)

    tb_writer.close()
    print("Done!")

In [44]:
train_and_freeze_gen_and_replace_disc("freeze_gen_replace_disc_epoch_5", freeze_epoch=5)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3479, G loss: 0.7482
[84/1762] D loss: 0.9303, G loss: 0.8759
[164/1762] D loss: 0.3762, G loss: 1.8266
[244/1762] D loss: 0.1265, G loss: 3.7376
[324/1762] D loss: 0.0473, G loss: 3.6683
[404/1762] D loss: 0.1313, G loss: 3.1214
[484/1762] D loss: 0.2862, G loss: 5.2596
[564/1762] D loss: 0.3426, G loss: 3.4827
[644/1762] D loss: 0.4483, G loss: 4.3587
[724/1762] D loss: 0.6159, G loss: 2.2959
[804/1762] D loss: 0.8739, G loss: 1.8865
[884/1762] D loss: 1.2328, G loss: 1.3862
[964/1762] D loss: 0.9520, G loss: 0.8686
[1044/1762] D loss: 1.1783, G loss: 0.8119
[1124/1762] D loss: 1.6243, G loss: 1.6233
[1204/1762] D loss: 1.1677, G loss: 1.2102
[1284/1762] D loss: 0.9198, G loss: 0.9742
[1364/1762] D loss: 1.4271, G loss: 0.4428
[1444/1762] D loss: 1.5194, G loss: 0.4273
[1524/1762] D loss: 1.2375, G loss: 0.7237
[1604/1762] D loss: 1.1886, G loss: 0.8156
[1684/1762] D loss: 1.4111, G loss: 0.8492
[1762/1762] D loss: 0.7200, G 

Even a newly initialized discriminator struggles on this problem. This suggests that the original discriminator is not in a "locked-in" state, it just isn't capable of learning the problem with the current setup.

To improve the situation, we could reduce the discriminator learning rate, or modify the discriminator architecture. In particular, we could increase the capacity and/or add more batch normalization to it.

The gradient histograms for the discriminator show that many layers have weights in the range (-2, 2) or (-5, 5), which is larger than I've seen with other ML models in the past. The final linear layer always tends to have some gradients that are below -3, sometimes even below -10 for one of the training runs. This means the discriminator might be having a mild version of the exploding gradient problem. We could try a different initialization scheme as per the DCGAN paper, or target our architecture changes towards minimising this.

## Learning rate

In [48]:
train_and_freeze_gen_and_replace_disc("fgrd_nlr_1em5", freeze_epoch=5, new_learning_rate=1e-5)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3850, G loss: 0.6712
[84/1762] D loss: 0.9374, G loss: 0.9944
[164/1762] D loss: 0.4977, G loss: 1.4863
[244/1762] D loss: 0.1457, G loss: 2.6692
[324/1762] D loss: 0.0961, G loss: 3.0312
[404/1762] D loss: 0.0894, G loss: 3.0216
[484/1762] D loss: 0.2817, G loss: 3.8023
[564/1762] D loss: 0.3669, G loss: 1.7353
[644/1762] D loss: 0.4591, G loss: 2.1730
[724/1762] D loss: 0.6833, G loss: 2.5311
[804/1762] D loss: 0.3689, G loss: 0.9843
[884/1762] D loss: 1.4407, G loss: 2.0988
[964/1762] D loss: 1.7064, G loss: 1.6584
[1044/1762] D loss: 0.7854, G loss: 0.9837
[1124/1762] D loss: 1.0819, G loss: 1.3380
[1204/1762] D loss: 1.4301, G loss: 0.5185
[1284/1762] D loss: 1.2398, G loss: 1.2098
[1364/1762] D loss: 1.4596, G loss: 0.9523
[1444/1762] D loss: 1.4115, G loss: 0.5996
[1524/1762] D loss: 1.0958, G loss: 1.6730
[1604/1762] D loss: 1.0966, G loss: 0.7482
[1684/1762] D loss: 1.2072, G loss: 0.7445
[1762/1762] D loss: 1.3103, G 

A learning rate of 1e-5 for the second discriminator seems to fix the problem. This new discriminator doesn't suffer from any instability, but depending on the training run it either improves its accuracy consistently or stays at about the same accuracy.

However, applying this knowledge to the main GAN training might be tricky, because the generator is constantly adapting. If we lower the learning rate of the discriminator, we should probably lower the learning rate of the generator even more so the discriminator doesn't fall behind too easily.

## Vary discriminator architecture

Let's try a few different discriminator architectures to see how it handles the frozen generator.

In [49]:
def check_disc(disc):
    print(f"{disc.__class__.__name__} has {count_parameters(disc)} parameters.")
    X, y = next(iter(train_dataloader))
    with torch.no_grad():
        pred = disc(X, y)
    print(f"Prediction on real data: {pred}")

check_disc(TetrisDiscriminator().to(device))

TetrisDiscriminator has 5521 parameters.
Prediction on real data: tensor([0.1933, 0.1997, 0.2426, 0.2720])


In [50]:
class DiscWithMoreConvBN(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Flatten(),
                nn.Linear(112, 16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits


check_disc(DiscWithMoreConvBN().to(device))

DiscWithMoreConvBN has 7089 parameters.
Prediction on real data: tensor([-0.1519, -0.0263, -0.0626, -0.3540])


In [51]:
class DiscWithMoreConv(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3),
                nn.ReLU(),
                nn.Flatten(),
                nn.Linear(112, 16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits


check_disc(DiscWithMoreConv().to(device))

DiscWithMoreConv has 7057 parameters.
Prediction on real data: tensor([0.1698, 0.1646, 0.1693, 0.1697])


In [52]:
class DiscWithMoreConvBNPad(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Flatten(),
                nn.Linear(112, 16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits


check_disc(DiscWithMoreConvBNPad().to(device))

DiscWithMoreConvBNPad has 9425 parameters.
Prediction on real data: tensor([ 0.1754,  0.1373, -0.0601,  0.2204])


In [53]:
class DiscWithMoreConvPad(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3),
                nn.ReLU(),
                nn.Flatten(),
                nn.Linear(112, 16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits


check_disc(DiscWithMoreConvPad().to(device))

DiscWithMoreConvPad has 9377 parameters.
Prediction on real data: tensor([0.1670, 0.1669, 0.1654, 0.1663])


In [54]:
def train_and_freeze_gen_and_replace_disc(run_name="", freeze_epoch=5, new_learning_rate=1e-3, disc_cls=TetrisDiscriminator):
    learning_rate = 1e-3
    epochs = 50

    gen = TetrisModel().to(device)
    disc = disc_cls().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_019")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        if epoch == freeze_epoch:
            gen.requires_grad_(False)
            disc = disc_cls().to(device)
            optimizer_disc = torch.optim.Adam(disc.parameters(), lr=new_learning_rate)
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples)
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)

    tb_writer.close()
    print("Done!")

In [55]:
for i in range(2):
    for cls in [DiscWithMoreConvBN, DiscWithMoreConv, DiscWithMoreConvBNPad, DiscWithMoreConvPad]:
        run_name = "fgrd_" + cls.__name__
        train_and_freeze_gen_and_replace_disc(run_name=run_name, disc_cls=cls)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3894, G loss: 0.7718
[84/1762] D loss: 1.1976, G loss: 0.8137
[164/1762] D loss: 0.5615, G loss: 1.5732
[244/1762] D loss: 0.3054, G loss: 2.1695
[324/1762] D loss: 0.1317, G loss: 3.4589
[404/1762] D loss: 0.0304, G loss: 4.5751
[484/1762] D loss: 0.0262, G loss: 4.4166
[564/1762] D loss: 0.0771, G loss: 5.0794
[644/1762] D loss: 0.0238, G loss: 4.7528
[724/1762] D loss: 0.0956, G loss: 6.9500
[804/1762] D loss: 0.0311, G loss: 5.8188
[884/1762] D loss: 0.1161, G loss: 5.9866
[964/1762] D loss: 0.0973, G loss: 4.0344
[1044/1762] D loss: 0.1484, G loss: 5.8313
[1124/1762] D loss: 0.7030, G loss: 5.7809
[1204/1762] D loss: 0.0240, G loss: 4.0951
[1284/1762] D loss: 0.1479, G loss: 3.2149
[1364/1762] D loss: 0.0409, G loss: 4.5501
[1444/1762] D loss: 0.0034, G loss: 6.6371
[1524/1762] D loss: 0.0058, G loss: 7.1821
[1604/1762] D loss: 0.0572, G loss: 5.1598
[1684/1762] D loss: 0.3683, G loss: 3.4146
[1762/1762] D loss: 0.6881, G 

The results here are very clear - batch normalization in the discriminator hurts performance. This is supported by the paper _[A Large-Scale Study on Regularization and Normalization in GANs](https://arxiv.org/pdf/1807.04720.pdf)_, which suggests other normalization schemes that we could try instead. However, just removing the batch norm for now is easiest.

Out of `DiscWithMoreConv` and `DiscWithMoreConvPad`, `DiscWithMoreConvPad` seems to perform better overall, but let's try both architectures in the main GAN training and see which leads to better overall performance.

I also noticed a problem with the "combined accuracy" which explain why it goes above 100%, and show that it might just be a useless metric. We assume that the input to the discriminator D is the thresholded output of the generator G, but in fact the input is the non-thresholded version. This means that the accuracy $ a_G $ used in the calculation is artificially high given the definition of $ a_D $.

# Apply to main training

In [59]:
def train(run_name="", disc_cls=TetrisDiscriminator, epochs=50):
    learning_rate = 1e-3

    gen = TetrisModel().to(device)
    disc = disc_cls().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_019")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples)
        for name, weight in gen.named_parameters():
            tb_writer.add_histogram(f"Generator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Generator gradients/{name}", weight.grad, epoch)
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)

    tb_writer.close()
    print("Done!")

In [58]:
for i in range(2):
    for cls in [DiscWithMoreConv, DiscWithMoreConvPad]:
        train(run_name=cls.__name__, disc_cls=cls)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3892, G loss: 0.6414
[84/1762] D loss: 1.3302, G loss: 0.7201
[164/1762] D loss: 0.7441, G loss: 1.4206
[244/1762] D loss: 0.2290, G loss: 3.1195
[324/1762] D loss: 0.0895, G loss: 3.2415
[404/1762] D loss: 0.1292, G loss: 3.1448
[484/1762] D loss: 0.4249, G loss: 1.3855
[564/1762] D loss: 0.5185, G loss: 0.8465
[644/1762] D loss: 0.5203, G loss: 2.4053
[724/1762] D loss: 1.4124, G loss: 0.3857
[804/1762] D loss: 0.9023, G loss: 1.1524
[884/1762] D loss: 0.7865, G loss: 1.7076
[964/1762] D loss: 0.8257, G loss: 1.6789
[1044/1762] D loss: 0.5351, G loss: 1.1946
[1124/1762] D loss: 0.2774, G loss: 1.2280
[1204/1762] D loss: 1.9897, G loss: 0.5594
[1284/1762] D loss: 1.5844, G loss: 1.7941
[1364/1762] D loss: 1.2731, G loss: 1.6554
[1444/1762] D loss: 1.0932, G loss: 0.9276
[1524/1762] D loss: 1.0267, G loss: 0.9000
[1604/1762] D loss: 0.9021, G loss: 1.0433
[1684/1762] D loss: 1.1549, G loss: 1.7695
[1762/1762] D loss: 0.8986, G 

In terms of discriminator loss, we see a clear improvement in stability with both these architectures. The generator loss also remains stable, as it was before. The generator's board accuracy develops some instability though, which is more pronounced with `DiscWithMoreConvPad`. In terms of spawn recall, most runs stay at 0% for most of the 50 epochs. For one run of `DiscWithMoreConv` and `DiscWithMoreConvPad`, the spawn recall jumps up to about 100% at epoch 5 and subsequently shoots straight back down to around 0%, staying there for the rest of the training run. For one run of `DiscWithMoreConvPad`, the spawn recall unstably climbs up to 100% near the end of the training run and stays there for 3 epochs until the end of the run.

Let's rerun these setups for more epochs.

In [60]:
for i in range(2):
    for cls in [TetrisDiscriminator, DiscWithMoreConv, DiscWithMoreConvPad]:
        train(run_name=cls.__name__, disc_cls=cls, epochs=200)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3851, G loss: 0.7879
[84/1762] D loss: 0.9136, G loss: 1.0055
[164/1762] D loss: 0.4617, G loss: 1.7729
[244/1762] D loss: 0.2613, G loss: 2.3178
[324/1762] D loss: 0.0654, G loss: 3.6756
[404/1762] D loss: 0.0686, G loss: 3.7125
[484/1762] D loss: 0.0520, G loss: 4.7508
[564/1762] D loss: 0.3333, G loss: 4.4908
[644/1762] D loss: 0.4289, G loss: 2.5978
[724/1762] D loss: 0.1807, G loss: 2.1890
[804/1762] D loss: 0.9813, G loss: 1.1784
[884/1762] D loss: 0.2758, G loss: 3.6516
[964/1762] D loss: 0.3733, G loss: 4.9151
[1044/1762] D loss: 0.8165, G loss: 1.1965
[1124/1762] D loss: 0.8963, G loss: 0.9477
[1204/1762] D loss: 1.3434, G loss: 1.1568
[1284/1762] D loss: 1.2089, G loss: 0.9935
[1364/1762] D loss: 1.1547, G loss: 1.0692
[1444/1762] D loss: 1.1294, G loss: 1.0402
[1524/1762] D loss: 1.7733, G loss: 0.3653
[1604/1762] D loss: 1.1462, G loss: 1.7192
[1684/1762] D loss: 1.2643, G loss: 0.8560
[1762/1762] D loss: 0.9819, G 

Even though the two new discriminator architectures don't push the generator to the best possible board accuracy, they still display some desirable properties over the original `TetrisDiscriminator`. The discriminator loss is visibly a lot more stable with the new architectures, and there is even some improvement in the stability of the generator loss! The discriminator accuracy is higher for a given value of the generator board accuracy, which suggests that the discriminator is doing well against the generator, but the generator fails to improve to correct its mistakes.

Looking at the generator gradients, they are basically all zero (or close to zero)! This must be why the generator stops improving. Is it because of dead ReLU?

In contrast, the discriminator clearly has nonzero gradients even until the end of the training process.

Since the two new discriminator architectures have such similar performance, let's pick `DiscWithMoreConv` as it's slightly simpler.

# Conclusion

By freezing one model and training the other, we diagnosed a stability problem in the discriminator. The new `DiscWithMoreConv` architecture solves this by having a larger capacity and no batch normalization. We should update the main GAN training notebook to use this new architecture.

As a next step, we should figure out why the generator stops improving. Perhaps it is to with dying ReLU in either the discriminator or the generator. (Dying ReLU in the discriminator could cause this, because backpropagation for the generator always goes via the discriminator.)