In [1]:
!pip install musdb
# !pip install librosa
# !pip install museval

Collecting musdb
  Downloading musdb-0.4.0-py2.py3-none-any.whl (29 kB)
Collecting stempeg>=0.2.3
  Downloading stempeg-0.2.3-py3-none-any.whl (963 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m963.5/963.5 kB[0m [31m26.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ffmpeg-python>=0.2.0
  Downloading ffmpeg_python-0.2.0-py3-none-any.whl (25 kB)
Installing collected packages: ffmpeg-python, stempeg, musdb
Successfully installed ffmpeg-python-0.2.0 musdb-0.4.0 stempeg-0.2.3
[0m

In [2]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import musdb
import math
import librosa
from tqdm.auto import tqdm
import torchaudio
from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX

In [3]:
class MUSDB18Dataset(Dataset):
    def __init__(self, musdb18_root, subset="train", split = "train", sr=44100, duration=2, seed = 42):
        self.sr = sr
        self.duration = duration
        self.seed = seed
#         self.transformer = torchaudio.transforms.Spectrogram(n_fft = 600)
        if(subset == 'test'):
            self.musdb18 = musdb.DB(root = musdb18_root, subsets=subset)
            print(f"Number of Testing Samples: {len(self.musdb18)}")
        else:
            self.musdb18 = musdb.DB(root = musdb18_root, subsets = subset, split = split)
            if split == "train":
                print(f"Number of Training Samples: {len(self.musdb18)}")
            else:
                print(f"Number of Validation Samples: {len(self.musdb18)}")
        
    def __getitem__(self, index):
        random.seed(self.seed)
        track = self.musdb18.tracks[index]
        track.chunk_duration = self.duration
        track.chunk_start = random.uniform(0, track.duration - track.chunk_duration)
        
        # Load the mixture waveform
        audio = track.audio.T
        audio = audio.mean(axis = 0, keepdims = True)
        audio_tensor = torch.from_numpy(audio).float()
#         audio_tensor = self.transformer(audio_tensor)
        
        # Load the target waveform (vocals)
        vocals = track.targets["vocals"].audio.T
        vocals = vocals.mean(axis = 0, keepdims = True)
        
        
        others = track.targets["accompaniment"].audio.T
        others = others.mean(axis = 0, keepdims = True)
        
        vocals_tensor = torch.from_numpy(np.concatenate((vocals, others), axis = 0)).float()
        
#       vocals_tensor = self.transformer(vocals_tensor)

        return audio_tensor, vocals_tensor

    def __len__(self):
        return len(self.musdb18.tracks)

In [4]:
def get_musdb18_dataloaders(root, batch_size=8, num_workers=0, seed = 42):
    train_set = MUSDB18Dataset(root, subset="train")
    val_set = MUSDB18Dataset(root, subset="train", split = 'valid')
    test_set = MUSDB18Dataset(root, subset="test")

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, test_loader

In [5]:
class ConvTasNetWithTransformer(nn.Module):
    def __init__(self, num_sources = 2, hidden_size=128, num_heads=1, num_layers=1):
        super(ConvTasNetWithTransformer, self).__init__()
        
        self.num_sources = num_sources
        
        conv_tasnet = CONVTASNET_BASE_LIBRI2MIX.get_model()
        self.conv_tasnet_encoder = conv_tasnet.encoder
        self.conv_tasnet_maskGenerator = conv_tasnet.mask_generator
        self.conv_tasnet_decoder = conv_tasnet.decoder
        
#         self.input_embed = nn.Linear(512, hidden_size)
#         self.transformer = nn.TransformerEncoderLayer(hidden_size, num_heads, hidden_size, dropout=0.1)
#         self.encoder = nn.TransformerEncoder(self.transformer, num_layers=num_layers)
        
#         self.output_embed = nn.Linear(hidden_size, 512)
        
        for param in self.conv_tasnet_encoder.parameters():
            param.requires_grad = False
            
        for param in self.conv_tasnet_maskGenerator.parameters():
            param.requires_grad = False
        
    def forward(self, mixture):
        # shape of mixture: [batch_size, num_channels, num_samples]
        
#         print(mixture.shape)
        
        # apply ConvTasNet for feature extraction and encoding
        encoded = self.conv_tasnet_encoder(mixture)
#         print(encoded.shape)
        
        encoded = self.conv_tasnet_maskGenerator(encoded)
#         print(encoded.shape)
        
        # permute dimensions for transformer input
#         encoded = encoded.permute(3, 0, 1, 2)
#         encoded = encoded.reshape(encoded.size(0), -1, encoded.size(-1))
#         print(encoded.shape)
        
#         encoded = self.input_embed(encoded)
        
        # apply transformer
#         encoded = self.encoder(encoded)
        
#         decoded = self.output_embed(encoded)
#         print(decoded.shape)
        
#         decoded = decoded.permute(1, 2, 0)
        encoded = encoded.reshape(-1, encoded.size(2), encoded.size(3))
        decoded = self.conv_tasnet_decoder(encoded)
        
        separated = decoded.view(-1, self.num_sources, decoded.size(-1)) 
        
#         print(separated.shape)
        
        # shape of separated: [batch_size, num_sources, num_samples]
        return separated

In [6]:
def train(model, dataloader, optimizer, loss_fn):
    model.train()
    total_loss = 0

    for i, (audio, vocals) in enumerate(dataloader):
        # move data to GPU if available
        if torch.cuda.is_available():
            audio = audio.cuda()
            vocals = vocals.cuda()

        # forward pass
        prediction = model(audio)
        loss = loss_fn(prediction, vocals)

        # backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update total loss
        total_loss += loss.item()

    # compute average loss
    avg_loss = total_loss / len(dataloader)

    return avg_loss

In [7]:
def validate(model, dataloader, loss_fn):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for i, (audio, vocals) in enumerate(dataloader):
            # move data to GPU if available
            if torch.cuda.is_available():
                audio = audio.cuda()
                vocals = vocals.cuda()

            # forward pass
            prediction = model(audio)
            loss = loss_fn(prediction, vocals)

            # update total loss
            total_loss += loss.item()

    # compute average loss
    avg_loss = total_loss / len(dataloader)

    return avg_loss

In [8]:
def calculate_sdr(predicted_output, ground_truth):
    """
    Calculates the Signal-to-Distortion Ratio (SDR) metric between a predicted output and its corresponding ground truth.

    Args:
    predicted_output: A numpy array of shape (number of channels, number of frames, number of bins).
    ground_truth: A numpy array of shape (number of channels, number of frames, number of bins).

    Returns:
    sdr: A scalar representing the SDR value between predicted_output and ground_truth.
    """
    eps = np.finfo(np.float32).eps  # To avoid division by zero errors
    num_channels = predicted_output.shape[0]
    sdr_sum = 0
    
#     print(predicted_output.shape, ground_truth.shape)
    
    for c in range(num_channels):
        # Compute the power of the true source signal
        true_source_power = np.sum(ground_truth[c]**2)

        # Compute the scalar product between true source signal and predicted signal
        true_pred_scalar = np.sum(ground_truth[c] * predicted_output[c])

        # Compute the SDR for this channel
        sdr = 10 * np.log10(true_source_power / (np.sum(ground_truth[c]**2) - true_pred_scalar + eps) + eps)
        if(not math.isnan(sdr)):
            sdr_sum += sdr

    # Compute the average SDR across all channels
    sdr = sdr_sum / num_channels

    return -sdr

In [9]:
def test(model, dataloader):
    model.eval()
    all_targets = []
    all_predictions = []
    sdr = 0
    count = 0
    with torch.no_grad():
        for i, (audio, vocals) in tqdm(enumerate(dataloader)):
            # move data to GPU if available
            if torch.cuda.is_available():
                audio = audio.cuda()
                vocals = vocals.cuda()

            # forward pass
            prediction = model(audio)
            
            # convert predictions and ground truth to numpy arrays
            prediction = prediction.cpu().numpy()
            vocals = vocals.cpu().numpy()
            
            for i in range(prediction.shape[0]):
                pred = prediction[i,:,:]
                targ = vocals[i,:,:]
                sdr += calculate_sdr(pred, targ)
                count += 1
    mean_sdr = sdr/count
    return mean_sdr

In [10]:
def fit(model, optimizer, loss_fn, num_epochs, train_dataloader, val_dataloader):
    for epoch in tqdm(range(num_epochs)):
        # train model for one epoch
        train_loss = train(model, train_dataloader, optimizer, loss_fn)
        print(f'Epoch {epoch + 1} - Train loss: {train_loss:.4f}')

        # evaluate model on validation data
        val_loss = validate(model, val_dataloader, loss_fn)
        print(f'Epoch {epoch + 1} - Val loss: {val_loss:.4f}')
        print("---------------------------------------------")
        
        if ((epoch+1)%20 == 0):
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=RuntimeWarning)    
                print(f"Epoch: {epoch+1} | Training SDR: {test(model, train_loader)}")
                print(f"Epoch: {epoch+1} | Validation SDR: {test(model, val_loader)}")
                print(f"Epoch: {epoch+1} | Testing SDR: {test(model, test_loader)}")

In [11]:
root = "/kaggle/input/musdb18/musdb18"
train_loader, val_loader, test_loader = get_musdb18_dataloaders(root, batch_size = 2)
# for x, y in val_loader:
#     print(x.shape, y.shape)

Number of Training Samples: 87
Number of Validation Samples: 13
Number of Testing Samples: 50


In [12]:
# define model and optimizer
model = ConvTasNetWithTransformer()
# model = PretrainedTransformer(output_size=257, num_heads = 8, num_layers = 6)
if torch.cuda.is_available():
    model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# define loss function
loss_fn = nn.MSELoss()

# fit(model, optimizer, loss_fn, 150, train_loader, val_loader)

100%|██████████| 19.1M/19.1M [00:00<00:00, 229MB/s]


In [13]:
print(f"Training MSE: {validate(model, train_loader, loss_fn)}")
print(f"Validation MSE: {validate(model, val_loader, loss_fn)}")
print(f"Testing MSE: {validate(model, test_loader, loss_fn)}")

Training MSE: 44958257.72727273
Validation MSE: 38945165.14285714
Testing MSE: 46983650.88


In [14]:
import warnings

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=RuntimeWarning)    
    print(f"Training SDR: {test(model, train_loader)}")
    print(f"Validation SDR: {test(model, val_loader)}")
    print(f"Testing SDR: {test(model, test_loader)}")

0it [00:00, ?it/s]

Training SDR: 24.069199849819316


0it [00:00, ?it/s]

Validation SDR: 30.689727067947388


0it [00:00, ?it/s]

Testing SDR: 24.82991266474128


In [15]:
fit(model, optimizer, loss_fn, 150, train_loader, val_loader)

  0%|          | 0/150 [00:00<?, ?it/s]

Epoch 1 - Train loss: 41252673.5909
Epoch 1 - Val loss: 32507081.1429
---------------------------------------------
Epoch 2 - Train loss: 34747185.2273
Epoch 2 - Val loss: 27277924.8571
---------------------------------------------
Epoch 3 - Train loss: 29283852.8182
Epoch 3 - Val loss: 23007441.1429
---------------------------------------------
Epoch 4 - Train loss: 24953053.2727
Epoch 4 - Val loss: 19583535.4286
---------------------------------------------
Epoch 5 - Train loss: 21411251.1818
Epoch 5 - Val loss: 16842906.8571
---------------------------------------------
Epoch 6 - Train loss: 18532320.1591
Epoch 6 - Val loss: 14706359.5714
---------------------------------------------
Epoch 7 - Train loss: 16356138.1818
Epoch 7 - Val loss: 13003115.5714
---------------------------------------------
Epoch 8 - Train loss: 14463191.7500
Epoch 8 - Val loss: 11641475.0000
---------------------------------------------
Epoch 9 - Train loss: 13015231.2273
Epoch 9 - Val loss: 10608737.2857
--

0it [00:00, ?it/s]

Epoch: 20 | Training SDR: 22.738654938922533


0it [00:00, ?it/s]

Epoch: 20 | Validation SDR: 24.976401649988613


0it [00:00, ?it/s]

Epoch: 20 | Testing SDR: 21.690710425376892
Epoch 21 - Train loss: 7723363.6818
Epoch 21 - Val loss: 6944525.1429
---------------------------------------------
Epoch 22 - Train loss: 7582551.9545
Epoch 22 - Val loss: 6852143.4286
---------------------------------------------
Epoch 23 - Train loss: 7458731.9545
Epoch 23 - Val loss: 6762514.9286
---------------------------------------------
Epoch 24 - Train loss: 7343141.4773
Epoch 24 - Val loss: 6675688.4286
---------------------------------------------
Epoch 25 - Train loss: 7244108.6705
Epoch 25 - Val loss: 6594054.5714
---------------------------------------------
Epoch 26 - Train loss: 7130313.6136
Epoch 26 - Val loss: 6517377.8571
---------------------------------------------
Epoch 27 - Train loss: 7036418.4318
Epoch 27 - Val loss: 6443228.0714
---------------------------------------------
Epoch 28 - Train loss: 6962777.9432
Epoch 28 - Val loss: 6367181.7857
---------------------------------------------
Epoch 29 - Train loss: 68566

0it [00:00, ?it/s]

Epoch: 40 | Training SDR: 22.55115480258547


0it [00:00, ?it/s]

Epoch: 40 | Validation SDR: 26.589902731088493


0it [00:00, ?it/s]

Epoch: 40 | Testing SDR: 21.823344230651855
Epoch 41 - Train loss: 5944675.7045
Epoch 41 - Val loss: 5595540.4286
---------------------------------------------
Epoch 42 - Train loss: 5903288.3636
Epoch 42 - Val loss: 5547515.7143
---------------------------------------------
Epoch 43 - Train loss: 5824536.5795
Epoch 43 - Val loss: 5499671.4286
---------------------------------------------
Epoch 44 - Train loss: 5784479.3750
Epoch 44 - Val loss: 5453979.2143
---------------------------------------------
Epoch 45 - Train loss: 5709695.8125
Epoch 45 - Val loss: 5405146.4286
---------------------------------------------
Epoch 46 - Train loss: 5658527.2955
Epoch 46 - Val loss: 5359232.9286
---------------------------------------------
Epoch 47 - Train loss: 5629695.8068
Epoch 47 - Val loss: 5312309.9286
---------------------------------------------
Epoch 48 - Train loss: 5558727.5398
Epoch 48 - Val loss: 5275420.5000
---------------------------------------------
Epoch 49 - Train loss: 56299

0it [00:00, ?it/s]

Epoch: 60 | Training SDR: 21.56274627680066


0it [00:00, ?it/s]

Epoch: 60 | Validation SDR: 25.991660494070786


0it [00:00, ?it/s]

Epoch: 60 | Testing SDR: 21.222721239924432
Epoch 61 - Train loss: 4942544.1705
Epoch 61 - Val loss: 4745767.8214
---------------------------------------------
Epoch 62 - Train loss: 4924103.4830
Epoch 62 - Val loss: 4710657.7857
---------------------------------------------
Epoch 63 - Train loss: 4855959.5227
Epoch 63 - Val loss: 4675832.5000
---------------------------------------------
Epoch 64 - Train loss: 4821005.7273
Epoch 64 - Val loss: 4637873.6429
---------------------------------------------
Epoch 65 - Train loss: 4794307.8068
Epoch 65 - Val loss: 4601077.4286
---------------------------------------------
Epoch 66 - Train loss: 4744832.6591
Epoch 66 - Val loss: 4569636.8571
---------------------------------------------
Epoch 67 - Train loss: 4723316.3295
Epoch 67 - Val loss: 4529254.5000
---------------------------------------------
Epoch 68 - Train loss: 4688787.4261
Epoch 68 - Val loss: 4495046.0714
---------------------------------------------
Epoch 69 - Train loss: 46464

0it [00:00, ?it/s]

Epoch: 80 | Training SDR: 20.28674705275174


0it [00:00, ?it/s]

Epoch: 80 | Validation SDR: 26.70686139510228


0it [00:00, ?it/s]

Epoch: 80 | Testing SDR: 20.46620301157236
Epoch 81 - Train loss: 4180199.9205
Epoch 81 - Val loss: 4065898.8214
---------------------------------------------
Epoch 82 - Train loss: 4166757.6591
Epoch 82 - Val loss: 4034552.2143
---------------------------------------------
Epoch 83 - Train loss: 4116892.0341
Epoch 83 - Val loss: 4011885.4286
---------------------------------------------
Epoch 84 - Train loss: 4203315.0852
Epoch 84 - Val loss: 3978629.1786
---------------------------------------------
Epoch 85 - Train loss: 4068021.2045
Epoch 85 - Val loss: 3949850.2500
---------------------------------------------
Epoch 86 - Train loss: 4032572.5568
Epoch 86 - Val loss: 3916272.5714
---------------------------------------------
Epoch 87 - Train loss: 4009856.6080
Epoch 87 - Val loss: 3890403.5357
---------------------------------------------
Epoch 88 - Train loss: 3968997.6193
Epoch 88 - Val loss: 3860725.8571
---------------------------------------------
Epoch 89 - Train loss: 395406

0it [00:00, ?it/s]

Epoch: 100 | Training SDR: 19.006508184575488


0it [00:00, ?it/s]

Epoch: 100 | Validation SDR: 25.77129508440311


0it [00:00, ?it/s]

Epoch: 100 | Testing SDR: 19.497002917528153
Epoch 101 - Train loss: 3593221.7102
Epoch 101 - Val loss: 3505368.7857
---------------------------------------------
Epoch 102 - Train loss: 3561243.0057
Epoch 102 - Val loss: 3479953.8571
---------------------------------------------
Epoch 103 - Train loss: 3529171.6307
Epoch 103 - Val loss: 3450562.7857
---------------------------------------------
Epoch 104 - Train loss: 3508546.3466
Epoch 104 - Val loss: 3430632.9643
---------------------------------------------
Epoch 105 - Train loss: 3482857.8807
Epoch 105 - Val loss: 3404653.5000
---------------------------------------------
Epoch 106 - Train loss: 3446110.3920
Epoch 106 - Val loss: 3382291.8929
---------------------------------------------
Epoch 107 - Train loss: 3426302.0852
Epoch 107 - Val loss: 3350731.7857
---------------------------------------------
Epoch 108 - Train loss: 3411829.4432
Epoch 108 - Val loss: 3331489.3571
---------------------------------------------
Epoch 109 -

0it [00:00, ?it/s]

Epoch: 120 | Training SDR: 18.347350453379853


0it [00:00, ?it/s]

Epoch: 120 | Validation SDR: 24.678949897105877


0it [00:00, ?it/s]

Epoch: 120 | Testing SDR: 18.50887430459261
Epoch 121 - Train loss: 3102486.1136
Epoch 121 - Val loss: 3035763.0714
---------------------------------------------
Epoch 122 - Train loss: 3102442.8239
Epoch 122 - Val loss: 3014082.7143
---------------------------------------------
Epoch 123 - Train loss: 3050943.5284
Epoch 123 - Val loss: 2996646.4286
---------------------------------------------
Epoch 124 - Train loss: 3041992.4148
Epoch 124 - Val loss: 2978890.6071
---------------------------------------------
Epoch 125 - Train loss: 3013833.7102
Epoch 125 - Val loss: 2956107.2143
---------------------------------------------
Epoch 126 - Train loss: 3002310.4176
Epoch 126 - Val loss: 2936510.3571
---------------------------------------------
Epoch 127 - Train loss: 2962294.6364
Epoch 127 - Val loss: 2918170.1429
---------------------------------------------
Epoch 128 - Train loss: 2939392.0739
Epoch 128 - Val loss: 2894927.5714
---------------------------------------------
Epoch 129 - 

0it [00:00, ?it/s]

Epoch: 140 | Training SDR: 17.370715888067224


0it [00:00, ?it/s]

Epoch: 140 | Validation SDR: 23.053802779087654


0it [00:00, ?it/s]

Epoch: 140 | Testing SDR: 18.05853741765022
Epoch 141 - Train loss: 2688556.5284
Epoch 141 - Val loss: 2654272.4286
---------------------------------------------
Epoch 142 - Train loss: 2690470.3125
Epoch 142 - Val loss: 2639925.1786
---------------------------------------------
Epoch 143 - Train loss: 2654034.7472
Epoch 143 - Val loss: 2619371.0000
---------------------------------------------
Epoch 144 - Train loss: 2657775.6278
Epoch 144 - Val loss: 2601802.7500
---------------------------------------------
Epoch 145 - Train loss: 2617941.0909
Epoch 145 - Val loss: 2590316.2143
---------------------------------------------
Epoch 146 - Train loss: 2600087.3494
Epoch 146 - Val loss: 2570210.1071
---------------------------------------------
Epoch 147 - Train loss: 2604234.0511
Epoch 147 - Val loss: 2554865.0714
---------------------------------------------
Epoch 148 - Train loss: 2577300.4148
Epoch 148 - Val loss: 2537673.1607
---------------------------------------------
Epoch 149 - 

In [16]:
print(f"Training MSE: {validate(model, train_loader, loss_fn)}")
print(f"Validation MSE: {validate(model, val_loader, loss_fn)}")
print(f"Testing MSE: {validate(model, test_loader, loss_fn)}")

Training MSE: 2543205.0255681816
Validation MSE: 2509882.035714286
Testing MSE: 2483696.15


In [17]:
import warnings

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=RuntimeWarning)    
    print(f"Training SDR: {test(model, train_loader)}")
    print(f"Validation SDR: {test(model, val_loader)}")
    print(f"Testing SDR: {test(model, test_loader)}")

0it [00:00, ?it/s]

Training SDR: 17.089233672481843


0it [00:00, ?it/s]

Validation SDR: 22.159432081075813


0it [00:00, ?it/s]

Testing SDR: 18.295340955257416
