# Training transformer

## Import package

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, GPT2Config, set_seed
from datasets import load_dataset
from typing import Dict, Any, Optional


In [None]:
set_seed(0)

In [None]:
from typing import List, Tuple, Union
import torch
from torch.utils.data import Dataset

class PixelSequenceDataset(Dataset):
    def __init__(self, data: List[List[int]], mode: str = "train"):
        self.data = data
        self.mode = mode
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
        # Union - 多态，匹配多种类型
        '''
        train: return -> (input_ids, labels)
        dev: return -> (input_ids, labels)
        test: -> input_ids
        '''
        sequence = self.data[idx] # 拿出指定序号的色彩列
        # 注意这里的transformer训练特点
        if self.mode == "train":
            input_ids = torch.tensor(sequence[:-1], dtype = torch.long) # 用来输入decoder的是最后一个之前的pixel
            labels = torch.tensor(sequence[1:], dtype = torch.long) # 用来label的是除了第一个之外的label
            # 因为这里输入和输出都是固定长度的，所以不需要什么起始和终止，也不需要tokenizer
            return input_ids, labels
        elif self.mode == "dev":
            # 验证的话，后160像素用来验证
            input_ids = torch.tensor(sequence[:-160], dtype=torch.long)
            labels = torch.tensor(sequence[-160:], dtype=torch.long)
            return input_ids, labels
        elif self.mode == "test":
            # test的话就全拿走
            input_ids = torch.tensor(sequence, dtype=torch.long)
            return input_ids

        raise ValueError(f"Invalid mode: {self.mode}. Choose from 'train', 'dev', or 'test'.")
    
    
    

## download dataset

In [None]:
# Load the pokemon dataset from Hugging Face Hub
pokemon_dataset = load_dataset("lca0503/ml2025-hw4-pokemon")
# (index, 400, num_classes)

# Load the colormap from Hugging Face Hub
colormap = list(load_dataset("lca0503/ml2025-hw4-colormap")["train"]["color"])
# colarmap 是序号和颜色的对应表 (colar_classes, 3(rgb))

# Define number of classes
num_classes = len(colormap)

# Define batch size
batch_size = 16

# === Prepare Dataset and DataLoader for Training ===
train_dataset: PixelSequenceDataset = PixelSequenceDataset(
    pokemon_dataset["train"]["pixel_color"], mode="train"
)
train_dataloader: DataLoader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)

# === Prepare Dataset and DataLoader for Validation ===
dev_dataset: PixelSequenceDataset = PixelSequenceDataset(
    pokemon_dataset["dev"]["pixel_color"], mode="dev"
)
dev_dataloader: DataLoader = DataLoader(
    dev_dataset, batch_size=batch_size, shuffle=False
)

# === Prepare Dataset and DataLoader for Testing ===
test_dataset: PixelSequenceDataset = PixelSequenceDataset(
    pokemon_dataset["test"]["pixel_color"], mode="test"
)
test_dataloader: DataLoader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False
)

In [None]:
def pixel_to_image(pixel_color: List[int], colormap: List[List[int]]) -> Image.Image:
    """
    Converts a list of pixel indices into a 20x20 RGB image using a colormap.

    Args:
        pixel_color (List[int]): A list of pixel indices representing colors.
        colormap (List[List[int]]): A list where each index maps to an RGB color [R, G, B].

    Returns:
        Image.Image: A PIL Image object representing the reconstructed image.
    """
    # Ensure the pixel_color list has at least 400 elements (pad with 0s if needed)
    while len(pixel_color) < 400:
        pixel_color.append(0)

    # Map pixel indices to actual RGB colors using the colormap
    pixel_data = [colormap[pixel] for pixel in pixel_color]

    # Convert to numpy array and reshape to 20x20x3 (RGB image)
    image_array = np.array(pixel_data, dtype=np.uint8).reshape(20, 20, 3)

    # Create a PIL Image from the array
    image = Image.fromarray(image_array)

    return image

def show_images(images: List[Image.Image]) -> None:
    """
    Displays a grid of up to 96 images using Matplotlib.

    Args:
        images (List[Image.Image]): A list of PIL Image objects to display.

    Returns:
        None
    """
    num_images = min(96, len(images))  # Limit to 96 images

    # Set up the figure size and grid layout (6 rows, 16 columns)
    fig, axes = plt.subplots(6, 16, figsize=(16, 6))
    axes = axes.flatten()  # Flatten to make iteration easier
    # 最多输出96张

    # Loop through images and display each one in the grid
    for i, ax in enumerate(axes):
        if i < num_images:
            ax.imshow(images[i])
            ax.axis('off')  # Hide axis
        else:
            ax.axis('off')  # Hide unused subplots

    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.show()

In [None]:
# Visualize train images
train_images = [pixel_to_image(data["pixel_color"], colormap) for data in pokemon_dataset["train"]]
show_images(train_images)

In [None]:
# Visualize test images
test_images = [pixel_to_image(data["pixel_color"], colormap) for data in pokemon_dataset["test"]]
show_images(test_images)

## prepare model

### Model Configuration
Here, we define the model configuration, including the architecture and key hyperparameters such as the number of attention heads, layers, embedding size, and more.
*   Hint 1: Adjust hyperparameters here for improved performance.
*   Hint 2: Experiment with different model architectures, such as Llama, Mistral, or Qwen, to enhance performance.
  * [LlamaConfig](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaConfig)
  * [MistralConfig](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralConfig)
  * [Qwen2Config](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Config)

In [None]:
# Define GPT-2 model configuration as a dictionary
gpt2_config = {
    "activation_function": "gelu_new",    # Activation function used in the model
    "architectures": ["GPT2LMHeadModel"],  # Specifies the model type
    "attn_pdrop": 0.1,            # Dropout rate for attention layers
    "embd_pdrop": 0.1,            # Dropout rate for embeddings
    "initializer_range": 0.02,        # Standard deviation for weight initialization
    "layer_norm_epsilon": 1e-05,       # Small constant to improve numerical stability in layer norm
    "model_type": "gpt2",           # Type of model
    "n_ctx": 128,               # Context size (maximum sequence length)
    "n_embd": 64,              # Embedding size
    "n_head": 2,               # Number of attention heads
    "n_layer": 2,              # Number of transformer layers
    "n_positions": 400,           # Maximum number of token positions
    "resid_pdrop": 0.1,           # Dropout rate for residual connections
    "vocab_size": num_classes,       # Number of unique tokens in vocabulary
    "pad_token_id": None,          # Padding token ID (None means no padding token)
    "eos_token_id": None,          # End-of-sequence token ID (None means not explicitly defined)
}

# Load GPT-2 model configuration from dictionary
config = GPT2Config.from_dict(gpt2_config)

In [None]:
# Load the model using the configuration defined above
model = AutoModelForCausalLM.from_config(config)

print(model)

# Count trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Trainable Parameters: {trainable_params:,}")

## Train and Inference

### Training Arguments
Here, we define the number of epochs for training, the learning rate, the optimizer, and the loss function.
*   Hint 3: Adjust the number of epochs and learning rate here to improve performance.

In [None]:
# Training Parameters
epochs = 50                                      # Number of training epochs
learning_rate = 1e-3                                 # Learning rate for optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")     # Check if CUDA is available for GPU
save_dir = "checkpoints"                               # Directory to save model checkpoints

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()                          # Loss function for classification tasks
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1) # AdamW optimizer with weight decay

In [None]:
def save_model(model: torch.nn.Module, optimizer: torch.optim.Optimizer, epoch: int, loss: float, save_dir: str, filename: str = "best_model.pth") -> None:
    """
    Saves the model state, optimizer state, current epoch, and loss to a specified directory.

    Args:
        model (torch.nn.Module): The PyTorch model to be saved.
        optimizer (torch.optim.Optimizer): The optimizer whose state will be saved.
        epoch (int): The current epoch number (used for checkpointing).
        loss (float): The current loss value to track model performance.
        save_dir (str): The directory where the model checkpoint will be saved.
        filename (str, optional): The name of the file to save the model. Defaults to "best_model.pth".

    Returns:
        None
    """
    # Construct the full path for saving the model checkpoint
    save_path = os.path.join(save_dir, filename)

    # Save the model, optimizer state, and additional metadata (epoch and loss)
    torch.save({
        'epoch': epoch + 1,                # Save epoch + 1 for easier tracking
        'model_state_dict': model.state_dict(),       # Save model weights
        'optimizer_state_dict': optimizer.state_dict(),  # Save optimizer state (important for resuming training)
        'loss': loss                     # Save the current loss value
    }, save_path)

    # Print a confirmation message indicating the model has been saved
    print(f"Model saved at {save_path} (Loss: {loss:.4f}, Epoch: {epoch + 1})")

### Train

We save the checkpoint with the lowest training loss since validation set reconstruction accuracy doesn't directly reflect the model's image generation quality.
*   Hint 4: Train a classifier to check if an image looks like a Pokémon or not. (Optional)

In [None]:
# Create save directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Initialize best loss as positive infinity for comparison during model checkpointing
best_loss: float = float('inf')
# Move model to the appropriate device (GPU or CPU)
model.to(device)

# Training Loop
for epoch in range(epochs):
    model.train()  # Set the model to training mode
    epoch_loss = 0  # Initialize the epoch loss

    # Iterate over training data batches
    for input_ids, labels in tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}/{epochs}"):
        input_ids, labels = input_ids.to(device), labels.to(device)  # Move data to the same device as the model

        # Forward pass through the model to get logits (output probabilities)
        outputs = model(input_ids=input_ids).logits.view(-1, config.vocab_size)
        labels = labels.view(-1)  # Flatten labels to match logits shape

        # Calculate loss using CrossEntropyLoss
        loss = criterion(outputs, labels)

        # Backpropagation and optimizer step
        optimizer.zero_grad()  # Reset gradients to zero
        loss.backward()     # Compute gradients
        optimizer.step()     # Update model weights

        # Accumulate the loss for the epoch
        epoch_loss += loss.item()

    # Compute average epoch loss
    avg_epoch_loss = epoch_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_epoch_loss:.4f}")

    # Evaluation Loop (Validation)
    model.eval()      # Set the model to evaluation mode (disables dropout, etc.)
    total_accuracy = 0  # Initialize total accuracy
    num_batches = 0   # Initialize batch counter

    with torch.no_grad():  # Disable gradient calculation for validation
        # Iterate over validation data batches
        for inputs, labels in tqdm(dev_dataloader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)  # Move validation data to device
            attention_mask = torch.ones_like(inputs)          # Attention mask to ensure valid token positions

            # Perform batch inference using the model
            generated_outputs = model.generate(inputs, attention_mask=attention_mask, max_length=400)

            # Extract the last 160 tokens from generated outputs and labels
            generated_outputs = generated_outputs[:, -160:]

            # Calculate accuracy for the batch
            accuracy = (generated_outputs == labels).float().mean().item()
            total_accuracy += accuracy
            num_batches += 1

    # Compute average reconstruction accuracy for the epoch
    avg_accuracy = total_accuracy / num_batches
    print(f"Epoch {epoch + 1}/{epochs}, Reconstruction Accuracy: {avg_accuracy:.4f}")

    # If the current epoch loss is better (lower) than the best loss, save the model
    if avg_epoch_loss < best_loss:
        best_loss = avg_epoch_loss                   # Update best loss
        save_model(model, optimizer, epoch, best_loss, save_dir)  # Save the model with the best loss

Training Epoch 1/50: 100%|██████████| 40/40 [00:00<00:00, 41.69it/s]


Epoch 1/50, Loss: 1.4755


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.67it/s]


Epoch 1/50, Reconstruction Accuracy: 0.3078
Model saved at checkpoints/best_model.pth (Loss: 1.4755, Epoch: 1)


Training Epoch 2/50: 100%|██████████| 40/40 [00:00<00:00, 43.35it/s]


Epoch 2/50, Loss: 1.4650


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s]


Epoch 2/50, Reconstruction Accuracy: 0.3329
Model saved at checkpoints/best_model.pth (Loss: 1.4650, Epoch: 2)


Training Epoch 3/50: 100%|██████████| 40/40 [00:00<00:00, 43.30it/s]


Epoch 3/50, Loss: 1.4609


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.69it/s]


Epoch 3/50, Reconstruction Accuracy: 0.3189
Model saved at checkpoints/best_model.pth (Loss: 1.4609, Epoch: 3)


Training Epoch 4/50: 100%|██████████| 40/40 [00:00<00:00, 44.39it/s]


Epoch 4/50, Loss: 1.4569


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.69it/s]


Epoch 4/50, Reconstruction Accuracy: 0.3291
Model saved at checkpoints/best_model.pth (Loss: 1.4569, Epoch: 4)


Training Epoch 5/50: 100%|██████████| 40/40 [00:00<00:00, 43.38it/s]


Epoch 5/50, Loss: 1.4531


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.68it/s]


Epoch 5/50, Reconstruction Accuracy: 0.2980
Model saved at checkpoints/best_model.pth (Loss: 1.4531, Epoch: 5)


Training Epoch 6/50: 100%|██████████| 40/40 [00:00<00:00, 43.41it/s]


Epoch 6/50, Loss: 1.4488


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s]


Epoch 6/50, Reconstruction Accuracy: 0.2934
Model saved at checkpoints/best_model.pth (Loss: 1.4488, Epoch: 6)


Training Epoch 7/50: 100%|██████████| 40/40 [00:00<00:00, 43.91it/s]


Epoch 7/50, Loss: 1.4475


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s]


Epoch 7/50, Reconstruction Accuracy: 0.3274
Model saved at checkpoints/best_model.pth (Loss: 1.4475, Epoch: 7)


Training Epoch 8/50: 100%|██████████| 40/40 [00:00<00:00, 43.38it/s]


Epoch 8/50, Loss: 1.4427


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s]


Epoch 8/50, Reconstruction Accuracy: 0.3239
Model saved at checkpoints/best_model.pth (Loss: 1.4427, Epoch: 8)


Training Epoch 9/50: 100%|██████████| 40/40 [00:00<00:00, 45.62it/s]


Epoch 9/50, Loss: 1.4395


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.69it/s]


Epoch 9/50, Reconstruction Accuracy: 0.3095
Model saved at checkpoints/best_model.pth (Loss: 1.4395, Epoch: 9)


Training Epoch 10/50: 100%|██████████| 40/40 [00:00<00:00, 42.71it/s]


Epoch 10/50, Loss: 1.4399


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.71it/s]


Epoch 10/50, Reconstruction Accuracy: 0.3264


Training Epoch 11/50: 100%|██████████| 40/40 [00:00<00:00, 43.06it/s]


Epoch 11/50, Loss: 1.4318


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.69it/s]


Epoch 11/50, Reconstruction Accuracy: 0.3441
Model saved at checkpoints/best_model.pth (Loss: 1.4318, Epoch: 11)


Training Epoch 12/50: 100%|██████████| 40/40 [00:00<00:00, 43.98it/s]


Epoch 12/50, Loss: 1.4273


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s]


Epoch 12/50, Reconstruction Accuracy: 0.3125
Model saved at checkpoints/best_model.pth (Loss: 1.4273, Epoch: 12)


Training Epoch 13/50: 100%|██████████| 40/40 [00:00<00:00, 43.79it/s]


Epoch 13/50, Loss: 1.4299


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.69it/s]


Epoch 13/50, Reconstruction Accuracy: 0.3577


Training Epoch 14/50: 100%|██████████| 40/40 [00:00<00:00, 43.86it/s]


Epoch 14/50, Loss: 1.4268


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.71it/s]


Epoch 14/50, Reconstruction Accuracy: 0.2975
Model saved at checkpoints/best_model.pth (Loss: 1.4268, Epoch: 14)


Training Epoch 15/50: 100%|██████████| 40/40 [00:01<00:00, 39.17it/s]


Epoch 15/50, Loss: 1.4221


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.68it/s]


Epoch 15/50, Reconstruction Accuracy: 0.3272
Model saved at checkpoints/best_model.pth (Loss: 1.4221, Epoch: 15)


Training Epoch 16/50: 100%|██████████| 40/40 [00:01<00:00, 36.51it/s]


Epoch 16/50, Loss: 1.4196


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.69it/s]


Epoch 16/50, Reconstruction Accuracy: 0.3290
Model saved at checkpoints/best_model.pth (Loss: 1.4196, Epoch: 16)


Training Epoch 17/50: 100%|██████████| 40/40 [00:01<00:00, 35.31it/s]


Epoch 17/50, Loss: 1.4161


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.68it/s]


Epoch 17/50, Reconstruction Accuracy: 0.3164
Model saved at checkpoints/best_model.pth (Loss: 1.4161, Epoch: 17)


Training Epoch 18/50: 100%|██████████| 40/40 [00:01<00:00, 35.29it/s]


Epoch 18/50, Loss: 1.4109


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.69it/s]


Epoch 18/50, Reconstruction Accuracy: 0.3055
Model saved at checkpoints/best_model.pth (Loss: 1.4109, Epoch: 18)


Training Epoch 19/50: 100%|██████████| 40/40 [00:01<00:00, 35.85it/s]


Epoch 19/50, Loss: 1.4103


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.71it/s]


Epoch 19/50, Reconstruction Accuracy: 0.3201
Model saved at checkpoints/best_model.pth (Loss: 1.4103, Epoch: 19)


Training Epoch 20/50: 100%|██████████| 40/40 [00:01<00:00, 35.35it/s]


Epoch 20/50, Loss: 1.4054


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s]


Epoch 20/50, Reconstruction Accuracy: 0.3189
Model saved at checkpoints/best_model.pth (Loss: 1.4054, Epoch: 20)


Training Epoch 21/50: 100%|██████████| 40/40 [00:01<00:00, 35.36it/s]


Epoch 21/50, Loss: 1.4042


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.71it/s]


Epoch 21/50, Reconstruction Accuracy: 0.2882
Model saved at checkpoints/best_model.pth (Loss: 1.4042, Epoch: 21)


Training Epoch 22/50: 100%|██████████| 40/40 [00:01<00:00, 35.63it/s]


Epoch 22/50, Loss: 1.4056


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s]


Epoch 22/50, Reconstruction Accuracy: 0.2552


Training Epoch 23/50: 100%|██████████| 40/40 [00:00<00:00, 42.93it/s]


Epoch 23/50, Loss: 1.3995


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s]


Epoch 23/50, Reconstruction Accuracy: 0.3168
Model saved at checkpoints/best_model.pth (Loss: 1.3995, Epoch: 23)


Training Epoch 24/50: 100%|██████████| 40/40 [00:00<00:00, 43.66it/s]


Epoch 24/50, Loss: 1.3992


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.69it/s]


Epoch 24/50, Reconstruction Accuracy: 0.2957
Model saved at checkpoints/best_model.pth (Loss: 1.3992, Epoch: 24)


Training Epoch 25/50: 100%|██████████| 40/40 [00:00<00:00, 43.86it/s]


Epoch 25/50, Loss: 1.3987


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.71it/s]


Epoch 25/50, Reconstruction Accuracy: 0.3231
Model saved at checkpoints/best_model.pth (Loss: 1.3987, Epoch: 25)


Training Epoch 26/50: 100%|██████████| 40/40 [00:00<00:00, 43.46it/s]


Epoch 26/50, Loss: 1.3945


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s]


Epoch 26/50, Reconstruction Accuracy: 0.3259
Model saved at checkpoints/best_model.pth (Loss: 1.3945, Epoch: 26)


Training Epoch 27/50: 100%|██████████| 40/40 [00:00<00:00, 44.49it/s]


Epoch 27/50, Loss: 1.3926


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s]


Epoch 27/50, Reconstruction Accuracy: 0.2769
Model saved at checkpoints/best_model.pth (Loss: 1.3926, Epoch: 27)


Training Epoch 28/50: 100%|██████████| 40/40 [00:00<00:00, 43.95it/s]


Epoch 28/50, Loss: 1.3877


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.69it/s]


Epoch 28/50, Reconstruction Accuracy: 0.2775
Model saved at checkpoints/best_model.pth (Loss: 1.3877, Epoch: 28)


Training Epoch 29/50: 100%|██████████| 40/40 [00:00<00:00, 44.11it/s]


Epoch 29/50, Loss: 1.3845


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.70it/s]


Epoch 29/50, Reconstruction Accuracy: 0.3395
Model saved at checkpoints/best_model.pth (Loss: 1.3845, Epoch: 29)


Training Epoch 30/50: 100%|██████████| 40/40 [00:00<00:00, 43.76it/s]


Epoch 30/50, Loss: 1.3842


Evaluating: 100%|██████████| 5/5 [00:01<00:00,  2.68it/s]


Epoch 30/50, Reconstruction Accuracy: 0.3191
Model saved at checkpoints/best_model.pth (Loss: 1.3842, Epoch: 30)


Training Epoch 31/50: 100%|██████████| 40/40 [00:00<00:00, 44.12it/s]


Epoch 31/50, Loss: 1.3821


Evaluating:  80%|████████  | 4/5 [00:01<00:00,  2.72it/s]

In [None]:
# Load the best model from the saved checkpoint
best_model_path = os.path.join(save_dir, "best_model.pth")              # Path to the best model checkpoint
checkpoint = torch.load(best_model_path, weights_only=True, map_location=device)  # Load checkpoint from the file
model.load_state_dict(checkpoint["model_state_dict"])                  # Load the model weights from checkpoint
model.eval()                                        # Set the model to evaluation mode (disables dropout, etc.)

# Testing Loop with Batch Inference
results: list = []  # List to store the generated sequences from the model

with torch.no_grad():  # Disable gradient calculations for inference
    # Iterate over test data in batches
    for inputs in tqdm(test_dataloader, desc="Generating Outputs"):
        inputs = inputs.to(device)         # Move model to the appropriate device (GPU or CPU)
        attention_mask = torch.ones_like(inputs)  # Attention mask (ensure valid token positions)

        # Generate predictions for the entire batch
        generated_outputs = model.generate(inputs, attention_mask=attention_mask, max_length=400)

        # Convert batch outputs to a list and append to results
        batch_results = generated_outputs.cpu().numpy().tolist()
        results.extend(batch_results)  # Extend the results list with batch results

# Save the results to a file
output_file: str = "reconstructed_results.txt"  # File to save the output sequences
with open(output_file, "w") as f:
    # Write each sequence to the file
    for seq in results:
        f.write(" ".join(map(str, seq)) + "\n")

print(f"Reconstructed results saved to {output_file}")  # Confirmation message

In [None]:
# Visualize generated test images
predicted_images = [pixel_to_image(sequence, colormap) for sequence in results]
show_images(predicted_images)