# Experiment 015

In this experiment, we simplify the task of training the discriminator temporarily so we can find an architecture that works. We will train a single-frame discriminator to just distinguish between a single frame and random noise.

In [1]:
import os
from pathlib import Path
import shutil
import datetime

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

In [2]:
class RecordingDataset(Dataset):
    def __init__(self, path: str):
        self.path = path
        if not os.path.exists(path):
            raise FileNotFoundError()
        with os.scandir(self.path) as it:
            entry: os.DirEntry = next(iter(it))
            _, self.ext = os.path.splitext(entry.name)
            self.highest_index = max((int(Path(file.path).stem) for file in it), default=-1)

    def __len__(self):
        return self.highest_index + 1

    def __getitem__(self, idx):
        file = os.path.join(self.path, f"{idx}{self.ext}")
        if not os.path.exists(file):
            raise IndexError()
        boards = np.load(file)

        def transform(board):
            board = torch.tensor(board, dtype=torch.long)
            board = F.one_hot(board, 2) # One-hot encode the cell types
            board = board.type(torch.float) # Convert to floating-point
            board = board.permute((2, 0, 1)) # Move channels/classes to dimension 0
            return board

        x = transform(boards[-2]) # Ignore all boards except the last two
        y = transform(boards[-1])
        return x, y
        

In [3]:
train_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "train"))
test_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "test"))
batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

x, y = next(iter(train_dataloader))
print(x.shape, x.dtype)
print(y.shape, y.dtype)

torch.Size([4, 2, 22, 10]) torch.float32
torch.Size([4, 2, 22, 10]) torch.float32


In [4]:
device = (
    "cuda" if torch.cuda.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cpu device


In [6]:
class SingleFrameDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(2, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 16, bias=False),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, y):
        logits = self.body(y)
        return logits

In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def check_model(model):
    with torch.no_grad():
        X, y = next(iter(train_dataloader))
        batch_size, channels, height, width = X.shape
        y_fake = torch.rand(batch_size, channels, height, width)
        pred_on_real = F.sigmoid(model(y)[0])
        pred_on_fake = F.sigmoid(model(y_fake)[0])
        print(f"Number of discriminator parameters: {count_parameters(model)}")
        print(f"Predicted label for real data: {pred_on_real}")
        print(f"Predicted label for fake data: {pred_on_fake}")

In [8]:
disc = SingleFrameDiscriminator().to(device)
check_model(disc)

Number of discriminator parameters: 5265
Predicted label for real data: 0.6543416976928711
Predicted label for fake data: 0.6104115843772888


In [9]:
real_label = 1.0
fake_label = 0.0

def train_loop(dataloader, disc, loss_fn, optimizer_disc):
    disc.train()

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(train_dataloader):
        ##################################################################
        # (1) Update discriminator: minimize -log(D(x)) - log(1 - D(G(z)))
        ##################################################################
        disc.zero_grad()

        ## Train with all-real batch
        # Format batch
        X, y = X.to(device), y.to(device)
        batch_size, channels, height, width = X.shape
        real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through discriminator
        output = torch.flatten(disc(y))
        # Calculate loss on all-real batch
        err_disc_real = loss_fn(output, real_labels)
        # Calculate gradients for discriminator in backward pass
        err_disc_real.backward()

        ## Train with all-fake batch
        # Generate fake image batch as random noise
        y_fake = torch.rand(batch_size, channels, height, width)
        fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)
        # Classify all fake batch with discriminator
        output = torch.flatten(disc(y_fake.detach()))
        # Calculate discriminator's loss on the all-fake batch
        err_disc_fake = loss_fn(output, fake_labels)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        err_disc_fake.backward()

        ## Update discriminator weights
        # Compute error of discriminator as sum over the fake and the real batches
        err_disc = err_disc_real + err_disc_fake
        # Update discriminator
        optimizer_disc.step()

        # Output training stats
        if batch % 20 == 0:
            current = batch * dataloader.batch_size + batch_size
            print(f"[{current}/{size}] D loss: {err_disc.item():.4f}")


def test_loop(split_name, dataloader, disc, loss_fn, tb_writer, epoch):
    disc.eval()

    loss_disc = 0.0
    disc_accuracy = 0.0

    num_batches = len(dataloader)
    with torch.no_grad():        
        for X, y in dataloader:
            batch_size, channels, height, width = X.shape
            real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)

            output_real = disc(y)
            loss_disc += loss_fn(output_real, real_labels).item()

            y_fake = torch.rand(batch_size, channels, height, width)
            output_fake = disc(y_fake)
            
            loss_disc += loss_fn(output_fake, fake_labels).item()

            pred_real = (output_real > 0.0)
            pred_fake = (output_fake > 0.0)
            disc_accuracy += pred_real.type(torch.float).mean().item()
            disc_accuracy += (~pred_fake).type(torch.float).mean().item()

    loss_disc /= num_batches
    disc_accuracy /= (2.0 * num_batches)

    print(f"{split_name} error: \n D loss: {loss_disc:>8f}, D accuracy: {(100*disc_accuracy):>0.1f}% \n")

    tb_writer.add_scalar(f"Discriminator loss/{split_name}", loss_disc, epoch)
    tb_writer.add_scalar(f"Discriminator accuracy/{split_name}", disc_accuracy, epoch)


In [10]:
def train(run_name="", cls=SingleFrameDiscriminator, learning_rate=1e-1):
    epochs = 100

    disc = cls().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_disc = torch.optim.SGD(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_015")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, disc, loss_fn, optimizer_disc)
        test_loop("train", train_dataloader, disc, loss_fn, tb_writer, epoch)
        test_loop("test", test_dataloader, disc, loss_fn, tb_writer, epoch)
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)

    tb_writer.close()
    print("Done!")

In [11]:
train()

Epoch 0
-------------------------------
[4/1762] D loss: 1.4210
[84/1762] D loss: 1.4010
[164/1762] D loss: 1.3750
[244/1762] D loss: 1.3888
[324/1762] D loss: 1.3900
[404/1762] D loss: 1.3864
[484/1762] D loss: 1.3895
[564/1762] D loss: 1.3902
[644/1762] D loss: 1.3824
[724/1762] D loss: 1.3891
[804/1762] D loss: 1.3890
[884/1762] D loss: 1.3886
[964/1762] D loss: 1.3861
[1044/1762] D loss: 1.3861
[1124/1762] D loss: 1.3920
[1204/1762] D loss: 1.3882
[1284/1762] D loss: 1.3857
[1364/1762] D loss: 1.3880
[1444/1762] D loss: 1.3860
[1524/1762] D loss: 1.3873
[1604/1762] D loss: 1.3853
[1684/1762] D loss: 1.3862
[1762/1762] D loss: 1.3865
train error: 
 D loss: 1.372075, D accuracy: 69.7% 

test error: 
 D loss: 1.372402, D accuracy: 69.5% 

Epoch 1
-------------------------------
[4/1762] D loss: 1.3882
[84/1762] D loss: 1.3873
[164/1762] D loss: 1.3832
[244/1762] D loss: 1.3853
[324/1762] D loss: 1.3862
[404/1762] D loss: 1.3814
[484/1762] D loss: 1.3843
[564/1762] D loss: 1.3856
[644/

The discriminator doesn't perform well on this task, with unstable training as seen before. After 50 epochs it even "gives up" and scores a constant 50% accuracy and 1.386 loss.

Let's first try some lower learning rates.

# Learning rate

In [14]:
for learning_rate in [1e-2, 1e-3, 1e-4, 1e-5]:
    run_name = "lr_" + f"{learning_rate:.1e}".replace("-", "m").replace(".", "p")
    train(run_name=run_name, learning_rate=learning_rate)

Epoch 0
-------------------------------
[4/1762] D loss: 1.4178
[84/1762] D loss: 1.4295
[164/1762] D loss: 1.4090
[244/1762] D loss: 1.3803
[324/1762] D loss: 1.4036
[404/1762] D loss: 1.4025
[484/1762] D loss: 1.4063
[564/1762] D loss: 1.4031
[644/1762] D loss: 1.3777
[724/1762] D loss: 1.3949
[804/1762] D loss: 1.4011
[884/1762] D loss: 1.4057
[964/1762] D loss: 1.4086
[1044/1762] D loss: 1.3781
[1124/1762] D loss: 1.3776
[1204/1762] D loss: 1.3998
[1284/1762] D loss: 1.3856
[1364/1762] D loss: 1.3924
[1444/1762] D loss: 1.3837
[1524/1762] D loss: 1.4013
[1604/1762] D loss: 1.3704
[1684/1762] D loss: 1.3966
[1762/1762] D loss: 1.3954
train error: 
 D loss: 1.358698, D accuracy: 62.5% 

test error: 
 D loss: 1.362209, D accuracy: 60.3% 

Epoch 1
-------------------------------
[4/1762] D loss: 1.3825
[84/1762] D loss: 1.3821
[164/1762] D loss: 1.3743
[244/1762] D loss: 1.3860
[324/1762] D loss: 1.3768
[404/1762] D loss: 1.4087
[484/1762] D loss: 1.3825
[564/1762] D loss: 1.3863
[644/

We see an odd behaviour here as the learning rate decreases. At 1e-3, the wild oscillations in accuracy subside, but the model stays around 60%. At 1e-4 and 1e-5, the model quickly stabilises on an accuracy value which could be as high as 80% or as low as 20%, then stays at that value until the end of the training process. This suggests that the problem can't be fixed by tweaking the learning rate.

# Model architecture

Let's try some of the architectures from Experiment 014 and some new ones.

In [16]:
class DeeperDisc(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(2, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 16, bias=False),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, y):
        logits = self.body(y)
        return logits


check_model(DeeperDisc().to(device))

Number of discriminator parameters: 9937
Predicted label for real data: 0.4662979245185852
Predicted label for fake data: 0.37488076090812683


In [17]:
class WiderDisc(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(2, 24, 3, padding=1, bias=False),
                nn.BatchNorm2d(24),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(24, 24, 3, padding=1, bias=False),
                nn.BatchNorm2d(24),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(240, 16, bias=False),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, y):
        logits = self.body(y)
        return logits


check_model(WiderDisc().to(device))

Number of discriminator parameters: 9601
Predicted label for real data: 0.5021631121635437
Predicted label for fake data: 0.759419322013855


In [18]:
class DiscWithLessBatchNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(2, 16, 3, padding=1, bias=False),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 16, bias=False),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, y):
        logits = self.body(y)
        return logits


check_model(DiscWithLessBatchNorm().to(device))

Number of discriminator parameters: 5233
Predicted label for real data: 0.5039087533950806
Predicted label for fake data: 0.5060989260673523


In [21]:
class DiscWithNoMaxPooling(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(2, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Conv2d(16, 16, 3, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.AvgPool2d(kernel_size=(14, 2)),
                nn.Flatten(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, y):
        logits = self.body(y)
        return logits


check_model(DiscWithNoMaxPooling().to(device))

Number of discriminator parameters: 9681
Predicted label for real data: 0.6092426180839539
Predicted label for fake data: 0.6277297139167786


In [22]:
for i in range(2):
    for cls in [DeeperDisc, WiderDisc, DiscWithLessBatchNorm, DiscWithNoMaxPooling]:
        train(run_name=cls.__name__, cls=cls, learning_rate=1e-2)

Epoch 0
-------------------------------
[4/1762] D loss: 1.4196
[84/1762] D loss: 1.3933
[164/1762] D loss: 1.3890
[244/1762] D loss: 1.3890
[324/1762] D loss: 1.3774
[404/1762] D loss: 1.4184
[484/1762] D loss: 1.4075
[564/1762] D loss: 1.4114
[644/1762] D loss: 1.4138
[724/1762] D loss: 1.3944
[804/1762] D loss: 1.4019
[884/1762] D loss: 1.3872
[964/1762] D loss: 1.3984
[1044/1762] D loss: 1.4009
[1124/1762] D loss: 1.3952
[1204/1762] D loss: 1.4204
[1284/1762] D loss: 1.3913
[1364/1762] D loss: 1.4106
[1444/1762] D loss: 1.4037
[1524/1762] D loss: 1.3790
[1604/1762] D loss: 1.3955
[1684/1762] D loss: 1.4208
[1762/1762] D loss: 1.4118
train error: 
 D loss: 2.612664, D accuracy: 49.9% 

test error: 
 D loss: 2.588432, D accuracy: 50.0% 

Epoch 1
-------------------------------
[4/1762] D loss: 1.3910
[84/1762] D loss: 1.3801
[164/1762] D loss: 1.4017
[244/1762] D loss: 1.3872
[324/1762] D loss: 1.3779
[404/1762] D loss: 1.3915
[484/1762] D loss: 1.3902
[564/1762] D loss: 1.4054
[644/

These models all suffer from serious problems. `WiderDisc`, `DeeperDisc` and `DiscWithLessBatchNorm` all have their loss functions increase, and their accuracy spike wildly, often hitting 0%. With `DiscWithNoMaxPooling`, the loss steadily goes up, but the accuracy stays around 50% after an initial wild oscillation in the first 5 epochs.

The last run, of `DiscWithNoMaxPooling`, is very interesting. It quickly achieves 100% training and test accuracy and low loss, even though on the other run the loss steadily increased and the accuracy stayed at 50%. This suggests the problem might be to do with weight initialization, which is mentioned in the DCGAN paper and the associated PyTorch tutorial.

As a next step, let's try a model with no batch normalization. I also noticed a bug with `DiscWithLessBatchNorm` where it has no bias in the first layer (an artifact from removing the batch norm), so I'll try a fixed version of that.

In [23]:
class DiscWithLessBatchNormFixed(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(2, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 16, bias=False),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, y):
        logits = self.body(y)
        return logits


check_model(DiscWithLessBatchNormFixed().to(device))

Number of discriminator parameters: 5249
Predicted label for real data: 0.35802119970321655
Predicted label for fake data: 0.39517462253570557


In [24]:
class DiscWithNoBatchNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(2, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, y):
        logits = self.body(y)
        return logits


check_model(DiscWithNoBatchNorm().to(device))

Number of discriminator parameters: 5217
Predicted label for real data: 0.5477533340454102
Predicted label for fake data: 0.5454436540603638


In [25]:
for i in range(2):
    for cls in [DiscWithLessBatchNormFixed, DiscWithNoBatchNorm]:
        train(run_name=cls.__name__, cls=cls, learning_rate=1e-2)

Epoch 0
-------------------------------
[4/1762] D loss: 1.3845
[84/1762] D loss: 1.4057
[164/1762] D loss: 1.4367
[244/1762] D loss: 1.3906
[324/1762] D loss: 1.4010
[404/1762] D loss: 1.4082
[484/1762] D loss: 1.4233
[564/1762] D loss: 1.3806
[644/1762] D loss: 1.3924
[724/1762] D loss: 1.4037
[804/1762] D loss: 1.3983
[884/1762] D loss: 1.4164
[964/1762] D loss: 1.3866
[1044/1762] D loss: 1.3911
[1124/1762] D loss: 1.3878
[1204/1762] D loss: 1.3880
[1284/1762] D loss: 1.4105
[1364/1762] D loss: 1.4010
[1444/1762] D loss: 1.3824
[1524/1762] D loss: 1.3995
[1604/1762] D loss: 1.3946
[1684/1762] D loss: 1.4043
[1762/1762] D loss: 1.3821
train error: 
 D loss: 1.528025, D accuracy: 49.3% 

test error: 
 D loss: 1.523935, D accuracy: 49.0% 

Epoch 1
-------------------------------
[4/1762] D loss: 1.3984
[84/1762] D loss: 1.3782
[164/1762] D loss: 1.3893
[244/1762] D loss: 1.3932
[324/1762] D loss: 1.3954
[404/1762] D loss: 1.4009
[484/1762] D loss: 1.3857
[564/1762] D loss: 1.3855
[644/

Incredible! The `DiscWithNoBatchNorm` model reached 100% accuracy and 5e-5 loss in a single epoch! The only difference is that it has no batch normalization and the bias is enabled on the corresponding convolutional and linear layers. Meanwhile, the `DiscWithLessBatchNormFixed` model suffers very similar problems to the other models.

Here are some hypotheses and things to try next:
* Perhaps the 2D batch norm was fine but the 1D batch norm wasn't, because DCGAN has no linear layers or 1D batch norm. Try removing the 1D batch norm but not the 2D batch norm.
* Perhaps the problem is not the batch norm itself, but disabling the bias on the other layers for some reason. Try with batch norm and with bias enabled.
* Perhaps the fact that real and fake data is fed as separate batches is what's causing them to be batch-normalized to look similar to each other.
* Perhaps the batches are too small to use batch norm, because the sample statistics calculated by it aren't good estimates of the population parameters when the batch size is only 4.

In [26]:
class DiscWithNo1dBatchNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(2, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 16, bias=False),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, y):
        logits = self.body(y)
        return logits


check_model(DiscWithNo1dBatchNorm().to(device))

Number of discriminator parameters: 5217
Predicted label for real data: 0.3660827577114105
Predicted label for fake data: 0.4189343750476837


In [27]:
class DiscWithBiasEnabled(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(2, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 16),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, y):
        logits = self.body(y)
        return logits


check_model(DiscWithBiasEnabled().to(device))

Number of discriminator parameters: 5281
Predicted label for real data: 0.5547911524772644
Predicted label for fake data: 0.5008944869041443


In [28]:
for i in range(2):
    for cls in [DiscWithNo1dBatchNorm, DiscWithBiasEnabled]:
        train(run_name=cls.__name__, cls=cls, learning_rate=1e-2)

Epoch 0
-------------------------------
[4/1762] D loss: 1.2824
[84/1762] D loss: 0.5820
[164/1762] D loss: 0.2057
[244/1762] D loss: 0.1047
[324/1762] D loss: 0.0729
[404/1762] D loss: 0.0399
[484/1762] D loss: 0.0219
[564/1762] D loss: 0.0158
[644/1762] D loss: 0.0129
[724/1762] D loss: 0.0127
[804/1762] D loss: 0.0098
[884/1762] D loss: 0.0097
[964/1762] D loss: 0.0108
[1044/1762] D loss: 0.0064
[1124/1762] D loss: 0.0033
[1204/1762] D loss: 0.0049
[1284/1762] D loss: 0.0083
[1364/1762] D loss: 0.0044
[1444/1762] D loss: 0.0032
[1524/1762] D loss: 0.0026
[1604/1762] D loss: 0.0041
[1684/1762] D loss: 0.0029
[1762/1762] D loss: 0.0056
train error: 
 D loss: 0.004873, D accuracy: 100.0% 

test error: 
 D loss: 0.004826, D accuracy: 100.0% 

Epoch 1
-------------------------------
[4/1762] D loss: 0.0024
[84/1762] D loss: 0.0052
[164/1762] D loss: 0.0032
[244/1762] D loss: 0.0024
[324/1762] D loss: 0.0033
[404/1762] D loss: 0.0017
[484/1762] D loss: 0.0019
[564/1762] D loss: 0.0025
[64

`DiscWithBiasEnabled` does terribly, but `DiscWithNo1dBatchNorm` does just as well as `DiscWithNoBatchNorm`. So the 1D batch norm at the end was the culprit.

# Conclusion

We have trained a version of the discriminator to reliably distinguish between the real data and random noise on a single frame. This was by removing the 1D batch normalization at the end of the model. We should continue with architectures that don't have this 1D batch norm at the end.

Removing the other 2D batch normalization in lower layers of the model doesn't hurt performance, but let's keep it for now. The combined depth of the generator and discriminator might mean that we get internal covariate shift in these layers, even though the discriminator isn't very deep by itself.

As a next step, let's copy Experiment 014 but use this architectural insight.