<a href="https://colab.research.google.com/github/AndoorAlanD/DA6401-Assignment-2/blob/main/PartA/Part_A_Question_4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import wandb
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict

In [None]:
import wandb
import os

os.environ['WANDB_API_KEY'] = '1ffc33d77af0fd022201ec32b81cd0e92cd75821'
wandb.login()


[34m[1mwandb[0m: Currently logged in as: [33malandandoor[0m ([33malandandoor-iit-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
#'name': 'trial',
sweep_config = {
    'name': 'Best_Model',
    'method': 'bayes',
    'metric': {
      'name': 'val_accuracy',
      'goal': 'maximize'
    },
    'parameters': {
        'kernel_size':{
            'values': [[3,3,3,3,3]]
        },
        'num_epochs':{
            'values': [10]
        },
        'dropout': {
            'values': [0.3]
        },
        'lr': {
            'values': [0.001]
        },
        'activation': {
            'values': ['Mish']
        },
        'optimizer': {
            'values': ['adam']
        },
        'batch_norm':{
            'values': ['true']
        },
        'filt_org':{
            'values': ['same']
        },
        'num_filters': {
            'values': [64]
        },
        'data_aug': {
            'values': ['true']
        },
        'batch_size': {
            'values': [128]
        },
        'num_dense':{
            'values': [128]
        }
    }
}


sweep_id = wandb.sweep(sweep=sweep_config, project='DL_A2')

Create sweep with ID: 5s6ki4zb
Sweep URL: https://wandb.ai/alandandoor-iit-madras/DL_A2/sweeps/5s6ki4zb


In [None]:
class CNN(nn.Module):
    def __init__(self, config, num_classes=10):
        super(CNN, self).__init__()
        self.config = config
        self.num_epochs = config.num_epochs

        self.to(device)

        self.build_transforms()
        self.prepare_data()
        self.build_model(num_classes)
        self.build_training_utils()


    def build_transforms(self):
        base_transform = [
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5,), std=(0.5,))
        ]

        augmented_transform = [
            transforms.Resize((256, 256)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(20),
            transforms.ColorJitter(),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5,), std=(0.5,))
        ]

        self.transform = transforms.Compose(base_transform)
        self.transform_aug = transforms.Compose(augmented_transform)

    def prepare_data(self):
        train_transform = self.transform_aug if self.config.data_aug == 'true' else self.transform

        self.train_dataset = torchvision.datasets.ImageFolder(
            root='/kaggle/input/dl-assignment-2/inaturalist_12K/train',
            transform=train_transform
        )
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(self.train_dataset, [7999, 2000])

        self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.config.batch_size, shuffle=True)
        self.val_loader = torch.utils.data.DataLoader(self.val_dataset, batch_size=self.config.batch_size, shuffle=True)
        self.test_dataset = torchvision.datasets.ImageFolder(
            root='/kaggle/input/dl-assignment-2/inaturalist_12K/val',
            transform=self.transform
        )
        self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=self.config.batch_size, shuffle=True)

    def build_model(self, num_classes):
        if self.config.filt_org == 'half':
            self.filt_size = 0.5
        elif self.config.filt_org == 'double':
            self.filt_size = 2
        else:
            self.filt_size = 1


        inp_fl = 3
        out_fl = self.config.num_filters
        self.convL1 = nn.Conv2d(inp_fl, out_fl, self.config.kernel_size[0], stride=1, padding=1)
        self.batN1 = nn.BatchNorm2d(out_fl)

        inp_fl = out_fl
        out_fl = int(out_fl * self.filt_size)
        self.convL2 = nn.Conv2d(inp_fl, out_fl, self.config.kernel_size[1], stride=1, padding=1)
        self.batN2 = nn.BatchNorm2d(out_fl)

        inp_fl = out_fl
        out_fl = int(out_fl * self.filt_size)
        self.convL3 = nn.Conv2d(inp_fl, out_fl, self.config.kernel_size[2], stride=1, padding=1)
        self.batN3 = nn.BatchNorm2d(out_fl)

        inp_fl = out_fl
        out_fl = int(out_fl * self.filt_size)
        self.convL4 = nn.Conv2d(inp_fl, out_fl, self.config.kernel_size[3], stride=1, padding=1)
        self.batN4 = nn.BatchNorm2d(out_fl)

        inp_fl = out_fl
        out_fl = int(out_fl * self.filt_size)
        self.convL5 = nn.Conv2d(inp_fl, out_fl, self.config.kernel_size[4], stride=1, padding=1)
        self.batN5 = nn.BatchNorm2d(out_fl)

        self.maxPool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        img_size = 256
        for k in self.config.kernel_size:
            img_size = (img_size - k + 3) // 2

        self.x_shape = out_fl * img_size * img_size

        self.f_Conn = nn.Linear(self.x_shape, self.config.num_dense)
        self.batN_de = nn.BatchNorm1d(self.config.num_dense)
        self.dropout = nn.Dropout(p=self.config.dropout)
        self.opL = nn.Linear(self.config.num_dense, num_classes)

        if self.config.activation == 'ReLU':
            self.activation = F.relu
        elif self.config.activation == 'SiLU':
            self.activation = F.silu
        elif self.config.activation == 'GELU':
            self.activation = F.gelu
        else:
            self.activation = F.mish

    def build_training_utils(self):
        self.criterion = nn.CrossEntropyLoss()
        optimizers = {
            'adam': optim.Adam,
            'nadam': optim.NAdam
        }
        self.optimizer = optimizers[self.config.optimizer](self.parameters(), lr=self.config.lr)

    def forward(self, x):
        x=self.activation(self.convL1(x))
        y=self.config.batch_norm
        if y== 'true': x = self.batN1(x)
        x = self.maxPool(x)

        x = self.activation(self.convL2(x))
        if y== 'true': x = self.batN2(x)
        x = self.maxPool(x)

        x = self.activation(self.convL3(x))
        if y== 'true': x = self.batN3(x)
        x = self.maxPool(x)

        x = self.activation(self.convL4(x))
        if y== 'true': x = self.batN4(x)
        x = self.maxPool(x)

        x = self.activation(self.convL5(x))
        if y== 'true': x = self.batN5(x)
        x = self.maxPool(x)

        x = x.view(-1, self.x_shape)
        x = self.activation(self.f_Conn(x))
        if y== 'true': x = self.batN_de(x)
        x = self.dropout(x)
        y_pred = self.opL(x)
        return y_pred

    def accuracy(self, loader):
        accurate, total, loss = 0, 0, 0
        self.eval()
        with torch.no_grad():
            for x, y_act in loader:
                x, y_act=x.to(device), y_act.to(device)
                result=self(x)
                batch_size=y_act.size(0)
                total+=batch_size
                _, y_pred=torch.max(result.data, 1)
                accurate+=torch.sum(y_pred==y_act).item()
                batch_loss=self.criterion(result, y_act).item()
                loss +=batch_loss*batch_size

        self.train()
        return accurate/total, loss/total

    def visualize_predictions(self):

        class_names = {
            0: 'Amphibia', 1: 'Animalia', 2: 'Arachnida', 3: 'Aves', 4: 'Fungi',
            5: 'Insecta', 6: 'Mammalia', 7: 'Mollusca', 8: 'Plantae', 9: 'Reptilia'
        }
        class_samples = defaultdict(list)

        self.eval()
        rows, cols = 10, 3
        fig, axes = plt.subplots(rows, cols, figsize=(9, 30))
        fig.tight_layout(pad=3.0)

        with torch.no_grad():
            for x, y in self.test_loader:
                x, y = x.to(device), y.to(device)
                outputs = self(x)
                preds = torch.argmax(outputs, dim=1)

                for img, label, pred in zip(x, y, preds):
                    label_id = label.item()
                    if len(class_samples[label_id]) < 3:
                        class_samples[label_id].append((img.cpu(), label_id, pred.item()))
                    if all(len(class_samples[i]) == 3 for i in range(10)):
                        break
                if all(len(class_samples[i]) == 3 for i in range(10)):
                    break

        for class_id in range(10):
            for j in range(3):
                ax = axes[class_id, j]
                if j >= len(class_samples[class_id]):
                    ax.axis('off')
                    continue

                img, gt, pred = class_samples[class_id][j]
                img = img.permute(1, 2, 0).numpy()
                img = img * 0.5 + 0.5

                ax.imshow(img)
                actual = class_names[gt]
                predicted = class_names[pred]
                title_color = "green" if gt == pred else "red"
                ax.set_title(f"Actual Class: {actual}\n Predicted Class: {predicted}", color=title_color, fontsize=9)
                ax.axis('off')

        fig.suptitle("Class-wise Sample Predictions (3 per class)", fontsize=16, fontweight='bold', y=1.01)
        plt.subplots_adjust(hspace=0.8)
        wandb.log({"Sample Prediction Grid (3 Images/Class)": wandb.Image(fig)})
        plt.close(fig)


    def train_model(self):
        total_size = len(self.train_loader)
        for epoch in range(self.num_epochs):
            tr_loss = 0
            accurate = 0
            for i, (x, y_act) in enumerate(self.train_loader):

                x, y_act = x.to(device), y_act.to(device)
                result = self(x)
                loss = self.criterion(result, y_act)

                y_pred = torch.argmax(result.data, dim=1)
                accurate += torch.sum(y_pred == y_act).item()

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                tr_loss += loss.item()
                if (i+1)%25 == 0:
                    print(f"Epoch [{epoch + 1}/{self.num_epochs}]| Step [{i + 1}/{total_size}]")

            tr_loss /= total_size
            tr_acc = accurate / (total_size * self.config.batch_size)

            val_acc, val_loss = self.accuracy(self.val_loader)
            test_acc, test_loss = self.accuracy(self.test_loader)

            tr_acc*=100
            val_acc*=100
            test_acc*=100

            print("Train Accuracy:", tr_acc, "\nTrain Loss:", tr_loss)
            print("Validation Accuracy:", val_acc, "\nValidation Loss:", val_loss)
            print("Test Accuracy:", test_acc, "\nTest Loss:", test_loss, "\n")
            wandb.log({'train_accuracy': tr_acc,'train_loss': tr_loss,'val_accuracy': val_acc,'val_loss': val_loss,'test_accuracy': test_acc,'test_loss': test_loss,})
        self.visualize_predictions()


In [None]:
def main():
    with wandb.init() as run:
        config=wandb.config

        bn=int(config.batch_norm == 'true')
        da=int(config.data_aug == 'true')
        ks=''.join(str(config.kernel_size[i]) for i in range(0, 5))

        wandb.run.name = (
            f"{config.activation}-{config.optimizer}-bn_{bn}-da_{da}-do_{config.dropout}-bs_{config.batch_size}"
            f"-lr_{config.lr}-f_{config.num_filters}-{config.filt_org}-ks_{ks}-fc_{config.num_dense}"
        )

        model=CNN(config, num_classes=10).to(device)
        model.train_model()

wandb.agent(sweep_id, function=main)
wandb.finish()