# Hyperparameter Tuning for Fruit Ripeness CNN (Colab + W&B)

This notebook extends the baseline CNN to run hyperparameter searches with [Weights & Biases](https://wandb.ai/). It is structured for Google Colab (GPU runtime) or local execution and logs richer metrics such as macro F1 and ROC-AUC to guide model selection.

**Notebook outline**

- Install dependencies (Colab-friendly)
- Configure Kaggle API + dataset download
- Define reusable data loaders, models, and training utilities
- Launch single runs or W&B sweeps across hyperparameters and architectures

In [None]:

%%capture
!pip install -q kagglehub wandb torch torchvision torchaudio scikit-learn tqdm


In [None]:
import os
import random
import copy
from pathlib import Path
import contextlib

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import (
    resnet18,
    ResNet18_Weights,
    mobilenet_v3_small,
    MobileNet_V3_Small_Weights,
    efficientnet_b0,
    EfficientNet_B0_Weights,
)
from sklearn.metrics import (
    balanced_accuracy_score,
    f1_score,
    roc_auc_score,
    classification_report,
)
from tqdm.auto import tqdm
import wandb

try:
    from torch import amp
except ImportError:  # Fallback for older PyTorch versions
    from torch.cuda import amp  # type: ignore[attr-defined]

try:
    from google.colab import files  # type: ignore
    IS_COLAB = True
except ImportError:
    files = None
    IS_COLAB = False

PROJECT_ROOT = Path.cwd()
DATA_ROOT = PROJECT_ROOT / 'data' / 'fruit_ripeness_dataset'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Running on device: {DEVICE}')
print(f'Project root: {PROJECT_ROOT}')


In [None]:

kaggle_dir = Path.home() / '.kaggle'
kaggle_dir.mkdir(parents=True, exist_ok=True)

kaggle_json = kaggle_dir / 'kaggle.json'
if not kaggle_json.exists():
    if files is None:
        raise FileNotFoundError(
            'kaggle.json not found. Upload it in Colab or place it in ~/.kaggle/'
        )
    print('Upload your kaggle.json file (Account > Create New API Token).')
    uploaded = files.upload()
    if not uploaded:
        raise ValueError('No files uploaded.')
    name, data = next(iter(uploaded.items()))
    if name != 'kaggle.json':
        print(f"Received '{name}'. Renaming to 'kaggle.json'.")
    kaggle_json.write_bytes(data)
    print('kaggle.json uploaded.')
else:
    print('kaggle.json already present; skipping upload.')


In [None]:

!chmod 600 ~/.kaggle/kaggle.json


In [None]:

import shutil
import zipfile
import kagglehub

DATASET_SLUG = 'leftin/fruit-ripeness-unripe-ripe-and-rotten'
TARGET_DIR = DATA_ROOT
FORCE_DOWNLOAD = False  # Set to True to refresh the dataset

def iter_files(path: Path):
    return [p for p in path.rglob('*') if p.is_file()]

def copy_contents(src: Path, dst: Path) -> None:
    files = iter_files(src)
    if not files:
        return
    for file_path in tqdm(files, desc='Copying dataset files', unit='file'):
        relative = file_path.relative_to(src)
        target_path = dst / relative
        target_path.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy2(file_path, target_path)

def extract_archives(path: Path) -> None:
    zip_paths = list(path.rglob('*.zip'))
    for zip_path in tqdm(zip_paths, desc='Extracting archives', unit='zip'):
        out_dir = zip_path.with_suffix('')
        out_dir.mkdir(parents=True, exist_ok=True)
        with zipfile.ZipFile(zip_path, 'r') as zf:
            for member in tqdm(zf.namelist(), desc=f'Extracting {zip_path.name}', leave=False, unit='file'):
                zf.extract(member, out_dir)
        zip_path.unlink()

def find_split_dir(root: Path, name: str):
    candidates = sorted([p for p in root.rglob(name) if p.is_dir()], key=lambda p: len(p.parts))
    for candidate in candidates:
        if any(candidate.glob('*/*')):
            return candidate
    return None

if TARGET_DIR.exists() and not FORCE_DOWNLOAD:
    print(f'Dataset already present at {TARGET_DIR.resolve()}')
else:
    if TARGET_DIR.exists():
        shutil.rmtree(TARGET_DIR)
    TARGET_DIR.mkdir(parents=True, exist_ok=True)
    print(f'Downloading {DATASET_SLUG} with kagglehub …')
    downloaded_path = Path(kagglehub.dataset_download(DATASET_SLUG)).resolve()
    print(f'Download complete: {downloaded_path}')
    copy_contents(downloaded_path, TARGET_DIR)
    extract_archives(TARGET_DIR)
    print(f'Dataset extracted to {TARGET_DIR.resolve()}')

TRAIN_DIR = find_split_dir(TARGET_DIR, 'train')
TEST_DIR = find_split_dir(TARGET_DIR, 'test')
if TRAIN_DIR is None or TEST_DIR is None:
    raise RuntimeError("Expected 'train' and 'test' folders inside the dataset directory.")

print(f'Train directory: {TRAIN_DIR}')
print(f'Test directory: {TEST_DIR}')


In [None]:

BASE_TRAIN_DATASET = ImageFolder(TRAIN_DIR, transform=None)
BASE_TEST_DATASET = ImageFolder(TEST_DIR, transform=None)
if BASE_TEST_DATASET.classes != BASE_TRAIN_DATASET.classes:
    raise RuntimeError('Class labels differ between train and test directories.')
CLASS_NAMES = BASE_TRAIN_DATASET.classes
NUM_CLASSES = len(CLASS_NAMES)
print(f'Classes: {CLASS_NAMES}')
print(f'Train images: {len(BASE_TRAIN_DATASET)} | Test images: {len(BASE_TEST_DATASET)}')


**Metrics monitored during tuning**

- `accuracy`: overall correctness for sanity checks.
- `macro_f1`: treats each class equally, ideal when classes are imbalanced.
- `weighted_f1`: accounts for class frequency while still penalising poor minority-class recall.
- `balanced_accuracy`: average recall per class, highlighting skew.
- `roc_auc_ovr`: one-vs-rest ROC AUC computed from predicted probabilities (logged when all classes appear in the split).

In [None]:

WANDB_PROJECT = 'fruit-ripeness-cnn'
WANDB_ENTITY = None  # Set to your team/user; None uses the default account
WANDB_TAGS_BASE = ['fruit-ripeness', 'cnn', 'hyperparam-tuning']
WANDB_NOTES = 'Hyperparameter sweeps for Assignment 2 CNN models.'

DEFAULT_CONFIG = {
    'seed': 42,
    'image_size': 224,
    'val_split': 0.15,
    'batch_size': 32,
    'epochs': 15,
    'learning_rate': 3e-4,
    'weight_decay': 1e-4,
    'optimizer': 'adamw',
    'architecture': 'simple_cnn',
    'pretrained': False,
    'dropout': 0.4,
    'label_smoothing': 0.05,
    'scheduler': 'cosine',
    'min_lr': 1e-6,
    'freeze_backbone': False,
    'aug_hflip': True,
    'aug_rotation': 10,
    'aug_color_jitter': 0.2,
    'aug_random_erasing': 0.0,
    'use_amp': True,
    'max_grad_norm': 2.0,
    'patience': 4,
    'log_model': False,
    'wandb_mode': os.environ.get('WANDB_MODE', 'online'),
}
print('Default config keys:', list(DEFAULT_CONFIG.keys()))
if os.environ.get('WANDB_API_KEY'):
    try:
        wandb.login()
    except Exception as err:
        print(f'W&B login failed: {err}')
else:
    print('Set WANDB_API_KEY or run wandb.login() to enable logging.')


In [None]:

def seed_everything(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class SubsetWithTransform(Dataset):
    def __init__(self, dataset: ImageFolder, indices, transform):
        self.dataset = dataset
        self.indices = list(indices)
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        image, label = self.dataset[self.indices[idx]]
        if self.transform:
            image = self.transform(image)
        return image, label

def build_transforms(config, train: bool):
    image_size = int(config['image_size'])
    transform_list = [transforms.Resize((image_size, image_size))]
    if train and config.get('aug_hflip', True):
        transform_list.append(transforms.RandomHorizontalFlip())
    rotation = float(config.get('aug_rotation', 0))
    if train and rotation:
        transform_list.append(transforms.RandomRotation(rotation))
    jitter = float(config.get('aug_color_jitter', 0))
    if train and jitter > 0:
        transform_list.append(
            transforms.ColorJitter(
                brightness=jitter,
                contrast=jitter,
                saturation=min(1.0, jitter),
                hue=min(0.5, jitter * 0.1),
            )
        )
    transform_list.append(transforms.ToTensor())
    transform_list.append(
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    )
    erasing_p = float(config.get('aug_random_erasing', 0))
    if train and erasing_p > 0:
        transform_list.append(transforms.RandomErasing(p=erasing_p, value='random'))
    return transforms.Compose(transform_list)

def create_dataloaders(config):
    val_split = float(config['val_split'])
    batch_size = int(config['batch_size'])
    generator = torch.Generator().manual_seed(int(config['seed']))

    val_size = max(1, int(len(BASE_TRAIN_DATASET) * val_split))
    train_size = len(BASE_TRAIN_DATASET) - val_size
    train_subset, val_subset = random_split(
        BASE_TRAIN_DATASET, [train_size, val_size], generator=generator
    )

    train_transforms = build_transforms(config, train=True)
    eval_transforms = build_transforms(config, train=False)

    train_dataset = SubsetWithTransform(BASE_TRAIN_DATASET, train_subset.indices, train_transforms)
    val_dataset = SubsetWithTransform(BASE_TRAIN_DATASET, val_subset.indices, eval_transforms)
    test_dataset = SubsetWithTransform(BASE_TEST_DATASET, range(len(BASE_TEST_DATASET)), eval_transforms)

    num_workers = 0 if IS_COLAB else max(0, min(4, (os.cpu_count() or 1) - 1))
    pin_memory = torch.cuda.is_available()

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

    return train_loader, val_loader, test_loader, CLASS_NAMES


In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes: int, image_size: int = 224, dropout: float = 0.5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout(p=dropout),
        )
        feature_dim = 128 * (image_size // 8) * (image_size // 8)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(feature_dim, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

def build_model(config, num_classes):
    arch = config['architecture']
    dropout = float(config.get('dropout', 0.4))
    pretrained = bool(config.get('pretrained', False))
    freeze_backbone = bool(config.get('freeze_backbone', False))
    image_size = int(config['image_size'])

    if arch == 'simple_cnn':
        model = SimpleCNN(num_classes=num_classes, image_size=image_size, dropout=dropout)
    elif arch == 'resnet18':
        weights = ResNet18_Weights.DEFAULT if pretrained else None
        model = resnet18(weights=weights)
        if freeze_backbone:
            for param in model.parameters():
                param.requires_grad = False
        in_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(in_features, num_classes),
        )
        if freeze_backbone:
            for param in model.fc.parameters():
                param.requires_grad = True
    elif arch == 'mobilenet_v3_small':
        weights = MobileNet_V3_Small_Weights.DEFAULT if pretrained else None
        model = mobilenet_v3_small(weights=weights)
        if freeze_backbone:
            for param in model.features.parameters():
                param.requires_grad = False
        in_features = model.classifier[-1].in_features
        model.classifier[2] = nn.Dropout(p=dropout, inplace=True)
        model.classifier[-1] = nn.Linear(in_features, num_classes)
        if freeze_backbone:
            for param in model.classifier.parameters():
                param.requires_grad = True
    elif arch == 'efficientnet_b0':
        weights = EfficientNet_B0_Weights.DEFAULT if pretrained else None
        model = efficientnet_b0(weights=weights)
        if freeze_backbone:
            for param in model.features.parameters():
                param.requires_grad = False
        in_features = model.classifier[-1].in_features
        model.classifier[0] = nn.Dropout(p=dropout, inplace=True)
        model.classifier[-1] = nn.Linear(in_features, num_classes)
        if freeze_backbone:
            for param in model.classifier.parameters():
                param.requires_grad = True
    else:
        raise ValueError(f'Unknown architecture: {arch}')

    return model.to(DEVICE)

def build_optimizer(model, config):
    lr = float(config['learning_rate'])
    weight_decay = float(config.get('weight_decay', 0.0))
    optimizer_name = config.get('optimizer', 'adamw').lower()
    parameters = filter(lambda p: p.requires_grad, model.parameters())

    if optimizer_name == 'adam':
        optimizer = torch.optim.Adam(parameters, lr=lr, weight_decay=weight_decay)
    elif optimizer_name == 'sgd':
        momentum = float(config.get('momentum', 0.9))
        nesterov = bool(config.get('nesterov', True))
        optimizer = torch.optim.SGD(parameters, lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov)
    else:
        optimizer = torch.optim.AdamW(parameters, lr=lr, weight_decay=weight_decay)
    return optimizer

def build_scheduler(optimizer, config):
    scheduler_name = config.get('scheduler', 'cosine').lower()
    epochs = int(config['epochs'])
    min_lr = float(config.get('min_lr', 1e-6))
    if scheduler_name == 'cosine':
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max(1, epochs), eta_min=min_lr)
    if scheduler_name == 'plateau':
        return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, min_lr=min_lr)
    return None

def compute_metrics(y_true, y_probs):
    y_true = np.asarray(y_true)
    y_probs = np.asarray(y_probs)
    y_pred = y_probs.argmax(axis=1)
    metrics = {
        'accuracy': float((y_pred == y_true).mean()),
        'macro_f1': float(f1_score(y_true, y_pred, average='macro')),
        'weighted_f1': float(f1_score(y_true, y_pred, average='weighted')),
        'balanced_accuracy': float(balanced_accuracy_score(y_true, y_pred)),
    }
    try:
        metrics['roc_auc_ovr'] = float(roc_auc_score(y_true, y_probs, multi_class='ovr'))
    except ValueError:
        metrics['roc_auc_ovr'] = float('nan')
    return metrics, y_pred

def autocast_context(enabled: bool):
    if not enabled:
        return contextlib.nullcontext()
    try:
        return amp.autocast(device_type=DEVICE.type, enabled=True)
    except TypeError:
        return amp.autocast(enabled=True)

def create_grad_scaler(enabled: bool):
    if not enabled:
        return None
    grad_scaler_cls = getattr(amp, 'GradScaler')
    try:
        return grad_scaler_cls(device_type=DEVICE.type, enabled=True)
    except TypeError:
        return grad_scaler_cls(enabled=True)

def train_one_epoch(model, loader, criterion, optimizer, scaler, config, epoch):
    model.train()
    use_amp = scaler is not None and scaler.is_enabled()
    running_loss = 0.0
    running_correct = 0
    total = 0
    progress = tqdm(loader, desc=f'Train {epoch:02d}', leave=False)
    for images, labels in progress:
        images = images.to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with autocast_context(use_amp):
            logits = model(images)
            loss = criterion(logits, labels)
        if scaler is not None and scaler.is_enabled():
            scaler.scale(loss).backward()
            if config.get('max_grad_norm', 0) > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            if config.get('max_grad_norm', 0) > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
            optimizer.step()
        running_loss += loss.item() * labels.size(0)
        preds = logits.argmax(dim=1)
        running_correct += (preds == labels).sum().item()
        total += labels.size(0)
        progress.set_postfix(loss=running_loss / total, acc=running_correct / total)

    epoch_loss = running_loss / total
    epoch_acc = running_correct / total
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion, config, split='Eval', epoch=None):
    model.eval()
    use_amp = config.get('use_amp', True) and DEVICE.type == 'cuda'
    running_loss = 0.0
    total = 0
    probs_buffer = []
    labels_buffer = []
    desc = f'{split} {epoch:02d}' if epoch is not None else split
    progress = tqdm(loader, desc=desc, leave=False)
    with torch.no_grad():
        for images, labels in progress:
            images = images.to(DEVICE, non_blocking=True)
            labels = labels.to(DEVICE, non_blocking=True)
            with autocast_context(use_amp):
                logits = model(images)
                loss = criterion(logits, labels)
            probs = torch.softmax(logits, dim=1).cpu()
            labels_cpu = labels.cpu()
            running_loss += loss.item() * labels.size(0)
            total += labels.size(0)
            probs_buffer.append(probs)
            labels_buffer.append(labels_cpu)
            preds = probs.argmax(dim=1)
            batch_acc = (preds == labels_cpu).float().mean().item()
            progress.set_postfix(loss=running_loss / total, acc=batch_acc)

    avg_loss = running_loss / total
    y_probs = torch.cat(probs_buffer).numpy()
    y_true = torch.cat(labels_buffer).numpy()
    metrics, y_pred = compute_metrics(y_true, y_probs)
    return avg_loss, metrics, y_true, y_probs, y_pred


In [None]:
def run_experiment(config=None, enable_wandb=True, run_name=None):
    cfg = DEFAULT_CONFIG.copy()
    extra_tags = WANDB_TAGS_BASE.copy()
    notes = WANDB_NOTES
    if config:
        for key, value in config.items():
            if key in cfg:
                cfg[key] = value
        if 'tags' in config:
            tags = config['tags']
            if isinstance(tags, (list, tuple)):
                extra_tags.extend(tags)
            else:
                extra_tags.append(tags)
        if 'notes' in config:
            notes = f"{WANDB_NOTES} | {config['notes']}"

    use_amp = cfg.get('use_amp', True) and DEVICE.type == 'cuda'
    scaler = create_grad_scaler(use_amp)
    wandb_mode = cfg.get('wandb_mode', 'online')
    if not enable_wandb:
        wandb_mode = 'disabled'

    with wandb.init(
        project=WANDB_PROJECT,
        entity=WANDB_ENTITY,
        config=cfg,
        mode=wandb_mode,
        job_type='tuning',
        tags=extra_tags,
        name=run_name,
        notes=notes,
    ) as run:
        cfg_runtime = dict(run.config) if run is not None else cfg
        seed_everything(int(cfg_runtime['seed']))
        train_loader, val_loader, test_loader, class_names = create_dataloaders(cfg_runtime)
        model = build_model(cfg_runtime, num_classes=len(class_names))
        criterion = nn.CrossEntropyLoss(label_smoothing=float(cfg_runtime.get('label_smoothing', 0.0)))
        optimizer = build_optimizer(model, cfg_runtime)
        scheduler = build_scheduler(optimizer, cfg_runtime)
        epochs = int(cfg_runtime['epochs'])
        best_metric = float('-inf')
        best_state = None
        epochs_without_improvement = 0

        for epoch in range(1, epochs + 1):
            train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, scaler, cfg_runtime, epoch)
            val_loss, val_metrics, y_true_val, y_probs_val, y_pred_val = evaluate(model, val_loader, criterion, cfg_runtime, split='Validation', epoch=epoch)

            if scheduler:
                if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                    scheduler.step(val_metrics['macro_f1'])
                else:
                    scheduler.step()

            log_payload = {
                'epoch': epoch,
                'train_loss': train_loss,
                'train_accuracy': train_acc,
                'val_loss': val_loss,
                'learning_rate': optimizer.param_groups[0]['lr'],
            }
            for key, value in val_metrics.items():
                log_payload[f'val_{key}'] = value
            wandb.log(log_payload)

            if val_metrics['macro_f1'] > best_metric:
                best_metric = val_metrics['macro_f1']
                best_state = copy.deepcopy(model.state_dict())
                epochs_without_improvement = 0
                wandb.log({'best_val_macro_f1': best_metric}, commit=False)
            else:
                epochs_without_improvement += 1

            if cfg_runtime.get('patience', 0) and epochs_without_improvement >= int(cfg_runtime['patience']):
                print(f'Early stopping at epoch {epoch} (best val macro F1: {best_metric:.4f})')
                break

        if best_state is not None:
            model.load_state_dict(best_state)

        test_loss, test_metrics, y_true_test, y_probs_test, y_pred_test = evaluate(model, test_loader, criterion, cfg_runtime, split='Test')

        test_log = {'test_loss': test_loss}
        for key, value in test_metrics.items():
            test_log[f'test_{key}'] = value
        wandb.log(test_log)

        try:
            cm_plot = wandb.plot.confusion_matrix(y_true=y_true_test, preds=y_pred_test, class_names=class_names)
            wandb.log({'test_confusion_matrix': cm_plot})
        except Exception as err:
            print(f'Confusion matrix not logged: {err}')

        try:
            roc_plot = wandb.plot.roc_curve(y_true_test, y_probs_test, labels=class_names)
            wandb.log({'test_roc_curve': roc_plot})
        except Exception as err:
            print(f'ROC curve not logged: {err}')

        report = classification_report(y_true_test, y_pred_test, target_names=class_names, output_dict=True)
        report_table = wandb.Table(columns=['class', 'precision', 'recall', 'f1', 'support'])
        for class_name in class_names:
            stats = report[class_name]
            report_table.add_data(class_name, float(stats['precision']), float(stats['recall']), float(stats['f1-score']), int(stats['support']))
        weighted_stats = report['weighted avg']
        report_table.add_data('weighted avg', float(weighted_stats['precision']), float(weighted_stats['recall']), float(weighted_stats['f1-score']), int(weighted_stats['support']))
        wandb.log({'test_classification_report': report_table})

        if cfg_runtime.get('log_model', False) and run is not None:
            ckpt_dir = PROJECT_ROOT / 'checkpoints'
            ckpt_dir.mkdir(parents=True, exist_ok=True)
            ckpt_path = ckpt_dir / f"{run.name or run.id}_best.pt"
            torch.save({'model_state_dict': best_state, 'config': cfg_runtime, 'class_names': class_names}, ckpt_path)
            artifact = wandb.Artifact(name=f"{run.name or run.id}-model", type='model')
            artifact.add_file(ckpt_path)
            run.log_artifact(artifact)
            print(f'Saved best model checkpoint to {ckpt_path}')

        print('Test metrics:')
        for key, value in test_metrics.items():
            print(f'  {key}: {value:.4f}')
        return test_metrics


In [None]:
SWEEP_CONFIG = {
    'name': 'cnn-architecture-hparam-search',
    'method': 'bayes',
    'metric': {'name': 'val_macro_f1', 'goal': 'maximize'},
    'parameters': {
        'architecture': {'values': ['simple_cnn', 'resnet18', 'mobilenet_v3_small', 'efficientnet_b0']},
        'pretrained': {'values': [True, False]},
        'freeze_backbone': {'values': [True, False]},
        'learning_rate': {'distribution': 'log_uniform_values', 'min': 1e-5, 'max': 3e-3},
        'weight_decay': {'distribution': 'log_uniform_values', 'min': 1e-6, 'max': 1e-2},
        'batch_size': {'values': [32, 48, 64]},
        'dropout': {'min': 0.1, 'max': 0.6},
        'aug_color_jitter': {'values': [0.0, 0.15, 0.3]},
        'aug_random_erasing': {'values': [0.0, 0.1, 0.25]},
        'optimizer': {'values': ['adamw', 'adam', 'sgd']},
    },
}
print('Sweep ready. Primary metric:', SWEEP_CONFIG['metric'])


In [None]:

def sweep_entry(config=None):
    return run_experiment(config=config, enable_wandb=True)

print('To launch: sweep_id = wandb.sweep(SWEEP_CONFIG, project=WANDB_PROJECT, entity=WANDB_ENTITY)')
print('Then run: wandb.agent(sweep_id, sweep_entry, count=NUM_RUNS)')


**Quick smoke test (optional)**

Toggle the cell below if you want to run a fast sanity check without logging to W&B.

In [None]:

if False:  # Set to True for a quick offline run
    debug_config = {**DEFAULT_CONFIG, 'epochs': 2, 'batch_size': 16, 'architecture': 'simple_cnn', 'use_amp': False}
    run_experiment(config=debug_config, enable_wandb=False, run_name='debug-run')


In [None]:

# Example sweep launch (uncomment to execute)
# sweep_id = wandb.sweep(SWEEP_CONFIG, project=WANDB_PROJECT, entity=WANDB_ENTITY)
# wandb.agent(sweep_id, function=sweep_entry, count=8)
