<a href="https://colab.research.google.com/github/Sai-sakunthala/Assignment2/blob/main/Assignment_2_partB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install lightning

In [None]:
#import required libraries
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb
import random
from collections import defaultdict
from torch.utils.data import Subset
from torchvision.models import efficientnet_v2_m, EfficientNet_V2_M_Weights

In [None]:
class FineTunedModel(pl.LightningModule):
    def __init__(self, num_classes=10, freeze_k = 2, unfreeze_every=2, dropout_prob = 0.4):
        super(FineTunedModel, self).__init__()

        #loading EfficientNet_V2_M model
        self.model = efficientnet_v2_m(weights=EfficientNet_V2_M_Weights.IMAGENET1K_V1)

        #initial freeze blocks
        self.freeze_k = freeze_k

        #after how many epochs the unfreezing happens
        self.unfreeze_every = unfreeze_every

        #total number of blocks
        self.total_blocks = len(self.model.features)

        #freeze k blocks
        for i, block in enumerate(self.model.features):
            if i < freeze_k:
                for param in block.parameters():
                    param.requires_grad = False

        #add dropout and change the final classification layer
        self.model.classifier[1] = nn.Sequential(
            nn.Dropout(p=dropout_prob),
            nn.Linear(self.model.classifier[1].in_features, num_classes)
        )

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        #adam with weight decay and lr scheduler
        optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-5, weight_decay=5e-5)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        #train in batches
        images, labels = batch
        outputs = self.model(images)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        acc = (outputs.argmax(dim=1) == labels).float().mean()

        #log metrics
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_accuracy', acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        #validation in batches
        images, labels = batch
        outputs = self.model(images)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        acc = (outputs.argmax(dim=1) == labels).float().mean()

        #log metrics
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_accuracy', acc, prog_bar=True)
        return loss

    def on_train_epoch_start(self):
        #unfreezing after unfreeze_every epochs
        if self.current_epoch % self.unfreeze_every == 0:

            #new k after unfreezing
            new_k = self.freeze_k + self.current_epoch // self.unfreeze_every

            #make required gradient as true for newly unfrozen layers
            if new_k > self.freeze_k and new_k < self.total_blocks:
                for i in range(self.freeze_k, new_k + 1):
                    for param in self.model.features[i].parameters():
                        param.requires_grad = True

                #update freeze_k
                self.freeze_k = new_k + 1

In [None]:
def train():
        #for reproducibility
        random.seed(42)
        torch.manual_seed(42)

        #initialize wandb project
        wandb.init(project="inaturalist_finetune", name="efficient_net_4")
        wandb_logger = WandbLogger(project="inaturalist_finetune", name="efficient_net_4")

        #augment data and resize to fit the efficientnet image dimentions
        transform_list = [
                          transforms.Resize((224, 224)),
                          transforms.RandomHorizontalFlip(),
                          transforms.RandomResizedCrop(224),
                          transforms.ToTensor(),
                          transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                          ]

        transform = transforms.Compose(transform_list)

        #non augmented data for validation
        val_transform = val_transform = transforms.Compose([
    			transforms.Resize((128, 128)),
    			transforms.ToTensor(),
        		transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
        		])

        #load training data
        data_dir = "/root/inaturalist_12K/train"
        full_dataset = datasets.ImageFolder(root=data_dir)
        num_classes = len(full_dataset.classes)

        #convert each class to index
        class_to_indices = defaultdict(list)
        for idx, (_, label) in enumerate(full_dataset.samples):
            class_to_indices[label].append(idx)

        #list for splitting to train and val indices
        train_indices = []
        val_indices = []

        #get indices
        for label, indices in class_to_indices.items():
            random.shuffle(indices)
            split = int(0.8 * len(indices))
            train_indices.extend(indices[:split])
            val_indices.extend(indices[split:])

        random.shuffle(train_indices)

        #load train and val datasets
        train_dataset = Subset(datasets.ImageFolder(root = data_dir, transform = transform), train_indices)
        val_dataset = Subset(datasets.ImageFolder(root = data_dir, transform = val_transform), val_indices)

        train_loader = DataLoader(train_dataset, 64, shuffle=True, num_workers=2, pin_memory=True)
        val_loader = DataLoader(val_dataset, 64, shuffle=False, num_workers=2, pin_memory=True)

        class_names = full_dataset.classes

        #initialize model with our required configurations
        model = FineTunedModel(num_classes, 2, 2, 0.4)

        #add callbacks
        callbacks = [
            #pl.callbacks.EarlyStopping(monitor="val_acc", patience = 5),
            pl.callbacks.ModelCheckpoint(monitor="val_acc", mode="max", save_top_k=1)
        ]

        #train model
        trainer = pl.Trainer(
            max_epochs=25,
            precision=16,
            logger=wandb_logger,
            accelerator="gpu",
            devices=1,
            callbacks=callbacks,
            gradient_clip_val=0.5
        )
        try:
            trainer.fit(model, train_loader, val_loader)
        finally:
            wandb.finish()

#call train function to train
train()