In [None]:
import musdb
mus = musdb.DB(download=True)
print("Tracks:", len(mus.tracks))
print("Root:", mus.root)

In [None]:
import os
import musdb
import soundfile as sf

# Paths
musdb_root = "/MUSDB18/MUSDB18-7"
output_root = "/MUSDB18/MUSDB18-train"

# Create folders for stems
stems = ["vocals", "drums", "bass", "others"]
for stem in stems:
    os.makedirs(os.path.join(output_root, stem), exist_ok=True)

# Load training dataset
mus_train = musdb.DB(root=musdb_root, subsets="train")

# Iterate over all tracks
for track in mus_train.tracks:
    print(f"Processing track: {track.name}")
    # Save each stem individually
    for stem_name in ["vocals", "drums", "bass", "accompaniment"]:
        audio = track.targets[stem_name].audio  # shape: (samples, channels)
        
        # Map 'accompaniment' to 'others'
        out_stem_name = "others" if stem_name == "accompaniment" else stem_name
        
        # File path to save
        out_file = os.path.join(output_root, out_stem_name, f"{track.name}_{stem_name}.wav")
        
        # Save as wav
        sf.write(out_file, audio, track.rate)

print("All tracks separated!")


In [None]:
import os
from IPython.display import Audio, display
import soundfile as sf

SR = 44100
idx = 6  # Change this to play a different sample

# Paths to each stem folder
bass_path   = "/MUSDB18-train/bass"
drums_path  = "/MUSDB18-train/drums"
vocals_path = "/MUSDB18-train/vocals"
others_path = "/MUSDB18-train/others"

# Get sorted lists of files
bass_files   = sorted([f for f in os.listdir(bass_path) if f.endswith(".wav")])
drums_files  = sorted([f for f in os.listdir(drums_path) if f.endswith(".wav")])
vocals_files = sorted([f for f in os.listdir(vocals_path) if f.endswith(".wav")])
others_files = sorted([f for f in os.listdir(others_path) if f.endswith(".wav")])

# Load and play bass
bass_file = os.path.join(bass_path, bass_files[idx])
bass_audio, sr = sf.read(bass_file)
if bass_audio.ndim > 1:
    bass_audio = bass_audio.mean(axis=1)
print("Bass:", bass_files[idx])
display(Audio(bass_audio, rate=sr))

# Load and play drums
drums_file = os.path.join(drums_path, drums_files[idx])
drums_audio, sr = sf.read(drums_file)
if drums_audio.ndim > 1:
    drums_audio = drums_audio.mean(axis=1)
print("Drums:", drums_files[idx])
display(Audio(drums_audio, rate=sr))

# Load and play vocals
vocals_file = os.path.join(vocals_path, vocals_files[idx])
vocals_audio, sr = sf.read(vocals_file)
if vocals_audio.ndim > 1:
    vocals_audio = vocals_audio.mean(axis=1)
print("Vocals:", vocals_files[idx])
display(Audio(vocals_audio, rate=sr))

# Load and play others
others_file = os.path.join(others_path, others_files[idx])
others_audio, sr = sf.read(others_file)
if others_audio.ndim > 1:
    others_audio = others_audio.mean(axis=1)
print("Others:", others_files[idx])
display(Audio(others_audio, rate=sr))


In [None]:
import os
import musdb
import soundfile as sf

# Paths
musdb_root = "/MUSDB18/MUSDB18-7"
output_root = "/MUSDB18/MUSDB18-test"

# Create folders for stems
stems = ["vocals", "drums", "bass", "others"]
for stem in stems:
    os.makedirs(os.path.join(output_root, stem), exist_ok=True)

# Load training dataset
mus_test = musdb.DB(root=musdb_root, subsets="test")

# Iterate over all tracks
for track in mus_test.tracks:
    print(f"Processing track: {track.name}")
    # Save each stem individually
    for stem_name in ["vocals", "drums", "bass", "accompaniment"]:
        audio = track.targets[stem_name].audio  # shape: (samples, channels)
        
        # Map 'accompaniment' to 'others'
        out_stem_name = "others" if stem_name == "accompaniment" else stem_name
        
        # File path to save
        out_file = os.path.join(output_root, out_stem_name, f"{track.name}_{stem_name}.wav")
        
        # Save as wav
        sf.write(out_file, audio, track.rate)

print("All tracks separated!")

In [None]:
import os
import itertools
import soundfile as sf
import numpy as np

# Base path where original stems are stored
base_path = "/MUSDB18-train/"

# Original stem folders
stems = ["bass", "vocals", "drums", "others"]

# Output path for synthetic mixtures
output_base = os.path.join(base_path, "synthetic_mixtures")
os.makedirs(output_base, exist_ok=True)

# Get sorted list of files for each stem
stem_files = {}
for stem in stems:
    folder = os.path.join(base_path, stem)
    files = sorted([f for f in os.listdir(folder) if f.endswith(".wav")])
    stem_files[stem] = files

# Function to load audio as mono
def load_audio(path):
    audio, sr = sf.read(path)
    if audio.ndim > 1:
        audio = audio.mean(axis=1)
    return audio, sr

# Function to save audio
def save_audio(audio, sr, out_path):
    sf.write(out_path, audio, sr)

# Generate all 2-stem and 3-stem combinations
for r in [2, 3]:
    combos = list(itertools.combinations(stems, r))
    for combo in combos:
        combo_name = "_".join(combo)
        out_folder = os.path.join(output_base, combo_name)
        os.makedirs(out_folder, exist_ok=True)
        
        # Number of files (assuming all stems have same filenames)
        num_files = len(stem_files[combo[0]])
        
        for idx in range(num_files):
            audios = []
            sr = None
            
            # Load each stem in the combination
            for stem in combo:
                file_path = os.path.join(base_path, stem, stem_files[stem][idx])
                audio, sr = load_audio(file_path)
                audios.append(audio)
            
            # Check lengths and truncate to minimum length
            lengths = [a.shape[0] for a in audios]
            min_len = min(lengths)
            if len(set(lengths)) > 1:
                print(f"Truncating track {stem_files[combo[0]][idx]} for combination {combo_name} "
                      f"from lengths {lengths} to minimum length {min_len}")
            audios = [a[:min_len] for a in audios]
            
            # Sum to create mixture
            mixture = np.sum(audios, axis=0)
            
            # Normalize to avoid clipping
            mixture = mixture / np.max(np.abs(mixture))
            
            # Save synthetic mixture
            out_file = os.path.join(out_folder, stem_files[combo[0]][idx])  # use first stem's filename
            save_audio(mixture, sr, out_file)

print("All synthetic 2-stem and 3-stem mixtures generated successfully!")

In [None]:
import os
from IPython.display import Audio, display
import soundfile as sf

SR = 44100
idx = 5  # Change this to play a different sample

# Base folder containing all synthetic mixture combinations
synthetic_base = "/MUSDB18-train/"

# Paths to each combination folder (explicitly)
bass_vocals_path       = os.path.join(synthetic_base, "bass_vocals")
bass_drums_path        = os.path.join(synthetic_base, "bass_drums")
bass_others_path       = os.path.join(synthetic_base, "bass_others")
vocals_drums_path      = os.path.join(synthetic_base, "vocals_drums")
vocals_others_path     = os.path.join(synthetic_base, "vocals_others")
drums_others_path      = os.path.join(synthetic_base, "drums_others")
bass_vocals_drums_path = os.path.join(synthetic_base, "bass_vocals_drums")
bass_vocals_others_path= os.path.join(synthetic_base, "bass_vocals_others")
bass_drums_others_path = os.path.join(synthetic_base, "bass_drums_others")
vocals_drums_others_path= os.path.join(synthetic_base, "vocals_drums_others")

# Function to load audio as mono
def load_mono_audio(path):
    audio, sr = sf.read(path)
    if audio.ndim > 1:
        audio = audio.mean(axis=1)
    return audio, sr

# Get sorted file lists
bass_vocals_files        = sorted([f for f in os.listdir(bass_vocals_path) if f.endswith(".wav")])
bass_drums_files         = sorted([f for f in os.listdir(bass_drums_path) if f.endswith(".wav")])
bass_others_files        = sorted([f for f in os.listdir(bass_others_path) if f.endswith(".wav")])
vocals_drums_files       = sorted([f for f in os.listdir(vocals_drums_path) if f.endswith(".wav")])
vocals_others_files      = sorted([f for f in os.listdir(vocals_others_path) if f.endswith(".wav")])
drums_others_files       = sorted([f for f in os.listdir(drums_others_path) if f.endswith(".wav")])
bass_vocals_drums_files  = sorted([f for f in os.listdir(bass_vocals_drums_path) if f.endswith(".wav")])
bass_vocals_others_files = sorted([f for f in os.listdir(bass_vocals_others_path) if f.endswith(".wav")])
bass_drums_others_files  = sorted([f for f in os.listdir(bass_drums_others_path) if f.endswith(".wav")])
vocals_drums_others_files= sorted([f for f in os.listdir(vocals_drums_others_path) if f.endswith(".wav")])

# Play bass_vocals
file_path = os.path.join(bass_vocals_path, bass_vocals_files[idx])
audio, sr = load_mono_audio(file_path)
print("Bass + Vocals:", bass_vocals_files[idx])
display(Audio(audio, rate=sr))

# Play bass_drums
file_path = os.path.join(bass_drums_path, bass_drums_files[idx])
audio, sr = load_mono_audio(file_path)
print("Bass + Drums:", bass_drums_files[idx])
display(Audio(audio, rate=sr))

# Play bass_others
file_path = os.path.join(bass_others_path, bass_others_files[idx])
audio, sr = load_mono_audio(file_path)
print("Bass + Others:", bass_others_files[idx])
display(Audio(audio, rate=sr))

# Play vocals_drums
file_path = os.path.join(vocals_drums_path, vocals_drums_files[idx])
audio, sr = load_mono_audio(file_path)
print("Vocals + Drums:", vocals_drums_files[idx])
display(Audio(audio, rate=sr))

# Play vocals_others
file_path = os.path.join(vocals_others_path, vocals_others_files[idx])
audio, sr = load_mono_audio(file_path)
print("Vocals + Others:", vocals_others_files[idx])
display(Audio(audio, rate=sr))

# Play drums_others
file_path = os.path.join(drums_others_path, drums_others_files[idx])
audio, sr = load_mono_audio(file_path)
print("Drums + Others:", drums_others_files[idx])
display(Audio(audio, rate=sr))

# Play bass_vocals_drums
file_path = os.path.join(bass_vocals_drums_path, bass_vocals_drums_files[idx])
audio, sr = load_mono_audio(file_path)
print("Bass + Vocals + Drums:", bass_vocals_drums_files[idx])
display(Audio(audio, rate=sr))

# Play bass_vocals_others
file_path = os.path.join(bass_vocals_others_path, bass_vocals_others_files[idx])
audio, sr = load_mono_audio(file_path)
print("Bass + Vocals + Others:", bass_vocals_others_files[idx])
display(Audio(audio, rate=sr))

# Play bass_drums_others
file_path = os.path.join(bass_drums_others_path, bass_drums_others_files[idx])
audio, sr = load_mono_audio(file_path)
print("Bass + Drums + Others:", bass_drums_others_files[idx])
display(Audio(audio, rate=sr))

# Play vocals_drums_others
file_path = os.path.join(vocals_drums_others_path, vocals_drums_others_files[idx])
audio, sr = load_mono_audio(file_path)
print("Vocals + Drums + Others:", vocals_drums_others_files[idx])
display(Audio(audio, rate=sr))


In [None]:
import librosa

# Path to one audio file from your dataset
audio_path = "/MUSDB18-train/bass/A Classic Education - NightOwl_bass.wav"

# Load with original sampling rate (do NOT resample)
wav, sr = librosa.load(audio_path, sr=None)

print("Sampling Rate =", sr)
print("Number of samples =", len(wav))
print("Duration (sec) =", len(wav) / sr)

In [None]:
import os
import glob
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import pywt
import librosa

# -----------------------------
# PARAMETERS
# -----------------------------
STEMS = ["vocals", "drums", "bass", "other"]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 4      # batch size can be increased after testing memory
EPOCHS = 150
LEARNING_RATE = 1e-3
DWT_LEVELS = 5      # downsample 5 times
MAX_LEN = 9376      # length after 5-level DWT of ~300k samples

train_losses = []
val_losses = []

# -----------------------------
# DATASET
# -----------------------------
def get_file_label_list(root_dir):
    file_paths = []
    labels = []

    # traverse folders for multi-labels
    for sub in os.listdir(root_dir):
        sub_path = os.path.join(root_dir, sub)
        if not os.path.isdir(sub_path):
            continue

        stems_in_folder = sub.split("_")
        lab = np.zeros(len(STEMS), dtype=np.float32)
        for i, s in enumerate(STEMS):
            if s in stems_in_folder:
                lab[i] = 1.0

        # get all wav files in this folder
        wav_files = glob.glob(os.path.join(sub_path, "*.wav"))
        for w in wav_files:
            file_paths.append(w)
            labels.append(lab.copy())

    return file_paths, labels

class WaveletDataset(Dataset):
    def __init__(self, file_paths, labels, max_len=MAX_LEN, dwt_levels=DWT_LEVELS):
        self.file_paths = file_paths
        self.labels = labels
        self.max_len = max_len
        self.dwt_levels = dwt_levels

    def __len__(self):
        return len(self.file_paths)

    def dwt_downsample(self, x):
        # Apply multi-level DWT downsampling
        for _ in range(self.dwt_levels):
            x, _ = pywt.dwt(x, 'haar')
        return x

    def pad_or_trim(self, x):
        if len(x) > self.max_len:
            return x[:self.max_len]
        elif len(x) < self.max_len:
            return np.pad(x, (0, self.max_len - len(x)))
        else:
            return x

    def __getitem__(self, idx):
        path = self.file_paths[idx]
        label = self.labels[idx]

        # load waveform
        wav, sr = librosa.load(path, sr=None)
        wav = self.dwt_downsample(wav)
        wav = self.pad_or_trim(wav)

        # return as tensor
        x = torch.tensor(wav, dtype=torch.float32).unsqueeze(0)  # shape: (1, T)
        y = torch.tensor(label, dtype=torch.float32)

        return x, y

# -----------------------------
# MODEL
# -----------------------------
class ResBlock1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, stride, padding)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.pool = nn.MaxPool1d(2)  # reduce temporal dimension by 2

    def forward(self, x):
        identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        out = self.pool(out)  # halve time dimension
        return out

class WaveletCNN(nn.Module):
    def __init__(self, in_channels=1, num_classes=4):
        super().__init__()
        self.stem = nn.Conv1d(in_channels, 8, kernel_size=7, padding=3)
        self.bn0 = nn.BatchNorm1d(8)
        self.relu = nn.ReLU()
        self.layer1 = ResBlock1D(8, 16)
        self.layer2 = ResBlock1D(16, 32)
        self.layer3 = ResBlock1D(32, 64)
        self.layer4 = ResBlock1D(64, 128)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        out = self.stem(x)
        out = self.bn0(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.global_pool(out).squeeze(-1)
        out = self.fc(out)
        return out

# -----------------------------
# TRAINING FUNCTIONS
# -----------------------------
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
    return total_loss / len(loader.dataset)

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            total_loss += loss.item() * x.size(0)
            all_preds.append(torch.sigmoid(logits).cpu())
            all_labels.append(y.cpu())
    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    return total_loss / len(loader.dataset), all_preds, all_labels

# -----------------------------
# MAIN TRAINING
# -----------------------------
def main_training(root_dir):
    # build file lists
    file_paths, labels = get_file_label_list(root_dir)

    # split 90/10 train/val
    n = len(file_paths)
    idxs = np.arange(n)
    np.random.shuffle(idxs)
    split = int(n*0.9)
    train_idx, val_idx = idxs[:split], idxs[split:]
    train_files = [file_paths[i] for i in train_idx]
    train_labels = [labels[i] for i in train_idx]
    val_files = [file_paths[i] for i in val_idx]
    val_labels = [labels[i] for i in val_idx]

    # datasets & loaders
    train_ds = WaveletDataset(train_files, train_labels)
    val_ds = WaveletDataset(val_files, val_labels)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

    # model, optimizer, loss
    model = WaveletCNN().to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.BCEWithLogitsLoss()

    best_val_loss = float('inf')
    for epoch in range(1, EPOCHS+1):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
        val_loss, _, _ = validate(model, val_loader, criterion, DEVICE)
        print(f"Epoch {epoch}: train_loss={train_loss:.4f} val_loss={val_loss:.4f}")
        train_losses.append(train_loss)
        val_losses.append(val_loss)

        # save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_wavelet_cnn.pth")
            print("Saved best model.")
    np.save("train_losses.npy", np.array(train_losses))
    np.save("val_losses.npy", np.array(val_losses))

if __name__ == "__main__":
    main_training("/MUSDB18-train")


## Testing Phase:

In [None]:
import os
from IPython.display import Audio, display
import soundfile as sf

SR = 44100
idx = 5  # Change this to play a different sample

# Paths to each stem folder
bass_path   = "/MUSDB18-test/bass"
drums_path  = "/MUSDB18-test/drums"
vocals_path = "/MUSDB18-test/vocals"
others_path = "/MUSDB18-test/others"

# Get sorted lists of files
bass_files   = sorted([f for f in os.listdir(bass_path) if f.endswith(".wav")])
drums_files  = sorted([f for f in os.listdir(drums_path) if f.endswith(".wav")])
vocals_files = sorted([f for f in os.listdir(vocals_path) if f.endswith(".wav")])
others_files = sorted([f for f in os.listdir(others_path) if f.endswith(".wav")])

# Load and play bass
bass_file = os.path.join(bass_path, bass_files[idx])
bass_audio, sr = sf.read(bass_file)
if bass_audio.ndim > 1:
    bass_audio = bass_audio.mean(axis=1)
print("Bass:", bass_files[idx])
display(Audio(bass_audio, rate=sr))

# Load and play drums
drums_file = os.path.join(drums_path, drums_files[idx])
drums_audio, sr = sf.read(drums_file)
if drums_audio.ndim > 1:
    drums_audio = drums_audio.mean(axis=1)
print("Drums:", drums_files[idx])
display(Audio(drums_audio, rate=sr))

# Load and play vocals
vocals_file = os.path.join(vocals_path, vocals_files[idx])
vocals_audio, sr = sf.read(vocals_file)
if vocals_audio.ndim > 1:
    vocals_audio = vocals_audio.mean(axis=1)
print("Vocals:", vocals_files[idx])
display(Audio(vocals_audio, rate=sr))

# Load and play others
others_file = os.path.join(others_path, others_files[idx])
others_audio, sr = sf.read(others_file)
if others_audio.ndim > 1:
    others_audio = others_audio.mean(axis=1)
print("Others:", others_files[idx])
display(Audio(others_audio, rate=sr))


In [None]:
import os
import itertools
import soundfile as sf
import numpy as np

# Base path where original stems are stored
base_path = "/MUSDB18-test/"

# Original stem folders
stems = ["bass", "vocals", "drums", "others"]

# Output path for synthetic mixtures
output_base = os.path.join(base_path, "synthetic_mixtures")
os.makedirs(output_base, exist_ok=True)

# Get sorted list of files for each stem
stem_files = {}
for stem in stems:
    folder = os.path.join(base_path, stem)
    files = sorted([f for f in os.listdir(folder) if f.endswith(".wav")])
    stem_files[stem] = files

# Function to load audio as mono
def load_audio(path):
    audio, sr = sf.read(path)
    if audio.ndim > 1:
        audio = audio.mean(axis=1)
    return audio, sr

# Function to save audio
def save_audio(audio, sr, out_path):
    sf.write(out_path, audio, sr)

# Generate all 2-stem and 3-stem combinations
for r in [2, 3]:
    combos = list(itertools.combinations(stems, r))
    for combo in combos:
        combo_name = "_".join(combo)
        out_folder = os.path.join(output_base, combo_name)
        os.makedirs(out_folder, exist_ok=True)
        
        # Number of files (assuming all stems have same filenames)
        num_files = len(stem_files[combo[0]])
        
        for idx in range(num_files):
            audios = []
            sr = None
            
            # Load each stem in the combination
            for stem in combo:
                file_path = os.path.join(base_path, stem, stem_files[stem][idx])
                audio, sr = load_audio(file_path)
                audios.append(audio)
            
            # Check lengths and truncate to minimum length
            lengths = [a.shape[0] for a in audios]
            min_len = min(lengths)
            if len(set(lengths)) > 1:
                print(f"Truncating track {stem_files[combo[0]][idx]} for combination {combo_name} "
                      f"from lengths {lengths} to minimum length {min_len}")
            audios = [a[:min_len] for a in audios]
            
            # Sum to create mixture
            mixture = np.sum(audios, axis=0)
            
            # Normalize to avoid clipping
            mixture = mixture / np.max(np.abs(mixture))
            
            # Save synthetic mixture
            out_file = os.path.join(out_folder, stem_files[combo[0]][idx])  # use first stem's filename
            save_audio(mixture, sr, out_file)

print("All synthetic 2-stem and 3-stem mixtures generated successfully!")

In [None]:
import os
import glob
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import pywt
import librosa
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# -----------------------------
# PARAMETERS
# -----------------------------
STEMS = ["vocals", "drums", "bass", "other"]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 1
DWT_LEVELS = 5
MAX_LEN = 9376  # must match training

# -----------------------------
# DATASET
# -----------------------------
def get_file_label_list(root_dir):
    file_paths = []
    labels = []

    for sub in os.listdir(root_dir):
        sub_path = os.path.join(root_dir, sub)
        if not os.path.isdir(sub_path):
            continue

        stems_in_folder = sub.split("_")
        lab = np.zeros(len(STEMS), dtype=np.float32)
        for i, s in enumerate(STEMS):
            if s in stems_in_folder:
                lab[i] = 1.0

        wav_files = glob.glob(os.path.join(sub_path, "*.wav"))
        for w in wav_files:
            file_paths.append(w)
            labels.append(lab.copy())

    return file_paths, labels

class WaveletDataset(Dataset):
    def __init__(self, file_paths, labels, max_len=MAX_LEN, dwt_levels=DWT_LEVELS):
        self.file_paths = file_paths
        self.labels = labels
        self.max_len = max_len
        self.dwt_levels = dwt_levels

    def dwt_downsample(self, x):
        for _ in range(self.dwt_levels):
            x, _ = pywt.dwt(x, 'haar')
        return x

    def pad_or_trim(self, x):
        if len(x) > self.max_len:
            return x[:self.max_len]
        elif len(x) < self.max_len:
            return np.pad(x, (0, self.max_len - len(x)))
        else:
            return x

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        path = self.file_paths[idx]
        label = self.labels[idx]

        wav, sr = librosa.load(path, sr=None)
        wav = self.dwt_downsample(wav)
        wav = self.pad_or_trim(wav)
        x = torch.tensor(wav, dtype=torch.float32).unsqueeze(0)
        y = torch.tensor(label, dtype=torch.float32)
        return x, y

# -----------------------------
# MODEL
# -----------------------------
class ResBlock1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, stride, padding)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.pool = nn.MaxPool1d(2)

    def forward(self, x):
        identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        out = self.pool(out)
        return out

class WaveletCNN(nn.Module):
    def __init__(self, in_channels=1, num_classes=4):
        super().__init__()
        self.stem = nn.Conv1d(in_channels, 8, kernel_size=7, padding=3)
        self.bn0 = nn.BatchNorm1d(8)
        self.relu = nn.ReLU()
        self.layer1 = ResBlock1D(8, 16)
        self.layer2 = ResBlock1D(16, 32)
        self.layer3 = ResBlock1D(32, 64)
        self.layer4 = ResBlock1D(64, 128)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        out = self.stem(x)
        out = self.bn0(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.global_pool(out).squeeze(-1)
        out = self.fc(out)
        return out

# -----------------------------
# EVALUATION
# -----------------------------
def evaluate_model(model, dataloader, device, threshold=0.5):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            probs = torch.sigmoid(logits)
            preds = (probs > threshold).float()
            all_preds.append(preds.cpu())
            all_labels.append(y.cpu())
    all_preds = torch.cat(all_preds, dim=0).numpy()
    all_labels = torch.cat(all_labels, dim=0).numpy()

    # Compute metrics
    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, average='samples', zero_division=0)
    rec = recall_score(all_labels, all_preds, average='samples', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='samples', zero_division=0)
    print(f"Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, F1-score: {f1:.4f}")
    return all_preds, all_labels

# -----------------------------
# MAIN
# -----------------------------
def main_test(test_root_dir, model_path="best_wavelet_cnn.pth"):
    # build test dataset
    file_paths, labels = get_file_label_list(test_root_dir)
    test_ds = WaveletDataset(file_paths, labels)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

    # load model
    model = WaveletCNN().to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))

    # evaluate
    preds, labels = evaluate_model(model, test_loader, DEVICE)

    # optional: print first 5 predictions
    print("First 5 predictions vs labels:")
    for p, l in zip(preds[:5], labels[:5]):
        print(p, l)

if __name__ == "__main__":
    main_test("/MUSDB18-test")


## Same Model Architecture Without Using DWT as Inputs

In [None]:
import os
import glob
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import librosa

# -----------------------------
# PARAMETERS
# -----------------------------
STEMS = ["vocals", "drums", "bass", "other"]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 1       # batch size 1 for variable-length input
EPOCHS = 150
LEARNING_RATE = 1e-3
TARGET_SR = 22050     # downsample to 22 kHz

train_losses_withoutDWT = []
val_losses_withoutDWT = []

# -----------------------------
# DATASET
# -----------------------------
def get_file_label_list(root_dir):
    file_paths, labels = [], []
    for sub in os.listdir(root_dir):
        sub_path = os.path.join(root_dir, sub)
        if not os.path.isdir(sub_path):
            continue

        stems_in_folder = sub.split("_")
        lab = np.zeros(len(STEMS), dtype=np.float32)
        for i, s in enumerate(STEMS):
            if s in stems_in_folder:
                lab[i] = 1.0

        wav_files = glob.glob(os.path.join(sub_path, "*.wav"))
        for w in wav_files:
            file_paths.append(w)
            labels.append(lab.copy())
    return file_paths, labels

class RawWaveformDataset(Dataset):
    def __init__(self, file_paths, labels, target_sr=TARGET_SR):
        self.file_paths = file_paths
        self.labels = labels
        self.target_sr = target_sr

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        path = self.file_paths[idx]
        label = self.labels[idx]

        # Load waveform and downsample to 22 kHz
        wav, sr = librosa.load(path, sr=None)
        if sr != self.target_sr:
            wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr)

        x = torch.tensor(wav, dtype=torch.float32).unsqueeze(0)  # shape: (1, T)
        y = torch.tensor(label, dtype=torch.float32)
        return x, y

# -----------------------------
# MODEL
# -----------------------------
class ResBlock1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, stride, padding)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.pool = nn.MaxPool1d(4)  # aggressive pooling to reduce sequence length

    def forward(self, x):
        identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        out = self.pool(out)
        return out

class RawWaveCNN(nn.Module):
    def __init__(self, in_channels=1, num_classes=4):
        super().__init__()
        self.stem = nn.Conv1d(in_channels, 8, kernel_size=7, padding=3)
        self.bn0 = nn.BatchNorm1d(8)
        self.relu = nn.ReLU()
        self.layer1 = ResBlock1D(8, 16)
        self.layer2 = ResBlock1D(16, 32)
        self.layer3 = ResBlock1D(32, 64)
        self.layer4 = ResBlock1D(64, 128)
        self.global_pool = nn.AdaptiveAvgPool1d(1)  # works with variable-length input
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        out = self.stem(x)
        out = self.bn0(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.global_pool(out).squeeze(-1)
        out = self.fc(out)
        return out

# -----------------------------
# TRAINING FUNCTIONS
# -----------------------------
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(loader.dataset)

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)
            total_loss += loss.item() * x.size(0)
            all_preds.append(torch.sigmoid(logits).cpu())
            all_labels.append(y.cpu())
    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    return total_loss / len(loader.dataset), all_preds, all_labels

# -----------------------------
# TRAINING LOOP
# -----------------------------
def main_training(root_dir):
    # Build dataset
    file_paths, labels = get_file_label_list(root_dir)
    n = len(file_paths)
    idxs = np.arange(n)
    np.random.shuffle(idxs)
    split = int(n*0.9)
    train_idx, val_idx = idxs[:split], idxs[split:]
    train_files = [file_paths[i] for i in train_idx]
    train_labels = [labels[i] for i in train_idx]
    val_files = [file_paths[i] for i in val_idx]
    val_labels = [labels[i] for i in val_idx]

    train_ds = RawWaveformDataset(train_files, train_labels, target_sr=TARGET_SR)
    val_ds   = RawWaveformDataset(val_files, val_labels, target_sr=TARGET_SR)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

    # Model, optimizer, loss
    model = RawWaveCNN().to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.BCEWithLogitsLoss()

    best_val_loss = float('inf')
    for epoch in range(1, EPOCHS+1):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
        val_loss, _, _ = validate(model, val_loader, criterion, DEVICE)
        print(f"Epoch {epoch}: train_loss={train_loss:.4f} val_loss={val_loss:.4f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_rawwave_cnn.pth")
            print("Saved best model.")

        # Save losses
        train_losses_withoutDWT.append(train_loss)
        val_losses_withoutDWT.append(val_loss)

    # Save loss history
    np.save("train_losses_withoutDWT.npy", np.array(train_losses_withoutDWT))
    np.save("val_losses_withoutDWT.npy", np.array(val_losses_withoutDWT))

# -----------------------------
# RUN
# -----------------------------
if __name__ == "__main__":
    main_training("/MUSDB18-train")

In [None]:
import os
import glob
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import librosa
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
# -----------------------------
# PARAMETERS
# -----------------------------
STEMS = ["vocals", "drums", "bass", "other"]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 1       # batch size 1 for variable-length input
EPOCHS = 150
LEARNING_RATE = 1e-3
TARGET_SR = 22050     # downsample to 22 kHz

train_losses_withoutDWT = []
val_losses_withoutDWT = []

# -----------------------------
# DATASET
# -----------------------------
def get_file_label_list(root_dir):
    file_paths, labels = [], []
    for sub in os.listdir(root_dir):
        sub_path = os.path.join(root_dir, sub)
        if not os.path.isdir(sub_path):
            continue

        stems_in_folder = sub.split("_")
        lab = np.zeros(len(STEMS), dtype=np.float32)
        for i, s in enumerate(STEMS):
            if s in stems_in_folder:
                lab[i] = 1.0

        wav_files = glob.glob(os.path.join(sub_path, "*.wav"))
        for w in wav_files:
            file_paths.append(w)
            labels.append(lab.copy())
    return file_paths, labels

class RawWaveformDataset(Dataset):
    def __init__(self, file_paths, labels, target_sr=TARGET_SR):
        self.file_paths = file_paths
        self.labels = labels
        self.target_sr = target_sr

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        path = self.file_paths[idx]
        label = self.labels[idx]

        # Load waveform and downsample to 22 kHz
        wav, sr = librosa.load(path, sr=None)
        if sr != self.target_sr:
            wav = librosa.resample(wav, orig_sr=sr, target_sr=self.target_sr)

        x = torch.tensor(wav, dtype=torch.float32).unsqueeze(0)  # shape: (1, T)
        y = torch.tensor(label, dtype=torch.float32)
        return x, y

# -----------------------------
# MODEL
# -----------------------------
class ResBlock1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, stride, padding)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.pool = nn.MaxPool1d(4)  # aggressive pooling to reduce sequence length

    def forward(self, x):
        identity = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        out = self.pool(out)
        return out

class RawWaveCNN(nn.Module):
    def __init__(self, in_channels=1, num_classes=4):
        super().__init__()
        self.stem = nn.Conv1d(in_channels, 8, kernel_size=7, padding=3)
        self.bn0 = nn.BatchNorm1d(8)
        self.relu = nn.ReLU()
        self.layer1 = ResBlock1D(8, 16)
        self.layer2 = ResBlock1D(16, 32)
        self.layer3 = ResBlock1D(32, 64)
        self.layer4 = ResBlock1D(64, 128)
        self.global_pool = nn.AdaptiveAvgPool1d(1)  # works with variable-length input
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        out = self.stem(x)
        out = self.bn0(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.global_pool(out).squeeze(-1)
        out = self.fc(out)
        return out

# -----------------------------
# TEST / EVALUATION
# -----------------------------
def evaluate_model(model, dataloader, device, threshold=0.5):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            probs = torch.sigmoid(logits)
            preds = (probs > threshold).float()
            all_preds.append(preds.cpu())
            all_labels.append(y.cpu())
    all_preds = torch.cat(all_preds, dim=0).numpy()
    all_labels = torch.cat(all_labels, dim=0).numpy()

    # metrics
    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, average='samples', zero_division=0)
    rec = recall_score(all_labels, all_preds, average='samples', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='samples', zero_division=0)
    print(f"Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, F1-score: {f1:.4f}")
    return all_preds, all_labels

def main_test(test_root_dir, model_path="best_rawwave_cnn.pth"):
    file_paths, labels = get_file_label_list(test_root_dir)
    test_ds = RawWaveformDataset(file_paths, labels)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

    model = RawWaveCNN().to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))

    preds, labels = evaluate_model(model, test_loader, DEVICE)
    print("First 5 predictions vs labels:")
    for p, l in zip(preds[:5], labels[:5]):
        print(p, l)


# -----------------------------
# RUN
# -----------------------------
if __name__ == "__main__":
    # Test
    main_test("/MUSDB18-test")


## Evaluating Model With  AWGN

In [None]:
import numpy as np

def add_awgn(x, snr_db):
    """
    x: numpy array of audio (float32)
    snr_db: target SNR in dB
    """
    # Signal power
    sig_power = np.mean(x**2)

    # Noise power
    snr_linear = 10 ** (snr_db / 10)
    noise_power = sig_power / snr_linear

    # Generate noise
    noise = np.random.normal(0, np.sqrt(noise_power), x.shape)

    return x + noise

In [None]:
import torch
import numpy as np
import librosa
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ----------------------------
# Add AWGN noise function
# ----------------------------
def add_awgn(x, snr_db):
    sig_power = np.mean(x**2)
    snr_linear = 10 ** (snr_db / 10)
    noise_power = sig_power / snr_linear
    noise = np.random.normal(0, np.sqrt(noise_power), x.shape)
    return x + noise

# ----------------------------
# Get predictions (WaveRCNN)
# ----------------------------
def predict_wavercnn(model, audio, sr=44100, max_len=300032):
    # pad/trim
    if len(audio) < max_len:
        audio = np.pad(audio, (0, max_len - len(audio)))
    else:
        audio = audio[:max_len]

    x = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    x = x.to(DEVICE)

    with torch.no_grad():
        logits = model(x)
    return torch.argmax(logits, dim=1).item()

# ----------------------------
# Get predictions (DWT-CNN)
# ----------------------------
def compute_dwt(audio, wavelet, level):
    import pywt
    coeffs = pywt.wavedec(audio, wavelet=wavelet, level=level)
    dwt_vector = np.concatenate([c for c in coeffs], axis=-1)
    return dwt_vector

def predict_dwtcnn(model, audio, wavelet, level):
    dwt_vec = compute_dwt(audio, wavelet, level)
    
    x = torch.tensor(dwt_vec, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    x = torch.nn.functional.interpolate(x, size=4096)  # adjust to model input size
    x = x.to(DEVICE)

    with torch.no_grad():
        logits = model(x)
    return torch.argmax(logits, dim=1).item()

# ----------------------------
# Noisy evaluation loop
# ----------------------------
def test_under_noise(wavercnn_model, dwt_model, test_files, test_labels,
                     wavelet="db4", level=4):

    SNR_LEVELS = [20, 30, 40]
    results = {}

    for snr in SNR_LEVELS:
        y_true = []
        y_pred_wave = []
        y_pred_dwt = []

        for audio_path, label in zip(test_files, test_labels):
            audio, sr = librosa.load(audio_path, sr=44100)

            noisy_audio = add_awgn(audio, snr)

            # WaveRCNN prediction
            pred_wave = predict_wavercnn(wavercnn_model, noisy_audio)
            y_pred_wave.append(pred_wave)

            # DWT-CNN prediction
            pred_dwt = predict_dwtcnn(dwt_model, noisy_audio, wavelet, level)
            y_pred_dwt.append(pred_dwt)

            y_true.append(label)

        # compute metrics
        acc_w = accuracy_score(y_true, y_pred_wave)
        acc_d = accuracy_score(y_true, y_pred_dwt)

        p_w, r_w, f_w, _ = precision_recall_fscore_support(y_true, y_pred_wave, average='macro')
        p_d, r_d, f_d, _ = precision_recall_fscore_support(y_true, y_pred_dwt, average='macro')

        results[snr] = {
            "WaveRCNN": {
                "Accuracy": acc_w,
                "Precision": p_w,
                "Recall": r_w,
                "F1": f_w,
            },
            "DWT-CNN": {
                "Accuracy": acc_d,
                "Precision": p_d,
                "Recall": r_d,
                "F1": f_d,
            }
        }

    return results


In [None]:
import os
import numpy as np
import torch
import librosa
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# ----------------------------
# PARAMETERS
# ----------------------------
STEMS = ["vocals", "drums", "bass", "other"]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LEN = 300032  # waveform length

# ----------------------------
# Add AWGN noise
# ----------------------------
def add_awgn(x, snr_db):
    sig_power = np.mean(x**2)
    snr_linear = 10 ** (snr_db / 10)
    noise_power = sig_power / snr_linear
    noise = np.random.normal(0, np.sqrt(noise_power), x.shape)
    return x + noise

# ----------------------------
# DWT helper
# ----------------------------
def compute_dwt(audio, wavelet, level):
    import pywt
    coeffs = pywt.wavedec(audio, wavelet=wavelet, level=level)
    dwt_vector = np.concatenate([c for c in coeffs], axis=-1)
    return dwt_vector

# ----------------------------
# Predictions
# ----------------------------
def predict_wavercnn(model, audio, max_len=MAX_LEN):
    if len(audio) < max_len:
        audio = np.pad(audio, (0, max_len - len(audio)))
    else:
        audio = audio[:max_len]
    x = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits = model(x)
    return torch.argmax(logits, dim=1).item()

def predict_dwtcnn(model, audio, wavelet, level, target_len=4096):
    dwt_vec = compute_dwt(audio, wavelet, level)
    x = torch.tensor(dwt_vec, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    x = torch.nn.functional.interpolate(x, size=target_len).to(DEVICE)
    with torch.no_grad():
        logits = model(x)
    return torch.argmax(logits, dim=1).item()

# ----------------------------
# Load test files & labels
# ----------------------------
def get_file_label_list(root_dir):
    file_paths, labels = [], []
    for sub in os.listdir(root_dir):
        sub_path = os.path.join(root_dir, sub)
        if not os.path.isdir(sub_path):
            continue
        stems_in_folder = sub.split("_")
        lab = np.zeros(len(STEMS), dtype=np.float32)
        for i, s in enumerate(STEMS):
            if s in stems_in_folder:
                lab[i] = 1.0
        wav_files = librosa.util.find_files(sub_path, ext='wav')
        for w in wav_files:
            file_paths.append(w)
            labels.append(lab.copy())
    return file_paths, labels

# ----------------------------
# Test under noise
# ----------------------------
def test_under_noise(wavercnn_model, dwt_model, test_files, test_labels,
                     wavelet="db4", level=4):

    SNR_LEVELS = [0,5,10]
    results = {}

    for snr in SNR_LEVELS:
        y_true, y_pred_wave, y_pred_dwt = [], [], []

        for audio_path, label in zip(test_files, test_labels):
            audio, sr = librosa.load(audio_path, sr=44100)
            noisy_audio = add_awgn(audio, snr)

            # WaveRCNN prediction
            pred_wave = predict_wavercnn(wavercnn_model, noisy_audio)
            y_pred_wave.append(pred_wave)

            # DWT-CNN prediction
            pred_dwt = predict_dwtcnn(dwt_model, noisy_audio, wavelet, level)
            y_pred_dwt.append(pred_dwt)

            # Convert one-hot label to integer
            if isinstance(label, (np.ndarray, list)):
                y_true.append(int(np.argmax(label)))
            else:
                y_true.append(label)
    
        # compute metrics
        acc_w = accuracy_score(y_true, y_pred_wave)
        acc_d = accuracy_score(y_true, y_pred_dwt)

        p_w, r_w, f_w, _ = precision_recall_fscore_support(y_true, y_pred_wave, average='macro')
        p_d, r_d, f_d, _ = precision_recall_fscore_support(y_true, y_pred_dwt, average='macro')

        results[snr] = {
            "WaveRCNN": {"Accuracy": acc_w, "Precision": p_w, "Recall": r_w, "F1": f_w},
            "DWT-CNN": {"Accuracy": acc_d, "Precision": p_d, "Recall": r_d, "F1": f_d}
        }

    return results

# ----------------------------
# MAIN RUN
# ----------------------------
if __name__ == "__main__":
    from model_dwtcnn import WaveletCNN  
    from model_wavercnn import RawWaveCNN

    # Load models
    wavercnn_model = RawWaveCNN().to(DEVICE)
    wavercnn_model.load_state_dict(torch.load("best_rawwave_cnn.pth", map_location=DEVICE))
    wavercnn_model.eval()

    dwt_model = WaveletCNN().to(DEVICE)
    dwt_model.load_state_dict(torch.load("best_wavelet_cnn.pth", map_location=DEVICE))
    dwt_model.eval()

    # Load test files
    test_files, test_labels = get_file_label_list("/MUSDB18-test")

    # Evaluate under noise
    results = test_under_noise(wavercnn_model, dwt_model, test_files, test_labels,
                               wavelet="db4", level=4)

    # Print results
    for snr, metrics in results.items():
        print(f"\n==================== SNR = {snr} dB ====================")
        print("---- WaveRCNN ----")
        print(f"Accuracy : {metrics['WaveRCNN']['Accuracy']:.4f}")
        print(f"Precision: {metrics['WaveRCNN']['Precision']:.4f}")
        print(f"Recall   : {metrics['WaveRCNN']['Recall']:.4f}")
        print(f"F1 Score : {metrics['WaveRCNN']['F1']:.4f}")

        print("\n---- DWT-CNN ----")
        print(f"Accuracy : {metrics['DWT-CNN']['Accuracy']:.4f}")
        print(f"Precision: {metrics['DWT-CNN']['Precision']:.4f}")
        print(f"Recall   : {metrics['DWT-CNN']['Recall']:.4f}")
        print(f"F1 Score : {metrics['DWT-CNN']['F1']:.4f}")


## Visualization ##

In [None]:
import musdb
import pywt
import matplotlib.pyplot as plt
import numpy as np

# Load train dataset
mus_train = musdb.DB(root="/MUSDB18/MUSDB18-7", subsets="train")
track = mus_train.tracks[0]

# Convert to mono
mixture = track.audio.mean(axis=1)

# Wavelet settings
wavelet = 'db4'
max_level = 5

# Single-level decomposition recursively to get all approximations
approxs = []
current = mixture.copy()
for l in range(max_level, 0, -1):
    c = pywt.wavedec(current, wavelet, level=1)
    approxs.append(c[0])  # approximation at this level
    current = c[0]        # go one level deeper

# Full DWT for details
coeffs = pywt.wavedec(mixture, wavelet, level=max_level)
details = [pywt.waverec([None]+[coeffs[i]] + [None]*(max_level-i), wavelet) for i in range(1, max_level+1)]

# Plotting
plt.figure(figsize=(15, 12))

# Original mixture
plt.subplot(max_level*2 + 1, 1, 1)
plt.plot(mixture)
plt.title("Original Mixture")
plt.xlabel("Sample Index")
plt.ylabel("Amplitude")

# Approximations
for i, a in enumerate(approxs):
    plt.subplot(max_level*2 + 1, 1, i+2)
    plt.plot(a)
    plt.title(f"Approximation Level {max_level-i}")
    plt.xlabel("Sample Index")
    plt.ylabel("Amplitude")

# Details
for i, d in enumerate(details):
    plt.subplot(max_level*2 + 1, 1, max_level+2 + i)
    plt.plot(d)
    plt.title(f"Detail Level {i+1}")
    plt.xlabel("Sample Index")
    plt.ylabel("Amplitude")

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Load the loss arrays
train_losses = np.load('train_losses.npy')  # replace with your training loss file if different
val_losses = np.load('val_losses.npy')

# Check that lengths match
epochs = range(1, len(val_losses) + 1)

# Plot training and validation loss
plt.figure(figsize=(8,6))
plt.plot(epochs, train_losses, 'b-', label='Training Loss')
plt.plot(epochs, val_losses, 'r-', label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss for DWT Features')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Load the loss arrays
train_losses = np.load('train_losses_withoutDWT.npy')  # replace with your training loss file if different
val_losses = np.load('val_losses_withoutDWT.npy')

# Check that lengths match
epochs = range(1, len(val_losses) + 1)

# Plot training and validation loss
plt.figure(figsize=(8,6))
plt.plot(epochs, train_losses, 'b-', label='Training Loss')
plt.plot(epochs, val_losses, 'r-', label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss for RAW Features')
plt.legend()
plt.grid(True)
plt.show()
