<a href="https://colab.research.google.com/github/Sai-sakunthala/Assignment2/blob/main/Assignment_2_partB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install lightning

Collecting lightning
  Downloading lightning-2.5.1-py3-none-any.whl.metadata (39 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Co

In [9]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb
import random
from collections import defaultdict
from torch.utils.data import Subset
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights

In [3]:
!unzip -q '/content/drive/MyDrive/nature_12K.zip' -d '/content/inaturalist_data'

In [10]:
class FineTunedModel(pl.LightningModule):
    def __init__(self, num_classes=10):
        super(FineTunedModel, self).__init__()

        self.model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)

        self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, num_classes)

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        return optimizer

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self.model(images)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        acc = (outputs.argmax(dim=1) == labels).float().mean()
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_accuracy', acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self.model(images)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        acc = (outputs.argmax(dim=1) == labels).float().mean()
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_accuracy', acc, prog_bar=True)
        return loss

In [8]:
def train():
        random.seed(42)
        torch.manual_seed(42)

        wandb.init(project="inaturalist_finetune", name="resnet50_run")

        wandb_logger = WandbLogger(project="inaturalist_finetune", name="resnet50_run")

        transform_list = [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]

        transform = transforms.Compose(transform_list)

        data_dir = "/content/inaturalist_data/inaturalist_12K/train"

        full_dataset = datasets.ImageFolder(root=data_dir, transform=transform)

        num_classes = len(full_dataset.classes)

        class_to_indices = defaultdict(list)
        for idx, (_, label) in enumerate(full_dataset.samples):
            class_to_indices[label].append(idx)

        train_indices = []
        val_indices = []

        for label, indices in class_to_indices.items():
            random.shuffle(indices)
            split = int(0.8 * len(indices))
            train_indices.extend(indices[:split])
            val_indices.extend(indices[split:])

        random.shuffle(train_indices)

        train_dataset = Subset(full_dataset, train_indices)
        val_dataset = Subset(full_dataset, val_indices)

        train_loader = DataLoader(train_dataset, 64, shuffle=True, num_workers=2, pin_memory=True)
        val_loader = DataLoader(val_dataset, 64, shuffle=False, num_workers=2, pin_memory=True)

        class_names = full_dataset.classes

        model = FineTunedModel(num_classes)

        callbacks = [
            #pl.callbacks.EarlyStopping(monitor="val_loss", patience=5),
            pl.callbacks.ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1)
        ]

        trainer = pl.Trainer(
            max_epochs=15,
            precision=16,
            logger=wandb_logger,
            accelerator="gpu",
            devices=1,
            callbacks=callbacks,
            gradient_clip_val=0.5
        )
        try:
            trainer.fit(model, train_loader, val_loader)
        finally:
            wandb.finish()