In [None]:

import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import wandb
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
import os
import subprocess
import zipfile
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall, MulticlassF1Score
from Dataset import Nature12KDataModule
from cnn_model import CNN



In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [None]:
class Nature12KDataModule(pl.LightningDataModule):
    def __init__(self, data_dir="data", batch_size=64, image_size=(512, 512), data_aug=False):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.image_size = image_size
        self.data_aug = data_aug

    def prepare_data(self):
        if os.path.exists(self.data_dir):
            print("✅ Dataset already prepared.")
            return

        zip_path = "iNaturalist.zip"
        url = "https://storage.googleapis.com/wandb_datasets/nature_12K.zip"

        if not os.path.exists(zip_path):
            print("📥 Downloading dataset...")
            subprocess.run(["curl", "-o", zip_path, "-L", url], check=True)
        else:
            print("✅ Zip file already exists.")

        print("📦 Extracting dataset...")
        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(".")
        os.rename("inaturalist_12K", self.data_dir)
        os.rename(os.path.join(self.data_dir, "val"), os.path.join(self.data_dir, "test"))

    @staticmethod
    def get_transform(image_size, data_aug=False):
        transform_list = [transforms.Resize(image_size)]

        if data_aug:
            transform_list += [
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(15),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2)
            ]

        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]

        return transforms.Compose(transform_list)

    def setup(self, stage=None):
        train_transform = self.get_transform(self.image_size, self.data_aug)
        test_transform = self.get_transform(self.image_size, False)

        full_train = datasets.ImageFolder(os.path.join(self.data_dir, "train"), transform=train_transform)
        test_set = datasets.ImageFolder(os.path.join(self.data_dir, "test"), transform=test_transform)

        val_size = int(0.2 * len(full_train))
        train_size = len(full_train) - val_size
        self.train_set, self.val_set = random_split(full_train, [train_size, val_size])
        self.test_set = test_set
        self.class_names = full_train.classes

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=2)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False, num_workers=2)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False, num_workers=2)


In [None]:
class CNN(pl.LightningModule):
    def __init__(self,
                 input_channels,
                 conv_filters,
                 kernel_sizes,
                 activation,
                 dense_neurons,
                 num_classes,
                 lr,
                 batch_norm=False,        # ← ADD THIS
                 dropout=0.0):            # ← AND THIS IF NOT PRESENT
        super().__init__()
        self.save_hyperparameters()


        self.activation_fn = self._get_activation_fn(activation)

        # Conv Layers
        layers = []
        in_channels = input_channels
        for out_channels, ksize in zip(conv_filters, kernel_sizes):
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=ksize, padding=ksize // 2))
    
            if batch_norm:
                layers.append(nn.BatchNorm2d(out_channels))
    
            layers.append(self._get_activation_fn(activation))
            layers.append(nn.MaxPool2d(2, 2))

            if dropout > 0:
                layers.append(nn.Dropout2d(dropout))
    
            in_channels = out_channels

        
        self.conv_blocks = nn.Sequential(*layers)

        # Flattened dim
        with torch.no_grad():
            dummy_input = torch.zeros(1, input_channels, 128, 128)
            dummy_output = self.conv_blocks(dummy_input)
            flatten_dim = dummy_output.view(1, -1).shape[1]

        # Fully connected
        self.classifier = nn.Sequential(
            nn.Linear(flatten_dim, dense_neurons),
            self.activation_fn,
            nn.Linear(dense_neurons, num_classes)
        )

        self.loss_fn = nn.CrossEntropyLoss()
        self.lr = lr

    def _get_activation_fn(self, name):
        name = name.lower()
        if name == 'relu':
            return nn.ReLU()
        elif name == 'gelu':
            return nn.GELU()
        elif name == 'silu':
            return nn.SiLU()
        elif name == 'mish':
            return nn.Mish()
        else:
            raise ValueError(f"Unsupported activation: {name}")

    def forward(self, x):
        x = self.conv_blocks(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        # self.log("test_loss", loss)
        # self.log("test_acc", acc)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


In [None]:
class CNN(pl.LightningModule):
    def __init__(self,
                 input_channels,
                 conv_filters,
                 kernel_sizes,
                 activation,
                 dense_neurons,
                 num_classes,
                 lr,
                 batch_norm=False,        # ← ADD THIS
                 dropout=0.0):            # ← AND THIS IF NOT PRESENT
        super().__init__()
        self.save_hyperparameters()


        self.activation_fn = self._get_activation_fn(activation)

        # Conv Layers
        layers = []
        in_channels = input_channels
        for out_channels, ksize in zip(conv_filters, kernel_sizes):
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=ksize, padding=ksize // 2))
    
            if batch_norm:
                layers.append(nn.BatchNorm2d(out_channels))
    
            layers.append(self._get_activation_fn(activation))
            layers.append(nn.MaxPool2d(2, 2))

            if dropout > 0:
                layers.append(nn.Dropout2d(dropout))
    
            in_channels = out_channels

        
        self.conv_blocks = nn.Sequential(*layers)

        # Flattened dim
        with torch.no_grad():
            dummy_input = torch.zeros(1, input_channels, 128, 128)
            dummy_output = self.conv_blocks(dummy_input)
            flatten_dim = dummy_output.view(1, -1).shape[1]

        # Fully connected
        self.classifier = nn.Sequential(
            nn.Linear(flatten_dim, dense_neurons),
            self.activation_fn,
            nn.Linear(dense_neurons, num_classes)
        )

        self.loss_fn = nn.CrossEntropyLoss()
        self.lr = lr

    def _get_activation_fn(self, name):
        name = name.lower()
        if name == 'relu':
            return nn.ReLU()
        elif name == 'gelu':
            return nn.GELU()
        elif name == 'silu':
            return nn.SiLU()
        elif name == 'mish':
            return nn.Mish()
        else:
            raise ValueError(f"Unsupported activation: {name}")

    def forward(self, x):
        x = self.conv_blocks(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        # self.log("test_loss", loss)
        # self.log("test_acc", acc)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


In [None]:
def launch_sweep():
    
    sweep_config = {
        'method': 'bayes',
        'metric': {
            'name': 'val_acc',
            'goal': 'maximize'
        },
        'parameters': {
            'conv_filters': {
                'values': [
                    [32, 64, 128, 256, 512,512],
                    # [64, 128, 256, 512,1024],
                    # [32, 32, 32, 32, 32],
                    # [64, 64, 64, 64, 64]
                ]
            },
            'kernel_sizes': {
                'values': [
                    [3, 3, 3, 3, 3,3],    
                ]
            },
            'activation': {
                'values': ['relu', 'gelu', 'silu', 'mish']
            },
            'dense_neurons': {
                'values': [256]
            },
            'lr': {
                'min': 0.0001,
                'max': 0.1
            },
            'batch_norm': {
                'values': [True, False]
            },
            'dropout': {
                'values': [0.2, 0.3]
            },
            'batch_size': {
                'values': [16,32]
            },
            'data_augmentation': {
                'values': [True, False]
            }
        }
    }
    sweep_id = wandb.sweep(sweep_config, project='iNaturalist_CNN_Sweep')
    wandb.agent(sweep_id, function=train, count=100)
    # 6mpxfky1

In [None]:
launch_sweep()

In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

def visualize_predictions(model, dataloader, class_names, num_images=30):
    model.eval()
    images_shown = 0
    rows, cols = 10, 3
    fig, axs = plt.subplots(rows, cols, figsize=(cols * 4, rows * 2.5))
    axs = axs.flatten()

    transform = transforms.ToPILImage()

    with torch.no_grad():
        for batch in dataloader:
            x, y = batch
            logits = model(x.to(model.device))
            preds = logits.argmax(dim=1)

            for img, label, pred in zip(x, y, preds):
                if images_shown >= num_images:
                    break
                img = transform(img.cpu())
                axs[images_shown].imshow(img)
                axs[images_shown].axis('off')
                axs[images_shown].set_title(f"True: {class_names[label]}\nPred: {class_names[pred]}")
                images_shown += 1

            if images_shown >= num_images:
                break

    plt.tight_layout()
    plt.savefig("test_predictions_grid.png")
    plt.show()


In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

def visualize_predictions(model, dataloader, class_names, num_images=30):
    model.eval()
    images_shown = 0
    rows, cols = 10, 3
    fig, axs = plt.subplots(rows, cols, figsize=(cols * 4, rows * 2.5))
    axs = axs.flatten()

    transform = transforms.ToPILImage()

    with torch.no_grad():
        for batch in dataloader:
            x, y = batch
            logits = model(x.to(model.device))
            preds = logits.argmax(dim=1)

            for img, label, pred in zip(x, y, preds):
                if images_shown >= num_images:
                    break
                img = transform(img.cpu())
                axs[images_shown].imshow(img)
                axs[images_shown].axis('off')
                axs[images_shown].set_title(f"True: {class_names[label]}\nPred: {class_names[pred]}")
                images_shown += 1

            if images_shown >= num_images:
                break

    plt.tight_layout()
    plt.savefig("test_predictions_grid.png")
    plt.show()


In [None]:
def train_best_model():
    # Define best config manually
    config = {
        "activation": "silu",
        "batch_norm": True,
        "batch_size": 32,
        "data_augmentation": True,
        "dense_neurons": 256,
        "dropout": 0.3,
        "input_channels": 3,
        "lr": 0.0004165458022262786,
        "num_classes": 10,
        "conv_filters": [32, 64, 128, 128, 128],
        "kernel_sizes": [3, 3, 3, 3, 3]
    }

    # Init datamodule
    data_module = Nature12KDataModule(
        data_dir="../../inaturalist_12K",
        batch_size=config["batch_size"],
        image_size=(128, 128),
        data_aug=config["data_augmentation"]
    )
    data_module.prepare_data()
    data_module.setup()

    # Build model with best hyperparameters
    model = CNN(
        input_channels=config["input_channels"],
        conv_filters=config["conv_filters"],
        kernel_sizes=config["kernel_sizes"],
        activation=config["activation"],
        dense_neurons=config["dense_neurons"],
        num_classes=config["num_classes"],
        lr=config["lr"],
        batch_norm=config["batch_norm"],
        dropout=config["dropout"]
    )

    # Optional: disable wandb logging if not needed
    wandb_logger = WandbLogger(project="best_model_eval", name="Best_CNN_Model")

    # Train the model
    trainer = pl.Trainer(
        max_epochs=10,
        accelerator="auto",
        devices="auto",
        log_every_n_steps=10,
        logger=wandb_logger
    )

    print("🚀 Training best model...")
    trainer.fit(model, data_module.train_dataloader(), data_module.val_dataloader())

    print("🧪 Evaluating best model on test set...")
    trainer.test(model, data_module.test_dataloader())

    return model, data_module


In [None]:
model, data_module = train_best_model()

visualize_predictions(model, data_module.test_dataloader(), data_module.class_names)
