In [1]:
!pip install soundfile

Collecting soundfile
  Downloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl.metadata (16 kB)
Downloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m0m
[?25hInstalling collected packages: soundfile
Successfully installed soundfile-0.13.1


In [2]:
!pip install librosa

Collecting librosa
  Downloading librosa-0.11.0-py3-none-any.whl.metadata (8.7 kB)
Collecting audioread>=2.1.9 (from librosa)
  Downloading audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)
Collecting pooch>=1.1 (from librosa)
  Downloading pooch-1.8.2-py3-none-any.whl.metadata (10 kB)
Collecting soxr>=0.3.2 (from librosa)
  Downloading soxr-1.0.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.6 kB)
Downloading librosa-0.11.0-py3-none-any.whl (260 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m260.7/260.7 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0mm
[?25hDownloading audioread-3.0.1-py3-none-any.whl (23 kB)
Downloading pooch-1.8.2-py3-none-any.whl (64 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.6/64.6 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading soxr-1.0.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (242 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[

## Data Engineering

In [1]:
import os
from pathlib import Path
import json
import random
import math
from typing import List, Tuple, Optional, Dict

import numpy as np
import librosa
import soundfile as sf
from scipy import signal
from tqdm import tqdm

In [None]:
# -------------------------
# Noise generator (your function, slightly hardened)
# -------------------------
def create_custom_noise_profile(duration, sample_rate, overall_gain_db=-25):
    """Return a noise array length = duration*sample_rate (dtype=float32)."""
    t = np.linspace(0, duration, int(duration * sample_rate), endpoint=False)
    harmonic_noise = np.zeros_like(t)

    # harmonic hum
    fundamental_freqs = [440, 516, 645]
    for fundamental in fundamental_freqs:
        for harmonic in range(2, 7):
            freq = fundamental * harmonic
            detuned_freq = freq * (1 + np.random.uniform(-0.01, 0.01))
            amplitude = 0.15 / harmonic
            am_depth = 0.1
            am_rate = 0.5
            am_mod = 1 + am_depth * np.sin(2 * np.pi * am_rate * t)
            harmonic_noise += amplitude * am_mod * np.sin(2 * np.pi * detuned_freq * t)

    # resonant peaks (narrow band)
    resonant_freqs = [3158, 3856, 5109]
    resonant_amplitudes = [0.08, 0.08, 0.06]
    resonant_noise = np.zeros_like(t)
    nyquist = sample_rate / 2.0
    for freq, amp in zip(resonant_freqs, resonant_amplitudes):
        white = np.random.normal(0, 1, len(t))
        low_cut = max(0.0001, (freq * 0.9) / nyquist)
        high_cut = min(0.9999, (freq * 1.1) / nyquist)
        try:
            b, a = signal.butter(4, [low_cut, high_cut], btype='band')
            narrow = signal.filtfilt(b, a, white)
        except Exception:
            narrow = white
        resonant_noise += amp * narrow

    # broadband hiss shaped by FIR
    white_noise = np.random.normal(0, 1, len(t))
    try:
        from scipy.signal import firwin2
        freq_points = [0, 1000, 1290, 1548, 1858, 2229, 2675, 4000, 6000, sample_rate/2]
        gain_response = [10, 15, 15, 15, 15, 15, 15, 8, 5, 5]
        norm = np.array(freq_points) / (sample_rate/2)
        norm = np.clip(norm, 0.0, 1.0)
        fir_coeffs = firwin2(1025, norm, gain_response)
        shaped_hiss = signal.filtfilt(fir_coeffs, [1.0], white_noise)
    except Exception:
        shaped_hiss = white_noise

    # combine -> apply notch -> normalize -> gain
    combined = harmonic_noise + resonant_noise + shaped_hiss
    try:
        b_notch, a_notch = signal.iirnotch(3179.3, 4, sample_rate)
        combined = signal.filtfilt(b_notch, a_notch, combined)
    except Exception:
        pass

    maxabs = np.max(np.abs(combined)) + 1e-12
    combined = combined / maxabs
    gain_lin = 10 ** (overall_gain_db / 20.0)
    return (combined * gain_lin).astype(np.float32)


# -------------------------
# SAD: find non-silent intervals (librosa)
# -------------------------
def detect_activity_intervals(audio: np.ndarray, sr: int, top_db: float = 30.0, frame_length: int = 2048, hop_length: int = 512) -> List[Tuple[int, int]]:
    """
    Return list of (start_sample, end_sample) intervals containing activity.
    Uses librosa.effects.split which is a simple SAD (energy thresholding).
    top_db: threshold in dB below reference to consider silence (lower -> more aggressive keep)
    """
    intervals = librosa.effects.split(y=audio, top_db=top_db, frame_length=frame_length, hop_length=hop_length)
    return [(int(s), int(e)) for s, e in intervals]


# -------------------------
# Helper: sample one fixed-length clip from non-silent intervals
# -------------------------
def sample_clip_from_intervals(audio: np.ndarray, sr: int, intervals: List[Tuple[int,int]], clip_duration: float, rng: Optional[random.Random] = None) -> np.ndarray:
    """
    Choose a random interval that can contain a clip of clip_duration.
    If no single interval is long enough, try to stitch or center-pad shorter audio to clip length.
    Returns a numpy array of length = clip_duration * sr
    """
    rng = rng or random
    clip_len = int(round(clip_duration * sr))
    # Filter intervals long enough
    long_intervals = [iv for iv in intervals if (iv[1] - iv[0]) >= clip_len]
    if long_intervals:
        s, e = rng.choice(long_intervals)
        start = rng.randint(s, e - clip_len)
        clip = audio[start:start + clip_len]
        return clip.astype(np.float32)
    # otherwise try to sample from any interval, possibly concatenating up to clip_len by wrapping/padding:
    if intervals:
        # pick a random interval, extract it, then either pad or loop to reach clip_len
        s, e = rng.choice(intervals)
        seg = audio[s:e].astype(np.float32)
        if len(seg) >= clip_len:
            # deterministic crop
            start = rng.randint(0, len(seg) - clip_len)
            return seg[start:start+clip_len]
        else:
            # repeat or pad center
            needed = clip_len - len(seg)
            left = needed // 2
            right = needed - left
            return np.pad(seg, (left, right), mode='constant', constant_values=0.0)
    # if no intervals (silent file) -> zero pad or use entire audio center
    if len(audio) >= clip_len:
        center = len(audio) // 2
        start = max(0, center - clip_len // 2)
        return audio[start:start + clip_len].astype(np.float32)
    else:
        return np.pad(audio.astype(np.float32), (0, clip_len - len(audio)), mode='constant')


# -------------------------
# RMS utilities
# -------------------------
def rms(x: np.ndarray, eps=1e-12) -> float:
    return float(np.sqrt(np.mean(x.astype(np.float64)**2) + eps))


# -------------------------
# Build one simulated mixture example (core of augmentation)
# -------------------------
def build_simulated_mixture(
    stem_paths: List[Path],
    sr: int,
    clip_duration: float = 3.0,
    min_stems: int = 1,
    max_stems: int = 8,
    energy_db_range: Tuple[float, float] = (-10.0, 10.0),
    rng: Optional[random.Random] = None,
    top_db: float = 30.0
) -> Tuple[np.ndarray, List[Dict]]:
    """
    Given a list of available stems (Paths), randomly select k stems (k in [min_stems, max_stems])
    and produce a clean mixture (1D numpy array of length clip_duration*sr) and metadata list.
    Metadata describes which stems, start samples, applied dB gains, original RMS.
    """
    rng = rng or random
    n_available = len(stem_paths)
    if n_available == 0:
        raise ValueError("No stems provided to build_simulated_mixture()")

    k = rng.randint(min_stems, min(max_stems, n_available))
    selected = rng.sample(stem_paths, k)
    clip_len = int(round(clip_duration * sr))

    mixture = np.zeros(clip_len, dtype=np.float32)
    metadata = []

    for p in selected:
        audio, file_sr = librosa.load(str(p), sr=None, mono=True)
        # resample if needed
        if file_sr != sr:
            audio = librosa.resample(audio, orig_sr=file_sr, target_sr=sr)
        intervals = detect_activity_intervals(audio, sr, top_db=top_db)
        clip = sample_clip_from_intervals(audio, sr, intervals, clip_duration, rng=rng)
        orig_rms = rms(clip)
        db_change = rng.uniform(energy_db_range[0], energy_db_range[1])
        gain_lin = 10 ** (db_change / 20.0)
        scaled = (clip * gain_lin).astype(np.float32)
        # sum to mixture
        mixture = mixture + scaled
        metadata.append({
            "stem_path": str(p),
            "db_change": float(db_change),
            "gain_lin": float(gain_lin),
            "orig_rms": float(orig_rms)
        })

    # After summing, avoid clipping: scale mixture by peak if needed, but preserve RMS relationships.
    peak = float(np.max(np.abs(mixture)) + 1e-12)
    if peak > 0.99:
        mixture = (mixture / peak * 0.99).astype(np.float32)

    return mixture, metadata


# -------------------------
# Add synthetic noise to get input features + return metadata
# -------------------------
def add_noise_to_mixture(
    clean_mixture: np.ndarray,
    sr: int,
    snr_db: float,
    noise_func=create_custom_noise_profile,
    overall_noise_gain_db: float = -25.0
) -> Tuple[np.ndarray, Dict]:
    """
    Create noise using noise_func, scale to target SNR with respect to clean_mixture RMS,
    return noisy_mixture and noise metadata.
    """
    duration = len(clean_mixture) / sr
    noise = noise_func(duration, sr, overall_gain_db=overall_noise_gain_db)
    if len(noise) > len(clean_mixture):
        noise = noise[:len(clean_mixture)]
    elif len(noise) < len(clean_mixture):
        noise = np.pad(noise, (0, len(clean_mixture)-len(noise)))

    rms_clean = rms(clean_mixture)
    rms_noise = rms(noise)
    target_lin = 10 ** (snr_db / 20.0)
    required_noise_rms = (rms_clean / target_lin) if target_lin > 0 else rms_clean
    noise_gain = (required_noise_rms / (rms_noise + 1e-12))
    adjusted_noise = (noise * noise_gain).astype(np.float32)
    noisy = clean_mixture + adjusted_noise

    # prevent clipping
    peak = float(np.max(np.abs(noisy)) + 1e-12)
    if peak > 1.0:
        noisy = (noisy / peak * 0.99).astype(np.float32)

    meta = {
        "snr_db_target": float(snr_db),
        "rms_clean": float(rms_clean),
        "rms_noise_before_gain": float(rms_noise),
        "noise_gain": float(noise_gain),
        "overall_noise_profile_db": float(overall_noise_gain_db)
    }
    return noisy, meta


# -------------------------
# Pipeline orchestration (Airflow friendly)
# -------------------------
def run_augmentation_pipeline(
    stems_root: str,
    output_base: str = "dataset",
    dataset_name: str = "aug_mixtures_v1",
    sample_rate: int = 22050,
    clip_duration: float = 3.0,
    min_stems: int = 1,
    max_stems: int = 8,
    energy_db_range: Tuple[float,float] = (-10.0, 10.0),
    snr_db_range: Tuple[float,float] = (5.0, 20.0),
    max_files: Optional[int] = None,         # restrict number of stems considered (N); None -> all
    n_examples: int = 1000,                  # number of augmented examples to synthesize
    top_db: float = 30.0,
    seed: Optional[int] = None
):
    """
    High-level pipeline:
      - discover stems (N)
      - limit to max_files if provided
      - for i in [0..n_examples): create one mixture example:
          - randomly choose between min_stems..max_stems stems
          - sample clip_duration from each stem (using SAD)
          - apply dB scaling per stem
          - sum -> clean_mixture (target)
          - sample snr in snr_db_range -> create noisy (feature)
          - save noisy and clean wavs + metadata json
    """
    rng = random.Random(seed)
    stems_root = Path(stems_root)
    stem_paths = sorted([p for p in stems_root.rglob("*") if p.suffix.lower() in (".wav", ".flac", ".mp3")])
    if not stem_paths:
        raise RuntimeError(f"No audio stems found in {stems_root}")

    if max_files is not None:
        stem_paths = stem_paths[:max_files]

    out_root = Path(output_base) / "processed" / dataset_name
    clean_dir = out_root / "clean"
    noisy_dir = out_root / "noisy"
    meta_dir = out_root / "meta"
    clean_dir.mkdir(parents=True, exist_ok=True)
    noisy_dir.mkdir(parents=True, exist_ok=True)
    meta_dir.mkdir(parents=True, exist_ok=True)

    print(f"Found {len(stem_paths)} stems. Will create {n_examples} examples using up to {max_stems} stems each.")

    for idx in tqdm(range(n_examples), desc="Synth examples"):
        # Build clean mixture
        mixture, stems_meta = build_simulated_mixture(
            stem_paths=stem_paths,
            sr=sample_rate,
            clip_duration=clip_duration,
            min_stems=min_stems,
            max_stems=max_stems,
            energy_db_range=energy_db_range,
            rng=rng,
            top_db=top_db
        )

        # Choose an SNR for noise
        snr = float(rng.uniform(snr_db_range[0], snr_db_range[1]))
        noisy, noise_meta = add_noise_to_mixture(mixture, sr=sample_rate, snr_db=snr)

        # Save files
        basename = f"example_{idx:06d}"
        clean_path = clean_dir / f"{basename}_clean.wav"
        noisy_path = noisy_dir / f"{basename}_noisy.wav"
        meta_path = meta_dir / f"{basename}.json"

        # sf.write(str(clean_path), mixture, sample_rate)
        # sf.write(str(noisy_path), noisy, sample_rate)

        # Derived components
        target = mixture                      # clean signal
        residual = noisy - mixture             # noise component

        target_path = clean_dir / f"{basename}_target.wav"
        residual_path = clean_dir / f"{basename}_residual.wav"
        mix_path = noisy_dir / f"{basename}_mix.wav"

        sf.write(str(target_path), target, sample_rate)
        sf.write(str(residual_path), residual, sample_rate)
        sf.write(str(mix_path), noisy, sample_rate)

        full_meta = {
            "example_id": basename,
            "sample_rate": int(sample_rate),
            "clip_duration": float(clip_duration),
            "chosen_snr_db": float(snr),
            "paths": {
                "target": str(target_path),
                "residual": str(residual_path),
                "mix": str(mix_path),
            },
            "stems_meta": stems_meta,
            "noise_meta": noise_meta
        }

        with open(meta_path, "w", encoding="utf-8") as f:
            json.dump(full_meta, f, indent=2)

    print(f"Saved {n_examples} examples to {out_root}")
    return str(out_root)


In [3]:
input_dir = "IDMT-SMT-GUITAR_V2/dataset2/audio/"

run_augmentation_pipeline(
    stems_root=input_dir,
    output_base="guitar_dataset",
    dataset_name="dataset2-v2",
    n_examples=200
)

Found 261 stems. Will create 200 examples using up to 8 stems each.


Synth examples:   0%|          | 0/200 [00:00<?, ?it/s]

Synth examples: 100%|██████████| 200/200 [01:59<00:00,  1.67it/s]

Saved 200 examples to guitar_dataset\processed\dataset2-v2





'guitar_dataset\\processed\\dataset2-v2'

## Model Engineering

In [198]:
"""
model_pipeline.py

Airflow-friendly model engineering module:
- load configs
- dataset creation (clean/noisy pairs)
- model build (Wave-U-Net 1D)
- train_epoch / validate_epoch
- checkpointing, plotting
- optional MLflow logging
"""

import os
import json
from pathlib import Path
import random
import math
import time
import logging
from typing import Tuple, Dict, List, Optional

import numpy as np
import soundfile as sf
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [199]:
# Optional YAML support
try:
    import yaml
except Exception:
    yaml = None

# Optional MLflow
try:
    import mlflow
    _mlflow_available = True
except Exception:
    mlflow = None
    _mlflow_available = False

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("model_pipeline")

In [200]:
# ---------------------------
# 1) Config loading utility
# ---------------------------
def load_configs(train_config_path: str, metrics_config_path: str) -> Tuple[dict, dict]:
    """
    Load training configuration and metrics/plot configuration from YAML or JSON.
    Returns (train_cfg, metrics_cfg)
    """
    def _load(path: str):
        p = Path(path)
        if not p.exists():
            raise FileNotFoundError(f"Config file not found: {path}")
        text = p.read_text(encoding="utf-8")
        if yaml and (p.suffix.lower() in (".yml", ".yaml")):
            return yaml.safe_load(text)
        # try json
        try:
            return json.loads(text)
        except Exception:
            if yaml:
                return yaml.safe_load(text)
            raise

    train_cfg = _load(train_config_path)
    metrics_cfg = _load(metrics_config_path)
    return train_cfg, metrics_cfg

In [201]:
# ---------------------------
# 2) Dataset for target/residual/mix triplets
# ---------------------------
class WaveUNetDataset(Dataset):
    """
    Loads triplets from processed dataset directory:
      processed/{dataset_name}/clean/*_target.wav
      processed/{dataset_name}/clean/*_residual.wav
      processed/{dataset_name}/noisy/*_mix.wav

    It matches by basename prefix (e.g. example_000001_target.wav / example_000001_residual.wav / example_000001_mix.wav)
    """

    def __init__(self, processed_root: str, dataset_name: str, split: str = "train"):
        """
        processed_root: base path (e.g., datasets/processed)
        dataset_name: the dataset folder name
        split: 'train' or 'test' - supports:
            - processed/{dataset_name}/clean + /noisy
            - processed/{dataset_name}/{split}/clean + /noisy
        """
        base = Path(processed_root) / dataset_name
        if not base.exists():
            raise RuntimeError(f"Processed dataset not found: {base}")

        # possible directory layouts
        possible_clean = base / "clean"
        possible_noisy = base / "noisy"
        alt_clean = base / split / "clean"
        alt_noisy = base / split / "noisy"

        if alt_clean.exists() and alt_noisy.exists():
            self.clean_dir = alt_clean
            self.noisy_dir = alt_noisy
        elif possible_clean.exists() and possible_noisy.exists():
            self.clean_dir = possible_clean
            self.noisy_dir = possible_noisy
        else:
            raise RuntimeError(f"Could not find clean/noisy directories in {base} or {base}/{split}")

        # list all relevant files
        target_files = sorted(self.clean_dir.glob("*_target.wav"))
        residual_files = sorted(self.clean_dir.glob("*_residual.wav"))
        mix_files = sorted(self.noisy_dir.glob("*_mix.wav"))

        # build basename maps
        def key_from_path(p: Path):
            stem = p.stem
            for s in ("_target", "_residual", "_mix"):
                if stem.endswith(s):
                    return stem[: -len(s)]
            return stem

        target_map = {key_from_path(p): p for p in target_files}
        residual_map = {key_from_path(p): p for p in residual_files}
        mix_map = {key_from_path(p): p for p in mix_files}

        # intersection of all three
        keys = sorted(list(set(target_map.keys()) & set(residual_map.keys()) & set(mix_map.keys())))
        if not keys:
            raise RuntimeError(f"No matching target/residual/mix triplets found in {self.clean_dir} and {self.noisy_dir}")

        self.triplets = [(target_map[k], residual_map[k], mix_map[k]) for k in keys]

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        target_p, residual_p, mix_p = self.triplets[idx]

        # load
        target, sr1 = sf.read(str(target_p))
        residual, sr2 = sf.read(str(residual_p))
        mix, sr3 = sf.read(str(mix_p))

        # ensure mono + same sr
        if target.ndim > 1:
            target = np.mean(target, axis=1)
        if residual.ndim > 1:
            residual = np.mean(residual, axis=1)
        if mix.ndim > 1:
            mix = np.mean(mix, axis=1)
        if not (sr1 == sr2 == sr3):
            raise RuntimeError(f"Sample rates mismatch for example {target_p.stem}")

        # make sure lengths align
        min_len = min(len(target), len(residual), len(mix))
        target = target[:min_len]
        residual = residual[:min_len]
        mix = mix[:min_len]

        # convert to (1, L) tensors
        target = torch.from_numpy(target.astype(np.float32)).unsqueeze(0)
        residual = torch.from_numpy(residual.astype(np.float32)).unsqueeze(0)
        mix = torch.from_numpy(mix.astype(np.float32)).unsqueeze(0)

        return mix, target, residual  # input, clean target, noise residual

In [202]:
# ---------------------------
# 3) Splitting helper (for target/residual/mix dataset)
# ---------------------------
def create_train_test_splits(processed_root: str, dataset_name: str, train_frac: float = 0.9, seed: int = 42):
    """
    Creates train/test subfolders for WaveUNetDataset-style processed datasets.
    Works with:
      processed/{dataset_name}/clean/*_target.wav
      processed/{dataset_name}/clean/*_residual.wav
      processed/{dataset_name}/noisy/*_mix.wav

    Output:
      processed/{dataset_name}/train/{clean,noisy}
      processed/{dataset_name}/test/{clean,noisy}

    Returns:
      (train_dataset_dir, test_dataset_dir)
    """
    base = Path(processed_root) / dataset_name
    train_clean = base / "train" / "clean"
    train_noisy = base / "train" / "noisy"
    test_clean = base / "test" / "clean"
    test_noisy = base / "test" / "noisy"

    # if already split, skip
    if train_clean.exists() and test_clean.exists():
        print("Train/test subfolders already exist; skipping split creation.")
        return str(base), str(base)

    # locate files
    all_target = sorted((base / "clean").glob("*_target.wav"))
    all_residual = sorted((base / "clean").glob("*_residual.wav"))
    all_mix = sorted((base / "noisy").glob("*_mix.wav"))

    if not all_target or not all_residual or not all_mix:
        raise RuntimeError(f"Missing one or more sets of files in {base} (need target, residual, mix)")

    # match keys like WaveUNetDataset
    def key(p: Path):
        s = p.stem
        for suf in ("_target", "_residual", "_mix"):
            if s.endswith(suf):
                return s[:-len(suf)]
        return s

    target_map = {key(p): p for p in all_target}
    residual_map = {key(p): p for p in all_residual}
    mix_map = {key(p): p for p in all_mix}

    keys = sorted(list(set(target_map.keys()) & set(residual_map.keys()) & set(mix_map.keys())))
    if not keys:
        raise RuntimeError(f"No matching triplets found under {base}")

    # random split
    rng = random.Random(seed)
    rng.shuffle(keys)
    n_train = int(math.floor(len(keys) * train_frac))
    train_keys = keys[:n_train]
    test_keys = keys[n_train:]

    # make dirs
    for d in [train_clean, train_noisy, test_clean, test_noisy]:
        d.mkdir(parents=True, exist_ok=True)

    from shutil import copy2
    def copy_triplet(k: str, dst_clean: Path, dst_noisy: Path):
        copy2(target_map[k], dst_clean / target_map[k].name)
        copy2(residual_map[k], dst_clean / residual_map[k].name)
        copy2(mix_map[k], dst_noisy / mix_map[k].name)

    for k in train_keys:
        copy_triplet(k, train_clean, train_noisy)
    for k in test_keys:
        copy_triplet(k, test_clean, test_noisy)

    print(f"Created train/test split: {len(train_keys)} train, {len(test_keys)} test examples.")
    return str(base), str(base)


In [203]:
# ---------------------------
# 5) Wave-UNET Architecture
# ---------------------------

# conv.py

class ConvLayer(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, conv_type, transpose=False):
        super(ConvLayer, self).__init__()
        self.transpose = transpose
        self.stride = stride
        self.kernel_size = kernel_size
        self.conv_type = conv_type

        # How many channels should be normalised as one group if GroupNorm is activated
        # WARNING: Number of channels has to be divisible by this number!
        NORM_CHANNELS = 8

        if self.transpose:
            self.filter = nn.ConvTranspose1d(n_inputs, n_outputs, self.kernel_size, stride, padding=kernel_size-1)
        else:
            self.filter = nn.Conv1d(n_inputs, n_outputs, self.kernel_size, stride)

        if conv_type == "gn":
            assert(n_outputs % NORM_CHANNELS == 0)
            self.norm = nn.GroupNorm(n_outputs // NORM_CHANNELS, n_outputs)
        elif conv_type == "bn":
            self.norm = nn.BatchNorm1d(n_outputs, momentum=0.01)
        # Add you own types of variations here!

    def forward(self, x):
        # Apply the convolution
        if self.conv_type == "gn" or self.conv_type == "bn":
            out = nn.functional.relu(self.norm((self.filter(x))))
        else: # Add your own variations here with elifs conditioned on "conv_type" parameter!
            assert(self.conv_type == "normal")
            out =nn.functional.leaky_relu(self.filter(x))
        return out

    def get_input_size(self, output_size):
        # Strided conv/decimation
        if not self.transpose:
            curr_size = (output_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1
        else:
            curr_size = output_size

        # Conv
        curr_size = curr_size + self.kernel_size - 1 # o = i + p - k + 1

        # Transposed
        if self.transpose:
            assert ((curr_size - 1) % self.stride == 0)# We need to have a value at the beginning and end
            curr_size = ((curr_size - 1) // self.stride) + 1
        assert(curr_size > 0)
        return curr_size

    def get_output_size(self, input_size):
        # Transposed
        if self.transpose:
            assert(input_size > 1)
            curr_size = (input_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1
        else:
            curr_size = input_size

        # Conv
        curr_size = curr_size - self.kernel_size + 1 # o = i + p - k + 1
        assert (curr_size > 0)

        # Strided conv/decimation
        if not self.transpose:
            assert ((curr_size - 1) % self.stride == 0)  # We need to have a value at the beginning and end
            curr_size = ((curr_size - 1) // self.stride) + 1

        return curr_size

In [204]:
# crop.py

def centre_crop(x, target):
    '''
    Center-crop 3-dim. input tensor along last axis so it fits the target tensor shape
    :param x: Input tensor
    :param target: Shape of this tensor will be used as target shape
    :return: Cropped input tensor
    '''
    if x is None:
        return None
    if target is None:
        return x

    target_shape = target.shape
    diff = x.shape[-1] - target_shape[-1]
    assert (diff % 2 == 0)
    crop = diff // 2

    if crop == 0:
        return x
    if crop < 0:
        raise ArithmeticError

    return x[:, :, crop:-crop].contiguous()

In [205]:
# resample.py

class Resample1d(nn.Module):
    def __init__(self, channels, kernel_size, stride, transpose=False, padding="reflect", trainable=False):
        '''
        Creates a resampling layer for time series data (using 1D convolution) - (N, C, W) input format
        :param channels: Number of features C at each time-step
        :param kernel_size: Width of sinc-based lowpass-filter (>= 15 recommended for good filtering performance)
        :param stride: Resampling factor (integer)
        :param transpose: False for down-, true for upsampling
        :param padding: Either "reflect" to pad or "valid" to not pad
        :param trainable: Optionally activate this to train the lowpass-filter, starting from the sinc initialisation
        '''
        super(Resample1d, self).__init__()

        self.padding = padding
        self.kernel_size = kernel_size
        self.stride = stride
        self.transpose = transpose
        self.channels = channels

        cutoff = 0.5 / stride

        assert(kernel_size > 2)
        assert ((kernel_size - 1) % 2 == 0)
        assert(padding == "reflect" or padding == "valid")

        filter = build_sinc_filter(kernel_size, cutoff)

        self.filter = torch.nn.Parameter(torch.from_numpy(np.repeat(np.reshape(filter, [1, 1, kernel_size]), channels, axis=0)), requires_grad=trainable)

    def forward(self, x):
        # Pad here if not using transposed conv
        input_size = x.shape[2]
        if self.padding != "valid":
            num_pad = (self.kernel_size-1)//2
            out = nn.functional.pad(x, (num_pad, num_pad), mode=self.padding)
        else:
            out = x

        # Lowpass filter (+ 0 insertion if transposed)
        if self.transpose:
            expected_steps = ((input_size - 1) * self.stride + 1)
            if self.padding == "valid":
                expected_steps = expected_steps - self.kernel_size + 1

            out = nn.functional.conv_transpose1d(out, self.filter, stride=self.stride, padding=0, groups=self.channels)
            diff_steps = out.shape[2] - expected_steps
            if diff_steps > 0:
                assert(diff_steps % 2 == 0)
                out = out[:,:,diff_steps//2:-diff_steps//2]
        else:
            assert(input_size % self.stride == 1)
            out = nn.functional.conv1d(out, self.filter, stride=self.stride, padding=0, groups=self.channels)

        return out

    def get_output_size(self, input_size):
        '''
        Returns the output dimensionality (number of timesteps) for a given input size
        :param input_size: Number of input time steps (Scalar, each feature is one-dimensional)
        :return: Output size (scalar)
        '''
        assert(input_size > 1)
        if self.transpose:
            if self.padding == "valid":
                return ((input_size - 1) * self.stride + 1) - self.kernel_size + 1
            else:
                return ((input_size - 1) * self.stride + 1)
        else:
            assert(input_size % self.stride == 1) # Want to take first and last sample
            if self.padding == "valid":
                return input_size - self.kernel_size + 1
            else:
                return input_size

    def get_input_size(self, output_size):
        '''
        Returns the input dimensionality (number of timesteps) for a given output size
        :param input_size: Number of input time steps (Scalar, each feature is one-dimensional)
        :return: Output size (scalar)
        '''

        # Strided conv/decimation
        if not self.transpose:
            curr_size = (output_size - 1)*self.stride + 1 # o = (i-1)//s + 1 => i = (o - 1)*s + 1
        else:
            curr_size = output_size

        # Conv
        if self.padding == "valid":
            curr_size = curr_size + self.kernel_size - 1 # o = i + p - k + 1

        # Transposed
        if self.transpose:
            assert ((curr_size - 1) % self.stride == 0)# We need to have a value at the beginning and end
            curr_size = ((curr_size - 1) // self.stride) + 1
        assert(curr_size > 0)
        return curr_size

def build_sinc_filter(kernel_size, cutoff):
    # FOLLOWING https://www.analog.com/media/en/technical-documentation/dsp-book/dsp_book_Ch16.pdf
    # Sinc lowpass filter
    # Build sinc kernel
    assert(kernel_size % 2 == 1)
    M = kernel_size - 1
    filter = np.zeros(kernel_size, dtype=np.float32)
    for i in range(kernel_size):
        if i == M//2:
            filter[i] = 2 * np.pi * cutoff
        else:
            filter[i] = (np.sin(2 * np.pi * cutoff * (i - M//2)) / (i - M//2)) * \
                    (0.42 - 0.5 * np.cos((2 * np.pi * i) / M) + 0.08 * np.cos(4 * np.pi * M))

    filter = filter / np.sum(filter)
    return filter

In [206]:
# utils.py

def save_model(model, optimizer, state, path):
    if isinstance(model, torch.nn.DataParallel):
        model = model.module  # save state dict of wrapped module
    if len(os.path.dirname(path)) > 0 and not os.path.exists(os.path.dirname(path)):
        os.makedirs(os.path.dirname(path))
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'state': state,  # state of training loop (was 'step')
    }, path)


def load_model(model, optimizer, path, cuda):
    if isinstance(model, torch.nn.DataParallel):
        model = model.module  # load state dict of wrapped module
    if cuda:
        checkpoint = torch.load(path)
    else:
        checkpoint = torch.load(path, map_location='cpu')
    try:
        model.load_state_dict(checkpoint['model_state_dict'])
    except:
        # work-around for loading checkpoints where DataParallel was saved instead of inner module
        from collections import OrderedDict
        model_state_dict_fixed = OrderedDict()
        prefix = 'module.'
        for k, v in checkpoint['model_state_dict'].items():
            if k.startswith(prefix):
                k = k[len(prefix):]
            model_state_dict_fixed[k] = v
        model.load_state_dict(model_state_dict_fixed)
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if 'state' in checkpoint:
        state = checkpoint['state']
    else:
        # older checkpoints only store step, rest of state won't be there
        state = {'step': checkpoint['step']}
    return state


def compute_loss(model, inputs, targets, criterion, compute_grad=False):
    '''
    Computes gradients of model with given inputs and targets and loss function.
    Optionally backpropagates to compute gradients for weights.
    Procedure depends on whether we have one model for each source or not
    :param model: Model to train with
    :param inputs: Input mixture
    :param targets: Target sources
    :param criterion: Loss function to use (L1, L2, ..)
    :param compute_grad: Whether to compute gradients
    :return: Model outputs, Average loss over batch
    '''
    all_outputs = {}

    if model.separate:
        avg_loss = 0.0
        num_sources = 0
        for inst in model.instruments:
            output = model(inputs, inst)
            loss = criterion(output[inst], targets[inst])

            if compute_grad:
                loss.backward()

            avg_loss += loss.item()
            num_sources += 1

            all_outputs[inst] = output[inst].detach().clone()

        avg_loss /= float(num_sources)
    else:
        loss = 0
        all_outputs = model(inputs)
        for inst in all_outputs.keys():
            loss += criterion(all_outputs[inst], targets[inst])

        if compute_grad:
            loss.backward()

        avg_loss = loss.item() / float(len(all_outputs))

    return all_outputs, avg_loss


class DataParallel(torch.nn.DataParallel):
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataParallel, self).__init__(module, device_ids, output_device, dim)

    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

In [207]:
# wave_unet.py

class UpsamplingBlock(nn.Module):
    def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res):
        super(UpsamplingBlock, self).__init__()
        assert(stride > 1)

        # CONV 1 for UPSAMPLING
        if res == "fixed":
            self.upconv = Resample1d(n_inputs, 15, stride, transpose=True)
        else:
            self.upconv = ConvLayer(n_inputs, n_inputs, kernel_size, stride, conv_type, transpose=True)

        self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_outputs, kernel_size, 1, conv_type)] +
                                                [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)])

        # CONVS to combine high- with low-level information (from shortcut)
        self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_outputs + n_shortcut, n_outputs, kernel_size, 1, conv_type)] +
                                                 [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in range(depth - 1)])

    def forward(self, x, shortcut):
        # UPSAMPLE HIGH-LEVEL FEATURES
        upsampled = self.upconv(x)

        for conv in self.pre_shortcut_convs:
            upsampled = conv(upsampled)

        # Prepare shortcut connection
        combined = centre_crop(shortcut, upsampled)

        # Combine high- and low-level features
        for conv in self.post_shortcut_convs:
            combined = conv(torch.cat([combined, centre_crop(upsampled, combined)], dim=1))
        return combined

    def get_output_size(self, input_size):
        curr_size = self.upconv.get_output_size(input_size)

        # Upsampling convs
        for conv in self.pre_shortcut_convs:
            curr_size = conv.get_output_size(curr_size)

        # Combine convolutions
        for conv in self.post_shortcut_convs:
            curr_size = conv.get_output_size(curr_size)

        return curr_size

class DownsamplingBlock(nn.Module):
    def __init__(self, n_inputs, n_shortcut, n_outputs, kernel_size, stride, depth, conv_type, res):
        super(DownsamplingBlock, self).__init__()
        assert(stride > 1)

        self.kernel_size = kernel_size
        self.stride = stride

        # CONV 1
        self.pre_shortcut_convs = nn.ModuleList([ConvLayer(n_inputs, n_shortcut, kernel_size, 1, conv_type)] +
                                                [ConvLayer(n_shortcut, n_shortcut, kernel_size, 1, conv_type) for _ in range(depth - 1)])

        self.post_shortcut_convs = nn.ModuleList([ConvLayer(n_shortcut, n_outputs, kernel_size, 1, conv_type)] +
                                                 [ConvLayer(n_outputs, n_outputs, kernel_size, 1, conv_type) for _ in
                                                  range(depth - 1)])

        # CONV 2 with decimation
        if res == "fixed":
            self.downconv = Resample1d(n_outputs, 15, stride) # Resampling with fixed-size sinc lowpass filter
        else:
            self.downconv = ConvLayer(n_outputs, n_outputs, kernel_size, stride, conv_type)

    def forward(self, x):
        # PREPARING SHORTCUT FEATURES
        shortcut = x
        for conv in self.pre_shortcut_convs:
            shortcut = conv(shortcut)

        # PREPARING FOR DOWNSAMPLING
        out = shortcut
        for conv in self.post_shortcut_convs:
            out = conv(out)

        # DOWNSAMPLING
        out = self.downconv(out)

        return out, shortcut

    def get_input_size(self, output_size):
        curr_size = self.downconv.get_input_size(output_size)

        for conv in reversed(self.post_shortcut_convs):
            curr_size = conv.get_input_size(curr_size)

        for conv in reversed(self.pre_shortcut_convs):
            curr_size = conv.get_input_size(curr_size)
        return curr_size

class Waveunet(nn.Module):
    def __init__(self, num_inputs, num_channels, num_outputs, instruments, kernel_size, target_output_size, conv_type, res, separate=False, depth=1, strides=2):
        super(Waveunet, self).__init__()

        self.num_levels = len(num_channels)
        self.strides = strides
        self.kernel_size = kernel_size
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.depth = depth
        self.instruments = instruments
        self.separate = separate

        # Only odd filter kernels allowed
        assert(kernel_size % 2 == 1)

        self.waveunets = nn.ModuleDict()

        model_list = instruments if separate else ["ALL"]
        # Create a model for each source if we separate sources separately, otherwise only one (model_list=["ALL"])
        for instrument in model_list:
            module = nn.Module()

            module.downsampling_blocks = nn.ModuleList()
            module.upsampling_blocks = nn.ModuleList()

            for i in range(self.num_levels - 1):
                in_ch = num_inputs if i == 0 else num_channels[i]

                module.downsampling_blocks.append(
                    DownsamplingBlock(in_ch, num_channels[i], num_channels[i+1], kernel_size, strides, depth, conv_type, res))

            for i in range(0, self.num_levels - 1):
                module.upsampling_blocks.append(
                    UpsamplingBlock(num_channels[-1-i], num_channels[-2-i], num_channels[-2-i], kernel_size, strides, depth, conv_type, res))

            module.bottlenecks = nn.ModuleList(
                [ConvLayer(num_channels[-1], num_channels[-1], kernel_size, 1, conv_type) for _ in range(depth)])

            # Output conv
            outputs = num_outputs if separate else num_outputs * len(instruments)
            module.output_conv = nn.Conv1d(num_channels[0], outputs, 1)

            self.waveunets[instrument] = module

        self.set_output_size(target_output_size)

    def set_output_size(self, target_output_size):
        self.target_output_size = target_output_size

        self.input_size, self.output_size = self.check_padding(target_output_size)
        print("Using valid convolutions with " + str(self.input_size) + " inputs and " + str(self.output_size) + " outputs")

        assert((self.input_size - self.output_size) % 2 == 0)
        self.shapes = {"output_start_frame" : (self.input_size - self.output_size) // 2,
                       "output_end_frame" : (self.input_size - self.output_size) // 2 + self.output_size,
                       "output_frames" : self.output_size,
                       "input_frames" : self.input_size}

    def check_padding(self, target_output_size):
        # Ensure number of outputs covers a whole number of cycles so each output in the cycle is weighted equally during training
        bottleneck = 1

        while True:
            out = self.check_padding_for_bottleneck(bottleneck, target_output_size)
            if out is not False:
                return out
            bottleneck += 1

    def check_padding_for_bottleneck(self, bottleneck, target_output_size):
        module = self.waveunets[[k for k in self.waveunets.keys()][0]]
        try:
            curr_size = bottleneck
            for idx, block in enumerate(module.upsampling_blocks):
                curr_size = block.get_output_size(curr_size)
            output_size = curr_size

            # Bottleneck-Conv
            curr_size = bottleneck
            for block in reversed(module.bottlenecks):
                curr_size = block.get_input_size(curr_size)
            for idx, block in enumerate(reversed(module.downsampling_blocks)):
                curr_size = block.get_input_size(curr_size)

            assert(output_size >= target_output_size)
            return curr_size, output_size
        except AssertionError as e:
            return False

    def forward_module(self, x, module):
        '''
        A forward pass through a single Wave-U-Net (multiple Wave-U-Nets might be used, one for each source)
        :param x: Input mix
        :param module: Network module to be used for prediction
        :return: Source estimates
        '''
        shortcuts = []
        out = x

        # DOWNSAMPLING BLOCKS
        for block in module.downsampling_blocks:
            out, short = block(out)
            shortcuts.append(short)

        # BOTTLENECK CONVOLUTION
        for conv in module.bottlenecks:
            out = conv(out)

        # UPSAMPLING BLOCKS
        for idx, block in enumerate(module.upsampling_blocks):
            out = block(out, shortcuts[-1 - idx])

        # OUTPUT CONV
        out = module.output_conv(out)
        if not self.training:  # At test time clip predictions to valid amplitude range
            out = out.clamp(min=-1.0, max=1.0)
        return out

    def forward(self, x, inst=None):
        curr_input_size = x.shape[-1]
        assert(curr_input_size == self.input_size) # User promises to feed the proper input himself, to get the pre-calculated (NOT the originally desired) output size

        if self.separate:
            return {inst : self.forward_module(x, self.waveunets[inst])}
        else:
            assert(len(self.waveunets) == 1)
            out = self.forward_module(x, self.waveunets["ALL"])

            out_dict = {}
            for idx, inst in enumerate(self.instruments):
                out_dict[inst] = out[:, idx * self.num_outputs:(idx + 1) * self.num_outputs]
            return out_dict

In [208]:
# ---------------------------
# 5) Losses
# ---------------------------

def si_sdr(estimation, reference, eps=1e-8):
    # estimation & reference: (B, C, T) or (B, 1, T)
    # Returns SI-SDR in dB (per-example average over channels)
    B, C, T = estimation.shape
    est = estimation - estimation.mean(dim=-1, keepdim=True)
    ref = reference - reference.mean(dim=-1, keepdim=True)

    # projection of est on ref
    ref_energy = torch.sum(ref * ref, dim=-1, keepdim=True) + eps
    scalar = torch.sum(est * ref, dim=-1, keepdim=True) / ref_energy
    s_target = scalar * ref
    e_noise = est - s_target

    target_energy = torch.sum(s_target ** 2, dim=-1)
    noise_energy = torch.sum(e_noise ** 2, dim=-1) + eps
    si_sdr_val = 10 * torch.log10((target_energy + eps) / noise_energy)
    # average over channels then batch
    return si_sdr_val.mean()


def si_sdr_loss(estimation, reference):
    # negative SI-SDR (assuming you already have a function si_sdr)
    return -si_sdr(estimation, reference)

def reconstruction_loss(target_est, residual_est, mix):
    """
    Reconstruction term that aligns input (mix) with shorter model output.
    """
   
    return nn.functional.l1_loss(target_est + residual_est, mix)

def combined_loss(target_est, residual_est, reference, mix):
    """
    Final combined objective.
    """
    return nn.MSELoss()(target_est + residual_est, mix)



In [209]:
# ---------------------------
# 6) Metrics (MSE, MAE, SNR, SI-SDR)
# ---------------------------
def mse_metric(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return float(np.mean((y_true - y_pred)**2))


def mae_metric(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return float(np.mean(np.abs(y_true - y_pred)))


def snr_db_metric(y_true: np.ndarray, y_pred: np.ndarray, eps=1e-8) -> float:
    # SNR = 20*log10(rms_true / rms_error)
    rms_true = math.sqrt(np.mean(y_true**2) + eps)
    rms_err = math.sqrt(np.mean((y_true - y_pred)**2) + eps)
    return 20.0 * math.log10(rms_true / (rms_err + 1e-12))


def si_sdr_metric(y_true: np.ndarray, y_pred: np.ndarray, eps=1e-8) -> float:
    """
    Scale-Invariant SDR for single-channel signals
    y_true, y_pred: 1D numpy arrays (same length)
    """
    # remove mean
    s = y_true.astype(np.float64) - np.mean(y_true)
    s_hat = y_pred.astype(np.float64) - np.mean(y_pred)
    # projection
    s_target = (np.dot(s_hat, s) / (np.dot(s, s) + eps)) * s
    e_noise = s_hat - s_target
    num = np.sum(s_target**2)
    den = np.sum(e_noise**2) + eps
    return 10.0 * math.log10((num + eps) / den)


def sdr_metric(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-8) -> float:
    """
    Signal to Distortion Ratio (SDR)
    
    SDR = 10 * log10(||s_target||^2 / (||e_interf + e_noise + e_artif||^2))
    
    Args:
        y_true: Reference signal (target)
        y_pred: Estimated signal
        eps: Small value to avoid division by zero
    
    Returns:
        SDR in dB
    """
    # Remove mean (optional but often done)
    s = y_true.astype(np.float64) - np.mean(y_true)
    s_hat = y_pred.astype(np.float64) - np.mean(y_pred)
    
    # Projection for target component
    alpha = np.dot(s_hat, s) / (np.dot(s, s) + eps)
    s_target = alpha * s
    
    # Error (distortion)
    e_total = s_hat - s_target
    
    # Calculate energies
    target_energy = np.sum(s_target**2)
    distortion_energy = np.sum(e_total**2)
    
    return 10.0 * math.log10((target_energy + eps) / (distortion_energy + eps))


def sir_metric(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-8) -> float:
    """
    Signal to Interference Ratio (SIR)
    
    SIR = 10 * log10(||s_target||^2 / ||e_interf||^2)
    
    Args:
        y_true: Reference signal (target)
        y_pred: Estimated signal
        eps: Small value to avoid division by zero
    
    Returns:
        SIR in dB
    """
    # Remove mean
    s = y_true.astype(np.float64) - np.mean(y_true)
    s_hat = y_pred.astype(np.float64) - np.mean(y_pred)
    
    # Projection for target component
    alpha = np.dot(s_hat, s) / (np.dot(s, s) + eps)
    s_target = alpha * s
    
    # For SIR, we need the interference component
    # In single-channel case, interference is the part that correlates with other sources
    # For simplicity in single-channel, we use the residual after removing target
    e_interf = s_hat - s_target
    
    # Calculate energies
    target_energy = np.sum(s_target**2)
    interf_energy = np.sum(e_interf**2)
    
    return 10.0 * math.log10((target_energy + eps) / (interf_energy + eps))


def sar_metric(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-8) -> float:
    """
    Signal to Artifacts Ratio (SAR)
    
    SAR = 10 * log10(||s_target + e_interf||^2 / ||e_artif||^2)
    
    Args:
        y_true: Reference signal (target)
        y_pred: Estimated signal
        eps: Small value to avoid division by zero
    
    Returns:
        SAR in dB
    """
    # Remove mean
    s = y_true.astype(np.float64) - np.mean(y_true)
    s_hat = y_pred.astype(np.float64) - np.mean(y_pred)
    
    # Projection for target component
    alpha = np.dot(s_hat, s) / (np.dot(s, s) + eps)
    s_target = alpha * s
    
    # In single-channel case, artifacts are typically the residual
    # For SAR, we consider s_target + e_interf vs artifacts
    # In single-channel context, artifacts ≈ e_noise
    e_interf = s_hat - s_target
    
    # Signal + interference
    signal_plus_interf = s_target + e_interf
    
    # For single-channel, artifacts are typically modeled as the non-linear distortions
    # We approximate artifacts as the residual that cannot be explained by linear projection
    e_artif = s_hat - signal_plus_interf  # This would be zero in linear model
    
    # More practical approach for single-channel SAR
    signal_plus_interf_energy = np.sum(signal_plus_interf**2)
    artifacts_energy = np.sum(e_artif**2) if np.sum(e_artif**2) > eps else eps
    
    return 10.0 * math.log10((signal_plus_interf_energy + eps) / (artifacts_energy + eps))


def isr_metric(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-8) -> float:
    """
    Image to Spatial distortion Ratio (ISR) - Also known as Source Image to Spatial distortion Ratio
    
    ISR = 10 * log10(||s_target||^2 / ||e_spat||^2)
    
    Note: In single-channel audio, ISR is less commonly used and may be approximated.
    This implementation provides a reasonable approximation for single-channel case.
    
    Args:
        y_true: Reference signal (target)
        y_pred: Estimated signal
        eps: Small value to avoid division by zero
    
    Returns:
        ISR in dB
    """
    # Remove mean
    s = y_true.astype(np.float64) - np.mean(y_true)
    s_hat = y_pred.astype(np.float64) - np.mean(y_pred)
    
    # Projection for target component
    alpha = np.dot(s_hat, s) / (np.dot(s, s) + eps)
    s_target = alpha * s
    
    # For ISR in single-channel, we approximate spatial distortion
    # as the part that doesn't align with the target signal
    e_spat = s_hat - s_target
    
    # Calculate energies
    target_energy = np.sum(s_target**2)
    spat_energy = np.sum(e_spat**2)
    
    return 10.0 * math.log10((target_energy + eps) / (spat_energy + eps))


_METRIC_FUNCS = {
    "mse": mse_metric,
    "mae": mae_metric,
    "snr_db": snr_db_metric,
    "si_sdr": si_sdr_metric,
    "sdr": sdr_metric,
    "sir": sir_metric,
    "sar": sar_metric,
    "isr": isr_metric
}

In [210]:
# ---------------------------
# 6) train_epoch & validate_epoch
# ---------------------------
def train_epoch(model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer,
                criterion: nn.Module, device: torch.device, clip_grad: Optional[float] = 5.0):
    model.train()
    running_loss = 0.0
    n_samples = 0
    pbar = tqdm(loader, desc="train", leave=False)
    for mix, target, residual in pbar:

        input_size = model.shapes["input_frames"]
        output_size = model.target_output_size

        # Original mix signal with sample size of output_size
        orig_mix = mix.clone()

        # Padding from left and right for mixed signal
        sample_diff = input_size - mix.shape[-1]
        if sample_diff > 0:
            # pad left side
            mix = nn.functional.pad(mix, (sample_diff // 2, 0))
            # pad right side
            mix = nn.functional.pad(mix, (0, sample_diff - sample_diff // 2))
        else:
            raise ValueError(f"Expected a input_size > mix.shape[-1], but got {input_size} < {mix.shape[-1]}")
        
        orig_mix = orig_mix.to(device)
        mix = mix.to(device)
        target = target.to(device)
        residual = residual.to(device)

        optimizer.zero_grad()

        out_dict = model(mix)
        target_est, residual_est = out_dict["target"], out_dict["residual"]

        # Padding or cropping the target signal
        sample_diff = output_size - target_est.shape[-1]
        if sample_diff > 0:
            # pad left side
            target_est = nn.functional.pad(target_est, (sample_diff // 2, 0))
            # pad right side
            target_est = nn.functional.pad(target_est, (0, sample_diff - sample_diff // 2))
        else:
            # crop extra samples
            target_est = target_est[..., :output_size]
        
        # Padding or cropping the residual signal
        sample_diff = output_size - residual_est.shape[-1]
        if sample_diff > 0:
            # pad left side
            residual_est = nn.functional.pad(residual_est, (sample_diff // 2, 0))
            # pad right side
            residual_est = nn.functional.pad(residual_est, (0, sample_diff - sample_diff // 2))
        else:
            # crop extra samples
            residual_est = residual_est[..., :output_size]

        loss = criterion(target_est, residual_est, target, orig_mix)
        loss.backward()
        if clip_grad:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad)
        optimizer.step()
        batch_size = mix.shape[0]
        running_loss += loss.item() * batch_size
        n_samples += batch_size
        pbar.set_postfix(loss=running_loss / n_samples)
    return running_loss / max(1, n_samples)


def validate_epoch(model: nn.Module, loader: DataLoader, criterion: nn.Module,
                   device: torch.device, metric_names: List[str]):
    model.eval()
    running_loss = 0.0
    n_samples = 0
    metric_sums = {m: 0.0 for m in metric_names}
    with torch.no_grad():
        pbar = tqdm(loader, desc="val", leave=False)
        for mix, target, residual in pbar:
            input_size = model.shapes["input_frames"]
            output_size = model.target_output_size

            # Original mix signal with sample size of output_size
            orig_mix = mix.clone()

            # Padding from left and right for mixed signal
            sample_diff = input_size - mix.shape[-1]
            if sample_diff > 0:
                # pad left side
                mix = nn.functional.pad(mix, (sample_diff // 2, 0))
                # pad right side
                mix = nn.functional.pad(mix, (0, sample_diff - sample_diff // 2))
            else:
                raise ValueError(f"Expected a input_size > mix.shape[-1], but got {input_size} < {mix.shape[-1]}")
            
            orig_mix = orig_mix.to(device)
            mix = mix.to(device)
            target = target.to(device)
            residual = residual.to(device)

            out_dict = model(mix)
            target_est, residual_est = out_dict["target"], out_dict["residual"]

            # Padding or cropping the target signal
            sample_diff = output_size - target_est.shape[-1]
            if sample_diff > 0:
                # pad left side
                target_est = nn.functional.pad(target_est, (sample_diff // 2, 0))
                # pad right side
                target_est = nn.functional.pad(target_est, (0, sample_diff - sample_diff // 2))
            else:
                # crop extra samples
                target_est = target_est[..., :output_size]
            
            # Padding or cropping the residual signal
            sample_diff = output_size - residual_est.shape[-1]
            if sample_diff > 0:
                # pad left side
                residual_est = nn.functional.pad(residual_est, (sample_diff // 2, 0))
                # pad right side
                residual_est = nn.functional.pad(residual_est, (0, sample_diff - sample_diff // 2))
            else:
                # crop extra samples
                residual_est = residual_est[..., :output_size]

            loss = criterion(target_est, residual_est, target, orig_mix)
            batch_size = mix.shape[0]
            running_loss += loss.item() * batch_size
            n_samples += batch_size
            # compute metrics sample-wise in numpy
            est_np = target_est.detach().cpu().numpy()
            clean_np = target.detach().cpu().numpy()
            for b in range(batch_size):
                y_true = clean_np[b, 0, :]
                y_pred = est_np[b, 0, :]
                for m in metric_names:
                    metric_value = _METRIC_FUNCS[m](y_true, y_pred)
                    metric_sums[m] += metric_value
            pbar.set_postfix(loss=running_loss / n_samples)
    avg_loss = running_loss / max(1, n_samples)
    avg_metrics = {m: (metric_sums[m] / n_samples) for m in metric_names}
    return avg_loss, avg_metrics

In [211]:
# ---------------------------
# 7) plotting helpers
# ---------------------------
def plot_history(history: dict, out_dir: str):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    # loss curve
    plt.figure()
    plt.plot(history.get("train_loss", []), label="train_loss")
    plt.plot(history.get("val_loss", []), label="val_loss")
    plt.xlabel("epoch"); plt.ylabel("loss"); plt.legend(); plt.grid(True)
    f1 = out_dir / "loss_curve.png"
    plt.savefig(f1); plt.close()
    # metrics
    for m in history.get("metrics", {}).keys():
        plt.figure()
        plt.plot(history["metrics"][m], label=m)
        plt.xlabel("epoch"); plt.ylabel(m); plt.legend(); plt.grid(True)
        plt.savefig(out_dir / f"metric_{m}.png")
        plt.close()
    return out_dir

In [212]:
# ---------------------------
# 8) High-level train & evaluate pipeline
# ---------------------------
def train_and_evaluate(
    processed_root: str,
    dataset_name: str,
    train_config_path: str,
    metrics_config_path: str,
    output_dir: str = "model_output",
    mlflow_enabled: bool = False,
    seed: int = 42
) -> dict:
    """
    High-level entrypoint for training and evaluation. This is Airflow callable.
    Returns a summary dict with final metrics and paths.
    """

    # load configs
    train_cfg, metrics_cfg = load_configs(train_config_path, metrics_config_path)
    # set seeds
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # create train/test splits if needed
    create_train_test_splits(processed_root, dataset_name, train_frac=float(train_cfg.get("train_frac", 0.9)), seed=seed)

    # dataset & dataloaders
    train_dataset = WaveUNetDataset(processed_root, dataset_name, split="train")
    val_dataset = WaveUNetDataset(processed_root, dataset_name, split="test")
    train_loader = DataLoader(train_dataset, batch_size=train_cfg["batch_size"], shuffle=True, num_workers=int(train_cfg.get("num_workers", 4)), pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=train_cfg["batch_size"], shuffle=False, num_workers=int(train_cfg.get("num_workers", 4)), pin_memory=True)

    # model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    kernel_size = int(train_cfg.get("kernel_size", 15))

    levels = int(train_cfg.get("levels", 6))
    features = int(train_cfg.get("features", 32))
    feature_growth = train_cfg.get("feature_growth", "add")

    depth = int(train_cfg.get("depth", 1))
    strides = int(train_cfg.get("stride", 4))


    num_features = [features*i for i in range(1, levels+1)] if feature_growth == "add" else \
        [features*2**i for i in range(0, levels)]

    model = Waveunet(
        num_inputs=1,
        num_channels=num_features,
        num_outputs=1,
        instruments=['target', 'residual'],
        kernel_size=kernel_size,
        target_output_size=66150,
        conv_type="gn",
        res="fixed",
        separate=False,
        depth=depth,
        strides=strides
    )

    model.to(device)

    # criterion, optimizer, scheduler
    criterion = combined_loss
    optimizer = torch.optim.Adam(model.parameters(), lr=float(train_cfg.get("lr", 1e-4)))
    scheduler = None
    if train_cfg.get("lr_step"):
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=train_cfg["lr_step"], gamma=float(train_cfg.get("lr_gamma", 0.5)))

    # metrics and monitoring
    metric_names = metrics_cfg.get("metrics", ["mse"])
    monitor = train_cfg.get("monitor_metric", "val_loss")  # e.g. "val_loss" or "si_sdr"
    monitor_mode = train_cfg.get("monitor_mode", "min")   # "min" or "max"
    best_score = math.inf if monitor_mode == "min" else -math.inf
    best_ckpt_path = None

    # output dirs
    out_root = Path(output_dir)
    ckpt_dir = out_root / "checkpoints"; ckpt_dir.mkdir(parents=True, exist_ok=True)
    plot_dir = out_root / "plots"; plot_dir.mkdir(parents=True, exist_ok=True)
    history = {"train_loss": [], "val_loss": [], "metrics": {m: [] for m in metric_names}}

    # MLflow start run
    mlflow_run = None
    if mlflow_enabled:
        if not _mlflow_available:
            logger.warning("MLflow requested but not available; continuing without MLflow.")
            mlflow_enabled = False
        else:
            mlflow.start_run()
            mlflow_run = mlflow.active_run()
            mlflow.log_params(train_cfg)
            mlflow.log_params({"metrics_cfg": metrics_cfg})

    num_epochs = int(train_cfg.get("epochs", 50))
    for epoch in range(num_epochs):
        logger.info(f"Epoch {epoch+1}/{num_epochs}")
        t0 = time.time()
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device, clip_grad=float(train_cfg.get("clip_grad", 5.0)))
        val_loss, val_metrics = validate_epoch(model, val_loader, criterion, device, metric_names)
        logger.info(f"Epoch {epoch+1} train_loss={train_loss:.6f} val_loss={val_loss:.6f} metrics={val_metrics} time={(time.time()-t0):.1f}s")

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        for m in metric_names:
            history["metrics"][m].append(val_metrics.get(m, None))

        # scheduler step
        if scheduler:
            scheduler.step()

        # monitoring and checkpointing
        # If monitor_metric is "val_loss" use that; if it's in val_metrics use that.
        if monitor == "val_loss":
            current = val_loss
        else:
            current = val_metrics.get(monitor)
            if current is None:
                logger.warning(f"Monitor metric {monitor} not found in val metrics; defaulting to val_loss.")
                current = val_loss

        is_better = (current < best_score) if monitor_mode == "min" else (current > best_score)
        if is_better:
            best_score = current
            best_ckpt_path = ckpt_dir / f"best_{monitor}_{epoch+1}.pt"
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_cfg": train_cfg,
                "metrics_cfg": metrics_cfg,
                "best_score": best_score
            }, str(best_ckpt_path))
            logger.info(f"Saved new best checkpoint: {best_ckpt_path}")
            if mlflow_enabled:
                mlflow.log_metric(f"best_{monitor}", float(best_score), step=epoch)

        # log epoch metrics to mlflow
        if mlflow_enabled:
            mlflow.log_metric("train_loss", float(train_loss), step=epoch)
            mlflow.log_metric("val_loss", float(val_loss), step=epoch)
            for m, v in val_metrics.items():
                mlflow.log_metric(m, float(v), step=epoch)

    # after training: load best model and evaluate on test set (here val set is test)
    if best_ckpt_path is not None:
        ckpt = torch.load(str(best_ckpt_path), map_location=device)
        model.load_state_dict(ckpt["model_state_dict"])
        logger.info(f"Loaded best model from {best_ckpt_path} for final evaluation.")
    else:
        logger.warning("No checkpoint saved during training; using last model for evaluation.")

    # Final evaluation (on val/test loader)
    final_loss, final_metrics = validate_epoch(model, val_loader, criterion, device, metric_names)
    logger.info(f"Final evaluation: loss={final_loss:.6f}, metrics={final_metrics}")

    # Save history & plots
    hist_json = out_root / "history.json"
    hist_json.write_text(json.dumps(history, indent=2))
    plot_history(history, plot_dir)

    # MLflow final logging & artifacts
    if mlflow_enabled:
        mlflow.log_metric("final_val_loss", float(final_loss))
        for m, v in final_metrics.items():
            mlflow.log_metric(f"final_{m}", float(v))
        # log artifacts
        mlflow.log_artifacts(str(plot_dir), artifact_path="plots")
        mlflow.log_artifact(str(hist_json), artifact_path="history")
        if best_ckpt_path:
            mlflow.log_artifact(str(best_ckpt_path), artifact_path="checkpoints")
        mlflow.end_run()

    summary = {
        "best_checkpoint": str(best_ckpt_path) if best_ckpt_path else None,
        "final_val_loss": float(final_loss),
        "final_metrics": final_metrics,
        "history_path": str(hist_json),
        "plot_dir": str(plot_dir)
    }
    return summary

In [213]:
train_and_evaluate(
    processed_root="guitar_dataset/processed",
    dataset_name="dataset2-v2",
    train_config_path="train_config.yaml",
    metrics_config_path="metrics_config.yaml",
    mlflow_enabled=False
)

INFO:model_pipeline:Epoch 1/10


Train/test subfolders already exist; skipping split creation.
Using valid convolutions with 68237 inputs and 66165 outputs


INFO:model_pipeline:Epoch 1 train_loss=0.080157 val_loss=0.002062 metrics={'mse': 0.00780771931167692, 'mae': 0.0666709529235959, 'snr_db': 1.7844986663677989, 'si_sdr': 10.788614429431686} time=4.5s
INFO:model_pipeline:Saved new best checkpoint: model_output/checkpoints/best_snr_db_1.pt
INFO:model_pipeline:Epoch 2/10
INFO:model_pipeline:Epoch 2 train_loss=0.001079 val_loss=0.001101 metrics={'mse': 0.009318339102901519, 'mae': 0.07712731659412383, 'snr_db': 0.7483060713206292, 'si_sdr': 11.133012477296319} time=4.9s
INFO:model_pipeline:Saved new best checkpoint: model_output/checkpoints/best_snr_db_2.pt
INFO:model_pipeline:Epoch 3/10
INFO:model_pipeline:Epoch 3 train_loss=0.000509 val_loss=0.000818 metrics={'mse': 0.009796246467158199, 'mae': 0.07916171066462993, 'snr_db': 0.5407234038804357, 'si_sdr': 10.50453961018179} time=4.7s
INFO:model_pipeline:Saved new best checkpoint: model_output/checkpoints/best_snr_db_3.pt
INFO:model_pipeline:Epoch 4/10
INFO:model_pipeline:Epoch 4 train_los

{'best_checkpoint': 'model_output/checkpoints/best_snr_db_3.pt',
 'final_val_loss': 0.0008176200906746089,
 'final_metrics': {'mse': 0.009796246467158199,
  'mae': 0.07916171066462993,
  'snr_db': 0.5407234038804357,
  'si_sdr': 10.50453961018179},
 'history_path': 'model_output/history.json',
 'plot_dir': 'model_output/plots'}

## Inference

In [220]:
# inference_pipeline.py
import os
import json
import time
from pathlib import Path
from typing import List, Tuple, Optional, Dict

import numpy as np
import soundfile as sf
import librosa
from tqdm.auto import tqdm

import torch

In [221]:
# ------------------------
# Helpers: chunking & IO
# ------------------------
def split_audio_into_chunks(
    audio: np.ndarray,
    sample_rate: int,
    chunk_duration: float
) -> Tuple[List[np.ndarray], int]:
    """
    Split 1D numpy audio into non-overlapping chunks of chunk_duration (seconds).
    Pads the last chunk with zeros if needed.

    Returns:
      chunks: list of numpy arrays (each length == chunk_len)
      orig_len: original audio length in samples (for trimming after reconstruction)
    """
    if audio.ndim > 1:
        audio = np.mean(audio, axis=1)  # to mono

    orig_len = len(audio)
    chunk_len = int(round(chunk_duration * sample_rate))
    if chunk_len <= 0:
        raise ValueError("chunk_duration too small for given sample_rate")

    n_chunks = (orig_len + chunk_len - 1) // chunk_len
    chunks = []
    for i in range(n_chunks):
        s = i * chunk_len
        e = s + chunk_len
        if e <= orig_len:
            chunks.append(audio[s:e].astype(np.float32))
        else:
            # pad right
            pad = e - orig_len
            chunk = np.pad(audio[s:orig_len], (0, pad), mode='constant').astype(np.float32)
            chunks.append(chunk)
    return chunks, orig_len


def reconstruct_from_chunks(chunks_preds: List[np.ndarray], orig_len: int) -> np.ndarray:
    """
    Concatenate list of 1D arrays (all same length) and trim to orig_len samples.
    """
    if not chunks_preds:
        return np.zeros(orig_len, dtype=np.float32)
    out = np.concatenate(chunks_preds, axis=0)
    if len(out) >= orig_len:
        return out[:orig_len].astype(np.float32)
    # if concatenated length is shorter (shouldn't happen), pad
    pad = orig_len - len(out)
    return np.pad(out, (0, pad), mode='constant').astype(np.float32)

In [222]:
# ------------------------
# Model loading
# ------------------------
def load_best_model(checkpoint_path: str, device: Optional[torch.device] = None):
    """
    Loads a checkpoint saved by train_and_evaluate(...) and returns a model in eval mode.
    The checkpoint is expected to contain 'model_state_dict' and optionally 'train_cfg' dict.
    If train_cfg exists, it is used to re-create the WaveUNet1D with the same architecture.
    """
    device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    ckpt = torch.load(checkpoint_path, map_location=device)

    # try to infer model config
    train_cfg = ckpt.get("train_cfg") or ckpt.get("args") or {}
    kernel_size = int(train_cfg.get("kernel_size", 15))

    levels = int(train_cfg.get("levels", 6))
    features = int(train_cfg.get("features", 32))
    feature_growth = train_cfg.get("feature_growth", "add")

    depth = int(train_cfg.get("depth", 1))
    strides = int(train_cfg.get("stride", 4))


    num_features = [features*i for i in range(1, levels+1)] if feature_growth == "add" else \
        [features*2**i for i in range(0, levels)]

    model = Waveunet(
        num_inputs=1,
        num_channels=num_features,
        num_outputs=1,
        instruments=['target', 'residual'],
        kernel_size=kernel_size,
        target_output_size=66150,
        conv_type="gn",
        res="fixed",
        separate=False,
        depth=depth,
        strides=strides
    )
    
    model.load_state_dict(ckpt["model_state_dict"])
    model.to(device)
    model.eval()
    return model, train_cfg

In [223]:
# ------------------------
# Core inference for a single recording
# ------------------------
def infer_single_recording(
    input_path: str,
    model: torch.nn.Module,
    sample_rate: int,
    chunk_duration: float,
    batch_size: int,
    device: torch.device,
    output_dir: str,
    checkpoint_path: str,
    dtype=np.float32
) -> Dict:
    """
    Process one recording: split into chunks, run model in batches, reconstruct, save output and metadata.

    Returns metadata dict with paths and basic stats.
    """
    input_path = Path(input_path)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # load and ensure mono
    audio, sr = librosa.load(str(input_path), sr=sample_rate, mono=True)
    chunks, orig_len = split_audio_into_chunks(audio, sample_rate, chunk_duration)

    preds_target = []
    preds_residual = []
    model_device = device

    with torch.no_grad():
        for i in range(0, len(chunks), batch_size):
            batch_chunks = chunks[i:i + batch_size]
            # build tensor shape (B,1,L)
            batch_arr = np.stack(batch_chunks, axis=0)  # (B, L)
            batch_tensor = torch.from_numpy(batch_arr).float().unsqueeze(1).to(model_device)  # (B,1,L)

            input_size = model.shapes["input_frames"]

            # Padding from left and right for mixed signal
            sample_diff = input_size - batch_tensor.shape[-1]
            if sample_diff > 0:
                # pad left side
                batch_tensor = nn.functional.pad(batch_tensor, (sample_diff // 2, 0))
                # pad right side
                batch_tensor = nn.functional.pad(batch_tensor, (0, sample_diff - sample_diff // 2))
            else:
                raise ValueError(f"Expected a input_size > mix.shape[-1], but got {input_size} < {mix.shape[-1]}")

            out_dict = model(batch_tensor)  # expected (B,1,L) thanks to model alignment
            target_est, residual_est = out_dict["target"], out_dict["residual"]
            target_tensor = target_est.detach().cpu().numpy()  # (B,1,L)
            residual_tensor = residual_est.detach().cpu().numpy()
            
            out_arrs_target = [o[0].astype(dtype) for o in target_tensor]  # list of 1D arrays
            out_arrs_residual = [o[0].astype(dtype) for o in residual_tensor]
            preds_target.extend(out_arrs_target)
            preds_residual.extend(out_arrs_residual)

    # reconstruct
    reconstructed_target = reconstruct_from_chunks(preds_target, orig_len)
    reconstructed_residual = reconstruct_from_chunks(preds_residual, orig_len)

    # save
    out_basename = input_path.stem + "_target_recon.wav"
    out_path = output_dir / out_basename
    sf.write(str(out_path), reconstructed_target, sample_rate)

    out_basename = input_path.stem + "_resid_recon.wav"
    out_path = output_dir / out_basename
    sf.write(str(out_path), reconstructed_residual, sample_rate)

    # save metadata
    meta = {
        "input_path": str(input_path),
        "output_path": str(out_path),
        "checkpoint": str(checkpoint_path),
        "sample_rate": int(sample_rate),
        "chunk_duration": float(chunk_duration),
        "n_chunks": int(len(chunks)),
        "orig_len_samples": int(orig_len),
        "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
    }
    meta_path = output_dir / (input_path.stem + "_meta.json")
    with open(meta_path, "w", encoding="utf-8") as f:
        json.dump(meta, f, indent=2)

    return meta

In [224]:
# ------------------------
# Top-level Airflow-friendly entrypoint
# ------------------------
def run_inference_pipeline(
    input_root: str,
    model_checkpoint: str,
    output_root: str,
    chunk_duration: float = 3.0,
    sample_rate: int = 22050,
    batch_size: int = 8,
    device_str: Optional[str] = None,
    glob_patterns: Optional[List[str]] = None,
    max_files: Optional[int] = None
) -> Dict:
    """
    Main function to be called by Airflow PythonOperator.

    Parameters:
      - input_root: folder with recordings (will search recursively)
      - model_checkpoint: path to checkpoint (best model)
      - output_root: where to save reconstructed recordings and metadata
      - chunk_duration: seconds per chunk to pass to model
      - sample_rate: sampling rate for loading/saving
      - batch_size: inference batch size
      - device_str: 'cuda' or 'cpu' (if None, automatically picked)
      - glob_patterns: list of glob patterns to find files (default ['**/*.wav'])
      - max_files: optional cap on number of recordings to process

    Returns:
      summary dict with list of processed files and metadata paths.
    """
    device = torch.device(device_str if device_str is not None else ("cuda" if torch.cuda.is_available() else "cpu"))
    model, train_cfg = load_best_model(model_checkpoint, device=device)

    input_root = Path(input_root)
    out_root = Path(output_root)
    out_root.mkdir(parents=True, exist_ok=True)
    glob_patterns = glob_patterns or ["**/*.wav"]

    # collect files
    files = []
    for pat in glob_patterns:
        files.extend(sorted(input_root.glob(pat)))
    if not files:
        raise RuntimeError(f"No audio files found in {input_root} with patterns {glob_patterns}")

    if max_files is not None:
        files = files[:max_files]

    processed = []
    pbar = tqdm(files, desc="Inference files", dynamic_ncols=True)
    for f in pbar:
        try:
            meta = infer_single_recording(
                input_path=str(f),
                model=model,
                sample_rate=sample_rate,
                chunk_duration=chunk_duration,
                batch_size=batch_size,
                device=device,
                output_dir=str(out_root),
                checkpoint_path=model_checkpoint
            )
            processed.append(meta)
        except Exception as e:
            # don't crash whole DAG on single file; instead record error
            processed.append({"input_path": str(f), "error": str(e)})
            # you may also choose to re-raise if you want failure semantics in Airflow
            # raise

    summary = {
        "model_checkpoint": str(model_checkpoint),
        "n_files_requested": len(files),
        "n_processed": len(processed),
        "output_root": str(out_root),
        "processed": processed
    }
    # save summary
    with open(out_root / "inference_summary.json", "w", encoding="utf-8") as fh:
        json.dump(summary, fh, indent=2)
    return summary

In [225]:
run_inference_pipeline(
    input_root="guitar_dataset/processed/quality_test/orig",
    model_checkpoint="model_output/checkpoints/best_snr_db_5.pt",
    output_root="guitar_dataset/processed/quality_test/recon"
)

Using valid convolutions with 99661 inputs and 66229 outputs


Inference files:   0%|          | 0/1 [00:00<?, ?it/s]

{'model_checkpoint': 'model_output/checkpoints/best_snr_db_5.pt',
 'n_files_requested': 1,
 'n_processed': 1,
 'output_root': 'guitar_dataset/processed/quality_test/recon',
 'processed': [{'input_path': 'guitar_dataset/processed/quality_test/orig/quality_test_1.wav',
   'output_path': 'guitar_dataset/processed/quality_test/recon/quality_test_1_resid_recon.wav',
   'checkpoint': 'model_output/checkpoints/best_snr_db_5.pt',
   'sample_rate': 22050,
   'chunk_duration': 3.0,
   'n_chunks': 6,
   'orig_len_samples': 377732,
   'created_at': '2025-10-06T14:06:37Z'}]}

In [None]:
# ------------------------
# Example Airflow PythonOperator usage:
# ------------------------
# from airflow import DAG
# from airflow.operators.python import PythonOperator
# from datetime import datetime
#
# with DAG("waveunet_inference", start_date=datetime(2025,10,4), schedule=None, catchup=False) as dag:
#     infer_task = PythonOperator(
#         task_id="run_inference",
#         python_callable=run_inference_pipeline,
#         op_kwargs={
#             "input_root": "/mnt/data/to_process",
#             "model_checkpoint": "/path/to/best_checkpoint.pt",
#             "output_root": "/mnt/data/inference_out",
#             "chunk_duration": 3.0,
#             "sample_rate": 22050,
#             "batch_size": 8,
#             "device_str": "cuda",
#             "glob_patterns": ["**/*.wav"],
#             "max_files": 200
#         },
#     )
