# Visualize Optimal Augmentations for CIFAR-100

This notebook implements visualization of optimal data augmentations for Self-Supervised Learning using Barlow Twins on CIFAR-100 with ResNet-18 as the target model.

**Based on:** "A Theoretical Characterization of Optimal Data Augmentations in Self-Supervised Learning" (arXiv:2411.01767v3)

## Contents
1. Setup and imports
2. Load configuration and data
3. Extract target representations
4. Generate optimal augmentations with different kernels
5. Visualize and compare results

## 1. Import Required Libraries

In [None]:
import os
import sys
import yaml
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Add src to path
sys.path.insert(0, os.path.join(os.getcwd(), '..'))

from src.utils import (
    load_cifar100,
    images_to_matrix,
    matrix_to_images,
    denormalize_images,
    save_image_grid,
    create_comparison_grid,
    set_seed,
)
from src.kernels import get_kernel, check_kernel_conditions
from src.target_models import TargetModel
from src.augmentation_generator import BarlowTwinsAugmentationGenerator

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (15, 10)

print("✓ All libraries imported successfully")

## 2. Load Configuration

In [None]:
# Load main configuration
with open('../configs/barlow_twins_cifar100.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Load kernel configurations
with open('../configs/kernels.yaml', 'r') as f:
    kernel_configs = yaml.safe_load(f)

# Set random seed
set_seed(config['experiment']['seed'])

print("Configuration loaded:")
print(f"  Dataset: {config['data']['dataset']}")
print(f"  Target model: {config['target_model']['architecture']}")
print(f"  PCA dim: {config['target_model']['pca_dim']}")
print(f"  Num samples: {config['data']['num_train_samples']}")
print(f"  Device: {config['experiment']['device']}")

## 3. Load CIFAR-100 Dataset

In [None]:
# Load CIFAR-100
print("Loading CIFAR-100 dataset...")
images, labels = load_cifar100(
    data_dir=config['data']['data_dir'],
    train=True,
    download=True,
    normalize=config['data']['normalize'],
    num_samples=config['data']['num_train_samples'],
)

print(f"✓ Loaded {len(images)} images")
print(f"  Shape: {images.shape}")
print(f"  Classes: {len(np.unique(labels))}")

# Convert to matrix format for kernel methods
X_train = images_to_matrix(images)
print(f"  Data matrix shape: {X_train.shape} (features x samples)")

## 4. Visualize Original Images

In [None]:
# Visualize a sample of original images
num_display = 64
images_to_display = denormalize_images(images[:num_display])

fig = plt.figure(figsize=(12, 12))
for i in range(min(num_display, 64)):
    ax = plt.subplot(8, 8, i + 1)
    img = images_to_display[i].transpose(1, 2, 0)
    plt.imshow(img)
    plt.axis('off')

plt.suptitle('Sample Original CIFAR-100 Images', fontsize=16)
plt.tight_layout()
plt.show()

## 5. Load Target Model and Extract Representations

We use a pretrained ResNet-18 as our target model f*, following the paper's approach.

In [None]:
print("Loading target model...")
target_model = TargetModel(
    architecture=config['target_model']['architecture'],
    pretrained=config['target_model']['pretrained'],
    weights=config['target_model']['weights'],
    pca_dim=config['target_model']['pca_dim'],
    device=config['experiment']['device'],
)
print(f"✓ {target_model}")

print("\nExtracting target representations...")
F_target = target_model.get_target_representations(
    torch.from_numpy(images).float(),
    fit_pca=True,
    batch_size=64,
)
print(f"✓ Target representations shape: {F_target.shape}")

# Convert to (d, n) format for algorithm
F_target = F_target.T
print(f"  Reformatted to: {F_target.shape} (d x n)")

## 6. Generate Augmentations with Different Kernels

Following Figure 3 in the paper, we generate augmentations using different kernel functions to show that "different function classes require different augmentations to achieve the same representations."

In [None]:
# Select kernels to compare
kernels_to_test = ['linear', 'rbf_medium', 'rbf_large']
num_vis_samples = 100  # Number of samples to visualize

augmentation_results = {}

for kernel_name in kernels_to_test:
    print(f"\n{'='*60}")
    print(f"Generating augmentations with: {kernel_name}")
    print(f"{'='*60}")
    
    # Get kernel configuration
    kernel_config = kernel_configs['kernels'][kernel_name]
    print(f"Config: {kernel_config}")
    
    # Create kernel
    kernel = get_kernel(**kernel_config)
    print(f"Kernel: {kernel}")
    
    # Create augmentation generator
    generator = BarlowTwinsAugmentationGenerator(
        kernel=kernel,
        lambda_ridge=config['augmentation']['lambda_ridge'],
        mu_p=config['augmentation']['mu_p'],
        check_conditions=True,
    )
    
    # Generate augmentations
    indices = np.arange(num_vis_samples)
    X_augmented = generator.fit_transform(X_train, F_target, indices=indices)
    
    # Convert back to images
    image_shape = (3, 32, 32)
    images_aug = matrix_to_images(X_augmented, image_shape)
    
    # Store results
    augmentation_results[kernel_name] = {
        'augmented': images_aug,
        'original': matrix_to_images(X_train[:, indices], image_shape),
        'kernel': kernel,
        'generator': generator,
    }
    
    print(f"✓ Generated {len(images_aug)} augmented images")

## 7. Visualize Augmentations by Kernel Type

In [None]:
# Visualize augmentations for each kernel
for kernel_name, results in augmentation_results.items():
    print(f"\nVisualizing {kernel_name} kernel augmentations...")
    
    # Denormalize for visualization
    orig_vis = denormalize_images(results['original'][:64])
    aug_vis = denormalize_images(results['augmented'][:64])
    
    # Create comparison figure
    fig, axes = plt.subplots(2, 1, figsize=(16, 8))
    
    # Original images
    grid_orig = torch.from_numpy(orig_vis).float()
    grid_orig = torchvision.utils.make_grid(grid_orig, nrow=8, padding=2, normalize=False)
    axes[0].imshow(grid_orig.permute(1, 2, 0))
    axes[0].set_title(f'Original Images', fontsize=14)
    axes[0].axis('off')
    
    # Augmented images
    grid_aug = torch.from_numpy(aug_vis).float()
    grid_aug = torchvision.utils.make_grid(grid_aug, nrow=8, padding=2, normalize=False)
    axes[1].imshow(grid_aug.permute(1, 2, 0))
    axes[1].set_title(f'Augmented Images ({kernel_name} kernel)', fontsize=14)
    axes[1].axis('off')
    
    plt.suptitle(f'Comparison: {kernel_name.upper()} Kernel', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

## 8. Side-by-Side Comparison Across Kernels

This recreates the visualization style of Figure 3 in the paper, showing that different kernels produce semantically different augmentations.

In [None]:
import torchvision

# Create multi-kernel comparison
num_examples = 16
fig, axes = plt.subplots(len(kernels_to_test) + 1, 1, figsize=(16, 4 * (len(kernels_to_test) + 1)))

# Show originals first
orig_vis = denormalize_images(augmentation_results[kernels_to_test[0]]['original'][:num_examples])
grid = torch.from_numpy(orig_vis).float()
grid = torchvision.utils.make_grid(grid, nrow=8, padding=4, normalize=False)
axes[0].imshow(grid.permute(1, 2, 0))
axes[0].set_title('Original CIFAR-100 Images', fontsize=14, fontweight='bold')
axes[0].axis('off')

# Show augmentations for each kernel
for idx, kernel_name in enumerate(kernels_to_test, start=1):
    aug_vis = denormalize_images(augmentation_results[kernel_name]['augmented'][:num_examples])
    grid = torch.from_numpy(aug_vis).float()
    grid = torchvision.utils.make_grid(grid, nrow=8, padding=4, normalize=False)
    axes[idx].imshow(grid.permute(1, 2, 0))
    axes[idx].set_title(f'Augmented with {kernel_name.upper()} Kernel', fontsize=14, fontweight='bold')
    axes[idx].axis('off')

plt.suptitle('Optimal Augmentations Vary by Kernel Function\\n(Recreating Figure 3 from Paper)', 
             fontsize=18, fontweight='bold', y=1.00)
plt.tight_layout()
plt.savefig('../results/generated_images/kernel_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ Saved comparison to: ../results/generated_images/kernel_comparison.png")

## 9. Analyze Augmentation Properties

Examine the properties of the generated augmentations and the transformation matrices.

In [None]:
# Analyze transformation properties for each kernel
print("="*60)
print("AUGMENTATION TRANSFORMATION ANALYSIS")
print("="*60)

for kernel_name, results in augmentation_results.items():
    print(f"\n{kernel_name.upper()} Kernel:")
    print("-" * 40)
    
    generator = results['generator']
    aug_info = generator.get_augmentation_distribution()
    
    print(f"  Transformation matrix shape: {aug_info['transformation_matrix'].shape}")
    print(f"  Min eigenvalue: {aug_info['min_eigenvalue']:.6e}")
    print(f"  Max eigenvalue: {aug_info['max_eigenvalue']:.6e}")
    print(f"  Condition number: {aug_info['condition_number']:.6e}")
    
    # Visualize eigenvalue spectrum
    eigvals = aug_info['eigenvalues']
    plt.figure(figsize=(8, 4))
    plt.plot(np.sort(eigvals)[::-1], 'o-', linewidth=2, markersize=4)
    plt.xlabel('Index', fontsize=12)
    plt.ylabel('Eigenvalue', fontsize=12)
    plt.title(f'Eigenvalue Spectrum of T_H Matrix ({kernel_name})', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

## 10. Key Insights

Based on the visualizations above, we can observe:

1. **Architecture Matters**: Different kernel functions (representing different architectures) produce recognizably different augmentations for the same target representations.

2. **Augmentations Need Not Resemble Data**: The generated augmentations can look quite different from the original CIFAR-100 distribution, yet they achieve the desired representations.

3. **Single Augmentation Suffices**: We use just one non-trivial augmentation (plus identity) as described in the paper, following Moutakanni et al. (2024).

4. **Theoretical Guarantees**: All augmentations are guaranteed to achieve the target representations up to orthogonal transformation (Theorem 4.4).

## 11. Save Results

In [None]:
# Save individual kernel results
os.makedirs('../results/generated_images', exist_ok=True)

for kernel_name, results in augmentation_results.items():
    print(f"Saving {kernel_name} results...")
    
    # Save augmented images grid
    aug_vis = denormalize_images(results['augmented'][:100])
    save_image_grid(
        aug_vis,
        f'../results/generated_images/{kernel_name}_augmented.png',
        nrow=10,
        title=f'Augmented Images ({kernel_name} kernel)',
    )
    
    # Save original images grid
    orig_vis = denormalize_images(results['original'][:100])
    save_image_grid(
        orig_vis,
        f'../results/generated_images/{kernel_name}_original.png',
        nrow=10,
        title='Original Images',
    )

print("\n✓ All results saved to ../results/generated_images/")