In [1]:
%%capture
!pip install nussl
!pip install git+https://github.com/source-separation/tutorial

In [140]:
%%capture
from common import data, utils, viz

import json
import copy
import argbind
from pathlib import Path

import nussl
from nussl.ml.networks.modules import AmplitudeToDB, BatchNorm, RecurrentStack, Embedding
from nussl.datasets import transforms as nussl_tfm

import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

## Preparing the Data

In [3]:
# Prepare MUSDB
data.prepare_musdb('~/.nussl/tutorial/')

100%|██████████| 80/80 [00:24<00:00,  3.23it/s]
100%|██████████| 14/14 [00:04<00:00,  3.24it/s]
100%|██████████| 50/50 [00:15<00:00,  3.20it/s]


In [74]:
stft_params = nussl.STFTParams(window_length=512, hop_length=128, window_type='sqrt_hann')

tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),  # TODO: currently trying to output only the vocals source from the model
    nussl_tfm.MagnitudeSpectrumApproximation(),
    nussl_tfm.IndexSources('source_magnitudes', 1),
    nussl_tfm.ToSeparationModel(),
])

In [75]:
train_folder = "~/.nussl/tutorial/train"
val_folder = "~/.nussl/tutorial/valid"

MAX_MIXTURES = int(1e8) # Set to some impossibly high number for on-the-fly mixing.

train_data = data.on_the_fly(stft_params, transform=tfm, fg_path=train_folder, num_mixtures=MAX_MIXTURES, coherent_prob=1.0)
train_dataloader = torch.utils.data.DataLoader(train_data, num_workers=1, batch_size=10)

val_data = data.on_the_fly(stft_params, transform=tfm, fg_path=val_folder, num_mixtures=10, coherent_prob=1.0)
val_dataloader = torch.utils.data.DataLoader(val_data, num_workers=1, batch_size=10)

In [76]:
item = train_data[0]
for key in item:
    print(key, type(item[key]), item[key].shape if isinstance(item[key], torch.Tensor) else "")

index <class 'int'> 
mix_magnitude <class 'torch.Tensor'> torch.Size([1724, 257, 1])
ideal_binary_mask <class 'torch.Tensor'> torch.Size([1724, 257, 1, 2])
source_magnitudes <class 'torch.Tensor'> torch.Size([1724, 257, 1, 1])


In [10]:
test_tfm = nussl_tfm.Compose([
    nussl_tfm.SumSources([['bass', 'drums', 'other']]),  # TODO: currently trying to extract only the vocals source
])

test_folder = "~/.nussl/tutorial/test"
test_data = data.on_the_fly(stft_params, transform=test_tfm, fg_path=test_folder, num_mixtures=100)

## Defining the Model

In [None]:
# TODO: copied from https://github.com/nussl/nussl/blob/master/nussl/ml/networks/separation_model.py


def _remove_cache_from_tfms(transforms):
    """Helper function to remove cache from transforms.
    """
    transforms = copy.deepcopy(transforms)

    if isinstance(transforms, nusssl.datasets.transforms.Compose):
        for t in transforms.transforms:
            if isinstance(t, nussl.datasets.transforms.Cache):
                transforms.transforms.remove(t)

    return transforms


def _prep_metadata(metadata):
    """Helper function for preparing metadata before saving a model.
    """
    metadata = copy.deepcopy(metadata)
    if 'transforms' in metadata:
        metadata['transforms'] = _remove_cache_from_tfms(metadata['transforms'])
    return metadata

In [None]:
class StemSeparationModel(nn.Module):
    def __init__(self, num_features, num_audio_channels, hidden_size,
                 num_layers, bidirectional, dropout, num_sources, 
                activation='sigmoid'):
        super().__init__()
        
        self.verbose = False

        self.amplitude_to_db = AmplitudeToDB()
        self.input_normalization = BatchNorm(num_features)
        self.recurrent_stack = RecurrentStack(
            num_features * num_audio_channels, hidden_size, 
            num_layers, bool(bidirectional), dropout
        )
        hidden_size = hidden_size * (int(bidirectional) + 1)
        self.embedding = Embedding(num_features, hidden_size, 
                                   num_sources, activation, 
                                   num_audio_channels)
        
        self.set_up_config(num_features, num_audio_channels, hidden_size,
                 num_layers, bidirectional, dropout, num_sources, 
                activation)

    def set_up_config(self, num_features, num_audio_channels, hidden_size,
                 num_layers, bidirectional, dropout, num_sources, 
                activation='sigmoid'):
        modules = {
            'model': {
                'class': 'StemSeparationModel',
                'args': {
                    'num_features': num_features,
                    'num_audio_channels': num_audio_channels,
                    'hidden_size': hidden_size,
                    'num_layers': num_layers,
                    'bidirectional': bidirectional,
                    'dropout': dropout,
                    'num_sources': num_sources,
                    'activation': activation,
                }
            }
        }

        connections = [
            ['model', ['mix_magnitude']]
        ]

        for key in ['mask', 'vocal_estimate']:
            modules[key] = {'class': 'Alias'}
            connections.append([key, [f'model:{key}']])
        
        output = ['vocal_estimate', 'mask',]
        self.config = {
            'name': 'StemSeparationModel',
            'modules': modules,
            'connections': connections,
            'output': output,
        }
        self.metadata = {
            'config': self.config,
            'nussl_version': '0.0.1',
        }

    def log(self, s):
        if self.verbose:
            print(s)

    def forward(self, item):
        # Get magnitude of mixture signal
        mixture_magnitude = item['mix_magnitude']
        if mixture_magnitude.dim() == 3:
            mixture_magnitude = mixture_magnitude.unsqueeze(0)  # Add a batch dimension to the mixture magnitude if needed
        self.log(f"Shape of mixture_magnitude: {mixture_magnitude.shape}")

        # Convert to log amplitude
        mixture_log_amplitude = self.amplitude_to_db(mixture_magnitude)
        self.log(f"Shape after amplitude to db: {mixture_log_amplitude.shape}")
        
        # Normalize the data
        normalized = self.input_normalization(mixture_log_amplitude)
        self.log(f"Shape after normalization: {normalized.shape}")

        # Pass through LSTM
        output = self.recurrent_stack(normalized)
        self.log(f"Shape after LSTM: {output.shape}")

        # Generate mask
        mask = self.embedding(output)
        self.log(f"Shape of mask: {mask.shape}")
    
        # Apply mask to get estimates
        # TODO: right now this model is defined just to output an estimate of the vocals source - later can expand this to calculate a mask for every source
        vocals_estimate = mixture_magnitude.unsqueeze(-1) * mask
        self.log(f"Shape of vocals estimate: {vocals_estimate.shape}")

        return {
            'mask': mask,
            'vocals_estimate': vocals_estimate,
        }

    # TODO: copied from https://github.com/nussl/nussl/blob/master/nussl/ml/networks/separation_model.py
    def save(self, location, metadata=None, train_data=None, 
             val_data=None, trainer=None):
        """
        Saves a SeparationModel into a location into a dictionary with the
        weights and model configuration.
        Args:
            location: (str) Where you want the model saved, as a path.
            metadata: (dict) Additional metadata to save along with the model. By default,
                model config and nussl version is saved as metadata.
            train_data: (BaseDataset) Dataset used for training. Metadata will be extracted
                from this object if it is passed into the save function, and saved 
                alongside the model.
            val_data: (BaseDataset) Dataset used for validation. Metadata will be extracted
                from this object if it is passed into the save function, and saved 
                alongside the model.
            trainer: (ignite.Engine) Engine used for training. Metadata will be extracted
                from this object if it is passed into the save function, and saved alongside
                the model.

        Returns:
            (str): where the model was saved.

        """
        save_dict = {
            'state_dict': self.state_dict(),
            'config': json.dumps(self.config)
        }

        metadata = metadata if metadata else {}
        metadata.update(self.metadata)

        if train_data is not None:
            dataset_metadata = {
                'stft_params': train_data.stft_params,
                'sample_rate': train_data.sample_rate,
                'num_channels': train_data.num_channels,
                'train_dataset': _prep_metadata(train_data.metadata),
            }
            metadata.update(dataset_metadata)

        try:
            metadata['val_dataset'] = _prep_metadata(val_data.metadata)
        except: # pragma: no cover
            pass
        
        if trainer is not None:
            train_metadata = {
                'trainer.state_dict': {
                    'epoch': trainer.state.epoch,
                    'epoch_length': trainer.state.epoch_length,
                    'max_epochs': trainer.state.max_epochs,
                    'output': trainer.state.output,
                    'metrics': trainer.state.metrics,
                    'seed': trainer.state.seed,
                },
                'trainer.state.epoch_history': trainer.state.epoch_history,
            }
            metadata.update(train_metadata)

        save_dict = {**save_dict, 'metadata': metadata}
        torch.save(save_dict, location)
        return location

    # TODO: copied from https://github.com/nussl/nussl/blob/master/nussl/ml/networks/separation_model.py
    def __repr__(self):
        output = super().__repr__()
        num_parameters = 0
        for p in self.parameters():
            if p.requires_grad:
                num_parameters += np.cumprod(p.size())[-1]
        output += '\nNumber of parameters: %d' % num_parameters
        return output

In [137]:
num_features = stft_params.window_length // 2 + 1
num_audio_channels = 1
hidden_size = 50
num_layers = 2
bidirectional = True
dropout = 0.3
num_sources = 1
activation = 'sigmoid'

model = StemSeparationModel(
    num_features=num_features,
    num_audio_channels=num_audio_channels,
    hidden_size=hidden_size,
    num_layers=num_layers,
    bidirectional=bidirectional,
    dropout=dropout,
    num_sources=num_sources,
    activation=activation,
)

In [138]:
model.verbose = True

def process_item(model, item):
    # Convert all tensors in the item dictionary to torch.float32
    for key, tensor in item.items():
        if isinstance(tensor, torch.Tensor):
            item[key] = tensor.to(torch.float32)

    with torch.no_grad():
        output = model(item)

    return output

output = process_item(model, item)
print()
for key in output:
    print(key, type(output[key]), output[key].shape)

model.verbose = False

Shape of mixture_magnitude: torch.Size([1, 1724, 257, 1])
Shape after amplitude to db: torch.Size([1, 1724, 257, 1])
Shape after normalization: torch.Size([1, 1724, 257, 1])
Shape after LSTM: torch.Size([1, 1724, 100])
Shape of mask: torch.Size([1, 1724, 257, 1, 1])
Shape of vocals estimate: torch.Size([1, 1724, 257, 1, 1])

mask <class 'torch.Tensor'> torch.Size([1, 1724, 257, 1, 1])
vocals_estimate <class 'torch.Tensor'> torch.Size([1, 1724, 257, 1, 1])


## Training the Model

In [141]:
model.verbose = False

utils.logger()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nussl.ml.train.loss.L1Loss()

def train_step(engine, batch):
    optimizer.zero_grad()
    output = model(batch) # forward pass
    loss = loss_fn(
        output['vocals_estimate'],
        batch['source_magnitudes']
    )
    
    loss.backward() # backwards + gradient step
    optimizer.step()
    
    loss_vals = {
        'L1Loss': loss.item(),
        'loss': loss.item()
    }
    
    return loss_vals

def val_step(engine, batch):
    with torch.no_grad():
        output = model(batch) # forward pass
    loss = loss_fn(
        output['vocals_estimate'],
        batch['source_magnitudes']
    )    
    loss_vals = {
        'L1Loss': loss.item(), 
        'loss': loss.item()
    }
    return loss_vals

# Create the engines
trainer, validator = nussl.ml.train.create_train_and_validation_engines(
    train_step, val_step, device=DEVICE
)

# We'll save the output relative to this notebook.
output_folder = Path('.').absolute()

# Adding handlers from nussl that print out details about model training
# run the validation step, and save the models.
nussl.ml.train.add_stdout_handler(trainer, validator)
nussl.ml.train.add_validate_and_checkpoint(output_folder, model, 
    optimizer, train_data, trainer, val_dataloader, validator)

trainer.run(
    train_dataloader, 
    epoch_length=10, 
    max_epochs=1
)

  from torch.distributed.optim import ZeroRedundancyOptimizer
  from torch.distributed.optim import ZeroRedundancyOptimizer


State:
	iteration: 10
	epoch: 1
	epoch_length: 10
	max_epochs: 1
	output: <class 'dict'>
	batch: <class 'dict'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>
	epoch_history: <class 'dict'>
	iter_history: <class 'dict'>
	past_iter_history: <class 'dict'>
	saved_model_path: /Users/shashankjarmale/Documents/Northeastern/Fall 2024 Semester/CS 5100 Foundations of Artificial Intelligence/Project/Stem-Separator-AMT/stem-separation/checkpoints/best.model.pth
	output_folder: <class 'pathlib.PosixPath'>

## Deployment