In [None]:
import random, numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

import torchaudio
import torchaudio.transforms as AT

from dataclasses import dataclass
from typing import Optional, Dict, Any

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Dataset config
DATA_ROOT = "./data"
BATCH_SIZE = 128
NUM_WORKERS = 2

# These will be inferred automatically
TARGET_SR: Optional[int] = None
CLIP_SECONDS: Optional[float] = None
TARGET_SAMPLES: Optional[int] = None

# Will be filled in get_dataloaders and used by the models
NUM_CLASSES: Optional[int] = None


def _seed_worker(worker_id: int):
    worker_seed = (SEED + worker_id) % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def pad_or_trim(waveform: torch.Tensor, target_len: int) -> torch.Tensor:
    if waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)
    T = waveform.shape[-1]
    if T == target_len:
        return waveform
    if T > target_len:
        return waveform[:, :target_len]
    pad = target_len - T
    return F.pad(waveform, (0, pad))

def print_preprocessing_report(config, profile, img_size: int):
    """
    Rich report of what the automatic audio preprocessing decided.
    Shows:
      - SR and clip length
      - duration distribution
      - RMS distribution
      - SR distribution
      - min/median files per class
      - spectrogram shapes (before/after resize)
      - SpecAugment configuration
    """
    durations = profile["durations"]
    rms_vals = profile["rms_values"]
    sr_counts = profile["sr_counts"]
    files_per_class = profile["files_per_class"]

    # --- basic dataset stats ---
    num_files = profile["num_files"]
    num_classes = len(files_per_class) if files_per_class else 0
    min_per_class = min(files_per_class.values()) if files_per_class else 0
    median_per_class = (
        int(np.median(list(files_per_class.values())))
        if files_per_class else 0
    )

    # --- duration stats ---
    if durations.size > 0:
        dur_mean = float(durations.mean())
        dur_p5 = float(np.percentile(durations, 5))
        dur_p50 = float(np.percentile(durations, 50))
        dur_p95 = float(np.percentile(durations, 95))
    else:
        dur_mean = dur_p5 = dur_p50 = dur_p95 = 0.0

    # --- RMS stats ---
    if rms_vals.size > 0:
        rms_mean = float(rms_vals.mean())
        rms_std = float(rms_vals.std())
    else:
        rms_mean = rms_std = 0.0

    # --- SR distribution string ---
    if sr_counts:
        sr_items = sorted(sr_counts.items(), key=lambda kv: -kv[1])
        sr_str = ", ".join(f"{sr}Hz({cnt})" for sr, cnt in sr_items)
    else:
        sr_str = "N/A"

    # --- spectrogram shape estimates ---
    target_samples = int(config.clip_seconds * config.target_sr)
    est_frames = max(
        1,
        1 + (target_samples - config.win_length) // config.hop_length
    )
    spec_before = f"(1, {config.n_mels}, {est_frames})"
    spec_after = f"(1, {img_size}, {img_size})"

    # --- one-line summary ---
    summary = (
        f"[Audio Preproc] SR={config.target_sr} | "
        f"Clip={config.clip_seconds:.2f}s | "
        f"Mels={config.n_mels} | "
        f"GlobalNorm={'ON' if config.use_global_norm else 'OFF'} | "
        f"SpecAug={'ON' if config.use_spec_augment else 'OFF'} "
        f"(Tmask={config.time_mask_param}Ã—{config.num_time_masks}, "
        f"Fmask={config.freq_mask_param}Ã—{config.num_freq_masks}) | "
        f"Files={num_files}, Classes={num_classes}, Min/class={min_per_class}"
    )

    print(summary)
    print(
        f"  Durations (s): mean={dur_mean:.3f}, p5={dur_p5:.3f}, "
        f"p50={dur_p50:.3f}, p95={dur_p95:.3f}"
    )
    print(
        f"  RMS: mean={rms_mean:.5f}, std={rms_std:.5f} | "
        f"Files/class median={median_per_class}"
    )
    print(
        f"  Sample rates: {sr_str}"
    )
    print(
        f"  Spectrogram shapes: raw={spec_before} -> resized={spec_after}"
    )




# ============================================================
#   Automated Audio Preprocessing (classes + best practices)
# ============================================================

@dataclass
class AudioPreprocessConfig:
    """Configuration for audio preprocessing."""
    target_sr: int = 16000
    clip_seconds: float = 1.0
    n_mels: int = 64
    n_fft: int = 512
    hop_length: int = 160
    win_length: int = 400
    use_global_norm: bool = True
    use_spec_augment: bool = True
    time_mask_param: int = 8
    freq_mask_param: int = 8
    num_time_masks: int = 1
    num_freq_masks: int = 1


class AudioPreprocessor:
    """
    Scikit-learn style:
      - fit(dataset) learns global stats over log-Mel specs
      - transform(waveform, sr) -> normalized log-Mel image
    """
    def __init__(self, config: Optional[AudioPreprocessConfig] = None):
        self.config = config or AudioPreprocessConfig()
        self._mel = None
        self._to_db = AT.AmplitudeToDB(stype="power")
        self.global_mean: Optional[float] = None
        self.global_std: Optional[float] = None
        self._time_mask = None
        self._freq_mask = None

    def _ensure_transforms(self):
        """Create MelSpectrogram and SpecAugment modules lazily."""
        if self._mel is None:
            cfg = self.config
            self._mel = AT.MelSpectrogram(
                sample_rate=cfg.target_sr,
                n_fft=cfg.n_fft,
                hop_length=cfg.hop_length,
                win_length=cfg.win_length,
                n_mels=cfg.n_mels,
            )
        if self.config.use_spec_augment and self._time_mask is None:
            cfg = self.config
            self._time_mask = AT.TimeMasking(time_mask_param=cfg.time_mask_param)
            self._freq_mask = AT.FrequencyMasking(freq_mask_param=cfg.freq_mask_param)

    def _waveform_to_logmel(self, waveform: torch.Tensor, sr: int) -> torch.Tensor:
        """
        Convert arbitrary waveform to log-Mel spectrogram without resizing.
        Returns tensor of shape (1, n_mels, time_frames).
        """
        cfg = self.config
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        if sr != cfg.target_sr:
            waveform = torchaudio.functional.resample(waveform, sr, cfg.target_sr)
        target_len = int(cfg.clip_seconds * cfg.target_sr)
        waveform = pad_or_trim(waveform, target_len)

        self._ensure_transforms()
        spec = self._mel(waveform)      # (1, n_mels, time)
        spec = self._to_db(spec)        # log-mel
        return spec

    def fit(self, dataset, max_items: int = 512) -> "AudioPreprocessor":
        """
        Learn global mean/std over log-Mel spectrograms from a dataset.
        Dataset is expected to yield (waveform, sr, label, ...).
        """
        sums = 0.0
        sumsq = 0.0
        count = 0

        n = len(dataset)
        indices = list(range(n))
        random.shuffle(indices)
        indices = indices[:max_items]

        with torch.no_grad():
            for i in indices:
                sample = dataset[i]
                # SpeechCommands returns: (waveform, sample_rate, label, speaker_id, utt_num)
                waveform, sr = sample[0], sample[1]
                spec = self._waveform_to_logmel(waveform, sr)  # (1, n_mels, time)
                v = spec.reshape(-1)
                sums += float(v.sum())
                sumsq += float((v ** 2).sum())
                count += v.numel()

        if count > 0:
            self.global_mean = sums / count
            self.global_std = max(1e-6, (sumsq / count - self.global_mean ** 2) ** 0.5)
        else:
            self.global_mean, self.global_std = 0.0, 1.0

        return self

    def transform(self, waveform: torch.Tensor, sr: int, img_size: int, augment: bool = False) -> torch.Tensor:
        """
        Full preprocessing pipeline: waveform -> normalized log-Mel image.
        Returns tensor of shape (1, img_size, img_size).
        """
        cfg = self.config
        spec = self._waveform_to_logmel(waveform, sr)  # (1, n_mels, time)

        # Normalization
        if cfg.use_global_norm and self.global_mean is not None:
            spec = (spec - self.global_mean) / (self.global_std + 1e-6)
        else:
            # Per-sample normalization fallback
            mean = spec.mean()
            std = spec.std()
            spec = (spec - mean) / (std + 1e-6)

        # SpecAugment (only on training)
        if augment and cfg.use_spec_augment:
            self._ensure_transforms()
            x = spec
            x = self._time_mask(x)
            x = self._freq_mask(x)
            spec = x

        # Resize to square image for the NAS/CNN backbone
        spec = spec.unsqueeze(0)  # (B=1, C=1, H, W)
        spec = F.interpolate(spec, size=(img_size, img_size),
                             mode="bilinear", align_corners=False)
        spec = spec.squeeze(0)    # (C=1, H, W)
        return spec


class AutoAudioPreprocessor:
    """
    Wrapper that:
    - profiles the dataset
    - chooses sensible config values
    - fits global normalization statistics
    - exposes train/test transforms compatible with the existing pipeline
    """
    def __init__(self, base_config: Optional[AudioPreprocessConfig] = None):
        self.base_config = base_config or AudioPreprocessConfig()
        self.config: Optional[AudioPreprocessConfig] = None
        self.preproc: Optional[AudioPreprocessor] = None
        self.profile: Optional[Dict[str, Any]] = None

    def _profile_dataset(self, dataset, max_items: int = 256) -> Dict[str, Any]:
        sr_counts: Dict[int, int] = {}
        durations = []
        rms_values = []
        files_per_class: Dict[str, int] = {}

        n = len(dataset)
        idxs = list(range(n))
        random.shuffle(idxs)
        idxs = idxs[:max_items]

        for i in idxs:
            waveform, sr, label, *_ = dataset[i]
            if waveform.dim() == 1:
                waveform = waveform.unsqueeze(0)
            sr_counts[sr] = sr_counts.get(sr, 0) + 1
            dur = waveform.shape[-1] / sr
            durations.append(float(dur))
            rms = float(torch.sqrt((waveform ** 2).mean()))
            rms_values.append(rms)
            files_per_class[label] = files_per_class.get(label, 0) + 1

        return {
            "sr_counts": sr_counts,
            "durations": np.array(durations, dtype=np.float32),
            "rms_values": np.array(rms_values, dtype=np.float32),
            "files_per_class": files_per_class,
            "num_files": n,
        }

    def _choose_config(self, img_size: int) -> AudioPreprocessConfig:
        p = self.profile
        cfg = self.base_config

        # Target sample rate: most common in the dataset, fallback to 16 kHz
        if p["sr_counts"]:
            cfg.target_sr = max(p["sr_counts"].keys(), key=lambda s: p["sr_counts"][s])
        else:
            cfg.target_sr = 16000

        # Clip duration: 95th percentile of durations, clamped
        if p["durations"].size > 0:
            p95 = float(np.percentile(p["durations"], 95))
            cfg.clip_seconds = float(max(0.5, min(5.0, p95)))
        else:
            cfg.clip_seconds = 1.0

        # Global normalization: good when you have enough data
        cfg.use_global_norm = p["num_files"] > 500

        # Augmentation strength based on per-class counts
        if p["files_per_class"]:
            min_per_class = min(p["files_per_class"].values())
        else:
            min_per_class = 0

        # Base mask size roughly proportional to image size
        base_mask = max(2, img_size // 10)
        if min_per_class < 50:
            # Fewer examples per class -> stronger augmentation
            cfg.use_spec_augment = True
            cfg.time_mask_param = base_mask * 2
            cfg.freq_mask_param = base_mask * 2
            cfg.num_time_masks = 2
            cfg.num_freq_masks = 2
        else:
            cfg.use_spec_augment = True
            cfg.time_mask_param = base_mask
            cfg.freq_mask_param = base_mask
            cfg.num_time_masks = 1
            cfg.num_freq_masks = 1

        self.config = cfg
        return cfg

    def fit(self, dataset, img_size: int) -> "AutoAudioPreprocessor":
        # 1) Profile dataset
        self.profile = self._profile_dataset(dataset)
        # 2) Choose config based on profile + image size
        cfg = self._choose_config(img_size)
        # 3) Build preprocessor and fit its global normalization (if enabled)
        self.preproc = AudioPreprocessor(cfg)
        if cfg.use_global_norm:
            self.preproc.fit(dataset)
        return self

    def make_transforms(self, img_size: int):
        assert self.preproc is not None, "Call fit(...) before make_transforms()."

        def train_tf(waveform: torch.Tensor, sr: int) -> torch.Tensor:
            return self.preproc.transform(waveform, sr, img_size=img_size, augment=True)

        def test_tf(waveform: torch.Tensor, sr: int) -> torch.Tensor:
            return self.preproc.transform(waveform, sr, img_size=img_size, augment=False)

        return train_tf, test_tf


# ============================================================
#   Dataset wrapper + dataloaders (API unchanged)
# ============================================================

class SpeechCommandsWrapped(torch.utils.data.Dataset):
    def __init__(self, base_ds, label2idx, tf):
        self.ds = base_ds
        self.label2idx = label2idx
        self.tf = tf

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        waveform, sr, label, *_ = self.ds[idx]
        x = self.tf(waveform, sr)
        y = self.label2idx[label]
        return x, y


def _build_label_mapping(train_base):
    labels = set()
    for i in range(len(train_base)):
        _, _, label, *_ = train_base[i]
        labels.add(label)
    labels = sorted(list(labels))
    label2idx = {lab: i for i, lab in enumerate(labels)}
    idx2label = {i: lab for lab, i in label2idx.items()}
    return label2idx, idx2label


def get_dataloaders(img_size: int):
    """
    Public API is the same as in your original notebook:
        train_loader, val_loader, test_loader = get_dataloaders(img_size)

    Internally, we now:
      - profile the audio
      - auto-configure preprocessing
      - learn global log-Mel normalization
    """
    global NUM_CLASSES, TARGET_SR, CLIP_SECONDS, TARGET_SAMPLES

    # Built-in SpeechCommands splits (unchanged)
    train_base = torchaudio.datasets.SPEECHCOMMANDS(
        root=DATA_ROOT, download=True, subset="training"
    )
    val_base = torchaudio.datasets.SPEECHCOMMANDS(
        root=DATA_ROOT, download=True, subset="validation"
    )
    test_base = torchaudio.datasets.SPEECHCOMMANDS(
        root=DATA_ROOT, download=True, subset="testing"
    )

    # Label mapping based on training subset
    label2idx, idx2label = _build_label_mapping(train_base)
    NUM_CLASSES = len(label2idx)

    # Automated audio preprocessing
    auto_preproc = AutoAudioPreprocessor()
    auto_preproc.fit(train_base, img_size=img_size)
    # ðŸ‘‰ print report here
    print_preprocessing_report(auto_preproc.config, auto_preproc.profile, img_size)

    train_tf, test_tf = auto_preproc.make_transforms(img_size=img_size)

    # Expose inferred low-level parameters if you still need them anywhere
    TARGET_SR = auto_preproc.config.target_sr
    CLIP_SECONDS = auto_preproc.config.clip_seconds
    TARGET_SAMPLES = int(TARGET_SR * CLIP_SECONDS)

    # Wrap datasets
    train_set = SpeechCommandsWrapped(train_base, label2idx, train_tf)
    val_set   = SpeechCommandsWrapped(val_base,   label2idx, test_tf)
    test_set  = SpeechCommandsWrapped(test_base,  label2idx, test_tf)

    # Dataloaders (same structure as before)
    g = torch.Generator().manual_seed(SEED)
    train_loader = DataLoader(
        train_set, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=True,
        worker_init_fn=_seed_worker, generator=g
    )
    val_loader = DataLoader(
        val_set, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=True,
        worker_init_fn=_seed_worker, generator=g
    )
    test_loader = DataLoader(
        test_set, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=True,
        worker_init_fn=_seed_worker, generator=g
    )
    return train_loader, val_loader, test_loader


In [None]:
!pip install -q torchcodec

import time
import torch
import torch.nn as nn
import torch.optim as optim
import os
import math
import torch.nn.functional as F

os.makedirs(DATA_ROOT, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

if device.type == "cuda":
    torch.backends.cudnn.benchmark = True

# Sweep configs
SWEEP = [
    {"name": "dscnn_plus_full_res64",         "backbone": "dscnn_plus",         "unfreeze_blocks": 999, "img_size": 64, "lr": 1.2e-3},
    {"name": "dscnn_plus_last3_res64",        "backbone": "dscnn_plus",         "unfreeze_blocks": 3,   "img_size": 64, "lr": 1.5e-3},
    {"name": "mbv2_tiny_full_res64",          "backbone": "mobilenetv2_tiny",   "unfreeze_blocks": 999, "img_size": 64, "lr": 1.0e-3},
    {"name": "mbv2_small_full_res64",         "backbone": "mobilenetv2_small",  "unfreeze_blocks": 999, "img_size": 64, "lr": 9.0e-4},
]

EPOCHS_PER_TRIAL = 3
FINAL_EPOCHS = 10

LABEL_SMOOTHING = 0.08
MIXUP_ALPHA = 0.35
USE_MIXUP = True


# Blocks / Models
class DSConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, dropout=0.0):
        super().__init__()
        self.dw = nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=stride, padding=1, groups=in_ch, bias=False)
        self.pw = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)
        self.drop = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()

    def forward(self, x):
        x = self.dw(x)
        x = self.pw(x)
        x = self.bn(x)
        x = self.act(x)
        return self.drop(x)


class DSCNNPlus(nn.Module):
    def __init__(self, num_classes: int, drop=0.15):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(1, 48, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True),
        )
        self.blocks = nn.ModuleList([
            DSConvBlock(48, 96,  stride=1, dropout=drop),
            DSConvBlock(96, 192, stride=2, dropout=drop),
            DSConvBlock(192, 192, stride=1, dropout=drop),
            DSConvBlock(192, 256, stride=2, dropout=drop),
            DSConvBlock(256, 256, stride=1, dropout=drop),
        ])
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.head_drop = nn.Dropout(drop)
        self.head = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.stem(x)
        for b in self.blocks:
            x = b(x)
        x = self.pool(x).flatten(1)
        x = self.head_drop(x)
        return self.head(x)


def _make_divisible(v, divisor=8):
    return int(math.ceil(v / divisor) * divisor)


class InvertedResidual(nn.Module):
    def __init__(self, in_ch, out_ch, stride, expand_ratio, drop=0.0):
        super().__init__()
        assert stride in [1, 2]
        hidden = int(round(in_ch * expand_ratio))
        self.use_res = (stride == 1 and in_ch == out_ch)

        layers = []
        if expand_ratio != 1:
            layers += [
                nn.Conv2d(in_ch, hidden, 1, bias=False),
                nn.BatchNorm2d(hidden),
                nn.ReLU6(inplace=True),
            ]
        layers += [
            nn.Conv2d(hidden, hidden, 3, stride=stride, padding=1, groups=hidden, bias=False),
            nn.BatchNorm2d(hidden),
            nn.ReLU6(inplace=True),
            nn.Conv2d(hidden, out_ch, 1, bias=False),
            nn.BatchNorm2d(out_ch),
        ]
        self.block = nn.Sequential(*layers)
        self.drop = nn.Dropout2d(drop) if drop > 0 else nn.Identity()

    def forward(self, x):
        out = self.block(x)
        out = self.drop(out)
        if self.use_res:
            return x + out
        return out


class MobileNetV2Spec(nn.Module):
    def __init__(self, num_classes: int, width_mult=0.75, drop=0.10):
        super().__init__()
        cfg = [
            (1,  16, 1, 1),
            (6,  24, 2, 2),
            (6,  32, 3, 2),
            (6,  64, 3, 2),
            (6,  96, 2, 1),
            (6, 160, 1, 2),
        ]

        in_ch = _make_divisible(32 * width_mult)
        self.stem = nn.Sequential(
            nn.Conv2d(1, in_ch, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(in_ch),
            nn.ReLU6(inplace=True),
        )

        blocks = []
        for t, c, n, s in cfg:
            out_ch = _make_divisible(c * width_mult)
            for i in range(n):
                stride = s if i == 0 else 1
                blocks.append(InvertedResidual(in_ch, out_ch, stride=stride, expand_ratio=t, drop=drop))
                in_ch = out_ch
        self.blocks = nn.ModuleList(blocks)

        last_ch = _make_divisible(192 * width_mult)
        self.last = nn.Sequential(
            nn.Conv2d(in_ch, last_ch, 1, bias=False),
            nn.BatchNorm2d(last_ch),
            nn.ReLU6(inplace=True),
        )

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.head_drop = nn.Dropout(drop)
        self.head = nn.Linear(last_ch, num_classes)

    def forward(self, x):
        x = self.stem(x)
        for b in self.blocks:
            x = b(x)
        x = self.last(x)
        x = self.pool(x).flatten(1)
        x = self.head_drop(x)
        return self.head(x)


def build_backbone(backbone: str, num_classes: int):
    if backbone == "dscnn_plus":
        return DSCNNPlus(num_classes=num_classes, drop=0.15)
    elif backbone == "mobilenetv2_tiny":
        # Very close to DSCNN speed; slightly different capacity/bias
        return MobileNetV2Spec(num_classes=num_classes, width_mult=0.50, drop=0.10)
    elif backbone == "mobilenetv2_small":
        # Still fast, generally stronger than tiny
        return MobileNetV2Spec(num_classes=num_classes, width_mult=0.75, drop=0.10)
    else:
        raise ValueError(f"Unknown backbone: {backbone}")

def unfreeze_module(m: nn.Module):
    for p in m.parameters():
        p.requires_grad = True

def freeze_module(m: nn.Module):
    for p in m.parameters():
        p.requires_grad = False

def get_head_module(model: nn.Module) -> nn.Module:
    if hasattr(model, "head"):
        return model.head
    raise ValueError("Could not find classifier head (expected .head)")

def get_block_list(backbone: str, model: nn.Module):
    if hasattr(model, "blocks"):
        return list(model.blocks)
    raise ValueError(f"Unsupported backbone for block unfreezing: {backbone}")

def freeze_all_but_head(model: nn.Module):
    freeze_module(model)
    unfreeze_module(get_head_module(model))

def unfreeze_last_n_blocks(model: nn.Module, backbone: str, n_blocks: int):
    blocks = get_block_list(backbone, model)
    if n_blocks >= 999 or n_blocks >= len(blocks):
        unfreeze_module(model)
        return
    freeze_all_but_head(model)
    for b in blocks[-n_blocks:]:
        unfreeze_module(b)

def smooth_one_hot(y: torch.Tensor, num_classes: int, smoothing: float):
    with torch.no_grad():
        y_oh = torch.zeros((y.size(0), num_classes), device=y.device, dtype=torch.float32)
        y_oh.scatter_(1, y.unsqueeze(1), 1.0)
        if smoothing > 0:
            y_oh = y_oh * (1.0 - smoothing) + smoothing / num_classes
    return y_oh

def mixup_batch(x, y, alpha: float, num_classes: int, smoothing: float):
    if alpha <= 0:
        return x, smooth_one_hot(y, num_classes, smoothing)

    lam = torch.distributions.Beta(alpha, alpha).sample((x.size(0),)).to(x.device)
    lam = torch.maximum(lam, 1.0 - lam)
    lam_x = lam.view(-1, 1, 1, 1)

    idx = torch.randperm(x.size(0), device=x.device)
    x2 = x[idx]
    y2 = y[idx]

    y1_sm = smooth_one_hot(y,  num_classes, smoothing)
    y2_sm = smooth_one_hot(y2, num_classes, smoothing)

    x_mix = x * lam_x + x2 * (1.0 - lam_x)
    lam_y = lam.view(-1, 1)
    y_mix = y1_sm * lam_y + y2_sm * (1.0 - lam_y)
    return x_mix, y_mix

def soft_target_ce(logits: torch.Tensor, soft_targets: torch.Tensor):
    log_probs = F.log_softmax(logits, dim=1)
    return -(soft_targets * log_probs).sum(dim=1).mean()


# Optimizer / scheduler
def make_optimizer(model, lr: float):
    params = [p for p in model.parameters() if p.requires_grad]
    return optim.AdamW(params, lr=lr, weight_decay=1e-4)

def make_scheduler(opt, total_steps: int, warmup_steps: int):
    def lr_lambda(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + math.cos(math.pi * progress))
    return optim.lr_scheduler.LambdaLR(opt, lr_lambda)

def count_params(model: nn.Module):
    return sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad)


# Train / eval utils
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    total_loss = 0.0

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=(device.type == "cuda")):
            logits = model(x)
            y_sm = smooth_one_hot(y, NUM_CLASSES, LABEL_SMOOTHING)
            loss = soft_target_ce(logits, y_sm)

        total_loss += loss.item() * x.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)

    return total_loss / max(1, total), correct / max(1, total)


def train_one_epoch(model, loader, optimizer, scheduler, scaler=None):
    model.train()
    correct, total = 0, 0
    total_loss = 0.0

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        if USE_MIXUP:
            x, y_soft = mixup_batch(x, y, MIXUP_ALPHA, NUM_CLASSES, LABEL_SMOOTHING)
        else:
            y_soft = smooth_one_hot(y, NUM_CLASSES, LABEL_SMOOTHING)

        optimizer.zero_grad(set_to_none=True)

        use_amp = (device.type == "cuda")
        with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=use_amp):
            logits = model(x)
            loss = soft_target_ce(logits, y_soft)

        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        if scheduler is not None:
            scheduler.step()

        total_loss += loss.item() * x.size(0)

        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += x.size(0)

    return total_loss / max(1, total), correct / max(1, total)


best_cfg = None
best_val_acc = -1.0

for cfg in SWEEP:
    t0 = time.time()

    train_loader, val_loader, test_loader = get_dataloaders(cfg["img_size"])

    model = build_backbone(cfg["backbone"], num_classes=NUM_CLASSES).to(device)
    unfreeze_last_n_blocks(model, backbone=cfg["backbone"], n_blocks=cfg["unfreeze_blocks"])

    opt = make_optimizer(model, lr=cfg["lr"])

    steps_per_epoch = len(train_loader)
    total_steps = steps_per_epoch * EPOCHS_PER_TRIAL
    warmup_steps = max(20, int(0.05 * total_steps))
    sched = make_scheduler(opt, total_steps=total_steps, warmup_steps=warmup_steps)

    scaler = torch.amp.GradScaler(enabled=(device.type == "cuda"))

    for epoch in range(1, EPOCHS_PER_TRIAL + 1):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, opt, sched, scaler=scaler)
        va_loss, va_acc = evaluate(model, val_loader)
        print(f"[{cfg['name']}] epoch {epoch}/{EPOCHS_PER_TRIAL} | tr_acc={tr_acc:.4f} | va_acc={va_acc:.4f}")

    elapsed = time.time() - t0
    total_p, train_p = count_params(model)
    print(f"[{cfg['name']}] done in {elapsed:.1f}s | params={total_p/1e6:.2f}M (trainable {train_p/1e6:.2f}M) | val_acc={va_acc:.4f}\n")

    if va_acc > best_val_acc:
        best_val_acc = va_acc
        best_cfg = cfg

print("Best cfg:", best_cfg)
print("Best val acc:", best_val_acc)

# Final train with best cfg
train_loader, val_loader, test_loader = get_dataloaders(best_cfg["img_size"])
best_model = build_backbone(best_cfg["backbone"], num_classes=NUM_CLASSES).to(device)
unfreeze_last_n_blocks(best_model, backbone=best_cfg["backbone"], n_blocks=best_cfg["unfreeze_blocks"])

optimizer = make_optimizer(best_model, lr=best_cfg["lr"])

steps_per_epoch = len(train_loader)
total_steps = steps_per_epoch * FINAL_EPOCHS
warmup_steps = max(50, int(0.08 * total_steps))
scheduler = make_scheduler(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)

scaler = torch.amp.GradScaler(enabled=(device.type == "cuda"))

best_val = 0.0
best_state = None

for epoch in range(1, FINAL_EPOCHS + 1):
    tr_loss, tr_acc = train_one_epoch(best_model, train_loader, optimizer, scheduler, scaler=scaler)
    va_loss, va_acc = evaluate(best_model, val_loader)

    print(f"[FINAL] epoch {epoch}/{FINAL_EPOCHS} | tr_acc={tr_acc:.4f} | va_acc={va_acc:.4f}")

    if va_acc > best_val:
        best_val = va_acc
        best_state = {k: v.detach().cpu().clone() for k, v in best_model.state_dict().items()}

if best_state is not None:
    best_model.load_state_dict(best_state)

test_loss, final_test_acc = evaluate(best_model, test_loader)
print(f"Final test accuracy: {final_test_acc:.4f}")

In [None]:
import os, copy
import torch
import torch.nn as nn

@torch.no_grad()
def _evaluate_acc(model, loader, device):
    model.eval()

    want_half = False
    try:
        p0 = next(model.parameters())
        want_half = (p0.dtype == torch.float16)
    except StopIteration:
        pass

    correct, total = 0, 0
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        if want_half:
            x = x.half()
        logits = model(x)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.numel()
    return float(correct / max(1, total))

def _state_dict_size_mb(model, tmp_path="__tmp_size.pth"):
    try:
        torch.save(model.state_dict(), tmp_path)
        return os.path.getsize(tmp_path) / (1024 * 1024)
    finally:
        try:
            if os.path.exists(tmp_path):
                os.remove(tmp_path)
        except Exception:
            pass

@torch.no_grad()
def _probe_forward(model, loader, device):
    model.eval()
    try:
        x, _ = next(iter(loader))
        x = x[:1].to(device)

        # Match dtype if fp16 model
        try:
            p0 = next(model.parameters())
            if p0.dtype == torch.float16:
                x = x.half()
        except StopIteration:
            pass

        _ = model(x)
        return True, ""
    except Exception as e:
        return False, f"{type(e).__name__}: {e}"

def quantize_model(best_model, device, test_loader, max_abs_drop=0.01, min_size_gain=0.05):
    report = {"fp32": {}, "fp16": {}, "selected": None}

    # ---- baseline fp32 ----
    fp32 = best_model.to(device).eval()
    ok, reason = _probe_forward(fp32, test_loader, device)
    if not ok:
        raise RuntimeError(f"FP32 probe failed (unexpected): {reason}")

    fp32_acc  = _evaluate_acc(fp32, test_loader, device)
    fp32_size = _state_dict_size_mb(fp32)
    report["fp32"] = {"supported": True, "acc": fp32_acc, "size_mb": fp32_size}

    # ---- candidate fp16 ----
    fp16_supported = False
    fp16_acc = None
    fp16_size = None
    fp16_model = None
    fp16_reason = ""

    if device.type == "cuda":
        fp16_model = copy.deepcopy(fp32).to(device).half().eval()
        ok, reason = _probe_forward(fp16_model, test_loader, device)
        fp16_supported = ok
        fp16_reason = reason

        if ok:
            fp16_acc  = _evaluate_acc(fp16_model, test_loader, device)
            fp16_size = _state_dict_size_mb(fp16_model)

    report["fp16"] = {
        "supported": bool(fp16_supported),
        "reason": "" if fp16_supported else fp16_reason,
        "acc": fp16_acc,
        "size_mb": fp16_size,
    }

    selected_name = "fp32"
    selected_model = fp32

    if fp16_supported:
        abs_drop = fp32_acc - fp16_acc
        size_gain = (fp32_size - fp16_size) / max(1e-9, fp32_size)
        if (abs_drop <= max_abs_drop) and (size_gain >= min_size_gain):
            selected_name = "fp16"
            selected_model = fp16_model

    report["selected"] = {
        "name": selected_name,
        "baseline_acc": fp32_acc,
        "baseline_size_mb": fp32_size,
        "chosen_acc": _evaluate_acc(selected_model, test_loader, device),
        "chosen_size_mb": _state_dict_size_mb(selected_model),
        "max_abs_drop": float(max_abs_drop),
        "min_size_gain": float(min_size_gain),
    }

    return selected_model, report


In [None]:
optimized_best_model, opt_report = quantize_model(
    best_model=best_model,
    device=device,
    test_loader=test_loader,
    max_abs_drop=0.01,
    min_size_gain=0.05
)

print("Optimization report:", opt_report)


In [None]:
import os
import torch

save_path = "best_model.pth"
torch.save(best_model.state_dict(), save_path)

size_mb = os.path.getsize(save_path) / (1024 * 1024)

print("Saved:", save_path)
print(f"Best model test accuracy: {final_test_acc:.4f}")
print(f"Saved .pth size: {size_mb:.2f} MB")

In [None]:
import os
import torch

optimized_test_acc = _evaluate_acc(optimized_best_model, test_loader, device)

save_path = "optimized_best_model.pth"
torch.save(optimized_best_model.state_dict(), save_path)
size_mb = os.path.getsize(save_path) / (1024 * 1024)

print("Saved:", save_path)
print(f"Optimized selected={opt_report['selected']['name']} test accuracy: {optimized_test_acc:.4f}")
print(f"Saved .pth size: {size_mb:.2f} MB")

if "final_test_acc" in globals():
    delta = optimized_test_acc - final_test_acc
    print(f"Accuracy vs FP32 best: {delta:+.4f}")
