# Deep Learning on Tiny-ImageNet with PyTorch

This notebook implements a deep learning pipeline for the Tiny-ImageNet dataset. It covers:
1.  **Data Loading**: Using Hugging Face `datasets`.
2.  **Preprocessing**: Augmentation and Normalization.
3.  **Models**:
    *   CNN Classifier (Custom ResNet-like)
    *   Autoencoder (Image Reconstruction)
    *   Vision Transformer (ViT) for 64x64 images
4.  **Training**: Gradient clipping, LR scheduling, Early stopping.
5.  **Evaluation**: Accuracy and Reconstruction visualization.

## Setup
Ensure you are running on a GPU runtime (Runtime > Change runtime type > GPU in Colab).

In [None]:
# @title Install Requirements
!pip install datasets transformers timm torch torchvision matplotlib tqdm -q

In [None]:
# @title Imports & Configuration
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
import time
import copy

# Set Random Seed for Reproducibility
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)

# Device Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Hyperparameters
BATCH_SIZE = 128
NUM_EPOCHS = 15
LEARNING_RATE = 1e-3
NUM_CLASSES = 200
IMAGE_SIZE = 64

## 1. Data Loading & Preprocessing
We use the `zh-plus/tiny-imagenet` dataset from Hugging Face.

In [None]:
# Load Dataset
print("Loading Tiny-ImageNet dataset...")
dataset = load_dataset("zh-plus/tiny-imagenet")

# Define Transforms
# Tiny-ImageNet images are 64x64
train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(64, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats
])

val_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Helper to apply transforms to HF dataset
def preprocess_train(examples):
    examples['pixel_values'] = [train_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

def preprocess_val(examples):
    examples['pixel_values'] = [val_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

# Apply transforms (set_transform is lazy and efficient)
dataset['train'].set_transform(preprocess_train)
dataset['valid'].set_transform(preprocess_val)

# Create DataLoaders
def collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    labels = torch.tensor([item['label'] for item in batch])
    return pixel_values, labels

train_loader = DataLoader(dataset['train'], batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=2)
val_loader = DataLoader(dataset['valid'], batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## 2. Model Architectures

### Model 1: CNN Classifier (ResNet-like Block)

In [None]:
class SimpleResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class TinyCNN(nn.Module):
    def __init__(self, num_classes=200):
        super().__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        
        self.layer1 = self._make_layer(64, 2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, out_channels, blocks, stride):
        layers = []
        layers.append(SimpleResidualBlock(self.in_channels, out_channels, stride))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(SimpleResidualBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

### Model 2: Autoencoder

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),  # 64 -> 32
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), # 32 -> 16
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), # 16 -> 8
            nn.ReLU()
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), # 8 -> 16
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), # 16 -> 32
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 3, stride=2, padding=1, output_padding=1),  # 32 -> 64
            nn.Sigmoid() # Output pixels 0-1 (if we un-normalize)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

### Model 3: Vision Transformer (ViT)
A lightweight ViT adapted for small images.

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=64, patch_size=8, in_channels=3, embed_dim=256):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x) # [B, E, H', W']
        x = x.flatten(2) # [B, E, N]
        x = x.transpose(1, 2) # [B, N, E]
        return x

class TinyViT(nn.Module):
    def __init__(self, num_classes=200, embed_dim=256, depth=6, heads=8, mlp_dim=512):
        super().__init__()
        self.patch_embed = PatchEmbedding()
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.num_patches, embed_dim))
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=heads, dim_feedforward=mlp_dim, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        
        x = self.transformer(x)
        
        # Use CLS token for classification
        out = self.mlp_head(x[:, 0])
        return out

## 3. Training Utilities

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=10, task='classification'):
    model.to(device)
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_loss = float('inf')
    
    history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
    early_stopper = EarlyStopping(patience=5)

    for epoch in range(num_epochs):
        # Training Phase
        model.train()
        running_loss = 0.0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        for inputs, labels in pbar:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            
            if task == 'classification':
                loss = criterion(outputs, labels)
            else: # autoencoder
                # For AE, we might need to un-normalize inputs for target if using Sigmoid, 
                # but here we'll assume target is normalized inputs for simplicity or match ranges.
                # If model output is Sigmoid (0-1), we should un-normalize inputs to 0-1 range for loss calc.
                # Simplified: Let's remove Sigmoid from AE decoder if we want to match normalized input range,
                # OR un-normalize inputs. Let's stick to matching input range directly.
                # NOTE: In the AE definition above I used Sigmoid. 
                # To match normalized inputs (approx -2 to 2), we should remove Sigmoid or un-normalize target.
                # Let's assume we want to reconstruct the NORMALIZED tensor for simplicity.
                loss = criterion(outputs, inputs)
            
            loss.backward()
            
            # Gradient Clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            pbar.set_postfix({'loss': loss.item()})
            
        epoch_loss = running_loss / len(train_loader.dataset)
        history['train_loss'].append(epoch_loss)
        
        # Validation Phase
        model.eval()
        val_running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                
                if task == 'classification':
                    loss = criterion(outputs, labels)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                else:
                    loss = criterion(outputs, inputs)
                
                val_running_loss += loss.item() * inputs.size(0)
        
        val_loss = val_running_loss / len(val_loader.dataset)
        history['val_loss'].append(val_loss)
        
        if task == 'classification':
            val_acc = 100 * correct / total
            history['val_acc'].append(val_acc)
            print(f"Epoch {epoch+1}: Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
            
            if val_acc > best_acc:
                best_acc = val_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        else:
            print(f"Epoch {epoch+1}: Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}")
            if val_loss < best_loss:
                best_loss = val_loss
                best_model_wts = copy.deepcopy(model.state_dict())
        
        if scheduler:
            scheduler.step()
            
        early_stopper(val_loss)
        if early_stopper.early_stop:
            print("Early stopping triggered")
            break
            
    model.load_state_dict(best_model_wts)
    return model, history

## 4. Training & Evaluation

### Train CNN Classifier

In [None]:
cnn_model = TinyCNN(num_classes=NUM_CLASSES)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(cnn_model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

print("Training CNN Classifier...")
cnn_model, cnn_history = train_model(
    cnn_model, train_loader, val_loader, criterion, optimizer, scheduler, 
    num_epochs=NUM_EPOCHS, task='classification'
)

In [None]:
# Plot Classification Results
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(cnn_history['train_loss'], label='Train Loss')
plt.plot(cnn_history['val_loss'], label='Val Loss')
plt.legend()
plt.title('CNN Loss')

plt.subplot(1, 2, 2)
plt.plot(cnn_history['val_acc'], label='Val Accuracy')
plt.legend()
plt.title('CNN Accuracy')
plt.show()

### Train Autoencoder

In [None]:
# Adjust Autoencoder for Normalized Data
# Since inputs are normalized (approx -2 to 2), we remove the final Sigmoid from the decoder 
# to allow the model to predict negative values.
class AutoencoderLinear(Autoencoder):
    def __init__(self):
        super().__init__()
        # Replace last layer to remove Sigmoid
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 3, stride=2, padding=1, output_padding=1)
            # No Sigmoid
        )

ae_model = AutoencoderLinear()
criterion_ae = nn.MSELoss()
optimizer_ae = optim.AdamW(ae_model.parameters(), lr=LEARNING_RATE)
scheduler_ae = optim.lr_scheduler.StepLR(optimizer_ae, step_size=5, gamma=0.5)

print("Training Autoencoder...")
ae_model, ae_history = train_model(
    ae_model, train_loader, val_loader, criterion_ae, optimizer_ae, scheduler_ae, 
    num_epochs=10, task='autoencoder'
)

In [None]:
# Visualize Reconstructions
def imshow(img, title):
    # Un-normalize for display
    img = img.cpu().numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    plt.imshow(img)
    plt.title(title)
    plt.axis('off')

ae_model.eval()
images, _ = next(iter(val_loader))
images = images[:5].to(device)
with torch.no_grad():
    reconstructed = ae_model(images)

plt.figure(figsize=(10, 4))
for i in range(5):
    plt.subplot(2, 5, i + 1)
    imshow(images[i], "Original")
    plt.subplot(2, 5, i + 6)
    imshow(reconstructed[i], "Reconstructed")
plt.show()

### (Optional) Train Vision Transformer

In [None]:
vit_model = TinyViT(num_classes=NUM_CLASSES)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(vit_model.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

print("Training Vision Transformer...")
vit_model, vit_history = train_model(
    vit_model, train_loader, val_loader, criterion, optimizer, scheduler, 
    num_epochs=NUM_EPOCHS, task='classification'
)