In [12]:
import os
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.optim import Adam
import torch.optim as optim  # for NAdam
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
import matplotlib.pyplot as plt
from pytorch_lightning.loggers import WandbLogger
from sklearn.model_selection import StratifiedShuffleSplit
from pytorch_lightning.callbacks import EarlyStopping

# -------------------------------
# Helper: Activation Function Getter
# -------------------------------
def get_activation(activation):
    """
    Returns an activation function based on the given input.
    Accepts a string (case insensitive) or a callable.

    Available options:
      "none" (for no activation),
      "relu", "tanh", "sigmoid", "leaky_relu", "gelu", "silu", "mish", "selu".
    For "mish": mish(x) = x * tanh(softplus(x))
    """
    if isinstance(activation, str):
        act = activation.lower()
        if act == "none":
            return lambda x: x  # Identity function (no activation)
        activation_dict = {
            'relu': F.relu,
            'tanh': torch.tanh,
            'sigmoid': torch.sigmoid,
            'leaky_relu': F.leaky_relu,
            'gelu': F.gelu,
            'silu': F.silu,
            'mish': lambda x: x * torch.tanh(F.softplus(x)),
            'selu': F.selu
        }
        return activation_dict.get(act, F.relu)
    return activation

# -------------------------------
# Model Definition: iNaturalistCNN
# -------------------------------
class iNaturalistCNN(pl.LightningModule):
    def __init__(self, num_classes=10, learning_rate=0.001, input_size=128,
                 conv_configs=[(32, 3), (64, 3), (128, 3), (256, 3), (512, 3)],
                 dense_neurons=512,
                 dropout_rate=0.5,
                 conv_activation='relu',
                 dense_activation='relu',
                 use_batch_norm=False,
                 optimizer_type="adam"):
        """
        Args:
            num_classes (int): Number of target classes.
            learning_rate (float): Learning rate.
            input_size (int): Image size (e.g. 224).
            conv_configs (list of tuples): List of 5 tuples, each (num_filters, kernel_size).
            dense_neurons (int): Number of neurons in the dense layer.
            dropout_rate (float): Dropout probability.
            conv_activation (str/callable): Activation for conv layers.
            dense_activation (str/callable): Activation for the dense layer.
            use_batch_norm (bool): Whether to include batch normalization.
            optimizer_type (str): Type of optimizer: "adam" or "nadam".
        """
        super(iNaturalistCNN, self).__init__()
        self.save_hyperparameters(ignore=['conv_activation', 'dense_activation'])
        self.use_batch_norm = use_batch_norm

        # Set activation functions.
        self.conv_activation = get_activation(conv_activation)
        self.dense_activation = get_activation(dense_activation)
        self.dropout_rate = dropout_rate

        # Build the convolutional layers.
        if len(conv_configs) != 5:
            raise ValueError("conv_configs must be a list of exactly five tuples.")
        self.conv_layers = nn.ModuleList()
        if self.use_batch_norm:
            self.bn_layers = nn.ModuleList()
        in_channels = 3  # for RGB images
        for out_channels, k_size in conv_configs:
            conv_layer = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=k_size,
                stride=1,
                padding=k_size // 2  # maintain spatial dimensions
            )
            self.conv_layers.append(conv_layer)
            if self.use_batch_norm:
                self.bn_layers.append(nn.BatchNorm2d(out_channels))
            in_channels = out_channels

        # Common max-pooling and dropout for convolutional blocks.
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout_conv = nn.Dropout2d(p=dropout_rate)

        # Compute flattened feature dimension using a dummy input.
        dummy_input = torch.zeros(1, 3, input_size, input_size)
        x = dummy_input
        for i, conv in enumerate(self.conv_layers):
            x = conv(x)
            if self.use_batch_norm:
                x = self.bn_layers[i](x)
            x = self.conv_activation(x)
            x = self.pool(x)
            x = self.dropout_conv(x)
        self.feature_dim = x.view(1, -1).size(1)

        # Fully-connected (dense) layers.
        self.fc1 = nn.Linear(self.feature_dim, dense_neurons)
        self.dropout_fc = nn.Dropout(p=dropout_rate)
        self.fc2 = nn.Linear(dense_neurons, 10)

    def forward(self, x):
        # Convolutional blocks: conv -> (optional BN) -> activation -> pool -> dropout.
        for i, conv in enumerate(self.conv_layers):
            x = conv(x)
            if self.use_batch_norm:
                x = self.bn_layers[i](x)
            x = self.conv_activation(x)
            x = self.pool(x)
            x = self.dropout_conv(x)
        x = x.view(x.size(0), -1)  # Flatten the features.
        x = self.dense_activation(self.fc1(x))
        x = self.dropout_fc(x)
        x = self.fc2(x)  # Raw logits (F.cross_entropy applies softmax internally)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log("val_loss", loss, prog_bar=True,on_epoch=True)
        self.log("val_acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log("test_loss", loss,on_epoch=True,on_step=False)
        self.log("test_acc", acc, on_step=False, on_epoch=True, )
        return loss

    def configure_optimizers(self):
        # Choose the optimizer based on the hyperparameter.
        optimizer_type = self.hparams.optimizer_type.lower()
        weight_decay = getattr(self.hparams, "weight_decay", 1e-4)
        if optimizer_type == "adam":
            optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate,weight_decay = weight_decay)
            
        elif optimizer_type == "nadam":
            optimizer = torch.optim.NAdam(self.parameters(), lr=self.hparams.learning_rate,weight_decay = weight_decay)
             
        else:
            raise ValueError(f"Unknown optimizer type: {optimizer_type}")
        return optimizer

# -------------------------------
# Data Splitting with Stratified Sampling
# -------------------------------
def get_train_val_split(dataset, test_size=0.2, random_state=42):
    """
    Given an ImageFolder dataset, splits the indices into training (80%) and
    validation (20%) sets using stratified sampling.
    """
    targets = dataset.targets
    indices = list(range(len(dataset)))
    sss = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
    train_idx, val_idx = next(sss.split(indices, targets))
    return train_idx, val_idx

# -------------------------------
# Data Transforms (with or without augmentation)
# -------------------------------
def get_transforms(use_data_aug):
    if use_data_aug:
        transform_train = transforms.Compose([
            transforms.Resize((128,128)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(20),
            transforms.ToTensor(),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.Resize((128,128)),
            transforms.ToTensor(),
        ])
    # For test/validation, use minimal transforms.
    transform_test = transforms.Compose([
        transforms.Resize((128,128)),
        transforms.ToTensor(),
    ])
    return transform_train, transform_test

# -------------------------------
# Main Training Function (sweep-compatible)
# -------------------------------
def run_training():
    wandb.init()
    config = wandb.config

    # Process filter organization and kernel size.
    kernel_size = config.get("kernel_size", 3)
    filt_org = config.get("filter_organization", "constant")
    if filt_org == "constant":
        conv_configs = [(config.constant_filter, kernel_size)] * 5
    elif filt_org == "doubling":
        base = config.base_filter
        conv_configs = [(base * (2 ** i), kernel_size) for i in range(5)]
    elif filt_org == "halving":
        base = config.base_filter
        conv_configs = [(base // (2 ** i), kernel_size) for i in range(5)]
    else:
        raise ValueError("Unknown filter organization!")
    
    # Data augmentation setup.
    use_data_aug = config.get("data_augmentation", True)
    transform_train, transform_test = get_transforms(use_data_aug)

    # Load the full training dataset from train_data_path.
    train_dataset_full = ImageFolder(root=config.train_data_path, transform=transform_train)
    # Split the full training dataset into training (80%) and validation (20%).
    train_idx, val_idx = get_train_val_split(train_dataset_full, test_size=0.2, random_state=42)
    train_dataset = Subset(train_dataset_full, train_idx)
    val_dataset = Subset(train_dataset_full, val_idx)

    # Load the test dataset from test_data_path.
    test_dataset = ImageFolder(root=config.test_data_path, transform=transform_test)

    # Create data loaders.
    train_loader = DataLoader(train_dataset, batch_size=int(config.batch_size), shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=int(config.batch_size), shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=int(config.batch_size), shuffle=False, num_workers=4)

    # Create the model.
    model = iNaturalistCNN(
        num_classes=10,
        learning_rate=config.learning_rate,
        input_size=128,
        conv_configs=conv_configs,
        dense_neurons=int(config.dense_neurons),
        dropout_rate=config.dropout_rate,
        conv_activation=config.conv_activation,
        dense_activation=config.dense_activation,
        use_batch_norm=config.batch_norm,
        optimizer_type=config.optimizer  # Use the optimizer hyperparameter here.
    )
    trainer.fit(model, train_loader, val_loader)

    # Set up the Wandb logger.
    wandb_logger = WandbLogger(project=config.project_name)
        
    # Define early stopping
    early_stop_callback = EarlyStopping(
        monitor="val_loss",       # metric to watch
        patience=3,               # number of epochs with no improvement after which training will be stopped
        mode="min",               # because we want to minimize val_loss
        verbose=True
    )
    # Define the PyTorch Lightning trainer.
    trainer = pl.Trainer(
        max_epochs=int(config.max_epochs),
        accelerator="auto",
        devices=1,
        logger=wandb_logger,
    )
    
    # Train the model using training and validation data.
    trainer.fit(model, train_loader, val_loader)
    
    # Test the final model on the test dataset.
    trainer.test(model, dataloaders=test_loader)
    wandb.finish()

In [17]:
# -------------------------------
# Define Sweep Configuration as a Dictionary (no external YAML)
# -------------------------------
sweep_config = {
    "program": "New_file.py",  # Ensure this matches your script filename.
    "method": "bayes",      # Options: "grid", "random", "bayes"
    "metric": {
        "name": "val_loss",
        "goal": "minimize"
    },
    "parameters": {
        "filter_organization": {"values": ["constant", "doubling", "halving"]},
        "constant_filter": {"values": [32, 64]},
        "base_filter": {"values": [64,128,]},
        "kernel_size": {"values": [3, 5, 7]},
        "conv_activation": {"values": ["relu", "gelu", "silu", "mish"]},
        "dense_activation": {"values": ["relu", "gelu", "silu", "mish"]},
        "data_augmentation": {"values": [True, False]},
        "batch_norm": {"values": [True, False]},
        "dropout_rate": {"values": [0,0.25]},
        "learning_rate": {"values": [0.0001, 0.001]},
        "dense_neurons": {"values": [256,512]},
        "batch_size": {"values": [32,64]},
        "optimizer": {"values": ["adam"]},
        "max_epochs": {"value": 20 },
        "weight_decay": {"values": [0.0, 1e-4,1e-3]},
        "train_data_path": {"value": "/home/user/kartikey_phd/DA6401/nature_12K/inaturalist_12K/train"},  # Update with actual path.
        "test_data_path": {"value": "/home/user/kartikey_phd/DA6401/nature_12K/inaturalist_12K/val"},    # Update with actual path.
        "project_name": {"value": "iNaturalist_Sweep_final_2"}
    }
}

# -------------------------------
# Run the Sweep and Launch the Agent 
# -------------------------------
if __name__ == "__main__":
    # Initialize the sweep programmatically.
    sweep_id = wandb.sweep(sweep_config, project=sweep_config["parameters"]["project_name"]["value"])
    # Launch the wandb agent to run the training function over multiple trials.
    wandb.agent(sweep_id, function=run_training)
