In [None]:
import os
import librosa
import librosa.display
import matplotlib.pyplot as plt
import IPython.display as ipd
import torch
import numpy as np
from torch.utils.data import DataLoader

from model import MixingModel
from dataset import AudioMixingDataset
from inference_utils import mix_song

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Torch version: ', torch.__version__)
print('Device: ', device)

print(torch.backends.cudnn.version())
torch.backends.cudnn.benchmark = True

In [None]:
base_path = '/media/apelykh/bottomless-pit/datasets/mixing/MUSDB18HQ'
weights_dir = './weights'
seed = 321
chunk_length = 5

In [None]:
d_train = AudioMixingDataset(os.path.join(base_path, 'train'),
                             chunk_length=chunk_length, train_val_test_split=(0.95, 0.0, 0.0),
                             mode='train', seed=seed)

d_val = AudioMixingDataset(os.path.join(base_path, 'train'),
                             chunk_length=chunk_length, train_val_test_split=(0.0, 0.05, 0.0),
                             mode='val', seed=seed)

d_test = AudioMixingDataset(os.path.join(base_path, 'test'),
                             chunk_length=chunk_length, train_val_test_split=(0.0, 0.0, 1.0),
                             mode='test', seed=seed)

print('Train: {} tracks, {} chunks'.format(d_train.get_num_songs(), len(d_train)))
print('Val: {} tracks, {} chunks'.format(d_val.get_num_songs(), len(d_val)))
print('Test: {} tracks, {} chunks'.format(d_test.get_num_songs(), len(d_test)))

In [None]:
plt.figure(figsize=(12, 12))

sample = d_val[10]
# ipd.Audio(d[1103]['drums_audio'], rate=44100)

print(sample['train_features'].shape)

summed_spec = sample['train_features'].sum(axis=0)

print(np.min(summed_spec), np.max(summed_spec))
print(summed_spec)

print(summed_spec.shape)
ax1 = plt.subplot(2,1,1)
librosa.display.specshow(summed_spec)
plt.title('Summed track spectrograms')

print(np.min(sample['gt_features']), np.max(sample['gt_features']))
print(sample['gt_features'])

print(sample['gt_features'].shape)
ax2 = plt.subplot(2,1,2, sharex=ax1)
librosa.display.specshow(sample['gt_features'])
plt.title('Mixture spectrogram')

plt.show()

In [None]:
for i in range(len(d)):
    print('CHUNK: {}'.format(i))
    print('---------------')
    sample = d[i]
    if i == 20:
#         print(sample['drums_feature'].shape)
#         librosa.display.specshow(sample['mixture_feature'])
        ipd.Audio(sample['drums_audio'], rate=44100)
        print('---------------')
        break

In [None]:
train_loader = DataLoader(d_train, batch_size=32, shuffle=False,
                          num_workers=0, collate_fn=None,
                          pin_memory=False, drop_last=False, timeout=0,
                          worker_init_fn=None)

val_loader = DataLoader(d_val, batch_size=136, shuffle=False,
                        num_workers=0, collate_fn=None,
                        pin_memory=False, drop_last=False, timeout=0,
                        worker_init_fn=None)

test_loader = DataLoader(d_test, batch_size=32, shuffle=False,
                        num_workers=0, collate_fn=None,
                        pin_memory=False, drop_last=False, timeout=0,
                        worker_init_fn=None)

In [None]:
for batch in val_loader:
    print(batch['gt_features'].shape)
    break

---
### Defining and training the model

In [None]:
model = MixingModel().to(device)

num_trainable_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('{} trainable parameters'.format(num_trainable_param))

In [None]:
weights = './weights/mixmodel_bs136_0020_3.346.pt'
model.load_state_dict(torch.load(weights, map_location=device))

In [None]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
def train_model(start_epoch, num_epochs):
    loss_hist = []

    for epoch in range(start_epoch, num_epochs):
        running_loss = 0.0

        for i, batch in enumerate(val_loader):
            # zero the parameter gradients
            optimizer.zero_grad()

            masked, masks = model(batch['train_features'].to(device))
    #         print(masked.shape)
    #         print(masks[0][0].shape)
    #         print(masks[0])
    #         masked_np = masked[0].detach().numpy()
    #         librosa.display.specshow(masked_np)

            loss = criterion(masked, batch['gt_features'].to(device))
            loss.backward()
            optimizer.step()

            each_n_batches = 1
            if i % each_n_batches == each_n_batches - 1:
                print('[%d, %4d] loss: %.3f' % (epoch + 1, i + 1, loss.item()))

            running_loss += loss.item()

        avg_epoch_loss = running_loss / len(val_loader)
        print('Epoch {} loss: {}\n'.format(epoch + 1, avg_epoch_loss))
        loss_hist.append(avg_epoch_loss)
    
    return loss_hist

In [None]:
start_epoch = 0
num_epochs = 20
loss_hist = train_model(start_epoch, num_epochs)

In [None]:
weights_file = os.path.join(weights_dir, 'mixmodel_bs{}_{:04d}_{:.3f}.pt'.format(136, 20, 3.346)) 
torch.save(model.state_dict(), weights_file)

In [None]:
plt.figure(figsize=(7,5))
plt.plot(loss_hist, label='Train loss')
plt.xlabel('Epochs')
plt.ylabel('Loss');
plt.xlim(0, 0 + 20)
plt.legend()
plt.tight_layout()
plt.savefig('{}/loss.png'.format(weights_dir))

---
## Model Inference

In [None]:
song = d_train[120]
print('Song index: ', song['song_index'])
print('Song name: ', song['song_name'])

sum_audio = np.zeros_like(song['mixture_audio'])

for track in d_test._tracklist:
    if track != 'mixture':
        sum_audio += song['{}_audio'.format(track)]

ipd.Audio(sum_audio, rate=44100)

In [None]:
ipd.Audio(song['mixture_audio'], rate=44100)

In [None]:
features = torch.Tensor(song['train_features'][np.newaxis, :])
masked, masks = model(features.to(device))

res = masked.to('cpu').detach().numpy()
print(res[0])

In [None]:
librosa.display.specshow(res[0])

---
### Mixing the full song

In [None]:
song_path = os.path.join(base_path, 'test/The Easton Ellises - Falcon 69')
loaded_tracks = {}

for track in d_train.get_tracklist():
    track_path = os.path.join(song_path, '{}.wav'.format(track))
    loaded_tracks[track], _ = librosa.load(track_path, sr=44100)

In [None]:
# ground truth
ipd.Audio(loaded_tracks['mixture'], rate=44100)

In [None]:
# summed tracks, no mix
sum_audio = np.zeros_like(loaded_tracks['mixture'])

for track in d_train.get_tracklist():
    if track != 'mixture':
        sum_audio += loaded_tracks['{}'.format(track)]

ipd.Audio(sum_audio, rate=44100)

In [None]:
# model result
mixed_song = mix_song(d_train, model, loaded_tracks)
ipd.Audio(mixed_song, rate=44100)