<!DOCTYPE html>
<html>
<head>
    <style>
        body {
            font-family: Arial, sans-serif;
            line-height: 1.6;
            color: #333;
        }
        h1 {
            color: #2F4F4F;
        }
        h2 {
            color: #4682B4;
        }
        p {
            margin: 10px 0;
        }
        .key-points, .open-questions {
            background-color: #F0F8FF;
            padding: 15px;
            border-left: 5px solid #4682B4;
            margin-bottom: 20px;
        }
        .key-points ul, .open-questions ul {
            margin: 0;
            padding-left: 20px;
        }
        .highlight {
            color: #B22222;
            font-weight: bold;
        }
          div {
            margin: 10px 0;
        }
    </style>
</head>
<body>

<h1 style="color: #2F4F4F;">1. Summary of the Approach</h1>
<p>In this notebook, I tackled the ARC challenge by framing it as a sequence-to-sequence (seq2seq) translation task, similar to translating from English to French. Here, the input grids and output grids are treated as sequences that need to be translated from one form to another. This approach leverages the power of transformer architectures, which have proven highly effective in language translation tasks.</p>

<div class="key-points">
    <h2 style="color: #4682B4;">Key Points of the Approach:</h2>
    <ul>
        <li><span class="highlight">Data Preparation:</span> The input and output grids are pre-processed to ensure consistent dimensions. The grids are padded to a fixed size (30x30), and various augmentations such as rotation are applied to diversify the training data.</li>
        <li><span class="highlight">Custom Transformer Model:</span> The model consists of separate encoder and decoder layers, designed to handle categorical data. Each grid cell is represented as a category, and the model uses <code>nn.Embedding</code> layers to convert these categorical values into dense embeddings. Additionally, positional encodings are added to represent the 2D spatial positions of grid cells.</li>
        <li><span class="highlight">Start and End Tokens:</span> To help the model understand the beginning and end of sequences, I introduced special tokens (<code>&lt;START&gt;</code> and <code>&lt;END&gt;</code>), similar to techniques used in NLP tasks.</li>
        <li><span class="highlight">Training:</span> The model is trained with a relatively high dropout rate of 0.4, which helps in regularization and prevents overfitting. Surprisingly, the model shows promising results during training, effectively learning to translate input grids to output grids.</li>
    </ul>
</div>

<h1 style="color: #2F4F4F;">2. Issue with Inference</h1>
<p>While the training phase of the model shows great promise with effective learning and generalization, the inference phase does not perform as expected. This discrepancy might stem from the way the transformer model is implemented for this grid-based problem. Although I initially treated the task as a sequence-to-sequence translation problem, similar to language translation, this approach may not be the most suitable for handling 2D grid transformations. The specific requirements and nuances of grid-based reasoning might require a different architecture or handling mechanism that I haven't fully determined yet.</p>

<div style="background-color: #F0F8FF; padding: 15px; border-left: 5px solid #4682B4; margin-bottom: 20px;">
    <h2 style="color: #4682B4;">Key Points to Consider:</h2>
    <ul>
        <li><span style="color: #B22222; font-weight: bold;">Possible Misalignment with Grid-Based Problem:</span> The current model might not correctly capture the spatial relationships and transformations required for grid tasks. The way positional encodings and embeddings are applied during inference might not fully align with the structure of grid data.</li>
        <li><span style="color: #B22222; font-weight: bold;">Reconsideration of Inference Approach:</span> Unlike typical transformer tasks, starting inference with a <code>&lt;START&gt;</code> token may not be the best strategy for grid-based tasks. An alternative approach could involve a one-shot inference technique, where the model predicts the entire grid output in one go, rather than sequentially. However, the exact implementation of such an approach remains unclear.</li>
        <li><span style="color: #B22222; font-weight: bold;">Potential Use of Patches as in Vision Transformers:</span> Another idea could be to treat parts of the grid as patches, similar to the Vision Transformer (ViT) approach, and process them in parallel to capture spatial dependencies more effectively. This would require a redesign of the model architecture to handle grid patches.</li>
        <li><span style="color: #B22222; font-weight: bold;">Inconsistent Performance Between Training and Inference:</span> It's puzzling that the model performs so well during training but fails to generalize effectively during inference. This inconsistency could indicate an issue with overfitting to the training data, or it might suggest that the model hasn't fully learned the underlying spatial patterns required for unseen data.</li>
    </ul>
</div>


<h1 style="color: #2F4F4F;">3. Reflection: Not the Path to AGI, but a Step Towards It</h1>
<p>While this approach is not the definitive way to achieve Artificial General Intelligence (AGI), it is a small but meaningful step towards exploring how transformers and other deep learning techniques can be applied to grid-based reasoning tasks. The process of framing a grid transformation task as a sequence-to-sequence problem allows us to harness the strengths of transformers while identifying their limitations in such domains.</p>

<p>Through this work, I aim to contribute to the broader discussion on the versatility of transformers and their applicability beyond traditional NLP tasks. The challenges faced here also underscore the need for more specialized architectures that can handle grid-based and spatial reasoning tasks more effectively.</p>

<div class="open-questions">
    <h2>Open Questions:</h2>
    <ul>
        <li>Why does the model excel during training with a relatively high dropout rate of 0.4, but struggles significantly during inference? Is this due to overfitting to the training data, or is there a fundamental flaw in how the model processes grid data during inference?</li>
    </ul>
    <p>I'm eager to hear thoughts from the community on what might be going wrong and suggestions on the right approach to solve grid-based problems using transformer architectures.</p>
</div>

</body>
</html>


## Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import json
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from termcolor import colored
import math
import random
import matplotlib.pyplot as plt

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

KAGGLE = False
DEBUG = False

NUM_CLASSES = 10
IGNORE_CLASS = 0
START_TOKEN = NUM_CLASSES   # 10
END_TOKEN = NUM_CLASSES + 1 # 11
SEP_TOKEN = NUM_CLASSES + 2 # 12

## Default Seeds

In [None]:
# Set seed for PyTorch
torch.manual_seed(42)

# If using CUDA
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)  # if you are using multi-GPU.

# Set seed for NumPy
np.random.seed(42)

# Set seed for Python's built-in random library
random.seed(42)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
def log_message(message: str, color: str = "black"):
    if DEBUG:
        print(colored(message, color))


def log_with_condition(message: str, condition=False, color: str = "blue"):
    if condition:
        print(colored(message, color))


def log_important(message: str, color: str = "blue"):
    print(colored(message, color))


def log_error(message: str, color: str = "red"):
    print(message, color)

In [None]:
colors = [
    [0, 0, 0],
    [0, 116, 217],
    [255, 65, 54],
    [46, 204, 64],
    [255, 220, 0],
    [170, 170, 170],
    [240, 18, 190],
    [255, 133, 27],
    [127, 219, 255],
    [135, 12, 37],
]


def paint(*matrices):
    num_matrices = len(matrices)
    fig, axes = plt.subplots(1, num_matrices, figsize=(3 * num_matrices, 2))

    if num_matrices == 1:
        axes = [axes]  # Ensure axes is a list even for a single subplot

    for i, matrix in enumerate(matrices):
        matrix[matrix < 0] = 0
        matrix[matrix > 9] = 9
        try:
            # Ensure matrix is a 2D array
            m, n = matrix.shape
            print(f"Matrix {i} shape: {m, n}")
            unique_values = np.unique(matrix)
            print(f"Unique values in matrix {i}: {unique_values}")

            # Convert matrix values to corresponding colors, ignoring -1
            matrix_colored = np.array(
                [
                    [
                        colors[element] if element != -1 else [255, 255, 255]
                        for element in row
                    ]
                    for row in matrix
                ]
            )
            axes[i].imshow(matrix_colored, interpolation='nearest')
            axes[i].set_title(f'({m} x {n})')
            axes[i].axis('off')
        except Exception as ex:
            print(f"Error processing matrix {i}: {ex}")
            pass

    plt.tight_layout()
    plt.show()

In [None]:
def pad_to_dimensions(grid, target_size, ignore_class=IGNORE_CLASS):
    h, w = grid.shape[0], grid.shape[1]

    # Calculate padding to match target_size
    pad_bottom = target_size[0] - h
    pad_right = target_size[1] - w

    # Pad within the original dimensions with zeros
    padded_grid = np.pad(
        grid,
        ((0, pad_bottom), (0, pad_right)),
        mode='constant',
        constant_values=IGNORE_CLASS,
    )

    # Pad the rest to maximum size of 30x30 with ignore_class
    max_size = (30, 30)
    if padded_grid.shape[0] < max_size[0] or padded_grid.shape[1] < max_size[1]:
        pad_bottom = max_size[0] - padded_grid.shape[0]
        pad_right = max_size[1] - padded_grid.shape[1]
        padded_grid = np.pad(
            padded_grid,
            ((0, pad_bottom), (0, pad_right)),
            mode='constant',
            constant_values=ignore_class,
        )

    return padded_grid

In [None]:
def rotate_grid(grid, angle):
    if angle == 0:
        return grid
    elif angle == 90:
        return np.rot90(grid, k=1, axes=(1, 0)).tolist()
    elif angle == 180:
        return np.rot90(grid, k=2, axes=(1, 0)).tolist()
    elif angle == 270:
        return np.rot90(grid, k=3, axes=(1, 0)).tolist()
    else:
        raise ValueError(
            "Unsupported rotation angle. Supported angles are 0, 90, and 270 degrees."
        )

In [None]:
def pad_to_final_size(grid, final_size=(30, 30), ignore_class=IGNORE_CLASS):
    h, w = grid.shape[0], grid.shape[1]

    # Calculate padding to match final_size
    pad_bottom = final_size[0] - h
    pad_right = final_size[1] - w

    # Pad within the original dimensions with zeros
    padded_grid = np.pad(
        grid,
        ((0, pad_bottom), (0, pad_right)),
        mode='constant',
        constant_values=ignore_class,
    )

    return padded_grid

In [None]:
def shift_grid(grid, step_right=0, step_down=0):
    """
    Shift the grid right and down by the specified steps, filling new cells with zeros.

    Args:
    grid (numpy.ndarray): The input grid to be shifted.
    step_right (int): The number of columns to shift right.
    step_down (int): The number of rows to shift down.

    Returns:
    numpy.ndarray: The shifted grid.

    Example usage:
    grid = np.array([[1, 2, 3,0], [4, 5, 6,0], [7, 8, 9,0],[0, 0, 0,0]])
    shifted_grid = shift_grid(grid, step_right=1, step_down=0)
    print(shifted_grid)
    """
    h, w = grid.shape
    shifted_grid = np.zeros_like(grid)

    # Shift right
    if step_right > 0:
        shifted_grid[:, step_right:] = grid[:, :-step_right]
    else:
        shifted_grid[:, :] = grid[:, :]

    # Shift down
    if step_down > 0:
        shifted_grid[step_down:, :] = shifted_grid[:-step_down, :]
        shifted_grid[:step_down, :] = 0
    else:
        shifted_grid[:, :] = shifted_grid[:, :]

    return shifted_grid

In [None]:
def scale_up_grid(grid, scale_factor):
    # Scale up a grid by repeating its elements along both axes.
    # Repeat elements `scale_factor` times along the row (axis=0) and column (axis=1).
    return np.repeat(np.repeat(grid, scale_factor, axis=0), scale_factor, axis=1)


def custom_zoom(input_grid, final_size=30):
    # Custom zoom function to scale the input grid up to a specified final size.
    # The input grid is scaled proportionally to the maximum dimension.

    # Determine the largest dimension (width or height) of the input grid.
    max_dim = max(input_grid.shape)

    # Calculate the scaling factor to ensure the larger dimension fits within the final size.
    scale_factor = final_size // max_dim

    # Scale up the input grid using the calculated scaling factor.
    scaled_input_grid = scale_up_grid(input_grid, scale_factor)

    # If the scaled grid does not match the final desired size (30x30 by default),
    # pad the grid to match the final size using the `pad_to_final_size` function.
    if scaled_input_grid.shape != (final_size, final_size):
        scaled_input_grid = pad_to_final_size(scaled_input_grid)

    # Return the scaled (and potentially padded) grid.
    return scaled_input_grid


def downscale_grid(grid, target_size):
    # Downscale a grid to a target size by taking the most frequent element (mode)
    # in each non-overlapping block of the original grid.

    # Determine the original size (height and width) of the input grid.
    original_size = grid.shape

    # Calculate the scaling factors for both dimensions (width and height).
    scale_factor_x = original_size[1] // target_size[1]  # Horizontal scaling factor
    scale_factor_y = original_size[0] // target_size[0]  # Vertical scaling factor

    # Initialize a new grid with the target size filled with zeros.
    downscaled_grid = np.zeros(target_size, dtype=grid.dtype)

    # Iterate through each block of the grid to downscale it to the target size.
    for i in range(target_size[0]):
        for j in range(target_size[1]):
            # Extract a block from the original grid that corresponds to the current position.
            block = grid[
                i * scale_factor_y : (i + 1) * scale_factor_y,
                j * scale_factor_x : (j + 1) * scale_factor_x,
            ]

            # Find the most frequent element (mode) in the current block and assign it to the downscaled grid.
            downscaled_grid[i, j] = np.bincount(block.flatten()).argmax()

    # Return the downscaled grid.
    return downscaled_grid

In [None]:
def prepare_data(challenges, task_id_to_index):
    input_grids, output_grids, task_idxs = [], [], []
    dimensions = {}

    for task_id, data in challenges.items():
        # task_ids.append(task_id)
        task_idx = task_id_to_index[task_id]
        dimensions[task_id] = {}
        dimensions[task_id]['inputs'] = []
        dimensions[task_id]['outputs'] = []
        previous_input_dimensions = (0, 0)

        ## Rotation
        train = [train for train in data['train']]
        temp = train.copy()
        for t in train:
            xinput = t['input']
            output = t['output']
            for angle in [0, 90, 180, 270]:
                temp.append(
                    {
                        "input": rotate_grid(xinput, angle),
                        "output": rotate_grid(output, angle),
                    }
                )

        data['train'] = temp

        ## 
        for sample in data['train']:
            input_grid = np.array(sample['input'], dtype=np.int32)
            output_grid = np.array(sample['output'], dtype=np.int32)

            # Store dimensions in dictionaries
            if previous_input_dimensions != input_grid.shape:
                dimensions[task_id]['inputs'].append(input_grid.shape)
                dimensions[task_id]['outputs'].append(output_grid.shape)
                previous_input_dimensions = input_grid.shape

            input_grid = custom_zoom(input_grid)
            output_grid = custom_zoom(output_grid)

            input_grids.append(torch.tensor(input_grid, dtype=torch.long))
            output_grids.append(torch.tensor(output_grid, dtype=torch.long))
            task_idxs.append(task_idx)

    return (torch.stack(input_grids), torch.stack(output_grids), dimensions)

In [None]:
def load_data(challenges_path, solutions_path):
    with open(challenges_path, 'r') as f:
        challenges = json.load(f)
    with open(solutions_path, 'r') as f:
        solutions = json.load(f)
    
    for task_id in challenges.keys():
        solution = solutions[task_id][0]
        challenges[task_id]['test'][0]['output'] = solution

    return challenges

## Our Transformer Implementation

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=7200):
        super(PositionalEncoding, self).__init__()
        self.encoding = self.generate_encoding(d_model, max_len)
        self.inference_encoding = self.generate_inference_encoding(d_model)

    def generate_encoding(self, d_model, max_len):
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        encoding = torch.zeros(max_len, d_model)
        encoding[:, 0::2] = torch.sin(position * div_term)
        encoding[:, 1::2] = torch.cos(position * div_term)
        encoding = encoding.unsqueeze(0)
        return encoding

    def generate_inference_encoding(self, d_model, height=30, width=30):
        num_positions = height * width
        position = torch.arange(0, num_positions).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        encoding = torch.zeros(num_positions, d_model)
        encoding[:, 0::2] = torch.sin(position * div_term)
        encoding[:, 1::2] = torch.cos(position * div_term)
        #         encoding = encoding.view(1, height, width, d_model)  # Reshape for 2D grid
        encoding = encoding.unsqueeze(0)

        return encoding

    def forward(self, x, mode="training"):

        log_message(f"x shape: {x.shape}")
        if x.size(0) > 1:
            x = x.reshape(batch_size, -1, 256)
        else:
            x = x.reshape(1, -1, 256)

        if x.dim() == 2:
            x = x.unsqueeze(1)
            print(f"x unsqueeze shape: {x.shape}")

        try:
            x = x + self.encoding[:, : x.size(1), : x.size(2)].to(x.device)
        except:
            print(f"x shape: {x.shape}")
            print(f"x size shape: {x.size}")
            print(
                f"x self.encoding[:, :x.size(1), :x.size(2)] shape: {self.encoding[:, :x.size(1), :x.size(2)].shape}"
            )

        return x

In [None]:
class CustomTransformerCategorical(nn.Module):
    def __init__(
        self,
        d_model,
        num_categories,
        nhead,
        num_encoder_layers,
        num_decoder_layers,
        dim_feedforward,
        dropout_rate=0.1,
    ):
        super(CustomTransformerCategorical, self).__init__()
        self.num_categories = (
            num_categories + 2
        )  # Accounting for <start> and <end> tokens
        self.grid_embedding = nn.Embedding(
            self.num_categories, d_model
        )  # Embedding layer for categories
        self.grid_positional_encoding = PositionalEncoding(d_model)
        self.grid_positional_encoding_inference = PositionalEncoding(
            d_model, max_len=30 * 30
        )
        self.d_model = d_model

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=self.d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout_rate,
            ),
            num_layers=num_encoder_layers,
        )

        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model=self.d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout_rate,
            ),
            num_layers=num_decoder_layers,
        )

        self.fc_out = nn.Linear(
            self.d_model, self.num_categories
        )  # Output layer for categories
        self.dropout = nn.Dropout(dropout_rate)

    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), 1).transpose(0, 1)
        mask = (
            mask.float()
            .masked_fill(mask == 1, float('-inf'))
            .masked_fill(mask == 0, float(0.0))
        )
        return mask
    
    def generate_upper_triangular_mask(self, N, M):
        """Generate a NxN mask with M tokens already in context window
        """
        mask = torch.zeros(N, N)
        mlen = N - M
        mask[:mlen, -mlen:] = torch.triu(torch.ones(mlen, mlen), diagonal=0)
        mask = (
            mask.float()
            .masked_fill(mask == 1, float('-inf'))
        )
        return mask

    def forward(self, encoder_input, decoder_in, src_mask=None, tgt_mask=None):
        memories = []
        for sample in encoder_input:
            # Embedding and positional encoding for encoder sequence
            encoder_emb = self.grid_embedding(sample)
            encoder_emb = self.dropout(encoder_emb)  # Apply dropout after embedding
            encoder_emb_pos = self.grid_positional_encoding(encoder_emb)

            # Encoder
            memory = self.encoder(encoder_emb_pos.transpose(0, 1), src_key_padding_mask=src_mask)
            memories.append(memory)
        
        encoder_mem = torch.mean(torch.stack(memories), dim=0)

        # Embedding and positional encoding for output grids
        decoder_emb = self.grid_embedding(decoder_in)
        decoder_emb = self.dropout(decoder_emb)
        decoder_emb_pos = self.grid_positional_encoding(decoder_emb)

        # decoder input SEP 及之前的长度
        decoder_mem_length = decoder_in.index(SEP_TOKEN) + 1

        # tgt_mask 去掉已知（SEP及之前）的长度，其余是一个上三角矩阵
        if tgt_mask is None:
            tgt_mask = self.generate_upper_triangular_mask(
                decoder_emb_pos.size(1), decoder_mem_length
            )
            print(tgt_mask)

        # Decoder
        # TODO 这样截取是否正确？
        transformer_output = self.decoder(
            decoder_emb_pos.transpose(0, 1),
            encoder_mem,
            tgt_mask=tgt_mask,
            memory_key_padding_mask=src_mask,
        )[decoder_mem_length:]
        transformer_output = self.dropout(transformer_output)

        # Logits for categories
        output = self.fc_out(transformer_output.transpose(0, 1))

        return output

    # TODO 改写 predict
    def predict(self, input_grid):
        self.eval()
        with torch.no_grad():
            input_grid_tensor = (
                torch.tensor(input_grid, dtype=torch.long).unsqueeze(0).to(DEVICE)
            )
            # input_grid_tensor = torch.tensor(input_grid, dtype=torch.long).to(device)

            input_emb = self.grid_embedding(input_grid_tensor)
            log_message(f"input_emb.shape: {input_emb.shape} ")
            grid_emb = self.grid_positional_encoding(input_emb)
            log_message(f"grid_emb.shape: {grid_emb.shape}")
            memory = self.encoder(grid_emb.transpose(0, 1))
            print(f"memory: {memory}")

            # Initialize the output grid with the <start> token
            output_grids = torch.full(
                (input_grid_tensor.size(0), 1),
                START_TOKEN,
                dtype=torch.long,
                device=DEVICE,
            )

            # Decode the output grid
            for _ in range(1, grid_emb.size(1)):
                output_emb = self.grid_embedding(output_grids)
                output_emb = self.grid_positional_encoding_inference(output_emb)
                transformer_output = self.decoder(output_emb.transpose(0, 1), memory)
                output_logits = self.fc_out(transformer_output.transpose(0, 1))

                next_token = torch.argmax(output_logits[:, -1, :], dim=-1, keepdim=True)
                output_grids = torch.cat([output_grids, next_token], dim=1)

                if next_token.item() == END_TOKEN:
                    break

            output_grids = output_grids.squeeze().cpu().numpy()

            return output_grids

In [None]:
def add_start_end_tokens(grid):
    return [START_TOKEN] + grid + [END_TOKEN]


def remove_start_end_tokens(grid):
    return grid[1:-1]


def add_special_tokens_to_grids(grid1, grid2):
    return [START_TOKEN] + grid1 + [SEP_TOKEN] + grid2 + [END_TOKEN]

In [None]:
def prepare_data_with_tokens(
    input_grids, output_grids, start_token=START_TOKEN, end_token=END_TOKEN
):
    input_grids_with_tokens = []
    output_grids_with_tokens = []

    for input_grid, output_grid in zip(input_grids, output_grids):
        # Flatten the input and output grids
        input_grid_flat = input_grid.flatten().tolist()
        output_grid_flat = output_grid.flatten().tolist()

        # Add start and end tokens
        input_grid_with_tokens = add_start_end_tokens(input_grid_flat)
        output_grid_with_tokens = add_start_end_tokens(output_grid_flat)

        # Convert to tensors
        input_grids_with_tokens.append(
            torch.tensor(input_grid_with_tokens, dtype=torch.long).to(DEVICE)
        )
        output_grids_with_tokens.append(
            torch.tensor(output_grid_with_tokens, dtype=torch.long).to(DEVICE)
        )

    return input_grids_with_tokens, output_grids_with_tokens

In [None]:
def prepare_encoder_data_with_tokens(input_grids, output_grids):
    all_data = []

    for input_grid, output_grid in zip(input_grids, output_grids):
        # Flatten the input and output grids
        input_grid_flat = input_grid.flatten().tolist()
        output_grid_flat = output_grid.flatten().tolist()

        # Add start and end tokens
        tokens = add_special_tokens_to_grids(input_grid_flat, output_grid_flat)

        # Convert to tensors
        all_data.append(torch.tensor(tokens, dtype=torch.long).to(DEVICE))

    return all_data

In [1]:
def train_model(dataset, model, criterion, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        for step, data in enumerate(dataset):
            # Encoder data
            encoder_inputs = [np.array(x['input'], dtype=np.int32) for x in data['train']]
            encoder_outputs = [np.array(x['output'], dtype=np.int32) for x in data['train']]

            # List[num_samples, seq_length]
            encoder_tokens = prepare_encoder_data_with_tokens(
                encoder_inputs, encoder_outputs
            )

            # Decoder data
            

            optimizer.zero_grad()

            # log_message(f"src shape before embedding: {input_grids_batch.shape}", "blue")

            output_logits = model(input_grids_batch, output_grids_batch)
            # output_logits = remove_start_end_tokens(output_logits)
            if step % 20 == 0:
                predictions = torch.argmax(output_logits, dim=-1).cpu().detach().numpy()
                actual_outputs = output_grids_batch.cpu().detach().numpy()

                paint(
                    remove_start_end_tokens(predictions[0]).reshape(30, 30),
                    remove_start_end_tokens(actual_outputs[0]).reshape(30, 30),
                )

            # Compute loss using the logits and true class indices
            loss = criterion(
                output_logits.view(-1, NUM_CLASSES + 2), output_grids_batch.view(-1)
            )  # Flatten to match shapes
            # loss = ignore_pad_tokens_loss(output_logits, output_grids_batch)

            loss.backward()
            optimizer.step()

            if step % 10 == 0:  # Print progress every 10 steps
                print(
                    f'Epoch [{epoch + 1}/{num_epochs}], Step [{step}/{len(data_loader)}], Loss: {loss.item():.4f}'
                )

In [None]:
if KAGGLE:
    train_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_training_challenges.json'
    train_solutions_path = '/kaggle/input/arc-prize-2024/arc-agi_training_solutions.json'
    eval_challenges_path = '/kaggle/input/arc-prize-2024/arc-agi_evaluation_challenges.json'
    eval_solutions_path = '/kaggle/input/arc-prize-2024/arc-agi_evaluation_solutions.json'
    submission_path = '/kaggle/working/submission.json'
else:
    train_challenges_path = './data/arc-agi_training_challenges.json'
    train_solutions_path = './data/arc-agi_training_solutions.json'
    eval_challenges_path = './data/arc-agi_evaluation_challenges.json'
    eval_solutions_path = './data/arc-agi_evaluation_solutions.json'
    submission_path = 'submission.json'

train_data = load_data(train_challenges_path, train_solutions_path)
val_data = load_data(eval_challenges_path, eval_solutions_path)

# train_task_ids = list(challenges.keys())
# task_id_to_index = {task_id: idx for idx, task_id in enumerate(train_task_ids)}

# input_grids, output_grids, dimensions = prepare_data(challenges, task_id_to_index)

# input_grids, output_grids, task_ids = prepare_data(challenges, task_id_to_index)
# val_output_grids, val_task_ids_indexes, val_task_ids = prepare_validation_data(solutions, task_id_to_index)

num_categories = NUM_CLASSES  # Number of categories (0-9)

In [None]:
# from torch.utils.data import DataLoader, TensorDataset

# # Create the dataset
# dataset = TensorDataset(input_grids, output_grids)

# # Shuffle the dataset before splitting
# indices = torch.randperm(len(dataset))

# # Split the dataset into training and validation sets (80-20 split)
# train_size = int(0.95 * len(indices))
# train_indices = indices[:train_size]
# val_indices = indices[train_size:]

# # Create subsets
# train_dataset = torch.utils.data.Subset(dataset, train_indices)
# val_dataset = torch.utils.data.Subset(dataset, val_indices)

# # Create DataLoaders for training and validation
# batch_size = 4
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
def reset_cuda_memory():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

In [None]:
# Define the model, loss function, and optimizer
num_categories = 10

if 'model' in globals():
    try:
        del model
        reset_cuda_memory()
    except:
        pass

model = CustomTransformerCategorical(
    256,  # Match the embedding dimension
    num_categories,
    nhead=4,  # 8,6,6
    num_encoder_layers=4,
    num_decoder_layers=4,
    dim_feedforward=2048,
    dropout_rate=0.4,
).to(DEVICE)

In [None]:
criterion = nn.CrossEntropyLoss()  # For categorical output
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [None]:
num_epochs = 1
train_model(train_data, model, criterion, optimizer, num_epochs)

In [None]:
torch.save(model.state_dict(), '/kaggle/working/arc_categorical_avg_model.pth')

In [None]:
saved_model = CustomTransformerCategorical(
    256,  # Match the embedding dimension
    num_categories,
    nhead=4,  # 8,6,6
    num_encoder_layers=4,
    num_decoder_layers=4,
    dim_feedforward=2048,
).to(DEVICE)
saved_model.load_state_dict(torch.load('/kaggle/working/arc_categorical_avg_model.pth'))
saved_model.to(DEVICE)

In [None]:
def inference(task_id):
    test_grid = challenges[task_id]['test'][0]['input']
    test_grid = np.array(test_grid, dtype=np.int32)
    original_shape = test_grid.shape
    test_grid_padded = custom_zoom(test_grid)

    paint(test_grid_padded)

    t_id = task_id_to_index[task_id]

    print(t_id)

    test_grid_padded = torch.tensor(test_grid_padded).to(DEVICE)

    print(f"test_grid_padded shape: {test_grid_padded.shape}")
    output = saved_model.predict(test_grid_padded.reshape(1, 30, 30))
    return output


for i in range(13, 14, 1):
    task_id = train_task_ids[i]
    t_id = task_id_to_index[task_id]
    print(t_id)
    output = inference(task_id)
    paint(output.reshape(30, 30))