In [None]:
import os
import numpy as np
import random
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import librosa
import librosa.display
from tqdm import tqdm_notebook
import warnings
import IPython.display as ipd
import museval as eval4
import mir_eval.separation as eval3
from time import sleep

warnings.simplefilter(action='ignore', category=Warning)

model_path = 'baseline_model'
model_name = 'baseline'

musdb_path = '../data/musdb18/preprocessed/'
musdb_train_path = musdb_path + 'train/'
musdb_valid_path = musdb_path + 'valid/'
musdb_test_path = musdb_path + 'test/'


mix_name = 'linear_mixture'
target_names = ['vocals', 'drums', 'bass', 'other']
target_name = target_names[0]

dim_c = 4
dim_f = 2048
dim_t = 128
n_fft = 4096
hop_length = 1024
sampling_rate = 44100
chunk_size = n_fft + hop_length * (dim_t-1)
trim = 5000  # trim each generated signal piece before concat

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def to_specs(signal):
    specs = []
    for channel in signal:
        spectrogram = librosa.stft(np.array(channel, dtype=np.float32), n_fft=n_fft, center=False, hop_length=hop_length)
        specs.append(spectrogram.real)
        specs.append(spectrogram.imag)
    return np.array(specs)

def restore(specs):
    specs = np.reshape(specs, (-1, 2, dim_f+1, dim_t))
    channels = []
    for ri in specs:
        signal = librosa.istft(ri[0] + 1.j * ri[1], center=False, hop_length=hop_length)
        channels.append(signal)
    return np.array(channels)


def load(path, s=0.0, d=None):
    return librosa.load(path, sr=None, mono=False, offset=s, duration=d)[0]

print(chunk_size/sampling_rate)

In [None]:
musdb_specs = []

for i in tqdm_notebook(range(86)):
    musdb_specs.append({})
    for t in target_names:
        musdb_specs[-1][t] = to_specs(load('{0}/{1:02}.wav'.format(musdb_train_path + t, i)))


In [None]:
class MssDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.size = len(self.data)
    
    def __len__(self):
        return self.size
    
    def __getitem__(self, index):
        
        def chunk(specs):
            s = np.random.randint(specs.shape[-1] - dim_t)
            return specs[:,:,s:s+dim_t]
        
        target = chunk(self.data[index][target_name])
        
        # data augmentation (mix instruments from different songs)
        mix = target
        for t_name in target_names:
            if t_name!=target_name:
                index2 = np.random.randint(len(self))
                target2 = chunk(self.data[index2][t_name])
                mix = mix + target2

        return torch.tensor(mix), torch.tensor(target) 
    
    
train_set = MssDataset(musdb_specs)
train_iter = DataLoader(train_set, batch_size=12, shuffle=True)

train_set[0][0].shape

## Modeling

In [None]:
class DenseBlock(nn.Module):
    def __init__(self, in_c, l, g, kx, ky):
        super(DenseBlock, self).__init__()
        
        c = in_c
        self.H = nn.ModuleList()
        for i in range(l):
            self.H.append(
                nn.Sequential(
                    nn.BatchNorm2d(c),
                    nn.Conv2d(in_channels=c, out_channels=g, kernel_size=(kx, ky), stride=1, padding=(kx//2, ky//2)),
                    nn.ReLU(),
                )
            )
            c += g

    def forward(self, x):
        x_ = self.H[0](x)
        for h in self.H[1:]:
            x = torch.cat((x_, x), 1)
            x_ = h(x)    
               
        return x_

In [None]:
class Baseline(nn.Module):
    def __init__(self, L, l, g, kx, ky, bn_factor, t_scale):
        super(Baseline, self).__init__()
        self.n = L//2
        
        self.first_conv = nn.Sequential(
            nn.BatchNorm2d(dim_c),
            nn.Conv2d(in_channels=dim_c, out_channels=g, kernel_size=(2,1), stride=1),
            nn.ReLU(),
        )
        
        f = dim_f
        self.ds_dense = nn.ModuleList()
        self.ds = nn.ModuleList()
        for i in range(self.n):
            self.ds_dense.append(DenseBlock(g, l, g, kx, ky, f, bn_factor))
            
            scale = (2,2) if i in t_scale else (1,2)
            self.ds.append(
                nn.Conv2d(in_channels=g, out_channels=g, kernel_size=scale, stride=scale)
            )
            f = f//2
        
        self.mid_dense = DenseBlock(g, l, g, kx, ky, f, bn_factor)
        
        self.us_dense = nn.ModuleList()
        self.us = nn.ModuleList()
        for i in range(self.n):
            scale = (2,2) if i in self.n-1 - t_scale else (1,2)
            self.us.append(
                nn.ConvTranspose2d(in_channels=g, out_channels=g, kernel_size=scale, stride=scale)
            )
            f = f*2
            
            self.us_dense.append(DenseBlock(2*g, l, g, kx, ky, f, bn_factor))
            
      
        self.final_conv = nn.Conv2d(in_channels=g, out_channels=dim_c, kernel_size=(2,1), stride=1, padding=(1,0))
            
        
        
    def forward(self, x):
        x = self.first_conv(x)
        
        x = x.transpose(-1,-2)
        
        ds_outputs = []
        for i in range(self.n):
            x = self.ds_dense[i](x)
            ds_outputs.append(x)
            x = self.ds[i](x)
        
        x = self.mid_dense(x)
        
        for i in range(self.n):
            x = self.us[i](x)
            x = torch.cat((x, ds_outputs[-i-1]), 1)
            x = self.us_dense[i](x)
        
        x = x.transpose(-1,-2)
        
        return self.final_conv(x)

## Training

In [None]:
model = Baseline(L=9, l=5, g=24, fx=3, fy=3, bn_factor=16, t_scale=np.array([1,2,3])).to(device)

#model = nn.DataParallel(model, device_ids=[0,1,2]).to(device)

loss_trace = []

def init_weights(model, ckpt, lr, epoch=None):
    if ckpt==0:
        print('Start training')
        for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p)
    else:
        global loss_trace
        loss_trace = list(np.load('{0}/{1}/{2}_lr{3}_loss.npy'.format(model_path, model_name, target_name, lr)))
        if epoch is None:
            model.load_state_dict(torch.load('{0}/{1}/{2}_lr{3}.pt'.format(model_path, model_name, target_name, lr)))
            optim.load_state_dict(torch.load('{0}/{1}/{2}_lr{3}_optim.pt'.format(model_path, model_name, target_name, lr)))
        else:
            model.load_state_dict(torch.load('{0}/{1}/{2}_lr{3}_e{4:04}.pt'.format(model_path, model_name, target_name, lr, epoch)))
            optim.load_state_dict(torch.load('{0}/{1}/{2}_lr{3}_e{4:04}_optim.pt'.format(model_path, model_name, target_name, lr, epoch)))
            loss_trace = loss_trace[:epoch]
            
criterion = nn.MSELoss()

lr = 1e-3

optim = torch.optim.RMSprop(model.parameters(), lr=lr)

sum(p.numel() for p in model.parameters() if p.requires_grad) 

In [None]:
num_epochs = 80000
ckpt_steps = 500
ckpt_min_epoch = 5000

In [None]:
init_weights(model, 1, lr=lr, epoch=None)


save_steps = 5
start_e = len(loss_trace) + 1

for e in range(start_e, num_epochs):
    model.train()
    loss_sum = 0
    for mix, tar in tqdm_notebook(train_iter):
        mix = mix.to(device)
        tar = tar.to(device)
        y_hat = model(mix)
        loss = criterion(y_hat, tar)
        loss.backward()
        optim.step()
        optim.zero_grad()
        loss_sum += loss.item() * mix.shape[0]
    
    ipd.clear_output(wait=True)
    print('epoch' , e)

    epoch_avg_loss = loss_sum / len(train_set)
    loss_trace.append(epoch_avg_loss)
    print(epoch_avg_loss)
    plt.plot(loss_trace)
    plt.ylim(0, 0.5)
    plt.show()
    
    if e%save_steps==0:
        torch.save(optim.state_dict(), '{0}/{1}/{2}_lr{3}_optim.pt'.format(model_path, model_name, target_name, lr))
        torch.save(model.state_dict(), '{0}/{1}/{2}_lr{3}.pt'.format(model_path, model_name, target_name, lr))
        np.save('{0}/{1}/{2}_lr{3}_loss.npy'.format(model_path, model_name, target_name, lr), loss_trace)
    if e>=ckpt_min_epoch and e%ckpt_steps==0:
        torch.save(optim.state_dict(), '{0}/{1}/{2}_lr{3}_e{4:04}_optim.pt'.format(model_path, model_name, target_name, lr, e))
        torch.save(model.state_dict(), '{0}/{1}/{2}_lr{3}_e{4:04}.pt'.format(model_path, model_name, target_name, lr, e))
    

## Validation

In [None]:
def preprocess_track(y):
    n_sample = y.shape[1]
    
    gen_size = chunk_size-2*trim
    pad = gen_size - n_sample%gen_size
    y_p = np.concatenate((np.zeros((2,trim)), y, np.zeros((2,pad)), np.zeros((2,trim))), 1)
    
    all_specs = []
    i = 0
    while i < n_sample + pad:
        specs = to_specs(y_p[:, i:i+chunk_size])
        all_specs.append(specs)
        i += gen_size

    return torch.tensor(all_specs), pad

def separate(model, mix_path):
    model.eval()
    
    mix_specs, pad_len = preprocess_track(mix_path)
    
    # create batches
    batch_size = 8
    i = 0
    num_intervals = mix_specs.shape[0]
    batches = []
    while i < num_intervals:
        batches.append(mix_specs[i:i+batch_size])
        i = i + batch_size

    # obtain estimated target spectrograms
    tar_signal = np.array([[],[]])
    with torch.no_grad():
        for batch in tqdm_notebook(batches):
            tar_specs = model(batch.to(device))
            for tar_spec in tar_specs:
                est_interval = np.array(restore(tar_spec.detach().cpu().numpy()))[:, trim:-trim]
                tar_signal = np.concatenate((tar_signal, est_interval), 1)
            
    return tar_signal[:, :-pad_len]

In [None]:
def median_nan(a):
    return np.median(a[~np.isnan(a)])

def musdb_sdr(ref, est, sr=sampling_rate):
    sdr, isr, sir, sar, perm = eval4.metrics.bss_eval(ref, est, window=sr, hop=sr)
    return median_nan(sdr[0])

def mse(ref, est):
    return ((ref-est)**2).mean()

In [None]:
max_epoch = 80000
min_epoch = 20000
cs = 500

num_ckpts = (max_epoch - min_epoch) // cs + 1

scores_mean = []
score_path = '{0}/{1}/lr{2}_valid_{3}.npy'.format(model_path, model_name, lr, target_name)

try:
    scores_mean = list(np.load(score_path))
except Exception:
    pass

for c in range(num_ckpts):
    ckpt_scores = []
    init_weights(model, 1, lr=lr, epoch=min_epoch + c*cs)
    for i in tqdm_notebook(range(14)):
        est = separate(model, load('{0}/{1}/{2:02}.wav'.format(musdb_valid_path, mix_name, i)))
        ref = load('{0}/{1}/{2:02}.wav'.format(musdb_valid_path, target_name, i))
    
        score = mse(ref, est)
        print(score)
        
        ckpt_scores.append(score)
    
    ckpt_score_mean = np.array(ckpt_scores).mean()
    scores_mean.append(ckpt_score_mean)
    
    ipd.clear_output(wait=True)
    print(ckpt_score_mean)
    plt.plot(scores_mean)
    plt.show()

np.save(score_path, np.array(scores_mean))

## Generate

In [None]:
init_weights(model, 1, lr=lr, epoch=49500)

for i in tqdm_notebook(range(50)):
    track_name = 'test_{0:02}'.format(i)   
    tar_signal = separate(model, load('{0}{1}/{2:02}.wav'.format(musdb_test_path, mix_name, i)))
    
    t_path = '{0}/{1}/estimates_sources/{2}/{3}.wav'.format(model_path, model_name, track_name, target_name)
    librosa.output.write_wav(t_path, np.array(tar_signal, np.float32), sampling_rate)

## Evaluation

In [None]:
import math
import statistics as stats

SDR = []
for i in range(50):
    
    ref = load('{0}/{1}/{2:02}.wav'.format(musdb_test_path, target_name, i))
    est = load('{0}/{1}/estimates_sources/test_{2:02}/{3}.wav'.format(model_path, model_name, i, target_name))
    sdr = musdb_sdr(np.array([ref.T]), np.array([est.T]))
    
    SDR.append(sdr)

    
    ipd.clear_output(wait=True)
    print(sdr)
    plt.plot(SDR)
    plt.show()

    print('SDR mean:', stats.mean(SDR))
    print('SDR median:', stats.median(SDR))


## Create folders

In [None]:
os.mkdir('{0}/{1}'.format(model_path, model_name))

os.mkdir('{0}/{1}/estimates_sources'.format(model_path, model_name))
for i in tqdm_notebook(range(50)):
    track_name = 'test_{0:02}'.format(i)
    p = '{0}/{1}/estimates_sources/test_{2:02}'.format(model_path, model_name, i)
    os.mkdir(p)