In [None]:
import pytorch_lightning as pl
import wandb
from pytorch_lightning.loggers import WandbLogger
import os
from Dataset import Nature12KDataModule
from ResNet import ResNetFinetune


In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [None]:
def train():
    wandb.init()
    config = wandb.config

    data_module = Nature12KDataModule(
        data_dir="../../inaturalist_12K",
        batch_size=config.batch_size,
        image_size=(config.image_size, config.image_size),
        data_aug=config.data_augmentation
    )

    data_module.prepare_data()
    data_module.setup()

    model = ResNetFinetune(
        num_classes=len(data_module.class_names),
        lr=config.lr,
        optimizer=config.optimizer,
        momentum=config.momentum,
        weight_decay=config.weight_decay,
        scheduler=config.scheduler,
        freeze_type=config.freeze_type,
        freeze_upto_layer=config.freeze_upto_layer,
        resnet_variant=config.resnet_variant
    )

    wandb_logger = WandbLogger(project=wandb.run.project, name=wandb.run.name)

    trainer = pl.Trainer(
        max_epochs=config.max_epochs,
        accelerator="auto",
        devices="auto",
        log_every_n_steps=10,
        logger=wandb_logger
    )

    print("🚀 Training model...")
    trainer.fit(model, data_module.train_dataloader(), data_module.val_dataloader())


In [None]:
def launch_sweep():
    sweep_config = {
        'method': 'bayes',
        'metric': {
            'name': 'val_acc',
            'goal': 'maximize'
        },
        'parameters': {
            'lr': {'min': 1e-5, 'max': 1e-2},  # Narrower range for fine-tuning
            'momentum': {'min': 0.8, 'max': 0.99},  # Higher momentum often helps convergence
            'batch_size': {'values': [32, 64]},  # Try larger batch if memory allows
            'freeze_type': {'values': ['none', 'upto']},  # 'all' typically underperforms in deeper models
            'freeze_upto_layer': {'values': [3, 5, 6, 7]},  # Deeper freezing for larger networks
            'data_augmentation': {'values': [True]},  # Always augment for fine-tuning on iNaturalist
            'resnet_variant': {'values': ['resnet101']},  # Using ResNet101 (closest to ResNet150)
            'weight_decay': {'min': 0.0001, 'max': 0.01},  # Avoid over-regularization
            'scheduler': {'values': [True]},  # Learning rate scheduler is typically beneficial
            'image_size': {'values': [224, 256, 384]},  # Avoid too large images to prevent OOM
            'max_epochs': {'values': [10, 15, 20]},  # Longer training improves accuracy
            'optimizer': {'values': ['sgd']}  # Best for large pretrained models like ResNet
        }
    }

    sweep_id = wandb.sweep(sweep_config, project='iNaturalist_ResNet_Sweep')
    wandb.agent(sweep_id, function=train, count=100)
