In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import torch.utils.data as utils
import librosa
import soundfile as sf
import time
import os
from torch.utils import data
from wavenet import Wavenet
from transformData import x_mu_law_encode,y_mu_law_encode,mu_law_decode,onehot,cateToSignal
from readDataset import Dataset

In [2]:
sampleSize=32000#the length of the sample size
quantization_channels=256
sample_rate=16000
dilations=[2**i for i in range(9)]*5  #idea from wavenet, have more receptive field
residualDim=128 #
skipDim=512
shapeoftest = 190500
filterSize=3
resumefile='./model/testac' # name of checkpoint
lossname='testacloss.txt' # name of loss file
continueTrain=True # whether use checkpoint
pad = np.sum(dilations) # padding for dilate convolutional layers
lossrecord=[]  #list for record loss
#pad=0

    #            |----------------------------------------|     *residual*
    #            |                                        |
    #            |    |-- conv -- tanh --|                |
    # -> dilate -|----|                  * ----|-- 1x1 -- + -->	*input*
    #                 |-- conv -- sigm --|     |    ||
    #                                         1x1=residualDim
    #                                          |
    # ---------------------------------------> + ------------->	*skip=skipDim*
    image changed from https://github.com/vincentherrmann/pytorch-wavenet/blob/master/wavenet_model.py

In [3]:
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # use specific GPU

In [4]:
use_cuda = torch.cuda.is_available() # whether have available GPU
torch.manual_seed(1)
device = torch.device("cuda" if use_cuda else "cpu")
#device = 'cpu'
#torch.set_default_tensor_type('torch.cuda.FloatTensor') #set_default_tensor_type as cuda tensor
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 

In [5]:
params = {'batch_size': 1,'shuffle': True,'num_workers': 1}
training_set = Dataset(['origin_mix'],['origin_vocal'],'./vsCorpus/','./vsCorpus/')
testing_set = Dataset(['pred_mix'],['pred_mix'],'./vsCorpus/','./vsCorpus/')
loadtr = data.DataLoader(training_set, **params) #pytorch dataloader, more faster than mine
loadval = data.DataLoader(testing_set, **params)

In [6]:
model = Wavenet(pad,skipDim,quantization_channels,residualDim,dilations).cuda()
criterion = nn.CrossEntropyLoss()
#in wavenet paper, they said crossentropyloss is far better than MSELoss
#optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
optimizer = optim.Adam(model.parameters(), lr=1e-3,weight_decay=1e-5)
#use adam to train
#optimizer = optim.SGD(model.parameters(), lr = 0.1, momentum=0.9, weight_decay=1e-5)
#scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
#scheduler = MultiStepLR(optimizer, milestones=[20,40], gamma=0.1)

In [7]:
if continueTrain:# if continueTrain, the program will find the checkpoints
    if os.path.isfile(resumefile):
        print("=> loading checkpoint '{}'".format(resumefile))
        checkpoint = torch.load(resumefile)
        start_epoch = checkpoint['epoch']
        #best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(resumefile, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(resumefile))

=> loading checkpoint './model/testac'
=> loaded checkpoint './model/testac' (epoch 65)


In [None]:
def val(xtrain,ytrain): #validation last 15 seconds of the audio.
    model.eval()
    start_time = time.time()
    with torch.no_grad():
        idx = np.arange(xtrain.shape[-1]-pad-10*sampleSize,xtrain.shape[-1]-pad-sampleSize,1000)
        np.random.shuffle(idx)
        data = xtrain[:,:,idx[0]-pad:pad+idx[0]+sampleSize].to(device)
        target = ytrain[:,idx[0]:idx[0]+sampleSize].to(device)
        output = model(data)
        pred = output.max(1, keepdim=True)[1]
        correct = pred.eq(target.view_as(pred)).sum().item() / pred.shape[-1]
        val_loss = criterion(output, target).item()
        print(correct,'accurate')
        print('\nval set:loss{:.4f}:, ({:.3f} sec/step)\n'.format(val_loss,time.time()-start_time))
        
        listofpred = []
        for ind in range(xtrain.shape[-1]-pad-10*sampleSize,xtrain.shape[-1]-pad-sampleSize,sampleSize):
            output = model(xtrain[:, :, ind - pad:ind + sampleSize + pad].to(device))
            pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
            listofpred.append(pred)
        ans = mu_law_decode(np.concatenate(listofpred))
        sf.write('./vsCorpus/notexval.wav', ans, sample_rate)
        print('val stored done',time.time() - start_time)
        

def test(xtrain):# testing data
    model.eval()
    start_time = time.time()
    with torch.no_grad():
        for iloader,(xtest,_) in enumerate(loadval):
            listofpred = []
            for ind in range(pad, xtest.shape[-1] - pad, sampleSize):
                output = model(xtest[:, :, ind - pad:ind + sampleSize + pad].to(device))
                pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
                listofpred.append(pred)
            ans = mu_law_decode(np.concatenate(listofpred))
            sf.write('./vsCorpus/notexte.wav', ans, sample_rate)

            listofpred=[]
            for ind in range(pad,xtrain.shape[-1]-pad,sampleSize):
                output = model(xtrain[:, :, ind-pad:ind+sampleSize+pad].to(device))
                pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
                listofpred.append(pred)
            ans = mu_law_decode(np.concatenate(listofpred))
            sf.write('./vsCorpus/notextr.wav', ans, sample_rate)
            print('test stored done',time.time() - start_time)
    
def train(epoch):#training data, the audio except for last 15 seconds
    model.train()
    for iloader,(xtrain,ytrain) in enumerate(loadtr):
        idx = np.arange(pad,xtrain.shape[-1]-pad-11*sampleSize,16000)
        np.random.shuffle(idx)#random the starting points
        for i, ind in enumerate(idx):
            start_time = time.time()
            data, target = xtrain[:,:,ind-pad:ind+sampleSize+pad].to(device), ytrain[:,ind:ind+sampleSize].to(device)
            output = model(data)
            loss = criterion(output, target)
            lossrecord.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss:{:.6f}: , ({:.3f} sec/step)'.format(
                    epoch, i, len(idx),100. * i / len(idx), loss.item(),time.time() - start_time))
            if i % 100 == 0:
                with open("./lossRecord/"+lossname, "w") as f:
                    for s in lossrecord:
                        f.write(str(s) +"\n")
                print('write finish')

        val(xtrain,ytrain)
        test(xtrain)
        state={'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()}
        torch.save(state, resumefile)

In [None]:
for epoch in range(100000):
    train(epoch)

write finish
write finish


write finish
0.02553125 accurate

val set:loss10.4109:, (0.664 sec/step)

val stored done 6.647839784622192
test stored done 82.91593790054321
write finish


write finish


write finish
0.03553125 accurate

val set:loss10.1373:, (0.664 sec/step)

val stored done 6.645713567733765
test stored done 82.85042715072632
write finish


write finish
write finish


0.02378125 accurate

val set:loss9.0158:, (0.664 sec/step)

val stored done 6.648181676864624
test stored done 82.97957849502563
write finish
write finish


write finish
0.03221875 accurate

val set:loss10.2008:, (0.664 sec/step)

val stored done 6.662619113922119
test stored done 82.98644018173218
write finish


write finish


write finish
0.03296875 accurate

val set:loss10.3034:, (0.664 sec/step)

val stored done 6.651562452316284
test stored done 83.00631356239319
write finish


write finish


write finish
0.018125 accurate

val set:loss10.0647:, (0.664 sec/step)

val stored done 6.658130884170532
test stored done 82.93777394294739
write finish
write finish


write finish
0.0258125 accurate

val set:loss9.6623:, (0.664 sec/step)

val stored done 6.649457693099976
test stored done 82.9147424697876
write finish


write finish


write finish
0.0331875 accurate

val set:loss10.2298:, (0.664 sec/step)

val stored done 6.650701999664307
test stored done 82.90164375305176
write finish


write finish


write finish
0.02253125 accurate

val set:loss9.1831:, (0.664 sec/step)

val stored done 6.656346321105957
test stored done 82.94410181045532
write finish


write finish
write finish
0.0198125 accurate

val set:loss9.5021:, (0.664 sec/step)

val stored done 6.651694297790527
test stored done 82.90150451660156
write finish


write finish


write finish
0.0211875 accurate

val set:loss10.8961:, (0.664 sec/step)

val stored done 6.654399156570435
test stored done 82.89040851593018
write finish


write finish


write finish
0.02796875 accurate

val set:loss9.3806:, (0.664 sec/step)

val stored done 6.646002531051636
test stored done 82.95368242263794
write finish


write finish


write finish
0.02421875 accurate

val set:loss9.2614:, (0.664 sec/step)

val stored done 6.647869110107422
test stored done 82.88250684738159
write finish
write finish


write finish
0.0354375 accurate

val set:loss10.0139:, (0.663 sec/step)

val stored done 6.6461021900177
test stored done 82.87916421890259
write finish


write finish


write finish
0.02775 accurate

val set:loss11.4885:, (0.664 sec/step)

val stored done 6.648071050643921
test stored done 82.90081977844238
write finish


write finish


write finish
0.0195625 accurate

val set:loss10.1798:, (0.664 sec/step)

val stored done 6.643070220947266
test stored done 82.8748950958252
write finish


write finish
write finish


0.0335625 accurate

val set:loss11.2096:, (0.664 sec/step)

val stored done 6.648542404174805
test stored done 82.93701839447021
write finish
write finish


write finish
0.03296875 accurate

val set:loss10.4576:, (0.664 sec/step)

val stored done 6.648010492324829
test stored done 83.0082426071167
write finish


write finish


write finish
0.02609375 accurate

val set:loss9.5832:, (0.664 sec/step)

val stored done 6.6463258266448975
test stored done 82.91061687469482
write finish


write finish


write finish
0.03171875 accurate

val set:loss10.8100:, (0.664 sec/step)

val stored done 6.651413202285767
test stored done 82.97158932685852
write finish
write finish


write finish
0.00875 accurate

val set:loss8.5945:, (0.664 sec/step)

val stored done 6.646071910858154
test stored done 82.90546441078186
write finish


write finish


write finish
0.03996875 accurate

val set:loss10.7782:, (0.663 sec/step)

val stored done 6.647452354431152
test stored done 82.86050701141357
write finish


write finish


write finish
0.02678125 accurate

val set:loss10.7947:, (0.664 sec/step)

val stored done 6.644447565078735
test stored done 82.95937895774841
write finish


write finish
write finish


0.01946875 accurate

val set:loss10.6392:, (0.663 sec/step)

val stored done 6.646300554275513
test stored done 82.95183873176575
write finish
write finish


write finish
0.032375 accurate

val set:loss11.2553:, (0.664 sec/step)

val stored done 6.646388530731201
test stored done 82.88687586784363
write finish


write finish


write finish
0.03871875 accurate

val set:loss11.5679:, (0.663 sec/step)

val stored done 6.643892288208008
test stored done 82.95013093948364
write finish


write finish


write finish
0.0086875 accurate

val set:loss9.2374:, (0.663 sec/step)

val stored done 6.646273612976074
test stored done 82.94399523735046
write finish


write finish
write finish
0.032875 accurate

val set:loss10.8874:, (0.664 sec/step)

val stored done 6.649601697921753
test stored done 82.9429349899292
write finish


write finish


write finish
0.0344375 accurate

val set:loss10.8640:, (0.664 sec/step)

val stored done 6.64429235458374
test stored done 82.89151430130005
write finish


write finish


write finish
0.02428125 accurate

val set:loss10.0023:, (0.664 sec/step)

val stored done 6.647886753082275
test stored done 82.86152935028076
write finish


write finish


write finish
0.00984375 accurate

val set:loss8.2276:, (0.664 sec/step)

val stored done 6.644959211349487
test stored done 82.94935274124146
write finish
write finish


write finish
0.04115625 accurate

val set:loss11.1091:, (0.665 sec/step)

val stored done 6.652665138244629
test stored done 83.16311430931091
write finish


write finish


write finish
0.02640625 accurate

val set:loss10.5269:, (0.664 sec/step)

val stored done 6.653456687927246
test stored done 83.0834367275238
write finish


write finish


write finish
0.0104375 accurate

val set:loss9.3666:, (0.664 sec/step)

val stored done 6.653703451156616
test stored done 83.05955696105957
write finish


write finish
write finish


0.02478125 accurate

val set:loss10.6765:, (0.664 sec/step)

val stored done 6.649251461029053
test stored done 82.97306108474731
write finish
write finish


write finish
0.032625 accurate

val set:loss11.0769:, (0.664 sec/step)

val stored done 6.650378465652466
test stored done 82.96651077270508
write finish


write finish


write finish
0.0101875 accurate

val set:loss8.4150:, (0.664 sec/step)

val stored done 6.65460467338562
test stored done 83.06132936477661
write finish


write finish


write finish
0.0101875 accurate

val set:loss8.8060:, (0.664 sec/step)

val stored done 6.651818037033081
test stored done 83.15348863601685
write finish
write finish


write finish
0.02940625 accurate

val set:loss10.6623:, (0.664 sec/step)

val stored done 6.64756178855896
test stored done 83.00601601600647
write finish


write finish


write finish
0.0351875 accurate

val set:loss10.9535:, (0.664 sec/step)

val stored done 6.646124362945557
test stored done 82.89856743812561
write finish


write finish


write finish
0.00790625 accurate

val set:loss9.6227:, (0.665 sec/step)

val stored done 6.65597939491272
test stored done 83.04560470581055
write finish


write finish
write finish


0.02728125 accurate

val set:loss11.4747:, (0.666 sec/step)

val stored done 6.655364751815796
test stored done 83.00734853744507
write finish
write finish


write finish
0.02128125 accurate

val set:loss9.5113:, (0.664 sec/step)

val stored done 6.653069734573364
test stored done 82.99948573112488
write finish


write finish


write finish
0.03534375 accurate

val set:loss12.0981:, (0.665 sec/step)

val stored done 6.651838779449463
test stored done 83.06128191947937
write finish


write finish


write finish
0.0073125 accurate

val set:loss9.9254:, (0.665 sec/step)

val stored done 6.65830397605896
test stored done 83.15967893600464
write finish


write finish
write finish
0.00846875 accurate

val set:loss9.0288:, (0.664 sec/step)

val stored done 6.65102481842041
test stored done 82.94064402580261
write finish


write finish


write finish
0.0321875 accurate

val set:loss11.5043:, (0.664 sec/step)

val stored done 6.650300025939941
test stored done 82.94720029830933
write finish


write finish


write finish
0.0388125 accurate

val set:loss12.3408:, (0.664 sec/step)

val stored done 6.651224136352539
test stored done 82.95489168167114
write finish


write finish


write finish
0.00996875 accurate

val set:loss9.2795:, (0.665 sec/step)

val stored done 6.6535303592681885
test stored done 82.97800922393799
write finish
write finish


write finish
0.0230625 accurate

val set:loss10.2802:, (0.664 sec/step)

val stored done 6.650042772293091
test stored done 83.07275819778442
write finish


write finish


write finish
0.02684375 accurate

val set:loss11.7634:, (0.664 sec/step)

val stored done 6.6469244956970215
test stored done 82.98856353759766
write finish


write finish


write finish
0.033375 accurate

val set:loss11.3331:, (0.664 sec/step)

val stored done 6.646295070648193
test stored done 83.00599503517151
write finish


write finish
write finish


0.03465625 accurate

val set:loss11.9586:, (0.664 sec/step)

val stored done 6.6470232009887695
test stored done 82.89468288421631
write finish
write finish


write finish
0.02590625 accurate

val set:loss11.7934:, (0.664 sec/step)

val stored done 6.647669792175293
test stored done 82.89752984046936
write finish


write finish


write finish
0.02496875 accurate

val set:loss10.5100:, (0.664 sec/step)

val stored done 6.646557331085205
test stored done 82.91553401947021
write finish


write finish


write finish
0.01125 accurate

val set:loss8.7590:, (0.664 sec/step)

val stored done 6.648185729980469
test stored done 83.05592060089111
write finish
write finish


write finish
0.0284375 accurate

val set:loss10.7003:, (0.664 sec/step)

val stored done 6.647562265396118
test stored done 83.0210554599762
write finish


write finish


write finish
0.03659375 accurate

val set:loss12.6518:, (0.665 sec/step)

val stored done 6.650180816650391
test stored done 83.22836327552795
write finish


write finish


write finish
0.03428125 accurate

val set:loss11.7588:, (0.664 sec/step)

val stored done 6.648432493209839
test stored done 82.97642397880554
write finish


write finish
write finish


0.02625 accurate

val set:loss11.5391:, (0.664 sec/step)

val stored done 6.647578954696655
test stored done 82.92369484901428
write finish
write finish


write finish
0.03171875 accurate

val set:loss12.6186:, (0.664 sec/step)

val stored done 6.648333549499512
test stored done 82.94824194908142
write finish


write finish


write finish
0.0368125 accurate

val set:loss12.8142:, (0.664 sec/step)

val stored done 6.647141218185425
test stored done 83.04736256599426
write finish


write finish


write finish
0.02725 accurate

val set:loss12.6034:, (0.664 sec/step)

val stored done 6.648632764816284
test stored done 83.02273559570312
write finish


write finish
write finish
0.0375625 accurate

val set:loss12.7437:, (0.664 sec/step)

val stored done 6.645395755767822
test stored done 82.98952054977417
write finish


write finish


write finish
0.022375 accurate

val set:loss10.2007:, (0.664 sec/step)

val stored done 6.647779226303101
test stored done 83.0362617969513
write finish


write finish


write finish
0.04034375 accurate

val set:loss12.5155:, (0.665 sec/step)

val stored done 6.65143609046936
test stored done 83.03542399406433
write finish


write finish


write finish
0.02653125 accurate

val set:loss11.7520:, (0.664 sec/step)

val stored done 6.647926568984985
test stored done 82.90596890449524
write finish
write finish


write finish
0.0275625 accurate

val set:loss11.9856:, (0.664 sec/step)

val stored done 6.649839401245117
test stored done 83.06209945678711
write finish


write finish


write finish
0.02359375 accurate

val set:loss10.4095:, (0.664 sec/step)

val stored done 6.650144577026367
test stored done 83.03827333450317
write finish


write finish


write finish
0.02884375 accurate

val set:loss10.5286:, (0.664 sec/step)

val stored done 6.648297548294067
test stored done 83.20613360404968
write finish


write finish
write finish


0.0335625 accurate

val set:loss11.2684:, (0.664 sec/step)

val stored done 6.64932107925415
test stored done 83.0020399093628
write finish
write finish


write finish
0.03590625 accurate

val set:loss11.8261:, (0.664 sec/step)

val stored done 6.648693561553955
test stored done 83.06917309761047
write finish


write finish


write finish
0.0088125 accurate

val set:loss9.7661:, (0.665 sec/step)

val stored done 6.648249864578247
test stored done 82.91818451881409
write finish


write finish


write finish
0.03646875 accurate

val set:loss11.7949:, (0.664 sec/step)

val stored done 6.65378999710083
test stored done 83.24607181549072
write finish
write finish


write finish
0.022625 accurate

val set:loss11.7792:, (0.664 sec/step)

val stored done 6.64498496055603
test stored done 82.95652532577515
write finish


write finish


write finish
0.03125 accurate

val set:loss12.0767:, (0.664 sec/step)

val stored done 6.648767948150635
test stored done 83.05948042869568
write finish


write finish


write finish
0.03234375 accurate

val set:loss11.3159:, (0.664 sec/step)

val stored done 6.646686792373657
test stored done 83.05739569664001
write finish


write finish
write finish


0.0259375 accurate

val set:loss12.0235:, (0.664 sec/step)

val stored done 6.649054527282715
test stored done 82.93063521385193
write finish
write finish


write finish
0.01725 accurate

val set:loss10.6444:, (0.664 sec/step)

val stored done 6.647311449050903
test stored done 83.02737784385681
write finish


write finish


write finish
0.0384375 accurate

val set:loss12.0170:, (0.664 sec/step)

val stored done 6.651576519012451
test stored done 83.05910444259644
write finish


write finish


write finish
0.0320625 accurate

val set:loss12.7692:, (0.664 sec/step)

val stored done 6.645392656326294
test stored done 83.00469088554382
write finish


write finish
write finish
0.00875 accurate

val set:loss9.8711:, (0.664 sec/step)

val stored done 6.649216651916504
test stored done 83.04894948005676
write finish


write finish


write finish
0.0094375 accurate

val set:loss10.3364:, (0.664 sec/step)

val stored done 6.6473610401153564
test stored done 83.00356602668762
write finish


write finish


write finish
0.0368125 accurate

val set:loss12.3935:, (0.664 sec/step)

val stored done 6.6530468463897705
test stored done 83.0467677116394
write finish


write finish


write finish
0.031625 accurate

val set:loss12.3265:, (0.664 sec/step)

val stored done 6.64919638633728
test stored done 83.01692938804626
write finish
write finish


write finish
0.03578125 accurate

val set:loss12.3169:, (0.665 sec/step)

val stored done 6.654561519622803
test stored done 83.13341450691223
write finish


write finish


write finish
0.03553125 accurate

val set:loss11.2767:, (0.664 sec/step)

val stored done 6.64738130569458
test stored done 82.97796750068665
write finish


write finish


write finish
0.008875 accurate

val set:loss10.5350:, (0.665 sec/step)

val stored done 6.651462078094482
test stored done 83.05481553077698
