In [1]:
!pip install musdb
# !pip install librosa
# !pip install museval

Collecting musdb
  Downloading musdb-0.4.0-py2.py3-none-any.whl (29 kB)
Collecting stempeg>=0.2.3
  Downloading stempeg-0.2.3-py3-none-any.whl (963 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m963.5/963.5 kB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0m
Collecting ffmpeg-python>=0.2.0
  Downloading ffmpeg_python-0.2.0-py3-none-any.whl (25 kB)
Installing collected packages: ffmpeg-python, stempeg, musdb
Successfully installed ffmpeg-python-0.2.0 musdb-0.4.0 stempeg-0.2.3
[0m

In [2]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import musdb
import math
import librosa
from tqdm.auto import tqdm
import torchaudio
from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX

In [3]:
class MUSDB18Dataset(Dataset):
    def __init__(self, musdb18_root, subset="train", split = "train", sr=44100, duration=2, seed = 42):
        self.sr = sr
        self.duration = duration
        self.seed = seed
#         self.transformer = torchaudio.transforms.Spectrogram(n_fft = 600)
        if(subset == 'test'):
            self.musdb18 = musdb.DB(root = musdb18_root, subsets=subset)
            print(f"Number of Testing Samples: {len(self.musdb18)}")
        else:
            self.musdb18 = musdb.DB(root = musdb18_root, subsets = subset, split = split)
            if split == "train":
                print(f"Number of Training Samples: {len(self.musdb18)}")
            else:
                print(f"Number of Validation Samples: {len(self.musdb18)}")
        
    def __getitem__(self, index):
        random.seed(self.seed)
        track = self.musdb18.tracks[index]
        track.chunk_duration = self.duration
        track.chunk_start = random.uniform(0, track.duration - track.chunk_duration)
        
        # Load the mixture waveform
        audio = track.audio.T
        audio = audio.mean(axis = 0, keepdims = True)
        audio_tensor = torch.from_numpy(audio).float()
#         audio_tensor = self.transformer(audio_tensor)
        
        # Load the target waveform (vocals)
        vocals = track.targets["vocals"].audio.T
        vocals = vocals.mean(axis = 0, keepdims = True)
        
        
        others = track.targets["accompaniment"].audio.T
        others = others.mean(axis = 0, keepdims = True)
        
        vocals_tensor = torch.from_numpy(np.concatenate((vocals, others), axis = 0)).float()
        
#       vocals_tensor = self.transformer(vocals_tensor)

        return audio_tensor, vocals_tensor

    def __len__(self):
        return len(self.musdb18.tracks)

In [4]:
def get_musdb18_dataloaders(root, batch_size=8, num_workers=0, seed = 42):
    train_set = MUSDB18Dataset(root, subset="train")
    val_set = MUSDB18Dataset(root, subset="train", split = 'valid')
    test_set = MUSDB18Dataset(root, subset="test")

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, test_loader

In [5]:
class ConvTasNetWithTransformer(nn.Module):
    def __init__(self, num_sources = 2, hidden_size=128, num_heads=1, num_layers=1):
        super(ConvTasNetWithTransformer, self).__init__()
        
        self.num_sources = num_sources
        
        conv_tasnet = CONVTASNET_BASE_LIBRI2MIX.get_model()
        self.conv_tasnet_encoder = conv_tasnet.encoder
        self.conv_tasnet_maskGenerator = conv_tasnet.mask_generator
        self.conv_tasnet_decoder = conv_tasnet.decoder
        
        self.input_embed = nn.Linear(512, hidden_size)
        self.transformer = nn.TransformerEncoderLayer(hidden_size, num_heads, hidden_size, dropout=0.1)
        self.encoder = nn.TransformerEncoder(self.transformer, num_layers=num_layers)
        
        self.output_embed = nn.Linear(hidden_size, 512)
        
        for param in self.conv_tasnet_encoder.parameters():
            param.requires_grad = False
            
        for param in self.conv_tasnet_maskGenerator.parameters():
            param.requires_grad = False
        
    def forward(self, mixture):
        # shape of mixture: [batch_size, num_channels, num_samples]
        
#         print(mixture.shape)
        
        # apply ConvTasNet for feature extraction and encoding
        encoded = self.conv_tasnet_encoder(mixture)
#         print(encoded.shape)
        
        encoded = self.conv_tasnet_maskGenerator(encoded)
#         print(encoded.shape)
        
        # permute dimensions for transformer input
        encoded = encoded.permute(3, 0, 1, 2)
        encoded = encoded.reshape(encoded.size(0), -1, encoded.size(-1))
#         print(encoded.shape)
        
        encoded = self.input_embed(encoded)
        
        # apply transformer
        encoded = self.encoder(encoded)
        
        decoded = self.output_embed(encoded)
#         print(decoded.shape)
        
        decoded = decoded.permute(1, 2, 0)
        decoded = self.conv_tasnet_decoder(decoded)
        
        separated = decoded.view(-1, self.num_sources, decoded.size(-1)) 
        
#         print(separated.shape)
        
        # shape of separated: [batch_size, num_sources, num_samples]
        return separated

In [6]:
def train(model, dataloader, optimizer, loss_fn):
    model.train()
    total_loss = 0

    for i, (audio, vocals) in enumerate(dataloader):
        # move data to GPU if available
        if torch.cuda.is_available():
            audio = audio.cuda()
            vocals = vocals.cuda()

        # forward pass
        prediction = model(audio)
        loss = loss_fn(prediction, vocals)

        # backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update total loss
        total_loss += loss.item()

    # compute average loss
    avg_loss = total_loss / len(dataloader)

    return avg_loss

In [7]:
def validate(model, dataloader, loss_fn):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for i, (audio, vocals) in enumerate(dataloader):
            # move data to GPU if available
            if torch.cuda.is_available():
                audio = audio.cuda()
                vocals = vocals.cuda()

            # forward pass
            prediction = model(audio)
            loss = loss_fn(prediction, vocals)

            # update total loss
            total_loss += loss.item()

    # compute average loss
    avg_loss = total_loss / len(dataloader)

    return avg_loss

In [8]:
def fit(model, optimizer, loss_fn, num_epochs, train_dataloader, val_dataloader):
    for epoch in tqdm(range(num_epochs)):
        # train model for one epoch
        train_loss = train(model, train_dataloader, optimizer, loss_fn)
        print(f'Epoch {epoch + 1} - Train loss: {train_loss:.4f}')

        # evaluate model on validation data
        val_loss = validate(model, val_dataloader, loss_fn)
        print(f'Epoch {epoch + 1} - Val loss: {val_loss:.4f}')
        print("---------------------------------------------")

In [9]:
def calculate_sdr(predicted_output, ground_truth):
    """
    Calculates the Signal-to-Distortion Ratio (SDR) metric between a predicted output and its corresponding ground truth.

    Args:
    predicted_output: A numpy array of shape (number of channels, number of frames, number of bins).
    ground_truth: A numpy array of shape (number of channels, number of frames, number of bins).

    Returns:
    sdr: A scalar representing the SDR value between predicted_output and ground_truth.
    """
    eps = np.finfo(np.float32).eps  # To avoid division by zero errors
    num_channels = predicted_output.shape[0]
    sdr_sum = 0
    
#     print(predicted_output.shape, ground_truth.shape)
    
    for c in range(num_channels):
        # Compute the power of the true source signal
        true_source_power = np.sum(ground_truth[c]**2)

        # Compute the scalar product between true source signal and predicted signal
        true_pred_scalar = np.sum(ground_truth[c] * predicted_output[c])

        # Compute the SDR for this channel
        sdr = 10 * np.log10(true_source_power / (np.sum(ground_truth[c]**2) - true_pred_scalar + eps) + eps)
        if(not math.isnan(sdr)):
            sdr_sum += sdr

    # Compute the average SDR across all channels
    sdr = sdr_sum / num_channels

    return -sdr

In [10]:
def test(model, dataloader):
    model.eval()
    all_targets = []
    all_predictions = []
    sdr = 0
    count = 0
    with torch.no_grad():
        for i, (audio, vocals) in tqdm(enumerate(dataloader)):
            # move data to GPU if available
            if torch.cuda.is_available():
                audio = audio.cuda()
                vocals = vocals.cuda()

            # forward pass
            prediction = model(audio)
            
            # convert predictions and ground truth to numpy arrays
            prediction = prediction.cpu().numpy()
            vocals = vocals.cpu().numpy()
            
            for i in range(prediction.shape[0]):
                pred = prediction[i,:,:]
                targ = vocals[i,:,:]
                sdr += calculate_sdr(pred, targ)
                count += 1
    mean_sdr = sdr/count
    return mean_sdr

In [11]:
root = "/kaggle/input/musdb18/musdb18"
train_loader, val_loader, test_loader = get_musdb18_dataloaders(root, batch_size = 2)
# for x, y in val_loader:
#     print(x.shape, y.shape)

Number of Training Samples: 87
Number of Validation Samples: 13
Number of Testing Samples: 50


In [12]:
# define model and optimizer
model = ConvTasNetWithTransformer()
# model = PretrainedTransformer(output_size=257, num_heads = 8, num_layers = 6)
if torch.cuda.is_available():
    model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# define loss function
loss_fn = nn.MSELoss()

fit(model, optimizer, loss_fn, 150, train_loader, val_loader)

  0%|          | 0.00/19.1M [00:00<?, ?B/s]

  0%|          | 0/150 [00:00<?, ?it/s]

Epoch 1 - Train loss: 111.2471
Epoch 1 - Val loss: 32.6345
---------------------------------------------
Epoch 2 - Train loss: 30.2429
Epoch 2 - Val loss: 14.3192
---------------------------------------------
Epoch 3 - Train loss: 19.4998
Epoch 3 - Val loss: 9.4439
---------------------------------------------
Epoch 4 - Train loss: 14.6934
Epoch 4 - Val loss: 6.9453
---------------------------------------------
Epoch 5 - Train loss: 11.8330
Epoch 5 - Val loss: 5.7341
---------------------------------------------
Epoch 6 - Train loss: 9.8390
Epoch 6 - Val loss: 4.7713
---------------------------------------------
Epoch 7 - Train loss: 8.4771
Epoch 7 - Val loss: 4.0259
---------------------------------------------
Epoch 8 - Train loss: 7.1362
Epoch 8 - Val loss: 3.3932
---------------------------------------------
Epoch 9 - Train loss: 6.1843
Epoch 9 - Val loss: 3.2421
---------------------------------------------
Epoch 10 - Train loss: 5.5622
Epoch 10 - Val loss: 2.6208
----------------

In [13]:
print(f"Training MSE: {validate(model, train_loader, loss_fn)}")
print(f"Validation MSE: {validate(model, val_loader, loss_fn)}")
print(f"Testing MSE: {validate(model, test_loader, loss_fn)}")

Training MSE: 0.027541908520189198
Validation MSE: 0.026301669755152295
Testing MSE: 0.03057378761470318


In [14]:
import warnings

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=RuntimeWarning)    
    print(f"Training SDR: {test(model, train_loader)}")
    print(f"Validation SDR: {test(model, val_loader)}")
    print(f"Testing SDR: {test(model, test_loader)}")

0it [00:00, ?it/s]

Training SDR: 4.347351231621799


0it [00:00, ?it/s]

Validation SDR: 3.3966928381526555


0it [00:00, ?it/s]

Testing SDR: 4.090461655007675
