# Training a Tetris emulator

In this notebook, we train a model to emulate Tetris and provide a backend for our model-based game engine.

In [1]:
import os
from pathlib import Path
import shutil
import datetime

import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt


from models import TetrisModel, TetrisDiscriminator
import metrics

In [2]:
class RecordingDataset(Dataset):
    def __init__(self, path: str):
        self.path = path
        if not os.path.exists(path):
            raise FileNotFoundError()
        with os.scandir(self.path) as it:
            entry: os.DirEntry = next(iter(it))
            _, self.ext = os.path.splitext(entry.name)
            self.highest_index = max((int(Path(file.path).stem) for file in it), default=-1)

    def __len__(self):
        return self.highest_index + 1

    def __getitem__(self, idx):
        file = os.path.join(self.path, f"{idx}{self.ext}")
        if not os.path.exists(file):
            raise IndexError()
        boards = np.load(file)

        def transform(board):
            board = torch.tensor(board, dtype=torch.long)
            board = F.one_hot(board, 2) # One-hot encode the cell types
            board = board.type(torch.float) # Convert to floating-point
            board = board.permute((2, 0, 1)) # Move channels/classes to dimension 0
            return board

        x = transform(boards[-2]) # Ignore all boards except the last two
        y = transform(boards[-1])
        return x, y
        

In [3]:
train_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "train"))
test_dataset = RecordingDataset(os.path.join("data", "tetris_emulator", "test"))
batch_size = 4
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

x, y = next(iter(train_dataloader))
print(x.shape, x.dtype)
print(y.shape, y.dtype)

torch.Size([4, 2, 22, 10]) torch.float32
torch.Size([4, 2, 22, 10]) torch.float32


In [4]:
device = (
    "cuda" if torch.cuda.is_available()
    else "cpu"
)

print(f"Using {device} device")

Using cpu device


In [5]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [6]:
gen = TetrisModel().to(device)
disc = TetrisDiscriminator().to(device)

with torch.no_grad():
    X, y = next(iter(train_dataloader))
    y_gen = gen(X)
    pred_on_real = F.sigmoid(disc(X, y)[0])
    pred_on_fake = F.sigmoid(disc(X, y_gen)[0])
    print(f"Number of generator parameters: {count_parameters(gen)}")
    print(f"Number of discriminator parameters: {count_parameters(disc)}")
    print(f"Predicted label for real data: {pred_on_real}")
    print(f"Predicted label for fake data: {pred_on_fake}")

Number of generator parameters: 17996
Number of discriminator parameters: 7057
Predicted label for real data: 0.44655609130859375
Predicted label for fake data: 0.49669408798217773


In [7]:
import itertools

def find_interesting_examples(dataset, num=3):
    num_spawns = num
    
    def inner():
        num_spawns_left = num_spawns

        for x, y in dataset:
            # Check for block spawn
            if (x.argmax(0)[0] == 0).all() & (y.argmax(0)[0] == 1).any():
                if num_spawns_left > 0:
                    num_spawns_left -= 1
                    yield x, y
                else:
                    continue
            
    return list(itertools.islice(inner(), num))

In [8]:
def render_prediction(x, pred, y):
    """Renders an example and prediction into a single-image array.
    
    Inputs:
        x: Tensor of shape (height, width), the model input.
        pred: Tensor of shape (height, width), the model prediction.
        y: Tensor of shape (height, width), the target.
    """
    assert len(x.shape) == 2, f"Expected tensors of shape (width, height) but got {x.shape}"
    assert x.shape == pred.shape, f"Shapes do not match: {x.shape} != {pred.shape}"
    assert x.shape == y.shape, f"Shapes do not match: {x.shape} != {y.shape}"
    height, width = x.shape
    with torch.no_grad():
        separator = torch.ones(height, 1, dtype=x.dtype)
        return torch.cat((x, separator, pred, separator, y), dim=-1)

In [9]:
blocks = [
    torch.tensor(
        [[0, 0, 0, 1, 1, 1, 1, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # I

    torch.tensor(
        [[0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # O

    torch.tensor(
        [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # J

    torch.tensor(
        [[0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # T

    torch.tensor(
        [[0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # S

    torch.tensor(
        [[0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int), # L

    torch.tensor(
        [[0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
        dtype=torch.int) # Z
]

def get_valid_block_spawns(classes_X, classes_y_fake):
    """Determines whether predicted block spawns have a valid shape.
    
    Inputs:
        classes_X: Tensor of int32 of shape (batch_size, height, width), the first time step (with argmax applied on cell types).
        classes_y_fake: Tensor of int32 of shape (batch_size, height, width), the model's prediction (with argmax applied on cell types).

    Returns: Tensor of bool of shape (batch_size,), whether the items are predicted block spawns AND valid.
    """
    with torch.no_grad():
        batch_size = classes_X.size(0)
        ret = torch.full((batch_size,), False)

        # Take difference to see which cells are full but weren't before.
        diff = classes_y_fake - classes_X

        # It's only a valid block spawn if the change in the first 3 rows matches
        # one of the valid configurations.
        for block in blocks:
            ret |= (diff[:, :3, :] == block).all(-1).all(-1)
        
        return ret


In [10]:
real_label = 1.0
fake_label = 0.0

def train_loop(dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc):
    gen.train()
    disc.train()

    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(train_dataloader):
        ##################################################################
        # (1) Update discriminator: minimize -log(D(x)) - log(1 - D(G(z)))
        ##################################################################
        disc.zero_grad()

        ## Train with all-real batch
        # Format batch
        X, y = X.to(device), y.to(device)
        batch_size = X.size(0)
        real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through discriminator
        output = torch.flatten(disc(X, y))
        # Calculate loss on all-real batch
        err_disc_real = loss_fn(output, real_labels)
        # Calculate gradients for discriminator in backward pass
        err_disc_real.backward()

        ## Train with all-fake batch
        # Generate fake image batch with generator
        y_fake = gen(X)
        fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)
        # Classify all fake batch with discriminator
        output = torch.flatten(disc(X, y_fake.detach()))
        # Calculate discriminator's loss on the all-fake batch
        err_disc_fake = loss_fn(output, fake_labels)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        err_disc_fake.backward()

        ## Update discriminator weights
        # Compute error of discriminator as sum over the fake and the real batches
        err_disc = err_disc_real + err_disc_fake
        # Update discriminator
        optimizer_disc.step()

        ##############################################
        # (2) Update generator: minimize -log(D(G(z)))
        ##############################################
        gen.zero_grad()
        # Since we just updated the discriminator, perform another forward pass of the all-fake batch through it
        output = torch.flatten(disc(X, y_fake))
        # Calculate the generator's loss based on this output
        # We use real labels because the generator wants to fool the discriminator
        err_gen = loss_fn(output, real_labels)
        # Calculate gradients for generator
        err_gen.backward()
        # Update generator
        optimizer_gen.step()

        # Output training stats
        if batch % 30 == 0:
            current = batch * dataloader.batch_size + batch_size
            print(f"[{current}/{size}] D loss: {err_disc.item():.4f}, G loss: {err_gen.item():.4f}")


def test_loop(split_name, dataloader, gen, disc, loss_fn, tb_writer, epoch, examples):
    gen.eval()
    disc.eval()

    loss_disc = 0.0
    loss_gen = 0.0
    disc_accuracy = 0.0
    cell_accuracy = 0.0
    board_accuracy = 0.0
    spawn_recall = 0.0
    num_spawns = 0.0
    spawn_validity = 0.0
    num_predicted_spawns = 0.0
    spawn_precision = 0.0
    scores_real = np.zeros(len(dataloader.dataset))
    scores_fake = np.zeros(len(dataloader.dataset))
    spawn_diversity = metrics.SpawnDiversity()

    num_batches = len(dataloader)
    with torch.no_grad():        
        for batch, (X, y) in enumerate(dataloader):
            batch_size = X.size(0)
            real_labels = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
            fake_labels = torch.full((batch_size,), fake_label, dtype=torch.float, device=device)

            output_real = disc(X, y)
            loss_disc += loss_fn(output_real, real_labels).item()

            y_fake = gen(X)
            output_fake = disc(X, y_fake)
            
            loss_disc += loss_fn(output_fake, fake_labels).item()
            loss_gen += loss_fn(output_fake, real_labels).item()

            pred_real = (output_real > 0.0)
            pred_fake = (output_fake > 0.0)
            disc_accuracy += pred_real.type(torch.float).mean().item()
            disc_accuracy += (~pred_fake).type(torch.float).mean().item()

            classes_X = torch.argmax(X, dim=1)
            classes_y = torch.argmax(y, dim=1)
            classes_y_fake = torch.argmax(y_fake, dim=1)
            cell_accuracy += (classes_y_fake == classes_y).type(torch.float).mean().item()
            board_accuracy += (classes_y_fake == classes_y).all(-1).all(-1).type(torch.float).mean().item()

            actual_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y[:, 0, :] == 1).any(-1)
            predicted_spawns = (classes_X[:, 0, :] == 0).all(-1) & (classes_y_fake[:, 0, :] == 1).any(-1)
            num_true_positives = (actual_spawns & predicted_spawns).type(torch.float).sum().item()
            spawn_recall += num_true_positives
            spawn_precision += num_true_positives
            num_spawns += actual_spawns.type(torch.float).sum().item()
            valid_spawns = get_valid_block_spawns(classes_X, classes_y_fake)
            spawn_validity += valid_spawns.type(torch.float).sum().item()
            num_predicted_spawns += predicted_spawns.type(torch.float).sum().item()

            start_index = dataloader.batch_size * batch
            end_index = start_index + batch_size
            scores_real[start_index:end_index] = torch.sigmoid(output_real).numpy()
            scores_fake[start_index:end_index] = torch.sigmoid(output_fake).numpy()

            spawn_diversity.update_state(classes_X, classes_y_fake)

    loss_disc /= num_batches
    loss_gen /= num_batches
    cell_accuracy /= num_batches
    board_accuracy /= num_batches
    spawn_recall /= num_spawns
    spawn_precision = np.nan if (num_predicted_spawns == 0.0) else spawn_precision / num_predicted_spawns
    disc_accuracy /= (2.0 * num_batches)
    spawn_validity = np.nan if (num_predicted_spawns == 0.0) else spawn_validity / num_predicted_spawns

    print(f"{split_name} error: \n D loss: {loss_disc:>8f}, G loss: {loss_gen:>8f}, D accuracy: {(100*disc_accuracy):>0.1f}%, cell accuracy: {(100*cell_accuracy):>0.1f}%, board accuracy: {(100*board_accuracy):>0.1f}% \n")

    tb_writer.add_scalar(f"Discriminator loss/{split_name}", loss_disc, epoch)
    tb_writer.add_scalar(f"Loss/{split_name}", loss_gen, epoch)
    tb_writer.add_scalar(f"Discriminator accuracy/{split_name}", disc_accuracy, epoch)
    tb_writer.add_scalar(f"Cell accuracy/{split_name}", cell_accuracy, epoch)
    tb_writer.add_scalar(f"Board accuracy/{split_name}", board_accuracy, epoch)
    tb_writer.add_scalar(f"Spawn recall/{split_name}", spawn_recall, epoch)
    tb_writer.add_scalar(f"Spawn precision/{split_name}", spawn_precision, epoch)
    tb_writer.add_scalar(f"Spawn validity/{split_name}", spawn_validity, epoch)
    tb_writer.add_scalar(f"Spawn diversity/{split_name}", spawn_diversity.result(), epoch)

    with torch.no_grad():
        for i, (X, y) in enumerate(examples):
            X, y = X.unsqueeze(0), y.unsqueeze(0)
            y_fake = gen(X)
            X, y, y_fake = X.squeeze(0), y.squeeze(0), y_fake.squeeze(0)
            X, y, y_fake = X.argmax(0), y.argmax(0), y_fake.argmax(0)
            img = render_prediction(X, y_fake, y)
            tb_writer.add_image(f"Predictions/{split_name}/{i}", img, epoch, dataformats="HW")
    
    tb_writer.add_histogram(f"Discriminator scores/{split_name}/real", scores_real, epoch)
    tb_writer.add_histogram(f"Discriminator scores/{split_name}/fake", scores_fake, epoch)


In [11]:
learning_rate = 1e-4
epochs = 300

gen = TetrisModel().to(device)
disc = TetrisDiscriminator().to(device)

loss_fn = nn.BCEWithLogitsLoss()
optimizer_gen = torch.optim.Adam(gen.parameters(), lr=learning_rate)
optimizer_disc = torch.optim.Adam(disc.parameters(), lr=learning_rate)

log_dir = os.path.join("runs", "tetris_emulator")
log_subdir = os.path.join(log_dir, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
tb_writer = SummaryWriter(log_subdir)

train_examples = find_interesting_examples(train_dataset)
test_examples = find_interesting_examples(test_dataset)

for epoch in range(epochs):
    print(f"Epoch {epoch}\n-------------------------------")
    train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc)
    test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples)
    test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples)
    gen_zero_grads = 0
    for name, weight in gen.named_parameters():
        tb_writer.add_histogram(f"Weights/{name}", weight, epoch)
        if weight.grad is not None:
            tb_writer.add_histogram(f"Gradients/{name}", weight.grad, epoch)
            gen_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
    tb_writer.add_scalar(f"Zero gradients", gen_zero_grads, epoch)
    disc_zero_grads = 0
    for name, weight in disc.named_parameters():
        tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
        if weight.grad is not None:
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)
            disc_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
    tb_writer.add_scalar(f"Discriminator zero gradients", disc_zero_grads, epoch)

tb_writer.close()
print("Done!")

Epoch 0
-------------------------------
[4/1778] D loss: 1.3800, G loss: 0.7053
[124/1778] D loss: 1.2769, G loss: 0.7590
[244/1778] D loss: 1.2486, G loss: 0.8228
[364/1778] D loss: 1.0590, G loss: 0.8716
[484/1778] D loss: 0.9264, G loss: 0.9779
[604/1778] D loss: 0.6836, G loss: 1.1899
[724/1778] D loss: 0.6696, G loss: 1.3013
[844/1778] D loss: 0.6933, G loss: 1.5199
[964/1778] D loss: 0.5209, G loss: 1.9953
[1084/1778] D loss: 0.5419, G loss: 2.1550
[1204/1778] D loss: 0.6923, G loss: 2.5010
[1324/1778] D loss: 0.4646, G loss: 2.8830
[1444/1778] D loss: 0.2194, G loss: 3.3628
[1564/1778] D loss: 0.3282, G loss: 2.8467
[1684/1778] D loss: 0.2764, G loss: 3.2380
train error: 
 D loss: 0.297511, G loss: 3.940708, D accuracy: 98.4%, cell accuracy: 72.0%, board accuracy: 0.0% 



  probs = self.predicted_spawn_type_counts / num_predicted_spawns


test error: 
 D loss: 0.307771, G loss: 3.888197, D accuracy: 98.1%, cell accuracy: 72.2%, board accuracy: 0.0% 

Epoch 1
-------------------------------
[4/1778] D loss: 0.2616, G loss: 3.3098
[124/1778] D loss: 0.4101, G loss: 3.4060
[244/1778] D loss: 0.2586, G loss: 3.6057
[364/1778] D loss: 0.2040, G loss: 4.1561
[484/1778] D loss: 0.2863, G loss: 3.7932
[604/1778] D loss: 0.2322, G loss: 3.7849
[724/1778] D loss: 0.1215, G loss: 3.0120
[844/1778] D loss: 0.2646, G loss: 3.6381
[964/1778] D loss: 0.2269, G loss: 4.4583
[1084/1778] D loss: 0.2690, G loss: 4.6944
[1204/1778] D loss: 0.1991, G loss: 4.1890
[1324/1778] D loss: 0.2016, G loss: 4.0752
[1444/1778] D loss: 0.1840, G loss: 4.4541
[1564/1778] D loss: 0.1292, G loss: 4.0819
[1684/1778] D loss: 0.2164, G loss: 4.0738
train error: 
 D loss: 0.162994, G loss: 2.863518, D accuracy: 99.7%, cell accuracy: 92.7%, board accuracy: 0.0% 

test error: 
 D loss: 0.167977, G loss: 2.821768, D accuracy: 99.9%, cell accuracy: 92.5%, board 

In [None]:
# Run this as many times as needed to "top-up" the training

extra_epochs = 5

tb_writer = SummaryWriter(log_subdir)

for epoch in range(epochs, epochs + extra_epochs):
    print(f"Epoch {epoch}\n-------------------------------")
    train_loop(train_dataloader, gen, disc, loss_fn, optimizer_gen, optimizer_disc)
    test_loop("train", train_dataloader, gen, disc, loss_fn, tb_writer, epoch, train_examples)
    test_loop("test", test_dataloader, gen, disc, loss_fn, tb_writer, epoch, test_examples)
    gen_zero_grads = 0
    for name, weight in gen.named_parameters():
        tb_writer.add_histogram(f"Weights/{name}", weight, epoch)
        if weight.grad is not None:
            tb_writer.add_histogram(f"Gradients/{name}", weight.grad, epoch)
            gen_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
    tb_writer.add_scalar(f"Zero gradients", gen_zero_grads, epoch)
    disc_zero_grads = 0
    for name, weight in disc.named_parameters():
        tb_writer.add_histogram(f"Discriminator weights/{name}", weight, epoch)
        if weight.grad is not None:
            tb_writer.add_histogram(f"Discriminator gradients/{name}", weight.grad, epoch)
            disc_zero_grads += weight.grad.numel() - weight.grad.count_nonzero().item()
    tb_writer.add_scalar(f"Discriminator zero gradients", disc_zero_grads, epoch)

tb_writer.close()

epochs += extra_epochs

print("Done!")

In [None]:
def show_prediction(example):
    x, y = example
    pred = gen(x.unsqueeze(0)).squeeze(0)
    x, y, pred = x.argmax(0), y.argmax(0), pred.argmax(0)

    fig, axs = plt.subplots(1, 3)
    fig.suptitle("Prediction vs reality")

    for ax in axs:
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)

    axs[0].set_title("Input")
    axs[1].set_title("Predicted")
    axs[2].set_title("Reality")

    axs[0].imshow(x, vmin=0, vmax=1)
    axs[1].imshow(pred, vmin=0, vmax=1)
    axs[2].imshow(y, vmin=0, vmax=1)

    plt.show()

In [None]:
# Show a random training prediction vs reality
import random

idx = random.randrange(len(train_dataset))
print(f"Showing prediction for training example {idx}")
show_prediction(train_dataset[idx])

In [None]:
# Show a random test prediction vs reality
idx = random.randrange(len(test_dataset))
print(f"Showing prediction for test example {idx}")
show_prediction(test_dataset[idx])

In [None]:
def get_mistake_counts(gen, dataloader):
    mistake_counts = np.zeros(len(dataloader.dataset), dtype=np.int32)

    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            batch_size = X.size(0)

            X, y = X.to(device), y.to(device)
            y_fake = gen(X)

            classes_y = torch.argmax(y, dim=1)
            classes_y_fake = torch.argmax(y_fake, dim=1)
            num_mistakes = (classes_y_fake != classes_y).type(torch.int).sum(dim=[1, 2]).numpy()
            batch_start = batch * dataloader.batch_size
            batch_end = batch_start + batch_size
            mistake_counts[batch_start:batch_end] = num_mistakes
    
    return mistake_counts

In [None]:
mistake_counts_train = get_mistake_counts(gen, train_dataloader)
mistake_counts_test = get_mistake_counts(gen, test_dataloader)
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
fig.suptitle("Mistake counts")
axs[0].set_title("Train")
axs[0].hist(mistake_counts_train, bins=list(range(mistake_counts_train.max())))
axs[1].set_title("Test")
axs[1].hist(mistake_counts_test, bins=list(range(mistake_counts_test.max())))
plt.show()

In [None]:
def get_mistake_heatmap(gen, dataloader):
    heatmap = np.zeros((22, 10))

    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            batch_size = X.size(0)

            X, y = X.to(device), y.to(device)
            y_fake = gen(X)

            classes_y = torch.argmax(y, dim=1)
            classes_y_fake = torch.argmax(y_fake, dim=1)
            mistakes = (classes_y_fake != classes_y).type(torch.float).sum(dim=0).numpy()
            heatmap += mistakes
    
    return heatmap

In [None]:
heatmap_train = get_mistake_heatmap(gen, train_dataloader)
heatmap_test = get_mistake_heatmap(gen, test_dataloader)
fig, axs = plt.subplots(1, 2, figsize=(6, 4))
fig.suptitle("Mistake heatmap")
axs[0].set_title("Train")
axs[0].imshow(heatmap_train)
axs[1].set_title("Test")
axs[1].imshow(heatmap_test)
plt.show()

In [None]:
print(heatmap_train)
print(heatmap_train.sum())

In [None]:
def find_failed_example_by_cell(dataloader, cell):
    i, j = cell

    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            batch_size = X.size(0)

            X, y = X.to(device), y.to(device)
            y_fake = gen(X)

            classes_X = torch.argmax(X, dim=1)
            classes_y = torch.argmax(y, dim=1)
            classes_y_fake = torch.argmax(y_fake, dim=1)
            mistakes = (classes_y_fake != classes_y)
            failed_in_cell = mistakes[:, i, j]
            for batch_idx, failed in enumerate(failed_in_cell):
                if failed:
                    dataset_idx = batch * dataloader.batch_size + batch_idx
                    return dataset_idx, (classes_X[batch_idx], classes_y_fake[batch_idx], classes_y[batch_idx])

In [None]:
cell = (1, 5)

idx, (x, y_fake, y) = find_failed_example_by_cell(train_dataloader, cell)
print(f"Showing failed prediction for example {idx}")
fig, ax = plt.subplots(1, 1)
ax.set_title("Input | Predicted | Real")
ax.imshow(render_prediction(x, y_fake, y))

In [None]:
torch.save(gen.state_dict(), "tetris_emulator.pth")