In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import torch.utils.data as utils
import librosa
import soundfile as sf
import time
import os
from torch.utils import data
from wavenet import Wavenet
from transformData import x_mu_law_encode,y_mu_law_encode,mu_law_decode,onehot,cateToSignal
from readDataset import Dataset

In [2]:
sampleSize=32000#the length of the sample size
quantization_channels=256
sample_rate=16000
dilations=[2**i for i in range(9)]*5  #idea from wavenet, have more receptive field
residualDim=128 #
skipDim=512
shapeoftest = 190500
filterSize=3
songnum=10
savemusic='./vsCorpus/notextr{}.wav'
resumefile='./model/testac' # name of checkpoint
lossname='testacloss.txt' # name of loss file
continueTrain=False # whether use checkpoint
pad = np.sum(dilations) # padding for dilate convolutional layers
lossrecord=[]  #list for record loss
#pad=0

    #            |----------------------------------------|     *residual*
    #            |                                        |
    #            |    |-- conv -- tanh --|                |
    # -> dilate -|----|                  * ----|-- 1x1 -- + -->	*input*
    #                 |-- conv -- sigm --|     |    ||
    #                                         1x1=residualDim
    #                                          |
    # ---------------------------------------> + ------------->	*skip=skipDim*
    image changed from https://github.com/vincentherrmann/pytorch-wavenet/blob/master/wavenet_model.py

In [3]:
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # use specific GPU

In [4]:
use_cuda = torch.cuda.is_available() # whether have available GPU
torch.manual_seed(1)
device = torch.device("cuda" if use_cuda else "cpu")
#device = 'cpu'
#torch.set_default_tensor_type('torch.cuda.FloatTensor') #set_default_tensor_type as cuda tensor
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 

In [5]:
params = {'batch_size': 1, 'shuffle': True, 'num_workers': 1}
training_set = Dataset(np.arange(0, songnum), np.arange(0, songnum), 'ccmixter2/x/', 'ccmixter2/y/')
validation_set = Dataset(np.arange(0, songnum), np.arange(0, songnum), 'ccmixter2/x/', 'ccmixter2/y/')
loadtr = data.DataLoader(training_set, **params)  # pytorch dataloader, more faster than mine
loadval = data.DataLoader(validation_set, **params)

In [6]:
model = Wavenet(pad,skipDim,quantization_channels,residualDim,dilations).cuda()
criterion = nn.CrossEntropyLoss()
#in wavenet paper, they said crossentropyloss is far better than MSELoss
#optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
optimizer = optim.Adam(model.parameters(), lr=1e-3,weight_decay=1e-5)
#use adam to train
#optimizer = optim.SGD(model.parameters(), lr = 0.1, momentum=0.9, weight_decay=1e-5)
#scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
#scheduler = MultiStepLR(optimizer, milestones=[20,40], gamma=0.1)

In [7]:
if continueTrain:# if continueTrain, the program will find the checkpoints
    if os.path.isfile(resumefile):
        print("=> loading checkpoint '{}'".format(resumefile))
        checkpoint = torch.load(resumefile)
        start_epoch = checkpoint['epoch']
        #best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(resumefile, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(resumefile))

In [None]:
'''def val(xtrain,ytrain): #validation last 15 seconds of the audio.
    model.eval()
    start_time = time.time()
    with torch.no_grad():
        idx = np.arange(xtrain.shape[-1]-pad-10*sampleSize,xtrain.shape[-1]-pad-sampleSize,1000)
        np.random.shuffle(idx)
        data = xtrain[:,:,idx[0]-pad:pad+idx[0]+sampleSize].to(device)
        target = ytrain[:,idx[0]:idx[0]+sampleSize].to(device)
        output = model(data)
        pred = output.max(1, keepdim=True)[1]
        correct = pred.eq(target.view_as(pred)).sum().item() / pred.shape[-1]
        val_loss = criterion(output, target).item()
        print(correct,'accurate')
        print('\nval set:loss{:.4f}:, ({:.3f} sec/step)\n'.format(val_loss,time.time()-start_time))
        
        listofpred = []
        for ind in range(xtrain.shape[-1]-pad-10*sampleSize,xtrain.shape[-1]-pad-sampleSize,sampleSize):
            output = model(xtrain[:, :, ind - pad:ind + sampleSize + pad].to(device))
            pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
            listofpred.append(pred)
        ans = mu_law_decode(np.concatenate(listofpred))
        sf.write('./vsCorpus/notexval.wav', ans, sample_rate)
        print('val stored done',time.time() - start_time)'''
        

def test(xtrain,iloader):# testing data
    model.eval()
    start_time = time.time()
    with torch.no_grad():
        '''for iloader,(xtest,_) in enumerate(loadval):
            listofpred = []
            for ind in range(pad, xtest.shape[-1] - pad, sampleSize):
                output = model(xtest[:, :, ind - pad:ind + sampleSize + pad].to(device))
                pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
                listofpred.append(pred)
            ans = mu_law_decode(np.concatenate(listofpred))
            sf.write('./vsCorpus/notexte.wav', ans, sample_rate)'''

        listofpred=[]
        for ind in range(pad,xtrain.shape[-1]-pad,sampleSize):
            output = model(xtrain[:, :, ind-pad:ind+sampleSize+pad].to(device))
            pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
            listofpred.append(pred)
        ans = mu_law_decode(np.concatenate(listofpred))
        sf.write(savemusic.format(iloader), ans, sample_rate)
        print('test stored done',time.time() - start_time)
    
def train(epoch):#training data, the audio except for last 15 seconds
    model.train()
    for iloader,(xtrain,ytrain) in enumerate(loadtr):
        idx = np.arange(pad,xtrain.shape[-1]-pad-sampleSize,8000)
        np.random.shuffle(idx)#random the starting points
        lens = idx.shape[-1] // songnum
        idx = idx[:lens]
        for i, ind in enumerate(idx):
            start_time = time.time()
            data, target = xtrain[:,:,ind-pad:ind+sampleSize+pad].to(device), ytrain[:,ind:ind+sampleSize].to(device)
            output = model(data)
            loss = criterion(output, target)
            lossrecord.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print('Train Epoch: {} iloader:{} [{}/{} ({:.0f}%)] Loss:{:.6f}: , ({:.3f} sec/step)'.format(
                epoch, iloader, i, len(idx), 100. * i / len(idx), loss.item(), time.time() - start_time))
            if i % 100 == 0:
                with open("./lossRecord/"+lossname, "w") as f:
                    for s in lossrecord:
                        f.write(str(s) +"\n")
                print('write finish')

        #val(xtrain,ytrain)
        test(xtrain,iloader)
        state={'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()}
        torch.save(state, resumefile)

In [None]:
for epoch in range(100000):
    train(epoch)

write finish
test stored done 64.27492713928223
write finish
test stored done 70.72624969482422
write finish


test stored done 71.52951955795288
write finish
test stored done 56.1863853931427
write finish
test stored done 35.044321060180664
write finish
test stored done 63.601567029953
write finish


test stored done 73.0145833492279
write finish
test stored done 66.5096161365509
write finish


test stored done 89.00384616851807
write finish
test stored done 79.97767663002014
write finish
test stored done 70.69148349761963
write finish


test stored done 71.51208257675171
write finish
test stored done 89.02503156661987
write finish


test stored done 66.50520205497742
write finish
test stored done 79.95967555046082
write finish
test stored done 56.19634771347046
write finish


test stored done 63.60625219345093
write finish
test stored done 64.28791904449463
write finish
test stored done 73.00286841392517
write finish


test stored done 35.033369064331055
write finish
test stored done 64.27975153923035
write finish
test stored done 70.71395349502563
write finish


test stored done 73.0069670677185
write finish
test stored done 63.48702931404114
write finish
test stored done 71.38219928741455
write finish


test stored done 88.83055353164673
write finish
test stored done 56.080259799957275
write finish


test stored done 79.82687830924988
write finish
test stored done 34.95906710624695
write finish
test stored done 66.33846163749695
write finish
test stored done 70.56275725364685
write finish


test stored done 56.100361824035645
write finish
test stored done 63.48997735977173
write finish


test stored done 79.7683470249176
write finish
test stored done 66.39697051048279
write finish
test stored done 88.80032467842102
write finish


test stored done 71.34004259109497
write finish
test stored done 72.8353271484375
write finish
test stored done 34.95787858963013
write finish


test stored done 64.12502574920654
write finish
test stored done 63.45907378196716
write finish
test stored done 34.95691967010498
write finish


test stored done 70.53538870811462
write finish
test stored done 66.37425947189331
write finish
test stored done 71.39173340797424
write finish


test stored done 88.7499451637268
write finish
test stored done 79.79587435722351
write finish
test stored done 56.05764651298523
write finish


test stored done 72.8036618232727
write finish
test stored done 64.12776231765747
write finish


test stored done 64.12171411514282
write finish
test stored done 34.952434062957764
write finish
test stored done 66.36403393745422
write finish


test stored done 88.80538439750671
write finish
test stored done 72.82335233688354
write finish
test stored done 63.45771837234497
write finish


test stored done 79.7855327129364
write finish
test stored done 70.5219931602478
write finish
test stored done 71.38245558738708
write finish


test stored done 56.08816337585449
write finish
test stored done 70.53410005569458
write finish


test stored done 79.78805446624756
write finish
test stored done 88.82363367080688
write finish
test stored done 71.39592432975769
write finish


test stored done 72.85842776298523
write finish
test stored done 66.3692193031311
write finish
test stored done 34.95184135437012
write finish


test stored done 64.10014986991882
write finish
test stored done 63.4474892616272
write finish
test stored done 56.04970169067383
write finish


test stored done 63.45013475418091
write finish
test stored done 71.32516145706177
write finish
test stored done 70.50719618797302
write finish


test stored done 66.37016677856445
write finish
test stored done 79.80000281333923
write finish
test stored done 56.056970834732056
write finish


test stored done 34.93739724159241
write finish
test stored done 88.87896871566772
write finish
test stored done 64.1409661769867
write finish


test stored done 72.84847068786621
write finish
test stored done 88.7982542514801
write finish


test stored done 79.80113315582275
write finish
test stored done 71.3973639011383
write finish
test stored done 63.49886727333069
write finish


test stored done 34.98350477218628
write finish
test stored done 56.086151361465454
write finish
test stored done 66.37013292312622
write finish


test stored done 70.50223350524902
write finish
test stored done 72.8022735118866
write finish
test stored done 64.07376599311829
write finish


test stored done 34.93520402908325
write finish
test stored done 66.33031749725342
write finish
test stored done 71.31624984741211
write finish


test stored done 70.4770519733429
write finish
test stored done 56.032044410705566
write finish
test stored done 79.73350143432617
write finish


test stored done 72.77487587928772
write finish
test stored done 88.74556255340576
write finish


test stored done 64.08675336837769
write finish
test stored done 63.41271662712097
write finish
test stored done 63.42870545387268
write finish


test stored done 79.7798707485199
write finish
test stored done 71.34308981895447
write finish
test stored done 34.94878339767456
write finish


test stored done 56.05046367645264
write finish
test stored done 64.0840196609497
write finish
test stored done 88.73350739479065
write finish


test stored done 66.33936190605164
write finish
test stored done 70.46780395507812
write finish


test stored done 72.7728123664856
write finish
test stored done 88.74460220336914
write finish
test stored done 63.412548780441284
write finish


test stored done 66.32095694541931
write finish
test stored done 34.941983222961426
write finish
test stored done 56.04954552650452
write finish


test stored done 64.07828521728516
write finish
test stored done 71.30526924133301
write finish
test stored done 72.820059299469
write finish


test stored done 70.52918028831482
write finish
test stored done 79.79555058479309
write finish
test stored done 34.93623185157776
write finish


test stored done 88.74073266983032
write finish
test stored done 70.4761815071106
write finish
test stored done 66.32724404335022
write finish


test stored done 64.15593695640564
write finish
test stored done 72.81554317474365
write finish


test stored done 79.75145196914673
write finish
test stored done 71.37162828445435
write finish
test stored done 63.39776825904846
write finish


test stored done 56.03331971168518
write finish
test stored done 72.77096962928772
write finish
test stored done 70.46854329109192
write finish


test stored done 56.05711913108826
write finish
test stored done 79.74118685722351
write finish


test stored done 88.7697765827179
write finish
test stored done 71.2848596572876
write finish
test stored done 34.93778872489929
write finish


test stored done 64.07473707199097
write finish
test stored done 63.39634323120117
write finish
test stored done 66.32851600646973
write finish


test stored done 79.74274945259094
write finish
test stored done 64.10464143753052
write finish
test stored done 56.07766819000244
write finish


test stored done 34.95481634140015
write finish
test stored done 71.37281107902527
write finish
test stored done 63.425315141677856
write finish


test stored done 66.3209581375122
write finish
test stored done 72.78591418266296
write finish
test stored done 70.47528052330017
write finish


test stored done 88.71038770675659
write finish
test stored done 88.71919107437134
write finish


test stored done 64.08709716796875
write finish
test stored done 70.46715021133423
write finish
test stored done 66.30240297317505
write finish


test stored done 63.401691198349
write finish
test stored done 34.931153297424316
write finish
test stored done 71.32907557487488
write finish


test stored done 56.04078412055969
write finish
test stored done 79.7365574836731
write finish
test stored done 72.7695050239563
write finish


test stored done 79.7359266281128
write finish
test stored done 63.41609168052673
write finish


test stored done 64.06553792953491
write finish
test stored done 56.040600299835205
write finish
test stored done 72.78596496582031
write finish
test stored done 34.96018862724304
write finish


test stored done 70.49038767814636
write finish
test stored done 88.74358534812927
write finish


test stored done 71.30882263183594
write finish
test stored done 66.3585262298584
write finish
test stored done 34.971320390701294
write finish


test stored done 72.79036402702332
write finish
test stored done 88.77255582809448
write finish
