# Model A: Time-Frequency Domain U-Net for Music Source Separation

This notebook demonstrates music source separation using a U-Net architecture in the time-frequency domain.

## 1. Imports and Environment Setup

- Import required libraries (torch, numpy, matplotlib, etc.)

- Set device (CPU/GPU)

In [1]:
#%matplotlib inline
#%load_ext autoreload
#%autoreload 2

import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import os
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from IPython.display import Audio, display

sys.path.append('..')

from models import utils, model_A as ma

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
set_seed(42)

# Device setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cpu


In [2]:
import os
# Define a log file path relative to the notebook's execution directory
LOG_FILE_PATH = 'training_progress.log'
if os.path.exists(LOG_FILE_PATH):
    try:
        os.remove(LOG_FILE_PATH)
        print(f'Cleared old log file: {LOG_FILE_PATH}')
    except Exception as e:
        print(f'Error clearing old log file {LOG_FILE_PATH}: {e}')

## 2. Data Loading and Preprocessing

- Load example mixture and target data (from .npy files)

- Visualize waveforms and spectrograms

- Normalize or preprocess as needed

In [3]:
%matplotlib inline 

import sys
import numpy as np
import musdb
import matplotlib.pyplot as plt
from pathlib import Path

sys.path.append('..') 
from models import utils
from models.utils import AudioProcessor

# ==============================================================================
# 1. PREPARE CURRICULUM CACHE
# ==============================================================================
print("Checking Data Cache...")
mus = musdb.DB(download=True)
utils.prepare_curriculum_cache(mus, cache_dir="../data/curriculum")

# ==============================================================================
# 2. LOAD DATA PATHS FOR BOTH STAGES
# ==============================================================================
data_root = Path("../data/curriculum")
s1_mix_path = data_root / "stage1" / "mixture"
s1_tgt_path = data_root / "stage1" / "target"
s2_mix_path = data_root / "stage2" / "mixture"
s2_tgt_path = data_root / "stage2" / "target"

mix_files_stage1 = sorted(list(s1_mix_path.glob("*.npy")))
tgt_files_stage1 = sorted(list(s1_tgt_path.glob("*.npy")))
mix_files_stage2 = sorted(list(s2_mix_path.glob("*.npy")))
tgt_files_stage2 = sorted(list(s2_tgt_path.glob("*.npy")))

print(f"\n Data Ready!")
print(f"   Stage 1 Samples: {len(mix_files_stage1)}")
print(f"   Stage 2 Samples: {len(mix_files_stage2)}")

Checking Data Cache...
Cache found at ..\data\curriculum. Skipping generation.

 Data Ready!
   Stage 1 Samples: 144
   Stage 2 Samples: 144


## 3. Model Architecture

- Show the U-Net model code

- Print model summary

In [4]:
from models import model_A as ma

# Model summary (default architecture)
model_summary = ma.TimeFrequencyDomainUNet(in_channels=1, out_channels=1, base_filters=64, num_layers=4).to(device)
print(model_summary)
del model_summary

TimeFrequencyDomainUNet(
  (encoders): ModuleList(
    (0): EncoderBlock(
      (block): Sequential(
        (0): ConvLayer2D(
          (block): Sequential(
            (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
        )
        (1): ConvLayer2D(
          (block): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
        )
      )
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): EncoderBlock(
      (block): Sequential(
        (0): ConvLayer2D(
          (block): Sequential(
            (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(128, eps=1e

## 4. Training Setup

In [5]:
#%matplotlib inline
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import torch
import torch.nn as nn
import torch.optim as optim
from models import model_A as ma
from models import utils

# --- Configurable model and training parameters ---
# These can be set externally (e.g., from main.ipynb) by injecting a config dict
MODEL_CONFIG = globals().get('MODEL_CONFIG', {
    'in_channels': 1,
    'out_channels': 1,
    'base_filters': 64,
    'num_layers': 4,
    'batchnorm': True,
    'dropout': 0.1,
})
TRAIN_CONFIG = globals().get('TRAIN_CONFIG', {
    'num_epochs': 50,
    'learning_rate': 1e-4,
    'patience': 10000,
    'batch_size': 2,
})
OVERFIT_CONFIG = globals().get('OVERFIT_CONFIG', {
    'base_filters': 128,
    'num_layers': 4,
    'batchnorm': True,
    'dropout': 0.0,
    'learning_rate': 3e-4,
    'num_epochs': 100,
    'batch_size': 2,
    'patience': 10000,
})

# Device setup
if 'device' not in globals():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'overfit_device' not in globals():
    overfit_device = 'cuda' if torch.cuda.is_available() else 'cpu'

processor = utils.AudioProcessor(device=device)
overfit_processor = utils.AudioProcessor(device=overfit_device)

# Instantiate models
model = ma.TimeFrequencyDomainUNet(
    in_channels=MODEL_CONFIG['in_channels'],
    out_channels=MODEL_CONFIG['out_channels'],
    base_filters=MODEL_CONFIG['base_filters'],
    num_layers=MODEL_CONFIG['num_layers'],
    batchnorm=MODEL_CONFIG['batchnorm'],
    dropout=MODEL_CONFIG['dropout']
).to(device)
overfit_model = ma.TimeFrequencyDomainUNet(
    in_channels=MODEL_CONFIG['in_channels'],
    out_channels=MODEL_CONFIG['out_channels'],
    base_filters=OVERFIT_CONFIG['base_filters'],
    num_layers=OVERFIT_CONFIG['num_layers'],
    batchnorm=OVERFIT_CONFIG['batchnorm'],
    dropout=OVERFIT_CONFIG['dropout']
).to(overfit_device)

# Optimizers and losses
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=TRAIN_CONFIG['learning_rate'])
overfit_loss_fn = nn.MSELoss()
overfit_optimizer = optim.Adam(overfit_model.parameters(), lr=OVERFIT_CONFIG['learning_rate'])

print(f"Setup Complete:")
print(f"   - Device: {device}")
print(f"   - Model: {model.__class__.__name__} (Filters={MODEL_CONFIG['base_filters']}, Layers={MODEL_CONFIG['num_layers']})")
print(f"   - Overfit Model: {overfit_model.__class__.__name__} (Filters={OVERFIT_CONFIG['base_filters']}, Layers={OVERFIT_CONFIG['num_layers']})")
print(f"   - Loss Function: {loss_fn.__class__.__name__}")
print(f"   - Optimizer: {optimizer.__class__.__name__} (lr={TRAIN_CONFIG['learning_rate']})")
print(f"   - Overfit Optimizer: {overfit_optimizer.__class__.__name__} (lr={OVERFIT_CONFIG['learning_rate']})")

Setup Complete:
   - Device: cpu
   - Model: TimeFrequencyDomainUNet (Filters=64, Layers=4)
   - Overfit Model: TimeFrequencyDomainUNet (Filters=128, Layers=4)
   - Loss Function: MSELoss
   - Optimizer: Adam (lr=0.0001)
   - Overfit Optimizer: Adam (lr=0.0003)


## 5. Training Loop

### 5.a Overfit Method

- Train the model on a very small subset (e.g., one batch or a few samples) to ensure it can overfit and the implementation is correct.


In [6]:
from torch.utils.data import DataLoader
from pathlib import Path
import numpy as np
import os
import torch
os.makedirs("../checkpoints", exist_ok=True)

print("Starting Overfit Test on 5 Songs...")

# Prepare small dataset for overfitting
s1_root = Path("../data/curriculum/stage1")
mix_files = sorted(list((s1_root / "mixture").glob("*.npy")))[:5]
tgt_files = sorted(list((s1_root / "target").glob("*.npy")))[:5]

if len(mix_files) < 1 or len(tgt_files) < 1:
    raise ValueError("Not enough data for overfit test. Please check your dataset.")

tiny_ds = utils.StandardDataset(mix_files, tgt_files)
tiny_loader = DataLoader(tiny_ds, batch_size=OVERFIT_CONFIG['batch_size'], shuffle=False)

trainer_overfit = utils.UniversalTrainer(
    model=overfit_model,
    train_loader=tiny_loader,
    val_loader=tiny_loader,
    processor=overfit_processor,
    optimizer=overfit_optimizer,
    loss_fn=overfit_loss_fn,
    device=overfit_device,
    patience=OVERFIT_CONFIG['patience']
)

save_path = "../checkpoints/debug_overfit_5songs.pth"
history = {}

if not os.path.exists(save_path):
    print("   Training from scratch on 5 songs...")
    history = trainer_overfit.train(num_epochs=OVERFIT_CONFIG['num_epochs'], save_path=save_path)
else:
    print(f"✓ Found Checkpoint: {save_path}")
    ckpt = torch.load(save_path, map_location=device)
    overfit_model.load_state_dict(ckpt['model_state_dict'])
    history = ckpt.get('history', {})

# utils.plot_loss_history(history, "Overfit Learning Curve (5 Songs)")

# Visualize results for each song (optional, can be commented out for OOM safety)
# for i, batch in enumerate(tiny_loader):
#     mix = batch['mix'].to(device)
#     tgt = batch['tgt'].to(device)
#     overfit_model.eval()
#     with torch.no_grad():
#         mix_log, _ = processor.to_spectrogram(mix)
#         tgt_log, _ = processor.to_spectrogram(tgt)
#         mix_in = mix_log.unsqueeze(1)
#         mask = overfit_model(mix_in)
#         if mask.shape != mix_in.shape:
#             mask = mask[:, :, :mix_in.shape[2], :mix_in.shape[3]]
#         est_linear = mask * torch.expm1(mix_in)
#         est_log = torch.log1p(est_linear)
#     utils.visualize_results(mix_log[0], tgt_log[0], est_log[0].squeeze(), title=f"Overfit Verification Song {i+1} (Log Scale)")

Starting Overfit Test on 5 Songs...
✓ Found Checkpoint: ../checkpoints/debug_overfit_5songs.pth


  ckpt = torch.load(save_path, map_location=device)


# Listen to mixture, target, and predicted audio

In [7]:
# from models.utils import play_audio
#
# sr = 22050
# mix_wav = processor.to_waveform(mix_log[0].cpu(), torch.zeros_like(mix_log[0]))
# tgt_wav = processor.to_waveform(tgt_log[0].cpu(), torch.zeros_like(tgt_log[0]))
# est_wav = processor.to_waveform(est_log[0].cpu(), torch.zeros_like(est_log[0]))
#
# play_audio(mix_wav, sr=sr, title="Mixture Audio (Overfit)")
# play_audio(tgt_wav, sr=sr, title="Target Audio (Overfit)")
# play_audio(est_wav, sr=sr, title="Predicted Audio (Overfit)")

### 5.b Full Training

- Train the model on the full training dataset as intended.


In [None]:
print("Starting Full Training Pipeline...")

# Stage 1
print("\n--- Stage 1: Vocals + Other -> Other ---")

s1_root = Path("../data/curriculum/stage1")
s1_mix = sorted(list((s1_root / "mixture").glob("*.npy")))
s1_tgt = sorted(list((s1_root / "target").glob("*.npy")))

split_s1 = int(len(s1_mix) * 0.8)
train_ds1 = utils.StandardDataset(s1_mix[:split_s1], s1_tgt[:split_s1])
val_ds1 = utils.StandardDataset(s1_mix[split_s1:], s1_tgt[split_s1:])

if len(train_ds1) < 1 or len(val_ds1) < 1:
    raise ValueError("Not enough data for full training Stage 1. Please check your dataset.")

train_loader1 = DataLoader(train_ds1, batch_size=TRAIN_CONFIG['batch_size'], shuffle=True)
val_loader1 = DataLoader(val_ds1, batch_size=TRAIN_CONFIG['batch_size'], shuffle=False)

trainer_s1 = utils.UniversalTrainer(
    model=model,
    train_loader=train_loader1,
    val_loader=val_loader1,
    processor=processor,
    optimizer=optimizer,
    loss_fn=loss_fn,
    device=device,
    patience=TRAIN_CONFIG['patience']
)

path_s1 = "../checkpoints/full_stage1.pth"
hist_s1 = {}

if not os.path.exists(path_s1):
    hist_s1 = trainer_s1.train(num_epochs=TRAIN_CONFIG['num_epochs'], save_path=path_s1, log_file_path=LOG_FILE_PATH)
else:
    print(f"✓ Found Checkpoint: {path_s1}")
    ckpt = torch.load(path_s1, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    hist_s1 = ckpt.get('history', {})

# utils.plot_loss_history(hist_s1, "Stage 1 Results")

# Stage 2
print("\n--- Stage 2: Full Mix -> Other ---")

s2_root = Path("../data/curriculum/stage2")
s2_mix = sorted(list((s2_root / "mixture").glob("*.npy")))
s2_tgt = sorted(list((s2_root / "target").glob("*.npy")))

split_s2 = int(len(s2_mix) * 0.8)
train_ds2 = utils.StandardDataset(s2_mix[:split_s2], s2_tgt[:split_s2])
val_ds2 = utils.StandardDataset(s2_mix[split_s2:], s2_tgt[split_s2:])

if len(train_ds2) < 1 or len(val_ds2) < 1:
    raise ValueError("Not enough data for full training Stage 2. Please check your dataset.")

train_loader2 = DataLoader(train_ds2, batch_size=TRAIN_CONFIG['batch_size'], shuffle=True)
val_loader2 = DataLoader(val_ds2, batch_size=TRAIN_CONFIG['batch_size'], shuffle=False)

for param_group in optimizer.param_groups:
    param_group['lr'] = TRAIN_CONFIG['learning_rate'] * 0.1
print(f"✓ Optimizer LR reduced to {TRAIN_CONFIG['learning_rate'] * 0.1}")

trainer_s2 = utils.UniversalTrainer(
    model=model,
    train_loader=train_loader2,
    val_loader=val_loader2,
    processor=processor,
    optimizer=optimizer,
    loss_fn=loss_fn,
    device=device,
    patience=TRAIN_CONFIG['patience']
)

path_s2 = "../checkpoints/full_stage2.pth"
hist_s2 = {}

if not os.path.exists(path_s2):
    hist_s2 = trainer_s2.train(num_epochs=TRAIN_CONFIG['num_epochs'], save_path=path_s2, log_file_path=LOG_FILE_PATH)
else:
    print(f"✓ Found Checkpoint: {path_s2}")
    ckpt = torch.load(path_s2, map_location=device)
    model.load_state_dict(ckpt['model_state_dict'])
    hist_s2 = ckpt.get('history', {})

# utils.plot_loss_history(hist_s2, "Stage 2 Results")

Starting Full Training Pipeline...

--- Stage 1: Vocals + Other -> Other ---
[DEBUG] Using tqdm.notebook for progress bars.


Total Progress:   0%|          | 0/50 [00:00<?, ?it/s]

[DEBUG] Using tqdm.notebook for progress bars.


Ep 1 Training:   0%|          | 0/58 [00:00<?, ?it/s]

## 6. Evaluation and Inference

- Run the trained model on test data

- Visualize separated sources (waveforms, spectrograms)

- Optionally, compute evaluation metrics (e.g., SDR, SIR)


In [None]:
import torch
import matplotlib.pyplot as plt

ckpt_path = "../checkpoints/full_stage2.pth"

ckpt = torch.load(ckpt_path, map_location='cpu')
if 'history' in ckpt:
    history = ckpt['history']
    plt.figure(figsize=(10, 4))
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title("Loss Curves from Checkpoint")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()
else:
    print("No loss history found in checkpoint.")

## 7. Listen to the masked songs

- listen to masked songs from the MUSDB18 dataset

- watch the spectrograms

- upload a song to model_A_input folder

In [None]:
import os
import numpy as np
import torch
from models.utils import show_spectrogram, play_audio, AudioProcessor
from models.model_A import TimeFrequencyDomainUNet
from pathlib import Path
import matplotlib.pyplot as plt
import librosa
import librosa.display

# Parameters
sr = 22050 # Sample rate
duration = 6 # seconds
n_samples = sr * duration

# Load a sample from MUSDB18 cache
data_root = Path("../data/curriculum")
s1_mix_path = data_root / "stage1" / "mixture"
s1_tgt_path = data_root / "stage1" / "target"
mix_files = sorted(list(s1_mix_path.glob("*.npy")))
tgt_files = sorted(list(s1_tgt_path.glob("*.npy")))
song_num = 12  # Change this to select a different sample from dataset musdb18

sample_idx = song_num
mix_wav = np.load(mix_files[sample_idx])[:n_samples]
tgt_wav = np.load(tgt_files[sample_idx])[:n_samples]

# Compute and show spectrograms for 6 sec sample
processor = AudioProcessor(device=device)
mix_mag, mix_phase = processor.to_spectrogram(torch.tensor(mix_wav))
tgt_mag, tgt_phase = processor.to_spectrogram(torch.tensor(tgt_wav))
show_spectrogram(mix_mag, title="Mixture Spectrogram (6 sec)")
show_spectrogram(tgt_mag, title="Target Spectrogram (6 sec)")

# Predict masked output and show predicted spectrogram
model.eval()
with torch.no_grad():
    # Ensure mix_mag_in is 4D: (batch, channel, height, width)
    if mix_mag.dim() == 2:
        mix_mag_in = mix_mag.unsqueeze(0).unsqueeze(0).to(device)  # (1, 1, H, W)
    elif mix_mag.dim() == 3:
        mix_mag_in = mix_mag.unsqueeze(1).to(device)  # (batch, 1, H, W)
    elif mix_mag.dim() == 4:
        mix_mag_in = mix_mag.to(device)
    else:
        raise ValueError(f"mix_mag must be 2D, 3D, or 4D, got shape {mix_mag.shape}")

    if mix_mag_in.dim() != 4:
        raise ValueError(f"mix_mag_in must be 4D, got shape {mix_mag_in.shape}")

    mask = model(mix_mag_in)
    if mask.shape != mix_mag_in.shape:
        mask = mask[:, :, :mix_mag_in.shape[2], :mix_mag_in.shape[3]]
    est_mag = mask.squeeze(0).squeeze(0) * mix_mag.to(device)
    est_wav = processor.to_waveform(est_mag.cpu(), mix_phase.cpu())
show_spectrogram(est_mag.cpu(), title="Predicted Spectrogram (6 sec)")

# Play audio (optional, can comment out if not needed)
play_audio(mix_wav, sr=sr, title="Mixture Audio (6 sec)")
play_audio(tgt_wav, sr=sr, title="Target Audio (6 sec)")
play_audio(est_wav, sr=sr, title="Predicted Audio (6 sec)")

# Optionally upload and process your own song
input_dir = Path("../model_A_input")
os.makedirs(input_dir, exist_ok=True)
user_files = sorted(list(input_dir.glob("*.wav")))
if user_files:
    user_path = user_files[0] # Take the first uploaded music file 
    user_wav, user_sr = librosa.load(user_path, sr=sr)
    user_wav = user_wav[:n_samples]
    user_mag, user_phase = processor.to_spectrogram(torch.tensor(user_wav))
    show_spectrogram(user_mag, title=f"User Mixture Spectrogram (6 sec): {user_path.name}")
    with torch.no_grad():
        user_mag_in = user_mag.unsqueeze(0).unsqueeze(1).to(device)  # Shape: (1, 1, H, W)
        if user_mag_in.dim() != 4:
            raise ValueError(f"user_mag_in must be 4D, got shape {user_mag_in.shape}")
        user_mask = model(user_mag_in)
        if user_mask.shape != user_mag_in.shape:
            user_mask = user_mask[:, :, :user_mag_in.shape[2], :user_mag_in.shape[3]]
        user_est_mag = user_mask.squeeze(0).squeeze(0) * user_mag.to(device)
        user_est_wav = processor.to_waveform(user_est_mag.cpu(), user_phase.cpu())
    show_spectrogram(user_est_mag.cpu(), title=f"User Predicted Spectrogram (6 sec): {user_path.name}")
    show_spectrogram(user_phase.cpu() if hasattr(user_phase, 'cpu') else user_phase, title=f"User Phase Spectrogram (6 sec): {user_path.name}")
    # play_audio(user_wav, sr=sr, title=f"User Mixture Audio (6 sec): {user_path.name}")
    # play_audio(user_est_wav, sr=sr, title=f"User Predicted Audio (6 sec): {user_path.name}")
else:
    print("No user .wav file found in model_A_input. Upload a song to try your own audio!")

In [None]:
import os, sys
print('Current working directory:', os.getcwd())
print('sys.path:', sys.path)
print('Directory listing:', os.listdir('.'))