In [None]:
import random, numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

import torchaudio
import torchaudio.transforms as AT

from dataclasses import dataclass
from typing import Optional, Dict, Any

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Dataset config
DATA_ROOT = "./data"
BATCH_SIZE = 128
NUM_WORKERS = 2

# These will be inferred automatically
TARGET_SR: Optional[int] = None
CLIP_SECONDS: Optional[float] = None
TARGET_SAMPLES: Optional[int] = None

# Will be filled in get_dataloaders and used by the models
NUM_CLASSES: Optional[int] = None


def _seed_worker(worker_id: int):
    worker_seed = (SEED + worker_id) % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def pad_or_trim(waveform: torch.Tensor, target_len: int) -> torch.Tensor:
    if waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)
    T = waveform.shape[-1]
    if T == target_len:
        return waveform
    if T > target_len:
        return waveform[:, :target_len]
    pad = target_len - T
    return F.pad(waveform, (0, pad))

def print_preprocessing_report(config, profile, img_size: int):
    durations = profile["durations"]
    rms_vals = profile["rms_values"]
    sr_counts = profile["sr_counts"]
    files_per_class = profile["files_per_class"]

    # --- basic dataset stats ---
    num_files = profile["num_files"]
    num_classes = len(files_per_class) if files_per_class else 0
    min_per_class = min(files_per_class.values()) if files_per_class else 0
    median_per_class = (
        int(np.median(list(files_per_class.values())))
        if files_per_class else 0
    )

    # --- duration stats ---
    if durations.size > 0:
        dur_mean = float(durations.mean())
        dur_p5 = float(np.percentile(durations, 5))
        dur_p50 = float(np.percentile(durations, 50))
        dur_p95 = float(np.percentile(durations, 95))
    else:
        dur_mean = dur_p5 = dur_p50 = dur_p95 = 0.0

    # --- RMS stats ---
    if rms_vals.size > 0:
        rms_mean = float(rms_vals.mean())
        rms_std = float(rms_vals.std())
    else:
        rms_mean = rms_std = 0.0

    # --- SR distribution string ---
    if sr_counts:
        sr_items = sorted(sr_counts.items(), key=lambda kv: -kv[1])
        sr_str = ", ".join(f"{sr}Hz({cnt})" for sr, cnt in sr_items)
    else:
        sr_str = "N/A"

    # --- spectrogram shape estimates ---
    target_samples = int(config.clip_seconds * config.target_sr)
    est_frames = max(
        1,
        1 + (target_samples - config.win_length) // config.hop_length
    )
    spec_before = f"(1, {config.n_mels}, {est_frames})"
    spec_after = f"(1, {img_size}, {img_size})"

    # --- one-line summary ---
    summary = (
        f"[Audio Preproc] SR={config.target_sr} | "
        f"Clip={config.clip_seconds:.2f}s | "
        f"Mels={config.n_mels} | "
        f"GlobalNorm={'ON' if config.use_global_norm else 'OFF'} | "
        f"SpecAug={'ON' if config.use_spec_augment else 'OFF'} "
        f"(Tmask={config.time_mask_param}×{config.num_time_masks}, "
        f"Fmask={config.freq_mask_param}×{config.num_freq_masks}) | "
        f"Files={num_files}, Classes={num_classes}, Min/class={min_per_class}"
    )

    print(summary)
    print(
        f"  Durations (s): mean={dur_mean:.3f}, p5={dur_p5:.3f}, "
        f"p50={dur_p50:.3f}, p95={dur_p95:.3f}"
    )
    print(
        f"  RMS: mean={rms_mean:.5f}, std={rms_std:.5f} | "
        f"Files/class median={median_per_class}"
    )
    print(
        f"  Sample rates: {sr_str}"
    )
    print(
        f"  Spectrogram shapes: raw={spec_before} -> resized={spec_after}"
    )

#   Automated Audio Preprocessing 
@dataclass
class AudioPreprocessConfig:
    target_sr: int = 16000
    clip_seconds: float = 1.0
    n_mels: int = 64
    n_fft: int = 512
    hop_length: int = 160
    win_length: int = 400
    use_global_norm: bool = True
    use_spec_augment: bool = True
    time_mask_param: int = 8
    freq_mask_param: int = 8
    num_time_masks: int = 1
    num_freq_masks: int = 1


class AudioPreprocessor:
    def __init__(self, config: Optional[AudioPreprocessConfig] = None):
        self.config = config or AudioPreprocessConfig()
        self._mel = None
        self._to_db = AT.AmplitudeToDB(stype="power")
        self.global_mean: Optional[float] = None
        self.global_std: Optional[float] = None
        self._time_mask = None
        self._freq_mask = None

    def _ensure_transforms(self):
        if self._mel is None:
            cfg = self.config
            self._mel = AT.MelSpectrogram(
                sample_rate=cfg.target_sr,
                n_fft=cfg.n_fft,
                hop_length=cfg.hop_length,
                win_length=cfg.win_length,
                n_mels=cfg.n_mels,
            )
        if self.config.use_spec_augment and self._time_mask is None:
            cfg = self.config
            self._time_mask = AT.TimeMasking(time_mask_param=cfg.time_mask_param)
            self._freq_mask = AT.FrequencyMasking(freq_mask_param=cfg.freq_mask_param)

    def _waveform_to_logmel(self, waveform: torch.Tensor, sr: int) -> torch.Tensor:
        cfg = self.config
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        if sr != cfg.target_sr:
            waveform = torchaudio.functional.resample(waveform, sr, cfg.target_sr)
        target_len = int(cfg.clip_seconds * cfg.target_sr)
        waveform = pad_or_trim(waveform, target_len)

        self._ensure_transforms()
        spec = self._mel(waveform)     
        spec = self._to_db(spec)       
        return spec

    def fit(self, dataset, max_items: int = 512) -> "AudioPreprocessor":
        sums = 0.0
        sumsq = 0.0
        count = 0

        n = len(dataset)
        indices = list(range(n))
        random.shuffle(indices)
        indices = indices[:max_items]

        with torch.no_grad():
            for i in indices:
                sample = dataset[i]
                waveform, sr = sample[0], sample[1]
                spec = self._waveform_to_logmel(waveform, sr) 
                v = spec.reshape(-1)
                sums += float(v.sum())
                sumsq += float((v ** 2).sum())
                count += v.numel()

        if count > 0:
            self.global_mean = sums / count
            self.global_std = max(1e-6, (sumsq / count - self.global_mean ** 2) ** 0.5)
        else:
            self.global_mean, self.global_std = 0.0, 1.0

        return self

    def transform(self, waveform: torch.Tensor, sr: int, img_size: int, augment: bool = False) -> torch.Tensor:
        cfg = self.config
        spec = self._waveform_to_logmel(waveform, sr)  

        # Normalization
        if cfg.use_global_norm and self.global_mean is not None:
            spec = (spec - self.global_mean) / (self.global_std + 1e-6)
        else:
            # Per-sample normalization fallback
            mean = spec.mean()
            std = spec.std()
            spec = (spec - mean) / (std + 1e-6)

        # SpecAugment
        if augment and cfg.use_spec_augment:
            self._ensure_transforms()
            x = spec
            x = self._time_mask(x)
            x = self._freq_mask(x)
            spec = x

        # Resize to square image for the NAS/CNN backbone
        spec = spec.unsqueeze(0)  
        spec = F.interpolate(spec, size=(img_size, img_size),
                             mode="bilinear", align_corners=False)
        spec = spec.squeeze(0)    
        return spec


class AutoAudioPreprocessor:
    def __init__(self, base_config: Optional[AudioPreprocessConfig] = None):
        self.base_config = base_config or AudioPreprocessConfig()
        self.config: Optional[AudioPreprocessConfig] = None
        self.preproc: Optional[AudioPreprocessor] = None
        self.profile: Optional[Dict[str, Any]] = None

    def _profile_dataset(self, dataset, max_items: int = 256) -> Dict[str, Any]:
        sr_counts: Dict[int, int] = {}
        durations = []
        rms_values = []
        files_per_class: Dict[str, int] = {}

        n = len(dataset)
        idxs = list(range(n))
        random.shuffle(idxs)
        idxs = idxs[:max_items]

        for i in idxs:
            waveform, sr, label, *_ = dataset[i]
            if waveform.dim() == 1:
                waveform = waveform.unsqueeze(0)
            sr_counts[sr] = sr_counts.get(sr, 0) + 1
            dur = waveform.shape[-1] / sr
            durations.append(float(dur))
            rms = float(torch.sqrt((waveform ** 2).mean()))
            rms_values.append(rms)
            files_per_class[label] = files_per_class.get(label, 0) + 1

        return {
            "sr_counts": sr_counts,
            "durations": np.array(durations, dtype=np.float32),
            "rms_values": np.array(rms_values, dtype=np.float32),
            "files_per_class": files_per_class,
            "num_files": n,
        }

    def _choose_config(self, img_size: int) -> AudioPreprocessConfig:
        p = self.profile
        cfg = self.base_config

        # Target sample rate: most common in the dataset, fallback to 16 kHz
        if p["sr_counts"]:
            cfg.target_sr = max(p["sr_counts"].keys(), key=lambda s: p["sr_counts"][s])
        else:
            cfg.target_sr = 16000

        # Clip duration: 95th percentile of durations, clamped
        if p["durations"].size > 0:
            p95 = float(np.percentile(p["durations"], 95))
            cfg.clip_seconds = float(max(0.5, min(5.0, p95)))
        else:
            cfg.clip_seconds = 1.0

        # Global normalization
        cfg.use_global_norm = p["num_files"] > 500

        # Augmentation strength based on per-class counts
        if p["files_per_class"]:
            min_per_class = min(p["files_per_class"].values())
        else:
            min_per_class = 0

        # Base mask size roughly proportional to image size
        base_mask = max(2, img_size // 10)
        if min_per_class < 50:
            cfg.use_spec_augment = True
            cfg.time_mask_param = base_mask * 2
            cfg.freq_mask_param = base_mask * 2
            cfg.num_time_masks = 2
            cfg.num_freq_masks = 2
        else:
            cfg.use_spec_augment = True
            cfg.time_mask_param = base_mask
            cfg.freq_mask_param = base_mask
            cfg.num_time_masks = 1
            cfg.num_freq_masks = 1

        self.config = cfg
        return cfg

    def fit(self, dataset, img_size: int) -> "AutoAudioPreprocessor":
        # 1) Profile dataset
        self.profile = self._profile_dataset(dataset)
        # 2) Choose config based on profile + image size
        cfg = self._choose_config(img_size)
        # 3) Build preprocessor and fit its global normalization
        self.preproc = AudioPreprocessor(cfg)
        if cfg.use_global_norm:
            self.preproc.fit(dataset)
        return self

    def make_transforms(self, img_size: int):
        assert self.preproc is not None, "Call fit(...) before make_transforms()."

        def train_tf(waveform: torch.Tensor, sr: int) -> torch.Tensor:
            return self.preproc.transform(waveform, sr, img_size=img_size, augment=True)

        def test_tf(waveform: torch.Tensor, sr: int) -> torch.Tensor:
            return self.preproc.transform(waveform, sr, img_size=img_size, augment=False)

        return train_tf, test_tf



#   Dataset wrapper + dataloaders 


class SpeechCommandsWrapped(torch.utils.data.Dataset):
    def __init__(self, base_ds, label2idx, tf):
        self.ds = base_ds
        self.label2idx = label2idx
        self.tf = tf

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        waveform, sr, label, *_ = self.ds[idx]
        x = self.tf(waveform, sr)
        y = self.label2idx[label]
        return x, y


def _build_label_mapping(train_base):
    labels = set()
    for i in range(len(train_base)):
        _, _, label, *_ = train_base[i]
        labels.add(label)
    labels = sorted(list(labels))
    label2idx = {lab: i for i, lab in enumerate(labels)}
    idx2label = {i: lab for lab, i in label2idx.items()}
    return label2idx, idx2label


def get_dataloaders(img_size: int):
    global NUM_CLASSES, TARGET_SR, CLIP_SECONDS, TARGET_SAMPLES

    # Built-in SpeechCommands splits 
    train_base = torchaudio.datasets.SPEECHCOMMANDS(
        root=DATA_ROOT, download=True, subset="training"
    )
    val_base = torchaudio.datasets.SPEECHCOMMANDS(
        root=DATA_ROOT, download=True, subset="validation"
    )
    test_base = torchaudio.datasets.SPEECHCOMMANDS(
        root=DATA_ROOT, download=True, subset="testing"
    )

    # Label mapping based on training subset
    label2idx, idx2label = _build_label_mapping(train_base)
    NUM_CLASSES = len(label2idx)

    # Automated audio preprocessing
    auto_preproc = AutoAudioPreprocessor()
    auto_preproc.fit(train_base, img_size=img_size)
    # print report here
    print_preprocessing_report(auto_preproc.config, auto_preproc.profile, img_size)

    train_tf, test_tf = auto_preproc.make_transforms(img_size=img_size)

    # Expose inferred low-level parameters if you still need them anywhere
    TARGET_SR = auto_preproc.config.target_sr
    CLIP_SECONDS = auto_preproc.config.clip_seconds
    TARGET_SAMPLES = int(TARGET_SR * CLIP_SECONDS)

    # Wrap datasets
    train_set = SpeechCommandsWrapped(train_base, label2idx, train_tf)
    val_set   = SpeechCommandsWrapped(val_base,   label2idx, test_tf)
    test_set  = SpeechCommandsWrapped(test_base,  label2idx, test_tf)

    # Dataloaders
    g = torch.Generator().manual_seed(SEED)
    train_loader = DataLoader(
        train_set, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=True,
        worker_init_fn=_seed_worker, generator=g
    )
    val_loader = DataLoader(
        val_set, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=True,
        worker_init_fn=_seed_worker, generator=g
    )
    test_loader = DataLoader(
        test_set, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=True,
        worker_init_fn=_seed_worker, generator=g
    )
    return train_loader, val_loader, test_loader
