# 

In [1]:
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Part A

In [2]:
# Install required packages
!pip install wandb pytorch-lightning

# Import necessary libraries
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger
import wandb
from PIL import Image
import matplotlib.pyplot as plt
from collections import defaultdict
import math

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed()

# Configure device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [4]:
class CustomCNN(LightningModule):
    def __init__(self, 
                 num_classes=10,
                 filter_counts=[32, 32, 64, 64, 128],
                 filter_sizes=[3, 3, 3, 3, 3],
                 activation='relu',
                 dense_neurons=512,
                 input_channels=3,
                 input_size=224,
                 dropout_rate=0.5,
                 learning_rate=0.001,
                 batch_norm=False):
        """
        Custom CNN architecture with flexible hyperparameters
        
        Args:
            num_classes (int): Number of output classes
            filter_counts (list): Number of filters in each conv layer
            filter_sizes (list): Size of filters in each conv layer
            activation (str): Activation function ('relu', 'gelu', 'silu', 'mish')
            dense_neurons (int): Number of neurons in the dense layer
            input_channels (int): Number of input channels (3 for RGB)
            input_size (int): Size of input images (assumes square)
            dropout_rate (float): Dropout rate
            learning_rate (float): Learning rate for optimizer
            batch_norm (bool): Whether to use batch normalization
        """
        super().__init__()
        self.save_hyperparameters()
        
        # Configure activation function
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        elif activation == 'silu':
            self.activation = nn.SiLU()
        elif activation == 'mish':
            self.activation = nn.Mish()
        else:
            self.activation = nn.ReLU()
        
        # Build the network
        self.conv_layers = nn.ModuleList()
        
        # Calculate feature map sizes for computational analysis
        feature_size = input_size
        feature_sizes = [feature_size]
        
        # First convolutional block
        in_channels = input_channels
        for i in range(5):
            out_channels = filter_counts[i]
            filter_size = filter_sizes[i]
            
            # Create convolutional block
            conv_block = []
            
            # Convolutional layer
            conv_block.append(nn.Conv2d(in_channels, out_channels, kernel_size=filter_size, padding=filter_size//2))
            
            # Batch normalization (optional)
            if batch_norm:
                conv_block.append(nn.BatchNorm2d(out_channels))
            
            # Activation
            conv_block.append(self.activation)
            
            # Max pooling
            conv_block.append(nn.MaxPool2d(kernel_size=2, stride=2))
            
            # Add block to model
            self.conv_layers.append(nn.Sequential(*conv_block))
            
            # Update feature size (divided by 2 due to max pooling)
            feature_size = feature_size // 2
            feature_sizes.append(feature_size)
            
            # Update channels for next layer
            in_channels = out_channels
        
        # Calculate flattened features size
        self.flattened_size = filter_counts[-1] * feature_size * feature_size
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(dropout_rate),
            nn.Linear(self.flattened_size, dense_neurons),
            self.activation,
            nn.Dropout(dropout_rate),
            nn.Linear(dense_neurons, num_classes)
        )
        
        # Store additional parameters
        self.learning_rate = learning_rate
        self.num_classes = num_classes
        self.filter_counts = filter_counts
        self.filter_sizes = filter_sizes
        self.feature_sizes = feature_sizes
        
        # Calculate parameters and computations
        self.total_params = self.calculate_total_params()
        self.total_computations = self.calculate_total_computations()
        
    def forward(self, x):
        """Forward pass through the network"""
        # Pass through convolutional layers
        for conv_layer in self.conv_layers:
            x = conv_layer(x)
        
        # Pass through classifier
        return self.classifier(x)
    
    def calculate_total_params(self):
        """Calculate the total number of parameters in the network"""
        total = 0
        
        # Convolutional layers parameters
        input_channels = 3
        for i in range(5):
            output_channels = self.filter_counts[i]
            filter_size = self.filter_sizes[i]
            
            # Weight parameters: out_channels * in_channels * filter_height * filter_width
            params = output_channels * input_channels * filter_size * filter_size
            # Bias parameters: out_channels
            params += output_channels
            
            total += params
            input_channels = output_channels
        
        # Dense layer parameters
        # First dense layer: flattened_size * dense_neurons + dense_neurons (bias)
        total += self.flattened_size * self.hparams.dense_neurons + self.hparams.dense_neurons
        # Output layer: dense_neurons * num_classes + num_classes (bias)
        total += self.hparams.dense_neurons * self.num_classes + self.num_classes
        
        return total
    
    def calculate_total_computations(self):
        """Calculate the total number of computations in the network"""
        total = 0
        
        # Convolutional layers computations
        input_channels = 3
        for i in range(5):
            output_channels = self.filter_counts[i]
            filter_size = self.filter_sizes[i]
            feature_size = self.feature_sizes[i]
            
            # Convolution computations: 
            # out_channels * in_channels * filter_height * filter_width * feature_height * feature_width
            comp = output_channels * input_channels * filter_size * filter_size * feature_size * feature_size
            
            total += comp
            input_channels = output_channels
        
        # Dense layer computations
        # First dense layer: flattened_size * dense_neurons
        total += self.flattened_size * self.hparams.dense_neurons
        # Output layer: dense_neurons * num_classes
        total += self.hparams.dense_neurons * self.num_classes
        
        return total
    
    def configure_optimizers(self):
        """Configure optimizer"""
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        """Training step"""
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        # Log metrics
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        """Validation step"""
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        # Log metrics
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        """Test step"""
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        # Log metrics
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        
        return {'loss': loss, 'preds': preds, 'targets': y}
    
    def test_epoch_end(self, outputs):
        """Gather predictions after test epoch"""
        all_preds = torch.cat([x['preds'] for x in outputs])
        all_targets = torch.cat([x['targets'] for x in outputs])
        
        # Calculate confusion matrix and class accuracies
        conf_matrix = torch.zeros(self.num_classes, self.num_classes)
        for t, p in zip(all_targets, all_preds):
            conf_matrix[t.long(), p.long()] += 1
            
        # Log confusion matrix
        class_acc = conf_matrix.diag() / conf_matrix.sum(1)
        
        # Log to wandb
        wandb.log({"confusion_matrix": wandb.plot.confusion_matrix(
            preds=all_preds.cpu().numpy(),
            y_true=all_targets.cpu().numpy(),
            class_names=[str(i) for i in range(self.num_classes)])})
        
        return {'test_acc': (all_preds == all_targets).float().mean()}

In [5]:
class iNaturalistDataModule(LightningModule):
    def __init__(self, data_dir='./inaturalist_data', batch_size=32, num_workers=4, 
                 input_size=224, val_split=0.2, augmentation=False):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.input_size = input_size
        self.val_split = val_split
        self.augmentation = augmentation
        
    def setup(self, stage=None):
        """Setup data transformations and load datasets"""
        # Define transformations
        if self.augmentation:
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(self.input_size),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(10),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            train_transform = transforms.Compose([
                transforms.Resize((self.input_size, self.input_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
            
        val_transform = transforms.Compose([
            transforms.Resize((self.input_size, self.input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Load datasets
        train_dir = os.path.join(self.data_dir, 'train')
        test_dir = os.path.join(self.data_dir, 'test')
        
        self.train_dataset = ImageFolder(root=train_dir, transform=train_transform)
        self.test_dataset = ImageFolder(root=test_dir, transform=val_transform)
        
        # Split train set into train and validation
        dataset_size = len(self.train_dataset)
        indices = list(range(dataset_size))
        np.random.shuffle(indices)
        
        # Create stratified split
        class_indices = defaultdict(list)
        for idx, (_, label) in enumerate(self.train_dataset.samples):
            class_indices[label].append(idx)
        
        train_indices = []
        val_indices = []
        
        for class_idx, indices in class_indices.items():
            np.random.shuffle(indices)
            split_idx = int(len(indices) * (1 - self.val_split))
            train_indices.extend(indices[:split_idx])
            val_indices.extend(indices[split_idx:])
        
        # Create samplers for train and validation sets
        self.train_sampler = SubsetRandomSampler(train_indices)
        self.val_sampler = SubsetRandomSampler(val_indices)
        
        # Create a validation dataset with the same transforms as test
        self.val_dataset = ImageFolder(root=train_dir, transform=val_transform)
        
    def train_dataloader(self):
        """Return train dataloader"""
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            sampler=self.train_sampler,
            num_workers=self.num_workers
        )
    
    def val_dataloader(self):
        """Return validation dataloader"""
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            sampler=self.val_sampler,
            num_workers=self.num_workers
        )
    
    def test_dataloader(self):
        """Return test dataloader"""
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )

In [None]:
def setup_wandb_sweep():
    """Define sweep configuration for hyperparameter tuning"""
    sweep_config = {
        'method': 'bayes',  # Bayesian optimization
        'metric': {
            'name': 'val_acc',
            'goal': 'maximize'
        },
        'parameters': {
            'filter_counts_strategy': {
                'values': ['same', 'doubling', 'halving']  # Different filter count strategies
            },
            'base_filters': {
                'values': [16, 32, 64]  # Base number of filters
            },
            'filter_size': {
                'values': [3, 5]  # Filter sizes
            },
            'activation': {
                'values': ['relu', 'gelu', 'silu', 'mish']  # Different activation functions
            },
            'dense_neurons': {
                'values': [128, 256, 512]  # Number of neurons in dense layer
            },
            'dropout_rate': {
                'values': [0.2, 0.3, 0.5]  # Dropout rate
            },
            'learning_rate': {
                'values': [0.0001, 0.001, 0.01]  # Learning rate
            },
            'batch_norm': {
                'values': [True, False]  # Whether to use batch normalization
            },
            'batch_size': {
                'values': [16, 32, 64]  # Batch size
            },
            'augmentation': {
                'values': [True, False]  # Whether to use data augmentation
            }
        }
    }
    
    return sweep_config

def train_model_sweep():
    """Training function for sweep"""
    # Initialize wandb
    wandb.init()
    
    # Get hyperparameters from wandb
    config = wandb.config
    
    # Generate filter counts based on strategy
    if config.filter_counts_strategy == 'same':
        filter_counts = [config.base_filters] * 5
    elif config.filter_counts_strategy == 'doubling':
        filter_counts = [config.base_filters * (2**i) for i in range(5)]
    elif config.filter_counts_strategy == 'halving':
        filter_counts = [config.base_filters * (2**(4-i)) for i in range(5)]
    
    # Generate filter sizes
    filter_sizes = [config.filter_size] * 5
    
    # Create data module
    data_module = iNaturalistDataModule(
        batch_size=config.batch_size,
        augmentation=config.augmentation
    )
    data_module.setup()
    
    # Create model with hyperparameters
    model = CustomCNN(
        num_classes=10,  # Assuming 10 classes in iNaturalist subset
        filter_counts=filter_counts,
        filter_sizes=filter_sizes,
        activation=config.activation,
        dense_neurons=config.dense_neurons,
        dropout_rate=config.dropout_rate,
        learning_rate=config.learning_rate,
        batch_norm=config.batch_norm
    )
    
    # Setup callbacks
    callbacks = [
        ModelCheckpoint(
            monitor='val_acc',
            filename='best-{epoch:02d}-{val_acc:.4f}',
            save_top_k=1,
            mode='max'
        ),
        EarlyStopping(
            monitor='val_acc',
            patience=5,
            mode='max'
        )
    ]
    
    # Setup wandb logger
    wandb_logger = WandbLogger(project="inaturalist_cnn_sweep")
    
    # Create trainer
    trainer = Trainer(
        max_epochs=20,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1,
        callbacks=callbacks,
        logger=wandb_logger,
        log_every_n_steps=10
    )
    
    # Train model
    trainer.fit(model, data_module.train_dataloader(), data_module.val_dataloader())
    
    # Get best validation accuracy
    best_val_acc = trainer.callback_metrics.get('val_acc', 0)
    
    # Log metrics
    wandb.log({
        'best_val_acc': best_val_acc,
        'total_params': model.total_params,
        'total_computations': model.total_computations
    })
    
    return model, best_val_acc

def run_sweep():
    """Run the sweep"""
    # Initialize wandb
    wandb.login()
    
    # Setup sweep
    sweep_config = setup_wandb_sweep()
    
    # Create sweep
    sweep_id = wandb.sweep(sweep_config, project="inaturalist_cnn_sweep")
    
    # Run sweep
    wandb.agent(sweep_id, function=train_model_sweep, count=30)  # Run 30 experiments

In [None]:
def train_final_model(config):
    """Train final model with best hyperparameters"""
    # Initialize wandb
    wandb.init(project="inaturalist_cnn_final", config=config)
    
    # Generate filter counts based on strategy
    if config['filter_counts_strategy'] == 'same':
        filter_counts = [config['base_filters']] * 5
    elif config['filter_counts_strategy'] == 'doubling':
        filter_counts = [config['base_filters'] * (2**i) for i in range(5)]
    elif config['filter_counts_strategy'] == 'halving':
        filter_counts = [config['base_filters'] * (2**(4-i)) for i in range(5)]
    
    # Generate filter sizes
    filter_sizes = [config['filter_size']] * 5
    
    # Create data module
    data_module = iNaturalistDataModule(
        batch_size=config['batch_size'],
        augmentation=config['augmentation']
    )
    data_module.setup()
    
    # Create model with hyperparameters
    model = CustomCNN(
        num_classes=10,  # Assuming 10 classes in iNaturalist subset
        filter_counts=filter_counts,
        filter_sizes=filter_sizes,
        activation=config['activation'],
        dense_neurons=config['dense_neurons'],
        dropout_rate=config['dropout_rate'],
        learning_rate=config['learning_rate'],
        batch_norm=config['batch_norm']
    )
    
    # Log model summary
    wandb.watch(model, log="all")
    
    # Setup callbacks
    callbacks = [
        ModelCheckpoint(
            monitor='val_acc',
            filename='best-{epoch:02d}-{val_acc:.4f}',
            save_top_k=1,
            mode='max'
        )
    ]
    
    # Setup wandb logger
    wandb_logger = WandbLogger(project="inaturalist_cnn_final")
    
    # Create trainer
    trainer = Trainer(
        max_epochs=50,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1,
        callbacks=callbacks,
        logger=wandb_logger,
        log_every_n_steps=10
    )
    
    # Train model
    trainer.fit(model, data_module.train_dataloader(), data_module.val_dataloader())
    
    # Test model
    test_results = trainer.test(model, dataloader=data_module.test_dataloader())
    
    # Log test results
    wandb.log({
        'test_acc': test_results[0]['test_acc'],
        'test_loss': test_results[0]['test_loss']
    })
    
    # Log model architecture
    wandb.log({
        'total_params': model.total_params,
        'total_computations': model.total_computations
    })
    
    return model, test_results

def visualize_test_samples(model, data_module, num_samples=30):
    """Visualize test samples with predictions"""
    # Get test dataloader
    test_loader = data_module.test_dataloader()
    
    # Get a batch of test data
    images, labels = next(iter(test_loader))
    
    # Move to device
    model = model.to(device)
    images = images.to(device)
    
    # Make predictions
    with torch.no_grad():
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
    
    # Convert to numpy for visualization
    images = images.cpu().numpy()
    labels = labels.cpu().numpy()
    predicted = predicted.cpu().numpy()
    
    # Get class names
    class_names = data_module.test_dataset.classes
    
    # Visualize images with predictions
    fig, axes = plt.subplots(10, 3, figsize=(15, 30))
    
    for i, ax in enumerate(axes.flat):
        if i < num_samples:
            # Transpose image from (C, H, W) to (H, W, C)
            img = np.transpose(images[i], (1, 2, 0))
            
            # Denormalize
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img = std * img + mean
            img = np.clip(img, 0, 1)
            
            # Plot image
            ax.imshow(img)
            
            # Get true and predicted labels
            true_label = class_names[labels[i]]
            pred_label = class_names[predicted[i]]
            
            # Set title
            if labels[i] == predicted[i]:
                ax.set_title(f"True: {true_label}\nPred: {pred_label}", color='green')
            else:
                ax.set_title(f"True: {true_label}\nPred: {pred_label}", color='red')
            
            ax.axis('off')
    
    plt.tight_layout()
    
    # Log to wandb
    wandb.log({"test_predictions": wandb.Image(plt)})

def visualize_filters(model):
    """Visualize filters in the first convolutional layer"""
    # Get first layer filters
    filters = model.conv_layers[0][0].weight.data.cpu().numpy()
    
    # Create figure
    fig, axes = plt.subplots(8, 8, figsize=(12, 12))
    
    # Plot filters
    for i, ax in enumerate(axes.flat):
        if i < filters.shape[0]:
            # Normalize filter for visualization
            f = filters[i].transpose(1, 2, 0)
            f = (f - f.min()) / (f.max() - f.min())
            
            # Plot filter
            ax.imshow(f)
            ax.axis('off')
    
    plt.tight_layout()
    
    # Log to wandb
    wandb.log({"first_layer_filters": wandb.Image(plt)})

# Part B

In [None]:
class PretrainedModel(LightningModule):
    def __init__(self, model_name='resnet50', num_classes=10, learning_rate=0.001, 
                 fine_tuning_strategy='last_layer', feature_extract=True, unfreeze_layers=3):
        """
        Fine-tune a pre-trained model
        
        Args:
            model_name (str): Name of pre-trained model ('resnet50', 'vgg16', etc.)
            num_classes (int): Number of output classes
            learning_rate (float): Learning rate for optimizer
            fine_tuning_strategy (str): Strategy for fine-tuning ('last_layer', 'all_layers', 'k_last_layers')
            feature_extract (bool): If True, only update the reshaped layer params
            unfreeze_layers (int): Number of layers to unfreeze for 'k_last_layers' strategy
        """
        super().__init__()
        self.save_hyperparameters()
        
        self.model_name = model_name
        self.num_classes = num_classes
        self.learning_rate = learning_rate
        self.fine_tuning_strategy = fine_tuning_strategy
        self.unfreeze_layers = unfreeze_layers
        
        # Initialize the pre-trained model
        self.model = self._initialize_model()
        
    def _initialize_model(self):
        """Initialize the pre-trained model"""
        if self.model_name == 'resnet50':
            model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, self.num_classes)
        
        elif self.model_name == 'vgg16':
            model = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
            num_ftrs = model.classifier[6].in_features
            model.classifier[6] = nn.Linear(num_ftrs, self.num_classes)
        
        elif self.model_name == 'googlenet':
            model = torchvision.models.googlenet(weights=torchvision.models.GoogLeNet_Weights.IMAGENET1K_V1)
            num_ftrs = model.fc.in_features
            model.fc = nn.Linear(num_ftrs, self.num_classes)
        
        elif self.model_name == 'efficientnet_v2_s':
            model = torchvision.models.efficientnet_v2_s(weights=torchvision.models.EfficientNet_V2_S_Weights.IMAGENET1K_V1)
            num_ftrs = model.classifier[1].in_features
            model.classifier[1] = nn.Linear(num_ftrs, self.num_classes)
        
        elif self.model_name == 'vit_b_16':
            model = torchvision.models.vit_b_16(weights=torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1)
            num_ftrs = model.heads.head.in_features
            model.heads.head = nn.Linear(num_ftrs, self.num_classes)
        
        # Apply fine-tuning strategy
        self._apply_fine_tuning_strategy(model)
        
        return model
    
    def _apply_fine_tuning_strategy(self, model):
        """Apply the fine-tuning strategy to the model"""
        if self.fine_tuning_strategy == 'last_layer':
            # Freeze all parameters
            for param in model.parameters():
                param.requires_grad = False
            
            # Unfreeze the last layer
            if self.model_name == 'resnet50':
                for param in model.fc.parameters():
                    param.requires_grad = True
            
            elif self.model_name == 'vgg16':
                for param in model.classifier[6].parameters():
                    param.requires_grad = True
            
            elif self.model_name == 'googlenet':
                for param in model.fc.parameters():
                    param.requires_grad = True
            
            elif self.model_name == 'efficientnet_v2_s':
                for param in model.classifier[1].parameters():
                    param.requires_grad = True
            
            elif self.model_name == 'vit_b_16':
                for param in model.heads.head.parameters():
                    param.requires_grad = True
        
        elif self.fine_tuning_strategy == 'all_layers':
            # Unfreeze all parameters
            for param in model.parameters():
                param.requires_grad = True
        
        elif self.fine_tuning_strategy == 'k_last_layers':
            # First freeze all parameters
            for param in model.parameters():
                param.requires_grad = False
                
            # Then unfreeze the last k layers based on model architecture
            if self.model_name == 'resnet50':
                # Unfreeze final layers
                layers_to_unfreeze = [model.fc]
                if self.unfreeze_layers > 1:
                    layers_to_unfreeze.append(model.layer4)
                if self.unfreeze_layers > 2:
                    layers_to_unfreeze.append(model.layer3)
                
                for layer in layers_to_unfreeze:
                    for param in layer.parameters():
                        param.requires_grad = True
            
            elif self.model_name == 'vgg16':
                # Unfreeze final classifier layers
                num_unfrozen = min(self.unfreeze_layers, len(model.classifier))
                for i in range(len(model.classifier) - num_unfrozen, len(model.classifier)):
                    for param in model.classifier[i].parameters():
                        param.requires_grad = True
                        
                # If needed, also unfreeze some feature layers
                if self.unfreeze_layers > len(model.classifier):
                    remaining = self.unfreeze_layers - len(model.classifier)
                    num_unfrozen_features = min(remaining, len(model.features))
                    for i in range(len(model.features) - num_unfrozen_features, len(model.features)):
                        for param in model.features[i].parameters():
                            param.requires_grad = True
    
    def forward(self, x):
        """Forward pass"""
        return self.model(x)
    
    def configure_optimizers(self):
        """Configure optimizer"""
        params_to_update = [p for p in self.parameters() if p.requires_grad]
        optimizer = optim.Adam(params_to_update, lr=self.learning_rate)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        """Training step"""
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        # Log metrics
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        """Validation step"""
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        # Log metrics
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        """Test step"""
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        # Log metrics
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        
        return {'loss': loss, 'preds': preds, 'targets': y}

In [None]:
def train_finetune_model(config):
    """Train a fine-tuned model with specific configuration"""
    # Initialize wandb
    wandb.init(project="inaturalist_finetune", config=config)
    
    # Create data module
    data_module = iNaturalistDataModule(
        batch_size=config['batch_size'],
        augmentation=True  # Always use augmentation for fine-tuning
    )
    data_module.setup()
    
    # Create model
    model = PretrainedModel(
        model_name=config['model_name'],
        num_classes=10,  # Assuming 10 classes in iNaturalist subset
        learning_rate=config['learning_rate'],
        fine_tuning_strategy=config['fine_tuning_strategy'],
        unfreeze_layers=config.get('unfreeze_layers', 3)
    )
    
    # Log model summary
    wandb.watch(model, log="all")
    
    # Setup callbacks
    callbacks = [
        ModelCheckpoint(
            monitor='val_acc',
            filename='best-{epoch:02d}-{val_acc:.4f}',
            save_top_k=1,
            mode='max'
        ),
        EarlyStopping(
            monitor='val_acc',
            patience=5,
            mode='max'
        )
    ]
    
    # Setup wandb logger
    wandb_logger = WandbLogger(project="inaturalist_finetune")
    
    # Create trainer (with mixed precision to speed up training)
    trainer = Trainer(
        max_epochs=20,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1,
        callbacks=callbacks,
        logger=wandb_logger,
        log_every_n_steps=10,
        precision=16  # Use mixed precision
    )
    
    # Train model
    trainer.fit(model, data_module.train_dataloader(), data_module.val_dataloader())
    
    # Test model
    test_results = trainer.test(model, dataloader=data_module.test_dataloader())
    
    # Log test results
    wandb.log({
        'test_acc': test_results[0]['test_acc'],
        'test_loss': test_results[0]['test_loss']
    })
    
    return model, test_results

def compare_finetuning_strategies():
    """Compare different fine-tuning strategies"""
    # Different fine-tuning configurations to try
    configurations = [
        {
            'model_name': 'resnet50',
            'fine_tuning_strategy': 'last_layer',
            'learning_rate': 0.001,
            'batch_size': 32
        },
        {
            'model_name': 'resnet50',
            'fine_tuning_strategy': 'k_last_layers',
            'unfreeze_layers': 3,
            'learning_rate': 0.0001,
            'batch_size': 32
        },
        {
            'model_name': 'resnet50',
            'fine_tuning_strategy': 'all_layers',
            'learning_rate': 0.00001,
            'batch_size': 32
        },
        # Try with a different model
        {
            'model_name': 'efficientnet_v2_s',
            'fine_tuning_strategy': 'last_layer',
            'learning_rate': 0.001,
            'batch_size': 32
        },
        {
            'model_name': 'vit_b_16',
            'fine_tuning_strategy': 'last_layer',
            'learning_rate': 0.001,
            'batch_size': 32
        }
    ]
    
    results = []
    
    for config in configurations:
        model, test_result = train_finetune_model(config)
        results.append({
            'config': config,
            'test_acc': test_result[0]['test_acc'],
            'test_loss': test_result[0]['test_loss']
        })
    
    # Compare results
    for result in results:
        print(f"Model: {result['config']['model_name']}, "
              f"Strategy: {result['config']['fine_tuning_strategy']}, "
              f"Test Acc: {result['test_acc']:.4f}")
    
    return results

# Main

In [None]:
# Main script to run the assignment

import os
import argparse
import torch
import wandb
from datetime import datetime

# Import custom modules
# Note: Make sure all the previous code blocks are saved in appropriate Python files

def main():
    """Main function to run the assignment"""
    parser = argparse.ArgumentParser(description='DA6401 Assignment 2 - CNN Training')
    parser.add_argument('--part', type=str, default='both', choices=['a', 'b', 'both'], 
                        help='Which part of the assignment to run (a, b, or both)')
    parser.add_argument('--sweep', action='store_true', help='Run hyperparameter sweep')
    parser.add_argument('--train', action='store_true', help='Train final model')
    parser.add_argument('--test', action='store_true', help='Test model')
    parser.add_argument('--data_dir', type=str, default='./inaturalist_data', 
                        help='Directory containing the dataset')
    parser.add_argument('--output_dir', type=str, default='./output', 
                        help='Directory to save output')
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Initialize wandb
    wandb.login()
    
    if args.part in ['a', 'both']:
        print("Running Part A: Training from Scratch")
        
        # Run hyperparameter sweep
        if args.sweep:
            print("Running hyperparameter sweep...")
            run_sweep()
        
        # Train final model with best hyperparameters
        if args.train:
            print("Training final model...")
            
            # Best hyperparameters from sweep
            best_config = {
                'filter_counts_strategy': 'doubling',
                'base_filters': 32,
                'filter_size': 3,
                'activation': 'relu',
                'dense_neurons': 512,
                'dropout_rate': 0.3,
                'learning_rate': 0.001,
                'batch_norm': True,
                'batch_size': 32,
                'augmentation': True
            }
            
            # Train model
            model_a, test_results_a = train_final_model(best_config)
            
            # Test and visualize
            if args.test:
                data_module = iNaturalistDataModule(
                    data_dir=args.data_dir,
                    batch_size=best_config['batch_size'],
                    augmentation=best_config['augmentation']
                )
                data_module.setup()
                
                visualize_test_samples(model_a, data_module)
                visualize_filters(model_a)
    
    if args.part in ['b', 'both']:
        print("Running Part B: Fine-tuning Pre-trained Model")
        
        # Train and compare different fine-tuning strategies
        if args.train:
            print("Training and comparing fine-tuning strategies...")
            results = compare_finetuning_strategies()
            
            # Log comparison results
            wandb.init(project="inaturalist_finetune_comparison")
            
            # Create a table for the results
            table = wandb.Table(columns=["Model", "Strategy", "Test Accuracy"])
            
            for result in results:
                table.add_data(
                    result['config']['model_name'],
                    result['config']['fine_tuning_strategy'],
                    result['test_acc']
                )
            
            wandb.log({"finetuning_comparison": table})
            
            # Find best model
            best_result = max(results, key=lambda x: x['test_acc'])
            print(f"Best fine-tuning result: {best_result}")
            
            # Test and visualize best model
            if args.test:
                # Train the best model again
                best_model, _ = train_finetune_model(best_result['config'])
                
                # Setup data module
                data_module = iNaturalistDataModule(
                    data_dir=args.data_dir,
                    batch_size=best_result['config']['batch_size'],
                    augmentation=True
                )
                data_module.setup()
                
                # Visualize test samples
                visualize_test_samples(best_model, data_module)

if __name__ == "__main__":
    main()