# 

# Part B

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
import torchvision.transforms.functional as TF
from PIL import Image
import matplotlib.pyplot as plt
import wandb
import random
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Constants
NUM_CLASSES = 10
BATCH_SIZE = 32
NUM_EPOCHS = 10
LEARNING_RATE = 0.001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cuda


In [None]:
# Define paths
TRAIN_DIR = "/kaggle/input/inaturalist/inaturalist_12K/train"
TEST_DIR = "/kaggle/input/inaturalist/inaturalist_12K/val"

# Custom Dataset for iNaturalist
class INaturalistDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        
        self.samples = []
        for class_name in self.classes:
            class_dir = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_dir):
                img_path = os.path.join(class_dir, img_name)
                self.samples.append((img_path, self.class_to_idx[class_name]))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

# Data transformations
# ImageNet normalization values
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

# Define transformations for training and testing
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet models require 224x224 input
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

# Load datasets
train_dataset = INaturalistDataset(TRAIN_DIR, transform=train_transform)

# Split into training and validation
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_subset, val_subset = random_split(train_dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Load test dataset
test_dataset = INaturalistDataset(TEST_DIR, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [None]:
# Updated main function with enhanced wandb logging
def run_experiment(model_name, freeze_strategy, num_classes, train_loader, val_loader, 
                  test_loader, test_dataset, num_epochs=NUM_EPOCHS):
    """Run a complete fine-tuning experiment with comprehensive wandb logging."""
    # Initialize wandb run
    run_name = f"{model_name}_{freeze_strategy}"
    wandb.init(project="inaturalist_fine_tuning", name=run_name, config={
        "model": model_name,
        "freeze_strategy": freeze_strategy,
        "batch_size": BATCH_SIZE,
        "learning_rate": LEARNING_RATE,
        "epochs": num_epochs,
        "num_classes": num_classes
    })
    
    # Load model with specified freezing strategy
    model = load_pretrained_model(model_name=model_name, 
                                 freeze_layers=freeze_strategy, 
                                 num_classes=num_classes)
    
    # Calculate and log trainable parameters
    trainable_params = count_trainable_parameters(model)
    total_params = sum(p.numel() for p in model.parameters())
    wandb.log({
        "trainable_parameters": trainable_params,
        "total_parameters": total_params,
        "percent_trainable": (trainable_params / total_params) * 100
    })
    print(f"Strategy: {freeze_strategy} - Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%})")
    
    # Set up optimizer based on strategy
    if freeze_strategy == "none":
        # Different learning rates for pre-trained vs new layers
        params_to_update = []
        params_new = []
        
        for name, param in model.named_parameters():
            if name.startswith('fc') or name.startswith('classifier') or name.startswith('heads'):
                params_new.append(param)
            else:
                params_to_update.append(param)
        
        optimizer = optim.SGD([
            {'params': params_to_update, 'lr': LEARNING_RATE * 0.1},
            {'params': params_new, 'lr': LEARNING_RATE}
        ], momentum=0.9)
    else:
        # Regular optimizer for frozen models
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), 
                             lr=LEARNING_RATE, momentum=0.9)
    
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    criterion = nn.CrossEntropyLoss()
    
    # Create a table to log per-epoch metrics
    columns = ["epoch", "train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
    metrics_table = wandb.Table(columns=columns)
    
    # Train model
    print(f"Training model: {model_name} with strategy: {freeze_strategy}")
    best_val_acc = 0.0
    best_model_wts = None
    
    # History to track metrics
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels in train_loader:
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                
                # Backward pass + optimize
                loss.backward()
                optimizer.step()
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        
        if scheduler:
            scheduler.step()
            
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc.item())
        
        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
        # Validation phase
        model.eval()
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels in val_loader:
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
            
            # Forward pass
            with torch.no_grad():
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        
        val_loss = running_loss / len(val_loader.dataset)
        val_acc = running_corrects.double() / len(val_loader.dataset)
        
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc.item())
        
        print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')
        
        # Add row to metrics table
        curr_lr = optimizer.param_groups[0]['lr']
        metrics_table.add_data(epoch+1, epoch_loss, epoch_acc.item(), val_loss, val_acc.item(), curr_lr)
        
        # Log metrics for this epoch
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": epoch_loss,
            "train_acc": epoch_acc.item(),
            "val_loss": val_loss,
            "val_acc": val_acc.item(),
            "learning_rate": curr_lr
        })
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_wts = model.state_dict().copy()
            model_filename = f'best_model_{model_name}_{freeze_strategy}.pth'
            torch.save(best_model_wts, model_filename)
            print(f"Saved new best model with accuracy: {best_val_acc:.4f}")
            
            # Log best model as artifact
            model_artifact = wandb.Artifact(f"model-{run_name}", type="model")
            model_artifact.add_file(model_filename)
            wandb.log_artifact(model_artifact)
    
    # Log final metrics table
    wandb.log({"training_metrics": metrics_table})
    
    # Load best model and evaluate on test set
    model.load_state_dict(best_model_wts)
    class_names = get_class_names(test_dataset)
    test_loss, test_acc, cm, all_preds, all_labels = evaluate_model(model, test_loader, criterion, class_names, model_name, freeze_strategy)
    
    # Log final test metrics
    wandb.log({
        "final_test_accuracy": test_acc,
        "final_test_loss": test_loss
    })
    
    # Visualize incorrect predictions and log
    visualize_incorrect_predictions(test_dataset, test_loader, model, class_names)
    
    # Create and log training curves
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_acc'], label='Train')
    plt.plot(history['val_acc'], label='Validation')
    plt.title(f'{model_name}_{freeze_strategy} accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['val_loss'], label='Validation')
    plt.title(f'{model_name}_{freeze_strategy} loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    wandb.log({"training_curves": wandb.Image(plt)})
    plt.savefig(f'training_curves_{freeze_strategy}.png')
    plt.close()
    
    wandb.finish()
    return model, history, test_acc

In [None]:
def get_class_names(dataset):
    """Get the class names from the dataset."""
    return dataset.classes

# Function to load a pre-trained model and modify the last layer
def load_pretrained_model(model_name="resnet50", freeze_layers="all_except_last", num_classes=10):
    """
    Load a pre-trained model and modify it for fine-tuning
    
    Parameters:
        model_name: Name of the model to load (resnet50, vgg16, etc.)
        freeze_layers: Strategy for freezing layers
            - "all_except_last": Freeze all layers except the last layer
            - "none": Don't freeze any layers (full fine-tuning)
            - "first_k": Freeze only the first k layers
            - "all_except_k": Freeze all layers except the last k layers
        num_classes: Number of output classes
        
    Returns:
        model: Modified model ready for fine-tuning
    """
    if model_name == "resnet50":
        model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        
        # Replace the final fully connected layer
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
        
        # Freeze layers according to the strategy
        if freeze_layers == "all_except_last":
            # Freeze all layers except the final fc layer
            for param in model.parameters():
                param.requires_grad = False
            for param in model.fc.parameters():
                param.requires_grad = True
                
        elif freeze_layers == "none":
            # Don't freeze any layers (full fine-tuning)
            pass
            
        elif freeze_layers.startswith("first_"):
            # Freeze the first k layers
            k = int(freeze_layers.split("_")[1])
            layers_to_freeze = list(model.named_children())[:k]
            for name, layer in layers_to_freeze:
                for param in layer.parameters():
                    param.requires_grad = False
                    
        elif freeze_layers.startswith("all_except_"):
            # Freeze all layers except the last k layers
            k = int(freeze_layers.split("_")[2])
            total_layers = len(list(model.named_children()))
            layers_to_freeze = list(model.named_children())[:(total_layers-k)]
            for name, layer in layers_to_freeze:
                for param in layer.parameters():
                    param.requires_grad = False
                    
    elif model_name == "vgg16":
        model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
        # Replace classifier
        model.classifier[6] = nn.Linear(4096, num_classes)
        
        # Apply freezing strategy
        if freeze_layers == "all_except_last":
            for param in model.parameters():
                param.requires_grad = False
            for param in model.classifier[6].parameters():
                param.requires_grad = True
                
    elif model_name == "efficientnet_v2_s":
        model = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1)
        # Replace classifier
        in_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(in_features, num_classes)
        
        # Apply freezing strategy
        if freeze_layers == "all_except_last":
            for param in model.parameters():
                param.requires_grad = False
            for param in model.classifier[1].parameters():
                param.requires_grad = True
                
    elif model_name == "googlenet":
        model = models.googlenet(weights=models.GoogLeNet_Weights.IMAGENET1K_V1)
        # Replace fc layer
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
        
        # Apply freezing strategy
        if freeze_layers == "all_except_last":
            for param in model.parameters():
                param.requires_grad = False
            for param in model.fc.parameters():
                param.requires_grad = True
                
    elif model_name == "vit_b_16":
        model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
        # Replace the head
        in_features = model.heads.head.in_features
        model.heads.head = nn.Linear(in_features, num_classes)
        
        # Apply freezing strategy
        if freeze_layers == "all_except_last":
            for param in model.parameters():
                param.requires_grad = False
            for param in model.heads.head.parameters():
                param.requires_grad = True
                
    else:
        raise ValueError(f"Model {model_name} not supported")
    
    # Move model to device
    model = model.to(DEVICE)
    return model

# Function to count trainable parameters
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler=None, num_epochs=10):
    """Train the model and return training history."""
    # Initialize wandb for logging
    wandb.init(project="inaturalist_fine_tuning")
    
    # Log model architecture and hyperparameters
    wandb.config.update({
        "model": model.__class__.__name__,
        "trainable_params": count_trainable_parameters(model),
        "epochs": num_epochs,
        "batch_size": BATCH_SIZE,
        "learning_rate": LEARNING_RATE,
        "optimizer": optimizer.__class__.__name__
    })
    
    # Training history
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    # Best model tracking
    best_val_acc = 0.0
    best_model_wts = None
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels in train_loader:
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                
                # Backward pass + optimize
                loss.backward()
                optimizer.step()
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        
        if scheduler:
            scheduler.step()
            
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc.item())
        
        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
        # Validation phase
        model.eval()
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels in val_loader:
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
            
            # Forward pass
            with torch.no_grad():
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        
        epoch_loss = running_loss / len(val_loader.dataset)
        epoch_acc = running_corrects.double() / len(val_loader.dataset)
        
        history['val_loss'].append(epoch_loss)
        history['val_acc'].append(epoch_acc.item())
        
        print(f'Val Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
        # Save best model
        if epoch_acc > best_val_acc:
            best_val_acc = epoch_acc
            best_model_wts = model.state_dict().copy()
            torch.save(best_model_wts, 'best_model.pth')
            print(f"Saved new best model with accuracy: {best_val_acc:.4f}")
        
        # Log to wandb
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": history['train_loss'][-1],
            "train_acc": history['train_acc'][-1],
            "val_loss": history['val_loss'][-1],
            "val_acc": history['val_acc'][-1],
            "learning_rate": optimizer.param_groups[0]['lr']
        })
    
    # Load best model
    model.load_state_dict(best_model_wts)
    wandb.finish()
    return model, history

# Evaluation function
def evaluate_model(model, test_loader, criterion, class_names, model_name, freeze_strategy):
    """Evaluate the model on test data and log results to wandb."""
    model.eval()
    
    running_loss = 0.0
    running_corrects = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    test_loss = running_loss / len(test_loader.dataset)
    test_acc = running_corrects.double() / len(test_loader.dataset)
    
    print(f'Test Loss: {test_loss:.4f} Acc: {test_acc:.4f}')
    
    # Compute confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    
    # Log test metrics to wandb
    wandb.log({
        "test_loss": test_loss,
        "test_accuracy": test_acc.item(),
    })
    
    # Create and log confusion matrix visualization
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title(f'Confusion Matrix: {model_name}_{freeze_strategy}')
    plt.tight_layout()
    
    # Log confusion matrix to wandb
    wandb.log({"confusion_matrix": wandb.Image(plt)})
    plt.savefig('confusion_matrix.png')
    plt.close()
    
    return test_loss, test_acc.item(), cm, all_preds, all_labels

# Function to visualize incorrect predictions
def visualize_incorrect_predictions(test_dataset, test_loader, model, class_names, num_images=10):
    model.eval()
    incorrect_samples = []
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            # Find incorrect predictions
            incorrect_mask = preds != labels
            incorrect_indices = torch.nonzero(incorrect_mask).squeeze().cpu()
            
            if len(incorrect_indices.shape) == 0 and incorrect_indices.numel() > 0:
                incorrect_indices = incorrect_indices.unsqueeze(0)
            
            for idx in incorrect_indices:
                img_tensor = inputs[idx].cpu()
                true_label = labels[idx].item()
                pred_label = preds[idx].item()
                incorrect_samples.append((img_tensor, true_label, pred_label))
                
                if len(incorrect_samples) >= num_images:
                    break
            
            if len(incorrect_samples) >= num_images:
                break
    
    # Plot incorrect predictions
    if incorrect_samples:
        fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(15, 6))
        axes = axes.flatten()
        
        for i, (img_tensor, true_label, pred_label) in enumerate(incorrect_samples[:num_images]):
            # Denormalize the image
            img = img_tensor.numpy().transpose((1, 2, 0))
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img = std * img + mean
            img = np.clip(img, 0, 1)
            
            axes[i].imshow(img)
            axes[i].set_title(f"True: {class_names[true_label]}\nPred: {class_names[pred_label]}")
            axes[i].axis('off')
        
        plt.tight_layout()
        
        # Log incorrect predictions to wandb
        wandb.log({"incorrect_predictions": wandb.Image(fig)})
        plt.savefig('incorrect_predictions.png')
        plt.close()

In [None]:
# Modified main function to run all experiments and compare them
def main():
    wandb.login(key="e030007b097df00d9a751748294abc8440f932b1")
    
    """Run all fine-tuning experiments and compare results."""
    # Get class names
    class_names = get_class_names(train_dataset)
    
    # Dictionary to store results
    results = {}
    model_name = "resnet50"
    
    # Strategy 1: Freeze all layers except the last layer
    model1, history1, acc1 = run_experiment(
        model_name=model_name, 
        freeze_strategy="all_except_last",
        num_classes=NUM_CLASSES,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        test_dataset=test_dataset
    )
    results["all_except_last"] = {"accuracy": acc1, "history": history1}
    
    # Strategy 2: Full fine-tuning (no freezing)
    model2, history2, acc2 = run_experiment(
        model_name=model_name, 
        freeze_strategy="none",
        num_classes=NUM_CLASSES,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        test_dataset=test_dataset
    )
    results["none"] = {"accuracy": acc2, "history": history2}

    # Strategy 3: Freeze first 6 layers
    model3, history3, acc3 = run_experiment(
        model_name=model_name, 
        freeze_strategy="first_6",
        num_classes=NUM_CLASSES,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        test_dataset=test_dataset
    )
    results["first_6"] = {"accuracy": acc3, "history": history3}

    # Final comparison run
    wandb.init(project="inaturalist_fine_tuning", name="Model_comparison")
    
    # Compare strategies with a table
    comparison_table = wandb.Table(columns=["Strategy", "Test Accuracy", "Trainable Parameters"])
    
    comparison_table.add_data("Freeze all except last", results["all_except_last"]["accuracy"], 
                             count_trainable_parameters(model1))
    comparison_table.add_data("Full fine-tuning", results["none"]["accuracy"], 
                             count_trainable_parameters(model2))
    comparison_table.add_data("Freeze first 6 layers", results["first_6"]["accuracy"], 
                             count_trainable_parameters(model3))

    comparison_table = wandb.Table(columns=["Strategy", "Test Accuracy", "Trainable Parameters"])
    
    wandb.log({"strategy_comparison": comparison_table})
    
    # Plot comparison chart
    plt.figure(figsize=(15, 6))
    
    plt.subplot(1, 2, 1)
    for strategy, data in results.items():
        plt.plot(data["history"]["train_acc"], linestyle='-', label=f'{strategy} (Train)')
        plt.plot(data["history"]["val_acc"], linestyle='--', label=f'{strategy} (Val)')
    plt.title(f'Accuracy Comparison Across Strategies: {model_name}')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    for strategy, data in results.items():
        plt.plot(data["history"]["train_loss"], linestyle='-', label=f'{strategy} (Train)')
        plt.plot(data["history"]["val_loss"], linestyle='--', label=f'{strategy} (Val)')
    plt.title(f'Loss Comparison Across Strategies: {model_name}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    wandb.log({"strategy_comparison_chart": wandb.Image(plt)})
    plt.savefig('strategy_comparison.png')
    plt.close()
    
    # Create bar chart of test accuracies 
    plt.figure(figsize=(10, 6))
    strategies = list(results.keys())
    accuracies = [results[s]["accuracy"] for s in strategies]
    
    plt.bar(strategies, accuracies)
    plt.title(f'Test Accuracy by Fine-tuning Strategy: {model_name}')
    plt.xlabel('Strategy')
    plt.ylabel('Test Accuracy')
    plt.ylim(0, 1.0)
    
    for i, acc in enumerate(accuracies):
        plt.text(i, acc + 0.01, f'{acc:.4f}', ha='center')
    
    wandb.log({"accuracy_comparison": wandb.Image(plt)})
    plt.savefig('accuracy_comparison.png')
    plt.close()
    
    wandb.finish()
    
    print("All experiments completed!")
    print("\nTest Accuracies:")
    for model_key, data in results.items():
        print(f"- {model_key}: {data['accuracy']:.4f}")

In [None]:
# Modified main function to run all experiments and compare them
def main():
    key = ""      # Add your WandB API key here
    wandb.login(key=key)
    
    """Run all fine-tuning experiments and compare results."""
    # Get class names
    class_names = get_class_names(train_dataset)
    
    # Dictionary to store results
    results = {}
    model_name = "resnet50"
    
    model2, history2, acc2 = run_experiment(
        model_name="vgg16", 
        freeze_strategy="all_except_last",
        num_classes=NUM_CLASSES,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        test_dataset=test_dataset
    )
    results["vgg16"] = {"accuracy": acc2, "history": history2}

    
    model3, history3, acc3 = run_experiment(
        model_name="efficientnet_v2_s", 
        freeze_strategy="all_except_last",
        num_classes=NUM_CLASSES,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        test_dataset=test_dataset
    )
    results["efficientnet_v2_s"] = {"accuracy": acc3, "history": history3}

    
    model4, history4, acc4 = run_experiment(
        model_name="vit_b_16", 
        freeze_strategy="all_except_last",
        num_classes=NUM_CLASSES,
        train_loader=train_loader,
        val_loader=val_loader,
        test_loader=test_loader,
        test_dataset=test_dataset
    )
    results["vit_b_16"] = {"accuracy": acc4, "history": history4}
    
    # Final comparison run
    wandb.init(project="inaturalist_fine_tuning", name="Model_comparison")
    
    comparison_table = wandb.Table(columns=["Model", "Test Accuracy", "Trainable Parameters"])
    
    comparison_table.add_data("VGG16", results["vgg16"]["accuracy"], 
                             count_trainable_parameters(model2))
    comparison_table.add_data("EfficientNet", results["efficientnet_v2_s"]["accuracy"], 
                             count_trainable_parameters(model3))
    comparison_table.add_data("vit_b_16", results["vit_b_16"]["accuracy"], 
                             count_trainable_parameters(model4))
    
    wandb.log({"models_comparison": comparison_table})
    
    # Plot comparison chart
    plt.figure(figsize=(15, 6))
    
    plt.subplot(1, 2, 1)
    for model_name, data in results.items():
        plt.plot(data["history"]["train_acc"], linestyle='-', label=f'{model_name} (Train)')
        plt.plot(data["history"]["val_acc"], linestyle='--', label=f'{model_name} (Val)')
    plt.title(f'Accuracy Comparison Across Strategies: {model_name}')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    for model_name, data in results.items():
        plt.plot(data["history"]["train_loss"], linestyle='-', label=f'{model_name} (Train)')
        plt.plot(data["history"]["val_loss"], linestyle='--', label=f'{model_name} (Val)')
    plt.title(f'Loss Comparison Across Strategies: {model_name}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    wandb.log({"model_comparison_chart": wandb.Image(plt)})
    plt.savefig('model_comparison.png')
    plt.close()
    
    # Create bar chart of test accuracies 
    plt.figure(figsize=(10, 6))
    models = list(results.keys())
    accuracies = [results[s]["accuracy"] for s in models]
    
    plt.bar(models, accuracies)
    plt.title(f'Test Accuracy by Fine-tuning Strategy: {"all_except_last"}')
    plt.xlabel('Model')
    plt.ylabel('Test Accuracy')
    plt.ylim(0, 1.0)
    
    for i, acc in enumerate(accuracies):
        plt.text(i, acc + 0.01, f'{acc:.4f}', ha='center')
    
    wandb.log({"accuracy_comparison": wandb.Image(plt)})
    plt.savefig('accuracy_comparison.png')
    plt.close()
    
    wandb.finish()
    
    print("All experiments completed!")
    print("\nTest Accuracies:")
    for model_key, data in results.items():
        print(f"- {model_key}: {data['accuracy']:.4f}")

In [6]:
main()

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmm21b044[0m ([33mmm21b044-indian-institute-of-technology-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 195MB/s]  


Strategy: all_except_last - Trainable parameters: 40,970 (0.03%)
Training model: vgg16 with strategy: all_except_last
Epoch 1/10
----------
Train Loss: 1.1370 Acc: 0.6191
Val Loss: 0.8624 Acc: 0.7185
Saved new best model with accuracy: 0.7185
Epoch 2/10
----------
Train Loss: 0.9301 Acc: 0.6852
Val Loss: 0.8042 Acc: 0.7275
Saved new best model with accuracy: 0.7275
Epoch 3/10
----------
Train Loss: 0.8968 Acc: 0.6947
Val Loss: 0.7770 Acc: 0.7330
Saved new best model with accuracy: 0.7330
Epoch 4/10
----------
Train Loss: 0.8599 Acc: 0.7030
Val Loss: 0.7720 Acc: 0.7415
Saved new best model with accuracy: 0.7415
Epoch 5/10
----------
Train Loss: 0.8459 Acc: 0.7073
Val Loss: 0.7654 Acc: 0.7350
Epoch 6/10
----------
Train Loss: 0.8409 Acc: 0.7161
Val Loss: 0.7613 Acc: 0.7385
Epoch 7/10
----------
Train Loss: 0.8377 Acc: 0.7082
Val Loss: 0.7468 Acc: 0.7470
Saved new best model with accuracy: 0.7470
Epoch 8/10
----------
Train Loss: 0.7874 Acc: 0.7337
Val Loss: 0.7285 Acc: 0.7520
Saved new b

0,1
epoch,▁▂▃▃▄▅▆▆▇█
final_test_accuracy,▁
final_test_loss,▁
learning_rate,██████▁▁▁▁
percent_trainable,▁
test_accuracy,▁
test_loss,▁
total_parameters,▁
train_acc,▁▅▆▆▆▇▆███
train_loss,█▄▃▃▂▂▂▁▁▁

0,1
epoch,10.0
final_test_accuracy,0.769
final_test_loss,0.6743
learning_rate,0.0001
percent_trainable,0.03051
test_accuracy,0.769
test_loss,0.6743
total_parameters,134301514.0
train_acc,0.72959
train_loss,0.77947


Downloading: "https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_v2_s-dd5fe13b.pth
100%|██████████| 82.7M/82.7M [00:00<00:00, 180MB/s] 


Strategy: all_except_last - Trainable parameters: 12,810 (0.06%)
Training model: efficientnet_v2_s with strategy: all_except_last
Epoch 1/10
----------
Train Loss: 1.8086 Acc: 0.4894
Val Loss: 1.4164 Acc: 0.6580
Saved new best model with accuracy: 0.6580
Epoch 2/10
----------
Train Loss: 1.2965 Acc: 0.6477
Val Loss: 1.1686 Acc: 0.6680
Saved new best model with accuracy: 0.6680
Epoch 3/10
----------
Train Loss: 1.1539 Acc: 0.6612
Val Loss: 1.3341 Acc: 0.6800
Saved new best model with accuracy: 0.6800
Epoch 4/10
----------
Train Loss: 1.0925 Acc: 0.6686
Val Loss: 1.0173 Acc: 0.6975
Saved new best model with accuracy: 0.6975
Epoch 5/10
----------
Train Loss: 1.0468 Acc: 0.6775
Val Loss: 1.7514 Acc: 0.6965
Epoch 6/10
----------
Train Loss: 1.0187 Acc: 0.6808
Val Loss: 2.1136 Acc: 0.7095
Saved new best model with accuracy: 0.7095
Epoch 7/10
----------
Train Loss: 1.0027 Acc: 0.6845
Val Loss: 0.9355 Acc: 0.7045
Epoch 8/10
----------
Train Loss: 0.9904 Acc: 0.6876
Val Loss: 0.9594 Acc: 0.7065

0,1
epoch,▁▂▃▃▄▅▆▆▇█
final_test_accuracy,▁
final_test_loss,▁
learning_rate,██████▁▁▁▁
percent_trainable,▁
test_accuracy,▁
test_loss,▁
total_parameters,▁
train_acc,▁▇▇▇██████
train_loss,█▄▂▂▂▁▁▁▁▁

0,1
epoch,10.0
final_test_accuracy,0.725
final_test_loss,1.46532
learning_rate,0.0001
percent_trainable,0.06345
test_accuracy,0.725
test_loss,1.46532
total_parameters,20190298.0
train_acc,0.68959
train_loss,0.98454


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:01<00:00, 206MB/s] 


Strategy: all_except_last - Trainable parameters: 7,690 (0.01%)
Training model: vit_b_16 with strategy: all_except_last
Epoch 1/10
----------
Train Loss: 1.1922 Acc: 0.6975
Val Loss: 0.7963 Acc: 0.7920
Saved new best model with accuracy: 0.7920
Epoch 2/10
----------
Train Loss: 0.7180 Acc: 0.8102
Val Loss: 0.6555 Acc: 0.8225
Saved new best model with accuracy: 0.8225
Epoch 3/10
----------
Train Loss: 0.6311 Acc: 0.8257
Val Loss: 0.6082 Acc: 0.8260
Saved new best model with accuracy: 0.8260
Epoch 4/10
----------
Train Loss: 0.5872 Acc: 0.8346
Val Loss: 0.5994 Acc: 0.8295
Saved new best model with accuracy: 0.8295
Epoch 5/10
----------
Train Loss: 0.5537 Acc: 0.8391
Val Loss: 0.5657 Acc: 0.8355
Saved new best model with accuracy: 0.8355
Epoch 6/10
----------
Train Loss: 0.5320 Acc: 0.8479
Val Loss: 0.5743 Acc: 0.8250
Epoch 7/10
----------
Train Loss: 0.5160 Acc: 0.8469
Val Loss: 0.5656 Acc: 0.8320
Epoch 8/10
----------
Train Loss: 0.4992 Acc: 0.8516
Val Loss: 0.5587 Acc: 0.8295
Epoch 9/1

0,1
epoch,▁▂▃▃▄▅▆▆▇█
final_test_accuracy,▁
final_test_loss,▁
learning_rate,██████▁▁▁▁
percent_trainable,▁
test_accuracy,▁
test_loss,▁
total_parameters,▁
train_acc,▁▆▇▇▇█████
train_loss,█▃▂▂▂▁▁▁▁▁

0,1
epoch,10.0
final_test_accuracy,0.843
final_test_loss,0.55227
learning_rate,0.0001
percent_trainable,0.00896
test_accuracy,0.843
test_loss,0.55227
total_parameters,85806346.0
train_acc,0.85236
train_loss,0.50242


All experiments completed!

Test Accuracies:
- vgg16: 0.7690
- efficientnet_v2_s: 0.7250
- vit_b_16: 0.8430
