# Experiment 018

In this experiment, we will try some common "GAN hacks" to try improve the board accuracy of the generator.

In [1]:
import os
from pathlib import Path
import shutil
import datetime

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

In [2]:
class RecordingDataset(Dataset):
    def __init__(self, path: str):
        self.path = path
        if not os.path.exists(path):
            raise FileNotFoundError()
        with os.scandir(self.path) as it:
            entry: os.DirEntry = next(iter(it))
            _, self.ext = os.path.splitext(entry.name)
            self.highest_index = max((int(Path(file.path).stem) for file in it), default=-1)

    def __len__(self):
        return self.highest_index + 1

    def __getitem__(self, idx):
        file = os.path.join(self.path, f"{idx}{self.ext}")
        if not os.path.exists(file):
            raise IndexError()
        boards = np.load(file)

        def transform(board):
            board = torch.tensor(board, dtype=torch.long)
            board = F.one_hot(board, 2) # One-hot encode the cell types
            board = board.type(torch.float) # Convert to floating-point
            board = board.permute((2, 0, 1)) # Move channels/classes to dimension 0
            return board

        x = transform(boards[-2]) # Ignore all boards except the last two
        y = transform(boards[-1])
        return x, y
        

In [3]:
train_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "train"))
test_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "test"))
batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

x, y = next(iter(train_dataloader))
print(x.shape, x.dtype)
print(y.shape, y.dtype)

torch.Size([4, 2, 22, 10]) torch.float32
torch.Size([4, 2, 22, 10]) torch.float32


In [4]:
device = (
    "cuda" if torch.cuda.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cpu device


In [5]:
class TetrisModel(nn.Module):
    """Predicts the next state of the cells.

    Inputs:
        x: Tensor of float32 of shape (batch_size, channels, height, width). channels = 2 is the one-hot encoding of cell types, with
           0 for empty cells and 1 for filled cells. height = 22 and width = 10 are the dimensions of the game board. The entries
           should be 0 for empty cells and 1 for filled cells.
        z: Tensor of float32 of shape (batch_size, 4). The entries should be random numbers sampled from a uniform distribution.
    
    Returns: Tensor of float32 of shape (batch_size, height, width), logits for the new cells. Probabilities close to 0 (negative logits)
             correspond to empty cells, and probabilities close to 1 (positive logits) correspond to filled cells.
    """

    def __init__(self):
        super().__init__()
        self.loc = nn.Sequential(
            nn.Conv2d(6, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
        )
        self.glob = nn.Sequential(
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(160, 10)
        )
        self.head = nn.Sequential(
            nn.Conv2d(26, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 2, 1),
            nn.Softmax(dim=1)
        )

    def forward(self, x, z):
        batch_size, channels, height, width = x.shape
        
        z = z[:, :, None, None] # Expand dims to match x
        z = z.repeat(1, 1, height, width) # Upscale to image size
        x = torch.cat((x, z), dim=1)

        x = self.loc(x)

        x_glob = self.glob(x)
        x_glob = x_glob[:, :, None, None] # Expand dims
        x_glob = x_glob.repeat(1, 1, height, width) # Upscale to image size
        x = torch.cat((x, x_glob), dim=1)

        y = self.head(x)
        return y

In [6]:
class TetrisDiscriminator(nn.Module):
    """A discriminator for the cell state predictions. Assesses the output of the generator.

    Inputs:
        x: Tensor of float32 of shape (batch_size, channels, height, width). channels = 2 is the one-hot encoding of cell types, with
           0 for empty cells and 1 for filled cells. height = 22 and width = 10 are the dimensions of the game board. The entries
           should be 0 for empty cells and 1 for filled cells.
        y: Tensor of float32 of shape (batch_size, channels, height, width), as with x. This should be either the output of the
           generator (with exp applied) or the one-hot encoding of the ground truth of the next cell states.
    
    Returns: Tensor of float32 of shape (batch_size, 1), decisions on whether the data are real or fake. Probabilities close to 0 (negative logits)
             correspond to fake data, and probabilities close to 1 (positive logits) correspond to real data.
    """

    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits

In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def check_disc(disc):
    gen = TetrisModel().to(device)

    with torch.no_grad():
        X, y = next(iter(train_dataloader))
        z = torch.rand(batch_size, 4)
        y_gen = gen(X, z)
        pred_on_real = F.sigmoid(disc(X, y)[0])
        pred_on_fake = F.sigmoid(disc(X, y_gen)[0])
        print(f"Number of discriminator parameters: {count_parameters(disc)}")
        print(f"Predicted label for real data: {pred_on_real}")
        print(f"Predicted label for fake data: {pred_on_fake}")

In [8]:
import itertools

def find_interesting_examples(dataset, num=4):
    num_spawns = num // 2
    num_normal = num - num_spawns
    
    def inner():
        num_spawns_left = num_spawns
        num_normal_left = num_normal

        for x, y in dataset:
            # Check for block spawn
            if (x.argmax(0)[0] == 0).all() & (y.argmax(0)[0] == 1).any():
                if num_spawns_left > 0:
                    num_spawns_left -= 1
                    yield x, y
                else:
                    continue
            # Yield general examples
            if num_normal_left > 0:
                num_normal_left -= 1
                yield x, y
            else:
                continue
            
    return list(itertools.islice(inner(), num))

In [9]:
def render_prediction(x, pred, y):
    """Renders an example and prediction into a single-image array.
    
    Inputs:
        x: Tensor of shape (height, width), the model input.
        pred: Tensor of shape (height, width), the model prediction.
        y: Tensor of shape (height, width), the target.
    """
    assert len(x.shape) == 2, f"Expected tensors of shape (width, height) but got {x.shape}"
    assert x.shape == pred.shape, f"Shapes do not match: {x.shape} != {pred.shape}"
    assert x.shape == y.shape, f"Shapes do not match: {x.shape} != {y.shape}"
    height, width = x.shape
    with torch.no_grad():
        separator = torch.ones(height, 1, dtype=x.dtype)
        return torch.cat((x, separator, pred, separator, y), dim=-1)

In [10]:
real_label = 1.0
fake_label = 0.0

def train_loop(dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc):
    gen.train()
    disc.train()

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(train_dataloader):
        ##################################################################
        # (1) Update discriminator: minimize -log(D(x)) - log(1 - D(G(z)))
        ##################################################################
        disc.zero_grad()

        ## Train with all-real batch
        # Format batch
        X, y = X.to(device), y.to(device)
        batch_size = X.size(0)
        real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through discriminator
        output = torch.flatten(disc(X, y))
        # Calculate loss on all-real batch
        err_disc_real = loss_fn(output, real_labels)
        # Calculate gradients for discriminator in backward pass
        err_disc_real.backward()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        z = torch.rand(batch_size, 4, device=device)
        # Generate fake image batch with generator
        y_fake = gen(X, z)
        fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)
        # Classify all fake batch with discriminator
        output = torch.flatten(disc(X, y_fake.detach()))
        # Calculate discriminator's loss on the all-fake batch
        err_disc_fake = loss_fn(output, fake_labels)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        err_disc_fake.backward()

        ## Update discriminator weights
        # Compute error of discriminator as sum over the fake and the real batches
        err_disc = err_disc_real + err_disc_fake
        # Update discriminator
        optimizer_disc.step()

        ##############################################
        # (2) Update generator: minimize -log(D(G(z)))
        ##############################################
        gen.zero_grad()
        # Since we just updated the discriminator, perform another forward pass of the all-fake batch through it
        output = torch.flatten(disc(X, y_fake))
        # Calculate the generator's loss based on this output
        # We use real labels because the generator wants to fool the discriminator
        err_gen = loss_fn(output, real_labels)
        # Calculate gradients for generator
        err_gen.backward()
        # Update generator
        optimizer_gen.step()

        # Output training stats
        if batch % 20 == 0:
            current = batch * dataloader.batch_size + batch_size
            print(f"[{current}/{size}] D loss: {err_disc.item():.4f}, G loss: {err_gen.item():.4f}")


def test_loop(split_name, dataloader, gen, disc, loss_fn, tb_writer, epoch, examples):
    gen.eval()
    disc.eval()

    loss_disc = 0.0
    loss_gen = 0.0
    disc_accuracy = 0.0
    cell_accuracy = 0.0
    board_accuracy = 0.0
    spawn_recall = 0.0
    num_spawns = 0.0
    spawn_precision = 0.0
    num_predicted_spawns = 0.0

    num_batches = len(dataloader)
    with torch.no_grad():        
        for X, y in dataloader:
            batch_size = X.size(0)
            real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)

            output_real = disc(X, y)
            loss_disc += loss_fn(output_real, real_labels).item()

            z = torch.rand(batch_size, 4, device=device)
            y_fake = gen(X, z)
            output_fake = disc(X, y_fake)
            
            loss_disc += loss_fn(output_fake, fake_labels).item()
            loss_gen += loss_fn(output_fake, real_labels).item()

            pred_real = (output_real > 0.0)
            pred_fake = (output_fake > 0.0)
            disc_accuracy += pred_real.type(torch.float).mean().item()
            disc_accuracy += (~pred_fake).type(torch.float).mean().item()

            classes_X = torch.argmax(X, dim=1)
            classes_y = torch.argmax(y, dim=1)
            classes_y_fake = torch.argmax(y_fake, dim=1)
            cell_accuracy += (classes_y_fake == classes_y).type(torch.float).mean().item()
            board_accuracy += (classes_y_fake == classes_y).all(-1).all(-1).type(torch.float).mean().item()

            actual_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y[:, 0, :] == 1).any(-1)
            predicted_spawns = (classes_y_fake[:, 0, :] == 1).any(-1)
            num_true_positives = (actual_spawns & predicted_spawns).type(torch.float).sum().item()
            spawn_recall += num_true_positives
            num_spawns += actual_spawns.type(torch.float).sum().item()
            spawn_precision += num_true_positives
            num_predicted_spawns += predicted_spawns.type(torch.float).sum().item()

    loss_disc /= num_batches
    loss_gen /= num_batches
    cell_accuracy /= num_batches
    board_accuracy /= num_batches
    spawn_recall /= num_spawns
    spawn_precision = np.nan if (num_predicted_spawns == 0.0) else (spawn_precision / num_predicted_spawns)
    disc_accuracy /= (2.0 * num_batches)

    print(f"{split_name} error: \n D loss: {loss_disc:>8f}, G loss: {loss_gen:>8f}, D accuracy: {(100*disc_accuracy):>0.1f}%, cell accuracy: {(100*cell_accuracy):>0.1f}%, board accuracy: {(100*board_accuracy):>0.1f}% \n")

    tb_writer.add_scalar(f"Discriminator loss/{split_name}", loss_disc, epoch)
    tb_writer.add_scalar(f"Loss/{split_name}", loss_gen, epoch)
    tb_writer.add_scalar(f"Discriminator accuracy/{split_name}", disc_accuracy, epoch)
    tb_writer.add_scalar(f"Cell accuracy/{split_name}", cell_accuracy, epoch)
    tb_writer.add_scalar(f"Board accuracy/{split_name}", board_accuracy, epoch)
    tb_writer.add_scalar(f"Spawn recall/{split_name}", spawn_recall, epoch)

    with torch.no_grad():
        for i, (X, y) in enumerate(examples):
            X, y = X.unsqueeze(0), y.unsqueeze(0)
            z = torch.rand(1, 4, device=device)
            y_fake = gen(X, z)
            X, y, y_fake = X.squeeze(0), y.squeeze(0), y_fake.squeeze(0)
            X, y, y_fake = X.argmax(0), y.argmax(0), y_fake.argmax(0)
            img = render_prediction(X, y_fake, y)
            tb_writer.add_image(f"Predictions/{split_name}/{i}", img, epoch, dataformats="HW")


In [13]:
def train(run_name="", disc_cls=TetrisDiscriminator):
    learning_rate = 1e-2
    epochs = 50

    gen = TetrisModel().to(device)
    disc = disc_cls().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.SGD(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.SGD(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_018")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples)
        for name, weight in gen.named_parameters():
            tb_writer.add_histogram(f"Weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Gradients/{name}", weight.grad, epoch)
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)

    tb_writer.close()
    print("Done!")

# Leaky ReLU

In [14]:
class DiscWithLeakyReLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1),
                nn.LeakyReLU(0.1),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.LeakyReLU(0.1),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 16),
                nn.LeakyReLU(0.1),
                nn.Linear(16, 16),
                nn.LeakyReLU(0.1),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits


check_disc(DiscWithLeakyReLU().to(device))

Number of discriminator parameters: 5793
Predicted label for real data: 0.5493099093437195
Predicted label for fake data: 0.5405417680740356


In [15]:
for cls in [TetrisDiscriminator, DiscWithLeakyReLU]:
    train(run_name=cls.__name__, disc_cls=cls)

Epoch 0
-------------------------------
[4/1762] D loss: 1.4854, G loss: 0.5560
[84/1762] D loss: 0.8090, G loss: 1.1255
[164/1762] D loss: 0.9261, G loss: 0.6823
[244/1762] D loss: 0.6851, G loss: 0.7790
[324/1762] D loss: 0.9878, G loss: 0.4966
[404/1762] D loss: 0.7942, G loss: 2.0219
[484/1762] D loss: 0.9289, G loss: 1.5304
[564/1762] D loss: 1.3108, G loss: 0.3322
[644/1762] D loss: 0.9893, G loss: 0.6577
[724/1762] D loss: 1.1333, G loss: 0.5262
[804/1762] D loss: 1.3874, G loss: 1.8283
[884/1762] D loss: 1.1802, G loss: 0.9525
[964/1762] D loss: 1.1824, G loss: 1.0871
[1044/1762] D loss: 1.2453, G loss: 1.1719
[1124/1762] D loss: 1.5359, G loss: 1.3899
[1204/1762] D loss: 1.3761, G loss: 0.7429
[1284/1762] D loss: 1.2421, G loss: 0.9432
[1364/1762] D loss: 1.7282, G loss: 1.6486
[1444/1762] D loss: 1.3544, G loss: 0.9336
[1524/1762] D loss: 1.3495, G loss: 0.7664
[1604/1762] D loss: 1.3690, G loss: 0.7975
[1684/1762] D loss: 1.4459, G loss: 0.5958
[1762/1762] D loss: 1.4248, G 

Leaky ReLU doesn't seem to improve anything.

# Use Adam

What??? Apparently we've been using SGD instead of Adam this whole time! Let's fix that.

In [28]:
def train(run_name="", disc_cls=TetrisDiscriminator, learning_rate=1e-2, beta1=0.9, batch_size=4):
    epochs = 50

    gen = TetrisModel().to(device)
    disc = disc_cls().to(device)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate, betas=(beta1, 0.999))
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate, betas=(beta1, 0.999))

    log_dir = os.path.join("runs", "experiment_018")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples)
        for name, weight in gen.named_parameters():
            tb_writer.add_histogram(f"Weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Gradients/{name}", weight.grad, epoch)
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)

    tb_writer.close()
    print("Done!")

In [17]:
for i in range(2):
    train(run_name="adam")

Epoch 0
-------------------------------
[4/1762] D loss: 1.3761, G loss: 0.8881
[84/1762] D loss: 0.0480, G loss: 9.0529
[164/1762] D loss: 0.7109, G loss: 1.9697
[244/1762] D loss: 0.0792, G loss: 7.8136
[324/1762] D loss: 0.1144, G loss: 8.3762
[404/1762] D loss: 1.4968, G loss: 1.7351
[484/1762] D loss: 0.4892, G loss: 4.0273
[564/1762] D loss: 0.8031, G loss: 1.2203
[644/1762] D loss: 0.1756, G loss: 3.2178
[724/1762] D loss: 0.6650, G loss: 4.5854
[804/1762] D loss: 0.0932, G loss: 8.4619
[884/1762] D loss: 0.1161, G loss: 4.6898
[964/1762] D loss: 0.2300, G loss: 2.9076
[1044/1762] D loss: 0.5348, G loss: 2.9772
[1124/1762] D loss: 1.1148, G loss: 6.1005
[1204/1762] D loss: 0.2612, G loss: 3.7832
[1284/1762] D loss: 0.1535, G loss: 4.6657
[1364/1762] D loss: 0.0946, G loss: 5.2144
[1444/1762] D loss: 0.2621, G loss: 2.9920
[1524/1762] D loss: 0.3452, G loss: 3.1620
[1604/1762] D loss: 1.0174, G loss: 1.4952
[1684/1762] D loss: 0.2567, G loss: 3.9271
[1762/1762] D loss: 0.2069, G 

Adam seems to have some significant instability; let's try reducing the learning rate.

In [19]:
for learning_rate in [1e-3, 1e-4, 1e-5]:
    run_name = "adam_lr_" + f"{learning_rate:.0e}".replace("-", "m")
    train(run_name=run_name, learning_rate=learning_rate)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3816, G loss: 0.6912
[84/1762] D loss: 1.0127, G loss: 0.9967
[164/1762] D loss: 0.6188, G loss: 1.5769
[244/1762] D loss: 0.1177, G loss: 3.6155
[324/1762] D loss: 0.1746, G loss: 3.0600
[404/1762] D loss: 0.1901, G loss: 2.9943
[484/1762] D loss: 0.3561, G loss: 3.5213
[564/1762] D loss: 0.2798, G loss: 2.6448
[644/1762] D loss: 0.1786, G loss: 3.7024
[724/1762] D loss: 0.6618, G loss: 1.5532
[804/1762] D loss: 1.2984, G loss: 2.0507
[884/1762] D loss: 0.7367, G loss: 1.3706
[964/1762] D loss: 1.0316, G loss: 1.0764
[1044/1762] D loss: 0.6845, G loss: 0.7339
[1124/1762] D loss: 1.2141, G loss: 0.5853
[1204/1762] D loss: 1.6664, G loss: 1.2998
[1284/1762] D loss: 1.2940, G loss: 0.8930
[1364/1762] D loss: 1.3494, G loss: 1.6234
[1444/1762] D loss: 1.1270, G loss: 0.8822
[1524/1762] D loss: 1.0956, G loss: 1.1624
[1604/1762] D loss: 1.0677, G loss: 0.6873
[1684/1762] D loss: 1.3318, G loss: 0.7817
[1762/1762] D loss: 0.2863, G 

So far, Adam doesn't seem to be helping much. Let's try setting beta1 to 0.5 as per the DCGAN paper.

In [22]:
#for learning_rate in [1e-3, 1e-4, 1e-5]:
for learning_rate in [1e-3]:
    run_name = "adam_lr_" + f"{learning_rate:.0e}".replace("-", "m") + "_beta1_0p5"
    train(run_name=run_name, learning_rate=learning_rate, beta1=0.5)

Epoch 0
-------------------------------
[4/1762] D loss: 1.4119, G loss: 0.8702
[84/1762] D loss: 0.6370, G loss: 1.4336
[164/1762] D loss: 0.1705, G loss: 2.9810
[244/1762] D loss: 0.0605, G loss: 3.9275
[324/1762] D loss: 0.4647, G loss: 1.7480
[404/1762] D loss: 2.5029, G loss: 0.1242
[484/1762] D loss: 0.4363, G loss: 1.2609
[564/1762] D loss: 0.4308, G loss: 3.0727
[644/1762] D loss: 0.9062, G loss: 0.5194
[724/1762] D loss: 0.6781, G loss: 1.5322
[804/1762] D loss: 1.4525, G loss: 1.5776
[884/1762] D loss: 1.3220, G loss: 3.1353
[964/1762] D loss: 1.5245, G loss: 1.3114
[1044/1762] D loss: 0.9586, G loss: 1.3693
[1124/1762] D loss: 1.4398, G loss: 1.6562
[1204/1762] D loss: 1.2565, G loss: 0.4819
[1284/1762] D loss: 1.0660, G loss: 0.7753
[1364/1762] D loss: 1.1739, G loss: 1.1530
[1444/1762] D loss: 1.0411, G loss: 0.5582
[1524/1762] D loss: 1.2335, G loss: 0.7985
[1604/1762] D loss: 1.3291, G loss: 0.5242
[1684/1762] D loss: 1.0776, G loss: 1.0718
[1762/1762] D loss: 1.4934, G 

The change to beta1 seems not to make too much difference.

Let's keep using Adam as it's supposed to be less sensitive to the exact choice of learning rate than SGD.

# Remove linear layer

One GAN hack suggested not having linear layers at the end of the discriminator, or at least not having multiple. Let's try removing one of the linear layers from the top of the discriminator.

In [26]:
class DiscWithLessLinear(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits

In [27]:
for cls in [DiscWithLessLinear]:
    train(run_name=cls.__name__ + "_lr_1em03", learning_rate=1e-3, disc_cls=cls)

Epoch 0
-------------------------------
[4/1762] D loss: 1.6167, G loss: 1.1186
[84/1762] D loss: 0.8856, G loss: 1.2471
[164/1762] D loss: 0.5144, G loss: 1.5778
[244/1762] D loss: 0.2832, G loss: 2.2208
[324/1762] D loss: 0.1694, G loss: 2.9587
[404/1762] D loss: 0.1285, G loss: 2.7698
[484/1762] D loss: 0.1271, G loss: 3.2402
[564/1762] D loss: 0.2399, G loss: 1.8456
[644/1762] D loss: 0.3772, G loss: 2.6360
[724/1762] D loss: 0.5194, G loss: 2.0709
[804/1762] D loss: 1.0834, G loss: 1.9496
[884/1762] D loss: 0.3965, G loss: 1.4100
[964/1762] D loss: 0.7231, G loss: 1.6464
[1044/1762] D loss: 0.2092, G loss: 1.9423
[1124/1762] D loss: 1.5056, G loss: 2.5327
[1204/1762] D loss: 1.4458, G loss: 1.6801
[1284/1762] D loss: 0.9211, G loss: 1.3423
[1364/1762] D loss: 1.2521, G loss: 0.9912
[1444/1762] D loss: 1.1146, G loss: 0.8716
[1524/1762] D loss: 1.3431, G loss: 0.6098
[1604/1762] D loss: 1.0341, G loss: 1.3978
[1684/1762] D loss: 1.5701, G loss: 1.2832
[1762/1762] D loss: 1.2502, G 

Nope, no difference.

# Batch size

Just in case it makes a difference, let's try using a larger batch size.

In [29]:
for i in range(2):
    for bs in [8, 16]:
        train(run_name=f"bs_{bs}", learning_rate=1e-3, batch_size=bs)

Epoch 0
-------------------------------
[4/1762] D loss: 1.4096, G loss: 0.8376
[164/1762] D loss: 0.9216, G loss: 0.9966
[324/1762] D loss: 0.4052, G loss: 1.7929
[484/1762] D loss: 0.0583, G loss: 3.3663
[644/1762] D loss: 0.1307, G loss: 4.0234
[804/1762] D loss: 0.0342, G loss: 4.1476
[964/1762] D loss: 0.0555, G loss: 3.8795
[1124/1762] D loss: 0.0362, G loss: 4.2559
[1284/1762] D loss: 1.7508, G loss: 0.8830
[1444/1762] D loss: 0.4192, G loss: 2.6205
[1604/1762] D loss: 0.1526, G loss: 4.0958
[1764/1762] D loss: 0.4872, G loss: 2.8617
[1924/1762] D loss: 0.3653, G loss: 1.5350
[2084/1762] D loss: 0.8694, G loss: 1.4778
[2244/1762] D loss: 0.2608, G loss: 3.4962
[2404/1762] D loss: 0.2585, G loss: 1.6643
[2564/1762] D loss: 0.7230, G loss: 1.7321
[2724/1762] D loss: 1.5001, G loss: 2.3424
[2884/1762] D loss: 0.7781, G loss: 1.7318
[3044/1762] D loss: 0.9170, G loss: 2.0009
[3204/1762] D loss: 1.3824, G loss: 1.0525
[3364/1762] D loss: 1.3952, G loss: 0.6370
[3522/1762] D loss: 0.8

Increasing the batch size doesn't help.

# Conclusion

None of the "GAN hacks" seem to work. We'll have to find another approach of improving the models.