# Experiment 016

In this experiment, we train the discriminator to distinguish between the correct next frame and random noise, as in Experiment 014. The difference is that we will not use 1D batch norm after the linear layer of the model, because this destroys the model's ability to learn effectively.

In [1]:
import os
from pathlib import Path
import shutil
import datetime

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt

In [2]:
class RecordingDataset(Dataset):
    def __init__(self, path: str):
        self.path = path
        if not os.path.exists(path):
            raise FileNotFoundError()
        with os.scandir(self.path) as it:
            entry: os.DirEntry = next(iter(it))
            _, self.ext = os.path.splitext(entry.name)
            self.highest_index = max((int(Path(file.path).stem) for file in it), default=-1)

    def __len__(self):
        return self.highest_index + 1

    def __getitem__(self, idx):
        file = os.path.join(self.path, f"{idx}{self.ext}")
        if not os.path.exists(file):
            raise IndexError()
        boards = np.load(file)

        def transform(board):
            board = torch.tensor(board, dtype=torch.long)
            board = F.one_hot(board, 2) # One-hot encode the cell types
            board = board.type(torch.float) # Convert to floating-point
            board = board.permute((2, 0, 1)) # Move channels/classes to dimension 0
            return board

        x = transform(boards[-2]) # Ignore all boards except the last two
        y = transform(boards[-1])
        return x, y
        

In [3]:
train_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "train"))
test_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "test"))
batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

x, y = next(iter(train_dataloader))
print(x.shape, x.dtype)
print(y.shape, y.dtype)

torch.Size([4, 2, 22, 10]) torch.float32
torch.Size([4, 2, 22, 10]) torch.float32


In [4]:
device = (
    "cuda" if torch.cuda.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cpu device


In [5]:
class TetrisDiscriminator(nn.Module):
    """A discriminator for the cell state predictions. Assesses the output of the generator.

    Inputs:
        x: Tensor of float32 of shape (batch_size, channels, height, width). channels = 2 is the one-hot encoding of cell types, with
           0 for empty cells and 1 for filled cells. height = 22 and width = 10 are the dimensions of the game board. The entries
           should be 0 for empty cells and 1 for filled cells.
        y: Tensor of float32 of shape (batch_size, channels, height, width), as with x. This should be either the output of the
           generator (with exp applied) or the one-hot encoding of the ground truth of the next cell states.
    
    Returns: Tensor of float32 of shape (batch_size, 1), decisions on whether the data are real or fake. Probabilities close to 0 (negative logits)
             correspond to fake data, and probabilities close to 1 (positive logits) correspond to real data.
    """

    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 16, bias=False),
                nn.BatchNorm1d(16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits

In [6]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def check_model(model):
    with torch.no_grad():
        X, y = next(iter(train_dataloader))
        batch_size, channels, height, width = X.shape
        y_fake = torch.rand(batch_size, channels, height, width)
        pred_on_real = F.sigmoid(model(X, y)[0])
        pred_on_fake = F.sigmoid(model(X, y_fake)[0])
        print(f"Number of discriminator parameters: {count_parameters(model)}")
        print(f"Predicted label for real data: {pred_on_real}")
        print(f"Predicted label for fake data: {pred_on_fake}")

In [7]:
disc = TetrisDiscriminator().to(device)
check_model(disc)

Number of discriminator parameters: 5553
Predicted label for real data: 0.517902135848999
Predicted label for fake data: 0.39286068081855774


In [8]:
real_label = 1.0
fake_label = 0.0

def train_loop(dataloader, disc, loss_fn, optimizer_disc):
    disc.train()

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(train_dataloader):
        ##################################################################
        # (1) Update discriminator: minimize -log(D(x)) - log(1 - D(G(z)))
        ##################################################################
        disc.zero_grad()

        ## Train with all-real batch
        # Format batch
        X, y = X.to(device), y.to(device)
        batch_size, channels, height, width = X.shape
        real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through discriminator
        output = torch.flatten(disc(X, y))
        # Calculate loss on all-real batch
        err_disc_real = loss_fn(output, real_labels)
        # Calculate gradients for discriminator in backward pass
        err_disc_real.backward()

        ## Train with all-fake batch
        # Generate fake image batch as random noise
        y_fake = torch.rand(batch_size, channels, height, width)
        fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)
        # Classify all fake batch with discriminator
        output = torch.flatten(disc(X, y_fake.detach()))
        # Calculate discriminator's loss on the all-fake batch
        err_disc_fake = loss_fn(output, fake_labels)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        err_disc_fake.backward()

        ## Update discriminator weights
        # Compute error of discriminator as sum over the fake and the real batches
        err_disc = err_disc_real + err_disc_fake
        # Update discriminator
        optimizer_disc.step()

        # Output training stats
        if batch % 20 == 0:
            current = batch * dataloader.batch_size + batch_size
            print(f"[{current}/{size}] D loss: {err_disc.item():.4f}")


def test_loop(split_name, dataloader, disc, loss_fn, tb_writer, epoch):
    disc.eval()

    loss_disc = 0.0
    disc_accuracy = 0.0

    num_batches = len(dataloader)
    with torch.no_grad():        
        for X, y in dataloader:
            batch_size, channels, height, width = X.shape
            real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)

            output_real = disc(X, y)
            loss_disc += loss_fn(output_real, real_labels).item()

            y_fake = torch.rand(batch_size, channels, height, width)
            output_fake = disc(X, y_fake)
            
            loss_disc += loss_fn(output_fake, fake_labels).item()

            pred_real = (output_real > 0.0)
            pred_fake = (output_fake > 0.0)
            disc_accuracy += pred_real.type(torch.float).mean().item()
            disc_accuracy += (~pred_fake).type(torch.float).mean().item()

    loss_disc /= num_batches
    disc_accuracy /= (2.0 * num_batches)

    print(f"{split_name} error: \n D loss: {loss_disc:>8f}, D accuracy: {(100*disc_accuracy):>0.1f}% \n")

    tb_writer.add_scalar(f"Discriminator loss/{split_name}", loss_disc, epoch)
    tb_writer.add_scalar(f"Discriminator accuracy/{split_name}", disc_accuracy, epoch)


In [9]:
def train(run_name="", cls=TetrisDiscriminator, learning_rate=1e-1):
    epochs = 100

    disc = cls().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_disc = torch.optim.SGD(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_016")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, disc, loss_fn, optimizer_disc)
        test_loop("train", train_dataloader, disc, loss_fn, tb_writer, epoch)
        test_loop("test", test_dataloader, disc, loss_fn, tb_writer, epoch)
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)

    tb_writer.close()
    print("Done!")

In [11]:
class DiscWithNo1dBatchNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
                nn.Conv2d(4, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(16, 16, 3, padding=1, bias=False),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
                nn.Flatten(),
                nn.Linear(160, 16),
                nn.ReLU(),
                nn.Linear(16, 1),
                nn.Flatten(start_dim=0)
            )
    
    def forward(self, x, y):
        x = torch.cat((x, y), dim=1)
        logits = self.body(x)
        return logits


check_model(DiscWithNo1dBatchNorm().to(device))

Number of discriminator parameters: 5537
Predicted label for real data: 0.3999292850494385
Predicted label for fake data: 0.3705187439918518


In [12]:
for i in range(2):
    train(run_name=DiscWithNo1dBatchNorm.__name__, cls=DiscWithNo1dBatchNorm)

Epoch 0
-------------------------------
[4/1762] D loss: 1.5434
[84/1762] D loss: 0.0122
[164/1762] D loss: 0.1179
[244/1762] D loss: 0.0198
[324/1762] D loss: 0.0146
[404/1762] D loss: 0.0018
[484/1762] D loss: 0.0034
[564/1762] D loss: 0.0018
[644/1762] D loss: 0.0023
[724/1762] D loss: 0.0033
[804/1762] D loss: 0.0005
[884/1762] D loss: 0.0001
[964/1762] D loss: 0.0003
[1044/1762] D loss: 0.0002
[1124/1762] D loss: 0.0009
[1204/1762] D loss: 0.0001
[1284/1762] D loss: 0.0003
[1364/1762] D loss: 0.0002
[1444/1762] D loss: 0.0001
[1524/1762] D loss: 0.0002
[1604/1762] D loss: 0.0007
[1684/1762] D loss: 0.0001
[1762/1762] D loss: 0.0002
train error: 
 D loss: 0.005741, D accuracy: 100.0% 

test error: 
 D loss: 0.005598, D accuracy: 100.0% 

Epoch 1
-------------------------------
[4/1762] D loss: 0.0001
[84/1762] D loss: 0.0003
[164/1762] D loss: 0.0001
[244/1762] D loss: 0.0001
[324/1762] D loss: 0.0000
[404/1762] D loss: 0.0001
[484/1762] D loss: 0.0001
[564/1762] D loss: 0.0003
[64

# Conclusion

It works!! Within a single epoch, the model learns correctly.