In [3]:
## Installing WandB
!pip install wandb -qqq

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torchvision.transforms import ToTensor

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", device)

from tqdm import tqdm

Device:  cuda


In [1]:
import wandb, os
os.environ['WANDB_API_KEY'] = "5203e53880ceb7b6d2c0a93809e14ae43261f2ed" #your key here
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mcs24m016[0m ([33mcs24m016-indian-institute-of-technology-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
!pip install lightning

# Question 1

In [None]:
import torch
import torch.nn as nn

class ConvNet(nn.Module):
    def __init__(
        self,
        input_shape=(3, 224, 224),
        conv_filters=[32, 64, 128, 256, 512],
        filter_sizes=[3, 3, 3, 3, 3],
        activation_fn=nn.ReLU,
        dense_units=256,
        dense_activation_fn=nn.ReLU,
        dropout_rate=0.3,
        batch_norm=True,
        num_classes=10
    ):
        super(ConvNet, self).__init__()

        self.conv_blocks = nn.Sequential()
        in_channels = input_shape[0]
        h, w = input_shape[1], input_shape[2]

        # Add 5 Conv-BN-Activation-Pool blocks
        for i in range(5):
            out_channels = conv_filters[i]
            kernel_size = filter_sizes[i]
            padding = kernel_size // 2  # keep same spatial size before pooling

            self.conv_blocks.add_module(f"conv{i+1}", nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding))
            if batch_norm:
                self.conv_blocks.add_module(f"bn{i+1}", nn.BatchNorm2d(out_channels))
            self.conv_blocks.add_module(f"act{i+1}", activation_fn())
            self.conv_blocks.add_module(f"pool{i+1}", nn.MaxPool2d(2))
            if dropout_rate > 0:
                self.conv_blocks.add_module(f"dropout{i+1}", nn.Dropout2d(dropout_rate))

            in_channels = out_channels
            h, w = h // 2, w // 2  # due to MaxPool2d(2)

        # Compute the flattened size after conv blocks
        self.flattened_size = in_channels * h * w

        self.fc1 = nn.Linear(self.flattened_size, dense_units)
        self.fc1_act = dense_activation_fn()
        self.dropout = nn.Dropout(dropout_rate)

        self.output_layer = nn.Linear(dense_units, num_classes)

    def forward(self, x):
        x = self.conv_blocks(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(self.fc1_act(self.fc1(x)))
        return self.output_layer(x)


# Question 2

In [None]:
!pip install wandb


In [None]:
import os
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split, Subset
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np

def get_dataloaders(data_dir, batch_size=64, val_split=0.2, augment=True):
    # Transforms
    train_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor()
    ]) if augment else transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    test_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    full_dataset = ImageFolder(root=data_dir, transform=train_transforms)

    # Stratified split
    targets = np.array(full_dataset.targets)
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=val_split, random_state=42)
    train_idx, val_idx = next(splitter.split(np.zeros(len(targets)), targets))

    train_set = Subset(full_dataset, train_idx)
    val_set = Subset(ImageFolder(root=data_dir, transform=test_transforms), val_idx)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)

    return train_loader, val_loader, len(full_dataset.classes)


In [None]:
import torch.nn.functional as F
import wandb

def train(model, train_loader, val_loader, optimizer, criterion, device, epochs=10):
    model.to(device)

    for epoch in range(epochs):
        model.train()
        total_loss, correct = 0, 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            correct += (outputs.argmax(dim=1) == labels).sum().item()

        train_accuracy = correct / len(train_loader.dataset)

        # Validation
        model.eval()
        val_correct, val_loss = 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                val_correct += (outputs.argmax(dim=1) == labels).sum().item()

        val_accuracy = val_correct / len(val_loader.dataset)

        wandb.log({
            "epoch": epoch + 1,
            "train_loss": total_loss / len(train_loader),
            "train_accuracy": train_accuracy,
            "val_loss": val_loss / len(val_loader),
            "val_accuracy": val_accuracy
        })

        print(f"Epoch {epoch+1} - Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}")


In [None]:
from torchvision import models
from torch import optim
import torch.nn as nn
import wandb


def main():
    wandb.init(project="DL_A2")

    config = wandb.config

    activation_map = {
        "ReLU": nn.ReLU,
        "GELU": nn.GELU,
        "SiLU": nn.SiLU,
        "Mish": nn.Mish
    }

    model = ConvNet(
        input_shape=(3, 224, 224),
        conv_filters=config.conv_filters,
        filter_sizes=config.filter_sizes,
        activation_fn=activation_map[config.activation_fn],
        dense_units=config.dense_units,
        dense_activation_fn=activation_map[config.activation_fn],
        dropout_rate=config.dropout,
        batch_norm=config.batch_norm,
        num_classes=10
    )

    train_loader, val_loader, _ = get_dataloaders(
        data_dir="/kaggle/input/nature-12k/inaturalist_12K/train",
        batch_size=config.batch_size,
        augment=config.augment
    )

    optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=0.9, weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss()

    train(model, train_loader, val_loader, optimizer, criterion, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), epochs=config.epochs)



In [None]:
sweep_config = {
    "method": "random",
    "metric": {
        "name": "val_accuracy",
        "goal": "maximize"
    },
    "parameters": {
        "conv_filters": {
            "values": [[32, 32, 64, 64, 128], [32, 64, 128, 256, 512]]
        },
        "filter_sizes": {
            "values": [[3, 3, 3, 3, 3]]
        },
        "activation_fn": {
            "values": ["ReLU", "GELU", "SiLU", "Mish"]
        },
        "dropout": {
            "values": [0.2, 0.3]
        },
        "dense_units": {
            "values": [128, 256]
        },
        "batch_norm": {
            "values": [True, False]
        },
        "augment": {
            "values": [True, False]
        },
        "batch_size": {
            "values": [64, 128]
        },
        "lr": {
            "values": [0.01, 0.001]
        },
        "epochs": {
            "value": 10
        }
    }
}


In [None]:
sweep_id = wandb.sweep(sweep=sweep_config, project='DL_A2')
wandb.agent(sweep_id, function=main, count=5)


In [None]:
sweep_id = wandb.sweep(sweep=sweep_config, project='DL_A2')
wandb.agent(sweep_id, function=main, count=50)


gpu mem clear

In [None]:
!pip install GPUtil

from GPUtil import showUtilization as gpu_usage
gpu_usage()                             


In [None]:
import torch
torch.cuda.empty_cache()


In [None]:
!pip install GPUtil

import torch
from GPUtil import showUtilization as gpu_usage
from numba import cuda

def free_gpu_cache():
    print("Initial GPU Usage")
    gpu_usage()                             

    torch.cuda.empty_cache()

    cuda.select_device(0)
    cuda.close()
    cuda.select_device(0)

    print("GPU Usage after emptying the cache")
    gpu_usage()

free_gpu_cache()                           


# question 3 

In [None]:

sweep_config = {
    "method": "random",
    "metric": {
        "name": "val_accuracy",
        "goal": "maximize"
    },
    "parameters": {
        "conv_filters": {
            "values": [[32, 32, 64, 64, 128],[512,256,128,64,32],[256,128,64,64,32], [32, 64, 128, 256, 512]]
        },
        "filter_sizes": {
            "values": [[3, 3, 3, 3, 3],[5,5,5,5,5],[7,7,7,7,7],[7,7,5,5,3],[7,5,3,3,3]]
        },
        "activation_fn": {
            "values": ["ReLU", "GELU", "SiLU", "Mish"]
        },
        "dropout": {
            "values": [0.0,0.2, 0.3]
        },
        "dense_units": {
            "values": [128, 256]
        },
        "batch_norm": {
            "values": [True]
        },
        "augment": {
            "values": [True, False]
        },
        "batch_size": {
            "values": [64, 128,256]
        },
        "lr": {
            "values": [0.01, 0.001]
        },
        "epochs": {
            "value": 10
        }
    }
}



In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np
import wandb


def get_dataloaders(data_dir, batch_size=256, val_split=0.2, augment=True):
    # Enhanced data augmentation
    train_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]) if augment else transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    full_dataset = datasets.ImageFolder(root=data_dir, transform=train_transforms)
    
    # Stratified split
    targets = np.array(full_dataset.targets)
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=val_split, random_state=42)
    train_idx, val_idx = next(splitter.split(np.zeros(len(targets)), targets))
    
    train_set = Subset(full_dataset, train_idx)
    val_set = Subset(datasets.ImageFolder(root=data_dir, transform=val_transforms), val_idx)
    
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, 
                             num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, 
                           num_workers=4, pin_memory=True)
    
    return train_loader, val_loader, full_dataset.classes

class OptimizedCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(OptimizedCNN, self).__init__()
        
        # Larger filters in early layers, smaller in later layers
        self.conv_blocks = nn.Sequential(
            # Block 1: 64 filters, 7x7 kernel
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            
            # Block 2: 128 filters, 5x5 kernel
            nn.Conv2d(64, 128, kernel_size=5, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            
            # Block 3: 256 filters, 3x3 kernel
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            
            # Block 4: 512 filters, 3x3 kernel
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            
            # Block 5: 512 filters, 3x3 kernel
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # Classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, num_classes))
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, 0, 0.01)
                init.constant_(m.bias, 0)
    
    def forward(self, x):
        x = self.conv_blocks(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

def train(model, train_loader, val_loader, optimizer, criterion, scheduler, device, epochs=20):
    model.to(device)
    best_val_acc = 0.0
    
    for epoch in range(epochs):
        model.train()
        train_loss, train_correct = 0.0, 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_correct += predicted.eq(labels).sum().item()
        
        train_acc = 100 * train_correct / len(train_loader.dataset)
        
        # Validation
        model.eval()
        val_loss, val_correct = 0.0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                val_correct += outputs.argmax(1).eq(labels).sum().item()
        
        val_acc = 100 * val_correct / len(val_loader.dataset)
        
        # Step the scheduler
        scheduler.step(val_loss)
        
        # Log metrics
        wandb.log({            
            "epoch": epoch + 1,
            "train_loss": train_loss / len(train_loader),
            "train_accuracy": train_acc,
            "val_loss": val_loss / len(val_loader),
            "val_accuracy": val_acc,
            "lr": optimizer.param_groups[0]['lr']
        })
        
        print(f"Epoch {epoch+1}/{epochs} - "
              f"Train Loss: {train_loss/len(train_loader):.4f}, "
              f"Train Acc: {train_acc:.2f}%, "
              f"Val Loss: {val_loss/len(val_loader):.4f}, "
              f"Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')
    
    return best_val_acc

def main():
    wandb.init(project="DL_A2")
    
    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Get data loaders
    train_loader, val_loader, classes = get_dataloaders(
        data_dir="/kaggle/input/nature-12k/inaturalist_12K/train",
        batch_size=256,
        augment=True
    )
    
    # Initialize model
    model = OptimizedCNN(num_classes=len(classes))
    
    # Loss function
    criterion = nn.CrossEntropyLoss()
    
    # Optimizer with momentum and weight decay
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-5)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.1, patience=3, verbose=True)
    
    # Train the model
    best_val_acc = train(
        model, train_loader, val_loader, 
        optimizer, criterion, scheduler,
        device=device, epochs=20
    )
    
    wandb.summary["best_val_acc"] = best_val_acc
    wandb.finish()



In [None]:
if __name__ == "__main__":
    main()

In [None]:
sweep_id = wandb.sweep(sweep=sweep_config, project='DL_A2')
wandb.agent(sweep_id, function=main, count=20)


# Question 4

# Best Model 

In [4]:
import os
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np
import wandb



sweep_config = {
    "method": "random",
    "metric": {
        "name": "val_accuracy",
        "goal": "maximize"
    },
    "parameters": {
        "conv_filters": {
            "values": [[256,128,64,64,32], [32, 64, 128, 256, 512]]
        },
        "filter_sizes": {
            "values": [[3, 3, 3, 3, 3],[7,7,7,7,7]]
        },
        "activation_fn": {
            "values": [ "GELU"]
        },
        "dropout": {
            "values": [0.2, 0.3]
        },
        "dense_units": {
            "values": [128]
        },
        "batch_norm": {
            "values": [True]
        },
        "augment": {
            "values": [False]
        },
        "batch_size": {
            "values": [64,256]
        },
        "lr": {
            "values": [0.01, 0.001]
        },
        "epochs": {
            "value": 30
        }
    }
}



def get_dataloaders(data_dir, batch_size=256, val_split=0.2, augment=True):
    # Enhanced data augmentation
    train_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]) if augment else transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    full_dataset = datasets.ImageFolder(root=data_dir, transform=train_transforms)
    
    # Stratified split
    targets = np.array(full_dataset.targets)
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=val_split, random_state=42)
    train_idx, val_idx = next(splitter.split(np.zeros(len(targets)), targets))
    
    train_set = Subset(full_dataset, train_idx)
    val_set = Subset(datasets.ImageFolder(root=data_dir, transform=val_transforms), val_idx)
    
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, 
                             num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, 
                           num_workers=4, pin_memory=True)
    
    return train_loader, val_loader, full_dataset.classes

class OptimizedCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(OptimizedCNN, self).__init__()
        
        # Larger filters in early layers, smaller in later layers
        self.conv_blocks = nn.Sequential(
            # Block 1: 64 filters, 7x7 kernel
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            
            # Block 2: 128 filters, 5x5 kernel
            nn.Conv2d(64, 128, kernel_size=5, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            
            # Block 3: 256 filters, 3x3 kernel
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            
            # Block 4: 512 filters, 3x3 kernel
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            
            # Block 5: 512 filters, 3x3 kernel
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # Classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, num_classes))
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, 0, 0.01)
                init.constant_(m.bias, 0)
    
    def forward(self, x):
        x = self.conv_blocks(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

def train(model, train_loader, val_loader, optimizer, criterion, scheduler, device, epochs=20):
    model.to(device)
    best_val_acc = 0.0
    
    for epoch in range(epochs):
        model.train()
        train_loss, train_correct = 0.0, 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_correct += predicted.eq(labels).sum().item()
        
        train_acc = 100 * train_correct / len(train_loader.dataset)
        
        # Validation
        model.eval()
        val_loss, val_correct = 0.0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                val_correct += outputs.argmax(1).eq(labels).sum().item()
        
        val_acc = 100 * val_correct / len(val_loader.dataset)
        
        # Step the scheduler
        scheduler.step(val_loss)
        
        # Log metrics
        wandb.log({            
            "epoch": epoch + 1,
            "train_loss": train_loss / len(train_loader),
            "train_accuracy": train_acc,
            "val_loss": val_loss / len(val_loader),
            "val_accuracy": val_acc,
            "lr": optimizer.param_groups[0]['lr']
        })
        
        print(f"Epoch {epoch+1}/{epochs} - "
              f"Train Loss: {train_loss/len(train_loader):.4f}, "
              f"Train Acc: {train_acc:.2f}%, "
              f"Val Loss: {val_loss/len(val_loader):.4f}, "
              f"Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')
    
    return best_val_acc

def main():
    wandb.init(project="DL_A2")
    
    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Get data loaders
    train_loader, val_loader, classes = get_dataloaders(
        data_dir="/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/train",
        batch_size=256,
        augment=True
    )
    
    # Initialize model
    model = OptimizedCNN(num_classes=len(classes))
    
    # Loss function
    criterion = nn.CrossEntropyLoss()
    
    # Optimizer with momentum and weight decay
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-5)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.1, patience=3, verbose=True)
    
    # Train the model
    best_val_acc = train(
        model, train_loader, val_loader, 
        optimizer, criterion, scheduler,
        device=device, epochs=30
    )
    
    wandb.summary["best_val_acc"] = best_val_acc
    wandb.finish()



In [5]:
sweep_id = wandb.sweep(sweep=sweep_config, project='DL_A2')
wandb.agent(sweep_id, function=main, count=20)


Create sweep with ID: ckiu7skp
Sweep URL: https://wandb.ai/cs24m016-indian-institute-of-technology-madras/DL_A2/sweeps/ckiu7skp


[34m[1mwandb[0m: Agent Starting Run: lknf7ukr with config:
[34m[1mwandb[0m: 	activation_fn: GELU
[34m[1mwandb[0m: 	augment: False
[34m[1mwandb[0m: 	batch_norm: True
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	conv_filters: [256, 128, 64, 64, 32]
[34m[1mwandb[0m: 	dense_units: 128
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	epochs: 30
[34m[1mwandb[0m: 	filter_sizes: [7, 7, 7, 7, 7]
[34m[1mwandb[0m: 	lr: 0.01


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.




Epoch 1/30 - Train Loss: 2.2102, Train Acc: 18.24%, Val Loss: 2.3575, Val Acc: 17.00%
Epoch 2/30 - Train Loss: 2.1465, Train Acc: 21.89%, Val Loss: 2.2811, Val Acc: 22.10%
Epoch 3/30 - Train Loss: 2.1058, Train Acc: 23.70%, Val Loss: 2.0659, Val Acc: 26.55%
Epoch 4/30 - Train Loss: 2.0861, Train Acc: 24.87%, Val Loss: 2.0728, Val Acc: 25.90%
Epoch 5/30 - Train Loss: 2.0563, Train Acc: 25.00%, Val Loss: 2.1041, Val Acc: 23.00%
Epoch 6/30 - Train Loss: 2.0689, Train Acc: 25.57%, Val Loss: 2.0626, Val Acc: 24.60%
Epoch 7/30 - Train Loss: 2.0330, Train Acc: 26.59%, Val Loss: 2.1449, Val Acc: 23.90%
Epoch 8/30 - Train Loss: 2.0199, Train Acc: 26.84%, Val Loss: 2.0000, Val Acc: 28.80%
Epoch 9/30 - Train Loss: 2.0084, Train Acc: 26.98%, Val Loss: 1.9952, Val Acc: 29.20%
Epoch 10/30 - Train Loss: 2.0070, Train Acc: 27.67%, Val Loss: 2.1523, Val Acc: 25.50%
Epoch 11/30 - Train Loss: 2.0068, Train Acc: 28.18%, Val Loss: 1.9711, Val Acc: 30.60%
Epoch 12/30 - Train Loss: 1.9800, Train Acc: 29.29%,

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
lr,█████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▇▇▇████████
train_loss,█▇▇▆▆▆▆▅▅▅▅▅▅▄▅▄▄▄▃▂▂▂▂▂▁▁▂▁▁▁
val_accuracy,▁▂▄▄▃▃▃▄▄▃▅▅▄▅▅▅▄▆▆▇██▇███████
val_loss,█▇▅▅▅▅▆▄▄▆▄▄▅▄▄▄▇▄▂▂▂▁▂▁▁▁▁▁▁▁

0,1
best_val_acc,41.75
epoch,30.0
lr,0.01
train_accuracy,39.90499
train_loss,1.68725
val_accuracy,41.75
val_loss,1.66899


[34m[1mwandb[0m: Agent Starting Run: 5lf2veqv with config:
[34m[1mwandb[0m: 	activation_fn: GELU
[34m[1mwandb[0m: 	augment: False
[34m[1mwandb[0m: 	batch_norm: True
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	conv_filters: [256, 128, 64, 64, 32]
[34m[1mwandb[0m: 	dense_units: 128
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	epochs: 30
[34m[1mwandb[0m: 	filter_sizes: [3, 3, 3, 3, 3]
[34m[1mwandb[0m: 	lr: 0.001




Epoch 1/30 - Train Loss: 2.2052, Train Acc: 18.26%, Val Loss: 2.2421, Val Acc: 20.55%
Epoch 2/30 - Train Loss: 2.1291, Train Acc: 22.17%, Val Loss: 2.1544, Val Acc: 22.30%
Epoch 3/30 - Train Loss: 2.1044, Train Acc: 23.67%, Val Loss: 2.2156, Val Acc: 20.50%
Epoch 4/30 - Train Loss: 2.0962, Train Acc: 23.85%, Val Loss: 2.0497, Val Acc: 27.15%
Epoch 5/30 - Train Loss: 2.0717, Train Acc: 25.15%, Val Loss: 2.0553, Val Acc: 25.00%
Epoch 6/30 - Train Loss: 2.0498, Train Acc: 25.50%, Val Loss: 2.1130, Val Acc: 26.05%
Epoch 7/30 - Train Loss: 2.0291, Train Acc: 26.79%, Val Loss: 2.0437, Val Acc: 26.05%
Epoch 8/30 - Train Loss: 2.0239, Train Acc: 27.22%, Val Loss: 2.0706, Val Acc: 26.75%
Epoch 9/30 - Train Loss: 2.0424, Train Acc: 26.32%, Val Loss: 2.0312, Val Acc: 26.40%
Epoch 10/30 - Train Loss: 1.9883, Train Acc: 27.90%, Val Loss: 2.0185, Val Acc: 29.45%
Epoch 11/30 - Train Loss: 1.9875, Train Acc: 28.47%, Val Loss: 1.9602, Val Acc: 29.70%
Epoch 12/30 - Train Loss: 1.9915, Train Acc: 28.00%,

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
lr,██████████████████▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▇▇▇▇██████
train_loss,█▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▂▂▂▂▂▁▁▁▁▁
val_accuracy,▁▂▁▃▂▃▃▃▃▄▄▄▄▅▅▃▅▆▄▇▇▇▇▇▇▇▇███
val_loss,█▇▇▅▆▆▅▆▅▅▄▄▅▄▃█▄▄▅▂▂▂▂▂▂▁▁▁▁▁

0,1
best_val_acc,42.1
epoch,30.0
lr,0.01
train_accuracy,39.9675
train_loss,1.69726
val_accuracy,40.9
val_loss,1.66619


[34m[1mwandb[0m: Agent Starting Run: 28shrd5m with config:
[34m[1mwandb[0m: 	activation_fn: GELU
[34m[1mwandb[0m: 	augment: False
[34m[1mwandb[0m: 	batch_norm: True
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	conv_filters: [32, 64, 128, 256, 512]
[34m[1mwandb[0m: 	dense_units: 128
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	epochs: 30
[34m[1mwandb[0m: 	filter_sizes: [3, 3, 3, 3, 3]
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1/30 - Train Loss: 2.1960, Train Acc: 18.53%, Val Loss: 2.2380, Val Acc: 19.10%
Epoch 2/30 - Train Loss: 2.1396, Train Acc: 22.02%, Val Loss: 2.1545, Val Acc: 21.85%
Epoch 3/30 - Train Loss: 2.0930, Train Acc: 23.58%, Val Loss: 2.0760, Val Acc: 25.10%
Epoch 4/30 - Train Loss: 2.0963, Train Acc: 24.25%, Val Loss: 2.0846, Val Acc: 24.05%
Epoch 5/30 - Train Loss: 2.0774, Train Acc: 24.58%, Val Loss: 2.0899, Val Acc: 26.75%
Epoch 6/30 - Train Loss: 2.0560, Train Acc: 25.93%, Val Loss: 2.0589, Val Acc: 26.95%
Epoch 7/30 - Train Loss: 2.0551, Train Acc: 25.77%, Val Loss: 2.0797, Val Acc: 26.05%
Epoch 8/30 - Train Loss: 2.0352, Train Acc: 26.69%, Val Loss: 2.0163, Val Acc: 28.85%
Epoch 9/30 - Train Loss: 2.0245, Train Acc: 26.98%, Val Loss: 1.9758, Val Acc: 30.15%
Epoch 10/30 - Train Loss: 2.0054, Train Acc: 28.44%, Val Loss: 2.0389, Val Acc: 27.60%
Epoch 11/30 - Train Loss: 1.9942, Train Acc: 27.83%, Val Loss: 1.9721, Val Acc: 31.35%
Epoch 12/30 - Train Loss: 1.9766, Train Acc: 29.02%,

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
lr,████████████████▂▂▂▂▂▂▂▂▂▂▂▂▁▁
train_accuracy,▁▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▇▇▇▇▇▇▇█████
train_loss,█▇▇▇▆▆▆▆▆▅▅▅▅▅▅▅▄▃▂▂▂▂▂▂▂▁▁▁▁▁
val_accuracy,▁▂▃▂▃▃▃▄▄▄▅▅▅▅▄▅▅▇▇▇▇▇▇▇▇█▇▇██
val_loss,█▇▆▆▆▆▆▅▅▆▅▄▄▄▆▄▅▂▂▂▂▂▂▂▂▂▂▂▂▁

0,1
best_val_acc,42.35
epoch,30.0
lr,0.001
train_accuracy,40.29254
train_loss,1.6806
val_accuracy,42.35
val_loss,1.6504


[34m[1mwandb[0m: Agent Starting Run: 0equy2h8 with config:
[34m[1mwandb[0m: 	activation_fn: GELU
[34m[1mwandb[0m: 	augment: False
[34m[1mwandb[0m: 	batch_norm: True
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	conv_filters: [32, 64, 128, 256, 512]
[34m[1mwandb[0m: 	dense_units: 128
[34m[1mwandb[0m: 	dropout: 0.2
[34m[1mwandb[0m: 	epochs: 30
[34m[1mwandb[0m: 	filter_sizes: [7, 7, 7, 7, 7]
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1/30 - Train Loss: 2.1942, Train Acc: 18.10%, Val Loss: 2.2092, Val Acc: 19.35%
Epoch 2/30 - Train Loss: 2.1247, Train Acc: 22.87%, Val Loss: 2.2178, Val Acc: 23.30%
Epoch 3/30 - Train Loss: 2.0943, Train Acc: 24.03%, Val Loss: 2.0769, Val Acc: 24.55%
Epoch 4/30 - Train Loss: 2.0754, Train Acc: 25.30%, Val Loss: 2.0337, Val Acc: 26.75%
Epoch 5/30 - Train Loss: 2.0541, Train Acc: 25.73%, Val Loss: 2.1464, Val Acc: 23.60%
Epoch 6/30 - Train Loss: 2.0448, Train Acc: 25.84%, Val Loss: 2.0318, Val Acc: 28.30%
Epoch 7/30 - Train Loss: 2.0367, Train Acc: 27.17%, Val Loss: 2.0229, Val Acc: 27.05%
Epoch 8/30 - Train Loss: 2.0085, Train Acc: 27.58%, Val Loss: 2.2619, Val Acc: 25.00%
Epoch 9/30 - Train Loss: 1.9899, Train Acc: 28.47%, Val Loss: 2.0894, Val Acc: 28.65%
Epoch 10/30 - Train Loss: 1.9758, Train Acc: 29.09%, Val Loss: 1.9670, Val Acc: 31.25%
Epoch 11/30 - Train Loss: 1.9655, Train Acc: 29.63%, Val Loss: 1.9811, Val Acc: 31.10%
Epoch 12/30 - Train Loss: 1.9518, Train Acc: 30.28%,

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
lr,███████████████████▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▅▇▇▇███████
train_loss,█▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▄▂▂▂▁▁▁▁▁▁▁
val_accuracy,▁▂▃▃▂▄▃▃▄▄▄▄▅▅▅▅▅▅▅▅▇▇▇▇██████
val_loss,▇▇▆▅▇▅▅█▆▅▅▅▅▄▄▄▄▄▅▅▂▂▂▁▁▁▁▁▂▁

0,1
best_val_acc,43.45
epoch,30.0
lr,0.01
train_accuracy,40.25503
train_loss,1.67766
val_accuracy,42.45
val_loss,1.68988


[34m[1mwandb[0m: Agent Starting Run: 4h9zdgyq with config:
[34m[1mwandb[0m: 	activation_fn: GELU
[34m[1mwandb[0m: 	augment: False
[34m[1mwandb[0m: 	batch_norm: True
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	conv_filters: [256, 128, 64, 64, 32]
[34m[1mwandb[0m: 	dense_units: 128
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	epochs: 30
[34m[1mwandb[0m: 	filter_sizes: [3, 3, 3, 3, 3]
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1/30 - Train Loss: 2.2035, Train Acc: 18.18%, Val Loss: 2.1724, Val Acc: 22.20%
Epoch 2/30 - Train Loss: 2.1374, Train Acc: 21.53%, Val Loss: 2.0869, Val Acc: 23.40%
Epoch 3/30 - Train Loss: 2.0977, Train Acc: 23.50%, Val Loss: 2.0843, Val Acc: 24.95%
Epoch 4/30 - Train Loss: 2.0892, Train Acc: 23.83%, Val Loss: 2.0459, Val Acc: 25.40%
Epoch 5/30 - Train Loss: 2.0618, Train Acc: 25.77%, Val Loss: 2.0605, Val Acc: 25.45%
Epoch 6/30 - Train Loss: 2.0402, Train Acc: 26.43%, Val Loss: 2.0509, Val Acc: 26.30%
Epoch 7/30 - Train Loss: 2.0423, Train Acc: 27.20%, Val Loss: 1.9934, Val Acc: 28.00%
Epoch 8/30 - Train Loss: 2.0058, Train Acc: 28.22%, Val Loss: 2.1463, Val Acc: 25.60%
Epoch 9/30 - Train Loss: 2.0067, Train Acc: 27.74%, Val Loss: 1.9837, Val Acc: 30.00%
Epoch 10/30 - Train Loss: 1.9863, Train Acc: 28.78%, Val Loss: 2.1871, Val Acc: 25.45%
Epoch 11/30 - Train Loss: 1.9884, Train Acc: 28.67%, Val Loss: 2.0306, Val Acc: 26.65%
Epoch 12/30 - Train Loss: 1.9931, Train Acc: 27.84%,

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
lr,████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▂▃▃▄▄▄▄▄▄▄▄▅▆▇▇▇▇▇▇▇▇▇▇▇█████
train_loss,█▇▆▆▆▆▆▅▅▅▅▅▄▃▃▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁
val_accuracy,▁▁▂▂▂▃▃▂▄▂▃▃▄▆▇▇▇▇▇▇▆▇█▇▇██▇█▆
val_loss,█▇▇▆▆▆▅▇▅█▆▆▆▃▃▂▂▂▂▂▃▂▂▂▂▁▂▂▁▄

0,1
best_val_acc,40.95
epoch,30.0
lr,0.01
train_accuracy,39.39242
train_loss,1.72496
val_accuracy,36.8
val_loss,1.87376


[34m[1mwandb[0m: Agent Starting Run: zvvvpj24 with config:
[34m[1mwandb[0m: 	activation_fn: GELU
[34m[1mwandb[0m: 	augment: False
[34m[1mwandb[0m: 	batch_norm: True
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	conv_filters: [32, 64, 128, 256, 512]
[34m[1mwandb[0m: 	dense_units: 128
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	epochs: 30
[34m[1mwandb[0m: 	filter_sizes: [7, 7, 7, 7, 7]
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1/30 - Train Loss: 2.1991, Train Acc: 17.76%, Val Loss: 2.2476, Val Acc: 17.25%
Epoch 2/30 - Train Loss: 2.1265, Train Acc: 22.19%, Val Loss: 2.2045, Val Acc: 23.10%
Epoch 3/30 - Train Loss: 2.0889, Train Acc: 23.97%, Val Loss: 2.0915, Val Acc: 26.45%
Epoch 4/30 - Train Loss: 2.0612, Train Acc: 24.93%, Val Loss: 2.3098, Val Acc: 20.85%
Epoch 5/30 - Train Loss: 2.0449, Train Acc: 26.57%, Val Loss: 2.0976, Val Acc: 23.95%
Epoch 6/30 - Train Loss: 2.0451, Train Acc: 26.50%, Val Loss: 2.1063, Val Acc: 25.80%
Epoch 7/30 - Train Loss: 2.0252, Train Acc: 27.14%, Val Loss: 2.0415, Val Acc: 25.80%
Epoch 8/30 - Train Loss: 1.9970, Train Acc: 28.38%, Val Loss: 1.9918, Val Acc: 28.50%
Epoch 9/30 - Train Loss: 1.9847, Train Acc: 28.69%, Val Loss: 1.9931, Val Acc: 29.45%
Epoch 10/30 - Train Loss: 1.9548, Train Acc: 29.98%, Val Loss: 2.0423, Val Acc: 30.65%
Epoch 11/30 - Train Loss: 1.9607, Train Acc: 30.00%, Val Loss: 1.9788, Val Acc: 31.40%
Epoch 12/30 - Train Loss: 1.9523, Train Acc: 29.93%,

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
lr,██████████████████████████▁▁▁▁
train_accuracy,▁▂▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▆▅▅▆▆▆▆▆▆▆▇██
train_loss,█▇▇▆▆▆▆▅▅▅▅▅▅▅▄▅▄▄▄▄▄▃▃▃▃▃▃▂▁▁
val_accuracy,▁▃▃▂▃▃▃▄▄▅▅▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▅█▇█
val_loss,▇▇▆█▆▆▅▅▅▅▅▄▄▄▅▃▄▄▃▄▃▃▃▄▄▄▅▁▁▁

0,1
best_val_acc,43.5
epoch,30.0
lr,0.01
train_accuracy,41.94274
train_loss,1.64145
val_accuracy,43.5
val_loss,1.62069


[34m[1mwandb[0m: Agent Starting Run: 4p66no5p with config:
[34m[1mwandb[0m: 	activation_fn: GELU
[34m[1mwandb[0m: 	augment: False
[34m[1mwandb[0m: 	batch_norm: True
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	conv_filters: [32, 64, 128, 256, 512]
[34m[1mwandb[0m: 	dense_units: 128
[34m[1mwandb[0m: 	dropout: 0.3
[34m[1mwandb[0m: 	epochs: 30
[34m[1mwandb[0m: 	filter_sizes: [3, 3, 3, 3, 3]
[34m[1mwandb[0m: 	lr: 0.001




Epoch 1/30 - Train Loss: 2.2048, Train Acc: 18.15%, Val Loss: 2.3809, Val Acc: 18.35%
Epoch 2/30 - Train Loss: 2.1295, Train Acc: 22.78%, Val Loss: 2.1052, Val Acc: 23.90%


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


In [None]:
#1. Testing the Best Model and Creating Prediction Grid
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid

def test_model(model, test_loader, device, classes):
    model.eval()
    test_correct = 0
    all_preds = []
    all_images = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            test_correct += preds.eq(labels).sum().item()
            all_preds.extend(preds.cpu().numpy())
            all_images.extend(images.cpu())
            all_labels.extend(labels.cpu().numpy())
    
    test_acc = 100 * test_correct / len(test_loader.dataset)
    print(f'Test Accuracy: {test_acc:.2f}%')
    
    return test_acc, all_images, all_labels, all_preds

def create_prediction_grid(images, labels, preds, classes, n=10):
    # Create a figure with n rows and 3 columns
    fig, axes = plt.subplots(n, 3, figsize=(10, 3*n))
    
    # Get random indices for samples
    indices = np.random.choice(len(images), n, replace=False)
    
    for i, idx in enumerate(indices):
        # Original image
        img = images[idx].permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])  # Unnormalize
        img = np.clip(img, 0, 1)
        
        axes[i, 0].imshow(img)
        axes[i, 0].axis('off')
        if i == 0:
            axes[i, 0].set_title('Original Image')
        
        # True label
        true_label = classes[labels[idx]]
        axes[i, 1].text(0.5, 0.5, f'True: {true_label}', 
                       ha='center', va='center', fontsize=12)
        axes[i, 1].axis('off')
        if i == 0:
            axes[i, 1].set_title('True Label')
        
        # Predicted label (color red if wrong, green if correct)
        pred_label = classes[preds[idx]]
        color = 'red' if preds[idx] != labels[idx] else 'green'
        axes[i, 2].text(0.5, 0.5, f'Pred: {pred_label}', 
                        ha='center', va='center', fontsize=12, color=color)
        axes[i, 2].axis('off')
        if i == 0:
            axes[i, 2].set_title('Prediction')
    
    plt.tight_layout()
    plt.savefig('prediction_grid.png', bbox_inches='tight')
    plt.show()

# Load test data
test_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_dataset = datasets.ImageFolder(
    root="/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/val",  # Using val as test
    transform=test_transforms
)

test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4)

# Load best model
best_model = OptimizedCNN(num_classes=len(test_dataset.classes))
best_model.load_state_dict(torch.load('/kaggle/input/cnn/pytorch/default/1/best_model.pth'))
best_model.to(device)

# Test and create grid
test_acc, test_images, test_labels, test_preds = test_model(best_model, test_loader, device, test_dataset.classes)
create_prediction_grid(test_images, test_labels, test_preds, test_dataset.classes)

need to change the grid and visualize on wandb

In [None]:
# Visualizing First Layer Filters
def visualize_first_layer_filters(model, image, save_path='first_layer_filters.png'):
    # Get first conv layer weights
    first_conv = model.conv_blocks[0]
    filters = first_conv.weight.data.cpu().numpy()
    
    # Normalize filters to 0-1 for visualization
    f_min, f_max = filters.min(), filters.max()
    filters = (filters - f_min) / (f_max - f_min)
    
    # Plot filters in 8x8 grid
    fig, axes = plt.subplots(8, 8, figsize=(12, 12))
    
    for i in range(8):
        for j in range(8):
            idx = i * 8 + j
            if idx < filters.shape[0]:  # In case we have less than 64 filters
                filter_img = filters[idx].transpose(1, 2, 0)
                axes[i, j].imshow(filter_img)
                axes[i, j].axis('off')
            else:
                axes[i, j].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight')
    plt.show()

# Get a random test image
random_idx = np.random.randint(len(test_dataset))
image, _ = test_dataset[random_idx]
image = image.unsqueeze(0).to(device)

# Visualize filters
visualize_first_layer_filters(best_model, image)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms, datasets
import wandb

def visualize_first_layer(model, test_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load transformation
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Load test dataset and get random image
    test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)
    random_idx = np.random.randint(0, len(test_dataset))
    img, label = test_dataset[random_idx]
    img = img.unsqueeze(0).to(device)  # Add batch dimension
    
    # Get the first convolutional layer
    first_conv = model.conv_blocks[0]
    num_filters = first_conv.out_channels  # Dynamically get number of filters
    
    # Visualize filters
    filters = first_conv.weight.data.cpu().numpy()
    
    # Normalize filters to 0-1 for visualization
    f_min, f_max = filters.min(), filters.max()
    filters = (filters - f_min) / (f_max - f_min)
    
    # Calculate grid size (square as possible)
    grid_size = int(np.ceil(np.sqrt(num_filters)))
    
    # Plot filters
    plt.figure(figsize=(12, 12))
    for i in range(num_filters):
        plt.subplot(grid_size, grid_size, i+1)
        # Show first channel only (assuming RGB input)
        plt.imshow(filters[i, 0], cmap='gray')
        plt.axis('off')
    plt.suptitle(f'First Layer Filters ({num_filters} total)', fontsize=16)
    plt.tight_layout()
    filters_fig = plt.gcf()
    
    # Get feature maps
    model.eval()
    with torch.no_grad():
        feature_maps = first_conv(img)
    
    # Normalize feature maps
    fmaps = feature_maps.squeeze(0).cpu().numpy()
    fmap_min, fmap_max = fmaps.min(), fmaps.max()
    fmaps = (fmaps - fmap_min) / (fmap_max - fmap_min)
    
    # Plot feature maps
    plt.figure(figsize=(12, 12))
    for i in range(num_filters):
        plt.subplot(grid_size, grid_size, i+1)
        plt.imshow(fmaps[i], cmap='viridis')
        plt.axis('off')
    plt.suptitle(f'Feature Maps ({num_filters} total)', fontsize=16)
    plt.tight_layout()
    fmap_fig = plt.gcf()
    
    # Show original image (denormalized)
    img_denorm = img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
    img_denorm = img_denorm * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img_denorm = np.clip(img_denorm, 0, 1)
    
    plt.figure(figsize=(8, 8))
    plt.imshow(img_denorm)
    plt.title(f'Original Test Image\nClass: {test_dataset.classes[label]}')
    plt.axis('off')
    orig_fig = plt.gcf()
    
    # Additional analysis
    # 1. Filter magnitude distribution
    filter_magnitudes = torch.norm(first_conv.weight.data, dim=(1, 2, 3)).cpu().numpy()
    
    plt.figure(figsize=(10, 5))
    plt.hist(filter_magnitudes, bins=20)
    plt.title('Filter Magnitude Distribution')
    plt.xlabel('Magnitude (L2 norm)')
    plt.ylabel('Count')
    magnitude_fig = plt.gcf()
    
    # 2. Activation statistics
    activation_means = feature_maps.mean(dim=(0, 2, 3)).cpu().numpy()
    activation_max = feature_maps.amax(dim=(0, 2, 3)).cpu().numpy()
    
    plt.figure(figsize=(10, 5))
    plt.bar(range(num_filters), activation_means, alpha=0.5, label='Mean')
    plt.bar(range(num_filters), activation_max, alpha=0.5, label='Max')
    plt.title('Feature Map Activation Statistics')
    plt.xlabel('Filter Index')
    plt.ylabel('Activation Value')
    plt.legend()
    activation_fig = plt.gcf()
    
    # Log to wandb
    wandb.init(project="DL_A2", name="filter_visualization")
    wandb.log({
        "original_image": wandb.Image(orig_fig),
        "first_layer_filters": wandb.Image(filters_fig),
        "feature_maps": wandb.Image(fmap_fig),
        "filter_magnitudes": wandb.Image(magnitude_fig),
        "activation_stats": wandb.Image(activation_fig),
        "selected_class": test_dataset.classes[label]
    })
    
    plt.close('all')
    return {
        "num_filters": num_filters,
        "filter_magnitudes": filter_magnitudes,
        "activation_means": activation_means,
        "activation_max": activation_max
    }

# Load your model
model = OptimizedCNN(num_classes=len(test_dataset.classes))  # Adjust as needed
model.load_state_dict(torch.load('/kaggle/input/cnn/pytorch/default/1/best_model.pth'))
model = model.to(device)

# Run visualization
results = visualize_first_layer(
    model, 
    test_dir='/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/val'
)

# best feature map 

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms, datasets
import wandb

def visualize_first_layer(model, test_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load transformation
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Load test dataset and get random image
    test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)
    random_idx = np.random.randint(0, len(test_dataset))
    img, label = test_dataset[random_idx]
    img = img.unsqueeze(0).to(device)  # Add batch dimension
    
    # Get the first convolutional layer
    first_conv = model.conv_blocks[0]
    num_filters = first_conv.out_channels  # Dynamically get number of filters
    
    # Visualize filters
    filters = first_conv.weight.data.cpu().numpy()
    
    # Normalize filters to 0-1 for visualization
    f_min, f_max = filters.min(), filters.max()
    filters = (filters - f_min) / (f_max - f_min)
    
    # Calculate grid size (square as possible)
    grid_size = int(np.ceil(np.sqrt(num_filters)))
    
    # Plot filters
    plt.figure(figsize=(12, 12))
    for i in range(num_filters):
        plt.subplot(grid_size, grid_size, i+1)
        # Show first channel only (assuming RGB input)
        plt.imshow(filters[i, 0], cmap='gray')
        plt.axis('off')
    plt.suptitle(f'First Layer Filters ({num_filters} total)', fontsize=16)
    plt.tight_layout()
    filters_fig = plt.gcf()
    
    # Get feature maps
    model.eval()
    with torch.no_grad():
        feature_maps = first_conv(img)
    
    # Normalize feature maps
    fmaps = feature_maps.squeeze(0).cpu().numpy()
    fmap_min, fmap_max = fmaps.min(), fmaps.max()
    fmaps = (fmaps - fmap_min) / (fmap_max - fmap_min)
    
    # Plot feature maps
    plt.figure(figsize=(12, 12))
    for i in range(num_filters):
        plt.subplot(grid_size, grid_size, i+1)
        plt.imshow(fmaps[i], cmap='viridis')
        plt.axis('off')
    plt.suptitle(f'Feature Maps ({num_filters} total)', fontsize=16)
    plt.tight_layout()
    fmap_fig = plt.gcf()
    
    # Show original image (denormalized)
    img_denorm = img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
    img_denorm = img_denorm * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img_denorm = np.clip(img_denorm, 0, 1)
    
    plt.figure(figsize=(8, 8))
    plt.imshow(img_denorm)
    plt.title(f'Original Test Image\nClass: {test_dataset.classes[label]}')
    plt.axis('off')
    orig_fig = plt.gcf()
    
    # Additional analysis
    # 1. Filter magnitude distribution (fixed calculation)
    filter_magnitudes = torch.norm(first_conv.weight.data.view(num_filters, -1), p=2, dim=1).cpu().numpy()
    
    plt.figure(figsize=(10, 5))
    plt.hist(filter_magnitudes, bins=20)
    plt.title('Filter Magnitude Distribution')
    plt.xlabel('Magnitude (L2 norm)')
    plt.ylabel('Count')
    magnitude_fig = plt.gcf()
    
    # 2. Activation statistics
    activation_means = feature_maps.mean(dim=(0, 2, 3)).cpu().numpy()
    activation_max = feature_maps.amax(dim=(0, 2, 3)).cpu().numpy()
    
    plt.figure(figsize=(10, 5))
    plt.bar(range(num_filters), activation_means, alpha=0.5, label='Mean')
    plt.bar(range(num_filters), activation_max, alpha=0.5, label='Max')
    plt.title('Feature Map Activation Statistics')
    plt.xlabel('Filter Index')
    plt.ylabel('Activation Value')
    plt.legend()
    activation_fig = plt.gcf()
    
    # Log to wandb
    wandb.init(project="DL_A2", name="filter_visualization")
    wandb.log({
        "original_image": wandb.Image(orig_fig),
        "first_layer_filters": wandb.Image(filters_fig),
        "feature_maps": wandb.Image(fmap_fig),
        "filter_magnitudes": wandb.Image(magnitude_fig),
        "activation_stats": wandb.Image(activation_fig),
        "selected_class": test_dataset.classes[label]
    })
    
    plt.close('all')
    return {
        "num_filters": num_filters,
        "filter_magnitudes": filter_magnitudes,
        "activation_means": activation_means,
        "activation_max": activation_max
    }

# Load your model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_dataset = datasets.ImageFolder(root='/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/val', transform=transforms.ToTensor())
model = OptimizedCNN(num_classes=len(test_dataset.classes))
model.load_state_dict(torch.load('/kaggle/input/cnn/pytorch/default/1/best_model.pth', map_location=device))
model = model.to(device)

# Run visualization
results = visualize_first_layer(
    model, 
    test_dir='/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/val'
)

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from collections import defaultdict
import wandb
import random
from PIL import Image
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

def create_enhanced_grid(images, labels, preds, class_names, n_rows=10, n_cols=3, title="Predictions"):
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, n_rows*2.5))
    fig.suptitle(title, fontsize=16, y=1.02)
    
    for i in range(n_rows):
        for j in range(n_cols):
            idx = i * n_cols + j
            if idx >= len(images):
                break
                
            ax = axes[i,j]
            img = images[idx].numpy().transpose((1, 2, 0))
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img = np.clip((img * std + mean), 0, 1)
            
            ax.imshow(img)
            ax.axis('off')
            
            true_label = class_names[labels[idx]]
            pred_label = class_names[preds[idx]]
            is_correct = preds[idx] == labels[idx]
            
            # More informative title with confidence if available
            title_color = 'green' if is_correct else 'red'
            title_text = f"True: {true_label}\nPred: {pred_label}"
            
            if is_correct:
                title_text += "\n Correct"
            else:
                title_text += "\n Wrong"
                
            ax.set_title(title_text, fontsize=9, color=title_color, pad=2)
    
    plt.tight_layout()
    return fig

def plot_confusion_matrix(y_true, y_pred, class_names):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    return plt.gcf()

def evaluate_testset(model_path, test_dir, num_grids=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    dataset = datasets.ImageFolder(root=test_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=4)
    class_names = dataset.classes

    model = OptimizedCNN(num_classes=len(class_names))
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    # Collect all predictions
    all_images, all_labels, all_preds = [], [], []
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu()
            all_images.extend(images.cpu())
            all_labels.extend(labels)
            all_preds.extend(preds)

    accuracy = 100 * np.mean(np.array(all_preds) == np.array(all_labels))
    print(f"Test Accuracy: {accuracy:.2f}%")

    wandb.init(project="DL_A2", name="test_evaluation", job_type="eval")

    wandb.log({"test_accuracy": accuracy})

    # Class-wise accuracy with zero-division handling
    class_correct = defaultdict(int)
    class_total = defaultdict(int)
    
    for label, pred in zip(all_labels, all_preds):
        class_total[label] += 1
        if label == pred:
            class_correct[label] += 1
    
    # Calculate accuracy only for classes that have samples
    class_acc = {}
    for i in range(len(class_names)):
        if class_total[i] > 0:
            class_acc[class_names[i]] = 100 * class_correct[i] / class_total[i]
        else:
            class_acc[class_names[i]] = float('nan')  # Mark as NaN if no samples
    
    # Create wandb table
    wandb.log({"class_accuracy": wandb.Table(
        columns=["Class", "Accuracy", "Samples"],
        data=[[class_names[i], 
              class_acc[class_names[i]], 
              class_total[i]] 
             for i in range(len(class_names))]
    )})

    # Confusion matrix (only for classes with samples)
    present_classes = [i for i in range(len(class_names)) if class_total[i] > 0]
    present_labels = [l for l in all_labels if l in present_classes]
    present_preds = [p for i, p in enumerate(all_preds) if all_labels[i] in present_classes]
    
    if present_classes:
        cm = confusion_matrix(present_labels, present_preds, labels=present_classes)
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=[class_names[i] for i in present_classes],
                   yticklabels=[class_names[i] for i in present_classes])
        plt.title('Confusion Matrix (for classes with samples)')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        wandb.log({"confusion_matrix": wandb.Image(plt)})
        plt.close()

    # Create prediction grids only for classes with samples
    present_indices = [i for i in range(len(all_labels)) if all_labels[i] in present_classes]
    if present_indices:
        for i in range(min(num_grids, 10)):  # Ensure we don't request more grids than possible
            sample_size = min(30, len(present_indices))  # 10x3 grid
            indices = random.sample(present_indices, sample_size)
            sample_imgs = [all_images[j] for j in indices]
            sample_labels = [all_labels[j] for j in indices]
            sample_preds = [all_preds[j] for j in indices]

            fig = create_enhanced_grid(
                sample_imgs, sample_labels, sample_preds, 
                class_names, n_rows=10, n_cols=3,
                title=f"Sample Predictions - Grid {i+1}"
            )
            wandb.log({f"prediction_grid_{i}": wandb.Image(fig)})
            plt.close(fig)

    wandb.finish()

if __name__ == "__main__":
    evaluate_testset(
        model_path='/kaggle/input/cnn/pytorch/default/1/best_model.pth',
        test_dir='/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/val',
        num_grids=10
    )

#   best visualization for test dataset

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from collections import defaultdict
import wandb
import random
from sklearn.metrics import confusion_matrix
import seaborn as sns

def create_prediction_grid(images, labels, preds, class_names, n_rows=10, n_cols=3, title="Predictions"):
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, n_rows*2))
    fig.suptitle(title, fontsize=16, y=1.02)
    
    for i in range(n_rows):
        for j in range(n_cols):
            idx = i * n_cols + j
            if idx >= len(images):
                break
                
            ax = axes[i,j]
            img = images[idx].numpy().transpose((1, 2, 0))
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            img = np.clip((img * std + mean), 0, 1)
            
            ax.imshow(img)
            ax.axis('off')
            
            true_label = class_names[labels[idx]]
            pred_label = class_names[preds[idx]]
            is_correct = preds[idx] == labels[idx]
            
            title_color = 'green' if is_correct else 'red'
            title_text = f"True: {true_label}\nPred: {pred_label}"
            ax.set_title(title_text, fontsize=9, color=title_color, pad=2)
    
    plt.tight_layout()
    return fig

def evaluate_testset(model_path, test_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Same transforms as validation
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Load test dataset
    test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4)
    class_names = test_dataset.classes

    # Load model
    model = OptimizedCNN(num_classes=len(class_names))
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    # Collect predictions and ground truth
    all_images = []
    all_labels = []
    all_preds = []
    
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            all_images.extend(images.cpu())
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    # Calculate overall accuracy
    accuracy = 100 * np.sum(np.array(all_labels) == np.array(all_preds)) / len(all_labels)
    print(f"Test Accuracy: {accuracy:.2f}%")

    # Initialize wandb
    wandb.init(project="DL_A2", name="test_evaluation", job_type="eval")
    wandb.log({"test_accuracy": accuracy})

    # Calculate class-wise accuracy
    class_correct = defaultdict(int)
    class_total = defaultdict(int)
    
    for label, pred in zip(all_labels, all_preds):
        class_total[label] += 1
        if label == pred:
            class_correct[label] += 1
    
    # Create accuracy table
    accuracy_table = wandb.Table(columns=["Class", "Accuracy", "Samples"])
    for class_idx in range(len(class_names)):
        if class_total[class_idx] > 0:
            acc = 100 * class_correct[class_idx] / class_total[class_idx]
        else:
            acc = float('nan')
        accuracy_table.add_data(class_names[class_idx], acc, class_total[class_idx])
    
    wandb.log({"class_accuracy": accuracy_table})

    # Create confusion matrix (only for classes with samples)
    present_classes = [c for c in range(len(class_names)) if class_total[c] > 0]
    if present_classes:
        cm = confusion_matrix(all_labels, all_preds, labels=present_classes)
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=[class_names[c] for c in present_classes],
                   yticklabels=[class_names[c] for c in present_classes])
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        wandb.log({"confusion_matrix": wandb.Image(plt)})
        plt.close()

    # Create 10x3 prediction grid
    num_samples = 30  # 10 rows x 3 columns
    if len(all_images) >= num_samples:
        indices = random.sample(range(len(all_images)), num_samples)
        sample_images = [all_images[i] for i in indices]
        sample_labels = [all_labels[i] for i in indices]
        sample_preds = [all_preds[i] for i in indices]
        
        grid_fig = create_prediction_grid(
            sample_images, sample_labels, sample_preds, 
            class_names, title="Test Set Predictions (Random Sample)"
        )
        wandb.log({"prediction_grid": wandb.Image(grid_fig)})
        plt.close(grid_fig)
    else:
        print(f"Not enough samples ({len(all_images)}) to create full 10x3 grid")

    wandb.finish()

if __name__ == "__main__":
    evaluate_testset(
        model_path='/kaggle/input/cnn/pytorch/default/1/best_model.pth',
        test_dir='/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/val'
    )

In [None]:
import os
import random
import torch
import numpy as np
import matplotlib.pyplot as plt
import wandb
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.nn as nn


class OptimizedCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(OptimizedCNN, self).__init__()

        self.conv_blocks = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1),

            nn.Conv2d(64, 128, kernel_size=5, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1),

            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1)
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.conv_blocks(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


def load_test_data(test_dir="/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/val", batch_size=256):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    dataset = datasets.ImageFolder(root=test_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    return dataset, loader


def evaluate_best_model(best_model_path, test_loader, test_dataset, slider_index=0):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = OptimizedCNN(num_classes=len(test_dataset.classes)).to(device)
    model.load_state_dict(torch.load(best_model_path, map_location=device))
    model.eval()

    correct = 0
    total = 0
    all_images, all_labels, all_preds = [], [], []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)

            total += labels.size(0)
            correct += (preds == labels).sum().item()

            all_images.extend(images.cpu())
            all_labels.extend(labels.cpu())
            all_preds.extend(preds.cpu())

    test_accuracy = 100 * correct / total
    print(f"Test Accuracy: {test_accuracy:.2f}%")

    return all_images, all_labels, all_preds, test_accuracy


def create_prediction_grid(images, labels, preds, class_names, slider_index=0):
    """Returns a 10x3 grid of images with predictions"""
    random.seed(slider_index)
    indices = random.sample(range(len(images)), 30)

    fig, axes = plt.subplots(10, 3, figsize=(10, 25))
    for i, ax in enumerate(axes.flat):
        img = images[indices[i]].numpy().transpose((1, 2, 0))
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        ax.imshow(img)
        pred = class_names[preds[indices[i]]]
        label = class_names[labels[indices[i]]]
        title = f"Pred: {pred}\nTrue: {label}"
        ax.set_title(title, color='green' if pred == label else 'red', fontsize=8)
        ax.axis('off')

    plt.tight_layout()
    img_path = f"test_grid_{slider_index}.png"
    plt.savefig(img_path, dpi=300)
    return img_path


def main(best_model_path, slider_index=0):
    wandb.init(project="DL_A2", name=f"Test Evaluation - Grid {slider_index}")

    test_dataset, test_loader = load_test_data()
    all_images, all_labels, all_preds, test_acc = evaluate_best_model(
        best_model_path, test_loader, test_dataset, slider_index)

    grid_path = create_prediction_grid(all_images, all_labels, all_preds, test_dataset.classes, slider_index)

    wandb.log({
        "test_accuracy": test_acc,
        "prediction_grid": wandb.Image(grid_path),
        "grid_index": slider_index
    })
    wandb.finish()


# Call this with different slider values
for slider_index in range(0, 10):  # Slider index range from 0 to 9
    main(best_model_path="/kaggle/input/cnn/pytorch/default/1/best_model.pth", slider_index=slider_index)


In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from collections import defaultdict
import wandb
import random
from PIL import Image

# Load model architecture
class OptimizedCNN(torch.nn.Module):
    def __init__(self, num_classes=10):
        super(OptimizedCNN, self).__init__()
        self.conv_blocks = torch.nn.Sequential(
            torch.nn.Conv2d(3, 64, 7, 2, 3), torch.nn.BatchNorm2d(64), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2, 1),
            torch.nn.Conv2d(64, 128, 5, padding=2), torch.nn.BatchNorm2d(128), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2, 1),
            torch.nn.Conv2d(128, 256, 3, padding=1), torch.nn.BatchNorm2d(256), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2, 1),
            torch.nn.Conv2d(256, 512, 3, padding=1), torch.nn.BatchNorm2d(512), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2, 1),
            torch.nn.Conv2d(512, 512, 3, padding=1), torch.nn.BatchNorm2d(512), torch.nn.ReLU(), torch.nn.MaxPool2d(3, 2, 1)
        )
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(512, 1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.conv_blocks(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.classifier(x)

# Visualization Utility
def create_grid(images, labels, preds, class_names, title="Predictions"):
    fig, axes = plt.subplots(10, 3, figsize=(12, 30))
    for idx, ax in enumerate(axes.flat):
        img = images[idx].numpy().transpose((1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = np.clip((img * std + mean), 0, 1)
        ax.imshow(img)
        ax.axis('off')
        color = 'green' if preds[idx] == labels[idx] else 'red'
        ax.set_title(f"True: {class_names[labels[idx]]}\nPred: {class_names[preds[idx]]}", fontsize=8, color=color)
    fig.suptitle(title, fontsize=16)
    return fig

# Main Evaluation
def evaluate_testset(model_path, test_dir, num_grids=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    dataset = datasets.ImageFolder(root=test_dir, transform=transform)
    loader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=4)
    class_names = dataset.classes

    model = OptimizedCNN(num_classes=len(class_names))
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    # Collect all predictions
    all_images, all_labels, all_preds = [], [], []
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu()
            all_images.extend(images.cpu())
            all_labels.extend(labels)
            all_preds.extend(preds)

    accuracy = 100 * np.mean(np.array(all_preds) == np.array(all_labels))
    print(f"Test Accuracy: {accuracy:.2f}%")

    wandb.init(project="DL_A2", name="evaluate_test", job_type="test_eval", reinit=True)

    wandb.log({"test_accuracy": accuracy})

    # Create random panels of 30 images (10x3) with slider
    total = len(all_images)
    slider_images = []
    for i in range(num_grids):
        indices = random.sample(range(total), 30)
        sample_imgs = [all_images[j] for j in indices]
        sample_labels = [all_labels[j] for j in indices]
        sample_preds = [all_preds[j] for j in indices]

        fig = create_grid(sample_imgs, sample_labels, sample_preds, class_names, title=f"Random Grid {i+1}")
        grid_path = f"panel_grid_{i}.png"
        fig.savefig(grid_path, dpi=300, bbox_inches='tight')
        wandb.log({f"prediction_panel_{i}": wandb.Image(grid_path)})

    # Panel section slider (index from 0 to 9)
    wandb.log({
        "grid_slider": wandb.Image("panel_grid_0.png"),
        "index_slider": wandb.Html('<input type="range" min="0" max="9" value="0" step="1">')
    })

    wandb.finish()

# Run
if __name__ == "__main__":
    evaluate_testset(
        model_path='/kaggle/input/cnn/pytorch/default/1/best_model.pth',
        test_dir='/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/val',
        num_grids=10  # will generate 10 random 10×3 grids
    )


# Part B

## Question 5

In [5]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import wandb
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np

# ----------------------------
# Data loading and preprocessing
# ----------------------------
def get_dataloaders(data_dir, batch_size=64, val_split=0.2):
    transform_train = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    transform_val = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    full_dataset = datasets.ImageFolder(root=data_dir, transform=transform_train)
    targets = np.array(full_dataset.targets)

    # Stratified split
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=val_split, random_state=42)
    train_idx, val_idx = next(splitter.split(np.zeros(len(targets)), targets))

    train_dataset = torch.utils.data.Subset(full_dataset, train_idx)
    val_dataset = torch.utils.data.Subset(datasets.ImageFolder(root=data_dir, transform=transform_val), val_idx)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader, full_dataset.classes

# ----------------------------
# Fine-tune ViT model
# ----------------------------
def build_model(num_classes=10):
    model = models.vit_b_16(weights="IMAGENET1K_V1")
    model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
    return model

# ----------------------------
# Train and validate
# ----------------------------
def train(model, train_loader, val_loader, optimizer, criterion, scheduler, device, epochs=10):
    model.to(device)
    best_val_acc = 0.0

    for epoch in range(epochs):
        model.train()
        total_loss, correct = 0, 0

        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()

        train_acc = correct / len(train_loader.dataset)

        # Validation
        model.eval()
        val_loss, val_correct = 0, 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                val_correct += (outputs.argmax(1) == labels).sum().item()

        val_acc = val_correct / len(val_loader.dataset)
        scheduler.step(val_loss)

        wandb.log({
            "epoch": epoch + 1,
            "train_loss": total_loss / len(train_loader),
            "val_loss": val_loss / len(val_loader),
            "train_accuracy": train_acc,
            "val_accuracy": val_acc,
            "lr": optimizer.param_groups[0]['lr']
        })

        print(f"Epoch {epoch+1}: Train Acc = {train_acc:.4f}, Val Acc = {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "vit_best_model.pth")
            print(f" Saved new best model with val_acc = {best_val_acc:.4f}")

# ----------------------------
# Main
# ----------------------------
def main():
    wandb.init(project="DL_A2", name="ViT_model")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, val_loader, classes = get_dataloaders(
        data_dir="/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/train",
        batch_size=64
    )

    model = build_model(num_classes=len(classes))

    # Freeze all except classifier (optional)
    for param in model.parameters():
        param.requires_grad = False
    for param in model.heads.parameters():
        param.requires_grad = True

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=3)

    train(model, train_loader, val_loader, optimizer, criterion, scheduler, device, epochs=10)

    wandb.finish()



In [6]:
if __name__ == "__main__":
    main()


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:01<00:00, 212MB/s]  


Epoch 1: Train Acc = 0.7223, Val Acc = 0.8430
 Saved new best model with val_acc = 0.8430
Epoch 2: Train Acc = 0.7896, Val Acc = 0.8500
 Saved new best model with val_acc = 0.8500
Epoch 3: Train Acc = 0.7981, Val Acc = 0.8590
 Saved new best model with val_acc = 0.8590
Epoch 4: Train Acc = 0.8019, Val Acc = 0.8560
Epoch 5: Train Acc = 0.8055, Val Acc = 0.8595
 Saved new best model with val_acc = 0.8595
Epoch 6: Train Acc = 0.8099, Val Acc = 0.8565
Epoch 7: Train Acc = 0.8074, Val Acc = 0.8575
Epoch 8: Train Acc = 0.8107, Val Acc = 0.8600
 Saved new best model with val_acc = 0.8600
Epoch 9: Train Acc = 0.8159, Val Acc = 0.8590
Epoch 10: Train Acc = 0.8117, Val Acc = 0.8520


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▆▇▇▇█▇███
train_loss,█▃▃▂▂▂▂▁▁▁
val_accuracy,▁▄█▆█▇▇██▅
val_loss,█▅▄▃▃▂▂▁▁▂

0,1
epoch,10.0
lr,0.01
train_accuracy,0.81173
train_loss,0.59423
val_accuracy,0.852
val_loss,0.47472


## Frozen Feature Extractor

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import wandb

# Initialize WandB
wandb.init(project="DL_A2",name="inaturalist-finetuning", config={
    "strategy": "Frozen Feature Extractor",
    "architecture": "ResNet50"
})

# Data Transforms
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load Data
train_data = datasets.ImageFolder('/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/train', transform=train_transform)
val_data = datasets.ImageFolder('/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/val', transform=val_transform)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)

# Model Setup
model = models.resnet50(weights='IMAGENET1K_V2')

# Freeze all layers
for param in model.parameters():
    param.requires_grad = False

# Replace final layer
model.fc = nn.Linear(model.fc.in_features, len(train_data.classes))

# Training Config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

# Training Loop
for epoch in range(10):
    model.train()
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
    # Validation
    model.eval()
    val_loss = 0
    correct = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            val_loss += criterion(outputs, labels).item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
    
    val_acc = 100 * correct / len(val_data)
    
    wandb.log({
        "epoch": epoch,
        "train_loss": loss.item(),
        "val_loss": val_loss/len(val_loader),
        "val_acc": val_acc
    })

torch.save(model.state_dict(), "finetuned_resnet50.pth")
wandb.finish()

0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_loss,▇█▄▆▂▃▅▅█▁
val_acc,▁▅▅▆▇▇█▇██
val_loss,█▅▃▂▂▂▁▁▁▁

0,1
epoch,9.0
train_loss,0.34243
val_acc,84.45
val_loss,0.50303


/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/

Progressive Unfreezing (Bottom-Up)

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import wandb

# Initialize WandB
wandb.init(project="DL_A2",name="inaturalist-progressive-unfreeze", config={
    "strategy": "Progressive Unfreezing (Bottom-Up)",
    "architecture": "ResNet50"
})

# Data Transforms
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load Data
train_data = datasets.ImageFolder('/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/train', transform=train_transform)
val_data = datasets.ImageFolder('/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/val', transform=val_transform)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_data, batch_size=64, num_workers=4)

# Model Setup
model = models.resnet50(weights='IMAGENET1K_V2')

# Freeze all layers initially
for param in model.parameters():
    param.requires_grad = False

# Replace final layer
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(model.fc.in_features, 256),
    nn.ReLU(),
    nn.Linear(256, len(train_data.classes))
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Training Config
optimizer = optim.Adam([
    {'params': model.fc.parameters(), 'lr': 1e-3}
])
criterion = nn.CrossEntropyLoss()

# Progressive Unfreezing Schedule
unfreeze_schedule = {
    3: ['layer4'],    # Unfreeze layer4 after 3 epochs
    6: ['layer3'],    # Unfreeze layer3 after 6 epochs
    9: ['layer2']     # Unfreeze layer2 after 9 epochs
}

for epoch in range(15):
    # Check unfreezing condition
    if epoch in unfreeze_schedule:
        for layer_name in unfreeze_schedule[epoch]:
            for name, param in model.named_parameters():
                if layer_name in name:
                    param.requires_grad = True
            print(f"Unfrozen {layer_name}")
            
            # Add to optimizer with lower LR
            new_params = [p for n,p in model.named_parameters() 
                         if layer_name in n and p.requires_grad]
            optimizer.add_param_group({
                'params': new_params,
                'lr': 1e-4 * (0.1 ** (epoch//3))  # Decreasing LR
            })
    
    # Training
    model.train()
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
    # Validation
    model.eval()
    val_loss, correct = 0, 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            val_loss += criterion(outputs, labels).item()
            correct += (outputs.argmax(1) == labels).sum().item()
    
    val_acc = 100 * correct / len(val_data)
    
    wandb.log({
        "epoch": epoch,
        "train_loss": loss.item(),
        "val_loss": val_loss/len(val_loader),
        "val_acc": val_acc,
        "lr_fc": optimizer.param_groups[0]['lr'],
        "lr_layer4": optimizer.param_groups[1]['lr'] if len(optimizer.param_groups)>1 else 0
    })

torch.save(model.state_dict(), "progressive_unfreeze_resnet50.pth")
wandb.finish()

Unfrozen layer4
Unfrozen layer3
Unfrozen layer2


0,1
epoch,▁▁▂▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
lr_fc,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_layer4,▁▁▁████████████
train_loss,▅▆▆█▇▅█▆█▆▆▆▆▁▄▂▃▅
val_acc,▁▃▄▃▄▄▅▆▇▇▇▇▇█████
val_loss,█▆▄▅▄▄▃▂▂▂▂▂▁▁▁▁▁▁

0,1
epoch,14.0
lr_fc,0.001
lr_layer4,1e-05
train_loss,0.8123
val_acc,86.95
val_loss,0.38761


In [5]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np
import wandb


# ------------------------ Sweep Config ------------------------ #
sweep_config = {
    "method": "random",
    "metric": {"name": "val_accuracy", "goal": "maximize"},
    "parameters": {
        "augment": {"values": [True, False]},
        "batch_size": {"values": [64, 256]},
        "lr": {"values": [0.01, 0.001]},
        "epochs": {"value": 10}
    }
}


# ------------------------ Data Loader ------------------------ #
def get_dataloaders(data_dir, batch_size=256, val_split=0.2, augment=True):
    train_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(0.2, 0.2, 0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]) if augment else transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    full_dataset = datasets.ImageFolder(root=data_dir, transform=train_transforms)
    targets = np.array(full_dataset.targets)

    splitter = StratifiedShuffleSplit(n_splits=1, test_size=val_split, random_state=42)
    train_idx, val_idx = next(splitter.split(np.zeros(len(targets)), targets))

    train_set = Subset(full_dataset, train_idx)
    val_set = Subset(datasets.ImageFolder(root=data_dir, transform=val_transforms), val_idx)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader, full_dataset.classes


# ------------------------ Training Function ------------------------ #
def train(model, train_loader, val_loader, optimizer, criterion, scheduler, device, epochs=30):
    model.to(device)
    best_val_acc = 0

    for epoch in range(epochs):
        # Progressive Unfreezing
        if epoch == 5:
            for name, param in model.named_parameters():
                if "encoder.layer.10" in name or "encoder.layer.11" in name:
                    param.requires_grad = True

        model.train()
        train_loss, correct = 0, 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            correct += (out.argmax(1) == y).sum().item()
        train_acc = 100. * correct / len(train_loader.dataset)

        # Validation
        model.eval()
        val_loss, val_correct = 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                out = model(x)
                loss = criterion(out, y)
                val_loss += loss.item()
                val_correct += (out.argmax(1) == y).sum().item()
        val_acc = 100. * val_correct / len(val_loader.dataset)
        scheduler.step(val_loss)

        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss / len(train_loader),
            "train_accuracy": train_acc,
            "val_loss": val_loss / len(val_loader),
            "val_accuracy": val_acc,
            "lr": optimizer.param_groups[0]['lr']
        })

        print(f"Epoch {epoch+1}: Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model_vit.pth")

    return best_val_acc


# ------------------------ Main Function ------------------------ #
def main():
    wandb.init(project="DL_A2")
    config = wandb.config
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader, val_loader, classes = get_dataloaders(
        data_dir="/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/train",
        batch_size=config.batch_size,
        augment=config.augment
    )

    # Load pre-trained ViT
    vit = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)

    # Freeze all layers first
    for param in vit.parameters():
        param.requires_grad = False

    # Replace classifier head
    in_features = vit.heads[0].in_features
    vit.heads = nn.Sequential(
        nn.Linear(in_features, 512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, len(classes))
    )

    # Enable training on new head
    for param in vit.heads.parameters():
        param.requires_grad = True

    optimizer = optim.SGD(vit.parameters(), lr=config.lr, momentum=0.9, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)
    criterion = nn.CrossEntropyLoss()

    best_val_acc = train(
        model=vit,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        scheduler=scheduler,
        device=device,
        epochs=config.epochs
    )

    wandb.summary["best_val_acc"] = best_val_acc
    wandb.finish()


# ------------------------ Start Sweep ------------------------ #
sweep_id = wandb.sweep(sweep_config, project="DL_A2")
wandb.agent(sweep_id, function=main, count=20)


Create sweep with ID: w0wfeap0
Sweep URL: https://wandb.ai/cs24m016-indian-institute-of-technology-madras/DL_A2/sweeps/w0wfeap0


[34m[1mwandb[0m: Agent Starting Run: 23i3wyb0 with config:
[34m[1mwandb[0m: 	augment: False
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.001




Epoch 1: Train Acc: 45.73%, Val Acc: 72.05%
Epoch 2: Train Acc: 76.27%, Val Acc: 81.25%
Epoch 3: Train Acc: 81.14%, Val Acc: 83.50%
Epoch 4: Train Acc: 83.09%, Val Acc: 84.15%
Epoch 5: Train Acc: 83.97%, Val Acc: 84.45%
Epoch 6: Train Acc: 84.57%, Val Acc: 84.75%
Epoch 7: Train Acc: 84.82%, Val Acc: 84.90%
Epoch 8: Train Acc: 85.61%, Val Acc: 85.30%
Epoch 9: Train Acc: 86.05%, Val Acc: 85.30%
Epoch 10: Train Acc: 86.41%, Val Acc: 85.35%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▆▇▇██████
train_loss,█▅▃▂▂▂▁▁▁▁
val_accuracy,▁▆▇▇██████
val_loss,█▅▃▂▂▂▁▁▁▁

0,1
best_val_acc,85.35
epoch,10.0
lr,0.001
train_accuracy,86.4108
train_loss,0.4941
val_accuracy,85.35
val_loss,0.51933


[34m[1mwandb[0m: Agent Starting Run: a1bwgbh4 with config:
[34m[1mwandb[0m: 	augment: True
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.001




Epoch 1: Train Acc: 14.41%, Val Acc: 31.30%
Epoch 2: Train Acc: 36.34%, Val Acc: 58.90%
Epoch 3: Train Acc: 54.43%, Val Acc: 69.45%
Epoch 4: Train Acc: 63.91%, Val Acc: 74.60%
Epoch 5: Train Acc: 69.08%, Val Acc: 77.50%
Epoch 6: Train Acc: 71.52%, Val Acc: 79.10%
Epoch 7: Train Acc: 72.07%, Val Acc: 79.80%
Epoch 8: Train Acc: 73.68%, Val Acc: 80.55%
Epoch 9: Train Acc: 74.66%, Val Acc: 80.95%
Epoch 10: Train Acc: 76.13%, Val Acc: 81.55%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▆▇▇▇████
train_loss,█▇▆▆▅▄▃▂▂▁
val_accuracy,▁▅▆▇▇█████
val_loss,█▇▆▅▅▄▃▂▂▁

0,1
best_val_acc,81.55
epoch,10.0
lr,0.001
train_accuracy,76.13452
train_loss,1.18125
val_accuracy,81.55
val_loss,1.0794


[34m[1mwandb[0m: Agent Starting Run: bhksj3yq with config:
[34m[1mwandb[0m: 	augment: True
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.001




Epoch 1: Train Acc: 13.19%, Val Acc: 30.05%
Epoch 2: Train Acc: 37.14%, Val Acc: 57.30%
Epoch 3: Train Acc: 54.73%, Val Acc: 67.85%
Epoch 4: Train Acc: 64.32%, Val Acc: 73.45%
Epoch 5: Train Acc: 67.60%, Val Acc: 76.25%
Epoch 6: Train Acc: 71.12%, Val Acc: 79.10%
Epoch 7: Train Acc: 73.37%, Val Acc: 80.15%
Epoch 8: Train Acc: 74.65%, Val Acc: 81.05%
Epoch 9: Train Acc: 74.65%, Val Acc: 81.75%
Epoch 10: Train Acc: 76.02%, Val Acc: 82.25%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▄▆▇▇▇████
train_loss,█▇▆▅▅▄▃▂▂▁
val_accuracy,▁▅▆▇▇█████
val_loss,█▇▆▅▄▄▃▂▂▁

0,1
best_val_acc,82.25
epoch,10.0
lr,0.001
train_accuracy,76.022
train_loss,1.17054
val_accuracy,82.25
val_loss,1.06778


[34m[1mwandb[0m: Agent Starting Run: 12vxj77x with config:
[34m[1mwandb[0m: 	augment: True
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1: Train Acc: 47.69%, Val Acc: 80.50%
Epoch 2: Train Acc: 76.25%, Val Acc: 82.30%
Epoch 3: Train Acc: 78.21%, Val Acc: 84.30%
Epoch 4: Train Acc: 79.47%, Val Acc: 84.75%
Epoch 5: Train Acc: 80.26%, Val Acc: 84.95%
Epoch 6: Train Acc: 80.59%, Val Acc: 85.10%
Epoch 7: Train Acc: 81.21%, Val Acc: 85.45%
Epoch 8: Train Acc: 81.54%, Val Acc: 85.75%
Epoch 9: Train Acc: 81.84%, Val Acc: 85.75%
Epoch 10: Train Acc: 82.39%, Val Acc: 86.10%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▇▇▇██████
train_loss,█▃▂▂▁▁▁▁▁▁
val_accuracy,▁▃▆▆▇▇▇███
val_loss,█▃▂▂▁▁▁▁▁▁

0,1
best_val_acc,86.1
epoch,10.0
lr,0.01
train_accuracy,82.3853
train_loss,0.57101
val_accuracy,86.1
val_loss,0.46575


[34m[1mwandb[0m: Agent Starting Run: 4iygpx6d with config:
[34m[1mwandb[0m: 	augment: True
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.001




Epoch 1: Train Acc: 42.32%, Val Acc: 73.45%
Epoch 2: Train Acc: 70.57%, Val Acc: 81.00%
Epoch 3: Train Acc: 75.73%, Val Acc: 82.30%
Epoch 4: Train Acc: 76.90%, Val Acc: 83.40%
Epoch 5: Train Acc: 77.53%, Val Acc: 84.20%
Epoch 6: Train Acc: 78.97%, Val Acc: 84.20%
Epoch 7: Train Acc: 78.78%, Val Acc: 84.30%
Epoch 8: Train Acc: 79.17%, Val Acc: 84.85%
Epoch 9: Train Acc: 79.58%, Val Acc: 84.75%
Epoch 10: Train Acc: 79.68%, Val Acc: 84.70%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▆▇▇██████
train_loss,█▅▃▂▂▂▁▁▁▁
val_accuracy,▁▆▆▇██████
val_loss,█▅▃▂▂▂▁▁▁▁

0,1
best_val_acc,84.85
epoch,10.0
lr,0.001
train_accuracy,79.68496
train_loss,0.6705
val_accuracy,84.7
val_loss,0.5439


[34m[1mwandb[0m: Agent Starting Run: yzgv9d3e with config:
[34m[1mwandb[0m: 	augment: False
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1: Train Acc: 51.27%, Val Acc: 80.30%
Epoch 2: Train Acc: 82.09%, Val Acc: 83.60%
Epoch 3: Train Acc: 84.37%, Val Acc: 84.80%
Epoch 4: Train Acc: 86.02%, Val Acc: 85.15%
Epoch 5: Train Acc: 86.77%, Val Acc: 85.35%
Epoch 6: Train Acc: 87.64%, Val Acc: 85.70%
Epoch 7: Train Acc: 88.54%, Val Acc: 85.95%
Epoch 8: Train Acc: 89.60%, Val Acc: 86.05%
Epoch 9: Train Acc: 89.59%, Val Acc: 86.50%
Epoch 10: Train Acc: 90.12%, Val Acc: 86.25%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▇▇▇▇█████
train_loss,█▄▂▂▂▁▁▁▁▁
val_accuracy,▁▅▆▆▇▇▇▇██
val_loss,█▃▂▂▁▁▁▁▁▁

0,1
best_val_acc,86.5
epoch,10.0
lr,0.01
train_accuracy,90.12377
train_loss,0.33415
val_accuracy,86.25
val_loss,0.45406


[34m[1mwandb[0m: Agent Starting Run: 59zknbis with config:
[34m[1mwandb[0m: 	augment: True
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1: Train Acc: 69.30%, Val Acc: 84.30%
Epoch 2: Train Acc: 80.05%, Val Acc: 86.05%
Epoch 3: Train Acc: 81.25%, Val Acc: 86.55%
Epoch 4: Train Acc: 82.89%, Val Acc: 86.70%
Epoch 5: Train Acc: 82.81%, Val Acc: 87.10%
Epoch 6: Train Acc: 83.91%, Val Acc: 87.30%
Epoch 7: Train Acc: 84.31%, Val Acc: 87.55%
Epoch 8: Train Acc: 85.17%, Val Acc: 87.35%
Epoch 9: Train Acc: 85.67%, Val Acc: 87.70%
Epoch 10: Train Acc: 86.30%, Val Acc: 87.85%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▅▆▇▇▇▇███
train_loss,█▃▃▂▂▂▂▁▁▁
val_accuracy,▁▄▅▆▇▇▇▇██
val_loss,█▅▃▂▂▂▂▁▁▁

0,1
best_val_acc,87.85
epoch,10.0
lr,0.01
train_accuracy,86.29829
train_loss,0.42413
val_accuracy,87.85
val_loss,0.38344


[34m[1mwandb[0m: Agent Starting Run: 6bgdkzeg with config:
[34m[1mwandb[0m: 	augment: True
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1: Train Acc: 69.35%, Val Acc: 84.95%
Epoch 2: Train Acc: 79.83%, Val Acc: 85.50%
Epoch 3: Train Acc: 81.31%, Val Acc: 86.05%
Epoch 4: Train Acc: 82.39%, Val Acc: 86.70%
Epoch 5: Train Acc: 82.85%, Val Acc: 86.55%
Epoch 6: Train Acc: 84.06%, Val Acc: 87.15%
Epoch 7: Train Acc: 84.46%, Val Acc: 87.10%
Epoch 8: Train Acc: 84.66%, Val Acc: 87.35%
Epoch 9: Train Acc: 85.61%, Val Acc: 86.90%
Epoch 10: Train Acc: 85.75%, Val Acc: 87.65%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▅▆▇▇▇▇███
train_loss,█▃▃▂▂▂▂▁▁▁
val_accuracy,▁▂▄▆▅▇▇▇▆█
val_loss,█▅▄▃▂▁▁▁▁▁

0,1
best_val_acc,87.65
epoch,10.0
lr,0.01
train_accuracy,85.74822
train_loss,0.42192
val_accuracy,87.65
val_loss,0.39498


[34m[1mwandb[0m: Agent Starting Run: 1mxu3mmd with config:
[34m[1mwandb[0m: 	augment: False
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1: Train Acc: 51.38%, Val Acc: 80.30%
Epoch 2: Train Acc: 81.76%, Val Acc: 83.35%
Epoch 3: Train Acc: 84.30%, Val Acc: 84.45%
Epoch 4: Train Acc: 85.49%, Val Acc: 84.90%
Epoch 5: Train Acc: 86.56%, Val Acc: 85.45%
Epoch 6: Train Acc: 87.74%, Val Acc: 85.70%
Epoch 7: Train Acc: 88.60%, Val Acc: 85.90%
Epoch 8: Train Acc: 89.07%, Val Acc: 86.65%
Epoch 9: Train Acc: 89.84%, Val Acc: 86.95%
Epoch 10: Train Acc: 90.32%, Val Acc: 87.10%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▆▇▇▇█████
train_loss,█▄▂▂▂▁▁▁▁▁
val_accuracy,▁▄▅▆▆▇▇███
val_loss,█▃▂▂▁▁▁▁▁▁

0,1
best_val_acc,87.1
epoch,10.0
lr,0.01
train_accuracy,90.32379
train_loss,0.33331
val_accuracy,87.1
val_loss,0.44397


[34m[1mwandb[0m: Agent Starting Run: zz5dyyg0 with config:
[34m[1mwandb[0m: 	augment: False
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1: Train Acc: 52.91%, Val Acc: 80.50%
Epoch 2: Train Acc: 81.89%, Val Acc: 84.30%
Epoch 3: Train Acc: 84.34%, Val Acc: 84.70%
Epoch 4: Train Acc: 85.75%, Val Acc: 85.05%
Epoch 5: Train Acc: 86.75%, Val Acc: 85.40%
Epoch 6: Train Acc: 87.76%, Val Acc: 85.65%
Epoch 7: Train Acc: 88.55%, Val Acc: 85.65%
Epoch 8: Train Acc: 89.01%, Val Acc: 85.65%
Epoch 9: Train Acc: 89.52%, Val Acc: 86.25%
Epoch 10: Train Acc: 90.07%, Val Acc: 86.05%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▆▇▇▇█████
train_loss,█▄▂▂▂▁▁▁▁▁
val_accuracy,▁▆▆▇▇▇▇▇██
val_loss,█▃▂▂▁▁▁▁▁▁

0,1
best_val_acc,86.25
epoch,10.0
lr,0.01
train_accuracy,90.07376
train_loss,0.33829
val_accuracy,86.05
val_loss,0.45322


[34m[1mwandb[0m: Agent Starting Run: o8am6wd8 with config:
[34m[1mwandb[0m: 	augment: True
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.001




Epoch 1: Train Acc: 40.71%, Val Acc: 73.60%
Epoch 2: Train Acc: 71.02%, Val Acc: 80.50%
Epoch 3: Train Acc: 74.58%, Val Acc: 82.35%
Epoch 4: Train Acc: 76.71%, Val Acc: 83.15%
Epoch 5: Train Acc: 78.13%, Val Acc: 83.75%
Epoch 6: Train Acc: 78.16%, Val Acc: 83.80%
Epoch 7: Train Acc: 78.20%, Val Acc: 84.10%
Epoch 8: Train Acc: 79.71%, Val Acc: 84.55%
Epoch 9: Train Acc: 79.40%, Val Acc: 84.60%
Epoch 10: Train Acc: 79.26%, Val Acc: 84.70%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▆▇▇██████
train_loss,█▅▃▂▂▁▁▁▁▁
val_accuracy,▁▅▇▇▇▇████
val_loss,█▅▃▂▂▂▁▁▁▁

0,1
best_val_acc,84.7
epoch,10.0
lr,0.001
train_accuracy,79.25991
train_loss,0.68659
val_accuracy,84.7
val_loss,0.54984


[34m[1mwandb[0m: Agent Starting Run: 0406a3zu with config:
[34m[1mwandb[0m: 	augment: False
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.001




Epoch 1: Train Acc: 17.24%, Val Acc: 35.00%
Epoch 2: Train Acc: 41.94%, Val Acc: 56.40%
Epoch 3: Train Acc: 59.58%, Val Acc: 69.40%
Epoch 4: Train Acc: 68.47%, Val Acc: 74.70%
Epoch 5: Train Acc: 73.85%, Val Acc: 77.65%
Epoch 6: Train Acc: 76.81%, Val Acc: 79.25%
Epoch 7: Train Acc: 79.65%, Val Acc: 79.55%
Epoch 8: Train Acc: 80.05%, Val Acc: 80.25%
Epoch 9: Train Acc: 81.75%, Val Acc: 81.05%
Epoch 10: Train Acc: 81.66%, Val Acc: 81.55%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▄▆▇▇▇████
train_loss,█▇▆▆▅▄▃▂▂▁
val_accuracy,▁▄▆▇▇█████
val_loss,█▇▆▅▅▄▃▂▂▁

0,1
best_val_acc,81.55
epoch,10.0
lr,0.001
train_accuracy,81.66021
train_loss,1.07035
val_accuracy,81.55
val_loss,1.02605


[34m[1mwandb[0m: Agent Starting Run: 61ubh9lp with config:
[34m[1mwandb[0m: 	augment: True
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.001




Epoch 1: Train Acc: 13.75%, Val Acc: 29.95%
Epoch 2: Train Acc: 35.48%, Val Acc: 55.30%
Epoch 3: Train Acc: 54.39%, Val Acc: 68.40%
Epoch 4: Train Acc: 63.16%, Val Acc: 74.55%
Epoch 5: Train Acc: 69.00%, Val Acc: 77.25%
Epoch 6: Train Acc: 71.68%, Val Acc: 79.45%
Epoch 7: Train Acc: 72.52%, Val Acc: 80.20%
Epoch 8: Train Acc: 73.81%, Val Acc: 81.30%
Epoch 9: Train Acc: 75.61%, Val Acc: 82.00%
Epoch 10: Train Acc: 75.40%, Val Acc: 82.20%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▆▇▇█████
train_loss,█▇▆▅▅▄▃▂▂▁
val_accuracy,▁▄▆▇▇█████
val_loss,█▇▆▅▅▄▃▂▂▁

0,1
best_val_acc,82.2
epoch,10.0
lr,0.001
train_accuracy,75.39692
train_loss,1.17337
val_accuracy,82.2
val_loss,1.07479


[34m[1mwandb[0m: Agent Starting Run: vli9wcvg with config:
[34m[1mwandb[0m: 	augment: True
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.001




Epoch 1: Train Acc: 14.91%, Val Acc: 33.75%
Epoch 2: Train Acc: 37.62%, Val Acc: 58.70%
Epoch 3: Train Acc: 55.99%, Val Acc: 69.50%
Epoch 4: Train Acc: 63.86%, Val Acc: 74.25%
Epoch 5: Train Acc: 68.53%, Val Acc: 76.85%
Epoch 6: Train Acc: 71.05%, Val Acc: 78.80%
Epoch 7: Train Acc: 73.65%, Val Acc: 79.20%
Epoch 8: Train Acc: 74.25%, Val Acc: 80.20%
Epoch 9: Train Acc: 75.21%, Val Acc: 81.15%
Epoch 10: Train Acc: 75.78%, Val Acc: 81.70%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▄▆▇▇▇████
train_loss,█▇▆▆▅▄▃▂▂▁
val_accuracy,▁▅▆▇▇█████
val_loss,█▇▆▅▅▄▃▂▂▁

0,1
best_val_acc,81.7
epoch,10.0
lr,0.001
train_accuracy,75.78447
train_loss,1.16571
val_accuracy,81.7
val_loss,1.06896


[34m[1mwandb[0m: Agent Starting Run: 4jn2jijl with config:
[34m[1mwandb[0m: 	augment: True
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.001




Epoch 1: Train Acc: 44.26%, Val Acc: 73.75%
Epoch 2: Train Acc: 71.12%, Val Acc: 80.70%
Epoch 3: Train Acc: 74.97%, Val Acc: 82.45%
Epoch 4: Train Acc: 76.50%, Val Acc: 83.25%
Epoch 5: Train Acc: 78.00%, Val Acc: 84.15%
Epoch 6: Train Acc: 78.45%, Val Acc: 84.65%
Epoch 7: Train Acc: 78.70%, Val Acc: 84.90%
Epoch 8: Train Acc: 79.20%, Val Acc: 84.95%
Epoch 9: Train Acc: 79.91%, Val Acc: 84.75%
Epoch 10: Train Acc: 80.15%, Val Acc: 85.40%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▆▇▇██████
train_loss,█▅▃▂▂▂▁▁▁▁
val_accuracy,▁▅▆▇▇█████
val_loss,█▅▃▂▂▂▁▁▁▁

0,1
best_val_acc,85.4
epoch,10.0
lr,0.001
train_accuracy,80.14752
train_loss,0.67872
val_accuracy,85.4
val_loss,0.54554


[34m[1mwandb[0m: Agent Starting Run: p5rw2eps with config:
[34m[1mwandb[0m: 	augment: False
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1: Train Acc: 74.35%, Val Acc: 85.10%
Epoch 2: Train Acc: 87.01%, Val Acc: 85.75%
Epoch 3: Train Acc: 89.20%, Val Acc: 86.50%
Epoch 4: Train Acc: 91.20%, Val Acc: 86.60%
Epoch 5: Train Acc: 92.42%, Val Acc: 86.50%
Epoch 6: Train Acc: 93.74%, Val Acc: 87.00%
Epoch 7: Train Acc: 95.05%, Val Acc: 87.15%
Epoch 8: Train Acc: 96.16%, Val Acc: 87.00%
Epoch 9: Train Acc: 97.05%, Val Acc: 86.90%
Epoch 10: Train Acc: 97.42%, Val Acc: 86.95%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,███████▁▁▁
train_accuracy,▁▅▆▆▆▇▇███
train_loss,█▄▃▂▂▂▁▁▁▁
val_accuracy,▁▃▆▆▆▇█▇▇▇
val_loss,█▄▂▁▁▂▂▂▂▂

0,1
best_val_acc,87.15
epoch,10.0
lr,0.001
train_accuracy,97.42468
train_loss,0.11205
val_accuracy,86.95
val_loss,0.44287


[34m[1mwandb[0m: Agent Starting Run: 87kmqdhx with config:
[34m[1mwandb[0m: 	augment: True
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1: Train Acc: 48.12%, Val Acc: 79.70%
Epoch 2: Train Acc: 75.92%, Val Acc: 82.80%
Epoch 3: Train Acc: 78.25%, Val Acc: 83.65%
Epoch 4: Train Acc: 79.10%, Val Acc: 84.40%
Epoch 5: Train Acc: 79.68%, Val Acc: 84.85%
Epoch 6: Train Acc: 80.84%, Val Acc: 85.25%
Epoch 7: Train Acc: 81.34%, Val Acc: 85.75%
Epoch 8: Train Acc: 81.39%, Val Acc: 85.80%
Epoch 9: Train Acc: 81.46%, Val Acc: 86.35%
Epoch 10: Train Acc: 82.15%, Val Acc: 86.15%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▇▇▇▇█████
train_loss,█▃▂▂▂▁▁▁▁▁
val_accuracy,▁▄▅▆▆▇▇▇██
val_loss,█▃▂▂▂▁▁▁▁▁

0,1
best_val_acc,86.35
epoch,10.0
lr,0.01
train_accuracy,82.14777
train_loss,0.57153
val_accuracy,86.15
val_loss,0.46555


[34m[1mwandb[0m: Agent Starting Run: 7tgxxcw3 with config:
[34m[1mwandb[0m: 	augment: False
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1: Train Acc: 74.62%, Val Acc: 84.70%
Epoch 2: Train Acc: 86.90%, Val Acc: 85.50%
Epoch 3: Train Acc: 89.34%, Val Acc: 86.30%
Epoch 4: Train Acc: 91.05%, Val Acc: 86.95%
Epoch 5: Train Acc: 92.87%, Val Acc: 87.10%
Epoch 6: Train Acc: 93.94%, Val Acc: 87.00%
Epoch 7: Train Acc: 95.12%, Val Acc: 87.05%
Epoch 8: Train Acc: 96.20%, Val Acc: 87.00%
Epoch 9: Train Acc: 96.86%, Val Acc: 86.70%
Epoch 10: Train Acc: 97.90%, Val Acc: 87.00%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,████████▁▁
train_accuracy,▁▅▅▆▆▇▇▇██
train_loss,█▄▃▃▂▂▂▁▁▁
val_accuracy,▁▃▆█████▇█
val_loss,█▄▂▂▁▁▁▂▃▃

0,1
best_val_acc,87.1
epoch,10.0
lr,0.001
train_accuracy,97.89974
train_loss,0.09434
val_accuracy,87.0
val_loss,0.45085


[34m[1mwandb[0m: Agent Starting Run: onvjckzq with config:
[34m[1mwandb[0m: 	augment: True
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1: Train Acc: 46.97%, Val Acc: 80.40%
Epoch 2: Train Acc: 75.12%, Val Acc: 82.85%
Epoch 3: Train Acc: 77.75%, Val Acc: 83.90%
Epoch 4: Train Acc: 78.41%, Val Acc: 84.55%
Epoch 5: Train Acc: 80.34%, Val Acc: 85.25%
Epoch 6: Train Acc: 80.64%, Val Acc: 85.35%
Epoch 7: Train Acc: 81.04%, Val Acc: 85.75%
Epoch 8: Train Acc: 82.05%, Val Acc: 85.75%
Epoch 9: Train Acc: 81.41%, Val Acc: 85.90%
Epoch 10: Train Acc: 82.59%, Val Acc: 86.15%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▇▇▇██████
train_loss,█▄▂▂▂▁▁▁▁▁
val_accuracy,▁▄▅▆▇▇████
val_loss,█▃▂▂▂▁▁▁▁▁

0,1
best_val_acc,86.15
epoch,10.0
lr,0.01
train_accuracy,82.58532
train_loss,0.56197
val_accuracy,86.15
val_loss,0.45918


[34m[1mwandb[0m: Agent Starting Run: 4fobcav1 with config:
[34m[1mwandb[0m: 	augment: False
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1: Train Acc: 50.94%, Val Acc: 80.00%
Epoch 2: Train Acc: 81.36%, Val Acc: 83.35%
Epoch 3: Train Acc: 84.15%, Val Acc: 84.30%
Epoch 4: Train Acc: 85.95%, Val Acc: 85.25%
Epoch 5: Train Acc: 86.80%, Val Acc: 85.70%
Epoch 6: Train Acc: 87.82%, Val Acc: 85.80%
Epoch 7: Train Acc: 88.50%, Val Acc: 86.15%
Epoch 8: Train Acc: 89.25%, Val Acc: 86.20%
Epoch 9: Train Acc: 89.86%, Val Acc: 86.35%
Epoch 10: Train Acc: 90.22%, Val Acc: 86.35%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▆▇▇▇█████
train_loss,█▄▂▂▂▁▁▁▁▁
val_accuracy,▁▅▆▇▇▇████
val_loss,█▃▂▂▁▁▁▁▁▁

0,1
best_val_acc,86.35
epoch,10.0
lr,0.01
train_accuracy,90.22378
train_loss,0.33171
val_accuracy,86.35
val_loss,0.45207


In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np
import wandb


# ------------------------ Sweep Config ------------------------ #
sweep_config = {
    "method": "random",
    "metric": {"name": "val_accuracy", "goal": "maximize"},
    "parameters": {
        "augment": {"values": [True, False]},
        "batch_size": {"values": [64, 256]},
        "lr": {"values": [0.01, 0.001]},
        "epochs": {"value": 10}
    }
}


# ------------------------ Data Loader ------------------------ #
def get_dataloaders(data_dir, batch_size=256, val_split=0.2, augment=True):
    train_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(0.2, 0.2, 0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]) if augment else transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    full_dataset = datasets.ImageFolder(root=data_dir, transform=train_transforms)
    targets = np.array(full_dataset.targets)

    splitter = StratifiedShuffleSplit(n_splits=1, test_size=val_split, random_state=42)
    train_idx, val_idx = next(splitter.split(np.zeros(len(targets)), targets))

    train_set = Subset(full_dataset, train_idx)
    val_set = Subset(datasets.ImageFolder(root=data_dir, transform=val_transforms), val_idx)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader, full_dataset.classes


# ------------------------ Training Function ------------------------ #
def train(model, train_loader, val_loader, optimizer, criterion, scheduler, device, epochs=30):
    model.to(device)
    best_val_acc = 0

    for epoch in range(epochs):
        # Progressive Unfreezing
        if epoch == 5:
            for name, param in model.named_parameters():
                if "encoder.layer.10" in name or "encoder.layer.11" in name:
                    param.requires_grad = True

        model.train()
        train_loss, correct = 0, 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            correct += (out.argmax(1) == y).sum().item()
        train_acc = 100. * correct / len(train_loader.dataset)

        # Validation
        model.eval()
        val_loss, val_correct = 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                out = model(x)
                loss = criterion(out, y)
                val_loss += loss.item()
                val_correct += (out.argmax(1) == y).sum().item()
        val_acc = 100. * val_correct / len(val_loader.dataset)
        scheduler.step(val_loss)

        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss / len(train_loader),
            "train_accuracy": train_acc,
            "val_loss": val_loss / len(val_loader),
            "val_accuracy": val_acc,
            "lr": optimizer.param_groups[0]['lr']
        })

        print(f"Epoch {epoch+1}: Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model_vit.pth")

    return best_val_acc


# ------------------------ Main Function ------------------------ #
def main():
    wandb.init(project="DL_A2")
    config = wandb.config
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_loader, val_loader, classes = get_dataloaders(
        data_dir="/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/train",
        batch_size=config.batch_size,
        augment=config.augment
    )

    # Load pre-trained ViT
    vit = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)

    # Freeze all layers first
    for param in vit.parameters():
        param.requires_grad = False

    # Replace classifier head
    in_features = vit.heads[0].in_features
    vit.heads = nn.Sequential(
        nn.Linear(in_features, 512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, len(classes))
    )

    # Enable training on new head
    for param in vit.heads.parameters():
        param.requires_grad = True

    optimizer = optim.SGD(vit.parameters(), lr=config.lr, momentum=0.9, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)
    criterion = nn.CrossEntropyLoss()

    best_val_acc = train(
        model=vit,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        scheduler=scheduler,
        device=device,
        epochs=config.epochs
    )

    wandb.summary["best_val_acc"] = best_val_acc
    wandb.finish()

In [3]:
# ------------------------ Start Sweep ------------------------ #
sweep_id = wandb.sweep(sweep_config, project="DL_A2")
wandb.agent(sweep_id, function=main, count=3)

Create sweep with ID: ptr2tahu
Sweep URL: https://wandb.ai/cs24m016-indian-institute-of-technology-madras/DL_A2/sweeps/ptr2tahu


[34m[1mwandb[0m: Agent Starting Run: myykm01c with config:
[34m[1mwandb[0m: 	augment: True
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.001


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:01<00:00, 234MB/s] 


Epoch 1: Train Acc: 42.28%, Val Acc: 75.15%
Epoch 2: Train Acc: 71.26%, Val Acc: 81.55%
Epoch 3: Train Acc: 75.53%, Val Acc: 82.80%
Epoch 4: Train Acc: 76.85%, Val Acc: 83.35%
Epoch 5: Train Acc: 77.63%, Val Acc: 84.20%
Epoch 6: Train Acc: 77.57%, Val Acc: 84.25%
Epoch 7: Train Acc: 78.61%, Val Acc: 84.45%
Epoch 8: Train Acc: 79.55%, Val Acc: 84.60%
Epoch 9: Train Acc: 79.43%, Val Acc: 85.10%
Epoch 10: Train Acc: 79.41%, Val Acc: 85.30%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▆▇▇██████
train_loss,█▅▃▂▂▂▁▁▁▁
val_accuracy,▁▅▆▇▇▇▇███
val_loss,█▅▃▂▂▂▁▁▁▁

0,1
best_val_acc,85.3
epoch,10.0
lr,0.001
train_accuracy,79.40993
train_loss,0.6993
val_accuracy,85.3
val_loss,0.54427


[34m[1mwandb[0m: Agent Starting Run: 219evrtp with config:
[34m[1mwandb[0m: 	augment: True
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1: Train Acc: 69.58%, Val Acc: 83.95%
Epoch 2: Train Acc: 80.20%, Val Acc: 86.00%
Epoch 3: Train Acc: 81.52%, Val Acc: 86.60%
Epoch 4: Train Acc: 82.50%, Val Acc: 86.75%
Epoch 5: Train Acc: 83.26%, Val Acc: 86.90%
Epoch 6: Train Acc: 84.32%, Val Acc: 87.25%
Epoch 7: Train Acc: 84.67%, Val Acc: 87.20%
Epoch 8: Train Acc: 84.84%, Val Acc: 87.15%
Epoch 9: Train Acc: 85.36%, Val Acc: 87.40%
Epoch 10: Train Acc: 86.07%, Val Acc: 87.75%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▆▆▆▇▇▇▇██
train_loss,█▃▃▂▂▂▁▁▁▁
val_accuracy,▁▅▆▆▆▇▇▇▇█
val_loss,█▅▃▂▂▁▁▁▁▁

0,1
best_val_acc,87.75
epoch,10.0
lr,0.01
train_accuracy,86.07326
train_loss,0.43099
val_accuracy,87.75
val_loss,0.39


[34m[1mwandb[0m: Agent Starting Run: be2ih0nn with config:
[34m[1mwandb[0m: 	augment: True
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	epochs: 10
[34m[1mwandb[0m: 	lr: 0.01




Epoch 1: Train Acc: 48.14%, Val Acc: 80.85%
Epoch 2: Train Acc: 76.13%, Val Acc: 83.20%
Epoch 3: Train Acc: 78.01%, Val Acc: 84.15%
Epoch 4: Train Acc: 79.82%, Val Acc: 84.55%
Epoch 5: Train Acc: 80.02%, Val Acc: 85.30%
Epoch 6: Train Acc: 80.95%, Val Acc: 85.30%
Epoch 7: Train Acc: 81.42%, Val Acc: 85.95%
Epoch 8: Train Acc: 81.90%, Val Acc: 86.25%
Epoch 9: Train Acc: 82.04%, Val Acc: 85.75%
Epoch 10: Train Acc: 82.37%, Val Acc: 85.95%


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▇▇▇██████
train_loss,█▃▂▂▂▁▁▁▁▁
val_accuracy,▁▄▅▆▇▇██▇█
val_loss,█▃▂▂▂▁▁▁▁▁

0,1
best_val_acc,86.25
epoch,10.0
lr,0.01
train_accuracy,82.3728
train_loss,0.5759
val_accuracy,85.95
val_loss,0.45992


In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
from sklearn.metrics import confusion_matrix
import random


def evaluate_vit_test(model_path, test_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Transforms
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

    test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2)
    class_names = test_dataset.classes

    # Load pre-trained ViT
    model = models.vit_b_16(pretrained=False)

    # Rebuild classifier head to match training setup
    in_features = model.heads[0].in_features
    model.heads = nn.Sequential(
        nn.Linear(in_features, 512),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(512, len(class_names))
    )

    # Load trained weights
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    all_preds, all_labels = [], []
    correct = 0
    sample_imgs = []

    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            sample_imgs.extend(imgs.cpu())

    accuracy = 100 * correct / len(test_dataset)
    print(f"Test Accuracy: {accuracy:.2f}%")

    wandb.init(project="DL_A2", name="ViT Test Evaluation")
    wandb.log({"vit_test_accuracy": accuracy})

    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap='Blues')
    plt.title("Confusion Matrix - ViT")
    plt.ylabel("True Label")
    plt.xlabel("Predicted Label")
    plt.xticks(rotation=45)
    wandb.log({"confusion_matrix_vit": wandb.Image(plt)})
    plt.close()

    # Sample prediction grid
    def show_predictions(images, labels, preds):
        fig, axes = plt.subplots(5, 6, figsize=(15, 10))
        for idx, ax in enumerate(axes.flat):
            if idx >= len(images): break
            img = images[idx].permute(1, 2, 0).numpy()
            img = 0.5 * img + 0.5  # unnormalize
            ax.imshow(np.clip(img, 0, 1))
            color = 'green' if labels[idx] == preds[idx] else 'red'
            ax.set_title(f"True: {class_names[labels[idx]]}\nPred: {class_names[preds[idx]]}", color=color, fontsize=8)
            ax.axis('off')
        plt.tight_layout()
        return fig

    indices = random.sample(range(len(sample_imgs)), min(30, len(sample_imgs)))
    sample_imgs_subset = [sample_imgs[i] for i in indices]
    sample_labels_subset = [all_labels[i] for i in indices]
    sample_preds_subset = [all_preds[i] for i in indices]

    fig = show_predictions(sample_imgs_subset, sample_labels_subset, sample_preds_subset)
    wandb.log({"vit_test_predictions": wandb.Image(fig)})
    plt.close(fig)

    wandb.finish()
    return accuracy


# Example usage
if __name__ == "__main__":
    evaluate_vit_test(
        model_path="/kaggle/input/vit/pytorch/default/1/vit_best_model.pth",
        test_dir="/kaggle/input/d/d4debeniitm/nature-12k/inaturalist_12K/val"
    )
