# Experiment 026

In this experiment, we'll train with a dataset that is balanced between block spawns and block falls, and see whether this improves the model metrics. We will compute the model metrics on the original, unbalanced dataset too, so we can compare like-for-like with the model trained on the unbalanced dataset.

In [24]:
import os
from pathlib import Path
import datetime

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from models import TetrisModel, TetrisDiscriminator
from recording import RecordingDatabase
import metrics

In [25]:
class RecordingDataset(Dataset):
    def __init__(self, path: str):
        self._db = RecordingDatabase(path)

    def __len__(self):
        return len(self._db)

    def __getitem__(self, idx):
        boards = self._db[idx]
        x = self._transform(boards[-2]) # Ignore all boards except the last two
        y = self._transform(boards[-1])
        return x, y
    
    def _transform(self, board):
        board = torch.tensor(board, dtype=torch.long)
        board = F.one_hot(board, 2) # One-hot encode the cell types
        board = board.type(torch.float) # Convert to floating-point
        board = board.permute((2, 0, 1)) # Move channels/classes to dimension 0
        return board

In [26]:
train_dataset = RecordingDataset(os.path.join("data", "balanced", "train"))
test_dataset = RecordingDataset(os.path.join("data", "balanced", "test"))
batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

x, y = next(iter(train_dataloader))
print(x.shape, x.dtype)
print(y.shape, y.dtype)

torch.Size([4, 2, 22, 10]) torch.float32
torch.Size([4, 2, 22, 10]) torch.float32


In [27]:
device = (
    "cuda" if torch.cuda.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cpu device


In [28]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [29]:
gen = TetrisModel().to(device)
disc = TetrisDiscriminator().to(device)

with torch.no_grad():
    X, y = next(iter(train_dataloader))
    y_gen = gen(X)
    pred_on_real = F.sigmoid(disc(X, y)[0])
    pred_on_fake = F.sigmoid(disc(X, y_gen)[0])
    print(f"Number of generator parameters: {count_parameters(gen)}")
    print(f"Number of discriminator parameters: {count_parameters(disc)}")
    print(f"Predicted label for real data: {pred_on_real}")
    print(f"Predicted label for fake data: {pred_on_fake}")

Number of generator parameters: 17996
Number of discriminator parameters: 7057
Predicted label for real data: 0.4241085648536682
Predicted label for fake data: 0.4778788983821869


# Training the model

In [30]:
import itertools

def find_interesting_examples(dataset, num=3):
    num_spawns = num
    
    def inner():
        num_spawns_left = num_spawns

        for x, y in dataset:
            # Check for block spawn
            if (x.argmax(0)[0] == 0).all() & (y.argmax(0)[0] == 1).any():
                if num_spawns_left > 0:
                    num_spawns_left -= 1
                    yield x, y
                else:
                    continue
            
    return list(itertools.islice(inner(), num))

In [31]:
def render_prediction(x, pred, y):
    """Renders an example and prediction into a single-image array.
    
    Inputs:
        x: Tensor of shape (height, width), the model input.
        pred: Tensor of shape (height, width), the model prediction.
        y: Tensor of shape (height, width), the target.
    """
    assert len(x.shape) == 2, f"Expected tensors of shape (width, height) but got {x.shape}"
    assert x.shape == pred.shape, f"Shapes do not match: {x.shape} != {pred.shape}"
    assert x.shape == y.shape, f"Shapes do not match: {x.shape} != {y.shape}"
    height, width = x.shape
    with torch.no_grad():
        separator = torch.ones(height, 1, dtype=x.dtype)
        return torch.cat((x, separator, pred, separator, y), dim=-1)

In [32]:
blocks = [
    torch.tensor(
        [[0, 0, 0, 1, 1, 1, 1, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # I

    torch.tensor(
        [[0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # O

    torch.tensor(
        [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # J

    torch.tensor(
        [[0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # T

    torch.tensor(
        [[0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # S

    torch.tensor(
        [[0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # L

    torch.tensor(
        [[0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int) # Z
]

def get_valid_block_spawns(classes_X, classes_y_fake):
    """Determines whether predicted block spawns have a valid shape.
    
    Inputs:
        classes_X: Tensor of int32 of shape (batch_size, height, width), the first time step (with argmax applied on cell types).
        classes_y_fake: Tensor of int32 of shape (batch_size, height, width), the model's prediction (with argmax applied on cell types).

    Returns: Tensor of bool of shape (batch_size,), whether the items are predicted block spawns AND valid.
    """
    with torch.no_grad():
        batch_size = classes_X.size(0)
        ret = torch.full((batch_size,), False)

        # Take difference to see which cells are full but weren't before.
        diff = classes_y_fake - classes_X

        # It's only a valid block spawn if the change in the first 3 rows matches
        # one of the valid configurations.
        for block in blocks:
            ret |= (diff[:, :3, :] == block).all(-1).all(-1)
        
        return ret


In [33]:
real_label = 1.0
fake_label = 0.0


def train_loop(dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc):
    gen.train()
    disc.train()

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(train_dataloader):
        ##################################################################
        # (1) Update discriminator: minimize -log(D(x)) - log(1 - D(G(z)))
        ##################################################################
        disc.zero_grad()

        ## Train with all-real batch
        # Format batch
        X, y = X.to(device), y.to(device)
        batch_size = X.size(0)
        real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through discriminator
        output = torch.flatten(disc(X, y))
        # Calculate loss on all-real batch
        err_disc_real = loss_fn(output, real_labels)
        # Calculate gradients for discriminator in backward pass
        err_disc_real.backward()

        ## Train with all-fake batch
        # Generate fake image batch with generator
        y_fake = gen(X)
        fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)
        # Classify all fake batch with discriminator
        output = torch.flatten(disc(X, y_fake.detach()))
        # Calculate discriminator's loss on the all-fake batch
        err_disc_fake = loss_fn(output, fake_labels)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        err_disc_fake.backward()

        ## Update discriminator weights
        # Compute error of discriminator as sum over the fake and the real batches
        err_disc = err_disc_real + err_disc_fake
        # Update discriminator
        optimizer_disc.step()

        ##############################################
        # (2) Update generator: minimize -log(D(G(z)))
        ##############################################
        gen.zero_grad()
        # Since we just updated the discriminator, perform another forward pass of the all-fake batch through it
        output = torch.flatten(disc(X, y_fake))
        # Calculate the generator's loss based on this output
        # We use real labels because the generator wants to fool the discriminator
        err_gen = loss_fn(output, real_labels)
        # Calculate gradients for generator
        err_gen.backward()
        # Update generator
        optimizer_gen.step()

        # Output training stats
        if batch % 20 == 0:
            current = batch * dataloader.batch_size + batch_size
            print(f"[{current}/{size}] D loss: {err_disc.item():.4f}, G loss: {err_gen.item():.4f}")


def test_loop(split_name, dataloader, gen, disc, loss_fn, tb_writer, epoch, examples):
    gen.eval()
    disc.eval()

    loss_disc = 0.0
    loss_gen = 0.0
    disc_accuracy = 0.0
    cell_accuracy = 0.0
    board_accuracy = 0.0
    spawn_recall = 0.0
    num_spawns = 0.0
    spawn_validity = 0.0
    num_predicted_spawns = 0.0
    spawn_precision = 0.0
    scores_real = np.zeros(len(dataloader.dataset))
    scores_fake = np.zeros(len(dataloader.dataset))
    spawn_diversity = metrics.SpawnDiversity()

    num_batches = len(dataloader)
    with torch.no_grad():        
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            
            batch_size = X.size(0)
            real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)

            output_real = disc(X, y)
            loss_disc += loss_fn(output_real, real_labels).item()

            y_fake = gen(X)
            output_fake = disc(X, y_fake)
            
            loss_disc += loss_fn(output_fake, fake_labels).item()
            loss_gen += loss_fn(output_fake, real_labels).item()

            pred_real = (output_real > 0.0)
            pred_fake = (output_fake > 0.0)
            disc_accuracy += pred_real.type(torch.float).mean().item()
            disc_accuracy += (~pred_fake).type(torch.float).mean().item()

            classes_X = torch.argmax(X, dim=1)
            classes_y = torch.argmax(y, dim=1)
            classes_y_fake = torch.argmax(y_fake, dim=1)
            cell_accuracy += (classes_y_fake == classes_y).type(torch.float).mean().item()
            board_accuracy += (classes_y_fake == classes_y).all(-1).all(-1).type(torch.float).mean().item()

            actual_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y[:, 0, :] == 1).any(-1)
            predicted_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y_fake[:, 0, :] == 1).any(-1)
            num_true_positives = (actual_spawns & predicted_spawns).type(torch.float).sum().item()
            spawn_recall += num_true_positives
            spawn_precision += num_true_positives
            num_spawns += actual_spawns.type(torch.float).sum().item()
            valid_spawns = get_valid_block_spawns(classes_X, classes_y_fake)
            spawn_validity += valid_spawns.type(torch.float).sum().item()
            num_predicted_spawns += predicted_spawns.type(torch.float).sum().item()

            start_index = dataloader.batch_size * batch
            end_index = start_index + batch_size
            scores_real[start_index:end_index] = torch.sigmoid(output_real).numpy()
            scores_fake[start_index:end_index] = torch.sigmoid(output_fake).numpy()

            spawn_diversity.update_state(classes_X, classes_y_fake)

    loss_disc /= num_batches
    loss_gen /= num_batches
    cell_accuracy /= num_batches
    board_accuracy /= num_batches
    spawn_recall /= num_spawns
    spawn_precision = np.nan if (num_predicted_spawns == 0.0) else spawn_precision / num_predicted_spawns
    disc_accuracy /= (2.0 * num_batches)
    spawn_validity = np.nan if (num_predicted_spawns == 0.0) else spawn_validity / num_predicted_spawns

    print(f"{split_name} error: \n D loss: {loss_disc:>8f}, G loss: {loss_gen:>8f}, D accuracy: {(100*disc_accuracy):>0.1f}%, cell accuracy: {(100*cell_accuracy):>0.1f}%, board accuracy: {(100*board_accuracy):>0.1f}% \n")

    tb_writer.add_scalar(f"Discriminator loss/{split_name}", loss_disc, epoch)
    tb_writer.add_scalar(f"Loss/{split_name}", loss_gen, epoch)
    tb_writer.add_scalar(f"Discriminator accuracy/{split_name}", disc_accuracy, epoch)
    tb_writer.add_scalar(f"Cell accuracy/{split_name}", cell_accuracy, epoch)
    tb_writer.add_scalar(f"Board accuracy/{split_name}", board_accuracy, epoch)
    tb_writer.add_scalar(f"Spawn recall/{split_name}", spawn_recall, epoch)
    tb_writer.add_scalar(f"Spawn precision/{split_name}", spawn_precision, epoch)
    tb_writer.add_scalar(f"Spawn validity/{split_name}", spawn_validity, epoch)
    tb_writer.add_scalar(f"Spawn diversity/{split_name}", spawn_diversity.result(), epoch)

    with torch.no_grad():
        for i, (X, y) in enumerate(examples):
            X, y = X.unsqueeze(0), y.unsqueeze(0)
            y_fake = gen(X)
            X, y, y_fake = X.squeeze(0), y.squeeze(0), y_fake.squeeze(0)
            X, y, y_fake = X.argmax(0), y.argmax(0), y_fake.argmax(0)
            img = render_prediction(X, y_fake, y)
            tb_writer.add_image(f"Predictions/{split_name}/{i}", img, epoch, dataformats="HW")
    
    tb_writer.add_histogram(f"Discriminator scores/{split_name}/real", scores_real, epoch)
    tb_writer.add_histogram(f"Discriminator scores/{split_name}/fake", scores_fake, epoch)


In [36]:
def train(run_name="", learning_rate=1e-4, epochs=300):
    gen = TetrisModel().to(device)
    disc = TetrisDiscriminator().to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
    optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)

    log_dir = os.path.join("runs", "experiment_026")
    log_subdir = os.path.join(log_dir, run_name + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    tb_writer = SummaryWriter(log_subdir)

    train_examples = find_interesting_examples(train_dataset)
    test_examples = find_interesting_examples(test_dataset)

    for epoch in range(epochs):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc)
        test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples)
        test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples)
        gen_zero_grads = 0
        for name, weight in gen.named_parameters():
            tb_writer.add_histogram(f"Weights/{name}", weight, epoch)
            if weight.grad is not None:
                tb_writer.add_histogram(f"Gradients/{name}", weight.grad, epoch)
                gen_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Zero gradients", gen_zero_grads, epoch)
        disc_zero_grads = 0
        for name, weight in disc.named_parameters():
            tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)
            disc_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
        tb_writer.add_scalar(f"Discriminator zero gradients", disc_zero_grads, epoch)

    tb_writer.close()
    print("Done!")
    return gen

In [37]:
for i in range(3):
    _ = train(run_name="balanced_data", epochs=600)

Epoch 0
-------------------------------
[4/1778] D loss: 1.4073, G loss: 0.7001
[84/1778] D loss: 1.3274, G loss: 0.7195
[164/1778] D loss: 1.3071, G loss: 0.7270
[244/1778] D loss: 1.1872, G loss: 0.7774
[324/1778] D loss: 1.1228, G loss: 0.7483
[404/1778] D loss: 0.9999, G loss: 0.7785
[484/1778] D loss: 0.8691, G loss: 0.8268
[564/1778] D loss: 0.7692, G loss: 0.9266
[644/1778] D loss: 0.6271, G loss: 1.0039
[724/1778] D loss: 0.5768, G loss: 1.1120
[804/1778] D loss: 0.4475, G loss: 1.3683
[884/1778] D loss: 0.3245, G loss: 1.6131
[964/1778] D loss: 0.3822, G loss: 1.7702
[1044/1778] D loss: 0.2689, G loss: 2.1092
[1124/1778] D loss: 0.3474, G loss: 2.3060
[1204/1778] D loss: 0.3615, G loss: 2.3263
[1284/1778] D loss: 0.2793, G loss: 2.6655
[1364/1778] D loss: 0.1953, G loss: 2.8758
[1444/1778] D loss: 0.2449, G loss: 2.7926
[1524/1778] D loss: 0.1649, G loss: 2.9331
[1604/1778] D loss: 0.1103, G loss: 3.0187
[1684/1778] D loss: 0.2174, G loss: 3.1872
[1764/1778] D loss: 0.1825, G 

  probs = self.predicted_spawn_type_counts / num_predicted_spawns


test error: 
 D loss: 0.243578, G loss: 3.128194, D accuracy: 96.1%, cell accuracy: 74.7%, board accuracy: 0.0% 

Epoch 1
-------------------------------
[4/1778] D loss: 0.2936, G loss: 3.2984
[84/1778] D loss: 0.1959, G loss: 3.4427
[164/1778] D loss: 0.2026, G loss: 3.6304
[244/1778] D loss: 0.1349, G loss: 3.5214
[324/1778] D loss: 0.1983, G loss: 3.9002
[404/1778] D loss: 0.1554, G loss: 3.6043
[484/1778] D loss: 0.1597, G loss: 3.7297
[564/1778] D loss: 0.1644, G loss: 3.6465
[644/1778] D loss: 0.2533, G loss: 3.5457
[724/1778] D loss: 0.1200, G loss: 4.1722
[804/1778] D loss: 0.1216, G loss: 4.4184
[884/1778] D loss: 0.1700, G loss: 3.6839
[964/1778] D loss: 0.1335, G loss: 3.9940
[1044/1778] D loss: 0.1591, G loss: 4.1370
[1124/1778] D loss: 0.1802, G loss: 4.2627
[1204/1778] D loss: 0.1393, G loss: 4.1989
[1284/1778] D loss: 0.2622, G loss: 3.9675
[1364/1778] D loss: 0.0978, G loss: 4.4127
[1444/1778] D loss: 0.1805, G loss: 3.9141
[1524/1778] D loss: 0.1392, G loss: 4.3429
[1

Indeed, training with a balanced dataset seems to help significantly. The spawn recall, precision and validity increase faster than with the unbalanced dataset, don't have such huge down-spikes, and don't get "stuck" at or near zero. The spawn recall is quite unstable up to about epoch 50 (jumping between near-zero and near-one), but it stabilises after that.

In terms of board accuracy, the model reaches about 50%, but this is to be expected, because half the dataset consists of block spawns and these are inherently unpredictable because the block type is randomly decided. Similarly, the cell accuracy reaches "only" 99%, whereas with the unbalanced dataset, the model can reach 99.8%.

Based on the training curves, we should train the model for about 320 epochs on the balanced dataset to maximise our chances of having good metric scores, but sometimes even 150 epochs is enough.

In [38]:
model = train(run_name="balanced_data_return_model", epochs=150)

Epoch 0
-------------------------------
[4/1778] D loss: 1.9051, G loss: 0.2380
[84/1778] D loss: 1.6418, G loss: 0.3569
[164/1778] D loss: 1.5128, G loss: 0.4992
[244/1778] D loss: 1.4304, G loss: 0.5564
[324/1778] D loss: 1.3866, G loss: 0.6107
[404/1778] D loss: 1.2973, G loss: 0.6599
[484/1778] D loss: 1.3022, G loss: 0.6765
[564/1778] D loss: 1.2112, G loss: 0.6224
[644/1778] D loss: 1.1422, G loss: 0.7488
[724/1778] D loss: 1.0963, G loss: 0.7515
[804/1778] D loss: 1.0531, G loss: 0.8167
[884/1778] D loss: 1.0002, G loss: 0.8885
[964/1778] D loss: 0.8869, G loss: 0.8440
[1044/1778] D loss: 0.7758, G loss: 1.0624
[1124/1778] D loss: 0.6281, G loss: 1.2240
[1204/1778] D loss: 0.5825, G loss: 1.1803
[1284/1778] D loss: 0.5424, G loss: 1.2549
[1364/1778] D loss: 0.4781, G loss: 1.3934
[1444/1778] D loss: 0.5452, G loss: 1.6668
[1524/1778] D loss: 0.4533, G loss: 1.9338
[1604/1778] D loss: 0.3384, G loss: 1.8804
[1684/1778] D loss: 0.3744, G loss: 2.2851
[1764/1778] D loss: 0.4004, G 

# Evaluate on the unbalanced dataset

In [48]:
def evaluate_generator(split_name, dataloader, gen):
    gen.eval()

    cell_accuracy = 0.0
    board_accuracy = 0.0
    spawn_recall = 0.0
    num_spawns = 0.0
    spawn_validity = 0.0
    num_predicted_spawns = 0.0
    spawn_precision = 0.0
    spawn_diversity = metrics.SpawnDiversity()

    num_batches = len(dataloader)
    with torch.no_grad():        
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            
            y_fake = gen(X)
            
            classes_X = torch.argmax(X, dim=1)
            classes_y = torch.argmax(y, dim=1)
            classes_y_fake = torch.argmax(y_fake, dim=1)
            cell_accuracy += (classes_y_fake == classes_y).type(torch.float).mean().item()
            board_accuracy += (classes_y_fake == classes_y).all(-1).all(-1).type(torch.float).mean().item()

            actual_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y[:, 0, :] == 1).any(-1)
            predicted_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y_fake[:, 0, :] == 1).any(-1)
            num_true_positives = (actual_spawns & predicted_spawns).type(torch.float).sum().item()
            spawn_recall += num_true_positives
            spawn_precision += num_true_positives
            num_spawns += actual_spawns.type(torch.float).sum().item()
            valid_spawns = get_valid_block_spawns(classes_X, classes_y_fake)
            spawn_validity += valid_spawns.type(torch.float).sum().item()
            num_predicted_spawns += predicted_spawns.type(torch.float).sum().item()

            spawn_diversity.update_state(classes_X, classes_y_fake)

    cell_accuracy /= num_batches
    board_accuracy /= num_batches
    spawn_recall /= num_spawns
    spawn_precision = np.nan if (num_predicted_spawns == 0.0) else spawn_precision / num_predicted_spawns
    spawn_validity = np.nan if (num_predicted_spawns == 0.0) else spawn_validity / num_predicted_spawns

    print(f"Evaluating model on {split_name} dataset...")
    print(f"Cell accuracy/{split_name}: {cell_accuracy:.2%}")
    print(f"Board accuracy/{split_name}: {board_accuracy:.2%}")
    print(f"Spawn recall/{split_name}: {spawn_recall:.2%}")
    print(f"Spawn precision/{split_name}: {spawn_precision:.2%}")
    print(f"Spawn validity/{split_name}: {spawn_validity:.2%}")
    print(f"Spawn diversity/{split_name}: {spawn_diversity.result():.2%}")
    print()

In [49]:
test_dataset_unb = RecordingDataset(os.path.join("data", "tetris_emulator", "test"))
test_dataloader_unb = DataLoader(test_dataset_unb, batch_size=batch_size, shuffle=True)
evaluate_generator("test", test_dataloader, model)
evaluate_generator("test_unbalanced", test_dataloader_unb, model)

Evaluating model on test dataset...
Cell accuracy/test: 99.00%
Board accuracy/test: 39.41%
Spawn recall/test: 99.55%
Spawn precision/test: 88.76%
Spawn validity/test: 87.95%
Spawn diversity/test: 17.62%

Evaluating model on test_unbalanced dataset...
Cell accuracy/test_unbalanced: 99.46%
Board accuracy/test_unbalanced: 62.61%
Spawn recall/test_unbalanced: 100.00%
Spawn precision/test_unbalanced: 48.89%
Spawn validity/test_unbalanced: 91.11%
Spawn diversity/test_unbalanced: 15.26%



The values of cell accuracy, board accuracy, spawn recall, spawn validity and spawn diversity are the same or better on the unbalanced dataset, whereas the spawn precision is lower.

The board accuracy is only 62.61% on the unbalanced data, which is surprisingly low, but perhaps this is because the model only has 39.41% board accuracy (instead of near 50%) on the balanced data. Perhaps a model trained to near 50% board accuracy on the balanced data would do better on the unbalanced data.

The spawn diversity is low on the unbalanced data, but again this could be because of the low spawn diversity on the balanced data. Let's try training a better model on the balanced data and see how that performs on the unbalanced data.

The low spawn precision on the unbalanced data could be because of the low board accuracy and spawn precision on the balanced data. Across several runs training on the balanced data, the model struggles to achieve spawn precision above 95% and the one time it does, this seems to come at the cost of spawn diversity.

In [50]:
model = train(run_name="balanced_data_return_model", epochs=350)
evaluate_generator("test", test_dataloader, model)
evaluate_generator("test_unbalanced", test_dataloader_unb, model)

Epoch 0
-------------------------------
[4/1778] D loss: 1.3993, G loss: 1.0097
[84/1778] D loss: 1.3436, G loss: 0.9149
[164/1778] D loss: 1.2996, G loss: 0.8446
[244/1778] D loss: 1.1867, G loss: 0.8812
[324/1778] D loss: 1.1147, G loss: 0.8650
[404/1778] D loss: 1.0183, G loss: 0.9228
[484/1778] D loss: 0.8409, G loss: 0.9297
[564/1778] D loss: 0.8274, G loss: 1.1411
[644/1778] D loss: 0.5655, G loss: 1.2855
[724/1778] D loss: 0.5780, G loss: 1.3525
[804/1778] D loss: 0.4039, G loss: 1.5613
[884/1778] D loss: 0.3493, G loss: 1.8835
[964/1778] D loss: 0.2424, G loss: 2.3122
[1044/1778] D loss: 0.3249, G loss: 2.3232
[1124/1778] D loss: 0.1624, G loss: 2.6294
[1204/1778] D loss: 0.2001, G loss: 2.8858
[1284/1778] D loss: 0.2504, G loss: 2.9156
[1364/1778] D loss: 0.2309, G loss: 3.9157
[1444/1778] D loss: 0.1274, G loss: 3.8606
[1524/1778] D loss: 0.1853, G loss: 3.9801
[1604/1778] D loss: 0.1784, G loss: 3.4235
[1684/1778] D loss: 0.1743, G loss: 3.7989
[1764/1778] D loss: 0.1153, G 

  probs = self.predicted_spawn_type_counts / num_predicted_spawns


test error: 
 D loss: 0.174675, G loss: 4.146503, D accuracy: 99.7%, cell accuracy: 73.6%, board accuracy: 0.0% 

Epoch 1
-------------------------------
[4/1778] D loss: 0.2006, G loss: 4.9283
[84/1778] D loss: 0.1376, G loss: 4.5195
[164/1778] D loss: 0.1537, G loss: 4.3445
[244/1778] D loss: 0.1049, G loss: 3.9333
[324/1778] D loss: 0.0986, G loss: 4.9085
[404/1778] D loss: 0.0984, G loss: 4.0549
[484/1778] D loss: 0.1573, G loss: 4.6401
[564/1778] D loss: 0.1259, G loss: 4.9317
[644/1778] D loss: 0.1011, G loss: 5.4547
[724/1778] D loss: 0.0875, G loss: 4.6207
[804/1778] D loss: 0.1333, G loss: 5.4427
[884/1778] D loss: 0.0793, G loss: 4.9253
[964/1778] D loss: 0.0903, G loss: 6.0168
[1044/1778] D loss: 0.1116, G loss: 4.5600
[1124/1778] D loss: 0.0655, G loss: 5.4243
[1204/1778] D loss: 0.0971, G loss: 5.2014
[1284/1778] D loss: 0.1163, G loss: 5.7276
[1364/1778] D loss: 0.2083, G loss: 4.0071
[1444/1778] D loss: 0.1259, G loss: 5.2025
[1524/1778] D loss: 0.1851, G loss: 6.0705
[1

With better board accuracy on the balanced dataset, we get the same on the unbalanced dataset. The same is true of spawn precision. I expect the same is true of spawn diversity too, even though the numbers here are inconclusive because again we got an unlucky run in terms of spawn diversity. Higher spawn diversity on the balanced dataset should yield the same on the unbalanced dataset because spawn diversity is only evaluated on block spawns, and there is nothing inherently different between the block spawns in one dataset as opposed to the other.

# Evaluate on the game

The reduced board accuracy on the unbalanced dataset, as compared with a model trained on the unbalanced dataset, might affect the model's performance as a game emulator. Let's save a version and test it out on the game.

In [52]:
torch.save(model.state_dict(), "tetris_emulator_balanced.pth")

As predicted, the model does worse in a real game scenario. On block fall frames, the block usually falls correctly, but the model usually makes mistakes as the block lands. Extra cells get filled below the block, making it land earlier. Some cells randomly appear near the bottom of the board. Also, all block spawns are of type O, even though the model appears to have nontrivial (i.e. greater than 1/7 = 14.29%) spawn diversity. This suggests that the model has overfit block spawn type to the training data and cannot randomise block spawns based when taking its own output as input.

This raises another interesting point. Perhaps the reason for all the botched block landings by this model is the method of recording test data. When we capture and save recordings, every frame has a defined position within the recording. For example, the frame when the first block spawns is always the second frame of a recording, never the first. Since most of the block shapes have a height of 2 and the recording length is 2 (1 input frame, 1 output frame), this may bias the dataset towards odd y-values in the grid behaving differently from even y-values. In a later experiment, we could record a dataset with potentially overlapping recordings to try remove any such bias, and see if this makes the model easier to train.

# Conclusion

Training on a dataset that is balanced between block falls and block spawns improves performance - and stability of performance - on the block spawns. However, the board accuracy on the unbalanced dataset is lower than for a model that was trained on the unbalanced data, and this manifests itself as poorer in-game performance on block falls, particularly as the block is landing.

We noted that the poor block landing performance in the model trained on the balanced dataset may be due to a different type of bias in the recording process - that of preferring odd block y-coordinates in the first (input) frame and even y-coordinates in the second (output) frame. We can fix such bias, if it exists, by allowing recordings to overlap.

The model trained on the balanced dataset still displays instability in spawn diversity, and even with nontrivial spawn diversity, may not produce diverse block spawns when used as an emulator since it repeatedly consumes its own output. More engineering is needed to achieve true spawn diversity.

Given that balancing the dataset does improve and stabilise the block-spawn-related metrics, it makes sense to use more block spawn examples in training than in the original, unbalanced dataset, where the split between falls and spawns was about 95/5. A 50/50 split degrades production in performance, so we should choose a split that is somewhere between 95/5 and 50/50. Perhaps 80/20 would work.