# DA6401 Assignment 2 - Part A
## Training CNN from Scratch on iNaturalist Dataset

In [None]:
!curl https://storage.googleapis.com/wandb_datasets/nature_12K.zip --output nature_12K.zip
!unzip nature_12K.zip > /dev/null 2>&1
!rm nature_12K.zip

In [None]:
!python -m pip install lightning

In [None]:
!pip install wandb

import wandb
wandb.login(key='130161b8988911058327a18dbbdfb663c58411b2')

In [None]:
# Import necessary libraries
import os
import math
import torch
import lightning as L
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
from torchmetrics import Accuracy
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
import wandb
import numpy as np
import matplotlib.pyplot as plt

 ###  Building a Flexible CNN Model

In [None]:
class NatureCNN(L.LightningModule):
    def __init__(self,
                 base_filters=32,
                 filter_strategy='double',
                 base_kernel=3,
                 kernel_strategy='same',
                 batch_norm=True,
                 conv_activation='relu',
                 dense_activation='relu',
                 dense_size=128,
                 learning_rate=1e-3,
                 weight_decay=1e-4,
                 dropout=0.2):
        """
        Flexible CNN model for iNaturalist dataset classification.
        
        Args:
            base_filters: Starting number of filters in first convolutional layer
            filter_strategy: Strategy for scaling filters across layers ['same', 'double', 'halve', 'alternate']
            base_kernel: Base kernel size for convolutional layers
            kernel_strategy: Strategy for kernel sizes ['same', 'decrease', 'alternate', 'pyramid']
            batch_norm: Whether to use batch normalization
            conv_activation: Activation function for convolutional layers
            dense_activation: Activation function for dense layers
            dense_size: Number of neurons in dense layer
            learning_rate: Learning rate for optimizer
            weight_decay: Weight decay for regularization
            dropout: Dropout rate
        """
        super().__init__()
        self.save_hyperparameters()

        # Generate architecture configuration
        self.filters = self.get_filter_strategy(base_filters, filter_strategy)
        self.kernel_sizes = self.generate_kernel_sizes(base_kernel, kernel_strategy)
        
        # Validate configuration
        if len(self.filters) != len(self.kernel_sizes):
            raise ValueError("Filter numbers and kernel sizes must match")

        # Build convolutional blocks
        self.features = nn.Sequential()
        in_channels = 3
        
        for i, (out_channels, k_size) in enumerate(zip(self.filters, self.kernel_sizes)):
            # Ensure odd kernel size for symmetric padding
            k_size = k_size if k_size % 2 else k_size + 1
            padding = k_size // 2
            
            # Add logging for model construction
            print(f"Layer {i+1}: {in_channels} -> {out_channels}, kernel: {k_size}x{k_size}")
            
            self.features.append(nn.Conv2d(in_channels, out_channels, k_size, padding=padding))
            
            if self.hparams.batch_norm:
                self.features.append(nn.BatchNorm2d(out_channels))
            
            self.features.append(self._get_activation(self.hparams.conv_activation))
            self.features.append(nn.MaxPool2d(2, 2))
            in_channels = out_channels

        # Calculate classifier input size
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 128, 128)
            self.feature_size = self.features(dummy).flatten(1).size(1)

        # Build classifier
        self.classifier = nn.Sequential(
            nn.Linear(self.feature_size, dense_size),
            self._get_activation(self.hparams.dense_activation),
            nn.Dropout(dropout),
            nn.Linear(dense_size, 10)  # 10 classes in the iNaturalist subset
        )

        # Metrics
        self.train_acc = Accuracy(task='multiclass', num_classes=10)
        self.val_acc = Accuracy(task='multiclass', num_classes=10)
        self.test_acc = Accuracy(task='multiclass', num_classes=10)

    @staticmethod
    def get_filter_strategy(base: int, strategy: str, num_layers=5) -> list:
        """Generate filter numbers based on base and strategy"""
        strategies = {
            'same': [base] * num_layers,
            'double': [base * (2**i) for i in range(num_layers)],
            'halve': [max(8, base // (2**i)) for i in range(num_layers)],
            'alternate': [base * (2 if i%2 else 1) for i in range(num_layers)]
        }
        return strategies[strategy.lower()]

    @staticmethod
    def generate_kernel_sizes(base: int, strategy: str, num_layers=5) -> list:
        """Generate kernel sizes based on base and strategy"""
        strategies = {
            'same': [base] * num_layers,
            'decrease': [max(3, base - 2*i) for i in range(num_layers)],
            'alternate': [base if i%2 else (base-2) for i in range(num_layers)],
            'pyramid': [base + 2*i for i in range(num_layers//2)] + 
                       [base + 2*(num_layers//2 - i) for i in range(1, num_layers-num_layers//2+1)]
        }
        return strategies[strategy.lower()]

    def _get_activation(self, name: str) -> nn.Module:
        """Map activation name to PyTorch module"""
        activations = {
            'relu': nn.ReLU(),
            'gelu': nn.GELU(),
            'silu': nn.SiLU(),
            'mish': nn.Mish(),
            'leaky_relu': nn.LeakyReLU(0.1),
            'elu': nn.ELU(),
            'selu': nn.SELU(),
            'tanh': nn.Tanh(),
            'sigmoid': nn.Sigmoid()
        }
        return activations[name.lower()]

    def forward(self, x):
        x = self.features(x)
        x = x.flatten(1)
        return self.classifier(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss, prog_bar=True , on_step=False, on_epoch=True)
        self.train_acc(logits, y)
        self.log('train_acc', self.train_acc, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
        self.val_acc(logits, y)
        self.log('val_acc', self.val_acc, prog_bar=True, on_step=False, on_epoch=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('test_loss', loss)
        self.test_acc(logits, y)
        self.log('test_acc', self.test_acc, prog_bar=True, on_step=False, on_epoch=True)
        

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='max',
            factor=0.5,
            patience=2
        )
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_acc',
                'interval': 'epoch',
                'frequency': 1
            }
        }
        
    def calculate_computations(self, image_size=128):
        """
        Calculate the total number of computations (FLOPs) performed by the network
        
        Args:
            image_size: Size of input image (assumed square)
            
        Returns:
            Total number of multiplication and addition operations
        """
        total_flops = 0
        input_size = image_size
        
        # Compute FLOPs for each convolutional layer
        in_channels = 3
        for i, (out_channels, k_size) in enumerate(zip(self.filters, self.kernel_sizes)):
            # Each output pixel requires k*k*in_channels multiplications and additions
            # Number of output pixels = input_size^2 (due to padding)
            # Number of output channels = out_channels
            flops_per_layer = input_size**2 * out_channels * k_size**2 * in_channels
            
            # Add batch norm operations if used (4 operations per element)
            if self.hparams.batch_norm:
                flops_per_layer += 4 * input_size**2 * out_channels
            
            # Add activation function operations (1 operation per element)
            flops_per_layer += input_size**2 * out_channels
            
            print(f"Layer {i+1}: {flops_per_layer} FLOPs")
            total_flops += flops_per_layer
            
            # Update for next layer (maxpool reduces spatial dimensions by half)
            input_size = input_size // 2
            in_channels = out_channels
        
        # Compute FLOPs for the dense layer
        flops_dense = self.feature_size * self.hparams.dense_size
        # Add activation
        flops_dense += self.hparams.dense_size
        # Add dropout (1 operation per element)
        flops_dense += self.hparams.dense_size
        
        print(f"Dense layer: {flops_dense} FLOPs")
        total_flops += flops_dense
        
        # Compute FLOPs for the output layer
        flops_output = self.hparams.dense_size * 10
        print(f"Output layer: {flops_output} FLOPs")
        total_flops += flops_output
        
        return total_flops
    
    def count_parameters(self):
        """
        Calculate the total number of trainable parameters in the network
        
        Returns:
            Total number of parameters
        """
        total_params = 0
        in_channels = 3
        
        # Count parameters for convolutional layers
        for i, (out_channels, k_size) in enumerate(zip(self.filters, self.kernel_sizes)):
            # Conv weights: out_channels * in_channels * k_size * k_size
            params_conv = out_channels * in_channels * k_size**2
            # Conv bias: out_channels
            params_conv += out_channels
            
            # BatchNorm parameters: 2 * out_channels (gamma and beta)
            params_bn = 2 * out_channels if self.hparams.batch_norm else 0
            
            layer_params = params_conv + params_bn
            print(f"Layer {i+1}: {layer_params} parameters")
            total_params += layer_params
            
            in_channels = out_channels
        
        # Count parameters for dense layer
        params_dense = self.feature_size * self.hparams.dense_size + self.hparams.dense_size
        print(f"Dense layer: {params_dense} parameters")
        total_params += params_dense
        
        # Count parameters for output layer
        params_output = self.hparams.dense_size * 10 + 10
        print(f"Output layer: {params_output} parameters")
        total_params += params_output
        
        return total_params


Calculation of Computations and Parameters for Question 1

In [None]:
# Create a model with default parameters to calculate computations and parameters
default_model = NatureCNN()

# Print the model architecture
print("\nModel Architecture:")
print(default_model)

# Calculate and print total computations
print("\nTotal Computations (FLOPs):")
total_flops = default_model.calculate_computations()
print(f"Total FLOPs: {total_flops:,}")

# Calculate and print total parameters
print("\nTotal Parameters:")
total_params = default_model.count_parameters()
print(f"Total Parameters: {total_params:,}")

# Analytical calculation for arbitrary values
print("\nAnalytical calculations for generic model:")
print("For a model with m filters in each layer of size k×k and n neurons in dense layer")
print("Total computations (FLOPs) = 5*(image_size^2 * m * k^2 * m) + feature_size*n + n*10")
print("Total parameters = 5*(m * m * k^2 + m) + feature_size*n + n + n*10 + 10")

### Training the Model and Hyperparameter Tuning

In [None]:
class NatureDataModule(L.LightningDataModule):
    def __init__(self, data_dir='/kaggle/working/inaturalist_12K/', image_size=128, 
                 batch_size=32, num_workers=4, data_aug=True):
        """
        Data module for the iNaturalist dataset with stratified train/validation split
        
        Args:
            data_dir: Directory containing the dataset
            image_size: Size to resize images to
            batch_size: Batch size for training
            num_workers: Number of workers for data loading
            data_aug: Whether to apply data augmentation
        """
        super().__init__()
        self.data_dir = data_dir
        self.image_size = image_size
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.data_aug = data_aug
        self.transforms = self._get_transforms()
        
        # Log configuration
        print(f"Initializing NatureDataModule:")
        print(f"- Data directory: {self.data_dir}")
        print(f"- Image size: {self.image_size}")
        print(f"- Batch size: {self.batch_size}")
        print(f"- Data augmentation: {self.data_aug}")
    
    def _get_transforms(self):
        """Create transforms for data preprocessing and augmentation"""
        base = [
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                std=[0.229, 0.224, 0.225])
        ]
        
        if self.data_aug:
            augmentations = [
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(0.1, 0.1, 0.1),
                transforms.RandomRotation(15)
            ]
            # Log the augmentations being used
            print("Using data augmentation:")
            print(" - RandomHorizontalFlip")
            print(" - ColorJitter(0.1, 0.1, 0.1)")
            print(" - RandomRotation(15)")
            return transforms.Compose(augmentations + base)
        else:
            return transforms.Compose(base)
    
    def setup(self, stage=None):
        """Set up the dataset with stratified train/validation split"""
        print("Setting up dataset...")
        full_dataset = ImageFolder(os.path.join(self.data_dir, 'train'), 
                                 transform=self.transforms)
        
        # Log class names
        self.classes = full_dataset.classes
        print(f"Classes: {self.classes}")
        
        # Stratified split
        class_indices = {}
        for idx, (_, label) in enumerate(full_dataset):
            if label not in class_indices:
                class_indices[label] = []
            class_indices[label].append(idx)
        
        train_indices = []
        val_indices = []
        
        # Ensure equal class representation in the validation set
        for label, indices in class_indices.items():
            n_val = int(len(indices) * 0.2)  # 20% for validation
            np.random.shuffle(indices)
            val_indices.extend(indices[:n_val])
            train_indices.extend(indices[n_val:])
            print(f"Class {label} ({full_dataset.classes[label]}): {len(indices)-n_val} train, {n_val} validation")
        
        self.train_dataset = Subset(full_dataset, train_indices)
        self.val_dataset = Subset(full_dataset, val_indices)
        self.test_dataset = ImageFolder(os.path.join(self.data_dir, 'val'),
                                      transform=self.transforms)
        
        print(f"Total training samples: {len(self.train_dataset)}")
        print(f"Total validation samples: {len(self.val_dataset)}")
        print(f"Total test samples: {len(self.test_dataset)}")

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size,
                        shuffle=True, num_workers=self.num_workers,
                        persistent_workers=True if self.num_workers > 0 else False)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,
                        num_workers=self.num_workers, 
                        persistent_workers=True if self.num_workers > 0 else False)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size,
                        num_workers=self.num_workers)


#### Hyperparameter Tuning with Weights & Biases

In [None]:

def visualize_model(model, dm):
    """Create visualizations for model analysis and log to W&B"""
    # Get a batch of test data
    test_loader = dm.test_dataloader()
    batch = next(iter(test_loader))
    images, labels = batch
    
    # Get predictions
    model.eval()
    with torch.no_grad():
        logits = model(images)
        preds = torch.argmax(logits, dim=1)
    
    # Log sample predictions
    fig, axes = plt.subplots(3, 3, figsize=(10, 10))
    for i, ax in enumerate(axes.flatten()):
        if i < len(images):
            # Convert tensor to numpy image
            img = images[i].permute(1, 2, 0).cpu().numpy()
            # Denormalize the image
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img = std * img + mean
            img = np.clip(img, 0, 1)
            
            ax.imshow(img)
            true_class = dm.classes[labels[i]]
            pred_class = dm.classes[preds[i]]
            color = 'green' if preds[i] == labels[i] else 'red'
            ax.set_title(f"True: {true_class}\nPred: {pred_class}", color=color)
            ax.axis('off')
    
    plt.tight_layout()
    wandb.log({"prediction_samples": wandb.Image(fig)})
    plt.close(fig)
    
    # Visualize first layer filters
    if hasattr(model.features[0], 'weight'):
        filters = model.features[0].weight.detach().cpu()
        fig, axes = plt.subplots(4, 8, figsize=(12, 6))
        for i, ax in enumerate(axes.flatten()):
            if i < filters.shape[0]:
                # Take mean across input channels
                ax.imshow(filters[i].mean(0), cmap='viridis')
                ax.axis('off')
        plt.tight_layout()
        wandb.log({"first_layer_filters": wandb.Image(fig)})
        plt.close(fig)
    
    # Log confusion matrix
    model.eval()
    all_preds = []
    all_labels = []
    
    for batch in test_loader:
        images, labels = batch
        with torch.no_grad():
            logits = model(images)
            preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    # Convert to numpy arrays
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    # Create confusion matrix
    confusion = wandb.plot.confusion_matrix(
        y_true=all_labels,
        preds=all_preds,
        class_names=dm.classes
    )
    wandb.log({"confusion_matrix": confusion})

In [None]:
# WandB sweep configuration
import math

sweep_config = {
    'method': 'random',
    'metric': {
        'name': 'val_acc',
        'goal': 'maximize'
    },
    'parameters': {
        'base_filters': {
            'values': [16, 32, 64]
        },
        'filter_strategy': {
            'values': ['double', 'halve', 'alternate', 'same']
        },
        'base_kernel': {
            'values': [3, 5, 7]
        },
        'kernel_strategy': {
            'values': ['same', 'decrease', 'alternate', 'pyramid']
        },
        'batch_norm': {
            'values': [True, False]
        },
        'conv_activation': {
            'values': ['relu', 'gelu', 'silu', 'mish']
        },
        'dense_activation': {
            'values': ['relu', 'gelu', 'silu', 'mish']
        },
        'dense_size': {
            'values': [128, 256, 512]
        },
        'learning_rate': {
            'distribution': 'log_uniform_values',
            'min': 1e-4,
            'max': 1e-2
        },
        'weight_decay': {
            'distribution': 'log_uniform_values',
            'min': 1e-6,
            'max': 1e-3
        },
        'dropout': {
            'values': [0.0, 0.2, 0.3, 0.4, 0.5]
        },
        'batch_size': {
            'values': [32, 64, 128]
        },
        'data_aug': {
            'values': [True, False]
        }
    }
}


# Function to train model with given hyperparameters
def train_sweep():
    """Train a model with the specified hyperparameters and log results to W&B"""
    # Initialize wandb run
    with wandb.init() as run:
        config = wandb.config
        wandb_logger = WandbLogger(log_model='all')
        
        # Log hyperparameters to be used
        print(f"Training with hyperparameters:")
        for key, value in config.items():
            print(f"- {key}: {value}")
        
        # Initialize data module
        dm = NatureDataModule(
            image_size=128,
            batch_size=config.batch_size,
            data_aug=config.data_aug
        )
        
        # Initialize model from sweep config
        model = NatureCNN(
            base_filters=config.base_filters,
            filter_strategy=config.filter_strategy,
            base_kernel=config.base_kernel,
            kernel_strategy=config.kernel_strategy,
            batch_norm=config.batch_norm,
            conv_activation=config.conv_activation,
            dense_size=config.dense_size,
            dense_activation=config.dense_activation,
            learning_rate=config.learning_rate,
            weight_decay=config.weight_decay,
            dropout=config.dropout,
        )
        
        # Log model architecture
        wandb.log({"model_summary": str(model)})
        
        # Set up trainer with early stopping and model checkpointing
        trainer = L.Trainer(
            max_epochs=5,
            logger=wandb_logger,
            callbacks=[
                EarlyStopping(monitor='val_acc', mode='max', patience=3),
                ModelCheckpoint(monitor='val_acc', mode='max', filename='best-{epoch:02d}-{val_acc:.4f}')
            ],
            precision='16-mixed',  # Use mixed precision for faster training
            accelerator='auto',
            devices=1,
            log_every_n_steps=1000
        )
        
        # Train model
        trainer.fit(model, dm)
        
        # Test best model
        best_model_path = trainer.checkpoint_callback.best_model_path
        print(f"Best model saved at: {best_model_path}")
        
        # best_model = NatureCNN.load_from_checkpoint(best_model_path)
        # test_results = trainer.test(best_model, dm)
        
        # # Log additional metrics and visualizations
        # visualize_model(best_model, dm)
        
        # return test_results

#### Running the Sweep

In [None]:
# # Start the sweep
# sweep_id = wandb.sweep(sweep_config, project="Assignment_CNN-partA")
# wandb.agent(sweep_id, train_sweep, count=30) 

### Hyperparameter Analysis and Insights

In [None]:
import pandas as pd

sweep_id = "mqi3zdx9"

# Function to analyze sweep results
def analyze_sweep_results():
    """Analyze the results from the hyperparameter sweep and generate insights"""
    # Initialize wandb API
    api = wandb.Api()
    
    # Get the sweep runs
    sweep = api.sweep(f"da24m005-iit-madras/Assignment_CNN-partA/sweeps/{sweep_id}")
    runs = sorted(sweep.runs, key=lambda run: run.summary.get('val_acc', 0), reverse=True)
    
    print(f"Total runs: {len(runs)}")
    print(f"Best validation accuracy: {runs[0].summary.get('val_acc', 0):.4f}")
    
    # Extract configurations and metrics for all runs
    configs = []
    metrics = []
    
    for run in runs:
        config = {k: v for k, v in run.config.items() 
                 if not k.startswith('_') and k != 'wandb'}
        metric = {'val_acc': run.summary.get('val_acc', 0),
                 'test_acc': run.summary.get('test/acc', 0)}
        configs.append(config)
        metrics.append(metric)
    
    # Create a DataFrame
    df = pd.DataFrame([{**c, **m} for c, m in zip(configs, metrics)])
    
    # Handle non-numeric columns for correlation
    numeric_df = df.select_dtypes(include=[np.number])
    
    # Calculate correlation with validation accuracy
    corr = numeric_df.corr()['val_acc'].sort_values(ascending=False)
    print("\nCorrelation with validation accuracy:")
    print(corr)
    
    # Analyze effect of filter strategy
    print("\nEffect of filter strategy:")
    filter_strategy_effect = df.groupby('filter_strategy')['val_acc'].agg(['mean', 'max', 'count'])
    print(filter_strategy_effect.sort_values('max', ascending=False))
    
    # Analyze effect of activation function
    print("\nEffect of activation function:")
    activation_effect = df.groupby('conv_activation')['val_acc'].agg(['mean', 'max', 'count'])
    print(activation_effect.sort_values('max', ascending=False))
    
    # Analyze effect of batch normalization
    print("\nEffect of batch normalization:")
    bn_effect = df.groupby('batch_norm')['val_acc'].agg(['mean', 'max', 'count'])
    print(bn_effect)
    
    # Analyze effect of data augmentation
    print("\nEffect of data augmentation:")
    aug_effect = df.groupby('data_aug')['val_acc'].agg(['mean', 'max', 'count'])
    print(aug_effect)
    
    # Generate insights (unchanged)
    print("\nKey insights from hyperparameter sweep:")
    
    # Insight 1: Filter strategy
    best_filter = filter_strategy_effect.index[filter_strategy_effect['max'].argmax()]
    print(f"1. Filter strategy: '{best_filter}' performed best, suggesting that "
          f"{'increasing filter complexity in deeper layers' if best_filter == 'double' else 'maintaining consistent filters across layers' if best_filter == 'same' else 'using alternating filter patterns' if best_filter == 'alternate' else 'reducing filter complexity in deeper layers'} "
          f"is effective for this dataset.")
    
    # Insight 2: Activation function
    best_activation = activation_effect.index[activation_effect['max'].argmax()]
    print(f"2. Activation function: '{best_activation}' yielded the highest accuracy, "
          f"which may be due to its {'better gradient flow' if best_activation in ['gelu', 'silu', 'mish'] else 'simplicity and efficiency' if best_activation == 'relu' else 'special properties'}.")
    
    # Insight 3: Batch normalization
    bn_better = bn_effect.loc[True, 'mean'] > bn_effect.loc[False, 'mean']
    print(f"3. Batch normalization {'improved' if bn_better else 'did not significantly improve'} model performance, "
          f"suggesting it {'helps normalize feature distributions' if bn_better else 'may not be necessary for this dataset'}.")
    
    # Insight 4: Data augmentation
    aug_better = aug_effect.loc[True, 'mean'] > aug_effect.loc[False, 'mean']
    print(f"4. Data augmentation {'improved' if aug_better else 'did not significantly improve'} generalization, "
          f"indicating that {'increasing dataset diversity helps prevent overfitting' if aug_better else 'the dataset may already contain sufficient variety'}.")
    
    # Return the best configuration
    best_config = {k: runs[0].config[k] for k in configs[0].keys()}
    return best_config


# The function would be called after the sweep completes
best_config = analyze_sweep_results()


### Evaluation on Test Data

In [None]:

def visualize_test_predictions(model, dm):
    """Create a grid of test images with model predictions"""
    # Get test dataloader
    test_loader = dm.test_dataloader()
    
    # Get a batch of test images
    all_images = []
    all_labels = []
    all_preds = []
    
    # Get 30 random samples for visualization
    for images, labels in test_loader:
        with torch.no_grad():
            logits = model(images)
            preds = torch.argmax(logits, dim=1)
        
        all_images.extend(images.cpu())
        all_labels.extend(labels.cpu())
        all_preds.extend(preds.cpu())
        
        if len(all_images) >= 30:
            break
    
    # Convert to numpy arrays
    all_images = [img.permute(1, 2, 0).numpy() for img in all_images[:30]]
    all_labels = [label.item() for label in all_labels[:30]]
    all_preds = [pred.item() for pred in all_preds[:30]]
    
    # Denormalize images
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    all_images = [np.clip(std * img + mean, 0, 1) for img in all_images]
    
    # Create a grid of images
    fig, axes = plt.subplots(10, 3, figsize=(12, 30))
    
    for i, (img, label, pred) in enumerate(zip(all_images, all_labels, all_preds)):
        row = i % 10
        col = i // 10
        
        ax = axes[row, col]
        ax.imshow(img)
        
        true_class = dm.classes[label]
        pred_class = dm.classes[pred]
        
        # Green for correct, red for incorrect
        color = 'green' if label == pred else 'red'
        
        ax.set_title(f"True: {true_class}\nPred: {pred_class}", color=color)
        ax.axis('off')
    
    plt.tight_layout()
    wandb.log({"test_predictions_grid": wandb.Image(fig)})
    plt.close(fig)

def visualize_filters(model):
    """Visualize filters in the first convolutional layer"""
    # Get the first convolutional layer
    first_conv = model.features[0]
    
    # Get the weights
    weights = first_conv.weight.detach().cpu()
    
    # Create a grid of filter visualizations
    fig, axes = plt.subplots(8, 8, figsize=(12, 12))
    
    for i, ax in enumerate(axes.flatten()):
        if i < weights.shape[0]:
            # Normalize the filter for visualization
            filter_img = weights[i].mean(0)  # Average across input channels
            filter_min, filter_max = filter_img.min(), filter_img.max()
            filter_img = (filter_img - filter_min) / (filter_max - filter_min + 1e-8)
            
            ax.imshow(filter_img, cmap='viridis')
            ax.axis('off')
        else:
            ax.axis('off')
    
    plt.tight_layout()
    wandb.log({"first_layer_filters": wandb.Image(fig)})
    plt.close(fig)

In [None]:
def evaluate_best_model():
    """Evaluate the best model from the sweep on the test set"""
    # Initialize wandb
    wandb.init(project="Assignment2_CNN-partA", name="best_model_evaluation")

    wandb_logger = WandbLogger(log_model='all')
    
    # Retrieve the best model configuration from the sweep
    api = wandb.Api()
    sweep = api.sweep(f"da24m005-iit-madras/Assignment_CNN-partA/{sweep_id}")
    best_run = sorted(sweep.runs, key=lambda run: run.summary.get('val_acc', 0), reverse=True)[0]
    
    # Log the best configuration
    print("Best model configuration:")
    config = {k: v for k, v in best_run.config.items() if not k.startswith('_') and k != 'wandb'}
    for k, v in config.items():
        print(f"- {k}: {v}")
    
    # Initialize data module
    dm = NatureDataModule(
        image_size=128,
        batch_size=config['batch_size'],
        data_aug=False  # No augmentation for evaluation
    )
    dm.setup()
    
    # Initialize model with best configuration
    model = NatureCNN(
        base_filters=best_config['base_filters'],
        filter_strategy=best_config['filter_strategy'],
        base_kernel=best_config['base_kernel'],
        kernel_strategy=best_config['kernel_strategy'],
        batch_norm=best_config['batch_norm'],
        conv_activation=best_config['conv_activation'],
        dense_activation=best_config['dense_activation'],
        dense_size=best_config['dense_size'],
        learning_rate=best_config['learning_rate'],
        weight_decay=best_config['weight_decay'],
        dropout=best_config['dropout']
    )
    
    # Set up trainer
    trainer = L.Trainer(
            max_epochs=25,
            min_epochs=17,
            logger=wandb_logger,
            callbacks=[
                EarlyStopping(monitor='val_acc', mode='max', patience=3),
                ModelCheckpoint(monitor='val_acc', mode='max', filename='best-{epoch:02d}-{val_acc:.4f}')
            ],
            precision='16-mixed',  # Use mixed precision for faster training
            accelerator='auto',
            devices=1,
            log_every_n_steps=1000
        )
    
    trainer.fit(model, dm)

        # Test the model
    trainer.test(model, dm)
    
    # Create and log visualization of predictions
    visualize_test_predictions(model, dm)
    
    # Visualize filters (optional)
    visualize_filters(model)
    
    wandb.finish()
    
    return model

# Call the evaluation function

In [None]:
model = evaluate_best_model()