In [6]:
!pip install natsort




[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [1]:
from pathlib import Path
from natsort import natsorted  # pip install natsort
import pandas as pd
import re

ROOT = Path(r"C:\Users\13523\Desktop\URochDataset_trimmed")

# piece metadata from folder name only (authoritative)
FOLDER_RE = re.compile(r"^(?P<idx>\d+)[ _-]+(?P<piece>[A-Za-z0-9]+)_(?P<instr>.+)$")
def parse_folder_meta(folder: Path):
    m = FOLDER_RE.match(folder.name)
    if not m:
        return {"piece_index": None, "piece": folder.name, "folder_instr": None}
    d = m.groupdict()
    return {"piece_index": d["idx"].zfill(2), "piece": d["piece"].title(), "folder_instr": d["instr"].lower()}

def list_clean_files(folder: Path):
    """List files, drop macOS resource forks (._*)."""
    fs = [p for p in folder.iterdir() if p.is_file()]
    fs = [p for p in fs if not p.name.startswith("._")]
    return natsorted(fs, key=lambda p: p.name.lower())

def block_indices(files):
    """Return indices for category blocks using only startswith checks."""
    names = [f.name for f in files]
    idx = {
        "AuMix":   [i for i,n in enumerate(names) if n.startswith("AuMix_")],
        "AuSep":   [i for i,n in enumerate(names) if n.startswith("AuSep_")],
        "F0s":     [i for i,n in enumerate(names) if n.startswith("F0s_")],
        "Notes":   [i for i,n in enumerate(names) if n.startswith("Notes_")],
        "Score":   [i for i,n in enumerate(names) if n.startswith("Sco_")],
        "Video":   [i for i,n in enumerate(names) if n.startswith("Vid_")],
    }
    return idx

def index_piece_by_position(folder: Path):
    meta  = parse_folder_meta(folder)
    files = list_clean_files(folder)
    idx   = block_indices(files)

    rows = []
    # infer tracks by counting AuSep files
    n_tracks = len(idx["AuSep"])

    # AuMix (0 or 1)
    for i in idx["AuMix"]:
        rows.append({**meta, "category":"AuMix", "track":None,
                     "instrument": meta["folder_instr"], "ext":files[i].suffix.lower(),
                     "path": str(files[i].resolve()), "folder": folder.name})

    # Per-track blocks (use only index order; no parsing)
    for t, i in enumerate(idx["AuSep"], start=1):
        rows.append({**meta, "category":"AuSep", "track":t,
                     "instrument": None, "ext":files[i].suffix.lower(),
                     "path": str(files[i].resolve()), "folder": folder.name})

    for t, i in enumerate(idx["F0s"], start=1):
        rows.append({**meta, "category":"F0s", "track":t,
                     "instrument": None, "ext":files[i].suffix.lower(),
                     "path": str(files[i].resolve()), "folder": folder.name})

    for t, i in enumerate(idx["Notes"], start=1):
        rows.append({**meta, "category":"Notes", "track":t,
                     "instrument": None, "ext":files[i].suffix.lower(),
                     "path": str(files[i].resolve()), "folder": folder.name})

    # Score (mid/pdf, order doesn’t matter)
    for i in idx["Score"]:
        rows.append({**meta, "category":"Score", "track":None,
                     "instrument": meta["folder_instr"], "ext":files[i].suffix.lower(),
                     "path": str(files[i].resolve()), "folder": folder.name})

    # Video (0 or 1)
    for i in idx["Video"]:
        rows.append({**meta, "category":"Video", "track":None,
                     "instrument": meta["folder_instr"], "ext":files[i].suffix.lower(),
                     "path": str(files[i].resolve()), "folder": folder.name})

    # Quick sanity checks (optional)
    if n_tracks and (len(idx["F0s"]) not in (0, n_tracks) or len(idx["Notes"]) not in (0, n_tracks)):
        print(f"[WARN] {folder.name}: tracks inferred={n_tracks} "
              f"but F0s={len(idx['F0s'])}, Notes={len(idx['Notes'])}")

    return rows

def build_manifest(root: Path):
    all_rows = []
    for fol in natsorted([d for d in root.iterdir() if d.is_dir()], key=lambda p: p.name.lower()):
        all_rows.extend(index_piece_by_position(fol))
    df = pd.DataFrame(all_rows)
    if not df.empty:
        df = df.sort_values(["piece_index","category","track"], na_position="last").reset_index(drop=True)
    return df

# Run
df = build_manifest(ROOT)
out_csv = ROOT / "_manifest_index_only.csv"
df.to_csv(out_csv, index=False)
print("[saved]", out_csv)
print(df.head(20))


[saved] C:\Users\13523\Desktop\URochDataset_trimmed\_manifest_index_only.csv
   piece_index    piece folder_instr category  track instrument   ext  \
0           01  Jupiter        vn_vc    AuMix    NaN      vn_vc  .wav   
1           01  Jupiter        vn_vc    AuSep    1.0       None  .wav   
2           01  Jupiter        vn_vc    AuSep    2.0       None  .wav   
3           01  Jupiter        vn_vc      F0s    1.0       None  .txt   
4           01  Jupiter        vn_vc      F0s    2.0       None  .txt   
5           01  Jupiter        vn_vc    Notes    1.0       None  .txt   
6           01  Jupiter        vn_vc    Notes    2.0       None  .txt   
7           01  Jupiter        vn_vc    Score    NaN      vn_vc  .mid   
8           01  Jupiter        vn_vc    Score    NaN      vn_vc  .pdf   
9           01  Jupiter        vn_vc    Video    NaN      vn_vc  .mp4   
10          02   Sonata        vn_vn    AuMix    NaN      vn_vn  .wav   
11          02   Sonata        vn_vn    AuSep  

In [None]:
#midi conversion
import tempfile, subprrocess
import soundfile as sf
import pretty_midi

def render_midi_with_fluidsynth(midi_path, output_dir, soundfont_path, sr=32000, gain=0.5):
    midi_path = Path(midi_path); output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    midi = pretty_midi.PrettyMIDI(str(midi_path))
    stems = {}
    for i, inst in enumerate(midi.instruments):
        pm = pretty_midi.PrettyMIDI()
        pm.instruments.append(inst)
        tmp = tempfile.NamedTemporaryFile(suffix=".mid", delete=False)
        pm.write(tmp.name)
        stem_name = (inst.name or f"program{inst.program}").replace(" ","_").replace("/","_")
        out_wav = output_dir / f"{midi_path.stem}_inst{i}_{stem_name}.wav" #output name
        subprocess.run([
            "fluidsynth","-ni", soundfont_path, tmp.name,
            "-F", str(out_wav), "-r", str(sr), "-g", str(gain)
        ], check=True, capture_output=True)
        audio, _ = sf.rad(out_wav)
        stems[stem_name] = audio.astype(np.float32)
        tmp.close()
        Path(tmp.name).unlink(missing_ok=True)
    #mix
    max_len = max(len(x) for x in stems.values())
    mix = np.zeros(max_len, dtype=np.float32)
    for x in stems.values():
        if len(x) < max_len:
            x = np.pad(x, (0, max_len-len(x)))
        mix += x
    mix = mix / max(1e-6, np.max(np.abs(mix)))
    mix_path = output_dir / f"{midi_path.stem}_mixture.wav"
    sf.write(mix_path, mix, sr)
    return mix_path, stems

def render_midi_with_pretty_midi(midi_path, output_dir, sr=32000):
    midi_path = Path(midi_path); output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    midi = pretty_midi.PrettyMIDI(str(midi_path))
    stems = {}
    longest = 0
    for i, inst in enumerate(midi.instruments):
        audio = inst.synthesize(fs=sr).astype(np.float32)
        name = (inst.name or f"program{inst.program}").replace(" ","_")
        sf.write(output_dir / f"{midi_path.stem}_inst{i}_{name}.wav", audio, sr)
        stems[name] = audio; longest = max(longest, len(audio))
    mix = np.zeros(longest, dtype=np.float32)
    for a in stems.values():
        if len(a) < longest:
            a = np.pad(a, (0, longest-len(a)))
        mix += a
    mix = mix / max(1e-6, np.max(np.abs(mix)))
    mix_path = output_dir / f"{midi_path.stem}_mixture.wav"
    sf.write(mix_path, mix, sr)
    return mix_path, stems

def batch_render_midis(midi_dir, output_dir, method='fluidsynth', soundfont_path='/usr/share/sounds/sf2/FluidR3_GM.sf2', sr=32000):
    midi_dir = Path(midi_dir)
    files = list(midi_dir.glob("*.mid")) + list(midi_dir.glob("*.midi"))
    print(f"Found {len(files)} MIDI files")
    for m in sorted(files):
        print(f"Rendering {m.name}…")
        try:
            if method=="fluidsynth":
                render_midi_with_fluidsynth(m, output_dir, soundfont_path, sr=sr)
            else:
                render_midi_with_pretty_midi(m, output_dir, sr=sr)
        except Exception as e:
            print("  ✗ Failed:", e)

In [None]:
#audio proccessing
import librosa
import pyloudnorm as pyln
import torch, torchaudio
eps = 1e-10

def ensure_mono(wav: torch.Tensor) -> torch.Tensor: #wav: (ch, n)
    if wav.dim() == 1:
        return wav.unsqueeze(0)
    if wav.shape[0] > 1:
        return wav.mean(dim=0, keepdim=True)
    return wav

def loudness_normalize_lufs(wav: torch.Tensor, sr: int, target_lufs =- 23.0) -> torch.Tensor:
    x = wav.squeeze(0).cpu().numpy().astype(np.float32)
    meter = pyln.Meter(sr)
    try:
        lufs = meter.integrated_loudness(x)
        y = pyln.normalize.loudness(x, lufs, target_lufs)
    except ValueError:
        y = x
    y = np.clip(y, -1.0, 1,0)
    return torch.from_numpy(y).unsqueeze(0)

def preprocess_audio(audio_path, target_sr=32000, target_lufs=- 23.0, trim_db =- 40):
    wav, sr = torchaudio.load(str(audio_path)) #load audio + sampling rate
    if sr!= target_sr:
        wav = torchaudio.functional.resample(wav,sr, target_sr); sr = target_sr
    wav = ensure_mono(wav)
    wav = loudness_normalize_lufs(wav, sr, targget_lufs)
    wav = trim_silence(wav, sr, threshold_db=trim_db)
    return wav, sr

def trim_silence(wav: torch.Tensor, sr: int, threshold_db =- 40, frame_length = 2048, hop_length = 512):
    x = wav.squeeze(0).cpu().numpy().astype(np.float32)
    rms = librosa.feature.rms(y=x, frame_length=frame_length, hop_length=hop_length)[0]
    rms_db = librosa.amplitude_to_db(rms, ref=np.max+eps)
    voiced = rms_db > threshold_db
    if not voiced.any(): #return wav if its empty
        return wav
    idx = np.where(voiced)[0]
    start = max(0, idx[0]-5): end = min(len(rms), idx[-1]+5)
    start_s = start*hop_length; end_s = min(len(x), end*hop_length)
    y = x[start_s:end_s]
    return torch.from_numpy(y).unsqueeze(0)

def make_mel_transform(sr=32000, n_fft=2048, hop=512, n_mels=128, fmin=55, fmax=8000):
    mel = torchaudio.transforms.MelSpectrogram(
        ample_rate=sr, n_fft=n_fft, hop_length=hop, win_length=n_fft,
        n_mels=n_mels, f_min=fmin, f_max=fmax, window_fn=torch.hann_window,
        power=2.0, normalized=False, center=True, pad_mode="reflect",
        mel_scale="htk", norm="slaney"
    )
    to_db = torchaudio.transforms.AmplitudeToDB(stype="power", top_db=80)
    return mel, to_db

def crop_or_pad_spec(spec: torch.Tensor, target_frames: int) -> torch.Tensor:
    F, T = spec.shape
    if T == target_frames: return spec
    if T > target_frames:
        s = (T - target_frames) // 2
        return spec[:, s:s+target_frame]
    pad = target_frames - T
    left = pad // 2; right = pad - left
    return torch.nn.functional.pad(spec, (left, right))

In [None]:
#loss function for mel-spectrogram used for source seperation 

class SeparationLoss(nn.Module):
    def __init__(self, cfg, band_mask=None):
        super().__init__()
        self.cfg = cfg["loss"]
        self.band_mask = band_mask  # optional (4,F,1) tensor on same device

    def forward(self, pred_masks, target_masks_lin, mixture_lin, presence):
        """
        pred_masks:      (B,4,F,T) in [0,1]
        target_masks_lin:(B,4,F,T) Wiener targets (linear mel power)
        mixture_lin:     (B,F,T)   linear mel power
        presence:        (B,4)     0/1
        """
        losses = {}
        B, C, F, T = pred_masks.shape

        # 1) Mask L1 (presence-aware weighting)
        pres_w = presence.unsqueeze(-1).unsqueeze(-1)          # (B,4,1,1)
        l1_err = (pred_masks - target_masks_lin).abs()
        l_mask = (l1_err * (0.5 + 0.5*pres_w)).mean()
        losses["mask_loss"] = float(l_mask.detach())
        total = self.cfg["mask_loss"]["weight"] * l_mask

        # 2) Reconstruction (silence weighting)
        pred_sources = pred_masks * mixture_lin.unsqueeze(1)   # (B,4,F,T)
        recon = pred_sources.sum(dim=1)                        # (B,F,T)
        mix_weight = (mixture_lin > 1e-6).float()
        recon_err = (recon - mixture_lin) ** 2
        l_recon = (recon_err * mix_weight).sum() / mix_weight.sum().clamp_min(1.0)
        losses["recon_loss"] = float(l_recon.detach())
        total += self.cfg["reconstruction_loss"]["weight"] * l_recon

        # 3) Activity BCE (safe clamp)
        pred_act = pred_masks.amax(dim=(2,3)).clamp(1e-6, 1-1e-6)  # (B,4)
        l_act = F.binary_cross_entropy(pred_act, presence)
        losses["activity_loss"] = float(l_act.detach())
        total += self.cfg["activity_loss"]["weight"] * l_act

        # 4) Optional register/sparsity
        if self.cfg.get("register_penalty", {}).get("enabled", False):
            if self.band_mask is not None:
                outside = (1.0 - self.band_mask) * pred_masks    # (B,4,F,T) broadcast band_mask
                l_reg = outside.mean()
            else:
                l_reg = pred_masks.mean()
            total += self.cfg["register_penalty"]["weight"] * l_reg
            losses["register_loss"] = float(l_reg.detach())

        return total, losses


In [None]:
"""PyTorch Datasets for chamber strings separation."""
from pathlib import Path
import json, random
import numpy as np
import torch
from torch.utils.data import Dataset
import soundfile as sf

from utils.preprocessing import make_mel_transform, crop_or_pad_spec

INSTRS = ["viola","cello","violin1","violin2"]

def _to_tensor_mono(path):
    x, sr = sf.read(str(path), dtype="float32")
    if x.ndim == 2: x = x.mean(axis=1)
    return torch.from_numpy(x).unsqueeze(0), sr  # (1, n)

def _mel_pair(wav1x: torch.Tensor, sr: int, cfg_data):
    mel, to_db = make_mel_transform(
        sr=cfg_data["sr"], n_fft=cfg_data["n_fft"], hop=cfg_data["hop"],
        n_mels=cfg_data["n_mels"], fmin=cfg_data["fmin"], fmax=cfg_data["fmax"]
    )
    
    if sr != cfg_data["sr"]:
        wav1x = torch.from_numpy(librosa.resample(wav1x.squeeze(0).numpy(), orig_sr=sr, target_sr=cfg_data["sr"])).unsqueeze(0)
        sr = cfg_data["sr"]
        
    S_lin = mel(wav1x)                      # (1,F,T) linear power
    S_db  = to_db(S_lin)                    # (1,F,T) log-power dB
    return S_lin.squeeze(0), S_db.squeeze(0)  # (F,T), (F,T)

def _target_frames(cfg_data):
    n = int(round(cfg_data["sec"] * cfg_data["sr"]))
    # frames ~= n / hop (centered). We'll crop/pad on spectrogram directly.
    return int(np.ceil(n / cfg_data["hop"]))

def _build_wiener_masks(stem_lin_list, mix_lin):
    """
    stem_lin_list: list of (F,T) linear mel power tensors for each instrument (missing -> zeros)
    mix_lin: (F,T) linear mel power of mixture
    Returns mask tensor (K,F,T) in [0,1]
    """
    # If you have all stems: use sum(stems) denominator; else fallback to mixture
    if len(stem_lin_list) > 0:
        S_sum = torch.stack(stem_lin_list, dim=0).sum(dim=0)  # (F,T)
        denom = torch.clamp(S_sum, min=1e-8)
        masks = [ (s / denom).clamp(0,1) for s in stem_lin_list ]
        return torch.stack(masks, dim=0)
    else:
        return torch.zeros((len(INSTRS),) + mix_lin.shape, dtype=mix_lin.dtype)

class ChamberMusicDataset(Dataset):
    """
    Expects files like:
      clip_001_mixture.wav
      clip_001_viola.wav (optional)
      clip_001_cello.wav (optional)
      clip_001_violin1.wav (optional)
      clip_001_violin2.wav (optional)
      metadata.json (optional {clip_id: {presence: [..]}})
    """
    def __init__(self, root, cfg, split="train", augment=False):
        self.root = Path(root)
        self.cfg = cfg
        self.split = split
        self.augment = augment
        self.mixes = sorted(self.root.glob("*_mixture.wav"))
        self.meta = {}
        mp = self.root / "metadata.json"
        if mp.exists():
            try:
                self.meta = json.loads(mp.read_text())
            except Exception:
                self.meta = {}

        self.target_T = _target_frames(cfg["data"])
        print(f"[{split}] {len(self.mixes)} mixtures from {self.root}")

    def __len__(self): return len(self.mixes)

    def __getitem__(self, idx):
        mix_path = self.mixes[idx]
        clip_id = mix_path.stem.replace("_mixture","")

        mix_wav, sr = _to_tensor_mono(mix_path)
        mix_lin, mix_db = _mel_pair(mix_wav, sr, self.cfg["data"])
        mix_lin = crop_or_pad_spec(mix_lin, self.target_T)
        mix_db  = crop_or_pad_spec(mix_db,  self.target_T)

        stems_lin = []
        presence = []
        for inst in INSTRS:
            sp = mix_path.with_name(f"{clip_id}_{inst}.wav")
            
            if sp.exists():
                w, sr2 = _to_tensor_mono(sp)
                s_lin, _ = _mel_pair(w, sr2, self.cfg["data"])
                s_lin = crop_or_pad_spec(s_lin, self.target_T)
                stems_lin.append(s_lin)
                presence.append(1.0)
            else:
                stems_lin.append(torch.zeros_like(mix_lin))
                # read presence from metadata if available else 0
                if self.meta.get(clip_id, {}).get("presence"):
                    i = INSTRS.index(inst)
                    presence.append(float(self.meta[clip_id]["presence"][i]))
                else:
                    presence.append(0.0)

        masks_t = _build_wiener_masks(stems_lin, mix_lin)  # (4,F,T)

        # SpecAugment on (F,T) — apply consistently to masks and mixture dB
        if self.augment and self.cfg["data"]["augmentation"]["enabled"]:
            aug = self.cfg["data"]["augmentation"]["spec_augment"]
            
            # time masks
            for _ in range(aug["n_time_masks"]):
                w = aug["time_mask_width"]
                if mix_db.shape[1] > w:
                    t = random.randint(0, mix_db.shape[1]-w)
                    mix_db[:, t:t+w] = 0
                    masks_t[:, :, t:t+w] = 0
                    mix_lin[:, t:t+w] = 0
            # freq masks
            for _ in range(aug["n_freq_masks"]):
                w = aug["freq_mask_width"]
                if mix_db.shape[0] > w:
                    f = random.randint(0, mix_db.shape[0]-w)
                    mix_db[f:f+w, :] = 0
                    masks_t[:, f:f+w, :] = 0
                    mix_lin[f:f+w, :] = 0

        return {
            "mixture_db": mix_db.unsqueeze(0),         # (1,F,T)
            "mixture_lin": mix_lin,                    # (F,T)
            "masks_t": masks_t,                        # (4,F,T)
            "presence": torch.tensor(presence, dtype=torch.float32),
            "clip_id": clip_id
        }

class SynthMIDIDataset(ChamberMusicDataset):
    """Same structure; stems are expected to exist for all instruments."""
    #not built yet
    pass

class CombinedDataset(torch.utils.data.Dataset):
    """Mixture of real and synthetic with sampling ratio."""
    def __init__(self, real_dir, synth_dir, synth_weight, cfg, split="train", augment=True):
        self.real = ChamberMusicDataset(real_dir, cfg, split=split, augment=augment)
        self.synth = SynthMIDIDataset(synth_dir, cfg, split=split, augment=augment) if synth_dir else None
        self.synth_weight = float(synth_weight) if synth_dir else 0.0
        if self.synth:
            n_real = len(self.real); self.total = int(n_real / (1 - self.synth_weight))
        else:
            self.total = len(self.real)
        print(f"[{split}] CombinedDataset total≈{self.total} (synth_weight={self.synth_weight})")

    def __len__(self): return self.total

    def __getitem__(self, idx):
        if self.synth and random.random() < self.synth_weight:
            i = random.randint(0, len(self.synth)-1)
            return self.synth[i]
            
        else:
            i = idx % len(self.real)
            return self.real[i]
