# Experiment 027

In this experiment, we'll train with a dataset that has 80% block falls and 20% block spawns, and see whether we get the combined benefits of the original, unbalanced dataset and a fully 50-50 balanced dataset. We will compute the model metrics on the original, unbalanced dataset too, so we can compare like-for-like with the model trained on the unbalanced dataset.

In [6]:
import os
from pathlib import Path
import datetime

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from models import TetrisModel, TetrisDiscriminator
from recording import RecordingDatabase
import metrics

In [7]:
class RecordingDataset(Dataset):
    def __init__(self, path: str):
        self._db = RecordingDatabase(path)

    def __len__(self):
        return len(self._db)

    def __getitem__(self, idx):
        boards = self._db[idx]
        x = self._transform(boards[-2]) # Ignore all boards except the last two
        y = self._transform(boards[-1])
        return x, y
    
    def _transform(self, board):
        board = torch.tensor(board, dtype=torch.long)
        board = F.one_hot(board, 2) # One-hot encode the cell types
        board = board.type(torch.float) # Convert to floating-point
        board = board.permute((2, 0, 1)) # Move channels/classes to dimension 0
        return board

In [12]:
train_dataset = RecordingDataset(os.path.join("data", "80_20", "train"))
test_dataset = RecordingDataset(os.path.join("data", "80_20", "test"))
batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

x, y = next(iter(train_dataloader))
print(x.shape, x.dtype)
print(y.shape, y.dtype)

torch.Size([4, 2, 22, 10]) torch.float32
torch.Size([4, 2, 22, 10]) torch.float32


In [13]:
device = (
    "cuda" if torch.cuda.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cpu device


In [14]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [15]:
gen = TetrisModel().to(device)
disc = TetrisDiscriminator().to(device)

with torch.no_grad():
    X, y = next(iter(train_dataloader))
    y_gen = gen(X)
    pred_on_real = F.sigmoid(disc(X, y)[0])
    pred_on_fake = F.sigmoid(disc(X, y_gen)[0])
    print(f"Number of generator parameters: {count_parameters(gen)}")
    print(f"Number of discriminator parameters: {count_parameters(disc)}")
    print(f"Predicted label for real data: {pred_on_real}")
    print(f"Predicted label for fake data: {pred_on_fake}")

Number of generator parameters: 17996
Number of discriminator parameters: 7057
Predicted label for real data: 0.7015882134437561
Predicted label for fake data: 0.6346157789230347


# Training the model

In [16]:
import itertools

def find_interesting_examples(dataset, num=3):
    num_spawns = num
    
    def inner():
        num_spawns_left = num_spawns

        for x, y in dataset:
            # Check for block spawn
            if (x.argmax(0)[0] == 0).all() & (y.argmax(0)[0] == 1).any():
                if num_spawns_left > 0:
                    num_spawns_left -= 1
                    yield x, y
                else:
                    continue
            
    return list(itertools.islice(inner(), num))

In [17]:
def render_prediction(x, pred, y):
    """Renders an example and prediction into a single-image array.
    
    Inputs:
        x: Tensor of shape (height, width), the model input.
        pred: Tensor of shape (height, width), the model prediction.
        y: Tensor of shape (height, width), the target.
    """
    assert len(x.shape) == 2, f"Expected tensors of shape (width, height) but got {x.shape}"
    assert x.shape == pred.shape, f"Shapes do not match: {x.shape} != {pred.shape}"
    assert x.shape == y.shape, f"Shapes do not match: {x.shape} != {y.shape}"
    height, width = x.shape
    with torch.no_grad():
        separator = torch.ones(height, 1, dtype=x.dtype)
        return torch.cat((x, separator, pred, separator, y), dim=-1)

In [18]:
blocks = [
    torch.tensor(
        [[0, 0, 0, 1, 1, 1, 1, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # I

    torch.tensor(
        [[0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # O

    torch.tensor(
        [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # J

    torch.tensor(
        [[0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # T

    torch.tensor(
        [[0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # S

    torch.tensor(
        [[0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # L

    torch.tensor(
        [[0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int) # Z
]

def get_valid_block_spawns(classes_X, classes_y_fake):
    """Determines whether predicted block spawns have a valid shape.
    
    Inputs:
        classes_X: Tensor of int32 of shape (batch_size, height, width), the first time step (with argmax applied on cell types).
        classes_y_fake: Tensor of int32 of shape (batch_size, height, width), the model's prediction (with argmax applied on cell types).

    Returns: Tensor of bool of shape (batch_size,), whether the items are predicted block spawns AND valid.
    """
    with torch.no_grad():
        batch_size = classes_X.size(0)
        ret = torch.full((batch_size,), False)

        # Take difference to see which cells are full but weren't before.
        diff = classes_y_fake - classes_X

        # It's only a valid block spawn if the change in the first 3 rows matches
        # one of the valid configurations.
        for block in blocks:
            ret |= (diff[:, :3, :] == block).all(-1).all(-1)
        
        return ret


In [22]:
real_label = 1.0
fake_label = 0.0


def train_loop(dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc):
    gen.train()
    disc.train()

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(train_dataloader):
        ##################################################################
        # (1) Update discriminator: minimize -log(D(x)) - log(1 - D(G(z)))
        ##################################################################
        disc.zero_grad()

        ## Train with all-real batch
        # Format batch
        X, y = X.to(device), y.to(device)
        batch_size = X.size(0)
        real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through discriminator
        output = torch.flatten(disc(X, y))
        # Calculate loss on all-real batch
        err_disc_real = loss_fn(output, real_labels)
        # Calculate gradients for discriminator in backward pass
        err_disc_real.backward()

        ## Train with all-fake batch
        # Generate fake image batch with generator
        y_fake = gen(X)
        fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)
        # Classify all fake batch with discriminator
        output = torch.flatten(disc(X, y_fake.detach()))
        # Calculate discriminator's loss on the all-fake batch
        err_disc_fake = loss_fn(output, fake_labels)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        err_disc_fake.backward()

        ## Update discriminator weights
        # Compute error of discriminator as sum over the fake and the real batches
        err_disc = err_disc_real + err_disc_fake
        # Update discriminator
        optimizer_disc.step()

        ##############################################
        # (2) Update generator: minimize -log(D(G(z)))
        ##############################################
        gen.zero_grad()
        # Since we just updated the discriminator, perform another forward pass of the all-fake batch through it
        output = torch.flatten(disc(X, y_fake))
        # Calculate the generator's loss based on this output
        # We use real labels because the generator wants to fool the discriminator
        err_gen = loss_fn(output, real_labels)
        # Calculate gradients for generator
        err_gen.backward()
        # Update generator
        optimizer_gen.step()

        # Output training stats
        if batch % 20 == 0:
            current = batch * dataloader.batch_size + batch_size
            print(f"[{current}/{size}] D loss: {err_disc.item():.4f}, G loss: {err_gen.item():.4f}")


def test_loop(split_name, dataloader, gen, disc, loss_fn, tb_writer, epoch, examples):
    gen.eval()
    disc.eval()

    loss_disc = 0.0
    loss_gen = 0.0
    disc_accuracy = 0.0
    cell_accuracy = 0.0
    board_accuracy = metrics.BoardAccuracy()
    board_plausibility = metrics.BoardPlausibility()
    spawn_recall = 0.0
    num_spawns = 0.0
    spawn_validity = 0.0
    num_predicted_spawns = 0.0
    spawn_precision = 0.0
    scores_real = np.zeros(len(dataloader.dataset))
    scores_fake = np.zeros(len(dataloader.dataset))
    spawn_diversity = metrics.SpawnDiversity()

    num_batches = len(dataloader)
    with torch.no_grad():        
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            
            batch_size = X.size(0)
            real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)

            output_real = disc(X, y)
            loss_disc += loss_fn(output_real, real_labels).item()

            y_fake = gen(X)
            output_fake = disc(X, y_fake)
            
            loss_disc += loss_fn(output_fake, fake_labels).item()
            loss_gen += loss_fn(output_fake, real_labels).item()

            pred_real = (output_real > 0.0)
            pred_fake = (output_fake > 0.0)
            disc_accuracy += pred_real.type(torch.float).mean().item()
            disc_accuracy += (~pred_fake).type(torch.float).mean().item()

            classes_X = torch.argmax(X, dim=1)
            classes_y = torch.argmax(y, dim=1)
            classes_y_fake = torch.argmax(y_fake, dim=1)
            cell_accuracy += (classes_y_fake == classes_y).type(torch.float).mean().item()
            board_accuracy.update_state(classes_y_fake, classes_y)
            board_plausibility.update_state(classes_X, classes_y_fake, classes_y)

            actual_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y[:, 0, :] == 1).any(-1)
            predicted_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y_fake[:, 0, :] == 1).any(-1)
            num_true_positives = (actual_spawns & predicted_spawns).type(torch.float).sum().item()
            spawn_recall += num_true_positives
            spawn_precision += num_true_positives
            num_spawns += actual_spawns.type(torch.float).sum().item()
            valid_spawns = get_valid_block_spawns(classes_X, classes_y_fake)
            spawn_validity += valid_spawns.type(torch.float).sum().item()
            num_predicted_spawns += predicted_spawns.type(torch.float).sum().item()

            start_index = dataloader.batch_size * batch
            end_index = start_index + batch_size
            scores_real[start_index:end_index] = torch.sigmoid(output_real).numpy()
            scores_fake[start_index:end_index] = torch.sigmoid(output_fake).numpy()

            spawn_diversity.update_state(classes_X, classes_y_fake)

    loss_disc /= num_batches
    loss_gen /= num_batches
    cell_accuracy /= num_batches
    spawn_recall /= num_spawns
    spawn_precision = np.nan if (num_predicted_spawns == 0.0) else spawn_precision / num_predicted_spawns
    disc_accuracy /= (2.0 * num_batches)
    spawn_validity = np.nan if (num_predicted_spawns == 0.0) else spawn_validity / num_predicted_spawns

    print(f"{split_name} error: \n D loss: {loss_disc:>8f}, G loss: {loss_gen:>8f}, D accuracy: {(100*disc_accuracy):>0.1f}%, cell accuracy: {(100*cell_accuracy):>0.1f}%, board accuracy: {(board_accuracy.result()):>0.1%}% \n")

    tb_writer.add_scalar(f"Discriminator loss/{split_name}", loss_disc, epoch)
    tb_writer.add_scalar(f"Loss/{split_name}", loss_gen, epoch)
    tb_writer.add_scalar(f"Discriminator accuracy/{split_name}", disc_accuracy, epoch)
    tb_writer.add_scalar(f"Cell accuracy/{split_name}", cell_accuracy, epoch)
    tb_writer.add_scalar(f"Board accuracy/{split_name}", board_accuracy.result(), epoch)
    tb_writer.add_scalar(f"Board plausibility/{split_name}", board_plausibility.result(), epoch)
    tb_writer.add_scalar(f"Spawn recall/{split_name}", spawn_recall, epoch)
    tb_writer.add_scalar(f"Spawn precision/{split_name}", spawn_precision, epoch)
    tb_writer.add_scalar(f"Spawn validity/{split_name}", spawn_validity, epoch)
    tb_writer.add_scalar(f"Spawn diversity/{split_name}", spawn_diversity.result(), epoch)

    with torch.no_grad():
        for i, (X, y) in enumerate(examples):
            X, y = X.unsqueeze(0), y.unsqueeze(0)
            y_fake = gen(X)
            X, y, y_fake = X.squeeze(0), y.squeeze(0), y_fake.squeeze(0)
            X, y, y_fake = X.argmax(0), y.argmax(0), y_fake.argmax(0)
            img = render_prediction(X, y_fake, y)
            tb_writer.add_image(f"Predictions/{split_name}/{i}", img, epoch, dataformats="HW")
    
    tb_writer.add_histogram(f"Discriminator scores/{split_name}/real", scores_real, epoch)
    tb_writer.add_histogram(f"Discriminator scores/{split_name}/fake", scores_fake, epoch)


In [20]:
def train(run_name="", learning_rate=1e-4, epochs=300):
    gen = TetrisModel().to(device)
    disc = TetrisDiscriminator().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_027")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples)
        gen_zero_grads = 0
        for name, weight in gen.named_parameters():
            tb_writer.add_histogram(f"Weights/{name}", weight, epoch)
            if weight.grad is not None:
                tb_writer.add_histogram(f"Gradients/{name}", weight.grad, epoch)
                gen_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Zero gradients", gen_zero_grads, epoch)
        disc_zero_grads = 0
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)
            disc_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Discriminator zero gradients", disc_zero_grads, epoch)

    tb_writer.close()
    print("Done!")
    return gen

In [21]:
_ = train(run_name="80_20_data", epochs=300)

Epoch 0
-------------------------------
[4/1778] D loss: 1.3751, G loss: 0.5417
[84/1778] D loss: 1.2914, G loss: 0.5953
[164/1778] D loss: 1.2658, G loss: 0.6508
[244/1778] D loss: 1.1499, G loss: 0.7308
[324/1778] D loss: 1.0671, G loss: 0.7471
[404/1778] D loss: 1.0235, G loss: 0.7886
[484/1778] D loss: 0.9535, G loss: 0.8059
[564/1778] D loss: 0.8469, G loss: 0.8776
[644/1778] D loss: 0.7938, G loss: 0.8950
[724/1778] D loss: 0.7591, G loss: 0.9126
[804/1778] D loss: 0.7721, G loss: 0.9194
[884/1778] D loss: 0.6351, G loss: 1.1920
[964/1778] D loss: 0.5774, G loss: 1.2264
[1044/1778] D loss: 0.5752, G loss: 1.4372
[1124/1778] D loss: 0.5654, G loss: 1.5138
[1204/1778] D loss: 0.4892, G loss: 1.6834
[1284/1778] D loss: 0.3893, G loss: 1.9320
[1364/1778] D loss: 0.4109, G loss: 1.7919
[1444/1778] D loss: 0.4686, G loss: 2.0739
[1524/1778] D loss: 0.5547, G loss: 2.0602
[1604/1778] D loss: 0.4005, G loss: 2.4006
[1684/1778] D loss: 0.3500, G loss: 2.2681
[1764/1778] D loss: 0.3395, G 

  probs = self.predicted_spawn_type_counts / num_predicted_spawns


test error: 
 D loss: 0.435457, G loss: 2.324965, D accuracy: 98.3%, cell accuracy: 80.4%, board accuracy: 0.0% 

Epoch 1
-------------------------------
[4/1778] D loss: 0.3120, G loss: 2.4435
[84/1778] D loss: 0.3304, G loss: 2.5153
[164/1778] D loss: 0.3459, G loss: 2.7655
[244/1778] D loss: 0.4835, G loss: 2.7124
[324/1778] D loss: 0.3327, G loss: 2.4725
[404/1778] D loss: 0.3338, G loss: 2.7053
[484/1778] D loss: 0.4009, G loss: 2.8256
[564/1778] D loss: 0.3953, G loss: 3.0328
[644/1778] D loss: 0.3574, G loss: 2.9567
[724/1778] D loss: 0.2809, G loss: 3.4689
[804/1778] D loss: 0.3939, G loss: 2.4758
[884/1778] D loss: 0.5648, G loss: 2.9032
[964/1778] D loss: 0.3198, G loss: 2.7521
[1044/1778] D loss: 0.3536, G loss: 3.5379
[1124/1778] D loss: 0.3523, G loss: 2.7003
[1204/1778] D loss: 0.4263, G loss: 3.3080
[1284/1778] D loss: 0.4888, G loss: 2.5017
[1364/1778] D loss: 0.2566, G loss: 2.8153
[1444/1778] D loss: 0.4458, G loss: 2.3728
[1524/1778] D loss: 0.4492, G loss: 2.4539
[1

In [23]:
for i in range(2):
    model = train(run_name="80_20_data", epochs=300)

Epoch 0
-------------------------------
[4/1778] D loss: 1.4796, G loss: 0.6203
[84/1778] D loss: 1.3435, G loss: 0.7075
[164/1778] D loss: 1.3686, G loss: 0.6595
[244/1778] D loss: 1.3499, G loss: 0.6549
[324/1778] D loss: 1.2096, G loss: 0.7657
[404/1778] D loss: 1.1346, G loss: 0.7182
[484/1778] D loss: 1.0992, G loss: 0.7521
[564/1778] D loss: 1.0841, G loss: 0.7425
[644/1778] D loss: 0.8951, G loss: 0.7791
[724/1778] D loss: 0.8290, G loss: 0.8760
[804/1778] D loss: 0.7708, G loss: 0.8882
[884/1778] D loss: 0.7085, G loss: 0.9062
[964/1778] D loss: 0.6307, G loss: 0.9983
[1044/1778] D loss: 0.6399, G loss: 1.1713
[1124/1778] D loss: 0.4470, G loss: 1.3905
[1204/1778] D loss: 0.4490, G loss: 1.5253
[1284/1778] D loss: 0.3199, G loss: 1.7535
[1364/1778] D loss: 0.4608, G loss: 1.7628
[1444/1778] D loss: 0.3495, G loss: 2.2193
[1524/1778] D loss: 0.4436, G loss: 2.0944
[1604/1778] D loss: 0.5785, G loss: 1.7898
[1684/1778] D loss: 0.3796, G loss: 2.8896
[1764/1778] D loss: 0.3286, G 

  probs = self.predicted_spawn_type_counts / num_predicted_spawns


[84/1778] D loss: 0.4628, G loss: 2.2382
[164/1778] D loss: 0.6996, G loss: 2.0263
[244/1778] D loss: 0.4112, G loss: 2.3226
[324/1778] D loss: 0.5193, G loss: 2.6132
[404/1778] D loss: 0.3249, G loss: 2.5148
[484/1778] D loss: 0.5341, G loss: 2.5438
[564/1778] D loss: 0.4980, G loss: 2.5624
[644/1778] D loss: 0.6377, G loss: 2.6422
[724/1778] D loss: 0.8447, G loss: 2.3219
[804/1778] D loss: 0.5412, G loss: 2.4113
[884/1778] D loss: 0.4974, G loss: 2.9047
[964/1778] D loss: 0.9138, G loss: 2.3756
[1044/1778] D loss: 0.4993, G loss: 2.5889
[1124/1778] D loss: 0.6899, G loss: 1.9691
[1204/1778] D loss: 0.5189, G loss: 2.2481
[1284/1778] D loss: 0.8048, G loss: 2.6200
[1364/1778] D loss: 0.5212, G loss: 2.4021
[1444/1778] D loss: 0.6208, G loss: 2.6600
[1524/1778] D loss: 0.6579, G loss: 3.3330
[1604/1778] D loss: 0.7036, G loss: 2.3488
[1684/1778] D loss: 0.6695, G loss: 2.9535
[1764/1778] D loss: 0.5768, G loss: 2.3641
train error: 
 D loss: 0.690770, G loss: 2.496312, D accuracy: 90.5

As expected, we get results between those we would have got training on the unbalanced or fully-balanced dataset. Unfortunately, we don't get the full benefit in terms of spawn recall, precision, validity and diversity. In particular, these four metrics are lower than with the fully-balanced dataset, spawn precision retains is still stable like with the unbalanced dataset, and spawn diversity sometimes gets "stuck" at low values. Not only is board accuracy lower than with the unbalanced dataset (expected), but board plausibility is also lower (moderately surprising).

We note that when we trained with the fully-balanced dataset, the spawn recall and precision became more stable after 300 epochs. It would be interesting to see what happens with the 80-20 dataset when we train for 600 epochs.

But first, let's test the generator we just trained on the unbalanced dataset and on the game!

# Evaluate on the unbalanced dataset

In [24]:
def evaluate_generator(split_name, dataloader, gen):
    gen.eval()

    cell_accuracy = 0.0
    board_accuracy = metrics.BoardAccuracy()
    board_plausibility = metrics.BoardPlausibility()
    spawn_recall = 0.0
    num_spawns = 0.0
    spawn_validity = 0.0
    num_predicted_spawns = 0.0
    spawn_precision = 0.0
    spawn_diversity = metrics.SpawnDiversity()

    num_batches = len(dataloader)
    with torch.no_grad():        
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            
            y_fake = gen(X)
            
            classes_X = torch.argmax(X, dim=1)
            classes_y = torch.argmax(y, dim=1)
            classes_y_fake = torch.argmax(y_fake, dim=1)
            cell_accuracy += (classes_y_fake == classes_y).type(torch.float).mean().item()
            board_accuracy.update_state(classes_y_fake, classes_y)
            board_plausibility.update_state(classes_X, classes_y_fake, classes_y)

            actual_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y[:, 0, :] == 1).any(-1)
            predicted_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y_fake[:, 0, :] == 1).any(-1)
            num_true_positives = (actual_spawns & predicted_spawns).type(torch.float).sum().item()
            spawn_recall += num_true_positives
            spawn_precision += num_true_positives
            num_spawns += actual_spawns.type(torch.float).sum().item()
            valid_spawns = get_valid_block_spawns(classes_X, classes_y_fake)
            spawn_validity += valid_spawns.type(torch.float).sum().item()
            num_predicted_spawns += predicted_spawns.type(torch.float).sum().item()

            spawn_diversity.update_state(classes_X, classes_y_fake)

    cell_accuracy /= num_batches
    spawn_recall /= num_spawns
    spawn_precision = np.nan if (num_predicted_spawns == 0.0) else spawn_precision / num_predicted_spawns
    spawn_validity = np.nan if (num_predicted_spawns == 0.0) else spawn_validity / num_predicted_spawns

    print(f"Evaluating model on {split_name} dataset...")
    print(f"Cell accuracy/{split_name}: {cell_accuracy:.2%}")
    print(f"Board accuracy/{split_name}: {board_accuracy.result():.2%}")
    print(f"Board plausibility/{split_name}: {board_plausibility.result():.2%}")
    print(f"Spawn recall/{split_name}: {spawn_recall:.2%}")
    print(f"Spawn precision/{split_name}: {spawn_precision:.2%}")
    print(f"Spawn validity/{split_name}: {spawn_validity:.2%}")
    print(f"Spawn diversity/{split_name}: {spawn_diversity.result():.2%}")
    print()

In [25]:
test_dataset_unb = RecordingDataset(os.path.join("data", "tetris_emulator", "test"))
test_dataloader_unb = DataLoader(test_dataset_unb, batch_size=batch_size, shuffle=True)
evaluate_generator("test", test_dataloader, model)
evaluate_generator("test_unbalanced", test_dataloader_unb, model)

Evaluating model on test dataset...
Cell accuracy/test: 99.70%
Board accuracy/test: 78.83%
Board plausibility/test: 91.67%
Spawn recall/test: 97.75%
Spawn precision/test: 93.55%
Spawn validity/test: 82.80%
Spawn diversity/test: 49.51%

Evaluating model on test_unbalanced dataset...
Cell accuracy/test_unbalanced: 99.82%
Board accuracy/test_unbalanced: 87.61%
Board plausibility/test_unbalanced: 94.37%
Spawn recall/test_unbalanced: 100.00%
Spawn precision/test_unbalanced: 88.00%
Spawn validity/test_unbalanced: 80.00%
Spawn diversity/test_unbalanced: 49.24%



When evaluated on the unbalanced dataset, our new model does about the same as a model trained on the unbalanced dataset in terms of cell accuracy and board accuracy. In terms of board plausibility, spawn recall, spawn validity and spawn diversity, it does significantly better! In terms of spawn precision, it does slightly worse. Overall, the model trained on the 80-20 dataset looks preferable so far.

Now let's save model and use it in the game.

# Evaluate on the game

In [26]:
torch.save(model.state_dict(), "tetris_emulator_80_20.pth")

The model actually generates variable block spawns! Aside from that, block falls all all look good, except the frame where the block is landing. Blocks relatively often change shape as they land. Also, many of the block spawns aren't valid shapes.

Overall the 80-20 model has its flaws, but it is generally quite good and much more interesting to observe than the one trained on the unbalanced dataset, mainly due to the variability of its block spawns.

Note that the 80-20 model has about 50% spawn diversity whereas the models trained on the unbalanced and balanced datasets only have around 15-20% diversity, so the observed increase in spawn diversity during gameplay corresponds to a higher spawn diversity at training/evaluation time.

# Training for longer periods

Now let's train on the dataset for 600 epochs and see if there is any change in the training curves after 300 epochs.

In [27]:
model = train(run_name="80_20_data", epochs=600)

Epoch 0
-------------------------------
[4/1778] D loss: 1.3341, G loss: 0.7348
[84/1778] D loss: 1.2908, G loss: 0.6945
[164/1778] D loss: 1.2960, G loss: 0.6574
[244/1778] D loss: 1.1835, G loss: 0.6972
[324/1778] D loss: 1.1556, G loss: 0.7364
[404/1778] D loss: 1.1429, G loss: 0.6944
[484/1778] D loss: 1.0481, G loss: 0.8873
[564/1778] D loss: 0.9651, G loss: 0.8460
[644/1778] D loss: 0.9388, G loss: 0.9234
[724/1778] D loss: 0.8587, G loss: 0.9156
[804/1778] D loss: 0.7241, G loss: 1.2421
[884/1778] D loss: 0.7419, G loss: 1.3821
[964/1778] D loss: 0.6527, G loss: 1.3725
[1044/1778] D loss: 0.5915, G loss: 1.4335
[1124/1778] D loss: 0.5061, G loss: 1.9630
[1204/1778] D loss: 0.4453, G loss: 2.0465
[1284/1778] D loss: 0.5127, G loss: 1.3625
[1364/1778] D loss: 0.3541, G loss: 1.7313
[1444/1778] D loss: 0.5056, G loss: 1.5629
[1524/1778] D loss: 0.4256, G loss: 2.2671
[1604/1778] D loss: 0.3716, G loss: 1.9464
[1684/1778] D loss: 0.4213, G loss: 2.4174
[1764/1778] D loss: 0.3551, G 

  probs = self.predicted_spawn_type_counts / num_predicted_spawns


test error: 
 D loss: 0.378452, G loss: 2.258816, D accuracy: 98.5%, cell accuracy: 94.7%, board accuracy: 0.0%% 

Epoch 1
-------------------------------
[4/1778] D loss: 0.3992, G loss: 2.2504
[84/1778] D loss: 0.4111, G loss: 2.3219
[164/1778] D loss: 0.4125, G loss: 1.9744
[244/1778] D loss: 0.2670, G loss: 2.5588
[324/1778] D loss: 0.3330, G loss: 1.7484
[404/1778] D loss: 0.3035, G loss: 2.3321
[484/1778] D loss: 0.3824, G loss: 1.9459
[564/1778] D loss: 0.3301, G loss: 2.9219
[644/1778] D loss: 0.3667, G loss: 2.2031
[724/1778] D loss: 0.4421, G loss: 2.1937
[804/1778] D loss: 0.3976, G loss: 1.5527
[884/1778] D loss: 0.2844, G loss: 2.1055
[964/1778] D loss: 0.4386, G loss: 2.3712
[1044/1778] D loss: 0.3024, G loss: 2.4774
[1124/1778] D loss: 0.4404, G loss: 2.0335
[1204/1778] D loss: 0.5188, G loss: 1.5200
[1284/1778] D loss: 0.4995, G loss: 3.0329
[1364/1778] D loss: 0.3671, G loss: 2.6842
[1444/1778] D loss: 0.2911, G loss: 2.7538
[1524/1778] D loss: 0.3083, G loss: 2.4330
[

In [28]:
for i in range(3):
    model = train(run_name="80_20_data", epochs=600)

Epoch 0
-------------------------------
[4/1778] D loss: 1.5300, G loss: 0.4428
[84/1778] D loss: 1.2433, G loss: 0.6626
[164/1778] D loss: 1.2477, G loss: 0.6280
[244/1778] D loss: 1.1392, G loss: 0.7104
[324/1778] D loss: 1.0018, G loss: 0.7993
[404/1778] D loss: 0.9727, G loss: 0.8605
[484/1778] D loss: 0.7608, G loss: 1.0452
[564/1778] D loss: 0.6963, G loss: 1.0789
[644/1778] D loss: 0.6679, G loss: 1.3508
[724/1778] D loss: 0.5048, G loss: 1.4111
[804/1778] D loss: 0.4510, G loss: 1.6378
[884/1778] D loss: 0.2717, G loss: 1.9049
[964/1778] D loss: 0.2857, G loss: 2.2487
[1044/1778] D loss: 0.3257, G loss: 2.4022
[1124/1778] D loss: 0.2764, G loss: 2.4695
[1204/1778] D loss: 0.3648, G loss: 2.5209
[1284/1778] D loss: 0.2219, G loss: 2.7880
[1364/1778] D loss: 0.2115, G loss: 2.9779
[1444/1778] D loss: 0.1650, G loss: 2.9481
[1524/1778] D loss: 0.4463, G loss: 3.0031
[1604/1778] D loss: 0.3306, G loss: 3.4374
[1684/1778] D loss: 0.2857, G loss: 3.3626
[1764/1778] D loss: 0.3582, G 

  probs = self.predicted_spawn_type_counts / num_predicted_spawns


test error: 
 D loss: 0.353341, G loss: 3.215525, D accuracy: 97.7%, cell accuracy: 97.0%, board accuracy: 0.0%% 

Epoch 4
-------------------------------
[4/1778] D loss: 0.2650, G loss: 2.0266
[84/1778] D loss: 0.5597, G loss: 2.6263
[164/1778] D loss: 0.4453, G loss: 2.0940
[244/1778] D loss: 0.4299, G loss: 1.7797
[324/1778] D loss: 0.2682, G loss: 3.0789
[404/1778] D loss: 0.4616, G loss: 1.9859
[484/1778] D loss: 0.4772, G loss: 1.6087
[564/1778] D loss: 0.2297, G loss: 2.5531
[644/1778] D loss: 0.4687, G loss: 3.1551
[724/1778] D loss: 0.6063, G loss: 1.3382
[804/1778] D loss: 0.4866, G loss: 1.5756
[884/1778] D loss: 0.8855, G loss: 1.4527
[964/1778] D loss: 0.3314, G loss: 1.7565
[1044/1778] D loss: 0.5168, G loss: 1.4690
[1124/1778] D loss: 0.5678, G loss: 1.6439
[1204/1778] D loss: 0.8559, G loss: 0.9342
[1284/1778] D loss: 0.7105, G loss: 1.1554
[1364/1778] D loss: 0.8356, G loss: 1.7219
[1444/1778] D loss: 0.4433, G loss: 1.9159
[1524/1778] D loss: 0.7179, G loss: 2.7883
[

With the longer runs, the board accuracy and board plausibility are fairly stable, but never reach the level of the curves of the models trained on unbalanced data. However, as shown above, this could still mean that the 80-20 model's board plausibility is higher when evaluated on the unbalanced dataset.

In terms of discriminator loss, we see the loss very gradually going down on the training set but up on the test set between epochs 300 and 600. This suggests some overfitting in the discriminator.

For the discriminator scores, we still see a significant cluster around 1.0 for real data and around 0.0 for fake data, which suggests that the generator needs to improve relative to the discriminator.

For spawn recall and precision, the value improves and is quite stable. For spawn validity, stability worsens, with more frequent large down-spikes. Spawn diversity depends on the run: 2 runs have low diversity (14% and 20%) throughout, 1 run has middling diversity (30%) which peaks around 300 epochs and then degrades, and 1 run has diversity oscillating between 30% and 50% throughout.

# Conclusion

Training with the 80-20 dataset leads to a model that is less precise but more able to diversify its block spawns. Pruning data to this ratio is a promising approach, but has some drawbacks that should be addressed:
1. It requires us to explicitly decide which features of the data to balance on, which may be infeasible in general.
2. It requires us to collect a large amount of data and then discard it, instead of only collecting relevant data.
3. It adds extra hyperparameters: the proportions of each class in the dataset.
4. It means that the test set is less representative of real gameplay, so the test metrics become less useful.

Points (1) and (2) could be addressed by automating the decision of whether to store data and integrating this into the data recording process. However, this could skew the newly-collected data in the direction that the model struggles most with, so the new data may be unsuitable for training a model on that data alone. The new data could be combined with the existing data, but then we get a new hyperparameter: how many examples to add to the dataset before retraining the model.

Point (4) could be addressed by having two test sets: one that comes from the same distribution as the training set, and one which follows the distribution of real game data. The first of these tests sets would be used for validation during training, and the second test set would be used for model evaluation.

Still, training on a more balanced dataset does not get the diversity reliably above 50%, so cannot solve the spawn diversity problem on its own.