<a href="https://colab.research.google.com/github/Sai-sakunthala/Assignment2/blob/main/Assignment_2_partA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#install pytorch
!pip install pytorch-lightning

In [None]:
#import required libraries
import os
import torch
from torch import nn
import torch.nn.functional as functional
from torch.utils.data import DataLoader, random_split, Subset
import pytorch_lightning as pl
from torchvision import transforms, datasets
from collections import defaultdict
import random
from pytorch_lightning.loggers import WandbLogger
import wandb

In [None]:
class CNN(pl.LightningModule):
    def __init__(self, initial_in_channels=3, num_classes=10, num_conv_layers=5, num_filters=64, kernel_size=3, activation_fn=nn.SiLU,
                 dense_neurons=256, learning_rate=1e-3, use_batchnorm=True, dropout_rate=0.3, filter_organization='same', data_augmentation = True):

        super().__init__()
        self.save_hyperparameters()

        # initialize a list to save all convolution layers
        layers_conv = []

        #number of imput images channels which is 3 in our case
        input_channels = initial_in_channels

        #variable to track filters in current layer
        current_filters = num_filters

        #loop over number of convolution layers
        for i in range(num_conv_layers):
            #number of output channels needed
            out_channels = current_filters

            #convolution layer with padding
            layers_conv.append(nn.Conv2d(input_channels, out_channels, kernel_size = kernel_size, padding = kernel_size//2))

            #if batch normalization is specified use it
            if use_batchnorm:
                layers_conv.append(nn.BatchNorm2d(out_channels))

            #activation layer
            layers_conv.append(activation_fn())

            #dropout is added after activation layer
            if dropout_rate == 0:
                layers_conv.append(nn.Dropout(dropout_rate))

            #maxpool layer
            layers_conv.append(nn.MaxPool2d(kernel_size=2, stride=2))

            #update input channels
            input_channels = out_channels

            #update number of filters for following layers based on configuration
            if filter_organization == 'double':
                current_filters *= 2
            elif filter_organization == 'half':
                current_filters = max(4, current_filters // 2)

        #add all layers as convolution block
        self.conv_block = nn.Sequential(*layers_conv)

        #dense layer
        self.fc1 = nn.LazyLinear(dense_neurons)
        self.bn_fc1 = nn.BatchNorm1d(dense_neurons) if use_batchnorm else None
        self.activation_dense = activation_fn()
        self.dropout_fc1 = nn.Dropout(dropout_rate) if dropout_rate == 0 else None

        #final classification layer
        self.fc2 = nn.Linear(dense_neurons, num_classes)
        self.learning_rate = learning_rate

    def forward(self, x):
        #forward propagation of network
        x = self.conv_block(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        if self.hparams.use_batchnorm:
            x = self.bn_fc1(x)
        x = self.activation_dense(x)
        if self.hparams.dropout_rate == 0:
            x = self.dropout_fc1(x)
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        #training in batches
        x, y = batch
        y_hat = self(x)
        loss = functional.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()

        #log metrics
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        #validation in batches
        x, y = batch
        y_hat = self(x)
        loss = functional.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()

        #log metrics
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        #useful for testing the model later
        x, y = batch
        y_hat = self(x)
        loss = functional.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()

        #log metrics
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", acc, prog_bar=True)
        return {"test_loss": loss, "test_acc": acc}

    def configure_optimizers(self):
        #adam optimizer with weightdecay
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay = 5e-5)

        #learning rate scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
        return [optimizer], [scheduler]

In [None]:
#sweep configuration with wandb
sweep_config = {
    'method': 'grid',
    'metric': {
        'name': 'val_acc',
        'goal': 'maximize'
    },
    'parameters': {
        'num_filters': {'values': [64]},
        'activation_fn': {'values': ['GELU', 'SiLU', "Mish"]},
        'filter_organization': {'values': ['same', 'double']},
        'use_batchnorm': {'values': [True]},
        'dropout_rate': {'values': [0.3]},
        'dense_neurons': {'values': [256, 512]},
        'learning_rate': {'values': [1e-3]},
        'batch_size': {'values': [64]},
        'data_augmentation': {'values': [True]},
        'kernel_size': {'values': [3]},
    }
}

In [None]:
#define activation functions
def get_activation(name):
    return {
        "ReLU": nn.ReLU,
        "GELU": nn.GELU,
        "SiLU": nn.SiLU,
        "Mish": nn.Mish
    }[name]

#train function
def train(config=None):
    with wandb.init(config=config) as run:
        #for reproducibility of results
        random.seed(42)
        torch.manual_seed(42)

        #load configuration from sweep
        config = wandb.config

        #rename runs
        run.name = f"{config.activation_fn}_{config.filter_organization}_r{config.dropout_rate}_fc{config.dense_neurons}"
        run.save()

        #wandb logger
        wandb_logger = WandbLogger(project="cnn-sweep", log_model='all')

        #augmentation of data if required
        if config.get("data_augmentation", False):
            transform_list = [
                transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.Resize((128, 128)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]
        else:
            transform_list = [
                transforms.Resize((128, 128)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]

        #non augmented data for validation
        val_transform = val_transform = transforms.Compose([
    			transforms.Resize((128, 128)),
    			transforms.ToTensor(),
        		transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
        		])

        transform = transforms.Compose(transform_list)

        #load training data
        data_dir = "/root/inaturalist_12K/train"
        full_dataset = datasets.ImageFolder(root=data_dir)
        num_classes = len(full_dataset.classes)

        #convert each class to index
        class_to_indices = defaultdict(list)
        for idx, (_, label) in enumerate(full_dataset.samples):
            class_to_indices[label].append(idx)

        #list for splitting to train and val indices
        train_indices = []
        val_indices = []

        #get indices
        for label, indices in class_to_indices.items():
            random.shuffle(indices)
            split = int(0.8 * len(indices))
            train_indices.extend(indices[:split])
            val_indices.extend(indices[split:])

        random.shuffle(train_indices)

        #load train and val datasets
        train_dataset = Subset(datasets.ImageFolder(root = data_dir, transform = transform), train_indices)
        val_dataset = Subset(datasets.ImageFolder(root = data_dir, transform = val_transform), val_indices)

        train_loader = DataLoader(train_dataset, config.batch_size, shuffle=True, num_workers=2, pin_memory=True)
        val_loader = DataLoader(val_dataset, config.batch_size, shuffle=False, num_workers=2, pin_memory=True)

        class_names = full_dataset.classes

        #initialize model with wandb configurations
        model = CNN(
            initial_in_channels=3,
            num_classes=num_classes,
            num_conv_layers=5,
            num_filters=config.num_filters,
            kernel_size=config.kernel_size,
            activation_fn=get_activation(config.activation_fn),
            dense_neurons=config.dense_neurons,
            learning_rate=config.learning_rate,
            use_batchnorm=config.use_batchnorm,
            dropout_rate=config.dropout_rate,
            filter_organization=config.filter_organization,
            data_augmentation=config.data_augmentation
        )

        # Add callbacks
        callbacks = [
            #pl.callbacks.EarlyStopping(monitor="val_acc", patience=5),
            pl.callbacks.ModelCheckpoint(monitor="val_acc", mode="max", save_top_k=1)
        ]

        #train model
        trainer = pl.Trainer(
            max_epochs=15,
            precision=16,
            logger=wandb_logger,
            accelerator="gpu",
            devices=1,
            callbacks=callbacks,
            gradient_clip_val=0.5
        )
        try:
            trainer.fit(model, train_loader, val_loader)
        finally:
            wandb.finish()

#sweep over all configurations
sweep_id = wandb.sweep(sweep_config, project="cnn-sweep")
wandb.agent(sweep_id, function=train, count=1)