In [1]:
import os
import logging

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger()

# Create a directory for saving checkpoints
os.makedirs('checkpoints', exist_ok=True)

class ArcDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        if self.transform:
            # TODO: Do we want to transform the original? If so, how do we know that both should still be close in embedding space?
            # img1 = self.transform(img)
            img1 = img
            img2 = self.transform(img)
        else:
            img1, img2 = img, img
        return img1, img2  # Return twice for SimCLR

# Transformer Encoder
class TransformerEncoder(nn.Module):
    def __init__(self, img_size, d_model, num_layers, num_heads, ff_dim, dropout, use_positional_encoding=True):
        super(TransformerEncoder, self).__init__()
        self.patch_size = 1  # Each pixel is a patch
        self.num_patches = img_size * img_size
        
        # Patch embedding
        self.embedding = nn.Linear(self.patch_size, d_model)
        
        # Positional encoding
        self.use_positional_encoding = use_positional_encoding
        if self.use_positional_encoding:
            self.positional_encoding = nn.Parameter(torch.zeros(1, self.num_patches, d_model))
        
        # Transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=ff_dim, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
    def forward(self, x):
        # Flatten image into patches
        x = x.view(x.size(0), -1, self.patch_size)  # (batch_size, num_patches, patch_size)
        
        # Linear embedding
        x = self.embedding(x)
        
        # Add positional encoding if enabled
        if self.use_positional_encoding:
            x += self.positional_encoding
        
        # Pass through transformer encoder
        x = self.transformer_encoder(x)
        
        # Global average pooling
        x = torch.mean(x, dim=1)
        return x

# Define the projection head
class ProjectionHead(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ProjectionHead, self).__init__()
        inter_dim = 128
        self.fc1 = nn.Linear(input_dim, inter_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(inter_dim, output_dim)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the SimCLR model with transformer encoder
class SimCLR(nn.Module):
    def __init__(self, transformer_encoder, projection_dim):
        super(SimCLR, self).__init__()
        self.encoder = transformer_encoder
        self.projection_head = ProjectionHead(transformer_encoder.embedding.out_features, projection_dim)

    def forward(self, x):
        x = self.encoder(x)
        x = self.projection_head(x)
        return x

# NT-Xent Loss remains the same
class NTXentLoss(nn.Module):
    def __init__(self, batch_size, temperature):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, z_i, z_j):
        N = 2 * self.batch_size
        z = torch.cat((z_i, z_j), dim=0)  # Concatenate along batch dimension
        
        # Compute cosine similarity matrix
        sim_matrix = torch.mm(z, z.t().contiguous()) / self.temperature

        # Mask the diagonal (self-similarities) by setting them to a large negative value
        mask = torch.eye(N, device=sim_matrix.device).bool()
        sim_matrix.masked_fill_(mask, -1e9)

        # Labels for positive pairs (originally batch_size, then indices shift by batch_size)
        labels = torch.cat([torch.arange(self.batch_size, device=z.device) + self.batch_size * i for i in range(2)])

        # Compute loss using the similarity matrix and the positive pair indices
        loss = self.criterion(sim_matrix, labels)
        return loss

# TODO: Replace with relevant ARC-AGI augmentations
augmentation = transforms.Compose([
    transforms.RandomResizedCrop(size=30, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

def save_checkpoint(epoch, model, optimizer, loss, checkpoint_dir='checkpoints'):
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }, checkpoint_path)
    logger.info(f"Checkpoint saved at {checkpoint_path}")

In [None]:
BATCH_SIZE = 256
NUM_EPOCHS = 100
CHECKPOINT_INTERVAL = 10

images = None  # TODO: load images here
dataset = ArcDataset(images, transform=augmentation)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

# Initialize model, loss, and optimizer with transformer encoder
transformer_encoder = TransformerEncoder(
    img_size=30,
    d_model=128,
    num_layers=4,
    num_heads=4,
    ff_dim=256,
    dropout=0.1,
    use_positional_encoding=True
).to(device)

model = SimCLR(transformer_encoder, projection_dim=128).to(device)
criterion = NTXentLoss(batch_size=BATCH_SIZE, temperature=0.5)
optimizer = optim.Adam(model.parameters(), lr=0.0003) # TODO: experiment w/ HPs like weight decay
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# Training loop remains the same
for epoch in range(NUM_EPOCHS):
    running_loss = 0.0
    for img1, img2 in dataloader:
        img1, img2 = img1.to(device), img2.to(device)
        
        optimizer.zero_grad()

        z_i = model(img1)
        z_j = model(img2)
        
        loss = criterion(z_i, z_j)
        loss.backward()
        
        optimizer.step()

        running_loss += loss.item()
    
    avg_loss = running_loss / len(dataloader)

    # TODO: Experiment with scheduler
    # scheduler.step()

    if (epoch + 1) % CHECKPOINT_INTERVAL == 0:
        save_checkpoint(epoch + 1, model, optimizer, avg_loss)
        
    logger.info(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Loss: {avg_loss:.4f}")

logger.info("Training complete.")