In [1]:
!pip install soundfile



In [2]:
!pip install librosa



## Data Engineering

In [1]:
import os
from pathlib import Path
import json
import random
import math
from typing import List, Tuple, Optional, Dict

import numpy as np
import librosa
import soundfile as sf
from scipy import signal
from tqdm import tqdm

In [2]:
# -------------------------
# Noise generator (your function, slightly hardened)
# -------------------------
def create_custom_noise_profile(duration, sample_rate, overall_gain_db=-25):
    """Return a noise array length = duration*sample_rate (dtype=float32)."""
    t = np.linspace(0, duration, int(duration * sample_rate), endpoint=False)
    harmonic_noise = np.zeros_like(t)

    # harmonic hum
    fundamental_freqs = [440, 516, 645]
    for fundamental in fundamental_freqs:
        for harmonic in range(2, 7):
            freq = fundamental * harmonic
            detuned_freq = freq * (1 + np.random.uniform(-0.01, 0.01))
            amplitude = 0.15 / harmonic
            am_depth = 0.1
            am_rate = 0.5
            am_mod = 1 + am_depth * np.sin(2 * np.pi * am_rate * t)
            harmonic_noise += amplitude * am_mod * np.sin(2 * np.pi * detuned_freq * t)

    # resonant peaks (narrow band)
    resonant_freqs = [3158, 3856, 5109]
    resonant_amplitudes = [0.08, 0.08, 0.06]
    resonant_noise = np.zeros_like(t)
    nyquist = sample_rate / 2.0
    for freq, amp in zip(resonant_freqs, resonant_amplitudes):
        white = np.random.normal(0, 1, len(t))
        low_cut = max(0.0001, (freq * 0.9) / nyquist)
        high_cut = min(0.9999, (freq * 1.1) / nyquist)
        try:
            b, a = signal.butter(4, [low_cut, high_cut], btype='band')
            narrow = signal.filtfilt(b, a, white)
        except Exception:
            narrow = white
        resonant_noise += amp * narrow

    # broadband hiss shaped by FIR
    white_noise = np.random.normal(0, 1, len(t))
    try:
        from scipy.signal import firwin2
        freq_points = [0, 1000, 1290, 1548, 1858, 2229, 2675, 4000, 6000, sample_rate/2]
        gain_response = [10, 15, 15, 15, 15, 15, 15, 8, 5, 5]
        norm = np.array(freq_points) / (sample_rate/2)
        norm = np.clip(norm, 0.0, 1.0)
        fir_coeffs = firwin2(1025, norm, gain_response)
        shaped_hiss = signal.filtfilt(fir_coeffs, [1.0], white_noise)
    except Exception:
        shaped_hiss = white_noise

    # combine -> apply notch -> normalize -> gain
    combined = harmonic_noise + resonant_noise + shaped_hiss
    try:
        b_notch, a_notch = signal.iirnotch(3179.3, 4, sample_rate)
        combined = signal.filtfilt(b_notch, a_notch, combined)
    except Exception:
        pass

    maxabs = np.max(np.abs(combined)) + 1e-12
    combined = combined / maxabs
    gain_lin = 10 ** (overall_gain_db / 20.0)
    return (combined * gain_lin).astype(np.float32)


# -------------------------
# SAD: find non-silent intervals (librosa)
# -------------------------
def detect_activity_intervals(audio: np.ndarray, sr: int, top_db: float = 30.0, frame_length: int = 2048, hop_length: int = 512) -> List[Tuple[int, int]]:
    """
    Return list of (start_sample, end_sample) intervals containing activity.
    Uses librosa.effects.split which is a simple SAD (energy thresholding).
    top_db: threshold in dB below reference to consider silence (lower -> more aggressive keep)
    """
    intervals = librosa.effects.split(y=audio, top_db=top_db, frame_length=frame_length, hop_length=hop_length)
    return [(int(s), int(e)) for s, e in intervals]


# -------------------------
# Helper: sample one fixed-length clip from non-silent intervals
# -------------------------
def sample_clip_from_intervals(audio: np.ndarray, sr: int, intervals: List[Tuple[int,int]], clip_duration: float, rng: Optional[random.Random] = None) -> np.ndarray:
    """
    Choose a random interval that can contain a clip of clip_duration.
    If no single interval is long enough, try to stitch or center-pad shorter audio to clip length.
    Returns a numpy array of length = clip_duration * sr
    """
    rng = rng or random
    clip_len = int(round(clip_duration * sr))
    # Filter intervals long enough
    long_intervals = [iv for iv in intervals if (iv[1] - iv[0]) >= clip_len]
    if long_intervals:
        s, e = rng.choice(long_intervals)
        start = rng.randint(s, e - clip_len)
        clip = audio[start:start + clip_len]
        return clip.astype(np.float32)
    # otherwise try to sample from any interval, possibly concatenating up to clip_len by wrapping/padding:
    if intervals:
        # pick a random interval, extract it, then either pad or loop to reach clip_len
        s, e = rng.choice(intervals)
        seg = audio[s:e].astype(np.float32)
        if len(seg) >= clip_len:
            # deterministic crop
            start = rng.randint(0, len(seg) - clip_len)
            return seg[start:start+clip_len]
        else:
            # repeat or pad center
            needed = clip_len - len(seg)
            left = needed // 2
            right = needed - left
            return np.pad(seg, (left, right), mode='constant', constant_values=0.0)
    # if no intervals (silent file) -> zero pad or use entire audio center
    if len(audio) >= clip_len:
        center = len(audio) // 2
        start = max(0, center - clip_len // 2)
        return audio[start:start + clip_len].astype(np.float32)
    else:
        return np.pad(audio.astype(np.float32), (0, clip_len - len(audio)), mode='constant')


# -------------------------
# RMS utilities
# -------------------------
def rms(x: np.ndarray, eps=1e-12) -> float:
    return float(np.sqrt(np.mean(x.astype(np.float64)**2) + eps))


# -------------------------
# Build one simulated mixture example (core of augmentation)
# -------------------------
def build_simulated_mixture(
    stem_paths: List[Path],
    sr: int,
    clip_duration: float = 3.0,
    min_stems: int = 1,
    max_stems: int = 8,
    energy_db_range: Tuple[float, float] = (-10.0, 10.0),
    rng: Optional[random.Random] = None,
    top_db: float = 30.0
) -> Tuple[np.ndarray, List[Dict]]:
    """
    Given a list of available stems (Paths), randomly select k stems (k in [min_stems, max_stems])
    and produce a clean mixture (1D numpy array of length clip_duration*sr) and metadata list.
    Metadata describes which stems, start samples, applied dB gains, original RMS.
    """
    rng = rng or random
    n_available = len(stem_paths)
    if n_available == 0:
        raise ValueError("No stems provided to build_simulated_mixture()")

    k = rng.randint(min_stems, min(max_stems, n_available))
    selected = rng.sample(stem_paths, k)
    clip_len = int(round(clip_duration * sr))

    mixture = np.zeros(clip_len, dtype=np.float32)
    metadata = []

    for p in selected:
        audio, file_sr = librosa.load(str(p), sr=None, mono=True)
        # resample if needed
        if file_sr != sr:
            audio = librosa.resample(audio, orig_sr=file_sr, target_sr=sr)
        intervals = detect_activity_intervals(audio, sr, top_db=top_db)
        clip = sample_clip_from_intervals(audio, sr, intervals, clip_duration, rng=rng)
        orig_rms = rms(clip)
        db_change = rng.uniform(energy_db_range[0], energy_db_range[1])
        gain_lin = 10 ** (db_change / 20.0)
        scaled = (clip * gain_lin).astype(np.float32)
        # sum to mixture
        mixture = mixture + scaled
        metadata.append({
            "stem_path": str(p),
            "db_change": float(db_change),
            "gain_lin": float(gain_lin),
            "orig_rms": float(orig_rms)
        })

    # After summing, avoid clipping: scale mixture by peak if needed, but preserve RMS relationships.
    peak = float(np.max(np.abs(mixture)) + 1e-12)
    if peak > 0.99:
        mixture = (mixture / peak * 0.99).astype(np.float32)

    return mixture, metadata


# -------------------------
# Add synthetic noise to get input features + return metadata
# -------------------------
def add_noise_to_mixture(
    clean_mixture: np.ndarray,
    sr: int,
    snr_db: float,
    noise_func=create_custom_noise_profile,
    overall_noise_gain_db: float = -25.0
) -> Tuple[np.ndarray, Dict]:
    """
    Create noise using noise_func, scale to target SNR with respect to clean_mixture RMS,
    return noisy_mixture and noise metadata.
    """
    duration = len(clean_mixture) / sr
    noise = noise_func(duration, sr, overall_gain_db=overall_noise_gain_db)
    if len(noise) > len(clean_mixture):
        noise = noise[:len(clean_mixture)]
    elif len(noise) < len(clean_mixture):
        noise = np.pad(noise, (0, len(clean_mixture)-len(noise)))

    rms_clean = rms(clean_mixture)
    rms_noise = rms(noise)
    target_lin = 10 ** (snr_db / 20.0)
    required_noise_rms = (rms_clean / target_lin) if target_lin > 0 else rms_clean
    noise_gain = (required_noise_rms / (rms_noise + 1e-12))
    adjusted_noise = (noise * noise_gain).astype(np.float32)
    noisy = clean_mixture + adjusted_noise

    # prevent clipping
    peak = float(np.max(np.abs(noisy)) + 1e-12)
    if peak > 1.0:
        noisy = (noisy / peak * 0.99).astype(np.float32)

    meta = {
        "snr_db_target": float(snr_db),
        "rms_clean": float(rms_clean),
        "rms_noise_before_gain": float(rms_noise),
        "noise_gain": float(noise_gain),
        "overall_noise_profile_db": float(overall_noise_gain_db)
    }
    return noisy, meta


# -------------------------
# Pipeline orchestration (Airflow friendly)
# -------------------------
def run_augmentation_pipeline(
    stems_root: str,
    output_base: str = "dataset",
    dataset_name: str = "aug_mixtures_v1",
    sample_rate: int = 22050,
    clip_duration: float = 3.0,
    min_stems: int = 1,
    max_stems: int = 8,
    energy_db_range: Tuple[float,float] = (-10.0, 10.0),
    snr_db_range: Tuple[float,float] = (5.0, 20.0),
    max_files: Optional[int] = None,         # restrict number of stems considered (N); None -> all
    n_examples: int = 1000,                  # number of augmented examples to synthesize
    top_db: float = 30.0,
    seed: Optional[int] = None
):
    """
    High-level pipeline:
      - discover stems (N)
      - limit to max_files if provided
      - for i in [0..n_examples): create one mixture example:
          - randomly choose between min_stems..max_stems stems
          - sample clip_duration from each stem (using SAD)
          - apply dB scaling per stem
          - sum -> clean_mixture (target)
          - sample snr in snr_db_range -> create noisy (feature)
          - save noisy and clean wavs + metadata json
    """
    rng = random.Random(seed)
    stems_root = Path(stems_root)
    stem_paths = sorted([p for p in stems_root.rglob("*") if p.suffix.lower() in (".wav", ".flac", ".mp3")])
    if not stem_paths:
        raise RuntimeError(f"No audio stems found in {stems_root}")

    if max_files is not None:
        stem_paths = stem_paths[:max_files]

    out_root = Path(output_base) / "processed" / dataset_name
    clean_dir = out_root / "clean"
    noisy_dir = out_root / "noisy"
    meta_dir = out_root / "meta"
    clean_dir.mkdir(parents=True, exist_ok=True)
    noisy_dir.mkdir(parents=True, exist_ok=True)
    meta_dir.mkdir(parents=True, exist_ok=True)

    print(f"Found {len(stem_paths)} stems. Will create {n_examples} examples using up to {max_stems} stems each.")

    for idx in tqdm(range(n_examples), desc="Synth examples"):
        # Build clean mixture
        mixture, stems_meta = build_simulated_mixture(
            stem_paths=stem_paths,
            sr=sample_rate,
            clip_duration=clip_duration,
            min_stems=min_stems,
            max_stems=max_stems,
            energy_db_range=energy_db_range,
            rng=rng,
            top_db=top_db
        )

        # Choose an SNR for noise
        snr = float(rng.uniform(snr_db_range[0], snr_db_range[1]))
        noisy, noise_meta = add_noise_to_mixture(mixture, sr=sample_rate, snr_db=snr)

        # Save files
        basename = f"example_{idx:06d}"
        clean_path = clean_dir / f"{basename}_clean.wav"
        noisy_path = noisy_dir / f"{basename}_noisy.wav"
        meta_path = meta_dir / f"{basename}.json"

        sf.write(str(clean_path), mixture, sample_rate)
        sf.write(str(noisy_path), noisy, sample_rate)

        full_meta = {
            "example_id": basename,
            "sample_rate": int(sample_rate),
            "clip_duration": float(clip_duration),
            "min_stems": int(min_stems),
            "max_stems": int(max_stems),
            "energy_db_range": [float(energy_db_range[0]), float(energy_db_range[1])],
            "snr_db_range": [float(snr_db_range[0]), float(snr_db_range[1])],
            "chosen_snr_db": float(snr),
            "stems_meta": stems_meta,
            "noise_meta": noise_meta
        }
        with open(meta_path, "w", encoding="utf-8") as f:
            json.dump(full_meta, f, indent=2)

    print(f"Saved {n_examples} examples to {out_root}")
    return str(out_root)


In [3]:
input_dir = "IDMT-SMT-GUITAR_V2/dataset2/audio/"

run_augmentation_pipeline(
    stems_root=input_dir,
    output_base="guitar_dataset",
    dataset_name="dataset2-test",
    n_examples=1000,
    seed=1
)

Found 261 stems. Will create 1000 examples using up to 8 stems each.


Synth examples:   0%|          | 0/1000 [00:00<?, ?it/s]

Synth examples: 100%|██████████| 1000/1000 [08:22<00:00,  1.99it/s]

Saved 1000 examples to guitar_dataset\processed\dataset2-test





'guitar_dataset\\processed\\dataset2-test'

## Model Engineering

In [None]:
"""
model_pipeline.py

Airflow-friendly model engineering module:
- load configs
- dataset creation (clean/noisy pairs)
- model build (Wave-U-Net 1D)
- train_epoch / validate_epoch
- checkpointing, plotting
- optional MLflow logging
"""

import os
import json
from pathlib import Path
import random
import math
import time
import logging
from typing import Tuple, Dict, List, Optional

import numpy as np
import soundfile as sf
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [21]:
# Optional YAML support
try:
    import yaml
except Exception:
    yaml = None

# Optional MLflow
try:
    import mlflow
    _mlflow_available = True
except Exception:
    mlflow = None
    _mlflow_available = False

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("model_pipeline")

In [22]:
# ---------------------------
# 1) Config loading utility
# ---------------------------
def load_configs(train_config_path: str, metrics_config_path: str) -> Tuple[dict, dict]:
    """
    Load training configuration and metrics/plot configuration from YAML or JSON.
    Returns (train_cfg, metrics_cfg)
    """
    def _load(path: str):
        p = Path(path)
        if not p.exists():
            raise FileNotFoundError(f"Config file not found: {path}")
        text = p.read_text(encoding="utf-8")
        if yaml and (p.suffix.lower() in (".yml", ".yaml")):
            return yaml.safe_load(text)
        # try json
        try:
            return json.loads(text)
        except Exception:
            if yaml:
                return yaml.safe_load(text)
            raise

    train_cfg = _load(train_config_path)
    metrics_cfg = _load(metrics_config_path)
    return train_cfg, metrics_cfg

In [23]:
# ---------------------------
# 2) Dataset for clean/noisy pairs
# ---------------------------
class CleanNoisyDataset(Dataset):
    """
    Loads pairs from processed dataset directory:
      processed/{dataset_name}/clean/*.wav
      processed/{dataset_name}/noisy/*.wav
    It matches by basename prefix (e.g. example_000001_clean.wav vs example_000001_noisy.wav)
    """

    def __init__(self, processed_root: str, dataset_name: str, split: str = "train"):
        """
        processed_root: base path (e.g., datasets/processed)
        dataset_name: the dataset folder name
        split: 'train' or 'test' - directory layout can contain subfolders 'train'/'test'
               otherwise the module will perform an internal split (see helper function)
        """
        base = Path(processed_root) / dataset_name
        if not base.exists():
            raise RuntimeError(f"Processed dataset not found: {base}")

        # allow both: base/clean & base/noisy OR base/{split}/clean & base/{split}/noisy
        possible_clean = base / "clean"
        possible_noisy = base / "noisy"
        alt_clean = base / split / "clean"
        alt_noisy = base / split / "noisy"

        if alt_clean.exists() and alt_noisy.exists():
            self.clean_dir = alt_clean
            self.noisy_dir = alt_noisy
        elif possible_clean.exists() and possible_noisy.exists():
            self.clean_dir = possible_clean
            self.noisy_dir = possible_noisy
        else:
            raise RuntimeError(f"Could not find clean/noisy directories in {base} or {base}/{split}")

        # build map of basenames
        clean_files = sorted([p for p in self.clean_dir.glob("*.wav")])
        noisy_files = sorted([p for p in self.noisy_dir.glob("*.wav")])

        # map base prefix (without _clean/_noisy suffix) -> full path
        def key_from_path(p: Path):
            stem = p.stem
            # remove common suffixes if present
            for s in ("_clean", "_noisy"):
                if stem.endswith(s):
                    return stem[: -len(s)]
            # if original naming is example_000001_clean -> returns example_000001
            # else return whole stem
            return stem

        clean_map = {key_from_path(p): p for p in clean_files}
        noisy_map = {key_from_path(p): p for p in noisy_files}
        # intersect keys
        keys = sorted(list(set(clean_map.keys()) & set(noisy_map.keys())))
        if not keys:
            raise RuntimeError(f"No matching clean/noisy pairs found in {self.clean_dir} and {self.noisy_dir}")

        self.pairs = [(clean_map[k], noisy_map[k]) for k in keys]

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        clean_p, noisy_p = self.pairs[idx]
        # load
        clean, sr1 = sf.read(str(clean_p))
        noisy, sr2 = sf.read(str(noisy_p))
        # ensure mono and same sr
        if clean.ndim > 1:
            clean = np.mean(clean, axis=1)
        if noisy.ndim > 1:
            noisy = np.mean(noisy, axis=1)
        if sr1 != sr2:
            raise RuntimeError(f"Sample rates mismatch between {clean_p} and {noisy_p}")
        # convert to float32 and tensor shape (1, L)
        clean = torch.from_numpy(np.asarray(clean, dtype=np.float32)).unsqueeze(0)
        noisy = torch.from_numpy(np.asarray(noisy, dtype=np.float32)).unsqueeze(0)
        return noisy, clean  # feature, target

In [24]:
# ---------------------------
# 3) Splitting helper
# ---------------------------
def create_train_test_splits(processed_root: str, dataset_name: str, train_frac: float = 0.9, seed: int = 42, out_dir: Optional[str] = None):
    """
    If the processed dataset does not contain explicit train/test folders, create splits by copying files.
    This function will create processed/{dataset_name}/train/{clean,noisy} and /test/{clean,noisy}
    and copy files accordingly. If those folders already exist, it does nothing.
    Returns (train_dataset_dir, test_dataset_dir) as paths to the processed root (which can be used by CleanNoisyDataset with split='train'/'test').
    """
    base = Path(processed_root) / dataset_name
    train_clean = base / "train" / "clean"
    train_noisy = base / "train" / "noisy"
    test_clean = base / "test" / "clean"
    test_noisy = base / "test" / "noisy"

    # if already split, return
    if train_clean.exists() and test_clean.exists():
        logger.info("Train/test subfolders already exist; skipping split creation.")
        return str(base), str(base)

    # otherwise create split
    all_clean = sorted([p for p in (base / "clean").glob("*.wav")])
    all_noisy = sorted([p for p in (base / "noisy").glob("*.wav")])
    if not all_clean or not all_noisy:
        raise RuntimeError(f"No clean/noisy files found under {base}")

    # match keys (same logic as dataset)
    def key(p: Path):
        s = p.stem
        for suf in ("_clean", "_noisy"):
            if s.endswith(suf):
                return s[:-len(suf)]
        return s

    clean_map = {key(p): p for p in all_clean}
    noisy_map = {key(p): p for p in all_noisy}
    keys = sorted(list(set(clean_map.keys()) & set(noisy_map.keys())))
    random.Random(seed).shuffle(keys)
    n_train = int(math.floor(len(keys) * train_frac))
    train_keys = keys[:n_train]
    test_keys = keys[n_train:]

    # create directories
    train_clean.mkdir(parents=True, exist_ok=True)
    train_noisy.mkdir(parents=True, exist_ok=True)
    test_clean.mkdir(parents=True, exist_ok=True)
    test_noisy.mkdir(parents=True, exist_ok=True)

    from shutil import copy2
    for k in train_keys:
        copy2(clean_map[k], train_clean / clean_map[k].name)
        copy2(noisy_map[k], train_noisy / noisy_map[k].name)
    for k in test_keys:
        copy2(clean_map[k], test_clean / clean_map[k].name)
        copy2(noisy_map[k], test_noisy / noisy_map[k].name)

    logger.info(f"Created train/test split: {len(train_keys)} train, {len(test_keys)} test")
    return str(base), str(base)

In [None]:
# ---------------------------
# 4) Wave U-Net model (same as earlier, slightly improved)
# ---------------------------
class ConvBlock1D(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=15, stride=1, padding=None, activation=nn.LeakyReLU(0.2)):
        super().__init__()
        if padding is None:
            padding = (kernel_size - 1) // 2
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm1d(out_ch),
            activation
        )
    def forward(self, x): return self.net(x)

class WaveUNet1D(nn.Module):
    def __init__(self, input_channels=1, output_channels=1, base_filters=24, depth=5, kernel_size=15):
        super().__init__()
        self.depth = depth
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        in_ch = input_channels
        for d in range(depth):
            out_ch = base_filters * (2 ** d)
            self.downs.append(
                nn.Sequential(
                    ConvBlock1D(in_ch, out_ch, kernel_size=kernel_size, stride=1),
                    ConvBlock1D(out_ch, out_ch, kernel_size=kernel_size, stride=2)  # downsample
                )
            )
            in_ch = out_ch
        self.bottleneck = nn.Sequential(
            ConvBlock1D(in_ch, in_ch * 2, kernel_size=kernel_size, stride=1),
            ConvBlock1D(in_ch * 2, in_ch, kernel_size=kernel_size, stride=1)
        )
        # for d in reversed(range(depth)):
        #     in_ch = base_filters * (2 ** d)
        #     # up conv expects concatenated channels (skip connection)
        #     self.ups.append(
        #         nn.Sequential(
        #             ConvBlock1D(in_ch * 2, in_ch, kernel_size=kernel_size, stride=1),
        #             ConvBlock1D(in_ch, in_ch, kernel_size=kernel_size, stride=1)
        #         )
        #     )

        # At this point `in_ch` is the channel count after the encoder (and matches bottleneck output)
        cur_channels = in_ch  # channels flowing into the decoder at first stage (deepest)
        self.ups = nn.ModuleList()
        for d in reversed(range(depth)):
            skip_ch = base_filters * (2 ** d)           # channels from this encoder skip
            concat_ch = cur_channels + skip_ch          # actual channels after torch.cat([cur, skip], dim=1)
            out_ch = skip_ch                            # we reduce to skip_ch (common UNet pattern)

            self.ups.append(
                nn.Sequential(
                    ConvBlock1D(concat_ch, out_ch, kernel_size=kernel_size, stride=1),
                    ConvBlock1D(out_ch, out_ch, kernel_size=kernel_size, stride=1)
                )
            )

            # update cur_channels for next (upper) decoder stage
            cur_channels = out_ch
        self.final_conv = nn.Conv1d(base_filters, output_channels, kernel_size=1)

    def forward(self, x):
        skips = []
        cur = x
        for block in self.downs:
            cur = block(cur)
            skips.append(cur)
        cur = self.bottleneck(cur)
        for up_block, skip in zip(self.ups, reversed(skips)):
            cur = nn.functional.interpolate(cur, size=skip.shape[-1], mode='linear', align_corners=False)
            # equalize length
            if cur.shape[-1] != skip.shape[-1]:
                diff = skip.shape[-1] - cur.shape[-1]
                if diff > 0:
                    cur = nn.functional.pad(cur, (0, diff))
                else:
                    cur = cur[..., :skip.shape[-1]]
            cur = torch.cat([cur, skip], dim=1)
            cur = up_block(cur)
        if cur.shape[1] != self.final_conv.in_channels:
            adjust = nn.Conv1d(cur.shape[1], self.final_conv.in_channels, kernel_size=1).to(cur.device)
            cur = adjust(cur)
        out = self.final_conv(cur)
        out = torch.tanh(out)

        # Ensure output length matches input length
        if out.shape[-1] != x.shape[-1]:
            diff = x.shape[-1] - out.shape[-1]
            if diff > 0:
                # pad right side
                out = nn.functional.pad(out, (0, diff))
            else:
                # crop extra samples
                out = out[..., :x.shape[-1]]

        return out

In [None]:
# ---------------------------
# 5) Metrics (MSE, MAE, SNR, SI-SDR)
# ---------------------------
def mse_metric(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return float(np.mean((y_true - y_pred)**2))


def mae_metric(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return float(np.mean(np.abs(y_true - y_pred)))


def snr_db_metric(y_true: np.ndarray, y_pred: np.ndarray, eps=1e-8) -> float:
    # SNR = 20*log10(rms_true / rms_error)
    rms_true = math.sqrt(np.mean(y_true**2) + eps)
    rms_err = math.sqrt(np.mean((y_true - y_pred)**2) + eps)
    return 20.0 * math.log10(rms_true / (rms_err + 1e-12))


def si_sdr_metric(y_true: np.ndarray, y_pred: np.ndarray, eps=1e-8) -> float:
    """
    Scale-Invariant SDR for single-channel signals
    y_true, y_pred: 1D numpy arrays (same length)
    """
    # remove mean
    s = y_true.astype(np.float64) - np.mean(y_true)
    s_hat = y_pred.astype(np.float64) - np.mean(y_pred)
    # projection
    s_target = (np.dot(s_hat, s) / (np.dot(s, s) + eps)) * s
    e_noise = s_hat - s_target
    num = np.sum(s_target**2)
    den = np.sum(e_noise**2) + eps
    return 10.0 * math.log10((num + eps) / den)


def sdr_metric(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-8) -> float:
    """
    Signal to Distortion Ratio (SDR)
    
    SDR = 10 * log10(||s_target||^2 / (||e_interf + e_noise + e_artif||^2))
    
    Args:
        y_true: Reference signal (target)
        y_pred: Estimated signal
        eps: Small value to avoid division by zero
    
    Returns:
        SDR in dB
    """
    # Remove mean (optional but often done)
    s = y_true.astype(np.float64) - np.mean(y_true)
    s_hat = y_pred.astype(np.float64) - np.mean(y_pred)
    
    # Projection for target component
    alpha = np.dot(s_hat, s) / (np.dot(s, s) + eps)
    s_target = alpha * s
    
    # Error (distortion)
    e_total = s_hat - s_target
    
    # Calculate energies
    target_energy = np.sum(s_target**2)
    distortion_energy = np.sum(e_total**2)
    
    return 10.0 * math.log10((target_energy + eps) / (distortion_energy + eps))


def sir_metric(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-8) -> float:
    """
    Signal to Interference Ratio (SIR)
    
    SIR = 10 * log10(||s_target||^2 / ||e_interf||^2)
    
    Args:
        y_true: Reference signal (target)
        y_pred: Estimated signal
        eps: Small value to avoid division by zero
    
    Returns:
        SIR in dB
    """
    # Remove mean
    s = y_true.astype(np.float64) - np.mean(y_true)
    s_hat = y_pred.astype(np.float64) - np.mean(y_pred)
    
    # Projection for target component
    alpha = np.dot(s_hat, s) / (np.dot(s, s) + eps)
    s_target = alpha * s
    
    # For SIR, we need the interference component
    # In single-channel case, interference is the part that correlates with other sources
    # For simplicity in single-channel, we use the residual after removing target
    e_interf = s_hat - s_target
    
    # Calculate energies
    target_energy = np.sum(s_target**2)
    interf_energy = np.sum(e_interf**2)
    
    return 10.0 * math.log10((target_energy + eps) / (interf_energy + eps))


def sar_metric(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-8) -> float:
    """
    Signal to Artifacts Ratio (SAR)
    
    SAR = 10 * log10(||s_target + e_interf||^2 / ||e_artif||^2)
    
    Args:
        y_true: Reference signal (target)
        y_pred: Estimated signal
        eps: Small value to avoid division by zero
    
    Returns:
        SAR in dB
    """
    # Remove mean
    s = y_true.astype(np.float64) - np.mean(y_true)
    s_hat = y_pred.astype(np.float64) - np.mean(y_pred)
    
    # Projection for target component
    alpha = np.dot(s_hat, s) / (np.dot(s, s) + eps)
    s_target = alpha * s
    
    # In single-channel case, artifacts are typically the residual
    # For SAR, we consider s_target + e_interf vs artifacts
    # In single-channel context, artifacts ≈ e_noise
    e_interf = s_hat - s_target
    
    # Signal + interference
    signal_plus_interf = s_target + e_interf
    
    # For single-channel, artifacts are typically modeled as the non-linear distortions
    # We approximate artifacts as the residual that cannot be explained by linear projection
    e_artif = s_hat - signal_plus_interf  # This would be zero in linear model
    
    # More practical approach for single-channel SAR
    signal_plus_interf_energy = np.sum(signal_plus_interf**2)
    artifacts_energy = np.sum(e_artif**2) if np.sum(e_artif**2) > eps else eps
    
    return 10.0 * math.log10((signal_plus_interf_energy + eps) / (artifacts_energy + eps))


def isr_metric(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-8) -> float:
    """
    Image to Spatial distortion Ratio (ISR) - Also known as Source Image to Spatial distortion Ratio
    
    ISR = 10 * log10(||s_target||^2 / ||e_spat||^2)
    
    Note: In single-channel audio, ISR is less commonly used and may be approximated.
    This implementation provides a reasonable approximation for single-channel case.
    
    Args:
        y_true: Reference signal (target)
        y_pred: Estimated signal
        eps: Small value to avoid division by zero
    
    Returns:
        ISR in dB
    """
    # Remove mean
    s = y_true.astype(np.float64) - np.mean(y_true)
    s_hat = y_pred.astype(np.float64) - np.mean(y_pred)
    
    # Projection for target component
    alpha = np.dot(s_hat, s) / (np.dot(s, s) + eps)
    s_target = alpha * s
    
    # For ISR in single-channel, we approximate spatial distortion
    # as the part that doesn't align with the target signal
    e_spat = s_hat - s_target
    
    # Calculate energies
    target_energy = np.sum(s_target**2)
    spat_energy = np.sum(e_spat**2)
    
    return 10.0 * math.log10((target_energy + eps) / (spat_energy + eps))


_METRIC_FUNCS = {
    "mse": mse_metric,
    "mae": mae_metric,
    "snr_db": snr_db_metric,
    "si_sdr": si_sdr_metric,
    "sdr": sdr_metric,
    "sir": sir_metric,
    "sar": sar_metric,
    "isr": isr_metric
}

In [27]:
# ---------------------------
# 6) train_epoch & validate_epoch
# ---------------------------
def train_epoch(model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer,
                criterion: nn.Module, device: torch.device, clip_grad: Optional[float] = 5.0):
    model.train()
    running_loss = 0.0
    n_samples = 0
    pbar = tqdm(loader, desc="train", leave=False)
    for noisy, clean in pbar:
        noisy = noisy.to(device)
        clean = clean.to(device)
        optimizer.zero_grad()
        est = model(noisy)
        loss = criterion(est, clean)
        loss.backward()
        if clip_grad:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad)
        optimizer.step()
        batch_size = noisy.shape[0]
        running_loss += loss.item() * batch_size
        n_samples += batch_size
        pbar.set_postfix(loss=running_loss / n_samples)
    return running_loss / max(1, n_samples)


def validate_epoch(model: nn.Module, loader: DataLoader, criterion: nn.Module,
                   device: torch.device, metric_names: List[str]):
    model.eval()
    running_loss = 0.0
    n_samples = 0
    metric_sums = {m: 0.0 for m in metric_names}
    with torch.no_grad():
        pbar = tqdm(loader, desc="val", leave=False)
        for noisy, clean in pbar:
            noisy = noisy.to(device)
            clean = clean.to(device)
            est = model(noisy)
            loss = criterion(est, clean)
            batch_size = noisy.shape[0]
            running_loss += loss.item() * batch_size
            n_samples += batch_size
            # compute metrics sample-wise in numpy
            est_np = est.detach().cpu().numpy()
            clean_np = clean.detach().cpu().numpy()
            for b in range(batch_size):
                y_true = clean_np[b, 0, :]
                y_pred = est_np[b, 0, :]
                for m in metric_names:
                    metric_value = _METRIC_FUNCS[m](y_true, y_pred)
                    metric_sums[m] += metric_value
            pbar.set_postfix(loss=running_loss / n_samples)
    avg_loss = running_loss / max(1, n_samples)
    avg_metrics = {m: (metric_sums[m] / n_samples) for m in metric_names}
    return avg_loss, avg_metrics

In [28]:
# ---------------------------
# 7) plotting helpers
# ---------------------------
def plot_history(history: dict, out_dir: str):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    # loss curve
    plt.figure()
    plt.plot(history.get("train_loss", []), label="train_loss")
    plt.plot(history.get("val_loss", []), label="val_loss")
    plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.grid(True)
    f1 = out_dir / "loss_curve.png"
    plt.savefig(f1); plt.close()
    # metrics
    for m in history.get("metrics", {}).keys():
        plt.figure()
        plt.plot(history["metrics"][m], label=m)
        plt.xlabel("epoch"); plt.ylabel(m); plt.legend(); plt.grid(True)
        plt.savefig(out_dir / f"metric_{m}.png")
        plt.close()
    return out_dir

In [29]:
# ---------------------------
# 8) High-level train & evaluate pipeline
# ---------------------------
def train_and_evaluate(
    processed_root: str,
    dataset_name: str,
    train_config_path: str,
    metrics_config_path: str,
    output_dir: str = "model_output",
    mlflow_enabled: bool = False,
    seed: int = 42
) -> dict:
    """
    High-level entrypoint for training and evaluation. This is Airflow callable.
    Returns a summary dict with final metrics and paths.
    """

    # load configs
    train_cfg, metrics_cfg = load_configs(train_config_path, metrics_config_path)
    # set seeds
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # create train/test splits if needed
    create_train_test_splits(processed_root, dataset_name, train_frac=float(train_cfg.get("train_frac", 0.9)), seed=seed)

    # dataset & dataloaders
    train_dataset = CleanNoisyDataset(processed_root, dataset_name, split="train")
    val_dataset = CleanNoisyDataset(processed_root, dataset_name, split="test")
    train_loader = DataLoader(train_dataset, batch_size=train_cfg["batch_size"], shuffle=True, num_workers=int(train_cfg.get("num_workers", 4)), pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=train_cfg["batch_size"], shuffle=False, num_workers=int(train_cfg.get("num_workers", 4)), pin_memory=True)

    # model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = WaveUNet1D(input_channels=1,
                       output_channels=1,
                       base_filters=int(train_cfg.get("base_filters", 24)),
                       depth=int(train_cfg.get("depth", 5)),
                       kernel_size=int(train_cfg.get("kernel_size", 15)))
    model.to(device)

    # criterion, optimizer, scheduler
    criterion = nn.L1Loss() if train_cfg.get("loss", "l1") == "l1" else nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=float(train_cfg.get("lr", 1e-4)))
    scheduler = None
    if train_cfg.get("lr_step"):
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=train_cfg["lr_step"], gamma=float(train_cfg.get("lr_gamma", 0.5)))

    # metrics and monitoring
    metric_names = metrics_cfg.get("metrics", ["mse"])
    monitor = train_cfg.get("monitor_metric", "val_loss")  # e.g. "val_loss" or "si_sdr"
    monitor_mode = train_cfg.get("monitor_mode", "min")   # "min" or "max"
    best_score = math.inf if monitor_mode == "min" else -math.inf
    best_ckpt_path = None

    # output dirs
    out_root = Path(output_dir)
    ckpt_dir = out_root / "checkpoints"; ckpt_dir.mkdir(parents=True, exist_ok=True)
    plot_dir = out_root / "plots"; plot_dir.mkdir(parents=True, exist_ok=True)
    history = {"train_loss": [], "val_loss": [], "metrics": {m: [] for m in metric_names}}

    # MLflow start run
    mlflow_run = None
    if mlflow_enabled:
        if not _mlflow_available:
            logger.warning("MLflow requested but not available; continuing without MLflow.")
            mlflow_enabled = False
        else:
            mlflow.start_run()
            mlflow_run = mlflow.active_run()
            mlflow.log_params(train_cfg)
            mlflow.log_params({"metrics_cfg": metrics_cfg})

    num_epochs = int(train_cfg.get("epochs", 50))
    for epoch in range(num_epochs):
        logger.info(f"Epoch {epoch+1}/{num_epochs}")
        t0 = time.time()
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device, clip_grad=float(train_cfg.get("clip_grad", 5.0)))
        val_loss, val_metrics = validate_epoch(model, val_loader, criterion, device, metric_names)
        logger.info(f"Epoch {epoch+1} train_loss={train_loss:.6f} val_loss={val_loss:.6f} metrics={val_metrics} time={(time.time()-t0):.1f}s")

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        for m in metric_names:
            history["metrics"][m].append(val_metrics.get(m, None))

        # scheduler step
        if scheduler:
            scheduler.step()

        # monitoring and checkpointing
        # If monitor_metric is "val_loss" use that; if it's in val_metrics use that.
        if monitor == "val_loss":
            current = val_loss
        else:
            current = val_metrics.get(monitor)
            if current is None:
                logger.warning(f"Monitor metric {monitor} not found in val metrics; defaulting to val_loss.")
                current = val_loss

        is_better = (current < best_score) if monitor_mode == "min" else (current > best_score)
        if is_better:
            best_score = current
            best_ckpt_path = ckpt_dir / f"best_{monitor}_{epoch+1}.pt"
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_cfg": train_cfg,
                "metrics_cfg": metrics_cfg,
                "best_score": best_score
            }, str(best_ckpt_path))
            logger.info(f"Saved new best checkpoint: {best_ckpt_path}")
            if mlflow_enabled:
                mlflow.log_metric(f"best_{monitor}", float(best_score), step=epoch)

        # log epoch metrics to mlflow
        if mlflow_enabled:
            mlflow.log_metric("train_loss", float(train_loss), step=epoch)
            mlflow.log_metric("val_loss", float(val_loss), step=epoch)
            for m, v in val_metrics.items():
                mlflow.log_metric(m, float(v), step=epoch)

    # after training: load best model and evaluate on test set (here val set is test)
    if best_ckpt_path is not None:
        ckpt = torch.load(str(best_ckpt_path), map_location=device)
        model.load_state_dict(ckpt["model_state_dict"])
        logger.info(f"Loaded best model from {best_ckpt_path} for final evaluation.")
    else:
        logger.warning("No checkpoint saved during training; using last model for evaluation.")

    # Final evaluation (on val/test loader)
    final_loss, final_metrics = validate_epoch(model, val_loader, criterion, device, metric_names)
    logger.info(f"Final evaluation: loss={final_loss:.6f}, metrics={final_metrics}")

    # Save history & plots
    hist_json = out_root / "history.json"
    hist_json.write_text(json.dumps(history, indent=2))
    plot_history(history, plot_dir)

    # MLflow final logging & artifacts
    if mlflow_enabled:
        mlflow.log_metric("final_val_loss", float(final_loss))
        for m, v in final_metrics.items():
            mlflow.log_metric(f"final_{m}", float(v))
        # log artifacts
        mlflow.log_artifacts(str(plot_dir), artifact_path="plots")
        mlflow.log_artifact(str(hist_json), artifact_path="history")
        if best_ckpt_path:
            mlflow.log_artifact(str(best_ckpt_path), artifact_path="checkpoints")
        mlflow.end_run()

    summary = {
        "best_checkpoint": str(best_ckpt_path) if best_ckpt_path else None,
        "final_val_loss": float(final_loss),
        "final_metrics": final_metrics,
        "history_path": str(hist_json),
        "plot_dir": str(plot_dir)
    }
    return summary

In [None]:
train_and_evaluate(
    processed_root="guitar_dataset/processed",
    dataset_name="dataset2",
    train_config_path="train_config.yaml",
    metrics_config_path="metrics_config.yaml",
    mlflow_enabled=False
)

INFO:model_pipeline:Train/test subfolders already exist; skipping split creation.
INFO:model_pipeline:Epoch 1/40
train:   0%|          | 0/113 [00:00<?, ?it/s]

## Inference

In [None]:
# inference_pipeline.py
import os
import json
import time
from pathlib import Path
from typing import List, Tuple, Optional, Dict

import numpy as np
import soundfile as sf
import librosa
from tqdm.auto import tqdm

import torch
import torch.nn.functional as F

In [None]:
# ------------------------
# Helpers: chunking & IO
# ------------------------
def split_audio_into_chunks(
    audio: np.ndarray,
    sample_rate: int,
    chunk_duration: float
) -> Tuple[List[np.ndarray], int]:
    """
    Split 1D numpy audio into non-overlapping chunks of chunk_duration (seconds).
    Pads the last chunk with zeros if needed.

    Returns:
      chunks: list of numpy arrays (each length == chunk_len)
      orig_len: original audio length in samples (for trimming after reconstruction)
    """
    if audio.ndim > 1:
        audio = np.mean(audio, axis=1)  # to mono

    orig_len = len(audio)
    chunk_len = int(round(chunk_duration * sample_rate))
    if chunk_len <= 0:
        raise ValueError("chunk_duration too small for given sample_rate")

    n_chunks = (orig_len + chunk_len - 1) // chunk_len
    chunks = []
    for i in range(n_chunks):
        s = i * chunk_len
        e = s + chunk_len
        if e <= orig_len:
            chunks.append(audio[s:e].astype(np.float32))
        else:
            # pad right
            pad = e - orig_len
            chunk = np.pad(audio[s:orig_len], (0, pad), mode='constant').astype(np.float32)
            chunks.append(chunk)
    return chunks, orig_len


def reconstruct_from_chunks(chunks_preds: List[np.ndarray], orig_len: int) -> np.ndarray:
    """
    Concatenate list of 1D arrays (all same length) and trim to orig_len samples.
    """
    if not chunks_preds:
        return np.zeros(orig_len, dtype=np.float32)
    out = np.concatenate(chunks_preds, axis=0)
    if len(out) >= orig_len:
        return out[:orig_len].astype(np.float32)
    # if concatenated length is shorter (shouldn't happen), pad
    pad = orig_len - len(out)
    return np.pad(out, (0, pad), mode='constant').astype(np.float32)

In [1]:
# ------------------------
# Model loading
# ------------------------
def load_best_model(checkpoint_path: str, device: Optional[torch.device] = None):
    """
    Loads a checkpoint saved by train_and_evaluate(...) and returns a model in eval mode.
    The checkpoint is expected to contain 'model_state_dict' and optionally 'train_cfg' dict.
    If train_cfg exists, it is used to re-create the WaveUNet1D with the same architecture.
    """
    device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    ckpt = torch.load(checkpoint_path, map_location=device)

    # try to infer model config
    train_cfg = ckpt.get("train_cfg") or ckpt.get("args") or {}
    base_filters = int(train_cfg.get("base_filters", 24))
    depth = int(train_cfg.get("depth", 5))
    kernel_size = int(train_cfg.get("kernel_size", 15))
    input_channels = int(train_cfg.get("input_channels", 1))
    output_channels = int(train_cfg.get("output_channels", 1))

    model = WaveUNet1D(input_channels=input_channels,
                       output_channels=output_channels,
                       base_filters=base_filters,
                       depth=depth,
                       kernel_size=kernel_size)
    model.load_state_dict(ckpt["model_state_dict"])
    model.to(device)
    model.eval()
    return model, train_cfg

NameError: name 'Optional' is not defined

In [None]:
# ------------------------
# Core inference for a single recording
# ------------------------
def infer_single_recording(
    input_path: str,
    model: torch.nn.Module,
    sample_rate: int,
    chunk_duration: float,
    batch_size: int,
    device: torch.device,
    output_dir: str,
    checkpoint_path: str,
    dtype=np.float32
) -> Dict:
    """
    Process one recording: split into chunks, run model in batches, reconstruct, save output and metadata.

    Returns metadata dict with paths and basic stats.
    """
    input_path = Path(input_path)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # load and ensure mono
    audio, sr = librosa.load(str(input_path), sr=sample_rate, mono=True)
    chunks, orig_len = split_audio_into_chunks(audio, sample_rate, chunk_duration)

    preds = []
    model_device = device

    with torch.no_grad():
        for i in range(0, len(chunks), batch_size):
            batch_chunks = chunks[i:i + batch_size]
            # build tensor shape (B,1,L)
            batch_arr = np.stack(batch_chunks, axis=0)  # (B, L)
            batch_tensor = torch.from_numpy(batch_arr).float().unsqueeze(1).to(model_device)  # (B,1,L)

            out_tensor = model(batch_tensor)  # expected (B,1,L) thanks to model alignment
            out_tensor = out_tensor.detach().cpu().numpy()  # (B,1,L)
            out_arrs = [o[0].astype(dtype) for o in out_tensor]  # list of 1D arrays
            preds.extend(out_arrs)

    # reconstruct
    reconstructed = reconstruct_from_chunks(preds, orig_len)

    # save
    out_basename = input_path.stem + "_recon.wav"
    out_path = output_dir / out_basename
    sf.write(str(out_path), reconstructed, sample_rate)

    # save metadata
    meta = {
        "input_path": str(input_path),
        "output_path": str(out_path),
        "checkpoint": str(checkpoint_path),
        "sample_rate": int(sample_rate),
        "chunk_duration": float(chunk_duration),
        "n_chunks": int(len(chunks)),
        "orig_len_samples": int(orig_len),
        "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
    }
    meta_path = output_dir / (input_path.stem + "_meta.json")
    with open(meta_path, "w", encoding="utf-8") as f:
        json.dump(meta, f, indent=2)

    return meta

In [None]:
# ------------------------
# Top-level Airflow-friendly entrypoint
# ------------------------
def run_inference_pipeline(
    input_root: str,
    model_checkpoint: str,
    output_root: str,
    chunk_duration: float = 3.0,
    sample_rate: int = 22050,
    batch_size: int = 8,
    device_str: Optional[str] = None,
    glob_patterns: Optional[List[str]] = None,
    max_files: Optional[int] = None
) -> Dict:
    """
    Main function to be called by Airflow PythonOperator.

    Parameters:
      - input_root: folder with recordings (will search recursively)
      - model_checkpoint: path to checkpoint (best model)
      - output_root: where to save reconstructed recordings and metadata
      - chunk_duration: seconds per chunk to pass to model
      - sample_rate: sampling rate for loading/saving
      - batch_size: inference batch size
      - device_str: 'cuda' or 'cpu' (if None, automatically picked)
      - glob_patterns: list of glob patterns to find files (default ['**/*.wav'])
      - max_files: optional cap on number of recordings to process

    Returns:
      summary dict with list of processed files and metadata paths.
    """
    device = torch.device(device_str if device_str is not None else ("cuda" if torch.cuda.is_available() else "cpu"))
    model, train_cfg = load_best_model(model_checkpoint, device=device)

    input_root = Path(input_root)
    out_root = Path(output_root)
    out_root.mkdir(parents=True, exist_ok=True)
    glob_patterns = glob_patterns or ["**/*.wav"]

    # collect files
    files = []
    for pat in glob_patterns:
        files.extend(sorted(input_root.glob(pat)))
    if not files:
        raise RuntimeError(f"No audio files found in {input_root} with patterns {glob_patterns}")

    if max_files is not None:
        files = files[:max_files]

    processed = []
    pbar = tqdm(files, desc="Inference files", dynamic_ncols=True)
    for f in pbar:
        try:
            meta = infer_single_recording(
                input_path=str(f),
                model=model,
                sample_rate=sample_rate,
                chunk_duration=chunk_duration,
                batch_size=batch_size,
                device=device,
                output_dir=str(out_root),
                checkpoint_path=model_checkpoint
            )
            processed.append(meta)
        except Exception as e:
            # don't crash whole DAG on single file; instead record error
            processed.append({"input_path": str(f), "error": str(e)})
            # you may also choose to re-raise if you want failure semantics in Airflow
            # raise

    summary = {
        "model_checkpoint": str(model_checkpoint),
        "n_files_requested": len(files),
        "n_processed": len(processed),
        "output_root": str(out_root),
        "processed": processed
    }
    # save summary
    with open(out_root / "inference_summary.json", "w", encoding="utf-8") as fh:
        json.dump(summary, fh, indent=2)
    return summary

In [None]:
# ------------------------
# Example Airflow PythonOperator usage:
# ------------------------
# from airflow import DAG
# from airflow.operators.python import PythonOperator
# from datetime import datetime
#
# with DAG("waveunet_inference", start_date=datetime(2025,10,4), schedule=None, catchup=False) as dag:
#     infer_task = PythonOperator(
#         task_id="run_inference",
#         python_callable=run_inference_pipeline,
#         op_kwargs={
#             "input_root": "/mnt/data/to_process",
#             "model_checkpoint": "/path/to/best_checkpoint.pt",
#             "output_root": "/mnt/data/inference_out",
#             "chunk_duration": 3.0,
#             "sample_rate": 22050,
#             "batch_size": 8,
#             "device_str": "cuda",
#             "glob_patterns": ["**/*.wav"],
#             "max_files": 200
#         },
#     )
