In [9]:
import os
import numpy as np
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional, List, Dict, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Subset, random_split

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

import wandb
from wandb.sdk.wandb_run import Run
import matplotlib.pyplot as plt
import torchvision

In [2]:
wandb.login(key='f56388c51b488c425a228537fd2d35e5498a3a91')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/sathwikpentela/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mda24m017[0m ([33mda24m017-indian-institute-of-technology-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
@dataclass
class ModelConfig:
    img_size: int = 128
    num_classes: int = 10
    batch_size: int = 32
    epochs: int = 10
    filter_organization: str = "double"  # 'same', 'double', 'half'
    activation: str = "silu"  # 'relu', 'gelu', 'silu', 'mish'
    data_augmentation: bool = True
    batch_norm: bool = True
    dropout: float = 0.3
    dense_neurons: int = 256
    learning_rate: float = 1e-4

# Sweep Configuration (for hyperparameter tuning)
sweep_config = {
    "method": "bayes",
    "metric": {"name": "val_acc", "goal": "maximize"},
    "parameters": {
        "filter_organization": {"values": ["same", "double", "half"]},
        "activation": {"values": ["relu", "gelu", "silu", "mish"]},
        "data_augmentation": {"values": [True, False]},
        "batch_norm": {"values": [True, False]},
        "dropout": {"values": [0.1, 0.2, 0.3]},
        "dense_neurons": {"values": [128, 256, 512]},
        "learning_rate": {"values": [1e-3, 1e-4]},
    },
}

In [6]:
class CustomCNN(pl.LightningModule):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.save_hyperparameters()
        self.config = config

        # Define filter progression
        if config.filter_organization == "same":
            filters = [32] * 5
        elif config.filter_organization == "double":
            filters = [32, 64, 128, 256, 512]
        elif config.filter_organization == "half":
            filters = [512, 256, 128, 64, 32]
        else:
            raise ValueError("Invalid filter organization")

        # Build convolutional blocks
        conv_blocks = []
        in_channels = 3
        for out_channels in filters:
            conv_blocks.extend([
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_channels) if config.batch_norm else nn.Identity(),
                self._get_activation(config.activation),
                nn.MaxPool2d(kernel_size=2),
            ])
            in_channels = out_channels
        self.conv_layers = nn.Sequential(*conv_blocks)

        # Calculate flattened size
        flattened_size = (config.img_size // (2 ** 5)) ** 2 * filters[-1]

        # Fully connected layers
        self.fc1 = nn.Linear(flattened_size, config.dense_neurons)
        self.dropout = nn.Dropout(config.dropout)
        self.fc2 = nn.Linear(config.dense_neurons, config.num_classes)

    def _get_activation(self, name: str) -> nn.Module:
        activations = {
            "relu": nn.ReLU(),
            "gelu": nn.GELU(),
            "silu": nn.SiLU(),
            "mish": nn.Mish(),
        }
        return activations.get(name.lower(), nn.ReLU())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_layers(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self._get_activation(self.config.activation)(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(1) == y).float().mean()
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(1) == y).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return {"val_loss": loss, "val_acc": acc}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.config.learning_rate)

In [7]:
def get_transforms(config: ModelConfig) -> T.Compose:
    base_transforms = [
        T.Resize((config.img_size, config.img_size)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
    if config.data_augmentation:
        base_transforms.insert(1, T.RandomHorizontalFlip())
        base_transforms.insert(2, T.RandomRotation(10))
    return T.Compose(base_transforms)

def prepare_data(config: ModelConfig) -> tuple[DataLoader, DataLoader]:
    dataset = datasets.ImageFolder("inaturalist_12K/train", transform=get_transforms(config))
    
    # Stratified split (alternative method)
    class_indices = defaultdict(list)
    for idx, (_, label) in enumerate(dataset.samples):
        class_indices[label].append(idx)
    
    train_indices, val_indices = [], []
    for label, indices in class_indices.items():
        split = int(0.8 * len(indices))
        train_indices.extend(indices[:split])
        val_indices.extend(indices[split:])
    
    train_dataset = Subset(dataset, train_indices)
    val_dataset = Subset(dataset, val_indices)
    
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
    
    return train_loader, val_loader

In [8]:
def train_sweep():
    run = wandb.init()
    config = run.config
    
    # Set the run name using the hyperparameters
    run.name = f"fo_{config.filter_organization}_act_{config.activation}_aug_{config.data_augmentation}_bn_{config.batch_norm}_do_{config.dropout}_dn_{config.dense_neurons}_lr_{config.learning_rate}"
    run.save()
    
    # Data
    train_loader, val_loader = prepare_data(config)
    
    # Model
    model = CustomCNN(config)
    
    # Logger & Trainer
    wandb_logger = WandbLogger(project="da6401_assignment2")
    trainer = pl.Trainer(
        max_epochs=config.epochs,
        logger=wandb_logger,
        accelerator="auto",
    )
    
    # Train
    trainer.fit(model, train_loader, val_loader)
    run.finish()


In [None]:
# Start sweep
sweep_id = wandb.sweep(sweep_config, project="da6401_assignment2")
wandb.agent(sweep_id, train_sweep, count=30)

In [10]:
def train_best_model():
    """Train the model with best hyperparameters found during the sweep"""
    # Best configuration from sweep results
    best_config = ModelConfig(
        img_size=128,
        num_classes=10,
        batch_size=32,
        epochs=10,
        filter_organization="half",
        activation="silu",
        data_augmentation=True,
        batch_norm=True,
        dropout=0.2,
        dense_neurons=512,
        learning_rate=1e-4
    )

    # Initialize wandb run for best model
    run = wandb.init(
        project="da6401_assignment2",
        config=best_config.__dict__,
        name="best_model_run"
    )
    
    # Prepare data
    train_loader, val_loader = prepare_data(best_config)
    
    # Setup model checkpointing
    checkpoint_callback = ModelCheckpoint(
        monitor="val_acc",
        dirpath="model_checkpoints",
        filename="best_model-{epoch:02d}-{val_acc:.2f}",
        save_top_k=1,
        mode="max",
        save_last=True
    )
    
    # Initialize model and trainer
    model = CustomCNN(best_config)
    wandb_logger = WandbLogger(project="DA6401_Assignment2", log_model="all")
    
    trainer = pl.Trainer(
        max_epochs=best_config.epochs,
        logger=wandb_logger,
        accelerator="auto",
        callbacks=[checkpoint_callback],
        deterministic=True
    )
    
    # Train the model
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
    
    # Save the best model as artifact
    wandb.log_artifact(
        checkpoint_callback.best_model_path,
        name="best_cnn_model",
        type="model"
    )
    
    run.finish()



In [11]:
def denormalize(tensor: torch.Tensor) -> torch.Tensor:
    """Reverse normalization for visualization"""
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return tensor * std + mean

In [12]:
def visualize_predictions(model: CustomCNN, test_dataset: datasets.ImageFolder, run: Run):
    """Create 10x3 grid of sample predictions"""
    # Collect 3 samples per class
    samples = {i: [] for i in range(model.config.num_classes)}
    for img, label in test_dataset:
        if len(samples[label]) < 3:
            samples[label].append(img)
        if all(len(v) == 3 for v in samples.values()):
            break
    
    # Get predictions
    predictions = {}
    for cls in range(model.config.num_classes):
        imgs = torch.stack(samples[cls])
        if torch.cuda.is_available():
            imgs = imgs.cuda()
        
        with torch.no_grad():
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
        
        predictions[cls] = preds.cpu().tolist()
    
    # Create figure
    fig, axes = plt.subplots(10, 3, figsize=(12, 30))
    
    for cls in range(model.config.num_classes):
        for j in range(3):
            img = samples[cls][j]
            img = denormalize(img)
            
            ax = axes[cls][j]
            ax.imshow(img.permute(1, 2, 0))
            ax.set_title(
                f"True: {test_dataset.classes[cls]}\n"
                f"Pred: {test_dataset.classes[predictions[cls][j]]}"
            )
            ax.axis("off")
    
    plt.tight_layout()
    wandb.log({"sample_predictions": wandb.Image(fig)})
    plt.close()

In [None]:
def visualize_first_layer_filters(model: CustomCNN, run: Run):
    """Visualize filters from first convolutional layer"""
    first_conv = model.conv_layers[0]
    weights = first_conv.weight.data.cpu()
    
    # Normalize filter weights for visualization
    weights = (weights - weights.min()) / (weights.max() - weights.min())
    
    # Create grid of filters
    grid = torchvision.utils.make_grid(weights, nrow=8, padding=2)
    
    # Plot and log to wandb
    plt.figure(figsize=(10, 10))
    plt.imshow(grid.permute(1, 2, 0))
    plt.title("First Layer Conv Filters")
    plt.axis("off")
    wandb.log({"conv_filters": wandb.Image(plt)})
    plt.close()

In [15]:
def evaluate_on_test_set():
    """Evaluate the best model on test data and visualize results"""
    # Load best configuration
    best_config = ModelConfig(
        img_size=128,
        num_classes=10,
        batch_size=32,
        epochs=10,
        filter_organization="half",
        activation="silu",
        data_augmentation=True,
        batch_norm=True,
        dropout=0.2,
        dense_neurons=512,
        learning_rate=1e-4
    )
    
    # Initialize wandb run for evaluation
    run = wandb.init(project="da6401_assignment2", job_type="evaluation")
    
    # Prepare test data transforms (no augmentation)
    test_transform = T.Compose([
        T.Resize((best_config.img_size, best_config.img_size)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Load test dataset
    test_path = "inaturalist_12K/test"
    if not os.path.exists(test_path):
        test_path = "inaturalist_12K/val"  # Fallback if test folder not found
    
    test_dataset = datasets.ImageFolder(root=test_path, transform=test_transform)
    test_loader = DataLoader(test_dataset, batch_size=best_config.batch_size, shuffle=False)
    
    # Load best model checkpoint
    checkpoint_path = "model_checkpoints/best_model.ckpt"  # Update with your actual path
    model = CustomCNN.load_from_checkpoint(
        checkpoint_path,
        config=best_config
    )
    model.eval()
    
    if torch.cuda.is_available():
        model = model.cuda()
    
    # Evaluate test performance
    test_results = {"correct": 0, "total": 0, "loss": 0.0}
    
    with torch.no_grad():
        for batch in test_loader:
            x, y = batch
            if torch.cuda.is_available():
                x, y = x.cuda(), y.cuda()
                
            logits = model(x)
            loss = F.cross_entropy(logits, y)
            preds = logits.argmax(dim=1)
            
            test_results["loss"] += loss.item() * x.size(0)
            test_results["correct"] += (preds == y).sum().item()
            test_results["total"] += y.size(0)
    
    # Calculate metrics
    test_loss = test_results["loss"] / test_results["total"]
    test_acc = test_results["correct"] / test_results["total"]
    
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")
    wandb.log({"test_loss": test_loss, "test_acc": test_acc})
    
    # Visualization: Sample predictions
    visualize_predictions(model, test_dataset, run)
    
    # Visualization: First layer filters
    visualize_first_layer_filters(model, run)
    
    run.finish()


In [None]:


if __name__ == "__main__":
    train_best_model()
    evaluate_on_test_set()

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/sathwikpentela/miniforge3/envs/codify/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.

  | Name        | Type       | Params | Mode 
---------------------------------------------------
0 | conv_layers | Sequential | 1.6 M  | train
1 | fc1         | Linear     | 262 K  | train
2 | dropout     | Dropout    | 0      | train
3 | fc2         | Linear     | 5.1 K  | train
---------------------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.405     Total estimated model params size (MB)
24        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/sathwikpentela/miniforge3/envs/codify/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/sathwikpentela/miniforge3/envs/codify/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]