# Synthetic Thalamus for Image Classification

This notebook demonstrates how to use the Synthetic Thalamus module for image classification tasks.

In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl

# Add the parent directory to the path
sys.path.append('..')
from core.thalamus import SyntheticThalamus
from core.adapters import ImageAdapter

## Load ImageNet-A Dataset

ImageNet-A is a dataset of natural adversarial examples that commonly fool classifiers trained on ImageNet.

In [None]:
# Define transforms
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Replace with your path to ImageNet-A
# dataset = torchvision.datasets.ImageFolder(root='./path/to/imagenet-a', transform=transform)
# dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# For this demo, we'll use CIFAR10 instead
cifar10_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
cifar10_train_loader = DataLoader(cifar10_train, batch_size=32, shuffle=True, num_workers=2)

cifar10_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
cifar10_test_loader = DataLoader(cifar10_test, batch_size=32, shuffle=False, num_workers=2)

## Define the Thalamus-based Image Classification Model

In [None]:
class ThalamusImageClassifier(pl.LightningModule):
    def __init__(self, num_classes=10):
        super().__init__()
        
        # Image adapter
        self.d_model = 128
        self.adapter = ImageAdapter(d_model=self.d_model)
        
        # Vision transformer to create token sequence
        patch_size = 16
        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(3, self.d_model, kernel_size=patch_size, stride=patch_size),
            nn.Flatten(2),
            nn.Transpose(1, 2)  # [B, num_patches, d_model]
        )
        
        # Synthetic thalamus for token gating
        self.thalamus = SyntheticThalamus(
            d_model=self.d_model, n_heads=4, k=8, phase_dim=16, task_dim=64, num_tasks=10
        )
        
        # Classifier workspace
        self.classifier = nn.Sequential(
            nn.Linear(self.d_model + 16, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )
        
        self.loss_fn = nn.CrossEntropyLoss()
    
    def forward(self, x, task_id):
        # Convert images to tokens
        tokens = self.to_patch_embedding(x)  # [B, num_patches, d_model]
        
        # Apply synthetic thalamus
        gated_tokens = self.thalamus(tokens, task_id)  # [B, k, d_model + phase_dim]
        
        # Global average pooling
        pooled = gated_tokens.mean(dim=1)  # [B, d_model + phase_dim]
        
        # Classification
        return self.classifier(pooled)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        # Assign random task IDs for this example
        task_id = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        logits = self(x, task_id)
        loss = self.loss_fn(logits, y)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        task_id = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        logits = self(x, task_id)
        loss = self.loss_fn(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log('val_loss', loss)
        self.log('val_acc', acc)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

## Train the Model

In [None]:
# Initialize model
model = ThalamusImageClassifier(num_classes=10)

# Initialize trainer
trainer = pl.Trainer(
    max_epochs=5,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1 if torch.cuda.is_available() else None,
)

# Train the model
trainer.fit(model, cifar10_train_loader, cifar10_test_loader)

## Analyze the Thalamus Gating Behavior

In [None]:
def visualize_gating(model, images, task_id=None):
    """Visualize which patches are being gated through the thalamus."""
    if task_id is None:
        task_id = torch.zeros(images.size(0), dtype=torch.long)
    
    # Get the patch tokens
    tokens = model.to_patch_embedding(images)  # [B, num_patches, d_model]
    
    # Get the scores without gating
    B, N, D = tokens.size()
    task_embedding = model.thalamus.task_embed(task_id)  # [B, task_dim]
    task_embedding = model.thalamus.task_proj(task_embedding)  # [B, d_model]
    task_embedding = task_embedding.unsqueeze(1).expand(-1, N, -1)  # [B, N, d_model]
    x_combined = tokens + task_embedding
    x_attn, _ = model.thalamus.scorer(x_combined, x_combined, x_combined)
    scores = x_attn.norm(dim=-1)  # [B, N]
    
    # Plot the first image and its gating pattern
    plt.figure(figsize=(10, 5))
    
    # Show the image
    plt.subplot(1, 2, 1)
    img = images[0].cpu().permute(1, 2, 0).numpy()
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = img * std + mean
    img = np.clip(img, 0, 1)
    plt.imshow(img)
    plt.title('Input Image')
    plt.axis('off')
    
    # Show the attention scores as a heatmap
    plt.subplot(1, 2, 2)
    attention_map = scores[0].detach().cpu().reshape(224 // 16, 224 // 16).numpy()
    plt.imshow(attention_map, cmap='hot')
    plt.title('Thalamus Attention Scores')
    plt.colorbar(label='Score')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Get a batch of images
dataiter = iter(cifar10_test_loader)
images, labels = next(dataiter)

# Visualize the gating behavior
with torch.no_grad():
    visualize_gating(model, images)

## Experiment with Different Task IDs

One of the interesting aspects of the synthetic thalamus is its task-conditioned behavior. Let's see how different task IDs affect the model's performance and attention patterns.

In [None]:
def evaluate_with_task(model, dataloader, task_id):
    """Evaluate the model's performance with a specific task ID."""
    model.eval()
    correct = 0
    total = 0
    task_tensor = torch.tensor([task_id], dtype=torch.long)
    
    with torch.no_grad():
        for images, labels in dataloader:
            task_batch = task_tensor.repeat(images.size(0))
            outputs = model(images, task_batch)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return correct / total

# Evaluate with different task IDs
task_accuracies = []
for task_id in range(10):
    acc = evaluate_with_task(model, cifar10_test_loader, task_id)
    task_accuracies.append(acc)
    print(f"Task ID {task_id}: Accuracy = {acc:.4f}")

# Plot the results
plt.figure(figsize=(10, 6))
plt.bar(range(10), task_accuracies)
plt.xlabel('Task ID')
plt.ylabel('Accuracy')
plt.title('Model Performance by Task ID')
plt.xticks(range(10))
plt.ylim(0, 1.0)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()

## Visualize Phase Embeddings

The phase embeddings are a key component of the synthetic thalamus, enabling phase-aware gating. Let's visualize these embeddings to understand their structure.

In [None]:
def visualize_phase_embeddings(model, images):
    """Visualize the phase embeddings produced by the thalamus."""
    # Get a batch of images
    task_id = torch.zeros(images.size(0), dtype=torch.long)
    
    # Get the patch tokens
    tokens = model.to_patch_embedding(images)  # [B, num_patches, d_model]
    
    # Forward pass through thalamus
    gated = model.thalamus(tokens, task_id)  # [B, k, d_model + phase_dim]
    
    # Extract the phase embeddings
    phase_dim = model.thalamus.phase_dim
    phase_embeddings = gated[0, :, -phase_dim:].detach().cpu().numpy()  # [k, phase_dim]
    
    # Plot the phase embeddings
    plt.figure(figsize=(12, 6))
    
    # Heatmap of phase values
    plt.subplot(1, 2, 1)
    plt.imshow(phase_embeddings, cmap='coolwarm', vmin=-1, vmax=1)
    plt.colorbar(label='Phase Value')
    plt.xlabel('Phase Dimension')
    plt.ylabel('Token Index')
    plt.title('Phase Embeddings Heatmap')
    
    # Line plot for better visualization of sin waves
    plt.subplot(1, 2, 2)
    for i in range(min(8, phase_embeddings.shape[0])):  # Plot first 8 tokens
        plt.plot(phase_embeddings[i], label=f'Token {i}')
    plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
    plt.xlabel('Phase Dimension')
    plt.ylabel('Phase Value')
    plt.title('Phase Embeddings by Token')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

# Visualize phase embeddings
with torch.no_grad():
    visualize_phase_embeddings(model, images)

## Conclusion

In this notebook, we've demonstrated how the Synthetic Thalamus can be used for image classification. The key aspects of this approach include:

1. **Attention-based Salience Scoring**: The thalamus uses multi-head attention to score the importance of input tokens.
2. **Task-Conditioned Gating**: The model's behavior can be modulated via task embeddings.
3. **Phase-Aware Processing**: Phase tags add an extra dimension to token representation, potentially enabling more complex workspace dynamics.

This approach shows promise for building more efficient vision models that focus computational resources on the most relevant parts of the input, similar to how biological visual attention works.