Installs

In [None]:
!pip install musdb
!pip install museval
!apt install -y ffmpeg
!pip install stempeg

Imports

In [None]:
import math
import torch as th
from torch import nn
from torch.utils.data import DataLoader
import musdb
import museval
from IPython.display import Audio, display
from pathlib import Path
from google.colab import drive
import numpy as np
import random
import os

Copy dataset

In [None]:
'''
I've downloaded and extracted the musdb dataset from 
https://zenodo.org/record/1117372#.YKTWk5MzbAI into my google drive.
Here I'm copying it inside the colab VM for faster loading in the rest of the
code.
'''
drive.mount("/content/drive")
!rsync -rah --info=progress2 /content/drive/MyDrive/musdb /content

Parameters

In [None]:
# Seed for the random generators
seed = 42

# Parameters of the neural network
channels = 64
depth = 6
encoder_kernel_1 = 8
encoder_stride_1 = 4
encoder_kernel_2 = 1
encoder_stride_2 = 1
decoder_kernel_1 = 3
decoder_stride_1 = 1
decoder_kernel_2 = 8
decoder_stride_2 = 4
lstm_layers = 2
growth = 2
rescale = 0.1

# Parameters of the dataset generation
samplerate = 44100
sample_length = 10
samples_per_minute = 1.25
shift_seconds = 1
workers = os.cpu_count()

# Parameters of training
batch_size = 8
epochs = 240
augmentation = True
learning_rate = 3e-4
loss_function = "L2" # L1, L2
apply_shifts = 10

# Path to the saved model checkpoint
model_path = "drive/MyDrive/checkpoint/model-l2.pt"

# Path to the folder where to save the estimates
estimates_path = "drive/MyDrive/estimates-l2/"

Init

In [None]:
# Initializing the random generators and the device
th.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

device = th.device('cpu')
if th.cuda.is_available():
    device = th.device('cuda')

Model definition

In [None]:
class Model(nn.Module):
  def __init__(self):
    super().__init__()

    self.audio_channels = 2
    self.tracks = 4

    C_in = self.audio_channels
    C_out = channels

    self.encoder = nn.ModuleList()
    for i in range(0, depth):
      if i > 0:
        C_in = C_out
        C_out = growth * C_in
      self.encoder.append(nn.Sequential(
        nn.Conv1d(C_in, C_out, encoder_kernel_1, encoder_stride_1), 
        nn.ReLU(), 
        nn.Conv1d(C_out, 2*C_out, encoder_kernel_2, encoder_stride_2), 
        nn.GLU(dim=1)
      ))  

    self.lstm = nn.LSTM(hidden_size=C_out, input_size=C_out, bidirectional=True, 
                        num_layers=lstm_layers)
    self.linear = nn.Linear(2 * C_out, C_out)

    self.decoder = nn.ModuleList()
    for i in range(0, depth):
      C_in = C_out
      C_out = int(C_in / growth)
      if i == depth-1:
        C_out = self.audio_channels*self.tracks
      blocks = [
        nn.Conv1d(C_in, 2*C_in, decoder_kernel_1, decoder_stride_1), 
        nn.GLU(dim=1),
        nn.ConvTranspose1d(C_in, C_out, decoder_kernel_2, decoder_stride_2)
      ]
      if i < depth-1: blocks.append(nn.ReLU())
      self.decoder.append(nn.Sequential(*blocks))

  def forward(self, x):
    encoder_out = [x]
    for encode in self.encoder:
      x = encode(x)
      encoder_out.append(x)

    # conv1d/ConvTranspose1d takes in input and gives in output 
    # [batch, channels, length] while lstm/linear wants 
    # [length, batch, size (i.e., channels)]
    x = x.permute(2, 0, 1)
    x = model.lstm(x)[0]
    x = model.linear(x)
    x = x.permute(1, 2, 0)

    for decode in model.decoder:  
      skip = encoder_out.pop()
      skip = th.narrow(skip, 2, (skip.size(2)-x.size(2)) // 2, x.size(2))
      x = x + skip
      x = decode(x)

    x = x.view(x.size(0), self.tracks, self.audio_channels, x.size(-1))
    return x
    
model = Model()
model.to(device)

In [None]:
'''
As stated in the original source code it returns "the nearest valid length to 
use with the model so that there is no time steps leftover in convolutions".
The number of samples of input mixtures must be "valid", i.e., such that
"all convolution windows are full". This "prevents hard to debug mistake with 
the prediction being shifted compared to the input mixture".
'''

def valid_length(length):
  for _ in range(depth):
    length = math.ceil((length - encoder_kernel_1) / encoder_stride_1) + 1
    length = max(1, length)
    length += decoder_kernel_1 - 1
  for _ in range(depth):
    length = (length - 1) * encoder_stride_1 + encoder_kernel_1
  return int(length)

In [None]:
'''
As explained in the paper, the demucs model is not time equivariant. The
apply_model function works around this problem by using the "shift trick", i.e.,
it "samples S random shifts of an input mixture x and averages the predictions 
of the model for each, after having applied the opposite shift".
The function also allows splitting the mixture in chunks of 10 seconds before 
applying the model to reduce the amount of VRAM needed.
'''

def apply_model(model, mix, shifts=None, split=False):
  def apply(mix):
    length = mix.size(-1)
    delta = valid_length(length) - length
    padded = nn.functional.pad(mix, (delta // 2, delta - delta // 2))
    with th.no_grad():
      output = model(padded.unsqueeze(0))[0]
    return th.narrow(output, -1, (output.size(-1) - length) // 2, length)

  def shift(mix):
    length = mix.size(-1)
    max_shift = int(samplerate / 2)
    padded = nn.functional.pad(mix, (max_shift, max_shift))
    offsets = [random.randint(0, max_shift) for i in range(shifts)]
    output = 0
    for offset in offsets:
      shifted = th.narrow(padded, -1, offset, length + max_shift)
      shifted_output = apply(shifted)
      output += th.narrow(shifted_output, -1, max_shift - offset, length)
    return output / shifts

  def split(mix):
    length = mix.size(-1)
    chunk_len = samplerate * 10
    padded_len = math.ceil(length/chunk_len)*chunk_len
    padded = nn.functional.pad(mix, (0, padded_len - length))

    chunks = th.split(padded, chunk_len, dim=-1)
    chunks_output = map(lambda x: shift(x) if shifts else apply(x), chunks)
    output = th.cat(list(chunks_output), dim=-1)

    return th.narrow(output, -1, 0, length)

  if split: return split(mix)
  return shift(mix) if shifts else apply(mix)

Rescaling

In [None]:
# Weight rescaling at initialization
for layer in model.modules():
  if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
    alpha = math.sqrt(layer.weight.std().detach() / rescale)
    layer.weight.data /= alpha
    if layer.bias is not None:
      layer.bias.data /= alpha

Dataset

In [None]:
'''
This cell extracts and loads in memory the training dataset. It first computes 
the mean and the standard deviation of each training track in the musDB database 
and then extracts random chunks of length "sample_length+shift_seconds" from 
every track and applies z-score normalization using the previously computed mean 
and standard deviation.
'''

trainingMusdb = musdb.DB(Path("musdb"), subsets=["train"], split="train")

train_meta = []
for i, track in enumerate(trainingMusdb.tracks, 0):
  print('\r', '%d/%d' % (i + 1, len(trainingMusdb)), end = "")
  mono = np.mean(track.audio, axis=1)
  train_meta.append({"mean": mono.mean(), "std": mono.std()})
  del mono
  
sample_duration = valid_length(sample_length*samplerate)/samplerate+shift_seconds

training_set = []

for i, track in enumerate(trainingMusdb.tracks, 0):
  track.chunk_duration = sample_duration

  print('\r', '%d/%d' % (i + 1, len(trainingMusdb)), end = "")

  for start in [random.randint(0, int(track.duration-sample_duration)-1) 
                for j in range(max(1, int(samples_per_minute*track.duration/60)))]:
    track.chunk_start = start
    song = track.stems.transpose(0, 2, 1)
    song = (song - train_meta[i]["mean"]) / train_meta[i]["std"]
    training_set.append(th.from_numpy(song).float())

trainingLoader = DataLoader(training_set, batch_size=batch_size, 
                            num_workers=workers, shuffle=True)

In [None]:
'''
The validation tracks are instead loaded from the disk every time to reserve as 
much memory as possible for the training dataset. As shown later in the training
loop, the validation loss is computed only once every 10 epochs and using only 
the first 30 seconds from each track.
'''
validationMusdb = musdb.DB(Path("musdb"), subsets=["train"], split="valid")

Augmentation

In [None]:
# Randomly shifts the audio in time.
def shift(batch, time):
  l = []
  for track in batch:
    length = track.size(-1) - time
    offsets = [random.randrange(time) for i in range(track.size(0))]
    track = [th.narrow(track[i], -1, offsets[i], length) for i in range(track.size(0))]
    l.append(th.stack(track))
  return th.stack(l)

In [None]:
# Randomly flips the audio channels.
def flipChannels(batch):
  if batch.size(2) == 2:
    for i, track in enumerate(batch):
      for j, target in enumerate(track):
        if bool(random.getrandbits(1)):
          batch[i,j] = batch[i,j].flip(0)
  return batch

In [None]:
# Randomly flips the sign.
def flipSign(batch):
  for i, track in enumerate(batch):
    for j, target in enumerate(track):
      if bool(random.getrandbits(1)):
        batch[i,j] = batch[i,j] * -1
  return batch

In [None]:
# Shuffles the sources within the batch.
def remix(batch):
  tracks, targets, channels, length = batch.size()
  permutation = th.stack([th.randperm(tracks, device=batch.device) 
                          for target in range(targets)]).transpose(0, 1)
  return batch.gather(0, permutation.view(tracks, targets, 1, 1)
              .expand(-1, -1, channels, length))

Training

In [None]:
last_epoch = 0
optimizer = th.optim.Adam(model.parameters(), lr=learning_rate)

criterion = nn.L1Loss()
if (loss_function == "L2"): criterion = nn.MSELoss()

In [None]:
# Execute this cell to load a checkpoint
try:
  checkpoint = th.load(Path(model_path))
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  last_epoch = checkpoint['epoch']
except:
  print("Start from 0")

In [None]:
'''
This cell trains the network and saves a checkpoint in the google drive 
every 10 epochs.
'''
for epoch in range(last_epoch, epochs):
  model.train()
  train_loss_sum = 0
  train_loss = 0

  for i, batch in enumerate(trainingLoader, 0):    
    batch = batch.to(device)

    sources = batch[:, 1:]
    if augmentation:
      sources = remix(shift(flipChannels(flipSign(sources)), 
                            shift_seconds*samplerate))
    else:
      sources = shift(sources, shift_seconds*samplerate)
    input = sources.sum(dim=1)
    estimated_output = model(input)
    expected_output = th.narrow(sources, 3, 
                                (sources.size(3)-estimated_output.size(3)) // 2, 
                                estimated_output.size(3))

    loss = criterion(estimated_output, expected_output)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    train_loss_sum += loss.item()
    train_loss = train_loss_sum / (1 + i)
    print('\r', 'epoch: %d, batch: %d, loss: %.4f' 
          % (epoch + 1, i + 1, train_loss), end="")

    del batch, input, loss, expected_output, estimated_output
    th.cuda.empty_cache()

  model.eval()
  valid_loss_sum = 0
  valid_loss = 0
  
  if epoch==0 or (epoch + 1)%10 == 0:
    for i, track in enumerate(validationMusdb, 0):
      track.chunk_duration = min(30, track.duration)
      mono = np.mean(track.audio, axis=1)
      song = track.stems.transpose(0, 2, 1)
      song = (song - mono.mean()) / mono.std()
      streams = th.from_numpy(song).float()

      streams = streams.to(device)
      expected_output = streams[1:]
      input = streams[0]
      estimated_output = apply_model(model, input, shifts=apply_shifts)

      loss = criterion(estimated_output, expected_output)

      valid_loss_sum += loss.item()
      valid_loss = valid_loss_sum / (1 + i)

      del streams, input, loss, expected_output, estimated_output
      th.cuda.empty_cache()
  

    print('\r', 'epoch: %d, training loss: %.4f, validation loss: %.4f' 
          % (epoch + 1, train_loss, valid_loss))
    
  else:  
    print('\r', 'epoch: %d, training loss: %.4f' 
            % (epoch + 1, train_loss))
  
  if (epoch + 1)%10 == 0:
    th.save({
      'epoch': epoch+1,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      }, Path(model_path))


print('Finished Training')

This cell shows the output of my last training. 

```
 epoch: 1, training loss: 0.2445, validation loss: 0.1931
 epoch: 2, training loss: 0.1871
 epoch: 3, training loss: 0.1759
 epoch: 4, training loss: 0.1704
 epoch: 5, training loss: 0.1629
 epoch: 6, training loss: 0.1570
 epoch: 7, training loss: 0.1543
 epoch: 8, training loss: 0.1492
 epoch: 9, training loss: 0.1475
 epoch: 10, training loss: 0.1431, validation loss: 0.1424
 epoch: 11, training loss: 0.1448
 epoch: 12, training loss: 0.1425
 epoch: 13, training loss: 0.1373
 epoch: 14, training loss: 0.1386
 epoch: 15, training loss: 0.1378
 epoch: 16, training loss: 0.1344
 epoch: 17, training loss: 0.1329
 epoch: 18, training loss: 0.1307
 epoch: 19, training loss: 0.1264
 epoch: 20, training loss: 0.1263, validation loss: 0.1419
 epoch: 21, training loss: 0.1249
 epoch: 22, training loss: 0.1261
 epoch: 23, training loss: 0.1240
 epoch: 24, training loss: 0.1196
 epoch: 25, training loss: 0.1205
 epoch: 26, training loss: 0.1186
 epoch: 27, training loss: 0.1187
 epoch: 28, training loss: 0.1156
 epoch: 29, training loss: 0.1161
 epoch: 30, training loss: 0.1172, validation loss: 0.1230
 epoch: 31, training loss: 0.1163
 epoch: 32, training loss: 0.1168
 epoch: 33, training loss: 0.1159
 epoch: 34, training loss: 0.1113
 epoch: 35, training loss: 0.1128
 epoch: 36, training loss: 0.1115
 epoch: 37, training loss: 0.1110
 epoch: 38, training loss: 0.1095
 epoch: 39, training loss: 0.1098
 epoch: 40, training loss: 0.1080, validation loss: 0.1136
 epoch: 41, training loss: 0.1081
 epoch: 42, training loss: 0.1085
 epoch: 43, training loss: 0.1101
 epoch: 44, training loss: 0.1062
 epoch: 45, training loss: 0.1044
 epoch: 46, training loss: 0.1071
 epoch: 47, training loss: 0.1077
 epoch: 48, training loss: 0.1044
 epoch: 49, training loss: 0.1048
 epoch: 50, training loss: 0.1042, validation loss: 0.1191
 epoch: 51, training loss: 0.1049
 epoch: 52, training loss: 0.1043
 epoch: 53, training loss: 0.1035
 epoch: 54, training loss: 0.1036
 epoch: 55, training loss: 0.1016
 epoch: 56, training loss: 0.1018
 epoch: 57, training loss: 0.1007
 epoch: 58, training loss: 0.1013
 epoch: 59, training loss: 0.1008
 epoch: 60, training loss: 0.1016, validation loss: 0.1179
 epoch: 61, training loss: 0.1018
 epoch: 62, training loss: 0.0991
 epoch: 63, training loss: 0.1001
 epoch: 64, training loss: 0.0991
 epoch: 65, training loss: 0.0988
 epoch: 66, training loss: 0.0994
 epoch: 67, training loss: 0.0991
 epoch: 68, training loss: 0.0994
 epoch: 69, training loss: 0.0958
 epoch: 70, training loss: 0.0970, validation loss: 0.1119
 epoch: 71, training loss: 0.0965
 epoch: 72, training loss: 0.0950
 epoch: 73, training loss: 0.0948
 epoch: 74, training loss: 0.0946
 epoch: 75, training loss: 0.0954
 epoch: 76, training loss: 0.0949
 epoch: 77, training loss: 0.0952
 epoch: 78, training loss: 0.0942
 epoch: 79, training loss: 0.0936
 epoch: 80, training loss: 0.0937, validation loss: 0.1038
 epoch: 81, training loss: 0.0937
 epoch: 82, training loss: 0.0927
 epoch: 83, training loss: 0.0920
 epoch: 84, training loss: 0.0899
 epoch: 85, training loss: 0.0922
 epoch: 86, training loss: 0.0921
 epoch: 87, training loss: 0.0910
 epoch: 88, training loss: 0.0915
 epoch: 89, training loss: 0.0894
 epoch: 90, training loss: 0.0915, validation loss: 0.1026
 epoch: 91, training loss: 0.0891
 epoch: 92, training loss: 0.0911
 epoch: 93, training loss: 0.0886
 epoch: 94, training loss: 0.0920
 epoch: 95, training loss: 0.0910
 epoch: 96, training loss: 0.0890
 epoch: 97, training loss: 0.0900
 epoch: 98, training loss: 0.0892
 epoch: 99, training loss: 0.0885
 epoch: 100, training loss: 0.0859, validation loss: 0.1091
 epoch: 101, training loss: 0.0887
 epoch: 102, training loss: 0.0883
 epoch: 103, training loss: 0.0863
 epoch: 104, training loss: 0.0876
 epoch: 105, training loss: 0.0868
 epoch: 106, training loss: 0.0870
 epoch: 107, training loss: 0.0879
 epoch: 108, training loss: 0.0856
 epoch: 109, training loss: 0.0845
 epoch: 110, training loss: 0.0842, validation loss: 0.1080
 epoch: 111, training loss: 0.0847
 epoch: 112, training loss: 0.0836
 epoch: 113, training loss: 0.0866
 epoch: 114, training loss: 0.0832
 epoch: 115, training loss: 0.0833
 epoch: 116, training loss: 0.0848
 epoch: 117, training loss: 0.0853
 epoch: 118, training loss: 0.0837
 epoch: 119, training loss: 0.0828
 epoch: 120, training loss: 0.0837, validation loss: 0.1051
 epoch: 121, training loss: 0.0817
 epoch: 122, training loss: 0.0817
 epoch: 123, training loss: 0.0807
 epoch: 124, training loss: 0.0822
 epoch: 125, training loss: 0.0820
 epoch: 126, training loss: 0.0805
 epoch: 127, training loss: 0.0801
 epoch: 128, training loss: 0.0802
 epoch: 129, training loss: 0.0811
 epoch: 130, training loss: 0.0798, validation loss: 0.1001
 epoch: 131, training loss: 0.0803
 epoch: 132, training loss: 0.0797
 epoch: 133, training loss: 0.0799
 epoch: 134, training loss: 0.0821
 epoch: 135, training loss: 0.0793
 epoch: 136, training loss: 0.0783
 epoch: 137, training loss: 0.0797
 epoch: 138, training loss: 0.0802
 epoch: 139, training loss: 0.0785
 epoch: 140, training loss: 0.0791, validation loss: 0.0907
 epoch: 141, training loss: 0.0792
 epoch: 142, training loss: 0.0791
 epoch: 143, training loss: 0.0781
 epoch: 144, training loss: 0.0784
 epoch: 145, training loss: 0.0783
 epoch: 146, training loss: 0.0777
 epoch: 147, training loss: 0.0773
 epoch: 148, training loss: 0.0761
 epoch: 149, training loss: 0.0757
 epoch: 150, training loss: 0.0752, validation loss: 0.0906
 epoch: 151, training loss: 0.0775
 epoch: 152, training loss: 0.0776
 epoch: 153, training loss: 0.0779
 epoch: 154, training loss: 0.0765
 epoch: 155, training loss: 0.0757
 epoch: 156, training loss: 0.0748
 epoch: 157, training loss: 0.0758
 epoch: 158, training loss: 0.0765
 epoch: 159, training loss: 0.0762
 epoch: 160, training loss: 0.0758, validation loss: 0.0960
 epoch: 161, training loss: 0.0755
 epoch: 162, training loss: 0.0742
 epoch: 163, training loss: 0.0753
 epoch: 164, training loss: 0.0767
 epoch: 165, training loss: 0.0753
 epoch: 166, training loss: 0.0737
 epoch: 167, training loss: 0.0744
 epoch: 168, training loss: 0.0743
 epoch: 169, training loss: 0.0743
 epoch: 170, training loss: 0.0767, validation loss: 0.0988
 epoch: 171, training loss: 0.0746
 epoch: 172, training loss: 0.0726
 epoch: 173, training loss: 0.0722
 epoch: 174, training loss: 0.0738
 epoch: 175, training loss: 0.0733
 epoch: 176, training loss: 0.0720
 epoch: 177, training loss: 0.0723
 epoch: 178, training loss: 0.0729
 epoch: 179, training loss: 0.0721
 epoch: 180, training loss: 0.0727, validation loss: 0.0911
 epoch: 181, training loss: 0.0733
 epoch: 182, training loss: 0.0730
 epoch: 183, training loss: 0.0714
 epoch: 184, training loss: 0.0698
 epoch: 185, training loss: 0.0714
 epoch: 186, training loss: 0.0699
 epoch: 187, training loss: 0.0731
 epoch: 188, training loss: 0.0720
 epoch: 189, training loss: 0.0736
 epoch: 190, training loss: 0.0734, validation loss: 0.0890
 epoch: 191, training loss: 0.0744
 epoch: 192, training loss: 0.0720
 epoch: 193, training loss: 0.0702
 epoch: 194, training loss: 0.0723
 epoch: 195, training loss: 0.0708
 epoch: 196, training loss: 0.0717
 epoch: 197, training loss: 0.0693
 epoch: 198, training loss: 0.0701
 epoch: 199, training loss: 0.0689
 epoch: 200, training loss: 0.0700, validation loss: 0.0933
 epoch: 201, training loss: 0.0702
 epoch: 202, training loss: 0.0701
 epoch: 203, training loss: 0.0707
 epoch: 204, training loss: 0.0705
 epoch: 205, training loss: 0.0701
 epoch: 206, training loss: 0.0680
 epoch: 207, training loss: 0.0674
 epoch: 208, training loss: 0.0677
 epoch: 209, training loss: 0.0704
 epoch: 210, training loss: 0.0680, validation loss: 0.0792
 epoch: 211, training loss: 0.0697
 epoch: 212, training loss: 0.0702
 epoch: 213, training loss: 0.0691
 epoch: 214, training loss: 0.0679
 epoch: 215, training loss: 0.0700
 epoch: 216, training loss: 0.0702
 epoch: 217, training loss: 0.0683
 epoch: 218, training loss: 0.0669
 epoch: 219, training loss: 0.0676
 epoch: 220, training loss: 0.0669, validation loss: 0.0736
 epoch: 221, training loss: 0.0678
 epoch: 222, training loss: 0.0696
 epoch: 223, training loss: 0.0678
 epoch: 224, training loss: 0.0677
 epoch: 225, training loss: 0.0675
 epoch: 226, training loss: 0.0667
 epoch: 227, training loss: 0.0688
 epoch: 228, training loss: 0.0681
 epoch: 229, training loss: 0.0663
 epoch: 230, training loss: 0.0673, validation loss: 0.0865
 epoch: 231, training loss: 0.0677
 epoch: 232, training loss: 0.0650
 epoch: 233, training loss: 0.0652
 epoch: 234, training loss: 0.0660
 epoch: 235, training loss: 0.0664
 epoch: 236, training loss: 0.0659
 epoch: 237, training loss: 0.0664
 epoch: 238, training loss: 0.0652
 epoch: 239, training loss: 0.0661
 epoch: 240, training loss: 0.0665, validation loss: 0.0772
Finished Training
```




Example test

In [None]:
'''
This cell applies the model on a track from the validation dataset as a simple
test to check the results of the training.
'''

track = validationMusdb[0]
track.chunk_duration = min(30, track.duration)

print("mixture")
display(Audio(track.targets['linear_mixture'].audio.T, rate=44100))

print("expected output")
display(Audio(track.targets['drums'].audio.T, rate=44100))
display(Audio(track.targets['bass'].audio.T, rate=44100))
display(Audio(track.targets['other'].audio.T, rate=44100))
display(Audio(track.targets['vocals'].audio.T, rate=44100))

mono = np.mean(track.audio, axis=1)
song = track.stems.transpose(0, 2, 1)
song = (song - mono.mean()) / mono.std()
streams = th.from_numpy(song).float()
streams = streams.to(device)

output = apply_model(model, streams[0], shifts=apply_shifts)
output = output * mono.std() + mono.mean()

print("estimated output")
for target in output:
  display(Audio(target.cpu().detach().numpy(), rate=44100))

del streams, output
th.cuda.empty_cache() 

Save test estimates and compute loss

In [None]:
'''
This cell applies the model on all the mixtures in the test dataset, saves the
estimates in the google drive and computes the average loss.
'''

#del training_set
test_set = musdb.DB(Path("musdb"), subsets=["test"])
model.eval()
for p in model.parameters():
  p.requires_grad = False
  p.grad = None

test_loss_sum = 0
test_loss = 0

for i, track in enumerate(test_set):
  print('\r', '%d/%d' % (i + 1, len(test_set)), end = "")

  mono = np.mean(track.audio, axis=1)
  song = track.stems.transpose(0, 2, 1)
  song = (song - mono.mean()) / mono.std()

  streams = th.from_numpy(song).float().to(device)
  reference = streams[1:]
  input = streams[0]

  output = apply_model(model, input, shifts=apply_shifts, split=True)

  loss = criterion(output, reference)
  test_loss_sum += loss.item()
  test_loss = test_loss_sum / (1 + i)

  print('\r', '%d/%d, loss: %.4f' % (i + 1, len(test_set), test_loss), end="")

  output = output * mono.std() + mono.mean()

  estimates = output.cpu().numpy().transpose(0, 2, 1)
  estimates_dict = {
    "drums": estimates[0], 
    "bass": estimates[1], 
    "other": estimates[2], 
    "vocals": estimates[3], 
  }
  test_set.save_estimates(estimates_dict, track, Path(estimates_path))  

  del input, output, loss
  th.cuda.empty_cache()

The loss after the last training I did was 0.1027