# Experiment 028

In this experiment, we will increase the batch size to try and speed up training, and see if it adversely affects the model's performance.

In [1]:
import os
from pathlib import Path
import shutil
import datetime

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt


from models import TetrisModel, TetrisDiscriminator
import metrics

In [2]:
class RecordingDataset(Dataset):
    def __init__(self, path: str):
        self.path = path
        if not os.path.exists(path):
            raise FileNotFoundError()
        with os.scandir(self.path) as it:
            entry: os.DirEntry = next(iter(it))
            _, self.ext = os.path.splitext(entry.name)
            self.highest_index = max((int(Path(file.path).stem) for file in it), default=-1)

    def __len__(self):
        return self.highest_index + 1

    def __getitem__(self, idx):
        file = os.path.join(self.path, f"{idx}{self.ext}")
        if not os.path.exists(file):
            raise IndexError()
        boards = np.load(file)

        def transform(board):
            board = torch.tensor(board, dtype=torch.long)
            board = F.one_hot(board, 2) # One-hot encode the cell types
            board = board.type(torch.float) # Convert to floating-point
            board = board.permute((2, 0, 1)) # Move channels/classes to dimension 0
            return board

        x = transform(boards[-2]) # Ignore all boards except the last two
        y = transform(boards[-1])
        return x, y
        

In [3]:
train_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "train"))
test_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "test"))

In [4]:
device = (
    "cuda" if torch.cuda.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cpu device


In [5]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [6]:
import itertools

def find_interesting_examples(dataset, num=3):
    num_spawns = num
    
    def inner():
        num_spawns_left = num_spawns

        for x, y in dataset:
            # Check for block spawn
            if (x.argmax(0)[0] == 0).all() & (y.argmax(0)[0] == 1).any():
                if num_spawns_left > 0:
                    num_spawns_left -= 1
                    yield x, y
                else:
                    continue
            
    return list(itertools.islice(inner(), num))

In [7]:
def render_prediction(x, pred, y):
    """Renders an example and prediction into a single-image array.
    
    Inputs:
        x: Tensor of shape (height, width), the model input.
        pred: Tensor of shape (height, width), the model prediction.
        y: Tensor of shape (height, width), the target.
    """
    assert len(x.shape) == 2, f"Expected tensors of shape (width, height) but got {x.shape}"
    assert x.shape == pred.shape, f"Shapes do not match: {x.shape} != {pred.shape}"
    assert x.shape == y.shape, f"Shapes do not match: {x.shape} != {y.shape}"
    height, width = x.shape
    with torch.no_grad():
        separator = torch.ones(height, 1, dtype=x.dtype)
        return torch.cat((x, separator, pred, separator, y), dim=-1)

In [15]:
real_label = 1.0
fake_label = 0.0

def train_loop(dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc):
    gen.train()
    disc.train()

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        ##################################################################
        # (1) Update discriminator: minimize -log(D(x)) - log(1 - D(G(z)))
        ##################################################################
        disc.zero_grad()

        ## Train with all-real batch
        # Format batch
        X, y = X.to(device), y.to(device)
        batch_size = X.size(0)
        real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through discriminator
        output = torch.flatten(disc(X, y))
        # Calculate loss on all-real batch
        err_disc_real = loss_fn(output, real_labels)
        # Calculate gradients for discriminator in backward pass
        err_disc_real.backward()

        ## Train with all-fake batch
        # Generate fake image batch with generator
        y_fake = gen(X)
        fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)
        # Classify all fake batch with discriminator
        output = torch.flatten(disc(X, y_fake.detach()))
        # Calculate discriminator's loss on the all-fake batch
        err_disc_fake = loss_fn(output, fake_labels)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        err_disc_fake.backward()

        ## Update discriminator weights
        # Compute error of discriminator as sum over the fake and the real batches
        err_disc = err_disc_real + err_disc_fake
        # Update discriminator
        optimizer_disc.step()

        ##############################################
        # (2) Update generator: minimize -log(D(G(z)))
        ##############################################
        gen.zero_grad()
        # Since we just updated the discriminator, perform another forward pass of the all-fake batch through it
        output = torch.flatten(disc(X, y_fake))
        # Calculate the generator's loss based on this output
        # We use real labels because the generator wants to fool the discriminator
        err_gen = loss_fn(output, real_labels)
        # Calculate gradients for generator
        err_gen.backward()
        # Update generator
        optimizer_gen.step()

        # Output training stats
        if batch % 30 == 0:
            current = batch * dataloader.batch_size + batch_size
            print(f"[{current}/{size}] D loss: {err_disc.item():.4f}, G loss: {err_gen.item():.4f}")


def test_loop(split_name, dataloader, gen, disc, loss_fn, tb_writer, epoch, examples):
    gen.eval()
    disc.eval()

    loss_disc = 0.0
    loss_gen = 0.0
    disc_accuracy = 0.0
    cell_accuracy = metrics.CellAccuracy()
    board_accuracy = metrics.BoardAccuracy()
    board_plausibility = metrics.BoardPlausibility()
    spawn_recall = metrics.SpawnRecall()
    spawn_precision = metrics.SpawnPrecision()
    spawn_validity = metrics.SpawnValidity()
    scores_real = np.zeros(len(dataloader.dataset))
    scores_fake = np.zeros(len(dataloader.dataset))
    spawn_diversity = metrics.SpawnDiversity()

    num_batches = len(dataloader)
    with torch.no_grad():        
        for batch, (X, y) in enumerate(dataloader):
            batch_size = X.size(0)
            real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)

            output_real = disc(X, y)
            loss_disc += loss_fn(output_real, real_labels).item()

            y_fake = gen(X)
            output_fake = disc(X, y_fake)
            
            loss_disc += loss_fn(output_fake, fake_labels).item()
            loss_gen += loss_fn(output_fake, real_labels).item()

            pred_real = (output_real > 0.0)
            pred_fake = (output_fake > 0.0)
            disc_accuracy += pred_real.type(torch.float).mean().item()
            disc_accuracy += (~pred_fake).type(torch.float).mean().item()

            classes_X = torch.argmax(X, dim=1)
            classes_y = torch.argmax(y, dim=1)
            classes_y_fake = torch.argmax(y_fake, dim=1)
            cell_accuracy.update_state(classes_y_fake, classes_y)
            board_accuracy.update_state(classes_y_fake, classes_y)
            board_plausibility.update_state(classes_X, classes_y_fake, classes_y)

            spawn_recall.update_state(classes_X, classes_y_fake, classes_y)
            spawn_precision.update_state(classes_X, classes_y_fake, classes_y)
            spawn_validity.update_state(classes_X, classes_y_fake)

            start_index = dataloader.batch_size * batch
            end_index = start_index + batch_size
            scores_real[start_index:end_index] = torch.sigmoid(output_real).numpy()
            scores_fake[start_index:end_index] = torch.sigmoid(output_fake).numpy()

            spawn_diversity.update_state(classes_X, classes_y_fake)

    loss_disc /= num_batches
    loss_gen /= num_batches
    disc_accuracy /= (2.0 * num_batches)

    print(f"{split_name} error: \n D loss: {loss_disc:>8f}, G loss: {loss_gen:>8f}, D accuracy: {(100*disc_accuracy):>0.1f}%, cell accuracy: {(cell_accuracy.result()):>0.1%}, board accuracy: {(board_accuracy.result()):>0.1%} \n")

    tb_writer.add_scalar(f"Discriminator loss/{split_name}", loss_disc, epoch)
    tb_writer.add_scalar(f"Loss/{split_name}", loss_gen, epoch)
    tb_writer.add_scalar(f"Discriminator accuracy/{split_name}", disc_accuracy, epoch)
    tb_writer.add_scalar(f"Cell accuracy/{split_name}", cell_accuracy.result(), epoch)
    tb_writer.add_scalar(f"Board accuracy/{split_name}", board_accuracy.result(), epoch)
    tb_writer.add_scalar(f"Board plausibility/{split_name}", board_plausibility.result(), epoch)
    tb_writer.add_scalar(f"Spawn recall/{split_name}", spawn_recall.result(), epoch)
    tb_writer.add_scalar(f"Spawn precision/{split_name}", spawn_precision.result(), epoch)
    tb_writer.add_scalar(f"Spawn validity/{split_name}", spawn_validity.result(), epoch)
    tb_writer.add_scalar(f"Spawn diversity/{split_name}", spawn_diversity.result(), epoch)

    with torch.no_grad():
        for i, (X, y) in enumerate(examples):
            X, y = X.unsqueeze(0), y.unsqueeze(0)
            y_fake = gen(X)
            X, y, y_fake = X.squeeze(0), y.squeeze(0), y_fake.squeeze(0)
            X, y, y_fake = X.argmax(0), y.argmax(0), y_fake.argmax(0)
            img = render_prediction(X, y_fake, y)
            tb_writer.add_image(f"Predictions/{split_name}/{i}", img, epoch, dataformats="HW")
    
    tb_writer.add_histogram(f"Discriminator scores/{split_name}/real", scores_real, epoch)
    tb_writer.add_histogram(f"Discriminator scores/{split_name}/fake", scores_fake, epoch)


In [16]:
def train(run_name, batch_size):
    learning_rate = 1e-4
    epochs = 300

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    gen = TetrisModel().to(device)
    disc = TetrisDiscriminator().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_028")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples)
        gen_zero_grads = 0
        for name, weight in gen.named_parameters():
            tb_writer.add_histogram(f"Weights/{name}", weight, epoch)
            if weight.grad is not None:
                tb_writer.add_histogram(f"Gradients/{name}", weight.grad, epoch)
                gen_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Zero gradients", gen_zero_grads, epoch)
        disc_zero_grads = 0
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            if weight.grad is not None:
                tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)
                disc_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Discriminator zero gradients", disc_zero_grads, epoch)

    tb_writer.close()
    print("Done!")

In [17]:
for i in range(3):
    for batch_size in [32]:
        train(run_name=f"batch_size_{batch_size}", batch_size=batch_size)

Epoch 0
-------------------------------
[32/1778] D loss: 1.4311, G loss: 0.7888
[992/1778] D loss: 1.3322, G loss: 0.8005
train error: 
 D loss: 1.186653, G loss: 0.872012, D accuracy: 93.0%, cell accuracy: 71.5%, board accuracy: 0.0% 



  probs = self.predicted_spawn_type_counts / num_predicted_spawns


test error: 
 D loss: 1.197825, G loss: 0.864579, D accuracy: 90.9%, cell accuracy: 71.9%, board accuracy: 0.0% 

Epoch 1
-------------------------------
[32/1778] D loss: 1.1906, G loss: 0.8823
[992/1778] D loss: 0.9429, G loss: 1.1019
train error: 
 D loss: 0.751932, G loss: 1.326031, D accuracy: 98.5%, cell accuracy: 73.8%, board accuracy: 0.0% 

test error: 
 D loss: 0.769185, G loss: 1.296466, D accuracy: 98.2%, cell accuracy: 73.3%, board accuracy: 0.0% 

Epoch 2
-------------------------------
[32/1778] D loss: 0.7490, G loss: 1.3697
[992/1778] D loss: 0.5516, G loss: 1.7730
train error: 
 D loss: 0.402367, G loss: 2.172505, D accuracy: 100.0%, cell accuracy: 73.2%, board accuracy: 0.0% 

test error: 
 D loss: 0.402484, G loss: 2.159161, D accuracy: 100.0%, cell accuracy: 73.3%, board accuracy: 0.0% 

Epoch 3
-------------------------------
[32/1778] D loss: 0.4274, G loss: 2.0371
[992/1778] D loss: 0.3150, G loss: 2.5256
train error: 
 D loss: 0.246664, G loss: 2.798213, D accu

  return self.num_true_positives / self.num_spawns_pred
  return self.num_valid_spawns_pred / self.num_spawns_pred


test error: 
 D loss: 0.123670, G loss: 3.464578, D accuracy: 99.7%, cell accuracy: 96.7%, board accuracy: 0.0% 

Epoch 20
-------------------------------
[32/1778] D loss: 0.1045, G loss: 3.3302
[992/1778] D loss: 0.1653, G loss: 4.2534
train error: 
 D loss: 0.160059, G loss: 2.970484, D accuracy: 98.9%, cell accuracy: 96.7%, board accuracy: 0.0% 

test error: 
 D loss: 0.174375, G loss: 2.974112, D accuracy: 98.8%, cell accuracy: 96.7%, board accuracy: 0.0% 

Epoch 21
-------------------------------
[32/1778] D loss: 0.1264, G loss: 3.3297
[992/1778] D loss: 0.1587, G loss: 4.4272
train error: 
 D loss: 0.169251, G loss: 2.945782, D accuracy: 98.9%, cell accuracy: 96.7%, board accuracy: 0.0% 

test error: 
 D loss: 0.184692, G loss: 2.943646, D accuracy: 98.1%, cell accuracy: 96.6%, board accuracy: 0.0% 

Epoch 22
-------------------------------
[32/1778] D loss: 0.0820, G loss: 3.2870
[992/1778] D loss: 0.1541, G loss: 3.0721
train error: 
 D loss: 0.197026, G loss: 2.641068, D acc

Oddly, training with a batch size of 32 as opposed to 4 significantly decreases the performance of the model, even on the training data. For example, the board accuracy increases more slowly and has more frequent and larger down-spikes. Perhaps the slower increase is because the model parameters get updated less often.

Let's try a less aggressive batch size increase. Let's try once with 8 and once with 16.

In [18]:
for batch_size in [8, 16]:
    train(run_name=f"batch_size_{batch_size}", batch_size=batch_size)

Epoch 0
-------------------------------
[8/1778] D loss: 1.4065, G loss: 0.7115
[248/1778] D loss: 1.3375, G loss: 0.7888
[488/1778] D loss: 1.1737, G loss: 0.8530
[728/1778] D loss: 0.9736, G loss: 0.9897
[968/1778] D loss: 0.7238, G loss: 1.3036
[1208/1778] D loss: 0.5274, G loss: 1.5965
[1448/1778] D loss: 0.3768, G loss: 1.8182
[1688/1778] D loss: 0.3144, G loss: 2.0990
train error: 
 D loss: 0.295427, G loss: 1.970658, D accuracy: 99.8%, cell accuracy: 82.2%, board accuracy: 0.0% 

test error: 
 D loss: 0.289595, G loss: 1.972734, D accuracy: 100.0%, cell accuracy: 82.3%, board accuracy: 0.0% 

Epoch 1
-------------------------------
[8/1778] D loss: 0.2322, G loss: 2.2380
[248/1778] D loss: 0.2031, G loss: 2.3563
[488/1778] D loss: 0.1988, G loss: 2.5311
[728/1778] D loss: 0.1309, G loss: 2.6791
[968/1778] D loss: 0.0922, G loss: 3.1813
[1208/1778] D loss: 0.0766, G loss: 3.2810
[1448/1778] D loss: 0.0708, G loss: 3.1621
[1688/1778] D loss: 0.0761, G loss: 3.6550
train error: 
 D

Even batch sizes 8 and 16 make the board accuracy lower and more unstable. The generator loss curves look about the same though, only the discriminator loss curves show a difference. Perhaps the performance degradation is then due to the lack of any kind of normalization in the discriminator.

The table below shows how batch size affects training time:

| Batch size | Training time (minutes) |
| ---------- | ----------------------- |
| 4 | 59 |
| 8 | 40 |
| 16 | 31 |
| 32 | 26 |

Clearly increasing the batch size reduces the training time significantly, but I believe it's not worth reducing the batch size due to the impact on model performance.

# Conclusion

Increasing batch size reduces training time but makes board accuracy and plausibility worse and more unstable. We will keep the batch size the same for now.

In future work, we may try using a higher batch size alongside spectral or layer normalization in the discriminator. (We wouldn't use batch normalization as we previously showed that that significantly worsens performance and training stability).