In [1]:
# Import the packages we'll use

import numpy as np
import os, glob, csv

# librosa is a widely-used audio processing library
import librosa

import sklearn
import scipy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
from torchvision import transforms

from tensorboardX import SummaryWriter
from torch.optim.lr_scheduler import MultiStepLR

from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

# for plotting
%matplotlib inline
import matplotlib.pyplot as plt

import math

# for accuracy and confusion matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
# for data normalization
from sklearn.preprocessing import StandardScaler

# use GPU if available, otherwise, use cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
import mir_eval  
import museval.metrics as metrics
import numpy as np
from scipy.io import wavfile

In [3]:
def load_audio(path, seq):
    example_audio, sr = librosa.load(path, sr=44100, offset=seq*30.0, duration=60.0)
    return example_audio

In [4]:
def load_mixture(filename,seq, win_len=0.05, hop_len=0.0125, n_mels=64):
    audio, sr = librosa.load("%s.wav" % (filename), sr=44100, offset=seq*30.0, duration=60.0)
    win_len = int(win_len*sr)
    hop_len = int(hop_len*sr)
    spec, phase = librosa.magphase(librosa.stft(audio, n_fft=win_len,hop_length=hop_len,window='hann',center='True'))
    spec = spec.transpose((1,0))
    phase = phase.transpose((1,0))
    if(len(spec) < 4803):
        spec = np.pad(spec, ((4803-len(spec),0),(0,0)))
        phase = np.pad(phase, ((4803-len(phase),0),(0,0)))
    spec = spec.transpose((1,0))
    phase = phase.transpose((1,0))
    return spec, phase

In [5]:
def load_segments(filename,seq, win_len=0.05, hop_len=0.0125, n_mels=64):
    audio, sr = librosa.load("%s.wav" % (filename), sr=44100,offset=seq*30.0,duration=60.0)
    win_len = int(win_len*sr)
    hop_len = int(hop_len*sr)
    spec, phase = librosa.magphase(librosa.stft(audio, n_fft=win_len,hop_length=hop_len,window='hann',center='True'))
    spec = spec.transpose((1,0))
    if(len(spec) < 4803):
        spec = np.pad(spec, ((4803-len(spec),0),(0,0)))
    spec = spec.transpose((1,0))
    return spec

In [6]:
def load_segments_test(filename,seq, win_len=0.05, hop_len=0.0125, n_mels=64):
    audio, sr = librosa.load("%s.wav" % (filename), sr=44100,offset=seq*30.0,duration=60.0)
    win_len = int(win_len*sr)
    hop_len = int(hop_len*sr)
    spec = librosa.stft(audio, n_fft=win_len,hop_length=hop_len,window='hann',center='True')
    spec = spec.transpose((1,0))
    if(len(spec) < 4803):
        spec = np.pad(spec, ((4803-len(spec),0),(0,0)))
    spec = spec.transpose((1,0))
    return spec

In [7]:
class Dataset(BaseDataset):
    
    def __init__(self,ids, seq, path='./musdb18hq/train', transforms=None):
            
        self.ids = ids
        self.ids = [x for x in self.ids if not x.startswith('.')]
        self.path = path+'/'
        self.seq = seq
        
    def __getitem__(self, i):
        
        # read data
        mixture_path = '/mixture'
        bass_path = '/bass'
        vocals_path = '/vocals'
        drums_path = '/drums'
        others_path = '/other'
        mixture, m_phase = load_mixture(self.path+self.ids[i]+mixture_path,seq = self.seq)
        #phase = torch.load(mixture_path+self.list[index]+'_p')
        bass = load_segments(self.path+self.ids[i]+bass_path, seq = self.seq)
        vocals = load_segments(self.path+self.ids[i]+vocals_path, seq = self.seq)
        drums = load_segments(self.path+self.ids[i]+drums_path, seq = self.seq)
        others = load_segments(self.path+self.ids[i]+others_path, seq = self.seq)
            
        return mixture, m_phase, bass, vocals, drums, others
        
    def __len__(self):
        return len(self.ids)

In [8]:
class ConvNet(nn.Module):
    def __init__(self,input_shape=[1103,4803]):
        super(ConvNet, self).__init__()
        t1=1
        f1=1103
        t2=50
        f2=1
        N1=15
        N2=25
        NN=128
        self.vconv = nn.Conv2d(in_channels=1, out_channels=15,
                               kernel_size=(1103,1), stride=1)
        self.hconv = nn.Conv2d(15,25, kernel_size=(1,50))
        self.fc0 = nn.Linear(N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2), NN)
        self.fc1 = nn.Linear(NN,N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2))
        self.fc2 = nn.Linear(NN,N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2))
        self.fc3 = nn.Linear(NN,N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2))
        self.fc4 = nn.Linear(NN,N2*(input_shape[0]-f1-f2+2)*(input_shape[1]-t1-t2+2))
        self.hdeconv1 = nn.ConvTranspose2d(25, 15, kernel_size=(1,50))
        self.hdeconv2 = nn.ConvTranspose2d(25, 15, kernel_size=(1,50))
        self.hdeconv3 = nn.ConvTranspose2d(25, 15, kernel_size=(1,50))
        self.hdeconv4 = nn.ConvTranspose2d(25, 15, kernel_size=(1,50))
        self.vdeconv1 = nn.ConvTranspose2d(15, 1, kernel_size=(1103,1), stride=1)
        self.vdeconv2 = nn.ConvTranspose2d(15, 1, kernel_size=(1103,1), stride=1)
        self.vdeconv3 = nn.ConvTranspose2d(15, 1, kernel_size=(1103,1), stride=1)
        self.vdeconv4 = nn.ConvTranspose2d(15, 1, kernel_size=(1103,1), stride=1)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        (_, time_len, mel_bins) = x.shape

        x = x.view(-1, 1, time_len, mel_bins)
        x = F.relu(self.vconv(x))
        x = F.relu(self.hconv(x))
        s1 = x.shape
        
        x = x.view(s1[0],-1)

        x = F.relu(self.fc0(x))
        x = self.dropout(x)

        x1 = F.relu(self.fc1(x))
        x2 = F.relu(self.fc2(x))
        x3 = F.relu(self.fc3(x))
        x4 = F.relu(self.fc4(x))
        
        x1 = self.dropout(x1)
        x2 = self.dropout(x2)
        x3 = self.dropout(x3)
        x4 = self.dropout(x4)

        x1 = x1.view(s1[0], s1[1],s1[2],s1[3])
        x2 = x2.view(s1[0], s1[1],s1[2],s1[3])
        x3 = x3.view(s1[0], s1[1],s1[2],s1[3])
        x4 = x4.view(s1[0], s1[1],s1[2],s1[3])
        
        x1 = self.hdeconv1(x1)
        x2 = self.hdeconv2(x2)
        x3 = self.hdeconv3(x3)
        x4 = self.hdeconv4(x4)
        
        x1 = self.dropout(x1)
        x2 = self.dropout(x2)
        x3 = self.dropout(x3)
        x4 = self.dropout(x4)

        x1 = self.vdeconv1(x1)
        x2 = self.vdeconv2(x2)
        x3 = self.vdeconv3(x3)
        x4 = self.vdeconv4(x4)

        return x1, x2, x3, x4

In [9]:
inp_size = [1103, 4803]

In [10]:
netb = ConvNet(inp_size)

In [11]:
from torchsummary import summary
summary(netb, (1103, 4803))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 15, 1, 4803]          16,560
            Conv2d-2          [-1, 25, 1, 4754]          18,775
            Linear-3                  [-1, 128]      15,212,928
           Dropout-4                  [-1, 128]               0
            Linear-5               [-1, 118850]      15,331,650
            Linear-6               [-1, 118850]      15,331,650
            Linear-7               [-1, 118850]      15,331,650
            Linear-8               [-1, 118850]      15,331,650
           Dropout-9               [-1, 118850]               0
          Dropout-10               [-1, 118850]               0
          Dropout-11               [-1, 118850]               0
          Dropout-12               [-1, 118850]               0
  ConvTranspose2d-13          [-1, 15, 1, 4803]          18,765
  ConvTranspose2d-14          [-1, 15, 

In [12]:
mean_var_path= "../Processed/"
if not os.path.exists('Model2_Weights_v2'):
    os.makedirs('Model2_Weights_v2')
#os.environ["CUDA_VISIBLE_DEVICES"]="0"
#--------------------------
class Average(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.sum += val
        self.count += n

    #property
    def avg(self):
        return self.sum / self.count
#------------------------------
# import csv
writer = SummaryWriter()
#----------------------------------------

inp_size = [1103, 4803]
alpha = 0.003
beta = 0.05
beta_vocals = 0.11
batch_size = 15
num_epochs = 20
cur_ep = 0

class MixedSquaredError(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(MixedSquaredError, self).__init__()

    def forward(self, pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums, gt_others):


        L_sq = torch.sum((pred_bass-gt_bass).pow(2)) + torch.sum((pred_vocals-gt_vocals).pow(2)) + torch.sum((pred_drums-gt_drums).pow(2)) 
        #+ torch.sum((pred_others-gt_others).pow(2))
        L_other = torch.sum((pred_bass-gt_others).pow(2)) + torch.sum((pred_drums-gt_others).pow(2)) + torch.sum((pred_others-gt_others).pow(2))
        L_othervocals = torch.sum((pred_vocals - gt_others).pow(2))
        L_diff = torch.sum((pred_bass-pred_vocals).pow(2)) + torch.sum((pred_bass-pred_drums).pow(2)) + torch.sum((pred_vocals-pred_drums).pow(2)) + torch.sum((pred_vocals-pred_others).pow(2))

        return (L_sq- alpha*L_diff - beta*L_other - beta_vocals*L_othervocals)

def TimeFreqMasking(bass,vocals,drums,others,cuda=0):
    den = torch.abs(bass) + torch.abs(vocals) + torch.abs(drums) + torch.abs(others)
    if(cuda):
        den = den + 10e-8*torch.cuda.FloatTensor(bass.size()).normal_()
    else:
        den = den + 10e-8*torch.FloatTensor(bass.size()).normal_()
    bass = torch.abs(bass)/den
    vocals = torch.abs(vocals)/den
    drums = torch.abs(drums)/den
    others = torch.abs(others)/den
    
    return bass,vocals,drums,others


def train():
    
    cuda = torch.cuda.is_available()
    net = ConvNet(inp_size)
    criterion = MixedSquaredError()
    if cuda:
        net = net.cuda()
        criterion = criterion.cuda()
    optimizer = torch.optim.Adam(net.parameters(), lr = 0.0001)

    print("preparing training data ...")
    
    ids = os.listdir('./musdb18hq/train')
    ids = [x for x in ids if not x.startswith('.')]
    filenames = []
    
    val_ids = os.listdir('./musdb18hq/val')
    val_ids = [x for x in val_ids if not x.startswith('.')]
    val_filenames = []
    for i in val_ids:
        # load an example audio file, converting the data to mel spectrogram
        path = './musdb18hq/val/'+i+'/mixture.wav'
        example_audio = load_audio(path, seq = 0)
        if example_audio.shape[0] > 0:
            val_filenames.append(i)
    val_set = Dataset(ids = val_filenames, seq = 0, path='./musdb18hq/val',transforms = None)
    val_loader = DataLoader(val_set, batch_size=batch_size,shuffle=False)
    
    
    for epoch in range(cur_ep, num_epochs):
        batch = 0
        filenames = []
        for i in ids:
            # load an example audio file, converting the data to mel spectrogram
            path = './musdb18hq/train/'+i+'/mixture.wav'
            example_audio = load_audio(path, seq = 0)
            if example_audio.shape[0] > 0:
                filenames.append(i)
        print(len(filenames))
        train_set = Dataset(ids = filenames, seq = 0, transforms = None)
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        print("done ...")

        train_loss = Average()       

        net.train()
        for i, (inp,phase, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(train_loader):
            print(f'epoch {epoch+1}.... batch {batch+i+1}')
            inp = torch.FloatTensor(inp)
            mean = torch.mean(inp)
            std = torch.std(inp)
            inp_n = (inp-mean)/std
            inp_n = torch.FloatTensor(inp_n)
            gt_bass = torch.FloatTensor(gt_bass)
            gt_vocals = torch.FloatTensor(gt_vocals)
            gt_drums = torch.FloatTensor(gt_drums)
            gt_others= torch.FloatTensor(gt_others)
            if cuda:
                inp = inp.cuda()
                inp_n = inp_n.cuda()
                gt_bass = gt_bass.cuda()
                gt_vocals = gt_vocals.cuda()
                gt_drums = gt_drums.cuda()
                gt_others= gt_others.cuda()
            optimizer.zero_grad()
            bass, vocals, drums, others = net(torch.FloatTensor(inp_n))
            mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass, vocals, drums, others,cuda)

            pred_drums=inp*np.squeeze(mask_drums)
            pred_vocals=inp*np.squeeze(mask_vocals)
            pred_bass=inp*np.squeeze(mask_bass)
            pred_others=inp*np.squeeze(mask_others)

            loss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums,gt_others)
            writer.add_scalar('Train Loss',loss,epoch)
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), inp.size(0))
            
            
        #..........
        filenames = []
        print(len(filenames))
        for i in ids:
            # load an example audio file, converting the data to mel spectrogram
            examplefkey = './musdb18hq/train/'+i+'/mixture.wav'
            example_audio = load_audio(examplefkey, seq = 1)
            if example_audio.shape[0] > 0:
                filenames.append(i)
        print(len(filenames))
        train_set = Dataset(ids = filenames, seq = 1, transforms = None)
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        print("done ...")

        for i, (inp,phase, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(train_loader):
            print(f'epoch {epoch+1}.... batch {batch+i+1}')
            inp = torch.FloatTensor(inp)
            mean = torch.mean(inp)
            std = torch.std(inp)
            inp_n = (inp-mean)/std
            inp_n = torch.FloatTensor(inp_n)
            gt_bass = torch.FloatTensor(gt_bass)
            gt_vocals = torch.FloatTensor(gt_vocals)
            gt_drums = torch.FloatTensor(gt_drums)
            gt_others= torch.FloatTensor(gt_others)
            if cuda:
                inp = inp.cuda()
                inp_n = inp_n.cuda()
                gt_bass = gt_bass.cuda()
                gt_vocals = gt_vocals.cuda()
                gt_drums = gt_drums.cuda()
                gt_others= gt_others.cuda()
            optimizer.zero_grad()
            bass, vocals, drums, others = net(torch.FloatTensor(inp_n))
            mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass, vocals, drums, others,cuda)

            pred_drums=inp*np.squeeze(mask_drums)
            pred_vocals=inp*np.squeeze(mask_vocals)
            pred_bass=inp*np.squeeze(mask_bass)
            pred_others=inp*np.squeeze(mask_others)

            loss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums,gt_others)
            writer.add_scalar('Train Loss',loss,epoch)
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), inp.size(0))
        
        #..................
        filenames = []
        for i in ids:
            # load an example audio file, converting the data to mel spectrogram
            examplefkey = './musdb18hq/train/'+i+'/mixture.wav'
            example_audio = load_audio(examplefkey, seq = 2)
            if example_audio.shape[0] > 0:
                filenames.append(i)
        print(len(filenames))
        train_set = Dataset(ids = filenames, seq = 2, transforms = None)
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        print("done ...")

        for i, (inp,phase, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(train_loader):
            print(f'epoch {epoch+1}.... batch {batch+i+1}')
            inp = torch.FloatTensor(inp)
            mean = torch.mean(inp)
            std = torch.std(inp)
            inp_n = (inp-mean)/std
            inp_n = torch.FloatTensor(inp_n)
            gt_bass = torch.FloatTensor(gt_bass)
            gt_vocals = torch.FloatTensor(gt_vocals)
            gt_drums = torch.FloatTensor(gt_drums)
            gt_others= torch.FloatTensor(gt_others)
            if cuda:
                inp = inp.cuda()
                inp_n = inp_n.cuda()
                gt_bass = gt_bass.cuda()
                gt_vocals = gt_vocals.cuda()
                gt_drums = gt_drums.cuda()
                gt_others= gt_others.cuda()
            optimizer.zero_grad()
            bass, vocals, drums, others = net(torch.FloatTensor(inp_n))
            mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass, vocals, drums, others,cuda)

            pred_drums=inp*np.squeeze(mask_drums)
            pred_vocals=inp*np.squeeze(mask_vocals)
            pred_bass=inp*np.squeeze(mask_bass)
            pred_others=inp*np.squeeze(mask_others)

            loss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums,gt_others)
            writer.add_scalar('Train Loss',loss,epoch)
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), inp.size(0))
        
        #..............
        filenames = []
        for i in ids:
            # load an example audio file, converting the data to mel spectrogram
            examplefkey = './musdb18hq/train/'+i+'/mixture.wav'
            example_audio = load_audio(examplefkey, seq = 3)
            if example_audio.shape[0] > 0:
                filenames.append(i)
        print(len(filenames))
        train_set = Dataset(ids = filenames, seq = 3, transforms = None)
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        print("done ...")

        for i, (inp,phase, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(train_loader):
            print(f'epoch {epoch+1}.... batch {batch+i+1}')
            inp = torch.FloatTensor(inp)
            mean = torch.mean(inp)
            std = torch.std(inp)
            inp_n = (inp-mean)/std
            inp_n = torch.FloatTensor(inp_n)
            gt_bass = torch.FloatTensor(gt_bass)
            gt_vocals = torch.FloatTensor(gt_vocals)
            gt_drums = torch.FloatTensor(gt_drums)
            gt_others= torch.FloatTensor(gt_others)
            if cuda:
                inp = inp.cuda()
                inp_n = inp_n.cuda()
                gt_bass = gt_bass.cuda()
                gt_vocals = gt_vocals.cuda()
                gt_drums = gt_drums.cuda()
                gt_others= gt_others.cuda()
            optimizer.zero_grad()
            bass, vocals, drums, others = net(torch.FloatTensor(inp_n))
            mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass, vocals, drums, others,cuda)

            pred_drums=inp*np.squeeze(mask_drums)
            pred_vocals=inp*np.squeeze(mask_vocals)
            pred_bass=inp*np.squeeze(mask_bass)
            pred_others=inp*np.squeeze(mask_others)

            loss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums,gt_others)
            writer.add_scalar('Train Loss',loss,epoch)
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), inp.size(0))
        
        #......................
        filenames = []
        for i in ids:
            # load an example audio file, converting the data to mel spectrogram
            examplefkey = './musdb18hq/train/'+i+'/mixture.wav'
            example_audio = load_audio(examplefkey, seq = 4)
            if example_audio.shape[0] > 0:
                filenames.append(i)
        print(len(filenames))
        train_set = Dataset(ids = filenames, seq = 4, transforms = None)
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        print("done ...")

        for i, (inp,phase, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(train_loader):
            print(f'epoch {epoch+1}.... batch {batch+i+1}')
            inp = torch.FloatTensor(inp)
            mean = torch.mean(inp)
            std = torch.std(inp)
            inp_n = (inp-mean)/std
            inp_n = torch.FloatTensor(inp_n)
            gt_bass = torch.FloatTensor(gt_bass)
            gt_vocals = torch.FloatTensor(gt_vocals)
            gt_drums = torch.FloatTensor(gt_drums)
            gt_others= torch.FloatTensor(gt_others)
            if cuda:
                inp = inp.cuda()
                inp_n = inp_n.cuda()
                gt_bass = gt_bass.cuda()
                gt_vocals = gt_vocals.cuda()
                gt_drums = gt_drums.cuda()
                gt_others= gt_others.cuda()
            optimizer.zero_grad()
            bass, vocals, drums, others = net(torch.FloatTensor(inp_n))
            mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass, vocals, drums, others,cuda)

            pred_drums=inp*np.squeeze(mask_drums)
            pred_vocals=inp*np.squeeze(mask_vocals)
            pred_bass=inp*np.squeeze(mask_bass)
            pred_others=inp*np.squeeze(mask_others)

            loss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums,gt_others)
            writer.add_scalar('Train Loss',loss,epoch)
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), inp.size(0))
        #.........................
        
        filenames = []
        for i in ids:
            # load an example audio file, converting the data to mel spectrogram
            examplefkey = './musdb18hq/train/'+i+'/mixture.wav'
            example_audio = load_audio(examplefkey, seq = 5)
            if example_audio.shape[0] > 0:
                filenames.append(i)
        print(len(filenames))
        train_set = Dataset(ids = filenames, seq = 5, transforms = None)
        train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
        print("done ...")

        for i, (inp,phase, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(train_loader):
            print(f'epoch {epoch+1}.... batch {batch+i+1}')
            inp = torch.FloatTensor(inp)
            mean = torch.mean(inp)
            std = torch.std(inp)
            inp_n = (inp-mean)/std
            inp_n = torch.FloatTensor(inp_n)
            gt_bass = torch.FloatTensor(gt_bass)
            gt_vocals = torch.FloatTensor(gt_vocals)
            gt_drums = torch.FloatTensor(gt_drums)
            gt_others= torch.FloatTensor(gt_others)
            if cuda:
                inp = inp.cuda()
                inp_n = inp_n.cuda()
                gt_bass = gt_bass.cuda()
                gt_vocals = gt_vocals.cuda()
                gt_drums = gt_drums.cuda()
                gt_others= gt_others.cuda()
            optimizer.zero_grad()
            bass, vocals, drums, others = net(torch.FloatTensor(inp_n))
            mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass, vocals, drums, others,cuda)

            pred_drums=inp*np.squeeze(mask_drums)
            pred_vocals=inp*np.squeeze(mask_vocals)
            pred_bass=inp*np.squeeze(mask_bass)
            pred_others=inp*np.squeeze(mask_others)

            loss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums,gt_others)
            writer.add_scalar('Train Loss',loss,epoch)
            loss.backward()
            optimizer.step()
            train_loss.update(loss.item(), inp.size(0))
        #.........................

        val_loss = Average()
        net.eval()
        for i,(val_inp,val_phase, gt_bass,gt_vocals,gt_drums,gt_others) in enumerate(val_loader):
            
            val_inp = torch.FloatTensor(val_inp)
            val_mean = torch.mean(val_inp)
            val_std = torch.std(val_inp)
            val_inp_n = (val_inp-val_mean)/val_std
            val_inp_n = torch.FloatTensor(val_inp_n)
            gt_bass = torch.FloatTensor(gt_bass)
            gt_vocals = torch.FloatTensor(gt_vocals)
            gt_drums = torch.FloatTensor(gt_drums)
            gt_others= torch.FloatTensor(gt_others)
            if cuda:
                val_inp = val_inp.cuda()
                val_inp_n = val_inp_n.cuda()
                gt_bass = gt_bass.cuda()
                gt_vocals = gt_vocals.cuda()
                gt_drums = gt_drums.cuda()
                gt_others = gt_others.cuda()

            bass, vocals, drums, others = net(val_inp_n)
            mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass, vocals, drums, others,cuda)

            pred_drums=val_inp*np.squeeze(mask_drums)
            pred_vocals=val_inp*np.squeeze(mask_vocals)
            pred_bass=val_inp*np.squeeze(mask_bass)
            pred_others=val_inp*np.squeeze(mask_others)
            

            """if (epoch)%10==0:
                writer.add_image('Validation Input',val_inp,epoch)
                writer.add_image('Validation Bass GT ',gt_bass,epoch)
                writer.add_image('Validation Bass Pred ',pred_bass,epoch)
                writer.add_image('Validation Vocals GT ',gt_vocals,epoch)
                writer.add_image('Validation Vocals Pred ',pred_vocals,epoch)
                writer.add_image('Validation Drums GT ',gt_drums,epoch)
                writer.add_image('Validation Drums Pred ',pred_drums,epoch)
                writer.add_image('Validation Other GT ',gt_others,epoch)
                writer.add_image('Validation Others Pred ',pred_others,epoch)"""

            vloss = criterion(pred_bass,pred_vocals,pred_drums,pred_others, gt_bass,gt_vocals,gt_drums, gt_others)
            writer.add_scalar('Validation loss',vloss,epoch)
            val_loss.update(vloss.item(), inp.size(0))

        print("Epoch {}, Training Loss: {}, Validation Loss: {}".format(epoch+1, train_loss.avg(), val_loss.avg()))
        torch.save(net.state_dict(), 'Model2_Weights_v2/Weights_{}_{}.pth'.format(epoch+1, val_loss.avg()))
    return net

In [13]:
train()

preparing training data ...
106
done ...
epoch 1.... batch 1
epoch 1.... batch 2
epoch 1.... batch 3
epoch 1.... batch 4
epoch 1.... batch 5
epoch 1.... batch 6
epoch 1.... batch 7
epoch 1.... batch 8
0




100
done ...
epoch 1.... batch 1
epoch 1.... batch 2
epoch 1.... batch 3
epoch 1.... batch 4
epoch 1.... batch 5
epoch 1.... batch 6
epoch 1.... batch 7




95
done ...
epoch 1.... batch 1
epoch 1.... batch 2
epoch 1.... batch 3
epoch 1.... batch 4
epoch 1.... batch 5
epoch 1.... batch 6
epoch 1.... batch 7




94
done ...
epoch 1.... batch 1
epoch 1.... batch 2
epoch 1.... batch 3
epoch 1.... batch 4
epoch 1.... batch 5
epoch 1.... batch 6
epoch 1.... batch 7




93
done ...
epoch 1.... batch 1
epoch 1.... batch 2
epoch 1.... batch 3
epoch 1.... batch 4
epoch 1.... batch 5
epoch 1.... batch 6
epoch 1.... batch 7




91
done ...
epoch 1.... batch 1
epoch 1.... batch 2
epoch 1.... batch 3
epoch 1.... batch 4
epoch 1.... batch 5
epoch 1.... batch 6
epoch 1.... batch 7
Epoch 1, Training Loss: 26696101.63385147, Validation Loss: 178887120.0
106
done ...
epoch 2.... batch 1
epoch 2.... batch 2
epoch 2.... batch 3
epoch 2.... batch 4
epoch 2.... batch 5
epoch 2.... batch 6
epoch 2.... batch 7
epoch 2.... batch 8
0




100
done ...
epoch 2.... batch 1
epoch 2.... batch 2
epoch 2.... batch 3
epoch 2.... batch 4
epoch 2.... batch 5
epoch 2.... batch 6
epoch 2.... batch 7




95
done ...
epoch 2.... batch 1
epoch 2.... batch 2
epoch 2.... batch 3
epoch 2.... batch 4
epoch 2.... batch 5
epoch 2.... batch 6
epoch 2.... batch 7




94
done ...
epoch 2.... batch 1
epoch 2.... batch 2
epoch 2.... batch 3
epoch 2.... batch 4
epoch 2.... batch 5
epoch 2.... batch 6
epoch 2.... batch 7




93
done ...
epoch 2.... batch 1
epoch 2.... batch 2
epoch 2.... batch 3
epoch 2.... batch 4
epoch 2.... batch 5
epoch 2.... batch 6
epoch 2.... batch 7




91
done ...
epoch 2.... batch 1
epoch 2.... batch 2
epoch 2.... batch 3
epoch 2.... batch 4
epoch 2.... batch 5
epoch 2.... batch 6
epoch 2.... batch 7
Epoch 2, Training Loss: 19930576.99654577, Validation Loss: 153018400.0
106
done ...
epoch 3.... batch 1
epoch 3.... batch 2
epoch 3.... batch 3
epoch 3.... batch 4
epoch 3.... batch 5
epoch 3.... batch 6
epoch 3.... batch 7
epoch 3.... batch 8
0




100
done ...
epoch 3.... batch 1
epoch 3.... batch 2
epoch 3.... batch 3
epoch 3.... batch 4
epoch 3.... batch 5
epoch 3.... batch 6
epoch 3.... batch 7




95
done ...
epoch 3.... batch 1
epoch 3.... batch 2
epoch 3.... batch 3
epoch 3.... batch 4
epoch 3.... batch 5
epoch 3.... batch 6
epoch 3.... batch 7




94
done ...
epoch 3.... batch 1
epoch 3.... batch 2
epoch 3.... batch 3
epoch 3.... batch 4
epoch 3.... batch 5
epoch 3.... batch 6
epoch 3.... batch 7




93
done ...
epoch 3.... batch 1
epoch 3.... batch 2
epoch 3.... batch 3
epoch 3.... batch 4
epoch 3.... batch 5
epoch 3.... batch 6
epoch 3.... batch 7




91
done ...
epoch 3.... batch 1
epoch 3.... batch 2
epoch 3.... batch 3
epoch 3.... batch 4
epoch 3.... batch 5
epoch 3.... batch 6
epoch 3.... batch 7
Epoch 3, Training Loss: 17417189.170984454, Validation Loss: 149552976.0
106
done ...
epoch 4.... batch 1
epoch 4.... batch 2
epoch 4.... batch 3
epoch 4.... batch 4
epoch 4.... batch 5
epoch 4.... batch 6
epoch 4.... batch 7
epoch 4.... batch 8
0




100
done ...
epoch 4.... batch 1
epoch 4.... batch 2
epoch 4.... batch 3
epoch 4.... batch 4
epoch 4.... batch 5
epoch 4.... batch 6
epoch 4.... batch 7




95
done ...
epoch 4.... batch 1
epoch 4.... batch 2
epoch 4.... batch 3
epoch 4.... batch 4
epoch 4.... batch 5
epoch 4.... batch 6
epoch 4.... batch 7




94
done ...
epoch 4.... batch 1
epoch 4.... batch 2
epoch 4.... batch 3
epoch 4.... batch 4
epoch 4.... batch 5
epoch 4.... batch 6
epoch 4.... batch 7




93
done ...
epoch 4.... batch 1
epoch 4.... batch 2
epoch 4.... batch 3
epoch 4.... batch 4
epoch 4.... batch 5
epoch 4.... batch 6
epoch 4.... batch 7




91
done ...
epoch 4.... batch 1
epoch 4.... batch 2
epoch 4.... batch 3
epoch 4.... batch 4
epoch 4.... batch 5
epoch 4.... batch 6
epoch 4.... batch 7
Epoch 4, Training Loss: 16128488.418825561, Validation Loss: 148545056.0
106
done ...
epoch 5.... batch 1
epoch 5.... batch 2
epoch 5.... batch 3
epoch 5.... batch 4
epoch 5.... batch 5
epoch 5.... batch 6
epoch 5.... batch 7
epoch 5.... batch 8
0




100
done ...
epoch 5.... batch 1
epoch 5.... batch 2
epoch 5.... batch 3
epoch 5.... batch 4
epoch 5.... batch 5
epoch 5.... batch 6
epoch 5.... batch 7




95
done ...
epoch 5.... batch 1
epoch 5.... batch 2
epoch 5.... batch 3
epoch 5.... batch 4
epoch 5.... batch 5
epoch 5.... batch 6
epoch 5.... batch 7




94
done ...
epoch 5.... batch 1
epoch 5.... batch 2
epoch 5.... batch 3
epoch 5.... batch 4
epoch 5.... batch 5
epoch 5.... batch 6
epoch 5.... batch 7




93
done ...
epoch 5.... batch 1
epoch 5.... batch 2
epoch 5.... batch 3
epoch 5.... batch 4
epoch 5.... batch 5
epoch 5.... batch 6
epoch 5.... batch 7




91
done ...
epoch 5.... batch 1
epoch 5.... batch 2
epoch 5.... batch 3
epoch 5.... batch 4
epoch 5.... batch 5
epoch 5.... batch 6
epoch 5.... batch 7
Epoch 5, Training Loss: 15381955.070811745, Validation Loss: 146300768.0
106
done ...
epoch 6.... batch 1
epoch 6.... batch 2
epoch 6.... batch 3
epoch 6.... batch 4
epoch 6.... batch 5
epoch 6.... batch 6
epoch 6.... batch 7
epoch 6.... batch 8
0




100
done ...
epoch 6.... batch 1
epoch 6.... batch 2
epoch 6.... batch 3
epoch 6.... batch 4
epoch 6.... batch 5
epoch 6.... batch 6
epoch 6.... batch 7




95
done ...
epoch 6.... batch 1
epoch 6.... batch 2
epoch 6.... batch 3
epoch 6.... batch 4
epoch 6.... batch 5
epoch 6.... batch 6
epoch 6.... batch 7




94
done ...
epoch 6.... batch 1
epoch 6.... batch 2
epoch 6.... batch 3
epoch 6.... batch 4
epoch 6.... batch 5
epoch 6.... batch 6
epoch 6.... batch 7




93
done ...
epoch 6.... batch 1
epoch 6.... batch 2
epoch 6.... batch 3
epoch 6.... batch 4
epoch 6.... batch 5
epoch 6.... batch 6
epoch 6.... batch 7




91
done ...
epoch 6.... batch 1
epoch 6.... batch 2
epoch 6.... batch 3
epoch 6.... batch 4
epoch 6.... batch 5
epoch 6.... batch 6
epoch 6.... batch 7
Epoch 6, Training Loss: 14702013.397236615, Validation Loss: 147635168.0
106
done ...
epoch 7.... batch 1
epoch 7.... batch 2
epoch 7.... batch 3
epoch 7.... batch 4
epoch 7.... batch 5
epoch 7.... batch 6
epoch 7.... batch 7
epoch 7.... batch 8
0




100
done ...
epoch 7.... batch 1
epoch 7.... batch 2
epoch 7.... batch 3
epoch 7.... batch 4
epoch 7.... batch 5
epoch 7.... batch 6
epoch 7.... batch 7




95
done ...
epoch 7.... batch 1
epoch 7.... batch 2
epoch 7.... batch 3
epoch 7.... batch 4
epoch 7.... batch 5
epoch 7.... batch 6
epoch 7.... batch 7




94
done ...
epoch 7.... batch 1
epoch 7.... batch 2
epoch 7.... batch 3
epoch 7.... batch 4
epoch 7.... batch 5
epoch 7.... batch 6
epoch 7.... batch 7




93
done ...
epoch 7.... batch 1
epoch 7.... batch 2
epoch 7.... batch 3
epoch 7.... batch 4
epoch 7.... batch 5
epoch 7.... batch 6
epoch 7.... batch 7




91
done ...
epoch 7.... batch 1
epoch 7.... batch 2
epoch 7.... batch 3
epoch 7.... batch 4
epoch 7.... batch 5
epoch 7.... batch 6
epoch 7.... batch 7
Epoch 7, Training Loss: 14276438.830742659, Validation Loss: 143649072.0
106
done ...
epoch 8.... batch 1
epoch 8.... batch 2
epoch 8.... batch 3
epoch 8.... batch 4
epoch 8.... batch 5
epoch 8.... batch 6
epoch 8.... batch 7
epoch 8.... batch 8
0




100
done ...
epoch 8.... batch 1
epoch 8.... batch 2
epoch 8.... batch 3
epoch 8.... batch 4
epoch 8.... batch 5
epoch 8.... batch 6
epoch 8.... batch 7




95
done ...
epoch 8.... batch 1
epoch 8.... batch 2
epoch 8.... batch 3
epoch 8.... batch 4
epoch 8.... batch 5
epoch 8.... batch 6
epoch 8.... batch 7




94
done ...
epoch 8.... batch 1
epoch 8.... batch 2
epoch 8.... batch 3
epoch 8.... batch 4
epoch 8.... batch 5
epoch 8.... batch 6
epoch 8.... batch 7




93
done ...
epoch 8.... batch 1
epoch 8.... batch 2
epoch 8.... batch 3
epoch 8.... batch 4
epoch 8.... batch 5
epoch 8.... batch 6
epoch 8.... batch 7




91
done ...
epoch 8.... batch 1
epoch 8.... batch 2
epoch 8.... batch 3
epoch 8.... batch 4
epoch 8.... batch 5
epoch 8.... batch 6
epoch 8.... batch 7
Epoch 8, Training Loss: 13918890.257340241, Validation Loss: 144891072.0
106
done ...
epoch 9.... batch 1
epoch 9.... batch 2
epoch 9.... batch 3
epoch 9.... batch 4
epoch 9.... batch 5
epoch 9.... batch 6
epoch 9.... batch 7
epoch 9.... batch 8
0




100
done ...
epoch 9.... batch 1
epoch 9.... batch 2
epoch 9.... batch 3
epoch 9.... batch 4
epoch 9.... batch 5
epoch 9.... batch 6
epoch 9.... batch 7




95
done ...
epoch 9.... batch 1
epoch 9.... batch 2
epoch 9.... batch 3
epoch 9.... batch 4
epoch 9.... batch 5
epoch 9.... batch 6
epoch 9.... batch 7




94
done ...
epoch 9.... batch 1
epoch 9.... batch 2
epoch 9.... batch 3
epoch 9.... batch 4
epoch 9.... batch 5
epoch 9.... batch 6
epoch 9.... batch 7




93
done ...
epoch 9.... batch 1
epoch 9.... batch 2
epoch 9.... batch 3
epoch 9.... batch 4
epoch 9.... batch 5
epoch 9.... batch 6
epoch 9.... batch 7




91
done ...
epoch 9.... batch 1
epoch 9.... batch 2
epoch 9.... batch 3
epoch 9.... batch 4
epoch 9.... batch 5
epoch 9.... batch 6
epoch 9.... batch 7
Epoch 9, Training Loss: 13550926.097582038, Validation Loss: 146864832.0
106
done ...
epoch 10.... batch 1
epoch 10.... batch 2
epoch 10.... batch 3
epoch 10.... batch 4
epoch 10.... batch 5
epoch 10.... batch 6
epoch 10.... batch 7
epoch 10.... batch 8
0




100
done ...
epoch 10.... batch 1
epoch 10.... batch 2
epoch 10.... batch 3
epoch 10.... batch 4
epoch 10.... batch 5
epoch 10.... batch 6
epoch 10.... batch 7




95
done ...
epoch 10.... batch 1
epoch 10.... batch 2
epoch 10.... batch 3
epoch 10.... batch 4
epoch 10.... batch 5
epoch 10.... batch 6
epoch 10.... batch 7




94
done ...
epoch 10.... batch 1
epoch 10.... batch 2
epoch 10.... batch 3
epoch 10.... batch 4
epoch 10.... batch 5
epoch 10.... batch 6
epoch 10.... batch 7




93
done ...
epoch 10.... batch 1
epoch 10.... batch 2
epoch 10.... batch 3
epoch 10.... batch 4
epoch 10.... batch 5
epoch 10.... batch 6
epoch 10.... batch 7




91
done ...
epoch 10.... batch 1
epoch 10.... batch 2
epoch 10.... batch 3
epoch 10.... batch 4
epoch 10.... batch 5
epoch 10.... batch 6
epoch 10.... batch 7
Epoch 10, Training Loss: 13362945.962003455, Validation Loss: 143580000.0
106
done ...
epoch 11.... batch 1
epoch 11.... batch 2
epoch 11.... batch 3
epoch 11.... batch 4
epoch 11.... batch 5
epoch 11.... batch 6
epoch 11.... batch 7
epoch 11.... batch 8
0




100
done ...
epoch 11.... batch 1
epoch 11.... batch 2
epoch 11.... batch 3
epoch 11.... batch 4
epoch 11.... batch 5
epoch 11.... batch 6
epoch 11.... batch 7




95
done ...
epoch 11.... batch 1
epoch 11.... batch 2
epoch 11.... batch 3
epoch 11.... batch 4
epoch 11.... batch 5
epoch 11.... batch 6
epoch 11.... batch 7




94
done ...
epoch 11.... batch 1
epoch 11.... batch 2
epoch 11.... batch 3
epoch 11.... batch 4
epoch 11.... batch 5
epoch 11.... batch 6
epoch 11.... batch 7




93
done ...
epoch 11.... batch 1
epoch 11.... batch 2
epoch 11.... batch 3
epoch 11.... batch 4
epoch 11.... batch 5
epoch 11.... batch 6
epoch 11.... batch 7




91
done ...
epoch 11.... batch 1
epoch 11.... batch 2
epoch 11.... batch 3
epoch 11.... batch 4
epoch 11.... batch 5
epoch 11.... batch 6
epoch 11.... batch 7
Epoch 11, Training Loss: 13145455.303972365, Validation Loss: 148357168.0
106
done ...
epoch 12.... batch 1
epoch 12.... batch 2
epoch 12.... batch 3
epoch 12.... batch 4
epoch 12.... batch 5
epoch 12.... batch 6
epoch 12.... batch 7
epoch 12.... batch 8
0




100
done ...
epoch 12.... batch 1
epoch 12.... batch 2
epoch 12.... batch 3
epoch 12.... batch 4
epoch 12.... batch 5
epoch 12.... batch 6
epoch 12.... batch 7




95
done ...
epoch 12.... batch 1
epoch 12.... batch 2
epoch 12.... batch 3
epoch 12.... batch 4
epoch 12.... batch 5
epoch 12.... batch 6
epoch 12.... batch 7




94
done ...
epoch 12.... batch 1
epoch 12.... batch 2
epoch 12.... batch 3
epoch 12.... batch 4
epoch 12.... batch 5
epoch 12.... batch 6
epoch 12.... batch 7




93
done ...
epoch 12.... batch 1
epoch 12.... batch 2
epoch 12.... batch 3
epoch 12.... batch 4
epoch 12.... batch 5
epoch 12.... batch 6
epoch 12.... batch 7




91
done ...
epoch 12.... batch 1
epoch 12.... batch 2
epoch 12.... batch 3
epoch 12.... batch 4
epoch 12.... batch 5
epoch 12.... batch 6
epoch 12.... batch 7
Epoch 12, Training Loss: 13003491.778065631, Validation Loss: 141569328.0
106
done ...
epoch 13.... batch 1
epoch 13.... batch 2
epoch 13.... batch 3
epoch 13.... batch 4
epoch 13.... batch 5
epoch 13.... batch 6
epoch 13.... batch 7
epoch 13.... batch 8
0




100
done ...
epoch 13.... batch 1
epoch 13.... batch 2
epoch 13.... batch 3
epoch 13.... batch 4
epoch 13.... batch 5
epoch 13.... batch 6
epoch 13.... batch 7




95
done ...
epoch 13.... batch 1
epoch 13.... batch 2
epoch 13.... batch 3
epoch 13.... batch 4
epoch 13.... batch 5
epoch 13.... batch 6
epoch 13.... batch 7




94
done ...
epoch 13.... batch 1
epoch 13.... batch 2
epoch 13.... batch 3
epoch 13.... batch 4
epoch 13.... batch 5
epoch 13.... batch 6
epoch 13.... batch 7




93
done ...
epoch 13.... batch 1
epoch 13.... batch 2
epoch 13.... batch 3
epoch 13.... batch 4
epoch 13.... batch 5
epoch 13.... batch 6
epoch 13.... batch 7




91
done ...
epoch 13.... batch 1
epoch 13.... batch 2
epoch 13.... batch 3
epoch 13.... batch 4
epoch 13.... batch 5
epoch 13.... batch 6
epoch 13.... batch 7
Epoch 13, Training Loss: 12822515.307426598, Validation Loss: 169421536.0
106
done ...
epoch 14.... batch 1
epoch 14.... batch 2
epoch 14.... batch 3
epoch 14.... batch 4
epoch 14.... batch 5
epoch 14.... batch 6
epoch 14.... batch 7
epoch 14.... batch 8
0




100
done ...
epoch 14.... batch 1
epoch 14.... batch 2
epoch 14.... batch 3
epoch 14.... batch 4
epoch 14.... batch 5
epoch 14.... batch 6
epoch 14.... batch 7




95
done ...
epoch 14.... batch 1
epoch 14.... batch 2
epoch 14.... batch 3
epoch 14.... batch 4
epoch 14.... batch 5
epoch 14.... batch 6
epoch 14.... batch 7




94
done ...
epoch 14.... batch 1
epoch 14.... batch 2
epoch 14.... batch 3
epoch 14.... batch 4
epoch 14.... batch 5
epoch 14.... batch 6
epoch 14.... batch 7




93
done ...
epoch 14.... batch 1
epoch 14.... batch 2
epoch 14.... batch 3
epoch 14.... batch 4
epoch 14.... batch 5
epoch 14.... batch 6
epoch 14.... batch 7




91
done ...
epoch 14.... batch 1
epoch 14.... batch 2
epoch 14.... batch 3
epoch 14.... batch 4
epoch 14.... batch 5
epoch 14.... batch 6
epoch 14.... batch 7
Epoch 14, Training Loss: 12695925.936096719, Validation Loss: 165894960.0
106
done ...
epoch 15.... batch 1
epoch 15.... batch 2
epoch 15.... batch 3
epoch 15.... batch 4
epoch 15.... batch 5
epoch 15.... batch 6
epoch 15.... batch 7
epoch 15.... batch 8
0




100
done ...
epoch 15.... batch 1
epoch 15.... batch 2
epoch 15.... batch 3
epoch 15.... batch 4
epoch 15.... batch 5
epoch 15.... batch 6
epoch 15.... batch 7




95
done ...
epoch 15.... batch 1
epoch 15.... batch 2
epoch 15.... batch 3
epoch 15.... batch 4
epoch 15.... batch 5
epoch 15.... batch 6
epoch 15.... batch 7




94
done ...
epoch 15.... batch 1
epoch 15.... batch 2
epoch 15.... batch 3
epoch 15.... batch 4
epoch 15.... batch 5
epoch 15.... batch 6
epoch 15.... batch 7




93
done ...
epoch 15.... batch 1
epoch 15.... batch 2
epoch 15.... batch 3
epoch 15.... batch 4
epoch 15.... batch 5
epoch 15.... batch 6
epoch 15.... batch 7




91
done ...
epoch 15.... batch 1
epoch 15.... batch 2
epoch 15.... batch 3
epoch 15.... batch 4
epoch 15.... batch 5
epoch 15.... batch 6
epoch 15.... batch 7
Epoch 15, Training Loss: 12523748.974093264, Validation Loss: 144422528.0
106
done ...
epoch 16.... batch 1
epoch 16.... batch 2
epoch 16.... batch 3
epoch 16.... batch 4
epoch 16.... batch 5
epoch 16.... batch 6
epoch 16.... batch 7
epoch 16.... batch 8
0




100
done ...
epoch 16.... batch 1
epoch 16.... batch 2
epoch 16.... batch 3
epoch 16.... batch 4
epoch 16.... batch 5
epoch 16.... batch 6
epoch 16.... batch 7




95
done ...
epoch 16.... batch 1
epoch 16.... batch 2
epoch 16.... batch 3
epoch 16.... batch 4
epoch 16.... batch 5
epoch 16.... batch 6
epoch 16.... batch 7




94
done ...
epoch 16.... batch 1
epoch 16.... batch 2
epoch 16.... batch 3
epoch 16.... batch 4
epoch 16.... batch 5
epoch 16.... batch 6
epoch 16.... batch 7




93
done ...
epoch 16.... batch 1
epoch 16.... batch 2
epoch 16.... batch 3
epoch 16.... batch 4
epoch 16.... batch 5
epoch 16.... batch 6
epoch 16.... batch 7




91
done ...
epoch 16.... batch 1
epoch 16.... batch 2
epoch 16.... batch 3
epoch 16.... batch 4
epoch 16.... batch 5
epoch 16.... batch 6
epoch 16.... batch 7
Epoch 16, Training Loss: 12290292.0041019, Validation Loss: 158943280.0
106
done ...
epoch 17.... batch 1
epoch 17.... batch 2
epoch 17.... batch 3
epoch 17.... batch 4
epoch 17.... batch 5
epoch 17.... batch 6
epoch 17.... batch 7
epoch 17.... batch 8
0




100
done ...
epoch 17.... batch 1
epoch 17.... batch 2
epoch 17.... batch 3
epoch 17.... batch 4
epoch 17.... batch 5
epoch 17.... batch 6
epoch 17.... batch 7




95
done ...
epoch 17.... batch 1
epoch 17.... batch 2
epoch 17.... batch 3
epoch 17.... batch 4
epoch 17.... batch 5
epoch 17.... batch 6
epoch 17.... batch 7




94
done ...
epoch 17.... batch 1
epoch 17.... batch 2
epoch 17.... batch 3
epoch 17.... batch 4
epoch 17.... batch 5
epoch 17.... batch 6
epoch 17.... batch 7




93
done ...
epoch 17.... batch 1
epoch 17.... batch 2
epoch 17.... batch 3
epoch 17.... batch 4
epoch 17.... batch 5
epoch 17.... batch 6
epoch 17.... batch 7




91
done ...
epoch 17.... batch 1
epoch 17.... batch 2
epoch 17.... batch 3
epoch 17.... batch 4
epoch 17.... batch 5
epoch 17.... batch 6
epoch 17.... batch 7
Epoch 17, Training Loss: 12104560.962219344, Validation Loss: 161474784.0
106
done ...
epoch 18.... batch 1
epoch 18.... batch 2
epoch 18.... batch 3
epoch 18.... batch 4
epoch 18.... batch 5
epoch 18.... batch 6
epoch 18.... batch 7
epoch 18.... batch 8
0




100
done ...
epoch 18.... batch 1
epoch 18.... batch 2
epoch 18.... batch 3
epoch 18.... batch 4
epoch 18.... batch 5
epoch 18.... batch 6
epoch 18.... batch 7




95
done ...
epoch 18.... batch 1
epoch 18.... batch 2
epoch 18.... batch 3
epoch 18.... batch 4
epoch 18.... batch 5
epoch 18.... batch 6
epoch 18.... batch 7




94
done ...
epoch 18.... batch 1
epoch 18.... batch 2
epoch 18.... batch 3
epoch 18.... batch 4
epoch 18.... batch 5
epoch 18.... batch 6
epoch 18.... batch 7




93
done ...
epoch 18.... batch 1
epoch 18.... batch 2
epoch 18.... batch 3
epoch 18.... batch 4
epoch 18.... batch 5
epoch 18.... batch 6
epoch 18.... batch 7




91
done ...
epoch 18.... batch 1
epoch 18.... batch 2
epoch 18.... batch 3
epoch 18.... batch 4
epoch 18.... batch 5
epoch 18.... batch 6
epoch 18.... batch 7
Epoch 18, Training Loss: 11971846.898100173, Validation Loss: 182884736.0
106
done ...
epoch 19.... batch 1
epoch 19.... batch 2
epoch 19.... batch 3
epoch 19.... batch 4
epoch 19.... batch 5
epoch 19.... batch 6
epoch 19.... batch 7
epoch 19.... batch 8
0




100
done ...
epoch 19.... batch 1
epoch 19.... batch 2
epoch 19.... batch 3
epoch 19.... batch 4
epoch 19.... batch 5
epoch 19.... batch 6
epoch 19.... batch 7




95
done ...
epoch 19.... batch 1
epoch 19.... batch 2
epoch 19.... batch 3
epoch 19.... batch 4
epoch 19.... batch 5
epoch 19.... batch 6
epoch 19.... batch 7




94
done ...
epoch 19.... batch 1
epoch 19.... batch 2
epoch 19.... batch 3
epoch 19.... batch 4
epoch 19.... batch 5
epoch 19.... batch 6
epoch 19.... batch 7




93
done ...
epoch 19.... batch 1
epoch 19.... batch 2
epoch 19.... batch 3
epoch 19.... batch 4
epoch 19.... batch 5
epoch 19.... batch 6
epoch 19.... batch 7




91
done ...
epoch 19.... batch 1
epoch 19.... batch 2
epoch 19.... batch 3
epoch 19.... batch 4
epoch 19.... batch 5
epoch 19.... batch 6
epoch 19.... batch 7
Epoch 19, Training Loss: 11810475.995628238, Validation Loss: 177394496.0
106
done ...
epoch 20.... batch 1
epoch 20.... batch 2
epoch 20.... batch 3
epoch 20.... batch 4
epoch 20.... batch 5
epoch 20.... batch 6
epoch 20.... batch 7
epoch 20.... batch 8
0




100
done ...
epoch 20.... batch 1
epoch 20.... batch 2
epoch 20.... batch 3
epoch 20.... batch 4
epoch 20.... batch 5
epoch 20.... batch 6
epoch 20.... batch 7




95
done ...
epoch 20.... batch 1
epoch 20.... batch 2
epoch 20.... batch 3
epoch 20.... batch 4
epoch 20.... batch 5
epoch 20.... batch 6
epoch 20.... batch 7




94
done ...
epoch 20.... batch 1
epoch 20.... batch 2
epoch 20.... batch 3
epoch 20.... batch 4
epoch 20.... batch 5
epoch 20.... batch 6
epoch 20.... batch 7




93
done ...
epoch 20.... batch 1
epoch 20.... batch 2
epoch 20.... batch 3
epoch 20.... batch 4
epoch 20.... batch 5
epoch 20.... batch 6
epoch 20.... batch 7




91
done ...
epoch 20.... batch 1
epoch 20.... batch 2
epoch 20.... batch 3
epoch 20.... batch 4
epoch 20.... batch 5
epoch 20.... batch 6
epoch 20.... batch 7
Epoch 20, Training Loss: 11715104.589810017, Validation Loss: 169725984.0


ConvNet(
  (vconv): Conv2d(1, 15, kernel_size=(1103, 1), stride=(1, 1))
  (hconv): Conv2d(15, 25, kernel_size=(1, 50), stride=(1, 1))
  (fc0): Linear(in_features=118850, out_features=128, bias=True)
  (fc1): Linear(in_features=128, out_features=118850, bias=True)
  (fc2): Linear(in_features=128, out_features=118850, bias=True)
  (fc3): Linear(in_features=128, out_features=118850, bias=True)
  (fc4): Linear(in_features=128, out_features=118850, bias=True)
  (hdeconv1): ConvTranspose2d(25, 15, kernel_size=(1, 50), stride=(1, 1))
  (hdeconv2): ConvTranspose2d(25, 15, kernel_size=(1, 50), stride=(1, 1))
  (hdeconv3): ConvTranspose2d(25, 15, kernel_size=(1, 50), stride=(1, 1))
  (hdeconv4): ConvTranspose2d(25, 15, kernel_size=(1, 50), stride=(1, 1))
  (vdeconv1): ConvTranspose2d(15, 1, kernel_size=(1103, 1), stride=(1, 1))
  (vdeconv2): ConvTranspose2d(15, 1, kernel_size=(1103, 1), stride=(1, 1))
  (vdeconv3): ConvTranspose2d(15, 1, kernel_size=(1103, 1), stride=(1, 1))
  (vdeconv4): ConvTr

In [14]:
net = ConvNet(inp_size)
    # net.load_state_dict(torch.load('Weights/Weights_200_3722932.6015625.pth')) #least score Weights so far
net.load_state_dict(torch.load('Model2_Weights_v2/Weights_12_141569328.0.pth'))
net.eval()
filenames = []
ids = os.listdir('./musdb18hq/test')
ids = [x for x in ids if not x.startswith('.')]
filenames = []
count = 0
for i in ids:
    examplefkey = './musdb18hq/test/'+i+'/mixture.wav'
    example_audio = load_audio(examplefkey, seq = 0)
    if example_audio.shape[0] > 0:
        filenames.append(i)
    count += 1
    if count > 10:
        break

print(len(filenames))
    
for i in filenames:
# load an example audio file, converting the data to mel spectrogram
    total_SDR_bass = total_ISR_bass =  total_SIR_bass = total_SAR_bass =  0
    total_SDR_drum = total_ISR_drum = total_SIR_drum = total_SAR_drum =  0
    total_SDR_vocal = total_ISR_vocal = total_SIR_vocal = total_SAR_vocal =  0
    total_SDR_other = total_ISR_other = total_SIR_other = total_SAR_other =  0
    
    examplefkey = './musdb18hq/test/'+i
    mixture, m_phase = load_mixture(examplefkey+'/mixture',0)
    bass_gt = load_segments(examplefkey+'/bass', seq = 0)
    vocals_gt = load_segments(examplefkey+'/vocals', seq = 0)
    drums_gt = load_segments(examplefkey+'/drums', seq = 0)
    others_gt = load_segments(examplefkey+'/other', seq = 0)
    mean = torch.mean(torch.tensor(mixture))
    std = torch.std(torch.tensor(mixture))
    inp_n = torch.tensor(mixture)
    (time_len, mel_bins) = inp_n.shape
    inp_n = inp_n.view(1,time_len, mel_bins)
    print(inp_n.shape)
    bass_mag, vocals_mag, drums_mag,others_mag = net(torch.tensor(inp_n))
    print(bass_mag.shape)
    drums_mag=torch.tensor(mixture)*np.squeeze(bass_mag)
    vocals_mag=torch.tensor(mixture)*np.squeeze(vocals_mag)
    bass_mag=torch.tensor(mixture)*np.squeeze(drums_mag)
    others_mag=torch.tensor(mixture)*np.squeeze(others_mag)
    print(bass_mag.shape)    
    vocals = np.squeeze(vocals_mag.detach().numpy())
        #print(vocals.shape)
    bass = np.squeeze(bass_mag.detach().numpy())
    drums = np.squeeze(drums_mag.detach().numpy())
    others = np.squeeze(others_mag.detach().numpy())
    
    shape = drums_gt.flatten().shape
    
    
    SDR_bass,ISR_bass, SIR_bass, SAR_bass, perm = mir_eval.separation.bss_eval_images(bass_gt.reshape((1, shape[0],1)), bass.reshape((1, shape[0], 1)))
    total_SDR_bass += SDR_bass[0]
    total_ISR_bass += ISR_bass[0]
    total_SIR_bass += SIR_bass[0]
    total_SAR_bass += SAR_bass[0]
    SDR_drum, ISR_drum, SIR_drum, SAR_drum, perm = mir_eval.separation.bss_eval_images(drums_gt.reshape((1, shape[0],1)), drums.reshape((1, shape[0],1)))
    total_SDR_drum += SDR_drum[0]
    total_ISR_drum += ISR_drum[0]
    total_SIR_drum += SIR_drum[0]
    total_SAR_drum += SAR_drum[0]
    SDR_vocal, ISR_vocal, SIR_vocal, SAR_vocal, perm = mir_eval.separation.bss_eval_images(vocals_gt.reshape((1, shape[0],1)), vocals.reshape((1, shape[0],1)))
    total_SDR_vocal += SDR_vocal[0]
    total_ISR_vocal += ISR_vocal[0]
    total_SIR_vocal += SIR_vocal[0]
    total_SAR_vocal += SAR_vocal[0]
    SDR_other, ISR_other, SIR_other, SAR_other, perm = mir_eval.separation.bss_eval_images(others_gt.reshape((1, shape[0],1)), others.reshape((1, shape[0],1)))
    total_SDR_other += SDR_other[0]
    total_ISR_other += ISR_other[0]
    total_SIR_other += SIR_other[0]
    total_SAR_other += SAR_other[0]
    
mean_SDR_bass = total_SDR_bass/len(filenames)
mean_ISR_bass = total_ISR_bass/len(filenames)
mean_SIR_bass = total_SIR_bass/len(filenames)
mean_SAR_bass = total_SAR_bass/len(filenames)
mean_SDR_drum = total_SDR_drum/len(filenames)
mean_ISR_drum = total_ISR_bass/len(filenames)
mean_SIR_drum = total_SIR_drum/len(filenames)
mean_SAR_drum = total_SAR_drum/len(filenames)
mean_SDR_vocal = total_SDR_vocal/len(filenames)
mean_ISR_vocal = total_ISR_bass/len(filenames)
mean_SIR_vocal = total_SIR_vocal/len(filenames)
mean_SAR_vocal = total_SAR_vocal/len(filenames)
mean_SDR_other = total_SDR_other/len(filenames)
mean_ISR_other = total_ISR_bass/len(filenames)
mean_SIR_other = total_SIR_other/len(filenames)
mean_SAR_other = total_SAR_other/len(filenames)

11
torch.Size([1, 1103, 4803])


  bass_mag, vocals_mag, drums_mag,others_mag = net(torch.tensor(inp_n))


torch.Size([1, 1, 1103, 4803])
torch.Size([1103, 4803])
torch.Size([1, 1103, 4803])
torch.Size([1, 1, 1103, 4803])
torch.Size([1103, 4803])
torch.Size([1, 1103, 4803])
torch.Size([1, 1, 1103, 4803])
torch.Size([1103, 4803])
torch.Size([1, 1103, 4803])
torch.Size([1, 1, 1103, 4803])
torch.Size([1103, 4803])
torch.Size([1, 1103, 4803])
torch.Size([1, 1, 1103, 4803])
torch.Size([1103, 4803])
torch.Size([1, 1103, 4803])
torch.Size([1, 1, 1103, 4803])
torch.Size([1103, 4803])
torch.Size([1, 1103, 4803])
torch.Size([1, 1, 1103, 4803])
torch.Size([1103, 4803])
torch.Size([1, 1103, 4803])
torch.Size([1, 1, 1103, 4803])
torch.Size([1103, 4803])
torch.Size([1, 1103, 4803])
torch.Size([1, 1, 1103, 4803])
torch.Size([1103, 4803])
torch.Size([1, 1103, 4803])
torch.Size([1, 1, 1103, 4803])
torch.Size([1103, 4803])
torch.Size([1, 1103, 4803])
torch.Size([1, 1, 1103, 4803])
torch.Size([1103, 4803])


In [15]:
print(mean_SDR_bass)
print(mean_ISR_bass)
print(mean_SIR_bass)
print(mean_SAR_bass)

-1.6775937939922414
-1.0329210423220858
inf
-0.39153640360242664


In [16]:
print(mean_SDR_drum)
print(mean_ISR_drum)
print(mean_SIR_drum)
print(mean_SAR_drum)

0.06757230033464652
-1.0329210423220858
inf
-0.12958991127356365


In [17]:
print(mean_SDR_vocal)
print(mean_ISR_vocal)
print(mean_SIR_vocal)
print(mean_SAR_vocal)

-0.09318872925502573
-1.0329210423220858
inf
-2.0605958297324296


In [18]:
print(mean_SDR_other)
print(mean_ISR_other)
print(mean_SIR_other)
print(mean_SAR_other)

-0.0438717610673701
-1.0329210423220858
inf
-0.9514529531832366


In [20]:
net = net = ConvNet(inp_size)
    # net.load_state_dict(torch.load('Weights/Weights_200_3722932.6015625.pth')) #least score Weights so far
net.load_state_dict(torch.load('Model2_Weights_v2/Weights_12_141569328.0.pth'))
net.eval()
mixture, m_phase = load_mixture('./musdb18hq/test/Side Effects Project - Sing With Me/mixture', seq=1)
mean = torch.mean(torch.tensor(mixture))
std = torch.std(torch.tensor(mixture))
inp_n = (torch.tensor(mixture)-mean)/std
(time_len, mel_bins) = inp_n.shape
inp_n = inp_n.view(1,time_len, mel_bins)
bass_mag, vocals_mag, drums_mag,others_mag = net(torch.FloatTensor(inp_n))
mask_bass,mask_vocals,mask_drums,mask_others = TimeFreqMasking(bass_mag, vocals_mag, drums_mag, others_mag)
drums_mag=torch.FloatTensor(mixture)*np.squeeze(mask_drums)
vocals_mag=torch.FloatTensor(mixture)*np.squeeze(mask_vocals)
bass_mag=torch.FloatTensor(mixture)*np.squeeze(mask_bass)
others_mag=torch.FloatTensor(mixture)*np.squeeze(mask_others)
    
vocals = vocals_mag.detach().numpy() * m_phase
	#print(vocals.shape)
bass = bass_mag.detach().numpy()* m_phase
drums = drums_mag.detach().numpy()* m_phase
others = others_mag.detach().numpy()* m_phase

In [21]:
vocals_audio = librosa.istft(vocals, win_length=2204,hop_length=1103,window='hann',center='True')
bass_audio = librosa.istft(bass, win_length=2204,hop_length=1103,window='hann',center='True')
drums_audio = librosa.istft(drums, win_length=2204,hop_length=1103,window='hann',center='True')
others_audio = librosa.istft(others, win_length=2204,hop_length=1103,window='hann',center='True')

In [22]:
scipy.io.wavfile.write('./musdb18hq/model1_v2/vocal.wav', 44100, vocals_audio)
scipy.io.wavfile.write('./musdb18hq/model1_v2/bass.wav', 44100, bass_audio)
scipy.io.wavfile.write('./musdb18hq/model1_v2/drums.wav',  44100, drums_audio)
scipy.io.wavfile.write('./musdb18hq/model1_v2/other.wav',  44100, others_audio)