In [None]:
!pip install musdb
# !pip install librosa
# !pip install museval

[0m

In [7]:
import os
import random
import math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import musdb
# import librosa
from tqdm.auto import tqdm
from torchaudio.transforms import Spectrogram, MelSpectrogram, InverseSpectrogram

In [8]:
class MUSDB18Dataset(Dataset):
    def __init__(self, musdb18_root, subset="train", split = "train", sr=44100, duration=2.0, seed = 42):
        self.sr = sr
        self.duration = duration
        self.seed = seed
        if(subset == 'test'):
            self.musdb18 = musdb.DB(root = musdb18_root, subsets=subset)
            print(f"Number of Testing Samples: {len(self.musdb18)}")
        else:
            self.musdb18 = musdb.DB(root = musdb18_root, subsets = subset, split = split)
            if split == "train":
                print(f"Number of Training Samples: {len(self.musdb18)}")
            else:
                print(f"Number of Validation Samples: {len(self.musdb18)}")
        
    def __getitem__(self, index):
        random.seed(self.seed)
        track = self.musdb18.tracks[index]
        track.chunk_duration = self.duration
        track.chunk_start = random.uniform(0, track.duration - track.chunk_duration)
        
        # Load the mixture waveform
        audio = track.audio.T
        audio = audio.mean(axis = 0, keepdims = True)
        audio_tensor = torch.from_numpy(audio).float()
        
        # Load the target waveform (vocals)
        vocals = track.targets["vocals"].audio.T
        vocals = vocals.mean(axis = 0, keepdims = True)
        others = track.targets["accompaniment"].audio.T
        others = others.mean(axis = 0, keepdims = True)
        
        vocals_tensor = torch.from_numpy(np.concatenate((vocals, others), axis = 0)).float()
        
        return audio_tensor, vocals_tensor

    def __len__(self):
        return len(self.musdb18.tracks)

In [9]:
def get_musdb18_dataloaders(root, batch_size=8, num_workers=0, seed = 42):
    train_set = MUSDB18Dataset(root, subset="train")
    val_set = MUSDB18Dataset(root, subset="train", split = 'valid')
    test_set = MUSDB18Dataset(root, subset="test")

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, test_loader

In [50]:
class TransformerModel(nn.Module):
    def __init__(self, d_model = 512, num_heads = 8, num_layers = 6, dropout = 0.1):
        super().__init__()
        
        # input embedding layer
        self.input_embed = nn.Linear(101, d_model)

        # transformer encoder layers
        encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dropout=dropout)
        self.encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
        
        decoder_layers = nn.TransformerDecoderLayer(d_model=d_model, nhead = num_heads, dropout=dropout)
        self.decoder = nn.TransformerDecoder(decoder_layers, num_layers = num_layers//2)
                
        # output embedding layer
        self.output_embed = nn.Linear(d_model, 202)
        
        self.transform = Spectrogram(n_fft = 200)
        self.invtransform = InverseSpectrogram(n_fft = 200)
        
    def forward(self, x):
        
        length = x.size(-1)
        
        x = self.transform(x)
        x = x.permute(3, 0, 1, 2)
        x = x.reshape(x.size(0), -1, x.size(-1))
        
        # apply input embedding layer
        x = self.input_embed(x)
        
        # apply transformer encoder layers
        x = self.encoder(x)
        memory = x
        
        # apply transformer decoder layers
        x = self.decoder(x, memory)
        
        # apply output embedding layer
        x = self.output_embed(x)
        
        # reshape x back to (batch_size, num_channels, num_frames, num_bins)
        x = x.reshape(x.size(1), 2, -1, x.size(0))
        
        y = self.invtransform(x.type(torch.complex64), length)
        
        return y

In [51]:
def train(model, dataloader, optimizer, loss_fn):
    model.train()
    total_loss = 0

    for i, (audio, vocals) in enumerate(dataloader):
        # move data to GPU if available
        if torch.cuda.is_available():
            audio = audio.cuda()
            vocals = vocals.cuda()

        # forward pass
        prediction = model(audio)
        loss = loss_fn(prediction, vocals)

        # backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update total loss
        total_loss += loss.item()

    # compute average loss
    avg_loss = total_loss / len(dataloader)

    return avg_loss

In [52]:
def validate(model, dataloader, loss_fn):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for i, (audio, vocals) in enumerate(dataloader):
            # move data to GPU if available
            if torch.cuda.is_available():
                audio = audio.cuda()
                vocals = vocals.cuda()

            # forward pass
            prediction = model(audio)
            loss = loss_fn(prediction, vocals)

            # update total loss
            total_loss += loss.item()

    # compute average loss
    avg_loss = total_loss / len(dataloader)

    return avg_loss

In [53]:
def fit(model, optimizer, loss_fn, num_epochs, train_dataloader, val_dataloader):
    for epoch in tqdm(range(num_epochs)):
        # train model for one epoch
        train_loss = train(model, train_dataloader, optimizer, loss_fn)
        print(f'Epoch {epoch + 1} - Train loss: {train_loss:.4f}')

        # evaluate model on validation data
        val_loss = validate(model, val_dataloader, loss_fn)
        print(f'Epoch {epoch + 1} - Val loss: {val_loss:.4f}')
        print("---------------------------------------------")

In [54]:
def calculate_sdr(predicted_output, ground_truth):
    """
    Calculates the Signal-to-Distortion Ratio (SDR) metric between a predicted output and its corresponding ground truth.

    Args:
    predicted_output: A numpy array of shape (number of channels, number of frames, number of bins).
    ground_truth: A numpy array of shape (number of channels, number of frames, number of bins).

    Returns:
    sdr: A scalar representing the SDR value between predicted_output and ground_truth.
    """
    eps = np.finfo(np.float32).eps  # To avoid division by zero errors
    num_channels = predicted_output.shape[0]
    sdr_sum = 0
    
#     print(predicted_output.shape, ground_truth.shape)
    
    for c in range(num_channels):
        # Compute the power of the true source signal
        true_source_power = np.sum(ground_truth[c]**2)

        # Compute the scalar product between true source signal and predicted signal
        true_pred_scalar = np.sum(ground_truth[c] * predicted_output[c])

        # Compute the SDR for this channel
        sdr = 10 * np.log10(true_source_power / (np.sum(ground_truth[c]**2) - true_pred_scalar + eps) + eps)
        if not math.isnan(sdr):
            sdr_sum += sdr

    # Compute the average SDR across all channels
    sdr = sdr_sum / num_channels

    return -sdr

In [55]:
def test(model, dataloader):
    model.eval()
    all_targets = []
    all_predictions = []
    sdr = 0
    count = 0
    with torch.no_grad():
        for i, (audio, vocals) in tqdm(enumerate(dataloader)):
            # move data to GPU if available
            if torch.cuda.is_available():
                audio = audio.cuda()
                vocals = vocals.cuda()

            # forward pass
            prediction = model(audio)
            
            # convert predictions and ground truth to numpy arrays
            prediction = prediction.cpu().numpy()
            vocals = vocals.cpu().numpy()
            
            for i in range(prediction.shape[0]):
                pred = prediction[i,:,:]
                targ = vocals[i,:,:]
                sdr += calculate_sdr(pred, targ)
                count += 1
    mean_sdr = sdr/count
    return mean_sdr

In [56]:
root = "/kaggle/input/musdb18/musdb18"
train_loader, val_loader, test_loader = get_musdb18_dataloaders(root, batch_size = 2)
# for x, y in val_loader:
#     print(x.shape, y.shape)

Number of Training Samples: 87
Number of Validation Samples: 13
Number of Testing Samples: 50


In [57]:
# define model and optimizer
model = TransformerModel(num_heads = 8, num_layers = 6)
# model = PretrainedTransformer(output_size=257, num_heads = 8, num_layers = 6)
if torch.cuda.is_available():
    model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# define loss function
loss_fn = nn.MSELoss()

fit(model, optimizer, loss_fn, 250, train_loader, val_loader)

In [58]:
print(f"Training MSE: {validate(model, train_loader, loss_fn)}")
print(f"Validation MSE: {validate(model, val_loader, loss_fn)}")
print(f"Testing MSE: {validate(model, test_loader, loss_fn)}")

In [None]:
import warnings

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=RuntimeWarning)    
    print(f"Training SDR: {test(model, train_loader)}")
    print(f"Validation SDR: {test(model, val_loader)}")
    print(f"Testing SDR: {test(model, test_loader)}")