In [1]:
%%capture
!pip install nussl
!pip install git+https://github.com/source-separation/tutorial

In [18]:
%%capture
from common import data, viz
import nussl
from nussl.ml.networks.modules import AmplitudeToDB, BatchNorm, RecurrentStack, Embedding

import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

## Preparing the Data

In [3]:
# Prepare MUSDB
data.prepare_musdb('~/.nussl/tutorial/')

100%|██████████| 80/80 [00:24<00:00,  3.26it/s]
100%|██████████| 14/14 [00:04<00:00,  3.29it/s]
100%|██████████| 50/50 [00:15<00:00,  3.24it/s]


In [23]:
stft_params = nussl.STFTParams(window_length=512, hop_length=128, window_type='sqrt_hann')
fg_path = "~/.nussl/tutorial/train"
train_data = data.on_the_fly(stft_params, transform=None, fg_path=fg_path, num_mixtures=1000, coherent_prob=1.0)

item = train_data[0]

In [24]:
fg_path = "~/.nussl/tutorial/valid"
val_data = data.on_the_fly(stft_params, transform=None, fg_path=fg_path, num_mixtures=500)

fg_path = "~/.nussl/tutorial/test"
test_data = data.on_the_fly(stft_params, transform=None, fg_path=fg_path, num_mixtures=100)

In [137]:
for key in item:
    print(key, item[key], type(item[key]))

mix AudioSignal (unlabeled): 5.000 sec @ path unknown, 44100 Hz, 1 ch. <class 'nussl.core.audio_signal.AudioSignal'>
sources {'vocals': <nussl.core.audio_signal.AudioSignal object at 0x30a84a5c0>, 'drums': <nussl.core.audio_signal.AudioSignal object at 0x30a8e0fd0>, 'bass': <nussl.core.audio_signal.AudioSignal object at 0x30a8e0460>, 'other': <nussl.core.audio_signal.AudioSignal object at 0x30a8e1180>} <class 'dict'>
metadata {'jam': <JAMS(file_metadata=<FileMetadata(...)>,
      annotations=[1 annotation],
      sandbox=<Sandbox(...)>)>, 'idx': 0} <class 'dict'>


## Defining the Model

In [134]:
class StemSeparationModel(nn.Module):
    def __init__(self, num_features, num_audio_channels, hidden_size,
                 num_layers, bidirectional, dropout, num_sources, 
                activation='sigmoid'):
        super().__init__()
        
        self.verbose = False

        self.amplitude_to_db = AmplitudeToDB()
        self.input_normalization = BatchNorm(num_features)
        self.recurrent_stack = RecurrentStack(
            num_features * num_audio_channels, hidden_size, 
            num_layers, bool(bidirectional), dropout
        )
        hidden_size = hidden_size * (int(bidirectional) + 1)
        self.embedding = Embedding(num_features, hidden_size, 
                                   num_sources, activation, 
                                   num_audio_channels)
        
    def forward(self, item):
        # Get magnitude of mixture signal
        mixture_magnitude_np = np.abs(item['mix'].stft())
        mixture_magnitude = torch.tensor(mixture_magnitude_np).float()
        self.log(f"mixture_magnitude shape before batch dimension: {mixture_magnitude.shape}")

        # Add batch dimension if needed and reshape
        # Expected shape: (batch_size, freq_bins, time_steps)
        # mixture_magnitude shape: [1, 257, 1724]
        if mixture_magnitude.dim() == 3 and mixture_magnitude.size(-1) == 1:
            mixture_magnitude = mixture_magnitude.squeeze(-1)  # Remove last dimension of size 1
        if mixture_magnitude.dim() == 2:
            mixture_magnitude = mixture_magnitude.unsqueeze(0)  # Add batch dimension
        self.log(f"mixture_magnitude shape after batch dimension: {mixture_magnitude.shape}")

        # Convert to log amplitude
        mixture_log_amplitude = self.amplitude_to_db(mixture_magnitude)
        self.log(f"Shape after amplitude to db: {mixture_log_amplitude.shape}")
        
        # Normalize the data
        normalized = self.input_normalization(mixture_log_amplitude)
        self.log(f"Shape after normalization: {normalized.shape}")
    
        # Reshape for LSTM: (batch, time_steps, freq_bins)
        normalized = normalized.transpose(1, 2)
        self.log(f"Shape before LSTM (after transpose): {normalized.shape}")

        # Pass through LSTM
        output = self.recurrent_stack(normalized)
        self.log(f"Shape after LSTM: {output.shape}")

        # Generate mask - should have shape (batch, time_steps, freq_bins, num_sources)
        # mask shape: (1, 1724, 257, 1)
        mask = self.embedding(output)
        if mask.dim() == 5:
            mask = mask.squeeze(-1)  # Remove the extra dimension from the mask if it exists
        self.log(f"Shape of mask: {mask.shape}")
 
        # Reshape mixture_magnitude to align with mask dimensions
        # Current shape: (1, 257, 1724) -> Need: (1, 1724, 257, 1)
        mixture_magnitude = mixture_magnitude.transpose(1, 2)  # (1, 1724, 257)
        mixture_magnitude = mixture_magnitude.unsqueeze(-1)    # (1, 1724, 257, 1)
        self.log(f"Shape of reshaped mixture_magnitude: {mixture_magnitude.shape}")
    
        # Apply mask to get estimates
        estimates = mixture_magnitude * mask
        self.log(f"Shape of estimates: {estimates.shape}")

        return {
            'mask': mask,
            'estimates': estimates,
        }

    def log(self, s):
        if self.verbose:
            print(s)

In [135]:
num_features = stft_params.window_length // 2 + 1
num_audio_channels = 1
hidden_size = 50
num_layers = 2
bidirectional = True
dropout = 0.3
num_sources = 1
activation = 'sigmoid'

model = StemSeparationModel(
    num_features=num_features,
    num_audio_channels=num_audio_channels,
    hidden_size=hidden_size,
    num_layers=num_layers,
    bidirectional=bidirectional,
    dropout=dropout,
    num_sources=num_sources,
    activation=activation,
)

In [136]:
model.verbose = True

def process_item(model, item):
    with torch.no_grad():
        output = model(item)
    return output

output = process_item(model, item)
print()
for key in output:
    print(key, type(output[key]), output[key].shape)

model.verbose = False

mixture_magnitude shape before batch dimension: torch.Size([257, 1724, 1])
mixture_magnitude shape after batch dimension: torch.Size([1, 257, 1724])
Shape after amplitude to db: torch.Size([1, 257, 1724])
Shape after normalization: torch.Size([1, 257, 1724])
Shape before LSTM (after transpose): torch.Size([1, 1724, 257])
Shape after LSTM: torch.Size([1, 1724, 100])
Shape of mask: torch.Size([1, 1724, 257, 1])
Shape of reshaped mixture_magnitude: torch.Size([1, 1724, 257, 1])
Shape of estimates: torch.Size([1, 1724, 257, 1])

mask <class 'torch.Tensor'> torch.Size([1, 1724, 257, 1])
estimates <class 'torch.Tensor'> torch.Size([1, 1724, 257, 1])


## Training the Model