In [1]:
%%capture
!pip install nussl
!pip install git+https://github.com/source-separation/tutorial

In [2]:
%%capture
from common import data, utils, viz

import json
import copy
import argbind
from pathlib import Path

import IPython.display as display
from IPython.display import Audio

import nussl
from nussl.ml.networks.modules import AmplitudeToDB, BatchNorm, RecurrentStack, Embedding
from nussl.datasets import transforms as nussl_tfm

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn

## Preparing the Data

In [3]:
# Prepare MUSDB
data.prepare_musdb('~/.nussl/tutorial/')

100%|██████████| 80/80 [00:27<00:00,  2.87it/s]
100%|██████████| 14/14 [00:04<00:00,  3.03it/s]
100%|██████████| 50/50 [00:17<00:00,  2.84it/s]


In [4]:
stft_params = nussl.STFTParams(window_length=512, hop_length=128, window_type='sqrt_hann')

tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),  # TODO: currently trying to output only the vocals source from the model
    nussl_tfm.MagnitudeSpectrumApproximation(),
    nussl_tfm.IndexSources('source_magnitudes', 1),
    nussl_tfm.ToSeparationModel(),
])

In [5]:
train_folder = "~/.nussl/tutorial/train"
val_folder = "~/.nussl/tutorial/valid"

MAX_MIXTURES = int(1e8) # Set to some impossibly high number for on-the-fly mixing.

train_data = data.on_the_fly(stft_params, transform=tfm, fg_path=train_folder, num_mixtures=MAX_MIXTURES, coherent_prob=1.0)
train_dataloader = torch.utils.data.DataLoader(train_data, num_workers=1, batch_size=10)

val_data = data.on_the_fly(stft_params, transform=tfm, fg_path=val_folder, num_mixtures=10, coherent_prob=1.0)
val_dataloader = torch.utils.data.DataLoader(val_data, num_workers=1, batch_size=10)

In [6]:
test_tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources(  # TODO: currently trying to output only the vocals source from the model
        groupings=[['drums', 'bass', 'other']],
        group_names=['accompaniment'],
    ),
    nussl_tfm.MagnitudeSpectrumApproximation(),
])

test_folder = "~/.nussl/tutorial/test"

test_data = data.on_the_fly(stft_params, transform=test_tfm, fg_path=test_folder, num_mixtures=100)

In [7]:
item = test_data[0]
for key in item:
    print(key, type(item[key]), item[key].shape if isinstance(item[key], np.ndarray) else "")

mix <class 'nussl.core.audio_signal.AudioSignal'> 
sources <class 'collections.OrderedDict'> 
metadata <class 'dict'> 
index <class 'int'> 
mix_magnitude <class 'numpy.ndarray'> (257, 1724, 1)
ideal_binary_mask <class 'numpy.ndarray'> (257, 1724, 1, 2)
source_magnitudes <class 'numpy.ndarray'> (257, 1724, 1, 2)


In [8]:
item = train_data[0]
for key in item:
    print(key, type(item[key]), item[key].shape if isinstance(item[key], torch.Tensor) else "")

index <class 'int'> 
mix_magnitude <class 'torch.Tensor'> torch.Size([1724, 257, 1])
ideal_binary_mask <class 'torch.Tensor'> torch.Size([1724, 257, 1, 2])
source_magnitudes <class 'torch.Tensor'> torch.Size([1724, 257, 1, 1])


## Defining the Model

In [9]:
class StemSeparationModel(nn.Module):
    def __init__(self, num_features, num_audio_channels, hidden_size,
                 num_layers, bidirectional, dropout, num_sources, 
                activation='sigmoid'):
        super().__init__()
        
        self.verbose = False

        self.amplitude_to_db = AmplitudeToDB()
        self.input_normalization = BatchNorm(num_features)
        self.recurrent_stack = RecurrentStack(
            num_features * num_audio_channels, hidden_size, 
            num_layers, bool(bidirectional), dropout
        )
        hidden_size = hidden_size * (int(bidirectional) + 1)
        self.embedding = Embedding(num_features, hidden_size, 
                                   num_sources, activation, 
                                   num_audio_channels)
        
        self.set_up_config(num_features, num_audio_channels, hidden_size,
                 num_layers, bidirectional, dropout, num_sources, 
                activation)

    def set_up_config(self, num_features, num_audio_channels, hidden_size,
                 num_layers, bidirectional, dropout, num_sources, 
                activation='sigmoid'):
        modules = {
            'model': {
                'class': 'StemSeparationModel',
                'args': {
                    'num_features': num_features,
                    'num_audio_channels': num_audio_channels,
                    'hidden_size': hidden_size,
                    'num_layers': num_layers,
                    'bidirectional': bidirectional,
                    'dropout': dropout,
                    'num_sources': num_sources,
                    'activation': activation,
                }
            }
        }

        connections = [
            ['model', ['mix_magnitude']]
        ]

        for key in ['mask', 'vocal_estimate']:
            modules[key] = {'class': 'Alias'}
            connections.append([key, [f'model:{key}']])
        
        output = ['vocal_estimate', 'mask',]
        self.config = {
            'name': 'StemSeparationModel',
            'modules': modules,
            'connections': connections,
            'output': output,
        }
        self.metadata = {
            'config': self.config,
            'nussl_version': '0.0.1',
        }

    def log(self, s):
        if self.verbose:
            print(s)

    def forward(self, item):
        # Get magnitude of mixture signal
        mixture_magnitude = item['mix_magnitude']
        if mixture_magnitude.dim() == 3:
            mixture_magnitude = mixture_magnitude.unsqueeze(0)  # Add a batch dimension to the mixture magnitude if needed
        self.log(f"Shape of mixture_magnitude: {mixture_magnitude.shape}")

        # Convert to log amplitude
        mixture_log_amplitude = self.amplitude_to_db(mixture_magnitude)
        self.log(f"Shape after amplitude to db: {mixture_log_amplitude.shape}")
        
        # Normalize the data
        normalized = self.input_normalization(mixture_log_amplitude)
        self.log(f"Shape after normalization: {normalized.shape}")

        # Pass through LSTM
        output = self.recurrent_stack(normalized)
        self.log(f"Shape after LSTM: {output.shape}")

        # Generate mask
        mask = self.embedding(output)
        self.log(f"Shape of mask: {mask.shape}")
    
        # Apply mask to get estimates
        # TODO: right now this model is defined just to output an estimate of the vocals source - later can expand this to calculate a mask for every source
        vocals_estimate = mixture_magnitude.unsqueeze(-1) * mask
        self.log(f"Shape of vocals estimate: {vocals_estimate.shape}")

        return {
            'mask': mask,
            'vocals_estimate': vocals_estimate,
        }

    def save(self, location, metadata=None, train_data=None, val_data=None, trainer=None):
        torch.save(self, location)
        return location

    # TODO: copied from https://github.com/nussl/nussl/blob/master/nussl/ml/networks/separation_model.py
    def __repr__(self):
        output = super().__repr__()
        num_parameters = 0
        for p in self.parameters():
            if p.requires_grad:
                num_parameters += np.cumprod(p.size())[-1]
        output += '\nNumber of parameters: %d' % num_parameters
        return output

In [10]:
num_features = stft_params.window_length // 2 + 1
num_audio_channels = 1
hidden_size = 50
num_layers = 2
bidirectional = True
dropout = 0.3
num_sources = 1
activation = 'sigmoid'

model = StemSeparationModel(
    num_features=num_features,
    num_audio_channels=num_audio_channels,
    hidden_size=hidden_size,
    num_layers=num_layers,
    bidirectional=bidirectional,
    dropout=dropout,
    num_sources=num_sources,
    activation=activation,
)

In [11]:
model.verbose = True

def process_item(model, item):
    # Convert all tensors in the item dictionary to torch.float32
    for key, tensor in item.items():
        if isinstance(tensor, torch.Tensor):
            item[key] = tensor.to(torch.float32)

    with torch.no_grad():
        output = model(item)

    return output

output = process_item(model, item)
print()
for key in output:
    print(key, type(output[key]), output[key].shape)

model.verbose = False

Shape of mixture_magnitude: torch.Size([1, 1724, 257, 1])
Shape after amplitude to db: torch.Size([1, 1724, 257, 1])
Shape after normalization: torch.Size([1, 1724, 257, 1])
Shape after LSTM: torch.Size([1, 1724, 100])
Shape of mask: torch.Size([1, 1724, 257, 1, 1])
Shape of vocals estimate: torch.Size([1, 1724, 257, 1, 1])

mask <class 'torch.Tensor'> torch.Size([1, 1724, 257, 1, 1])
vocals_estimate <class 'torch.Tensor'> torch.Size([1, 1724, 257, 1, 1])


## Training the Model

In [12]:
# TODO: add a timer here so we can see how long it takes to train for some number of epochs

model.verbose = False

utils.logger()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nussl.ml.train.loss.L1Loss()

def train_step(engine, batch):
    optimizer.zero_grad()
    output = model(batch) # forward pass
    loss = loss_fn(
        output['vocals_estimate'],
        batch['source_magnitudes']
    )
    
    loss.backward() # backwards + gradient step
    optimizer.step()
    
    loss_vals = {
        'L1Loss': loss.item(),
        'loss': loss.item()
    }
    
    return loss_vals

def val_step(engine, batch):
    with torch.no_grad():
        output = model(batch) # forward pass
    loss = loss_fn(
        output['vocals_estimate'],
        batch['source_magnitudes']
    )    
    loss_vals = {
        'L1Loss': loss.item(), 
        'loss': loss.item()
    }
    return loss_vals

# Create the engines
trainer, validator = nussl.ml.train.create_train_and_validation_engines(
    train_step, val_step, device=DEVICE
)

# We'll save the output relative to this notebook.
output_folder = Path('.').absolute()

# Adding handlers from nussl that print out details about model training
# run the validation step, and save the models.
nussl.ml.train.add_stdout_handler(trainer, validator)
nussl.ml.train.add_validate_and_checkpoint(output_folder, model, 
    optimizer, train_data, trainer, val_dataloader, validator)

trainer.run(
    train_dataloader, 
    epoch_length=10, 
    max_epochs=25,
)

  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer
  from t

State:
	iteration: 250
	epoch: 25
	epoch_length: 10
	max_epochs: 25
	output: <class 'dict'>
	batch: <class 'dict'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>
	epoch_history: <class 'dict'>
	iter_history: <class 'dict'>
	past_iter_history: <class 'dict'>
	saved_model_path: /Users/shashankjarmale/Documents/Northeastern/Fall 2024 Semester/CS 5100 Foundations of Artificial Intelligence/Project/Stem-Separator-AMT/stem-separation/checkpoints/best.model.pth
	output_folder: <class 'pathlib.PosixPath'>

## Deployment

In [13]:
checkpoint_path = "checkpoints/best.model.pth"
model = torch.load(checkpoint_path)

  model = torch.load(checkpoint_path)


In [20]:
def process_audio(item, model):    
    # Convert mixture signal to tensor if needed
    if isinstance(item['mix_magnitude'], np.ndarray):
        item['mix_magnitude'] = torch.from_numpy(item['mix_magnitude']).to(DEVICE).float()
    
    # Transpose for model input
    item["mix_magnitude"] = item["mix_magnitude"].transpose(0, 1)
    
    # Get model output (estimate of vocals source)
    with torch.no_grad():
        output = model(item)

    # Process the vocals estimate
    vocals_estimate = output['vocals_estimate']
    if vocals_estimate.dim() == 5:  # Remove extra dimensions
        vocals_estimate = vocals_estimate.squeeze(0).squeeze(-1).squeeze(-1)
    vocals_estimate = vocals_estimate.cpu().data.numpy()

    # Get the original mixture phase
    mix_stft = item['mix'].stft()
    mix_phase = np.angle(mix_stft)

    # Make sure vocals_estimate matches the original mixture phase
    # We want shape to be (freq_bins, time_frames)
    vocals_estimate = vocals_estimate.transpose()
    
    # Match shapes for combining magnitude and phase
    if vocals_estimate.shape[-1] == 1:
        vocals_estimate = vocals_estimate.squeeze(-1)
    if mix_phase.shape[-1] == 1:
        mix_phase = mix_phase.squeeze(-1)

    # Verify shapes of magnitude and phase match exactly
    assert vocals_estimate.shape == mix_phase.shape, f"Shape mismatch: vocals_estimate {vocals_estimate.shape} vs mix_phase {mix_phase.shape}"
    
    # Reconstruct complex STFT
    vocals_estimate_stft = vocals_estimate * np.exp(1j * mix_phase)
    
    # Create new audio signal with the same parameters as the input audio
    new_signal = nussl.AudioSignal(
        stft=vocals_estimate_stft,
        sample_rate=item['mix'].sample_rate,
        stft_params=item['mix'].stft_params
    )

    # Perform inverse STFT
    new_signal.istft()
    
    # Ensure the output length matches the input exactly
    target_length = len(item['mix'].audio_data[0])
    current_length = len(new_signal.audio_data[0])
    
    # Pad the output audio to match the target length, if necessary
    # TODO: there are potentially some bugs we need to fix to prevent this from happening
    if current_length != target_length:
        print(f"WARNING: Length mismatch - target: {target_length}, current: {current_length}")
        if current_length < target_length:
            pad_length = target_length - current_length
            new_signal.audio_data = np.pad(
                new_signal.audio_data, 
                ((0, 0), (0, pad_length)), 
                mode='constant',
            )
        else:
            new_signal.audio_data = new_signal.audio_data[:, :target_length]
    
    print("Final adjusted audio length:", len(new_signal.audio_data[0]))
    return new_signal


item = test_data[0]
new_signal = process_audio(item, model)
new_signal.embed_audio(display=False)

Final adjusted audio length: 220500


ffmpeg version 6.1.1 Copyright (c) 2000-2023 the FFmpeg developers
  built with clang version 14.0.6
  configuration: --prefix=/opt/anaconda3/envs/stem-separation --cc=arm64-apple-darwin20.0.0-clang --ar=arm64-apple-darwin20.0.0-ar --nm=arm64-apple-darwin20.0.0-nm --ranlib=arm64-apple-darwin20.0.0-ranlib --strip=arm64-apple-darwin20.0.0-strip --disable-doc --enable-swresample --enable-swscale --enable-openssl --enable-libxml2 --enable-libtheora --enable-demuxer=dash --enable-postproc --enable-hardcoded-tables --enable-libfreetype --enable-libharfbuzz --enable-libfontconfig --enable-libdav1d --enable-zlib --enable-libaom --enable-pic --enable-shared --disable-static --disable-gpl --enable-version3 --disable-sdl2 --enable-libopenh264 --enable-libopus --enable-libmp3lame --enable-libopenjpeg --enable-libvorbis --enable-pthreads --enable-libtesseract --enable-libvpx --enable-librsvg
  libavutil      58. 29.100 / 58. 29.100
  libavcodec     60. 31.102 / 60. 31.102
  libavformat    60. 16.10

In [21]:
print('true mixture')
display.display(Audio(data=item['mix'].audio_data, rate=item['mix'].sample_rate))

for stem_label in item['sources'].keys():
    print(f"true {stem_label}")
    audio_player = Audio(data=item['sources'][stem_label].audio_data, rate=item['sources'][stem_label].sample_rate)
    display.display(audio_player)

true mixture


true accompaniment


true vocals


## Evaluation