# Advanced Models for Profession Classification

This notebook explores the use of Vision Transformers (ViT) and KAN layers for the profession classification task.

In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision import transforms
import pandas as pd

# Add project root to path
sys.path.append('..')

from src.data.download import download_and_extract_idenprof
from src.data.preprocess import get_data_loaders, get_augmented_data_loaders, get_class_names
from src.models.vit import vit_tiny, vit_small, vit_base
from src.models.kan import add_kan_layer
from src.utils.metrics import evaluate_accuracy

## 1. Download and Prepare Dataset

In [None]:
# Download and extract dataset
data_dir = download_and_extract_idenprof(data_dir='../data')
print(f"Dataset directory: {data_dir}")

# Get class names
class_names = get_class_names(data_dir)
print(f"Class names: {class_names}")

# Set parameters
batch_size = 32
image_size = 224

# Create data loaders
train_loader, test_loader = get_data_loaders(data_dir, batch_size, image_size)
aug_train_loader, _ = get_augmented_data_loaders(
    data_dir, 
    batch_size=batch_size, 
    image_size=image_size,
    rotation=30,
    hue=0.05,
    saturation=0.05
)

## 2. Explore Vision Transformer Models

In [None]:
# Create Vision Transformer models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

models = {
    "ViT-Tiny": vit_tiny(num_classes=len(class_names)),
    "ViT-Small": vit_small(num_classes=len(class_names)),
    "ViT-Base": vit_base(num_classes=len(class_names))
}

# Display model architectures and parameter counts
model_stats = []
for name, model in models.items():
    param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
    model_stats.append({
        "Model": name,
        "Parameters": param_count,
        "Parameters (M)": param_count / 1_000_000
    })
    
# Create DataFrame and display
df_stats = pd.DataFrame(model_stats)
df_stats.set_index("Model", inplace=True)
display(df_stats)

### 2.1 Explore Model Architecture

In [None]:
# Get a sample batch
images, labels = next(iter(train_loader))
print(f"Batch shape: {images.shape}")

# Show first image
img = images[0].permute(1, 2, 0) / 2 + 0.5  # Convert to HWC and denormalize
plt.figure(figsize=(6, 6))
plt.imshow(img)
plt.title(f"Class: {class_names[labels[0].item()]}")
plt.axis('off')
plt.show()

In [None]:
# Explore ViT architecture
vit = models["ViT-Tiny"]

# Examine patch embedding
print(f"Patch embedding layer: {vit.patch_embed}")
print(f"Patch size: {vit.patch_embed.patch_size}")
print(f"Number of patches: {vit.patch_embed.num_patches}")

# Examine attention mechanism
print(f"\nFirst transformer block:")
print(f"Attention layer: {vit.blocks[0].attn}")
print(f"Number of attention heads: {vit.blocks[0].attn.num_heads}")
print(f"MLP layer: {vit.blocks[0].mlp}")

### 2.2 Visualize Patch Embedding

In [None]:
# Function to visualize patches
def visualize_patches(image, patch_size=16):
    # Get image dimensions
    _, h, w = image.shape
    
    # Calculate number of patches in each dimension
    n_h = h // patch_size
    n_w = w // patch_size
    
    # Create figure
    fig, axes = plt.subplots(n_h, n_w, figsize=(10, 10))
    
    # Extract and display patches
    for i in range(n_h):
        for j in range(n_w):
            # Extract patch
            patch = image[:, i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size]
            
            # Convert to HWC and denormalize
            patch = patch.permute(1, 2, 0) / 2 + 0.5
            
            # Display patch
            axes[i, j].imshow(patch)
            axes[i, j].axis('off')
    
    plt.tight_layout()
    plt.show()

# Select an image
sample_image = images[0]

# Display original image
plt.figure(figsize=(6, 6))
plt.imshow(sample_image.permute(1, 2, 0) / 2 + 0.5)
plt.title(f"Original Image: {class_names[labels[0].item()]}")
plt.axis('off')
plt.show()

# Visualize patches
print("Patches (16x16):")
visualize_patches(sample_image, patch_size=16)

### 2.3 Visualize Attention Maps

Let's look at the attention patterns in ViT

In [None]:
# Function to get attention maps
def get_attention_maps(model, image):
    # Add batch dimension
    image = image.unsqueeze(0).to(device)
    
    # Forward pass up to attention layer
    attention_maps = []
    
    def hook_fn(module, input, output):
        # Extract attention map from QK^T
        q, k, v = input[0][0], input[0][1], input[0][2]  # B, H, N, C/H
        attn = (q @ k.transpose(-2, -1)) * module.scale
        attn = attn.softmax(dim=-1)
        attention_maps.append(attn.detach().cpu().numpy())
    
    # Register hooks for attention layers
    hooks = []
    for block in model.blocks:
        hook = block.attn.register_forward_hook(hook_fn)
        hooks.append(hook)
    
    # Forward pass
    with torch.no_grad():
        model(image)
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    return attention_maps

# Get attention maps for a sample image
attention_maps = get_attention_maps(models["ViT-Tiny"], sample_image)

# Visualize attention maps
def plot_attention_maps(attention_maps, layer_idx=0, head_idx=0, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
    
    # Get attention map for specified layer and head
    attn = attention_maps[layer_idx][0, head_idx]  # (N, N)
    
    # Display attention map
    im = ax.imshow(attn, cmap='viridis')
    ax.set_title(f"Layer {layer_idx+1}, Head {head_idx+1}")
    plt.colorbar(im, ax=ax)
    
    # Add labels for CLS token and patches
    ax.set_xlabel("Key (patches)")
    ax.set_ylabel("Query (patches)")
    ax.set_xticks([0] + list(range(5, attn.shape[1], 5)))
    ax.set_yticks([0] + list(range(5, attn.shape[0], 5)))
    ax.set_xticklabels(["CLS"] + [str(i) for i in range(5, attn.shape[1], 5)])
    ax.set_yticklabels(["CLS"] + [str(i) for i in range(5, attn.shape[0], 5)])
    
    return ax

# Plot attention maps for first layer, all heads
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for i in range(3):  # ViT-Tiny has 3 heads
    plot_attention_maps(attention_maps, layer_idx=0, head_idx=i, ax=axes[i])
plt.tight_layout()
plt.show()

# Plot attention maps for all layers, first head
n_layers = len(attention_maps)
fig, axes = plt.subplots(1, n_layers, figsize=(4*n_layers, 4))
for i in range(n_layers):
    plot_attention_maps(attention_maps, layer_idx=i, head_idx=0, ax=axes[i])
plt.tight_layout()
plt.show()

## 3. Explore KAN Layers

In [None]:
# Create a ViT model with KAN layers
vit_tiny_model = vit_tiny(num_classes=len(class_names))
vit_with_kan = add_kan_layer(
    vit_tiny_model, 
    kan_hidden_sizes=[128, 64], 
    num_classes=len(class_names),
    kan_width=16
)

# Display model structure
print("ViT with KAN Layers:")
print(vit_with_kan)

In [None]:
# Visualize KAN layer structure
kan_network = vit_with_kan.kan_classifier
print(f"KAN Network Structure:\n{kan_network}")

# Show KAN layer parameters
kan_layer = kan_network.layers[0]
print(f"\nKAN Layer Parameters:")
print(f"Input features: {kan_layer.in_features}")
print(f"Output features: {kan_layer.out_features}")
print(f"Width: {kan_layer.width}")
print(f"W1 shape: {kan_layer.W1.shape}")
print(f"W2 shape: {kan_layer.W2.shape}")
print(f"Univariate weights shape: {kan_layer.univariate_weights.shape}")

### 3.1 Visualize KAN Univariate Functions

In [None]:
# Function to visualize KAN univariate functions
def visualize_kan_univariate_functions(kan_layer, n_samples=5):
    # Generate x values
    x = torch.linspace(-3, 3, 100)
    
    # Scale x for Chebyshev polynomials
    x_scaled = x / 3.0
    
    # Generate Chebyshev polynomials
    t0 = torch.ones_like(x_scaled)
    t1 = x_scaled
    t2 = 2 * x_scaled * t1 - t0
    t3 = 2 * x_scaled * t2 - t1
    t4 = 2 * x_scaled * t3 - t2
    t5 = 2 * x_scaled * t4 - t3
    t6 = 2 * x_scaled * t5 - t4
    t7 = 2 * x_scaled * t6 - t5
    
    # Stack Chebyshev polynomials
    cheb = torch.stack([t0, t1, t2, t3, t4, t5, t6, t7], dim=1)
    
    # Sample weights from univariate weights
    weights = kan_layer.univariate_weights.detach().cpu()
    bias = kan_layer.univariate_bias.detach().cpu()
    
    # Select a few random functions to visualize
    indices = torch.randperm(weights.shape[0])[:n_samples]
    
    # Plot the functions
    plt.figure(figsize=(12, 8))
    for i, idx in enumerate(indices):
        # Compute function value
        y = torch.matmul(cheb, weights[idx]) + bias[idx]
        
        # Plot
        plt.subplot(n_samples, 1, i+1)
        plt.plot(x.numpy(), y.numpy())
        plt.grid(True)
        plt.title(f"Univariate Function {idx+1}")
        plt.xlabel("Input")
        plt.ylabel("Output")
    
    plt.tight_layout()
    plt.show()

# Visualize KAN univariate functions
visualize_kan_univariate_functions(kan_layer, n_samples=5)

## 4. Test Models with Sample Images

In [None]:
# Test models with sample images
def predict_image(model, image, class_names):
    # Add batch dimension and move to device
    image = image.unsqueeze(0).to(device)
    
    # Forward pass
    model.eval()
    with torch.no_grad():
        outputs = model(image)
        probs = torch.nn.functional.softmax(outputs, dim=1)[0]
        _, pred = torch.max(outputs, 1)
    
    # Get predicted class
    pred_class = class_names[pred.item()]
    pred_prob = probs[pred.item()].item()
    
    # Return top 3 predictions
    top_probs, top_idxs = torch.topk(probs, 3)
    top_classes = [class_names[idx.item()] for idx in top_idxs]
    top_probs = top_probs.cpu().numpy()
    
    return pred_class, pred_prob, top_classes, top_probs

# Get a few test images
test_images, test_labels = next(iter(test_loader))
n_images = 5

# Initialize models
vit_tiny_model = vit_tiny(num_classes=len(class_names)).to(device)
vit_with_kan = add_kan_layer(vit_tiny(num_classes=len(class_names)), kan_hidden_sizes=[128, 64], 
                            num_classes=len(class_names), kan_width=16).to(device)

# Test with random weights (not trained)
plt.figure(figsize=(15, 4 * n_images))
for i in range(n_images):
    image = test_images[i]
    true_class = class_names[test_labels[i].item()]
    
    # Make predictions with both models
    pred_class1, pred_prob1, top_classes1, top_probs1 = predict_image(vit_tiny_model, image, class_names)
    pred_class2, pred_prob2, top_classes2, top_probs2 = predict_image(vit_with_kan, image, class_names)
    
    # Display image
    plt.subplot(n_images, 3, i*3 + 1)
    plt.imshow(image.permute(1, 2, 0) / 2 + 0.5)
    plt.title(f"True Class: {true_class}")
    plt.axis('off')
    
    # Display predictions for ViT
    plt.subplot(n_images, 3, i*3 + 2)
    bars = plt.bar(range(3), top_probs1)
    plt.xticks(range(3), top_classes1, rotation=45)
    plt.ylim(0, 1)
    plt.title(f"ViT: {pred_class1} ({pred_prob1:.2f})")
    
    # Display predictions for ViT+KAN
    plt.subplot(n_images, 3, i*3 + 3)
    bars = plt.bar(range(3), top_probs2)
    plt.xticks(range(3), top_classes2, rotation=45)
    plt.ylim(0, 1)
    plt.title(f"ViT+KAN: {pred_class2} ({pred_prob2:.2f})")

plt.tight_layout()
plt.show()

## 5. Training a ViT Model with KAN Layers

To train the model, we would call the training script. From this notebook, we could set up a small training loop for demonstration purposes.

In [None]:
# Demonstration of a simple training loop (not full training)
def train_for_demo(model, data_loader, num_batches=10):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    criterion = torch.nn.CrossEntropyLoss()
    
    losses = []
    
    for i, (images, labels) in enumerate(data_loader):
        if i >= num_batches:
            break
            
        # Move to device
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Record loss
        losses.append(loss.item())
        
        print(f"Batch {i+1}/{num_batches}, Loss: {loss.item():.4f}")
    
    return losses

# Train the KAN model for a few batches
print("Training ViT with KAN for a few batches...")
losses = train_for_demo(vit_with_kan, train_loader, num_batches=5)

# Plot losses
plt.figure(figsize=(8, 4))
plt.plot(losses)
plt.title("Training Loss")
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.grid(True)
plt.show()

## 6. Conclusion

In this notebook, we have explored Vision Transformer models and KAN layers for the profession classification task. We have seen how to:

1. Create and visualize ViT models of different sizes
2. Understand the patch embedding and attention mechanisms
3. Add KAN layers to ViT models
4. Visualize the univariate functions in KAN layers
5. Make predictions with ViT and ViT+KAN models

For complete training, use the training scripts in the project:

```bash
python -m src.training.train --model vit_tiny --use_kan --optimizer adam --lr 0.0003 --batch_size 32 --augment
```

Or use hyperparameter optimization:

```bash
python -m src.training.optuna_optimization --data_dir data/idenprof --num_trials 20
```