In [1]:
!pip install musdb
# !pip install librosa
# !pip install museval

Collecting musdb
  Downloading musdb-0.4.0-py2.py3-none-any.whl (29 kB)
Collecting stempeg>=0.2.3
  Downloading stempeg-0.2.3-py3-none-any.whl (963 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m963.5/963.5 kB[0m [31m36.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ffmpeg-python>=0.2.0
  Downloading ffmpeg_python-0.2.0-py3-none-any.whl (25 kB)
Installing collected packages: ffmpeg-python, stempeg, musdb
Successfully installed ffmpeg-python-0.2.0 musdb-0.4.0 stempeg-0.2.3
[0m

In [2]:
import os
import random
import math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import musdb
import warnings
# import librosa
from tqdm.auto import tqdm
from torchaudio.transforms import Spectrogram, MelSpectrogram, InverseSpectrogram

In [3]:
class MUSDB18Dataset(Dataset):
    def __init__(self, musdb18_root, subset="train", split = "train", sr=44100, duration=2.0, seed = 42):
        self.sr = sr
        self.duration = duration
        self.seed = seed
        if(subset == 'test'):
            self.musdb18 = musdb.DB(root = musdb18_root, subsets=subset)
            print(f"Number of Testing Samples: {len(self.musdb18)}")
        else:
            self.musdb18 = musdb.DB(root = musdb18_root, subsets = subset, split = split)
            if split == "train":
                print(f"Number of Training Samples: {len(self.musdb18)}")
            else:
                print(f"Number of Validation Samples: {len(self.musdb18)}")
        
    def __getitem__(self, index):
        random.seed(self.seed)
        track = self.musdb18.tracks[index]
        track.chunk_duration = self.duration
        track.chunk_start = random.uniform(0, track.duration - track.chunk_duration)
        
        # Load the mixture waveform
        audio = track.audio.T
        audio = audio.mean(axis = 0, keepdims = True)
        audio_tensor = torch.from_numpy(audio).float()
        
        # Load the target waveform (vocals)
        vocals = track.targets["vocals"].audio.T
        vocals = vocals.mean(axis = 0, keepdims = True)
        others = track.targets["accompaniment"].audio.T
        others = others.mean(axis = 0, keepdims = True)
        
        vocals_tensor = torch.from_numpy(np.concatenate((vocals, others), axis = 0)).float()
        
        return audio_tensor, vocals_tensor

    def __len__(self):
        return len(self.musdb18.tracks)

In [4]:
def get_musdb18_dataloaders(root, batch_size=8, num_workers=0, seed = 42):
    train_set = MUSDB18Dataset(root, subset="train")
    val_set = MUSDB18Dataset(root, subset="train", split = 'valid')
    test_set = MUSDB18Dataset(root, subset="test")

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, test_loader

In [5]:
class TransformerModel(nn.Module):
    def __init__(self, d_model = 512, num_heads = 4, num_layers = 4, dropout = 0.1):
        super().__init__()
        
        # input embedding layer
        self.input_embed = nn.Linear(101, d_model)

        # transformer encoder layers
        encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dropout=dropout)
        self.encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
        
        decoder_layers = nn.TransformerDecoderLayer(d_model=d_model, nhead = num_heads, dropout=dropout)
        self.decoder = nn.TransformerDecoder(decoder_layers, num_layers = num_layers//2)
                
        # output embedding layer
        self.output_embed = nn.Linear(d_model, 202)
        
        self.transform = Spectrogram(n_fft = 200)
        self.invtransform = InverseSpectrogram(n_fft = 200)
        
    def forward(self, x):
        
        length = x.size(-1)
        
        x = self.transform(x)
        x = x.permute(3, 0, 1, 2)
        x = x.reshape(x.size(0), -1, x.size(-1))
        
        # apply input embedding layer
        x = self.input_embed(x)
        
        # apply transformer encoder layers
        x = self.encoder(x)
        memory = x
        
        # apply transformer decoder layers
        x = self.decoder(x, memory)
        
        # apply output embedding layer
        x = self.output_embed(x)
        
        # reshape x back to (batch_size, num_channels, num_frames, num_bins)
        x = x.reshape(x.size(1), 2, -1, x.size(0))
        
        y = self.invtransform(x.type(torch.complex64), length)
        
        return y

In [6]:
def train(model, dataloader, optimizer, loss_fn):
    model.train()
    total_loss = 0

    for i, (audio, vocals) in enumerate(dataloader):
        # move data to GPU if available
        if torch.cuda.is_available():
            audio = audio.cuda()
            vocals = vocals.cuda()

        # forward pass
        prediction = model(audio)
        loss = loss_fn(prediction, vocals)

        # backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update total loss
        total_loss += loss.item()

    # compute average loss
    avg_loss = total_loss / len(dataloader)

    return avg_loss

In [7]:
def validate(model, dataloader, loss_fn):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for i, (audio, vocals) in enumerate(dataloader):
            # move data to GPU if available
            if torch.cuda.is_available():
                audio = audio.cuda()
                vocals = vocals.cuda()

            # forward pass
            prediction = model(audio)
            loss = loss_fn(prediction, vocals)

            # update total loss
            total_loss += loss.item()

    # compute average loss
    avg_loss = total_loss / len(dataloader)

    return avg_loss

In [8]:
def calculate_sdr(predicted_output, ground_truth):
    """
    Calculates the Signal-to-Distortion Ratio (SDR) metric between a predicted output and its corresponding ground truth.

    Args:
    predicted_output: A numpy array of shape (number of channels, number of frames, number of bins).
    ground_truth: A numpy array of shape (number of channels, number of frames, number of bins).

    Returns:
    sdr: A scalar representing the SDR value between predicted_output and ground_truth.
    """
    eps = np.finfo(np.float32).eps  # To avoid division by zero errors
    num_channels = predicted_output.shape[0]
    sdr_sum = 0
    
#     print(predicted_output.shape, ground_truth.shape)
    
    for c in range(num_channels):
        # Compute the power of the true source signal
        true_source_power = np.sum(ground_truth[c]**2)

        # Compute the scalar product between true source signal and predicted signal
        true_pred_scalar = np.sum(ground_truth[c] * predicted_output[c])

        # Compute the SDR for this channel
        sdr = 10 * np.log10(true_source_power / (np.sum(ground_truth[c]**2) - true_pred_scalar + eps) + eps)
        if not math.isnan(sdr):
            sdr_sum += sdr

    # Compute the average SDR across all channels
    sdr = sdr_sum / num_channels

    return -sdr

In [9]:
def test(model, dataloader):
    model.eval()
    all_targets = []
    all_predictions = []
    sdr = 0
    count = 0
    with torch.no_grad():
        for i, (audio, vocals) in tqdm(enumerate(dataloader)):
            # move data to GPU if available
            if torch.cuda.is_available():
                audio = audio.cuda()
                vocals = vocals.cuda()

            # forward pass
            prediction = model(audio)
            
            # convert predictions and ground truth to numpy arrays
            prediction = prediction.cpu().numpy()
            vocals = vocals.cpu().numpy()
            
            for i in range(prediction.shape[0]):
                pred = prediction[i,:,:]
                targ = vocals[i,:,:]
                sdr += calculate_sdr(pred, targ)
                count += 1
    mean_sdr = sdr/count
    return mean_sdr

In [10]:
def fit(model, optimizer, loss_fn, num_epochs, train_dataloader, val_dataloader, test_loader):
    for epoch in tqdm(range(num_epochs)):
        # train model for one epoch
        train_loss = train(model, train_dataloader, optimizer, loss_fn)
        print(f'Epoch {epoch + 1} - Train loss: {train_loss:.4f}')

        # evaluate model on validation data
        val_loss = validate(model, val_dataloader, loss_fn)
        print(f'Epoch {epoch + 1} - Val loss: {val_loss:.4f}')
        print("---------------------------------------------")
        
        if ((epoch+1)%20 == 0):
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=RuntimeWarning)    
                print(f"Epoch: {epoch+1} | Training SDR: {test(model, train_loader)}")
                print(f"Epoch: {epoch+1} | Validation SDR: {test(model, val_loader)}")
                print(f"Epoch: {epoch+1} | Testing SDR: {test(model, test_loader)}")

In [11]:
root = "/kaggle/input/musdb18/musdb18"
train_loader, val_loader, test_loader = get_musdb18_dataloaders(root, batch_size = 2)
# for x, y in val_loader:
#     print(x.shape, y.shape)

Number of Training Samples: 87
Number of Validation Samples: 13
Number of Testing Samples: 50


In [12]:
# define model and optimizer
model = TransformerModel(num_heads = 8, num_layers = 6)
# model = PretrainedTransformer(output_size=257, num_heads = 8, num_layers = 6)
if torch.cuda.is_available():
    model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# define loss function
loss_fn = nn.MSELoss()

fit(model, optimizer, loss_fn, 150, train_loader, val_loader, test_loader)

  0%|          | 0/150 [00:00<?, ?it/s]

Epoch 1 - Train loss: 0.0087
Epoch 1 - Val loss: 0.0081
---------------------------------------------
Epoch 2 - Train loss: 0.0085
Epoch 2 - Val loss: 0.0081
---------------------------------------------
Epoch 3 - Train loss: 0.0085
Epoch 3 - Val loss: 0.0081
---------------------------------------------
Epoch 4 - Train loss: 0.0086
Epoch 4 - Val loss: 0.0081
---------------------------------------------
Epoch 5 - Train loss: 0.0085
Epoch 5 - Val loss: 0.0081
---------------------------------------------
Epoch 6 - Train loss: 0.0086
Epoch 6 - Val loss: 0.0081
---------------------------------------------
Epoch 7 - Train loss: 0.0084
Epoch 7 - Val loss: 0.0081
---------------------------------------------
Epoch 8 - Train loss: 0.0085
Epoch 8 - Val loss: 0.0081
---------------------------------------------
Epoch 9 - Train loss: 0.0084
Epoch 9 - Val loss: 0.0081
---------------------------------------------
Epoch 10 - Train loss: 0.0084
Epoch 10 - Val loss: 0.0081
------------------------

0it [00:00, ?it/s]

Epoch: 20 | Training SDR: 4.381951302589786


0it [00:00, ?it/s]

Epoch: 20 | Validation SDR: 2.8559357632980644


0it [00:00, ?it/s]

Epoch: 20 | Testing SDR: 3.4786836161328436
Epoch 21 - Train loss: 0.0084
Epoch 21 - Val loss: 0.0081
---------------------------------------------
Epoch 22 - Train loss: 0.0084
Epoch 22 - Val loss: 0.0081
---------------------------------------------
Epoch 23 - Train loss: 0.0084
Epoch 23 - Val loss: 0.0081
---------------------------------------------
Epoch 24 - Train loss: 0.0085
Epoch 24 - Val loss: 0.0081
---------------------------------------------
Epoch 25 - Train loss: 0.0084
Epoch 25 - Val loss: 0.0081
---------------------------------------------
Epoch 26 - Train loss: 0.0084
Epoch 26 - Val loss: 0.0081
---------------------------------------------
Epoch 27 - Train loss: 0.0085
Epoch 27 - Val loss: 0.0081
---------------------------------------------
Epoch 28 - Train loss: 0.0085
Epoch 28 - Val loss: 0.0081
---------------------------------------------
Epoch 29 - Train loss: 0.0084
Epoch 29 - Val loss: 0.0081
---------------------------------------------
Epoch 30 - Train los

0it [00:00, ?it/s]

Epoch: 40 | Training SDR: 4.37362010303388


0it [00:00, ?it/s]

Epoch: 40 | Validation SDR: 2.604438680902009


0it [00:00, ?it/s]

Epoch: 40 | Testing SDR: 3.3811675321711108
Epoch 41 - Train loss: 0.0085
Epoch 41 - Val loss: 0.0081
---------------------------------------------
Epoch 42 - Train loss: 0.0084
Epoch 42 - Val loss: 0.0081
---------------------------------------------
Epoch 43 - Train loss: 0.0085
Epoch 43 - Val loss: 0.0081
---------------------------------------------
Epoch 44 - Train loss: 0.0085
Epoch 44 - Val loss: 0.0081
---------------------------------------------
Epoch 45 - Train loss: 0.0085
Epoch 45 - Val loss: 0.0081
---------------------------------------------
Epoch 46 - Train loss: 0.0085
Epoch 46 - Val loss: 0.0081
---------------------------------------------
Epoch 47 - Train loss: 0.0084
Epoch 47 - Val loss: 0.0081
---------------------------------------------
Epoch 48 - Train loss: 0.0084
Epoch 48 - Val loss: 0.0081
---------------------------------------------
Epoch 49 - Train loss: 0.0084
Epoch 49 - Val loss: 0.0081
---------------------------------------------
Epoch 50 - Train los

0it [00:00, ?it/s]

Epoch: 60 | Training SDR: 4.380830574204235


0it [00:00, ?it/s]

Epoch: 60 | Validation SDR: 2.7118411976705548


0it [00:00, ?it/s]

Epoch: 60 | Testing SDR: 3.501611276765067
Epoch 61 - Train loss: 0.0085
Epoch 61 - Val loss: 0.0081
---------------------------------------------
Epoch 62 - Train loss: 0.0084
Epoch 62 - Val loss: 0.0081
---------------------------------------------
Epoch 63 - Train loss: 0.0085
Epoch 63 - Val loss: 0.0081
---------------------------------------------
Epoch 64 - Train loss: 0.0084
Epoch 64 - Val loss: 0.0081
---------------------------------------------
Epoch 65 - Train loss: 0.0084
Epoch 65 - Val loss: 0.0081
---------------------------------------------
Epoch 66 - Train loss: 0.0084
Epoch 66 - Val loss: 0.0081
---------------------------------------------
Epoch 67 - Train loss: 0.0084
Epoch 67 - Val loss: 0.0081
---------------------------------------------
Epoch 68 - Train loss: 0.0085
Epoch 68 - Val loss: 0.0081
---------------------------------------------
Epoch 69 - Train loss: 0.0085
Epoch 69 - Val loss: 0.0081
---------------------------------------------
Epoch 70 - Train loss

0it [00:00, ?it/s]

Epoch: 80 | Training SDR: 4.37803028384051


0it [00:00, ?it/s]

Epoch: 80 | Validation SDR: 2.8329063120176774


0it [00:00, ?it/s]

Epoch: 80 | Testing SDR: 3.4517191176068422
Epoch 81 - Train loss: 0.0084
Epoch 81 - Val loss: 0.0081
---------------------------------------------
Epoch 82 - Train loss: 0.0084
Epoch 82 - Val loss: 0.0081
---------------------------------------------
Epoch 83 - Train loss: 0.0084
Epoch 83 - Val loss: 0.0081
---------------------------------------------
Epoch 84 - Train loss: 0.0085
Epoch 84 - Val loss: 0.0081
---------------------------------------------
Epoch 85 - Train loss: 0.0084
Epoch 85 - Val loss: 0.0081
---------------------------------------------
Epoch 86 - Train loss: 0.0085
Epoch 86 - Val loss: 0.0081
---------------------------------------------
Epoch 87 - Train loss: 0.0084
Epoch 87 - Val loss: 0.0081
---------------------------------------------
Epoch 88 - Train loss: 0.0085
Epoch 88 - Val loss: 0.0081
---------------------------------------------
Epoch 89 - Train loss: 0.0085
Epoch 89 - Val loss: 0.0081
---------------------------------------------
Epoch 90 - Train los

0it [00:00, ?it/s]

Epoch: 100 | Training SDR: 4.382703280385917


0it [00:00, ?it/s]

Epoch: 100 | Validation SDR: 2.851630327315451


0it [00:00, ?it/s]

Epoch: 100 | Testing SDR: 3.5017789044590972
Epoch 101 - Train loss: 0.0084
Epoch 101 - Val loss: 0.0081
---------------------------------------------
Epoch 102 - Train loss: 0.0084
Epoch 102 - Val loss: 0.0081
---------------------------------------------
Epoch 103 - Train loss: 0.0085
Epoch 103 - Val loss: 0.0081
---------------------------------------------
Epoch 104 - Train loss: 0.0084
Epoch 104 - Val loss: 0.0081
---------------------------------------------
Epoch 105 - Train loss: 0.0085
Epoch 105 - Val loss: 0.0081
---------------------------------------------
Epoch 106 - Train loss: 0.0084
Epoch 106 - Val loss: 0.0081
---------------------------------------------
Epoch 107 - Train loss: 0.0084
Epoch 107 - Val loss: 0.0081
---------------------------------------------
Epoch 108 - Train loss: 0.0084
Epoch 108 - Val loss: 0.0081
---------------------------------------------
Epoch 109 - Train loss: 0.0084
Epoch 109 - Val loss: 0.0081
---------------------------------------------
E

0it [00:00, ?it/s]

Epoch: 120 | Training SDR: 4.368817907445449


0it [00:00, ?it/s]

Epoch: 120 | Validation SDR: 2.658878710767567


0it [00:00, ?it/s]

Epoch: 120 | Testing SDR: 3.4365972452315305
Epoch 121 - Train loss: 0.0085
Epoch 121 - Val loss: 0.0081
---------------------------------------------
Epoch 122 - Train loss: 0.0085
Epoch 122 - Val loss: 0.0081
---------------------------------------------
Epoch 123 - Train loss: 0.0084
Epoch 123 - Val loss: 0.0081
---------------------------------------------
Epoch 124 - Train loss: 0.0084
Epoch 124 - Val loss: 0.0081
---------------------------------------------
Epoch 125 - Train loss: 0.0085
Epoch 125 - Val loss: 0.0081
---------------------------------------------
Epoch 126 - Train loss: 0.0085
Epoch 126 - Val loss: 0.0081
---------------------------------------------
Epoch 127 - Train loss: 0.0084
Epoch 127 - Val loss: 0.0081
---------------------------------------------
Epoch 128 - Train loss: 0.0085
Epoch 128 - Val loss: 0.0081
---------------------------------------------
Epoch 129 - Train loss: 0.0085
Epoch 129 - Val loss: 0.0081
---------------------------------------------
E

0it [00:00, ?it/s]

Epoch: 140 | Training SDR: 4.371554973340649


0it [00:00, ?it/s]

Epoch: 140 | Validation SDR: 2.928474809066214


0it [00:00, ?it/s]

Epoch: 140 | Testing SDR: 3.4561650745970267
Epoch 141 - Train loss: 0.0085
Epoch 141 - Val loss: 0.0081
---------------------------------------------
Epoch 142 - Train loss: 0.0085
Epoch 142 - Val loss: 0.0081
---------------------------------------------
Epoch 143 - Train loss: 0.0085
Epoch 143 - Val loss: 0.0081
---------------------------------------------
Epoch 144 - Train loss: 0.0084
Epoch 144 - Val loss: 0.0081
---------------------------------------------
Epoch 145 - Train loss: 0.0085
Epoch 145 - Val loss: 0.0081
---------------------------------------------
Epoch 146 - Train loss: 0.0085
Epoch 146 - Val loss: 0.0081
---------------------------------------------
Epoch 147 - Train loss: 0.0085
Epoch 147 - Val loss: 0.0081
---------------------------------------------
Epoch 148 - Train loss: 0.0084
Epoch 148 - Val loss: 0.0081
---------------------------------------------
Epoch 149 - Train loss: 0.0085
Epoch 149 - Val loss: 0.0081
---------------------------------------------
E

In [13]:
print(f"Training MSE: {validate(model, train_loader, loss_fn)}")
print(f"Validation MSE: {validate(model, val_loader, loss_fn)}")
print(f"Testing MSE: {validate(model, test_loader, loss_fn)}")

Training MSE: 0.008412279839500447
Validation MSE: 0.008107607452464955
Testing MSE: 0.0066387111600488425


In [14]:
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=RuntimeWarning)    
    print(f"Training SDR: {test(model, train_loader)}")
    print(f"Validation SDR: {test(model, val_loader)}")
    print(f"Testing SDR: {test(model, test_loader)}")

0it [00:00, ?it/s]

Training SDR: 4.369476915549454


0it [00:00, ?it/s]

Validation SDR: 2.8677127313681163


0it [00:00, ?it/s]

Testing SDR: 3.4226250170195756
