In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import torch.utils.data as utils
import librosa
import soundfile as sf
import time
import os
from torch.utils import data
from wavenet import Wavenet
from transformData import x_mu_law_encode,y_mu_law_encode,mu_law_decode,onehot,cateToSignal
from readDataset import Dataset

In [2]:
sampleSize=32000#the length of the sample size
quantization_channels=256
sample_rate=16000
dilations=[2**i for i in range(9)]*5  #idea from wavenet, have more receptive field
residualDim=128 #
skipDim=512
shapeoftest = 190500
filterSize=3
resumefile='testac' # name of checkpoint
lossname='testacloss.txt' # name of loss file
continueTrain=False # whether use checkpoint
pad = np.sum(dilations) # padding for dilate convolutional layers
lossrecord=[]  #list for record loss
pad=0

    #            |----------------------------------------|     *residual*
    #            |                                        |
    #            |    |-- conv -- tanh --|                |
    # -> dilate -|----|                  * ----|-- 1x1 -- + -->	*input*
    #                 |-- conv -- sigm --|     |    ||
    #                                         1x1=residualDim
    #                                          |
    # ---------------------------------------> + ------------->	*skip=skipDim*
    image changed from https://github.com/vincentherrmann/pytorch-wavenet/blob/master/wavenet_model.py

In [3]:
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # use specific GPU

In [4]:
use_cuda = torch.cuda.is_available() # whether have available GPU
torch.manual_seed(1)
device = torch.device("cuda" if use_cuda else "cpu")
#device = 'cpu'
#torch.set_default_tensor_type('torch.cuda.FloatTensor') #set_default_tensor_type as cuda tensor
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 

In [5]:
params = {'batch_size': 1,'shuffle': True,'num_workers': 1}
training_set = Dataset(['origin_mix'],['origin_vocal'],'./vsCorpus/','./vsCorpus/')
testing_set = Dataset(['pred_mix'],['pred_mix'],'./vsCorpus/','./vsCorpus/')
loadtr = data.DataLoader(training_set, **params) #pytorch dataloader, more faster than mine
loadval = data.DataLoader(testing_set, **params)

In [6]:
model = Wavenet(pad,skipDim,quantization_channels,residualDim,dilations).cuda()
criterion = nn.CrossEntropyLoss()
#in wavenet paper, they said crossentropyloss is far better than MSELoss
#optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
optimizer = optim.Adam(model.parameters(), lr=1e-3,weight_decay=1e-5)
#use adam to train
#optimizer = optim.SGD(model.parameters(), lr = 0.1, momentum=0.9, weight_decay=1e-5)
#scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
#scheduler = MultiStepLR(optimizer, milestones=[20,40], gamma=0.1)

In [7]:
if continueTrain:# if continueTrain, the program will find the checkpoints
    if os.path.isfile(resumefile):
        print("=> loading checkpoint '{}'".format(resumefile))
        checkpoint = torch.load(resumefile)
        start_epoch = checkpoint['epoch']
        #best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(resumefile, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(resumefile))

In [None]:
def val(xtrain,ytrain): #validation last 15 seconds of the audio.
    model.eval()
    start_time = time.time()
    with torch.no_grad():
        idx = np.arange(xtrain.shape[-1]-pad-10*sampleSize,xtrain.shape[-1]-pad-sampleSize,1000)
        np.random.shuffle(idx)
        data = xtrain[:,:,idx[0]-pad:pad+idx[0]+sampleSize].to(device)
        target = ytrain[:,idx[0]:idx[0]+sampleSize].to(device)
        output = model(data)
        pred = output.max(1, keepdim=True)[1]
        correct = pred.eq(target.view_as(pred)).sum().item() / pred.shape[-1]
        val_loss = criterion(output, target).item()
        print(correct,'accurate')
        print('\nval set:loss{:.4f}:, ({:.3f} sec/step)\n'.format(val_loss,time.time()-start_time))
        
        listofpred = []
        for ind in range(xtrain.shape[-1]-pad-10*sampleSize,xtrain.shape[-1]-pad-sampleSize,sampleSize):
            output = model(xtrain[:, :, ind - pad:ind + sampleSize + pad].to(device))
            pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
            listofpred.append(pred)
        ans = mu_law_decode(np.concatenate(listofpred))
        sf.write('./vsCorpus/notexval.wav', ans, sample_rate)
        print('val stored done',time.time() - start_time)
        

def test(xtrain):# testing data
    model.eval()
    start_time = time.time()
    with torch.no_grad():
        for iloader,(xtest,_) in enumerate(loadval):
            listofpred = []
            for ind in range(pad, xtest.shape[-1] - pad, sampleSize):
                output = model(xtest[:, :, ind - pad:ind + sampleSize + pad].to(device))
                pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
                listofpred.append(pred)
            ans = mu_law_decode(np.concatenate(listofpred))
            sf.write('./vsCorpus/notexte.wav', ans, sample_rate)

            listofpred=[]
            for ind in range(pad,xtrain.shape[-1]-pad,sampleSize):
                output = model(xtrain[:, :, ind-pad:ind+sampleSize+pad].to(device))
                pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
                listofpred.append(pred)
            ans = mu_law_decode(np.concatenate(listofpred))
            sf.write('./vsCorpus/notextr.wav', ans, sample_rate)
            print('test stored done',time.time() - start_time)
    
def train(epoch):#training data, the audio except for last 15 seconds
    model.train()
    for iloader,(xtrain,ytrain) in enumerate(loadtr):
        idx = np.arange(pad,xtrain.shape[-1]-pad-11*sampleSize,16000)
        np.random.shuffle(idx)#random the starting points
        for i, ind in enumerate(idx):
            start_time = time.time()
            data, target = xtrain[:,:,ind-pad:ind+sampleSize+pad].to(device), ytrain[:,ind:ind+sampleSize].to(device)
            output = model(data)
            loss = criterion(output, target)
            lossrecord.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss:{:.6f}: , ({:.3f} sec/step)'.format(
                    epoch, i, len(idx),100. * i / len(idx), loss.item(),time.time() - start_time))
            if i % 100 == 0:
                with open("./lossRecord/"+lossname, "w") as f:
                    for s in lossrecord:
                        f.write(str(s) +"\n")
                print('write finish')

        val(xtrain,ytrain)
        test(xtrain)
        state={'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()}
        torch.save(state, './model/'+resumefile)

In [None]:
for epoch in range(100000):
    train(epoch)

write finish
write finish


write finish
0.0201875 accurate

val set:loss4.6670:, (0.590 sec/step)

val stored done 5.926646709442139
test stored done 74.01042699813843
write finish


write finish


write finish
0.0248125 accurate

val set:loss4.4408:, (0.588 sec/step)

val stored done 5.8873982429504395
test stored done 74.41709685325623
write finish


write finish
write finish


0.0366875 accurate

val set:loss4.2518:, (0.584 sec/step)

val stored done 5.865708589553833
test stored done 74.16592073440552
write finish
write finish


write finish
0.0325625 accurate

val set:loss4.2051:, (0.588 sec/step)

val stored done 5.8973047733306885
test stored done 74.79969644546509
write finish


write finish


write finish
0.031375 accurate

val set:loss4.3368:, (0.586 sec/step)

val stored done 5.886323690414429
test stored done 74.31520056724548
write finish


write finish


write finish
0.0345625 accurate

val set:loss4.2316:, (0.588 sec/step)

val stored done 5.892812490463257
test stored done 75.00507473945618
write finish
write finish


write finish
0.028 accurate

val set:loss4.4489:, (0.736 sec/step)

val stored done 7.188734531402588
test stored done 91.42265200614929
write finish


write finish


write finish
0.02815625 accurate

val set:loss4.4423:, (0.738 sec/step)

val stored done 7.301671743392944
test stored done 91.71854424476624
write finish


write finish


write finish
0.03815625 accurate

val set:loss4.1484:, (0.586 sec/step)

val stored done 5.888523101806641
test stored done 74.33047866821289
write finish


write finish
write finish
0.02490625 accurate

val set:loss4.5893:, (0.587 sec/step)

val stored done 5.887896537780762
test stored done 74.27925896644592
write finish


write finish


write finish
0.03334375 accurate

val set:loss4.4894:, (0.585 sec/step)

val stored done 5.868078708648682
test stored done 73.96073961257935
write finish


write finish


write finish
0.02821875 accurate

val set:loss4.3575:, (0.587 sec/step)

val stored done 5.883631229400635
test stored done 74.97752141952515
write finish


write finish


write finish
0.02465625 accurate

val set:loss4.9410:, (0.723 sec/step)

val stored done 7.283166170120239
test stored done 91.17386245727539
write finish
write finish


write finish
0.03634375 accurate

val set:loss4.5321:, (0.739 sec/step)

val stored done 7.38289999961853
test stored done 91.42072033882141
write finish


write finish


write finish
0.0068125 accurate

val set:loss4.8748:, (0.630 sec/step)

val stored done 7.172504186630249
test stored done 91.58436346054077
write finish


write finish


write finish
0.03975 accurate

val set:loss4.6479:, (0.726 sec/step)

val stored done 7.258971214294434
test stored done 91.16660904884338
write finish


write finish
write finish


0.01659375 accurate

val set:loss4.7410:, (0.727 sec/step)

val stored done 7.286979675292969
test stored done 91.45322465896606
write finish
write finish


write finish
0.0065 accurate

val set:loss5.0808:, (0.731 sec/step)

val stored done 7.28698468208313
test stored done 91.15280628204346
write finish


write finish


write finish
0.0078125 accurate

val set:loss5.0112:, (0.720 sec/step)

val stored done 7.2551047801971436
test stored done 91.3019003868103
write finish


write finish


write finish
0.03521875 accurate

val set:loss5.1914:, (0.729 sec/step)

val stored done 7.16628623008728
test stored done 91.91409826278687
write finish
write finish


write finish
0.03521875 accurate

val set:loss5.1014:, (0.726 sec/step)

val stored done 7.285945892333984
test stored done 91.11758303642273
write finish


write finish


write finish
0.02759375 accurate

val set:loss5.3451:, (0.728 sec/step)

val stored done 7.26424765586853
test stored done 76.12918210029602
write finish


write finish


write finish
0.02453125 accurate

val set:loss5.8660:, (0.587 sec/step)

val stored done 5.8901026248931885
test stored done 74.46688580513
write finish


write finish
write finish


0.0400625 accurate

val set:loss5.7774:, (0.589 sec/step)

val stored done 5.901871204376221
test stored done 74.32988667488098
write finish
write finish


write finish
0.03346875 accurate

val set:loss6.1093:, (0.589 sec/step)

val stored done 5.909090518951416
test stored done 75.08097267150879
write finish


write finish


write finish
0.02225 accurate

val set:loss5.7597:, (0.590 sec/step)

val stored done 5.914597034454346
test stored done 74.6892740726471
write finish


write finish


write finish
0.0240625 accurate

val set:loss6.0204:, (0.590 sec/step)

val stored done 5.932009935379028
test stored done 74.57061886787415
write finish


write finish
write finish
0.02259375 accurate

val set:loss6.4602:, (0.590 sec/step)

val stored done 5.920278072357178
test stored done 75.07645440101624
write finish


write finish


write finish
0.0335 accurate

val set:loss6.3502:, (0.588 sec/step)

val stored done 5.889867305755615
test stored done 74.81329274177551
write finish


write finish


write finish
0.009875 accurate

val set:loss5.6119:, (0.587 sec/step)

val stored done 5.8898844718933105
test stored done 74.22093439102173
write finish


write finish


write finish
0.03784375 accurate

val set:loss6.6984:, (0.589 sec/step)

val stored done 5.910546541213989
test stored done 74.89516997337341
write finish
write finish


write finish
0.01190625 accurate

val set:loss6.0836:, (0.587 sec/step)

val stored done 5.88667368888855
test stored done 74.3187038898468
write finish


write finish


write finish
0.02759375 accurate

val set:loss7.2745:, (0.590 sec/step)

val stored done 5.915464639663696
test stored done 74.21878409385681
write finish


write finish


write finish
0.0339375 accurate

val set:loss7.3304:, (0.591 sec/step)

val stored done 5.92846941947937
test stored done 74.57233929634094
write finish


write finish
write finish


0.03540625 accurate

val set:loss7.1597:, (0.587 sec/step)

val stored done 5.890265703201294
test stored done 74.66832733154297
write finish
write finish


write finish
0.03528125 accurate

val set:loss7.1343:, (0.589 sec/step)

val stored done 5.903706789016724
test stored done 74.09449291229248
write finish


write finish


write finish
0.0389375 accurate

val set:loss7.3116:, (0.588 sec/step)

val stored done 5.9065845012664795
test stored done 74.5035891532898
write finish


write finish


write finish
0.032875 accurate

val set:loss7.5327:, (0.587 sec/step)

val stored done 5.8890464305877686
test stored done 74.41300439834595
write finish
write finish


write finish
0.01365625 accurate

val set:loss6.5445:, (0.588 sec/step)

val stored done 5.891075611114502
test stored done 75.18269491195679
write finish


write finish


write finish
0.0144375 accurate

val set:loss7.0244:, (0.587 sec/step)

val stored done 5.880619287490845
test stored done 73.99495434761047
write finish


write finish


write finish
0.0301875 accurate

val set:loss8.2331:, (0.588 sec/step)

val stored done 5.899067401885986
test stored done 74.41257405281067
write finish


write finish
write finish


0.03584375 accurate

val set:loss8.1588:, (0.590 sec/step)

val stored done 5.916539669036865
test stored done 74.5551815032959
write finish
write finish


write finish
0.03075 accurate

val set:loss8.0317:, (0.587 sec/step)

val stored done 5.890748977661133
test stored done 74.62664866447449
write finish


write finish


write finish
0.03621875 accurate

val set:loss8.5739:, (0.588 sec/step)

val stored done 5.896947383880615
test stored done 74.11543607711792
write finish


write finish


write finish
0.02934375 accurate

val set:loss9.0843:, (0.588 sec/step)

val stored done 5.90200662612915
test stored done 74.22055768966675
write finish


write finish
write finish
0.03328125 accurate

val set:loss8.6480:, (0.588 sec/step)

val stored done 5.896418571472168
test stored done 74.7164146900177
write finish


write finish


write finish
0.0278125 accurate

val set:loss9.2776:, (0.590 sec/step)

val stored done 5.917497873306274
test stored done 74.5824363231659
write finish


write finish


write finish
0.03253125 accurate

val set:loss8.5840:, (0.587 sec/step)

val stored done 5.885419607162476
test stored done 74.6684684753418
write finish


write finish


write finish
0.04225 accurate

val set:loss8.2571:, (0.590 sec/step)

val stored done 5.923754453659058
test stored done 74.54901123046875
write finish
write finish


write finish
0.02565625 accurate

val set:loss8.1895:, (0.588 sec/step)

val stored done 5.901029109954834
test stored done 74.72187876701355
write finish


write finish


write finish
0.0336875 accurate

val set:loss8.4538:, (0.589 sec/step)

val stored done 5.907353401184082
test stored done 74.87490224838257
write finish


write finish


write finish
0.03196875 accurate

val set:loss9.0451:, (0.588 sec/step)

val stored done 5.89915919303894
test stored done 74.67136406898499
write finish


write finish
write finish


0.04384375 accurate

val set:loss8.2555:, (0.588 sec/step)

val stored done 5.901579141616821
test stored done 74.16980743408203
write finish
write finish


write finish
0.03371875 accurate

val set:loss9.0265:, (0.588 sec/step)

val stored done 5.907092571258545
test stored done 74.41222810745239
write finish
