In [None]:
# 🧩 Cell 1: Imports
import wandb
from Dataset import Nature12KDataModule
from cnn_model import CNN
import torch.nn as nn
import pytorch_lightning as pl
import wandb
from pytorch_lightning.loggers import WandbLogger
from Dataset import Nature12KDataModule
from cnn_model import CNN

In [None]:
def train():
    wandb.init()
    config = wandb.config

    data_module = Nature12KDataModule(
        data_dir="../../inaturalist_12K",
        batch_size=config.batch_size,
        image_size=(128, 128),
        data_aug=config.data_augmentation
    )

    data_module.prepare_data()
    data_module.setup()

    model = CNN(
        input_channels=3,
        conv_filters=config.conv_filters,
        kernel_sizes=config.kernel_sizes,
        activation=config.activation,
        dense_neurons=config.dense_neurons,
        num_classes=len(data_module.class_names),
        lr=config.lr,
        batch_norm=config.batch_norm,
        dropout=config.dropout
    )

    wandb_logger = WandbLogger(project=wandb.run.project, name=wandb.run.name)

    trainer = pl.Trainer(
        max_epochs=10,
        accelerator="auto",
        devices="auto",
        log_every_n_steps=10,
        logger=wandb_logger
    )

    print("🚀 Training model...")
    trainer.fit(model, data_module.train_dataloader(), data_module.val_dataloader())

    print("🧪 Evaluating on test set...")
    trainer.test(model, data_module.test_dataloader())

    return model, data_module  # ✅ return both


In [None]:
# 🌀 Cell 2: Define and Launch Sweep
sweep_config = {
    'method': 'bayes',
    'metric': {
        'name': 'val_acc',
        'goal': 'maximize'
    },
    'parameters': {
        'conv_filters': {
            'values': [[32, 64, 128, 256, 512, 512]]
        },
        'kernel_sizes': {
            'values': [[3, 3, 3, 3, 3, 3]]
        },
        'activation': {
            'values': ['relu', 'gelu', 'silu', 'mish']
        },
        'dense_neurons': {
            'values': [256]
        },
        'lr': {
            'min': 0.0001,
            'max': 0.1
        },
        'batch_norm': {
            'values': [True, False]
        },
        'dropout': {
            'values': [0.2, 0.3]
        },
        'batch_size': {
            'values': [16, 32]
        },
        'data_augmentation': {
            'values': [True, False]
        }
    }
}

# Create Sweep
sweep_id = wandb.sweep(sweep_config, project='iNaturalist_CNN_Sweep')


In [None]:
# 🚀 Cell 3: Launch the sweep agent (runs train() with W&B config)
wandb.agent(sweep_id, function=train, count=100)  # You can change count if needed
