In [None]:
# Import the packages we'll use

import numpy as np
import os, glob, csv

# librosa is a widely-used audio processing library
import librosa

import sklearn
import scipy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
from torchvision import transforms

from tensorboardX import SummaryWriter
from torch.optim.lr_scheduler import MultiStepLR

from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

# for plotting
%matplotlib inline
import matplotlib.pyplot as plt

import math

# for accuracy and confusion matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
# for data normalization
from sklearn.preprocessing import StandardScaler

# use GPU if available, otherwise, use cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
import mir_eval  
import museval.metrics as metrics
import numpy as np
from scipy.io import wavfile

In [None]:
ids = os.listdir('./musdb18hq/train')
ids = [x for x in ids if not x.startswith('.')]
filenames =[]

In [None]:
def load_audio(path, seq):
    example_audio, sr = librosa.load(path, sr=44100, offset=seq*30.0, duration=60.0)
    return example_audio

In [None]:
def load_mixture(filename,seq, win_len=0.05, hop_len=0.0125, n_mels=64):
    audio, sr = librosa.load("%s.wav" % (filename), sr=44100, offset=seq*30.0, duration=60.0)
    win_len = int(win_len*sr)
    hop_len = int(hop_len*sr)
    spec, phase = librosa.magphase(librosa.stft(audio, n_fft=win_len,hop_length=hop_len,window='hann',center='True'))
    spec = spec.transpose((1,0))
    phase = phase.transpose((1,0))
    if(len(spec) < 4802):
        spec = np.pad(spec, ((4802-len(spec),0),(0,0)))
        phase = np.pad(phase, ((4802-len(phase),0),(0,0)))
    spec = spec[:4802]
    phase = phase[:4802]
    spec = spec.transpose((1,0))
    phase = phase.transpose((1,0))
    return spec, phase

In [None]:
def load_segments(filename,seq, win_len=0.05, hop_len=0.0125, n_mels=64):
    audio, sr = librosa.load("%s.wav" % (filename), sr=44100,offset=seq*30.0,duration=60.0)
    win_len = int(win_len*sr)
    hop_len = int(hop_len*sr)
    spec, phase = librosa.magphase(librosa.stft(audio, n_fft=win_len,hop_length=hop_len,window='hann',center='True'))
    spec = spec.transpose((1,0))
    if(len(spec) < 4802):
        spec = np.pad(spec, ((4802-len(spec),0),(0,0)))
    spec = spec[:4802]
    spec = spec.transpose((1,0))
    return spec

In [None]:
class Dataset(BaseDataset):
    
    def __init__(self,ids, seq, path='./musdb18hq/train', transforms=None):
            
        self.ids = ids
        self.ids = [x for x in self.ids if not x.startswith('.')]
        self.path = path+'/'
        self.seq = seq
        
    def __getitem__(self, i):
        
        # read data
        mixture_path = '/mixture'
        bass_path = '/bass'
        vocals_path = '/vocals'
        drums_path = '/drums'
        others_path = '/other'
        mixture, m_phase = load_mixture(self.path+self.ids[i]+mixture_path,seq = self.seq)
        #phase = torch.load(mixture_path+self.list[index]+'_p')
        bass = load_segments(self.path+self.ids[i]+bass_path, seq = self.seq)
        vocals = load_segments(self.path+self.ids[i]+vocals_path, seq = self.seq)
        drums = load_segments(self.path+self.ids[i]+drums_path, seq = self.seq)
        others = load_segments(self.path+self.ids[i]+others_path, seq = self.seq)
            
        return mixture, m_phase, bass, vocals, drums, others
        
    def __len__(self):
        return len(self.ids)

In [None]:
class ConvNet(nn.Module):
    def __init__(self,input_shape=[1103, 4803]):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=10,
                               kernel_size=(1102,1), stride=1)
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2,2),stride=2)
        self.conv2 = nn.Conv2d(10,20, kernel_size=(1,25))
        """self.maxpool2 = nn.MaxPool2d(kernel_size=(2,2),stride=2)
        self.conv3 = nn.Conv2d(20,30, kernel_size=(5,15))
        self.deconv3_1 = nn.ConvTranspose2d(30, 20, kernel_size=(5,15))
        self.deconv3_2 = nn.ConvTranspose2d(30, 20, kernel_size=(5,15))
        self.deconv3_3 = nn.ConvTranspose2d(30, 20, kernel_size=(5,15))
        self.deconv3_4 = nn.ConvTranspose2d(30, 20, kernel_size=(5,15))
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')"""
        self.deconv2_1 = nn.ConvTranspose2d(20, 10, kernel_size=(1,25), stride=1)
        self.deconv2_2 = nn.ConvTranspose2d(20, 10, kernel_size=(1,25), stride=1)
        self.deconv2_3 = nn.ConvTranspose2d(20, 10, kernel_size=(1,25), stride=1)
        self.deconv2_4 = nn.ConvTranspose2d(20, 10, kernel_size=(1,25), stride=1)
        self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.deconv1_1 = nn.ConvTranspose2d(10, 5, kernel_size=(1102,1), stride=1)
        self.deconv1_2 = nn.ConvTranspose2d(10, 5, kernel_size=(1102,1), stride=1)
        self.deconv1_3 = nn.ConvTranspose2d(10, 5, kernel_size=(1102,1), stride=1)
        self.deconv1_4 = nn.ConvTranspose2d(10, 5, kernel_size=(1102,1), stride=1)
        self.conv = nn.Conv2d(5,1, kernel_size=(1,1))
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        (_, time_len, mel_bins) = x.shape

        x = x.view(-1, 1, time_len, mel_bins)
        x = F.relu(self.maxpool1(self.conv1(x)))
        x = F.relu(self.conv2(x))
        """x = F.relu(self.conv3(x))
        
        x1 = self.deconv3_1(x)
        x2 = self.deconv3_2(x)
        x3 = self.deconv3_3(x)
        x4 = self.deconv3_4(x)
        
        x1 = self.dropout(x1)
        x2 = self.dropout(x2)
        x3 = self.dropout(x3)
        x4 = self.dropout(x4)
        
        x1 = self.upsample1(x1)
        x2 = self.upsample1(x2)
        x3 = self.upsample1(x3)
        x4 = self.upsample1(x4)"""
        

        x1 = self.deconv2_1(x)
        x2 = self.deconv2_2(x)
        x3 = self.deconv2_3(x)
        x4 = self.deconv2_4(x)
        
        x1 = self.dropout(x1)
        x2 = self.dropout(x2)
        x3 = self.dropout(x3)
        x4 = self.dropout(x4)
        
        x1 = self.upsample2(x1)
        x2 = self.upsample2(x2)
        x3 = self.upsample2(x3)
        x4 = self.upsample2(x4)
                                
        x1 = self.deconv1_1(x1)
        x2 = self.deconv1_2(x2)
        x3 = self.deconv1_3(x3)
        x4 = self.deconv1_4(x4)
        
        x1 = self.dropout(x1)
        x2 = self.dropout(x2)
        x3 = self.dropout(x3)
        x4 = self.dropout(x4)
        
        x1 = F.relu(self.conv(x1))
        x2 = F.relu(self.conv(x2))
        x3 = F.relu(self.conv(x3))
        x4 = F.relu(self.conv(x4))

        return x1, x2, x3, x4

In [None]:
netb = ConvNet([1103, 4802])

In [None]:
from torchsummary import summary
summary(netb, (1103, 4802))

In [None]:
mean_var_path= "../Processed/"
if not os.path.exists('Weights'):
    os.makedirs('Weights')
#os.environ["CUDA_VISIBLE_DEVICES"]="0"
#--------------------------
class Average(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.sum += val
        self.count += n

    #property
    def avg(self):
        return self.sum / self.count
#------------------------------
# import csv
writer = SummaryWriter()
#----------------------------------------

inp_size = [1103, 4802]
alpha = 0.005
beta = 0.05
beta_vocals = 0.08
batch_size = 15
num_epochs = 50
cur_epoch =47

class MixedSquaredError(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(MixedSquaredError, self).__init__()

    def forward(self, pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums, gt_others):


        L_sq = torch.sum((pred_bass-gt_bass).pow(2)) + torch.sum((pred_vocals-gt_vocals).pow(2)) + torch.sum((pred_drums-gt_drums).pow(2)) + torch.sum((pred_others-gt_others).pow(2))
        L_other = torch.sum((pred_bass-gt_others).pow(2)) + torch.sum((pred_drums-gt_others).pow(2))
        #+ torch.sum((pred_vocals-gt_others).pow(2))
        L_othervocals = torch.sum((pred_vocals - gt_others).pow(2))
        L_diff = torch.sum((pred_bass-pred_vocals).pow(2)) + torch.sum((pred_bass-pred_drums).pow(2)) + torch.sum((pred_vocals-pred_drums).pow(2))

        return (L_sq- alpha*L_diff - beta*L_other - beta_vocals*L_othervocals)

def TimeFreqMasking(bass,vocals,drums,others,cuda=0):
    den = torch.abs(bass) + torch.abs(vocals) + torch.abs(drums) + torch.abs(others)
    if(cuda):
        den = den + 10e-8*torch.cuda.FloatTensor(bass.size()).normal_()
    else:
        den = den + 10e-8*torch.FloatTensor(bass.size()).normal_()
        
    
    bass = torch.abs(bass)/den
    vocals = torch.abs(vocals)/den
    drums = torch.abs(drums)/den
    others = torch.abs(others)/den
    
    return bass,vocals,drums,others


def train():
    
    cuda = torch.cuda.is_available()
    net = ConvNet(inp_size)
    net.load_state_dict(torch.load('Weights/Weights_9_128225384.0.pth'))
    criterion = MixedSquaredError()
    if cuda:
        net = net.cuda()
        criterion = criterion.cuda()
    optimizer = torch.optim.Adam(net.parameters(), lr = 0.001)

    print("preparing training data ...")
    
    ids = os.listdir('./musdb18hq/train')
    ids = [x for x in ids if not x.startswith('.')]
    filenames = []
    
    val_ids = os.listdir('./musdb18hq/val')
    val_ids = [x for x in val_ids if not x.startswith('.')]
    val_filenames = []
    for i in val_ids:
        # load an example audio file, converting the data to mel spectrogram
        path = './musdb18hq/val/'+i+'/mixture.wav'
        example_audio = load_audio(path, seq = 0)
        if example_audio.shape[0] > 0:
            val_filenames.append(i)
    val_set = Dataset(ids = val_filenames, seq = 0, path='./musdb18hq/val',transforms = None)
    val_loader = DataLoader(val_set, batch_size=batch_size,shuffle=False)
    
    
    for epoch in range(cur_epoch, num_epochs+1):
        batch = 0
        filenames = []
        for i in ids:
            # load an example audio file, converting the data to mel spectrogram
            path = './musdb18hq/train/'+i+'/mixture.wav'
            example_audio = load_audio(path, seq = 0)
            if example_audio.shape[0] > 0:
                filenames.append(i)
        print(len(filenames))
        train_set = Dataset(ids = filenames, seq = 0, transforms = None)
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        print("done ...")

        train_loss = Average()       

        net.train()
        for i, (inp,phase, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(train_loader):
            print(f'epoch {epoch+1}.... batch {batch+i+1}')
            inp = torch.FloatTensor(inp)
            mean = torch.mean(inp)
            std = torch.std(inp)
            inp_n = (inp-mean)/std
            inp_n = torch.FloatTensor(inp_n)
            gt_bass = torch.FloatTensor(gt_bass)
            gt_vocals = torch.FloatTensor(gt_vocals)
            gt_drums = torch.FloatTensor(gt_drums)
            gt_others= torch.FloatTensor(gt_others)
            if cuda:
                inp = inp.cuda()
                inp_n = inp_n.cuda()
                gt_bass = gt_bass.cuda()
                gt_vocals = gt_vocals.cuda()
                gt_drums = gt_drums.cuda()
                gt_others= gt_others.cuda()
            optimizer.zero_grad()
            bass, vocals, drums, others = net(torch.FloatTensor(inp_n))
            mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass, vocals, drums, others,cuda)

            pred_drums=inp*np.squeeze(mask_drums)
            pred_vocals=inp*np.squeeze(mask_vocals)
            pred_bass=inp*np.squeeze(mask_bass)
            pred_others=inp*np.squeeze(mask_others)

            loss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums,gt_others)
            print(loss.item())
            writer.add_scalar('Train Loss',loss,epoch)
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), inp.size(0))
            
            
        #..........
        filenames = []
        print(len(filenames))
        for i in ids:
            # load an example audio file, converting the data to mel spectrogram
            examplefkey = './musdb18hq/train/'+i+'/mixture.wav'
            example_audio = load_audio(examplefkey, seq = 1)
            if example_audio.shape[0] > 0:
                filenames.append(i)
        print(len(filenames))
        train_set = Dataset(ids = filenames, seq = 1, transforms = None)
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        print("done ...")

        for i, (inp,phase, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(train_loader):
            print(f'epoch {epoch+1}.... batch {batch+i+1}')
            inp = torch.FloatTensor(inp)
            mean = torch.mean(inp)
            std = torch.std(inp)
            inp_n = (inp-mean)/std
            inp_n = torch.FloatTensor(inp_n)
            gt_bass = torch.FloatTensor(gt_bass)
            gt_vocals = torch.FloatTensor(gt_vocals)
            gt_drums = torch.FloatTensor(gt_drums)
            gt_others= torch.FloatTensor(gt_others)
            if cuda:
                inp = inp.cuda()
                inp_n = inp_n.cuda()
                gt_bass = gt_bass.cuda()
                gt_vocals = gt_vocals.cuda()
                gt_drums = gt_drums.cuda()
                gt_others= gt_others.cuda()
            optimizer.zero_grad()
            bass, vocals, drums, others = net(torch.FloatTensor(inp_n))
            mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass, vocals, drums, others,cuda)

            pred_drums=inp*np.squeeze(mask_drums)
            pred_vocals=inp*np.squeeze(mask_vocals)
            pred_bass=inp*np.squeeze(mask_bass)
            pred_others=inp*np.squeeze(mask_others)

            loss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums,gt_others)
            writer.add_scalar('Train Loss',loss,epoch)
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), inp.size(0))
        
        #..................
        filenames = []
        for i in ids:
            # load an example audio file, converting the data to mel spectrogram
            examplefkey = './musdb18hq/train/'+i+'/mixture.wav'
            example_audio = load_audio(examplefkey, seq = 2)
            if example_audio.shape[0] > 0:
                filenames.append(i)
        print(len(filenames))
        train_set = Dataset(ids = filenames, seq = 2, transforms = None)
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        print("done ...")

        for i, (inp,phase, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(train_loader):
            print(f'epoch {epoch+1}.... batch {batch+i+1}')
            inp = torch.FloatTensor(inp)
            mean = torch.mean(inp)
            std = torch.std(inp)
            inp_n = (inp-mean)/std
            inp_n = torch.FloatTensor(inp_n)
            gt_bass = torch.FloatTensor(gt_bass)
            gt_vocals = torch.FloatTensor(gt_vocals)
            gt_drums = torch.FloatTensor(gt_drums)
            gt_others= torch.FloatTensor(gt_others)
            if cuda:
                inp = inp.cuda()
                inp_n = inp_n.cuda()
                gt_bass = gt_bass.cuda()
                gt_vocals = gt_vocals.cuda()
                gt_drums = gt_drums.cuda()
                gt_others= gt_others.cuda()
            optimizer.zero_grad()
            bass, vocals, drums, others = net(torch.FloatTensor(inp_n))
            mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass, vocals, drums, others,cuda)

            pred_drums=inp*np.squeeze(mask_drums)
            pred_vocals=inp*np.squeeze(mask_vocals)
            pred_bass=inp*np.squeeze(mask_bass)
            pred_others=inp*np.squeeze(mask_others)

            loss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums,gt_others)
            writer.add_scalar('Train Loss',loss,epoch)
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), inp.size(0))
        
        #..............
        filenames = []
        for i in ids:
            # load an example audio file, converting the data to mel spectrogram
            examplefkey = './musdb18hq/train/'+i+'/mixture.wav'
            example_audio = load_audio(examplefkey, seq = 3)
            if example_audio.shape[0] > 0:
                filenames.append(i)
        print(len(filenames))
        train_set = Dataset(ids = filenames, seq = 3, transforms = None)
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        print("done ...")

        for i, (inp,phase, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(train_loader):
            print(f'epoch {epoch+1}.... batch {batch+i+1}')
            inp = torch.FloatTensor(inp)
            mean = torch.mean(inp)
            std = torch.std(inp)
            inp_n = (inp-mean)/std
            inp_n = torch.FloatTensor(inp_n)
            gt_bass = torch.FloatTensor(gt_bass)
            gt_vocals = torch.FloatTensor(gt_vocals)
            gt_drums = torch.FloatTensor(gt_drums)
            gt_others= torch.FloatTensor(gt_others)
            if cuda:
                inp = inp.cuda()
                inp_n = inp_n.cuda()
                gt_bass = gt_bass.cuda()
                gt_vocals = gt_vocals.cuda()
                gt_drums = gt_drums.cuda()
                gt_others= gt_others.cuda()
            optimizer.zero_grad()
            bass, vocals, drums, others = net(torch.FloatTensor(inp_n))
            mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass, vocals, drums, others,cuda)

            pred_drums=inp*np.squeeze(mask_drums)
            pred_vocals=inp*np.squeeze(mask_vocals)
            pred_bass=inp*np.squeeze(mask_bass)
            pred_others=inp*np.squeeze(mask_others)

            loss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums,gt_others)
            writer.add_scalar('Train Loss',loss,epoch)
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), inp.size(0))
        
        #......................
        filenames = []
        for i in ids:
            # load an example audio file, converting the data to mel spectrogram
            examplefkey = './musdb18hq/train/'+i+'/mixture.wav'
            example_audio = load_audio(examplefkey, seq = 4)
            if example_audio.shape[0] > 0:
                filenames.append(i)
        print(len(filenames))
        train_set = Dataset(ids = filenames, seq = 4, transforms = None)
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        print("done ...")

        for i, (inp,phase, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(train_loader):
            print(f'epoch {epoch+1}.... batch {batch+i+1}')
            inp = torch.FloatTensor(inp)
            mean = torch.mean(inp)
            std = torch.std(inp)
            inp_n = (inp-mean)/std
            inp_n = torch.FloatTensor(inp_n)
            gt_bass = torch.FloatTensor(gt_bass)
            gt_vocals = torch.FloatTensor(gt_vocals)
            gt_drums = torch.FloatTensor(gt_drums)
            gt_others= torch.FloatTensor(gt_others)
            if cuda:
                inp = inp.cuda()
                inp_n = inp_n.cuda()
                gt_bass = gt_bass.cuda()
                gt_vocals = gt_vocals.cuda()
                gt_drums = gt_drums.cuda()
                gt_others= gt_others.cuda()
            optimizer.zero_grad()
            bass, vocals, drums, others = net(torch.FloatTensor(inp_n))
            mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass, vocals, drums, others,cuda)

            pred_drums=inp*np.squeeze(mask_drums)
            pred_vocals=inp*np.squeeze(mask_vocals)
            pred_bass=inp*np.squeeze(mask_bass)
            pred_others=inp*np.squeeze(mask_others)

            loss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums,gt_others)
            writer.add_scalar('Train Loss',loss,epoch)
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), inp.size(0))
        #.........................
        
        filenames = []
        for i in ids:
            # load an example audio file, converting the data to mel spectrogram
            examplefkey = './musdb18hq/train/'+i+'/mixture.wav'
            example_audio = load_audio(examplefkey, seq = 5)
            if example_audio.shape[0] > 0:
                filenames.append(i)
        print(len(filenames))
        train_set = Dataset(ids = filenames, seq = 5, transforms = None)
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        print("done ...")

        for i, (inp,phase, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(train_loader):
            print(f'epoch {epoch+1}.... batch {batch+i+1}')
            inp = torch.FloatTensor(inp)
            mean = torch.mean(inp)
            std = torch.std(inp)
            inp_n = (inp-mean)/std
            inp_n = torch.FloatTensor(inp_n)
            gt_bass = torch.FloatTensor(gt_bass)
            gt_vocals = torch.FloatTensor(gt_vocals)
            gt_drums = torch.FloatTensor(gt_drums)
            gt_others= torch.FloatTensor(gt_others)
            if cuda:
                inp = inp.cuda()
                inp_n = inp_n.cuda()
                gt_bass = gt_bass.cuda()
                gt_vocals = gt_vocals.cuda()
                gt_drums = gt_drums.cuda()
                gt_others= gt_others.cuda()
            optimizer.zero_grad()
            bass, vocals, drums, others = net(torch.FloatTensor(inp_n))
            mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass, vocals, drums, others,cuda)

            pred_drums=inp*np.squeeze(mask_drums)
            pred_vocals=inp*np.squeeze(mask_vocals)
            pred_bass=inp*np.squeeze(mask_bass)
            pred_others=inp*np.squeeze(mask_others)

            loss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums,gt_others)
            writer.add_scalar('Train Loss',loss,epoch)
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), inp.size(0))
        #.........................

        val_loss = Average()
        net.eval()
        for i,(val_inp,val_phase, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(val_loader):
            
            val_inp = torch.FloatTensor(val_inp)
            val_mean = torch.mean(val_inp)
            val_std = torch.std(val_inp)
            val_inp_n = (val_inp-val_mean)/val_std
            val_inp_n = torch.FloatTensor(val_inp_n)
            gt_bass = torch.FloatTensor(gt_bass)
            gt_vocals = torch.FloatTensor(gt_vocals)
            gt_drums = torch.FloatTensor(gt_drums)
            gt_others= torch.FloatTensor(gt_others)
            if cuda:
                val_inp = val_inp.cuda()
                val_inp_n = val_inp_n.cuda()
                gt_bass = gt_bass.cuda()
                gt_vocals = gt_vocals.cuda()
                gt_drums = gt_drums.cuda()
                gt_others = gt_others.cuda()

            bass, vocals, drums, others = net(val_inp_n)
            mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass, vocals, drums, others,cuda)

            pred_drums=val_inp*np.squeeze(mask_drums)
            pred_vocals=val_inp*np.squeeze(mask_vocals)
            pred_bass=val_inp*np.squeeze(mask_bass)
            pred_others=val_inp*np.squeeze(mask_others)
            

            """if (epoch)%10==0:
                writer.add_image('Validation Input',val_inp,epoch)
                writer.add_image('Validation Bass GT ',gt_bass,epoch)
                writer.add_image('Validation Bass Pred ',pred_bass,epoch)
                writer.add_image('Validation Vocals GT ',gt_vocals,epoch)
                writer.add_image('Validation Vocals Pred ',pred_vocals,epoch)
                writer.add_image('Validation Drums GT ',gt_drums,epoch)
                writer.add_image('Validation Drums Pred ',pred_drums,epoch)
                writer.add_image('Validation Other GT ',gt_others,epoch)
                writer.add_image('Validation Others Pred ',pred_others,epoch)"""

            vloss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums, gt_others)
            writer.add_scalar('Validation loss',vloss,epoch)
            val_loss.update(vloss.item(), inp.size(0))

        print("Epoch {}, Training Loss: {}, Validation Loss: {}".format(epoch+1, train_loss.avg(), val_loss.avg()))
        torch.save(net.state_dict(), 'Weights/Weights_{}_{}.pth'.format(epoch+1, val_loss.avg()))
    return net

In [None]:
train()

In [None]:
net = net = ConvNet(inp_size)
    # net.load_state_dict(torch.load('Weights/Weights_200_3722932.6015625.pth')) #least score Weights so far
net.load_state_dict(torch.load('Weights/Weights_49_178278304.0.pth'))
net.eval()
filenames = []
ids = os.listdir('./musdb18hq/test')
ids = [x for x in ids if not x.startswith('.')]
filenames = []
count = 0
for i in ids:
    examplefkey = './musdb18hq/test/'+i+'/mixture.wav'
    example_audio = load_audio(examplefkey, seq = 0)
    if example_audio.shape[0] > 0:
        filenames.append(i)
    count += 1
    if count > 10:
        break

print(len(filenames))
    
for i in filenames:
# load an example audio file, converting the data to mel spectrogram
    total_SDR_bass = total_ISR_bass =  total_SIR_bass = total_SAR_bass =  0
    total_SDR_drum = total_ISR_drum = total_SIR_drum = total_SAR_drum =  0
    total_SDR_vocal = total_ISR_vocal = total_SIR_vocal = total_SAR_vocal =  0
    total_SDR_other = total_ISR_other = total_SIR_other = total_SAR_other =  0
    
    examplefkey = './musdb18hq/test/'+i
    mixture, m_phase = load_mixture(examplefkey+'/mixture',0)
    bass_gt = load_segments(examplefkey+'/bass', seq = 0)
    vocals_gt = load_segments(examplefkey+'/vocals', seq = 0)
    drums_gt = load_segments(examplefkey+'/drums', seq = 0)
    others_gt = load_segments(examplefkey+'/other', seq = 0)
    mean = torch.mean(torch.tensor(mixture))
    std = torch.std(torch.tensor(mixture))
    inp_n = torch.tensor(mixture)
    (time_len, mel_bins) = inp_n.shape
    inp_n = inp_n.view(1,time_len, mel_bins)
    print(inp_n.shape)
    bass_mag, vocals_mag, drums_mag,others_mag = net(torch.tensor(inp_n))
    print(bass_mag.shape)
    drums_mag=torch.tensor(mixture)*np.squeeze(bass_mag)
    vocals_mag=torch.tensor(mixture)*np.squeeze(vocals_mag)
    bass_mag=torch.tensor(mixture)*np.squeeze(drums_mag)
    others_mag=torch.tensor(mixture)*np.squeeze(others_mag)
    print(bass_mag.shape)    
    vocals = np.squeeze(vocals_mag.detach().numpy())
        #print(vocals.shape)
    bass = np.squeeze(bass_mag.detach().numpy())
    drums = np.squeeze(drums_mag.detach().numpy())
    others = np.squeeze(others_mag.detach().numpy())
    
    shape = drums_gt.flatten().shape
    
    
    SDR_bass,ISR_bass, SIR_bass, SAR_bass, perm = mir_eval.separation.bss_eval_images(bass_gt.reshape((1, shape[0],1)), bass.reshape((1, shape[0], 1)))
    total_SDR_bass += abs(SDR_bass[0])
    total_ISR_bass += abs(ISR_bass[0])
    total_SIR_bass += abs(SIR_bass[0])
    total_SAR_bass += abs(SAR_bass[0])
    SDR_drum, ISR_drum, SIR_drum, SAR_drum, perm = mir_eval.separation.bss_eval_images(drums_gt.reshape((1, shape[0],1)), drums.reshape((1, shape[0],1)))
    total_SDR_drum += abs(SDR_drum[0])
    total_ISR_drum += abs(ISR_drum[0])
    total_SIR_drum += abs(SIR_drum[0])
    total_SAR_drum += abs(SAR_drum[0])
    SDR_vocal, ISR_vocal, SIR_vocal, SAR_vocal, perm = mir_eval.separation.bss_eval_images(vocals_gt.reshape((1, shape[0],1)), vocals.reshape((1, shape[0],1)))
    total_SDR_vocal += abs(SDR_vocal[0])
    total_ISR_vocal += abs(ISR_vocal[0])
    total_SIR_vocal += abs(SIR_vocal[0])
    total_SAR_vocal += abs(SAR_vocal[0])
    SDR_other, ISR_other, SIR_other, SAR_other, perm = mir_eval.separation.bss_eval_images(others_gt.reshape((1, shape[0],1)), others.reshape((1, shape[0],1)))
    total_SDR_other += abs(SDR_other[0])
    total_ISR_other += abs(ISR_other[0])
    total_SIR_other += abs(SIR_other[0])
    total_SAR_other += abs(SAR_other[0])
    
mean_SDR_bass = total_SDR_bass/len(filenames)
mean_ISR_bass = total_ISR_bass/len(filenames)
mean_SIR_bass = total_SIR_bass/len(filenames)
mean_SAR_bass = total_SAR_bass/len(filenames)
mean_SDR_drum = total_SDR_drum/len(filenames)
mean_ISR_drum = total_ISR_bass/len(filenames)
mean_SIR_drum = total_SIR_drum/len(filenames)
mean_SAR_drum = total_SAR_drum/len(filenames)
mean_SDR_vocal = total_SDR_vocal/len(filenames)
mean_ISR_vocal = total_ISR_bass/len(filenames)
mean_SIR_vocal = total_SIR_vocal/len(filenames)
mean_SAR_vocal = total_SAR_vocal/len(filenames)
mean_SDR_other = total_SDR_other/len(filenames)
mean_ISR_other = total_ISR_bass/len(filenames)
mean_SIR_other = total_SIR_other/len(filenames)
mean_SAR_other = total_SAR_other/len(filenames)

In [None]:
print(mean_SDR_bass)
print(mean_ISR_bass)
print(mean_SIR_bass)
print(mean_SAR_bass)

In [None]:
print(mean_SDR_drum)
print(mean_ISR_drum)
print(mean_SIR_drum)
print(mean_SAR_drum)

In [None]:
print(mean_SDR_vocal)
print(mean_ISR_vocal)
print(mean_SIR_vocal)
print(mean_SAR_vocal)

In [None]:
print(mean_SDR_other)
print(mean_ISR_other)
print(mean_SIR_other)
print(mean_SAR_other)

In [None]:
net = net = ConvNet(inp_size)
    # net.load_state_dict(torch.load('Weights/Weights_200_3722932.6015625.pth')) #least score Weights so far
net.load_state_dict(torch.load('Weights/Weights_13_95039608.0.pth'))
net.eval()
mixture, m_phase = load_mixture('./musdb18hq/test/Side Effects Project - Sing With Me/mixture', seq=1)
mean = torch.mean(torch.tensor(mixture))
std = torch.std(torch.tensor(mixture))
inp_n = (torch.tensor(mixture)-mean)/std
(time_len, mel_bins) = inp_n.shape
inp_n = inp_n.view(1,time_len, mel_bins)
bass_mag, vocals_mag, drums_mag,others_mag = net(torch.FloatTensor(inp_n))
mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass_mag, vocals_mag, drums_mag, others_mag)
drums_mag=torch.FloatTensor(mixture)*np.squeeze(mask_drums)
vocals_mag=torch.FloatTensor(mixture)*np.squeeze(mask_vocals)
bass_mag=torch.FloatTensor(mixture)*np.squeeze(mask_bass)
others_mag=torch.FloatTensor(mixture)*np.squeeze(mask_others)
    
vocals = vocals_mag.detach().numpy() * m_phase
	#print(vocals.shape)
bass = bass_mag.detach().numpy()* m_phase
drums = drums_mag.detach().numpy()* m_phase
others = others_mag.detach().numpy()* m_phase

In [None]:
vocals_audio = librosa.istft(vocals, win_length=2204,hop_length=1103,window='hann',center='True')
bass_audio = librosa.istft(bass, win_length=2204,hop_length=1103,window='hann',center='True')
drums_audio = librosa.istft(drums, win_length=2204,hop_length=1103,window='hann',center='True')
others_audio = librosa.istft(others, win_length=2204,hop_length=1103,window='hann',center='True')

In [None]:
scipy.io.wavfile.write('./musdb18hq/model1/vocal.wav', 44100, vocals_audio)
scipy.io.wavfile.write('./musdb18hq/model1/bass.wav', 44100, bass_audio)
scipy.io.wavfile.write('./musdb18hq/model1/drums.wav',  44100, drums_audio)
scipy.io.wavfile.write('./musdb18hq/model1/other.wav',  44100, others_audio)