In [2]:
import os
import musdb
import torch
import torchaudio
from torch.utils.data import Dataset
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

In [None]:
mus = musdb.DB(root="../musdb18hq", is_wav=True)
print(f"Number of tracks: {len(mus)}")

mus_train = mus.load_mus_tracks(subsets="train")
mus_test = mus.load_mus_tracks(subsets="test")
print(f"Training tracks: {len(mus_train)}")
print(f"Test tracks: {len(mus_test)}")

Number of tracks: 150
Training tracks: 100
Test tracks: 50


In [4]:
def chunk_spectrogram(spec, chunk_size=512):
    num_frames = spec.shape[-1]
    chunks = []
    for start in range(0, num_frames, chunk_size):
        end = min(start + chunk_size, num_frames)
        chunk = spec[..., start:end]
        if chunk.shape[-1] < chunk_size:
            padding = (0, chunk_size - chunk.shape[-1])
            chunk = torch.nn.functional.pad(chunk, padding)
        chunks.append(chunk)
    return chunks

In [None]:
class PrecomputedMusdbDataset(Dataset):
    def __init__(self, spec_dir="musdb_specs", chunk_size=512):
        self.spec_dir = spec_dir
        self.chunk_size = chunk_size
        self.chunk_indices = self._prepare_chunk_indices()

    def _prepare_chunk_indices(self):
        chunk_indices = []
        for i in range(100):
            mix_data = torch.load(f"{self.spec_dir}/mix_{i}.pt")
            num_chunks = len(mix_data['chunks'])
            for chunk_idx in range(num_chunks):
                chunk_indices.append((i, chunk_idx))
        return chunk_indices

    def __len__(self): return len(self.chunk_indices)
    def __getitem__(self, idx):
        track_idx, chunk_idx = self.chunk_indices[idx]
        mix_data = torch.load(f"{self.spec_dir}/mix_{track_idx}.pt")
        target_data = torch.load(f"{self.spec_dir}/target_{track_idx}.pt")
        return mix_data['chunks'][chunk_idx], target_data['chunks'][chunk_idx]

In [6]:
class ConvLayer(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, conv_type, padding=0, transpose=False, dropout=0.1):
        super(ConvLayer, self).__init__()
        self.transpose = transpose
        self.stride = stride
        self.kernel_size = kernel_size
        self.conv_type = conv_type
        self.padding = padding

        if self.transpose:
            self.filter = nn.ConvTranspose2d(n_inputs, n_outputs, self.kernel_size, stride, padding=self.padding)
        else:
            self.filter = nn.Conv2d(n_inputs, n_outputs, self.kernel_size, stride, padding=self.padding)

        NORM_CHANNELS = 8
        if conv_type == "gn":
            assert n_outputs % NORM_CHANNELS == 0
            self.norm = nn.GroupNorm(n_outputs // NORM_CHANNELS, n_outputs)
        elif conv_type == "bn":
            self.norm = nn.BatchNorm2d(n_outputs, momentum=0.01)
        else:
            self.norm = None

        self.dropout = nn.Dropout2d(dropout) if dropout > 0 else None

    def forward(self, x):
        out = self.filter(x)
        if self.norm:
            out = self.norm(out)
        out = F.leaky_relu(out, negative_slope=0.2)
        if self.dropout:
            out = self.dropout(out)
        return out

In [7]:
def centre_crop(x, target):
    """
    Center-crop 3D or 4D input tensor along the last two spatial dimensions to match target shape.
    """
    if x is None or target is None:
        return x
    if x.size(2) == target.size(2) and x.size(3) == target.size(3):
        return x  # No cropping needed

    diff_h = x.size(2) - target.size(2)
    diff_w = x.size(3) - target.size(3)

    if diff_h < 0 or diff_w < 0:
        # If x is smaller, interpolate instead of cropping
        return F.interpolate(x, size=(target.size(2), target.size(3)), mode='bilinear', align_corners=False)

    crop_h1 = diff_h // 2
    crop_h2 = diff_h - crop_h1
    crop_w1 = diff_w // 2
    crop_w2 = diff_w - crop_w1

    return x[:, :, crop_h1:x.size(2) - crop_h2, crop_w1:x.size(3) - crop_w2].contiguous()

In [8]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.depthwise_conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1, groups=channels)
        self.pointwise_conv = nn.Conv2d(channels, channels, kernel_size=1)
        self.block = nn.Sequential(
            self.depthwise_conv,
            nn.GroupNorm(8, channels),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            self.pointwise_conv,
            nn.GroupNorm(8, channels)
        )

    def forward(self, x):
        return x + self.block(x)

In [9]:
class AttentionBlock(nn.Module):
    def __init__(self, channels):
        super(AttentionBlock, self).__init__()
        self.attention = nn.Sequential(
            nn.GroupNorm(8, channels),
            nn.Conv2d(channels, channels // 8, kernel_size=1, groups=4),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // 8, channels, kernel_size=1, groups=4),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.attention(x)

In [None]:
class ModifiedUNet(nn.Module):
    def __init__(self, in_channels=2, out_channels=8):
        super(ModifiedUNet, self).__init__()

        # Encoder
        self.enc1 = nn.Sequential(ConvLayer(in_channels, 64, 4, 2, "gn", padding=1), ResidualBlock(64))
        self.enc2 = nn.Sequential(ConvLayer(64, 128, 4, 2, "gn", padding=1), ResidualBlock(128))
        self.enc3 = nn.Sequential(ConvLayer(128, 256, 4, 2, "gn", padding=1), ResidualBlock(256))

        # Bottleneck
        self.bottleneck = nn.Sequential(ConvLayer(256, 512, 3, 1, "gn", padding=1), AttentionBlock(512))

        # Decoder
        self.dec3 = nn.Sequential(ConvLayer(512 + 256, 256, 4, 2, "gn", padding=1, transpose=True), AttentionBlock(256))
        self.dec2 = nn.Sequential(ConvLayer(256 + 128, 128, 4, 2, "gn", padding=1, transpose=True), AttentionBlock(128))
        self.dec1 = nn.Sequential(ConvLayer(128 + 64, 64, 4, 2, "gn", padding=1, transpose=True), AttentionBlock(64))

        # Final layer
        self.final = nn.Conv2d(64, out_channels, 3, 1, 1)

    def forward(self, x):
        x = x[:, :, :512, :]
        enc1 = checkpoint(self.enc1, x, use_reentrant=False)
        enc2 = checkpoint(self.enc2, enc1, use_reentrant=False)
        enc3 = checkpoint(self.enc3, enc2, use_reentrant=False)
        bottleneck = checkpoint(self.bottleneck, enc3, use_reentrant=False)
        dec3 = checkpoint(self.dec3, torch.cat((bottleneck, enc3), dim=1), use_reentrant=False)
        dec2 = checkpoint(self.dec2, torch.cat((dec3, enc2), dim=1), use_reentrant=False)
        dec1 = checkpoint(self.dec1, torch.cat((dec2, enc1), dim=1), use_reentrant=False)
        out = self.final(dec1)
        out = torch.sigmoid(out)
        out = out.view(-1, 4, 2, *out.shape[2:])
        return F.pad(out, (0, 0, 0, 1))

In [None]:
model = ModifiedUNet(in_channels=2, out_channels=8)
x = torch.randn(2, 2, 513, 512)
out = model(x)
print(f"Output shape: {out.shape}")

Output shape: torch.Size([2, 4, 2, 513, 512])


In [12]:
# Loss Function with Clamping and Debugging
class StereoSpectrogramLoss(nn.Module):
    def __init__(self, w_log_mag=1.0, w_lin_mag=1.0):
        super(StereoSpectrogramLoss, self).__init__()
        self.w_log_mag = w_log_mag
        self.w_lin_mag = w_lin_mag
        self.mse = nn.MSELoss()

    def forward(self, pred, target):
        pred_mid = (pred[:, :, 0] + pred[:, :, 1]) / 2
        pred_side = (pred[:, :, 0] - pred[:, :, 1]) / 2
        target_mid = (target[:, :, 0] + target[:, :, 1]) / 2
        target_side = (target[:, :, 0] - target[:, :, 1]) / 2

        # Clamp to ensure non-negative for log1p
        pred_mid = torch.clamp(pred_mid, min=0)
        pred_side = torch.clamp(pred_side, min=0)
        target_mid = torch.clamp(target_mid, min=0)
        target_side = torch.clamp(target_side, min=0)

        # Debugging
        if not (torch.all(torch.isfinite(pred_mid)) and torch.all(torch.isfinite(pred_side))):
            print("NaN/Inf in pred_mid or pred_side")
            print(f"pred_mid: {pred_mid.min()}, {pred_mid.max()}")
            print(f"pred_side: {pred_side.min()}, {pred_side.max()}")
        if not (torch.all(torch.isfinite(target_mid)) and torch.all(torch.isfinite(target_side))):
            print("NaN/Inf in target_mid or target_side")
            print(f"target_mid: {target_mid.min()}, {target_mid.max()}")
            print(f"target_side: {target_side.min()}, {target_side.max()}")

        log_loss_mid = self.mse(torch.log1p(pred_mid), torch.log1p(target_mid))
        log_loss_side = self.mse(torch.log1p(pred_side), torch.log1p(target_side))
        if not (torch.isfinite(log_loss_mid) and torch.isfinite(log_loss_side)):
            print("NaN in log losses")
            print(f"log_loss_mid: {log_loss_mid}, log_loss_side: {log_loss_side}")
        log_loss = (log_loss_mid + log_loss_side) / 2

        lin_loss_mid = self.mse(pred_mid, target_mid)
        lin_loss_side = self.mse(pred_side, target_side)
        if not (torch.isfinite(lin_loss_mid) and torch.isfinite(lin_loss_side)):
            print("NaN in lin losses")
            print(f"lin_loss_mid: {lin_loss_mid}, lin_loss_side: {lin_loss_side}")
        lin_loss = (lin_loss_mid + lin_loss_side) / 2

        total_loss = self.w_log_mag * log_loss + self.w_lin_mag * lin_loss
        if not torch.isfinite(total_loss):
            print("NaN in total_loss")
            print(f"log_loss: {log_loss}, lin_loss: {lin_loss}")
        return total_loss

In [None]:
def preprocess_dataset(musdb_root, output_dir="../musdb_specs"):
    os.makedirs(output_dir, exist_ok=True)
    mus = musdb.DB(root=musdb_root, is_wav=True)
    tracks = mus.load_mus_tracks(subsets="train")
    transform = torchaudio.transforms.Spectrogram(n_fft=1024, hop_length=256, power=2)
    for i, track in enumerate(tqdm(tracks, desc="Preprocessing Tracks")):
        mix = torch.tensor(track.audio.T, dtype=torch.float32)
        targets = [torch.tensor(track.targets[src].audio.T, dtype=torch.float32) 
                   for src in ['vocals', 'drums', 'bass', 'other']]
        mix_spec = torch.log1p(transform(mix))
        target_spec = torch.stack([torch.log1p(transform(t)) for t in targets])
        torch.save({'spec': mix_spec, 'chunks': chunk_spectrogram(mix_spec, 512)}, f"{output_dir}/mix_{i}.pt")
        torch.save({'spec': target_spec, 'chunks': chunk_spectrogram(target_spec, 512)}, f"{output_dir}/target_{i}.pt")
    return len(tracks)

In [None]:
preprocess_dataset("../musdb18hq")

Preprocessing Tracks: 100%|██████████| 100/100 [12:18<00:00,  7.39s/it]


100

In [15]:
dataset = PrecomputedMusdbDataset(spec_dir="musdb_specs")  # 7,754 chunks
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)  # No workers
model = ModifiedUNet(in_channels=2, out_channels=8)
criterion = StereoSpectrogramLoss(w_log_mag=1.0, w_lin_mag=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ModifiedUNet(
  (enc1): Sequential(
    (0): ConvLayer(
      (filter): Conv2d(2, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
      (dropout): Dropout2d(p=0.1, inplace=False)
    )
    (1): ResidualBlock(
      (depthwise_conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
      (pointwise_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      (block): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
        (1): GroupNorm(8, 64, eps=1e-05, affine=True)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
        (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
        (4): GroupNorm(8, 64, eps=1e-05, affine=True)
      )
    )
  )
  (enc2): Sequential(
    (0): ConvLayer(
      (filter): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (norm): GroupNorm(16, 128, eps=1e-05, affine=True)
      (dr

In [16]:
# Verify chunk and batch count
print(f"Total chunks in training dataset: {len(dataset)}")
print(f"Batches per epoch: {len(dataloader)}")

Total chunks in training dataset: 7754
Batches per epoch: 3877


In [None]:
# Training Function with Multi-GPU Support
def train(model, dataloader, criterion, optimizer, epochs=10, device_ids=[0, 1]):
    model.train()
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs!")
        model = nn.DataParallel(model, device_ids=device_ids)
    model.to(device)

    for epoch in tqdm(range(epochs), desc="Training Progress", unit="epoch"):
        running_loss = 0.0
        batch_iterator = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}", leave=False, unit="batch", total=len(dataloader))
        for i, (mix_batch, target_batch) in enumerate(batch_iterator):
            try:
                mix_batch, target_batch = mix_batch.to(device), target_batch.to(device)
                optimizer.zero_grad()
                pred_spec = model(mix_batch)  # [batch_size, 4, 2, 513, 512]
                loss = criterion(pred_spec, target_batch)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                running_loss += loss.item()
            except Exception as e:
                tqdm.write(f"Crash at Batch {i}: {str(e)}")
                raise

        avg_loss = running_loss / len(dataloader)
        tqdm.write(f"Epoch {epoch + 1}/{epochs} Completed - Average Loss: {avg_loss:.4f}")
        torch.save(model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(), 
                   f"unet_model_epoch{epoch + 1}.pth")

In [18]:
train(model, dataloader, criterion, optimizer, epochs=10)

Training Progress:  10%|█         | 1/10 [3:05:42<27:51:18, 11142.05s/epoch]

Epoch 1/10 Completed - Average Loss: 0.0474


Training Progress:  10%|█         | 1/10 [5:55:49<53:22:23, 21349.33s/epoch]


KeyboardInterrupt: 

In [None]:
import os
import torch
import torchaudio
from tqdm import tqdm
import soundfile as sf

def chunk_spectrogram(spec, chunk_size=512):
    num_frames = spec.shape[-1]
    chunks = []
    for start in range(0, num_frames, chunk_size):
        end = min(start + chunk_size, num_frames)
        chunk = spec[..., start:end]
        if chunk.shape[-1] < chunk_size:
            padding = (0, chunk_size - chunk.shape[-1])
            chunk = torch.nn.functional.pad(chunk, padding)
        chunks.append(chunk)
    return chunks

def test_model_on_file(model, audio_path, device, output_dir="separated_audio", batch_size=2):
    model.eval()
    os.makedirs(output_dir, exist_ok=True)

    audio, sample_rate = torchaudio.load(audio_path)
    if sample_rate != 44100:
        resampler = torchaudio.transforms.Resample(sample_rate, 44100)
        audio = resampler(audio)
    if audio.shape[0] != 2:
        raise ValueError("Input audio must be stereo (2 channels).")

    transform = torchaudio.transforms.Spectrogram(n_fft=1024, hop_length=256, power=None)
    complex_spec = transform(audio) 
    magnitude_spec = complex_spec.abs() 
    phase_spec = complex_spec.angle()    
    mix_spec = torch.log1p(magnitude_spec)  
    mix_chunks = chunk_spectrogram(mix_spec, 512)  

    num_chunks = len(mix_chunks)
    batches = [mix_chunks[i:i + batch_size] for i in range(0, num_chunks, batch_size)]
    
    full_pred_specs = {s: [] for s in range(4)} 
    with torch.no_grad():
        for batch in tqdm(batches, desc="Processing Chunks"):
            batch_tensor = torch.stack(batch).to(device)
            pred_spec = model(batch_tensor)             
            pred_spec = pred_spec.cpu()
            for b in range(pred_spec.size(0)):
                for s in range(4):
                    full_pred_specs[s].append(torch.expm1(pred_spec[b, s]))

    target_time_frames = magnitude_spec.shape[-1] 
    full_pred_magnitudes = [torch.cat(full_pred_specs[s], dim=-1)[..., :target_time_frames] for s in range(4)]
    full_mix_magnitude = torch.cat(mix_chunks, dim=-1)[..., :target_time_frames] 
    full_mix_magnitude = torch.expm1(full_mix_magnitude)  

    full_pred_specs = [mag * torch.exp(1j * phase_spec) for mag in full_pred_magnitudes]  
    full_mix_spec = full_mix_magnitude * torch.exp(1j * phase_spec) 

    inverse_transform = torchaudio.transforms.InverseSpectrogram(n_fft=1024, hop_length=256)
    mix_audio = inverse_transform(full_mix_spec) 
    source_audios = [inverse_transform(pred_spec) for pred_spec in full_pred_specs] 

    file_name = os.path.splitext(os.path.basename(audio_path))[0]
    sf.write(f"{output_dir}/{file_name}_mix.wav", mix_audio.T.numpy(), 44100)
    for s, source_audio in enumerate(source_audios):
        source_name = ['vocals', 'drums', 'bass', 'other'][s]
        sf.write(f"{output_dir}/{file_name}_{source_name}.wav", source_audio.T.numpy(), 44100)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ModifiedUNet(in_channels=2, out_channels=8).to(device)

model.load_state_dict(torch.load("unet_model_epoch1.pth"))
print("Loaded pretrained model from epoch 1")

audio_path = "/home/sid/Desktop/Projects/Samsung_PRISM/Project_ausep/Code/musdb18hq/test/Al James - Schoolboy Facination/mixture.wav"
test_model_on_file(model, audio_path, device, output_dir="separated_audio")

print("Testing complete. Check 'separated_audio' folder for WAV files.")

Loaded pretrained model from epoch 1


Processing Chunks: 100%|██████████| 34/34 [00:11<00:00,  2.98it/s]


Testing complete. Check 'separated_audio' folder for WAV files.
