# Experiment 014

In this experiment, we test the hypothesis that the discriminator is not learning anything useful by training the same architecture with the same hyperparameters on random noise.

In [27]:
import os
from pathlib import Path
import shutil
import datetime

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

In [28]:
class RecordingDataset(Dataset):
    def __init__(self, path: str):
        self.path = path
        if not os.path.exists(path):
            raise FileNotFoundError()
        with os.scandir(self.path) as it:
            entry: os.DirEntry = next(iter(it))
            _, self.ext = os.path.splitext(entry.name)
            self.highest_index = max((int(Path(file.path).stem) for file in it), default=-1)

    def __len__(self):
        return self.highest_index + 1

    def __getitem__(self, idx):
        file = os.path.join(self.path, f"{idx}{self.ext}")
        if not os.path.exists(file):
            raise IndexError()
        boards = np.load(file)

        def transform(board):
            board = torch.tensor(board, dtype=torch.long)
            board = F.one_hot(board, 2) # One-hot encode the cell types
            board = board.type(torch.float) # Convert to floating-point
            board = board.permute((2, 0, 1)) # Move channels/classes to dimension 0
            return board

        x = transform(boards[-2]) # Ignore all boards except the last two
        y = transform(boards[-1])
        return x, y
        

In [29]:
train_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "train"))
test_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "test"))
batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

x, y = next(iter(train_dataloader))
print(x.shape, x.dtype)
print(y.shape, y.dtype)

torch.Size([4, 2, 22, 10]) torch.float32
torch.Size([4, 2, 22, 10]) torch.float32


In [30]:
device = (
    "cuda" if torch.cuda.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cpu device


In [31]:
class TetrisDiscriminator(nn.Module):
    """A discriminator for the cell state predictions. Assesses the output of the generator.

    Inputs:
        x: Tensor of float32 of shape (batch_size, channels, height, width). channels = 2 is the one-hot encoding of cell types, with
           0 for empty cells and 1 for filled cells. height = 22 and width = 10 are the dimensions of the game board. The entries
           should be 0 for empty cells and 1 for filled cells.
        y: Tensor of float32 of shape (batch_size, channels, height, width), as with x. This should be either the output of the
           generator (with exp applied) or the one-hot encoding of the ground truth of the next cell states.
    
    Returns: Tensor of float32 of shape (batch_size, 1), decisions on whether the data are real or fake. Probabilities close to 0 (negative logits)
             correspond to fake data, and probabilities close to 1 (positive logits) correspond to real data.
    """

    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 16, bias=False),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits

In [53]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def check_model(model):
    with torch.no_grad():
        X, y = next(iter(train_dataloader))
        batch_size, channels, height, width = X.shape
        y_fake = torch.rand(batch_size, channels, height, width)
        pred_on_real = F.sigmoid(model(X, y)[0])
        pred_on_fake = F.sigmoid(model(X, y_fake)[0])
        print(f"Number of discriminator parameters: {count_parameters(model)}")
        print(f"Predicted label for real data: {pred_on_real}")
        print(f"Predicted label for fake data: {pred_on_fake}")

In [54]:
disc = TetrisDiscriminator().to(device)
check_model(disc)

Number of discriminator parameters: 5553
Predicted label for real data: 0.5012620687484741
Predicted label for fake data: 0.5037605166435242


In [34]:
real_label = 1.0
fake_label = 0.0

def train_loop(dataloader, disc, loss_fn, optimizer_disc):
    disc.train()

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(train_dataloader):
        ##################################################################
        # (1) Update discriminator: minimize -log(D(x)) - log(1 - D(G(z)))
        ##################################################################
        disc.zero_grad()

        ## Train with all-real batch
        # Format batch
        X, y = X.to(device), y.to(device)
        batch_size, channels, height, width = X.shape
        real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through discriminator
        output = torch.flatten(disc(X, y))
        # Calculate loss on all-real batch
        err_disc_real = loss_fn(output, real_labels)
        # Calculate gradients for discriminator in backward pass
        err_disc_real.backward()

        ## Train with all-fake batch
        # Generate fake image batch as random noise
        y_fake = torch.rand(batch_size, channels, height, width)
        fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)
        # Classify all fake batch with discriminator
        output = torch.flatten(disc(X, y_fake.detach()))
        # Calculate discriminator's loss on the all-fake batch
        err_disc_fake = loss_fn(output, fake_labels)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        err_disc_fake.backward()

        ## Update discriminator weights
        # Compute error of discriminator as sum over the fake and the real batches
        err_disc = err_disc_real + err_disc_fake
        # Update discriminator
        optimizer_disc.step()

        # Output training stats
        if batch % 20 == 0:
            current = batch * dataloader.batch_size + batch_size
            print(f"[{current}/{size}] D loss: {err_disc.item():.4f}")


def test_loop(split_name, dataloader, disc, loss_fn, tb_writer, epoch):
    disc.eval()

    loss_disc = 0.0
    disc_accuracy = 0.0

    num_batches = len(dataloader)
    with torch.no_grad():        
        for X, y in dataloader:
            batch_size, channels, height, width = X.shape
            real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)

            output_real = disc(X, y)
            loss_disc += loss_fn(output_real, real_labels).item()

            y_fake = torch.rand(batch_size, channels, height, width)
            output_fake = disc(X, y_fake)
            
            loss_disc += loss_fn(output_fake, fake_labels).item()

            pred_real = (output_real > 0.0)
            pred_fake = (output_fake > 0.0)
            disc_accuracy += pred_real.type(torch.float).mean().item()
            disc_accuracy += (~pred_fake).type(torch.float).mean().item()

    loss_disc /= num_batches
    disc_accuracy /= (2.0 * num_batches)

    print(f"{split_name} error: \n D loss: {loss_disc:>8f}, D accuracy: {(100*disc_accuracy):>0.1f}% \n")

    tb_writer.add_scalar(f"Discriminator loss/{split_name}", loss_disc, epoch)
    tb_writer.add_scalar(f"Discriminator accuracy/{split_name}", disc_accuracy, epoch)


In [49]:
def train(run_name="", cls=TetrisDiscriminator, learning_rate=1e-1):
    epochs = 100

    disc = cls().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_disc = torch.optim.SGD(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_014")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, disc, loss_fn, optimizer_disc)
        test_loop("train", train_dataloader, disc, loss_fn, tb_writer, epoch)
        test_loop("test", test_dataloader, disc, loss_fn, tb_writer, epoch)
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)

    tb_writer.close()
    print("Done!")

In [38]:
train()

Epoch 0
-------------------------------
[4/1762] D loss: 1.4138
[84/1762] D loss: 1.3968
[164/1762] D loss: 1.3868
[244/1762] D loss: 1.3822
[324/1762] D loss: 1.3919
[404/1762] D loss: 1.3794
[484/1762] D loss: 1.3876
[564/1762] D loss: 1.3849
[644/1762] D loss: 1.3914
[724/1762] D loss: 1.3882
[804/1762] D loss: 1.3844
[884/1762] D loss: 1.3899
[964/1762] D loss: 1.3881
[1044/1762] D loss: 1.3825
[1124/1762] D loss: 1.3812
[1204/1762] D loss: 1.3892
[1284/1762] D loss: 1.3882
[1364/1762] D loss: 1.3838
[1444/1762] D loss: 1.3873
[1524/1762] D loss: 1.3862
[1604/1762] D loss: 1.3840
[1684/1762] D loss: 1.3782


KeyboardInterrupt: 

The discriminator loss starts roughly flat for the first 5 epochs, then starts oscillating. Sometimes it jumps up drastically. The same pattern can be seen in both the training and test losses. The final value after 50 epochs is not much lower than the value at epoch 0, if at all. This means the discriminator is failing to tell the difference between the next frame and random noise.

Let's try reducing the learning rate to see if it avoids the spikes.

# Learning rate

In [40]:
train(run_name="lr_1em2", learning_rate=1e-2)

Epoch 0
-------------------------------
[4/1762] D loss: 1.4328
[84/1762] D loss: 1.3760
[164/1762] D loss: 1.4027
[244/1762] D loss: 1.3879
[324/1762] D loss: 1.4245
[404/1762] D loss: 1.3937
[484/1762] D loss: 1.3966
[564/1762] D loss: 1.4006
[644/1762] D loss: 1.4135
[724/1762] D loss: 1.4033
[804/1762] D loss: 1.3756
[884/1762] D loss: 1.3870
[964/1762] D loss: 1.4014
[1044/1762] D loss: 1.3949
[1124/1762] D loss: 1.3927
[1204/1762] D loss: 1.3889
[1284/1762] D loss: 1.3852
[1364/1762] D loss: 1.3818
[1444/1762] D loss: 1.3911
[1524/1762] D loss: 1.3786
[1604/1762] D loss: 1.3991
[1684/1762] D loss: 1.3852
[1762/1762] D loss: 1.3892
train error: 
 D loss: 1.469195, D accuracy: 28.2% 

test error: 
 D loss: 1.463980, D accuracy: 30.2% 

Epoch 1
-------------------------------
[4/1762] D loss: 1.4052
[84/1762] D loss: 1.4016
[164/1762] D loss: 1.3852
[244/1762] D loss: 1.3785
[324/1762] D loss: 1.4004
[404/1762] D loss: 1.3835
[484/1762] D loss: 1.3979
[564/1762] D loss: 1.3854
[644/

Now the discriminator loss barely changes up to about epoch 40 when it starts oscillating, then it has huge spikes around epochs 60 and 80. The accuracy oscillates around 50% and never goes significantly above 80%.

The lack of loss change at the start might mean that the model doesn't have enough capacity for the task.

# Model architecture

In [55]:
class DeeperDisc(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 16, bias=False),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits


check_model(DeeperDisc().to(device))

Number of discriminator parameters: 10225
Predicted label for real data: 0.5207337141036987
Predicted label for fake data: 0.3333703279495239


In [56]:
class WiderDisc(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 24, 3, padding=1, bias=False),
                nn.BatchNorm2d(24),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(24, 24, 3, padding=1, bias=False),
                nn.BatchNorm2d(24),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(240, 16, bias=False),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits


check_model(WiderDisc().to(device))

Number of discriminator parameters: 10033
Predicted label for real data: 0.5722304582595825
Predicted label for fake data: 0.5735625624656677


In [48]:
for cls in [TetrisDiscriminator, DeeperDisc, WiderDisc]:
    print(f"Model architecture {cls.__name__} has {count_parameters(cls())} parameters.")

Model architecture TetrisDiscriminator has 5553 parameters.
Model architecture DeeperDisc has 10225 parameters.
Model architecture WiderDisc has 10033 parameters.


In [52]:
for repeat in range(2):
    for cls in [DeeperDisc, WiderDisc]:
        train(run_name=cls.__name__, cls=cls, learning_rate=1e-2)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3913
[84/1762] D loss: 1.3920
[164/1762] D loss: 1.4067
[244/1762] D loss: 1.4008
[324/1762] D loss: 1.3731
[404/1762] D loss: 1.4017
[484/1762] D loss: 1.4178
[564/1762] D loss: 1.3987
[644/1762] D loss: 1.4003
[724/1762] D loss: 1.3894
[804/1762] D loss: 1.3892
[884/1762] D loss: 1.3924
[964/1762] D loss: 1.3894
[1044/1762] D loss: 1.3876
[1124/1762] D loss: 1.3996
[1204/1762] D loss: 1.4029
[1284/1762] D loss: 1.3947
[1364/1762] D loss: 1.3889
[1444/1762] D loss: 1.3876
[1524/1762] D loss: 1.3832
[1604/1762] D loss: 1.3976
[1684/1762] D loss: 1.3898
[1762/1762] D loss: 1.3934
train error: 
 D loss: 1.616359, D accuracy: 36.6% 

test error: 
 D loss: 1.614720, D accuracy: 37.2% 

Epoch 1
-------------------------------
[4/1762] D loss: 1.4009
[84/1762] D loss: 1.4038
[164/1762] D loss: 1.3878
[244/1762] D loss: 1.4059
[324/1762] D loss: 1.3845
[404/1762] D loss: 1.3841
[484/1762] D loss: 1.4031
[564/1762] D loss: 1.3857
[644/

The deeper and wider models don't do any better. Those and the original model all oscillate wildly in terms of both loss and accuracy. The wider model seems to oscillate a lot less in terms of loss than the other two architectures.

Let's try some more architectures.

In [57]:
class DiscWithLessBatchNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1, bias=False),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 16, bias=False),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits


check_model(DiscWithLessBatchNorm().to(device))

Number of discriminator parameters: 5521
Predicted label for real data: 0.4289992153644562
Predicted label for fake data: 0.44049957394599915


In [61]:
class DiscWithSeparateFrameProcessing(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(2, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(16, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
        )
        self.cmp = nn.Sequential(
            nn.Linear(320, 16, bias=False),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Flatten(start_dim=0)
        )
    
    def forward(self, x, y):
        x = self.enc(x)
        y = self.enc(y)
        s = torch.cat((x, y), dim=1)
        logits = self.cmp(s)
        return logits


check_model(DiscWithSeparateFrameProcessing().to(device))

Number of discriminator parameters: 7825
Predicted label for real data: 0.4225638806819916
Predicted label for fake data: 0.40801844000816345


In [62]:
for repeat in range(2):
    for cls in [DiscWithLessBatchNorm, DiscWithSeparateFrameProcessing]:
        train(run_name=cls.__name__, cls=cls, learning_rate=1e-2)

Epoch 0
-------------------------------
[4/1762] D loss: 1.5778
[84/1762] D loss: 1.4600
[164/1762] D loss: 1.4365
[244/1762] D loss: 1.4034
[324/1762] D loss: 1.4003
[404/1762] D loss: 1.3981
[484/1762] D loss: 1.4063
[564/1762] D loss: 1.3916
[644/1762] D loss: 1.3946
[724/1762] D loss: 1.3937
[804/1762] D loss: 1.3906
[884/1762] D loss: 1.3817
[964/1762] D loss: 1.3895
[1044/1762] D loss: 1.3986
[1124/1762] D loss: 1.3849
[1204/1762] D loss: 1.3972
[1284/1762] D loss: 1.3776
[1364/1762] D loss: 1.3834
[1444/1762] D loss: 1.3948
[1524/1762] D loss: 1.3871
[1604/1762] D loss: 1.3969
[1684/1762] D loss: 1.3920
[1762/1762] D loss: 1.3907
train error: 
 D loss: 1.478141, D accuracy: 18.7% 

test error: 
 D loss: 1.473266, D accuracy: 20.3% 

Epoch 1
-------------------------------
[4/1762] D loss: 1.3840
[84/1762] D loss: 1.3889
[164/1762] D loss: 1.3867
[244/1762] D loss: 1.3919
[324/1762] D loss: 1.3826
[404/1762] D loss: 1.3791
[484/1762] D loss: 1.3845
[564/1762] D loss: 1.3879
[644/

Again, these architectures don't show much of an improvement in performance. The discriminator with separate frame processing suffers from even worse loss spikes.

# Conclusion

We tried a few alternative architectures for the discriminator, but none of them showed a noticeable improvement. In the next experiment, let's try training a discriminator to just distinguish a single game frame `y` from random noise `y_fake`, ignoring the previous time step `x` for now.