In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import torch.utils.data as utils
import librosa
import soundfile as sf
import time
import os

In [2]:
sampleSize=50000
sample_rate=16000
quantization_channels=256
#dilations=[2**i for i in range(8)]*20
#"residualDim=32
dilations=[2**i for i in range(10)]*5
residualDim=96
skipDim=256
filterSize=3
pad = np.sum(dilations)
shapeoftest=190500
lossrecord=[]
pad

5115

In [3]:
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [4]:
use_cuda = torch.cuda.is_available()
torch.manual_seed(1)
device = torch.device("cuda" if use_cuda else "cpu")
#device = 'cpu'
torch.set_default_tensor_type('torch.cuda.FloatTensor')
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [5]:
def mu_law_encode(audio, quantization_channels=quantization_channels):
    '''Quantizes waveform amplitudes.'''
    mu = (quantization_channels - 1)*1.0
    # Perform mu-law companding transformation (ITU-T, 1988).
    # Minimum operation is here to deal with rare large amplitudes caused
    # by resampling.
    safe_audio_abs = np.minimum(np.abs(audio), 1.0)
    magnitude = np.log1p(mu * safe_audio_abs) / np.log1p(mu)
    signal = np.sign(audio) * magnitude
    # Quantize signal to the specified number of levels.
    #if(forX):return signal
    return ((signal + 1) / 2 * mu + 0.5).astype(int)
def mu_law_decode(output, quantization_channels=quantization_channels):
    '''Recovers waveform from quantized values.'''
    mu = quantization_channels - 1
    # Map values back to [-1, 1].
    signal = 2 * ((output*1.0) / mu) - 1
    # Perform inverse of mu-law transformation.
    magnitude = (1 / mu) * ((1 + mu)**np.abs(signal) - 1)
    return np.sign(signal) * magnitude

In [6]:
def readAudio(name):
    audio0, samplerate = sf.read(name, dtype='float32')
    return librosa.resample(audio0.T, samplerate, sample_rate).reshape(-1)
p=['./vsCorpus/origin_mix.wav','./vsCorpus/origin_vocal.wav',
   './vsCorpus/origin_mix.wav','./vsCorpus/origin_vocal.wav','./vsCorpus/pred_mix.wav']
xtrain,ytrain,xval,yval,xtest=readAudio(p[0]),readAudio(p[1]),readAudio(p[2]),readAudio(p[3]),readAudio(p[4])
assert((xtrain==xval).all())
assert((ytrain==yval).all())
assert((xtrain != ytrain).any())

In [7]:
xmean,xstd = xtrain.mean(),xtrain.std()
xtrain=(xtrain-xmean)/xstd
xval=(xval-xmean)/xstd
xtest=(xtest-xmean)/xstd
ytrain,yval=mu_law_encode(ytrain),mu_law_encode(yval)

In [8]:
xtrain,ytrain=xtrain[:xtest.shape[0]],ytrain[:xtest.shape[0]]
xval,yval=xval[:xtest.shape[0]],yval[:xtest.shape[0]]
xtrain=np.pad(xtrain, (pad, pad), 'constant')
xval=np.pad(xval, (pad, pad), 'constant')
xtest=np.pad(xtest, (pad, pad), 'constant')
yval=np.pad(yval, (pad, pad), 'constant')
ytrain=np.pad(ytrain, (pad, pad), 'constant')

In [9]:
#xtrain,ytrain,xval,yval=xtrain[:-sampleSize],ytrain[:-sampleSize],xval[-sampleSize:],yval[-sampleSize:]
#xtrain,ytrain,xval,yval=xtrain[:-sampleSize],ytrain[:-sampleSize],xval[:sampleSize],yval[:sampleSize]
xtrain,xval,xtest=xtrain.reshape(1,1,-1),xval.reshape(1,1,-1),xtest.reshape(1,1,-1)
ytrain,yval=ytrain.reshape(1,-1),yval.reshape(1,-1)

In [10]:
xtrain,ytrain,xval,yval,xtest = torch.from_numpy(xtrain).type(torch.float32),\
                                torch.from_numpy(ytrain).type(torch.LongTensor),\
                                torch.from_numpy(xval).type(torch.float32),\
                                torch.from_numpy(yval).type(torch.LongTensor),\
                                torch.from_numpy(xtest).type(torch.float32)

In [11]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        sd,qd,rd = skipDim,quantization_channels,residualDim
        self.causal = nn.Conv1d(in_channels=1,out_channels=rd,kernel_size=3,padding=1)
        self.layer=dict()
        for i, d in enumerate(dilations):
            self.layer['tanh'+str(i)] = nn.Conv1d(in_channels=rd,out_channels=rd,kernel_size=3,padding=0,dilation=d)
            self.layer['sigmoid'+str(i)] = nn.Conv1d(in_channels=rd,out_channels=rd,kernel_size=3,padding=0,dilation=d)
            self.layer['skip'+str(i)] = nn.Conv1d(in_channels=rd,out_channels=sd,kernel_size=1,padding=0)
            self.layer['dense'+str(i)] = nn.Conv1d(in_channels=rd,out_channels=rd,kernel_size=1,padding=0)
        self.post1 = nn.Conv1d(in_channels=sd,out_channels=sd,kernel_size=1,padding=0)
        self.post2 = nn.Conv1d(in_channels=sd,out_channels=qd,kernel_size=1,padding=0)
        self.tanh,self.sigmoid = nn.Tanh(),nn.Sigmoid()

    def forward(self, x):
        finallen = x.shape[-1]-2*pad
        x = self.causal(x)
        #print('x.shape',x.shape)
        skip_connections = torch.zeros([1,skipDim,finallen],dtype=torch.float32,device=device)
        for i, dilation in enumerate(dilations):
            xinput = x.clone()[:,:,dilation:-dilation]
            x1 = self.tanh(self.layer['tanh'+str(i)](x))
            #print('tanh.shape',x1.shape)
            x2 = self.sigmoid(self.layer['sigmoid'+str(i)](x))
            #print('sigmoid.shape',x2.shape)
            x = x1*x2
            #print('multi',x3.shape)
            cutlen = (x.shape[-1] - finallen)//2
            skip_connections += (self.layer['skip'+str(i)](x)).narrow(2,int(cutlen),int(finallen))
            #cur =self.layer['skip'+str(i)](x3)
            x = self.layer['dense'+str(i)](x)
            #print('dense.shape',x.shape)
            x += xinput
        x = self.post2(F.relu(self.post1(F.relu(skip_connections))))
        return x

model = Net().cuda()
criterion = nn.CrossEntropyLoss().cuda()
#optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
optimizer = optim.Adam(model.parameters(),weight_decay=1e-5)

In [12]:
#torch.sum(c,dim=0,keepdim=True)

In [13]:
#model = torch.load('torchmodel0.681600')

In [None]:
def val():
    model.eval()
    startval_time = time.time()
    with torch.no_grad():
        #data, target = xval.to(device), yval.to(device)
        data, target = xtrain[:,:,0:2*pad+shapeoftest].to(device), ytrain[:,pad:shapeoftest+pad].to(device)
        output = model(data)
        #print(output.shape)
        #print(output[:,:,:10])
        pred = output.max(1, keepdim=True)[1]
        correct = pred.eq(target.view_as(pred)).sum().item() / pred.shape[-1]
        val_loss = criterion(output, target).item()
    print(correct,'accurate')
    print('\nval set:loss{:.4f}:, ({:.3f} sec/step)\n'.format(val_loss,time.time()-startval_time))

def test():
    model.eval()
    startval_time = time.time()
    with torch.no_grad():
        output = model(xtest.to(device))
        pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
        ans = mu_law_decode(pred)
        sf.write('./vsCorpus/resultxte.wav', ans, sample_rate)
        
        #output = model(xtrain[:,:,:sampleSize].to(device))
        output = model(xtrain[:,:,0:2*pad+shapeoftest].to(device))
        pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
        ans = mu_law_decode(pred)
        sf.write('./vsCorpus/resultxtr.wav', ans, sample_rate)
        print('stored done\n')
    
def train(epoch):
    model.train()
    #idx = np.arange(xtrain.shape[-1] - 2 * sampleSize,1000)
    #176000
    idx = np.arange(pad,shapeoftest+pad-sampleSize,1000)
    np.random.shuffle(idx)
    for i, ind in enumerate(idx):
        start_time = time.time()
        data, target = xtrain[:,:,ind-pad:ind+sampleSize+pad].to(device), ytrain[:,ind:ind+sampleSize].to(device)
        output = model(data)
        #print(output.shape)
        loss = criterion(output, target)
        lossrecord.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss:{:.6f}: , ({:.3f} sec/step)'.format(
                epoch, i, len(idx),100. * i / len(idx), loss.item(),time.time() - start_time))
        if i % 100 == 0:
            val()
            test()
            torch.save(model, 'padonlyonside')

In [None]:
for epoch in range(100000):
    train(epoch)
    test()

0.005889763779527559 accurate

val set:loss5.5286:, (2.553 sec/step)

stored done



  "type " + obj.__name__ + ". It won't be checked "


0.03512335958005249 accurate

val set:loss4.5443:, (2.576 sec/step)

stored done



stored done

0.03836745406824147 accurate

val set:loss4.4981:, (2.576 sec/step)

stored done

0.04386351706036745 accurate

val set:loss4.4284:, (2.576 sec/step)

stored done



stored done

0.04668241469816273 accurate

val set:loss4.4007:, (2.577 sec/step)

stored done



0.05082939632545932 accurate

val set:loss4.3704:, (2.579 sec/step)

stored done

stored done

0.05228346456692914 accurate

val set:loss4.3727:, (2.579 sec/step)

stored done



0.05604724409448819 accurate

val set:loss4.3324:, (2.578 sec/step)

stored done

stored done

0.056236220472440944 accurate

val set:loss4.3167:, (2.576 sec/step)

stored done



0.06036745406824147 accurate

val set:loss4.2832:, (2.574 sec/step)

stored done

stored done

0.06203674540682415 accurate

val set:loss4.2742:, (2.573 sec/step)

stored done



0.06534908136482939 accurate

val set:loss4.2550:, (2.572 sec/step)

stored done

stored done

0.06719685039370078 accurate

val set:loss4.2525:, (2.571 sec/step)

stored done



0.06709186351706037 accurate

val set:loss4.2448:, (2.573 sec/step)

stored done

stored done

0.06974278215223097 accurate

val set:loss4.2223:, (2.575 sec/step)

stored done



0.07224146981627297 accurate

val set:loss4.1971:, (2.576 sec/step)

stored done

stored done

0.07332808398950132 accurate

val set:loss4.1977:, (2.577 sec/step)

stored done



0.07555380577427821 accurate

val set:loss4.1726:, (2.576 sec/step)

stored done



stored done

0.07589501312335958 accurate

val set:loss4.1753:, (2.576 sec/step)

stored done

0.08081364829396326 accurate

val set:loss4.1420:, (2.575 sec/step)

stored done



stored done

0.07830446194225722 accurate

val set:loss4.1502:, (2.576 sec/step)

stored done



0.08166404199475065 accurate

val set:loss4.1308:, (2.575 sec/step)

stored done

stored done

0.08237270341207349 accurate

val set:loss4.1316:, (2.576 sec/step)

stored done



0.08318635170603675 accurate

val set:loss4.1211:, (2.575 sec/step)

stored done

stored done

0.085748031496063 accurate

val set:loss4.1099:, (2.573 sec/step)

stored done



0.08671916010498687 accurate

val set:loss4.1127:, (2.574 sec/step)

stored done

stored done

0.0882992125984252 accurate

val set:loss4.0935:, (2.575 sec/step)

stored done



0.08989501312335958 accurate

val set:loss4.0886:, (2.576 sec/step)

stored done

stored done

0.09319160104986876 accurate

val set:loss4.0667:, (2.576 sec/step)

stored done



0.0923989501312336 accurate

val set:loss4.0813:, (2.574 sec/step)

stored done

stored done

0.09305511811023622 accurate

val set:loss4.0715:, (2.574 sec/step)

stored done



0.09515485564304461 accurate

val set:loss4.0564:, (2.574 sec/step)

stored done



stored done

0.0965511811023622 accurate

val set:loss4.0481:, (2.573 sec/step)

stored done

0.09797900262467192 accurate

val set:loss4.0441:, (2.574 sec/step)

stored done



stored done

0.09795800524934384 accurate

val set:loss4.0502:, (2.575 sec/step)

stored done

0.10111286089238845 accurate

val set:loss4.0239:, (2.575 sec/step)

stored done



stored done

0.10102887139107612 accurate

val set:loss4.0245:, (2.575 sec/step)

stored done



0.10253018372703412 accurate

val set:loss4.0150:, (2.575 sec/step)

stored done

stored done

0.10361679790026247 accurate

val set:loss4.0053:, (2.574 sec/step)

stored done



0.10518110236220472 accurate

val set:loss4.0032:, (2.575 sec/step)

stored done

stored done

0.10465616797900262 accurate

val set:loss4.0049:, (2.576 sec/step)

stored done



0.10671916010498687 accurate

val set:loss3.9926:, (2.573 sec/step)

stored done

stored done

0.10572178477690289 accurate

val set:loss3.9964:, (2.574 sec/step)

stored done



0.10810498687664041 accurate

val set:loss3.9849:, (2.574 sec/step)

stored done

stored done

0.1095748031496063 accurate

val set:loss3.9769:, (2.575 sec/step)

stored done



0.11058267716535433 accurate

val set:loss3.9707:, (2.576 sec/step)

stored done

stored done

0.11014173228346456 accurate

val set:loss3.9730:, (2.575 sec/step)

stored done



0.1115748031496063 accurate

val set:loss3.9678:, (2.576 sec/step)

stored done



stored done

0.11241469816272966 accurate

val set:loss3.9607:, (2.575 sec/step)

stored done

0.11276115485564304 accurate

val set:loss3.9558:, (2.574 sec/step)

stored done



stored done

0.11412073490813648 accurate

val set:loss3.9539:, (2.573 sec/step)

stored done



0.11571653543307087 accurate

val set:loss3.9407:, (2.572 sec/step)

stored done

stored done

0.11551706036745407 accurate

val set:loss3.9603:, (2.573 sec/step)

stored done



0.11731758530183727 accurate

val set:loss3.9423:, (2.576 sec/step)

stored done

stored done

0.11601049868766404 accurate

val set:loss3.9433:, (2.577 sec/step)

stored done



0.11750656167979003 accurate

val set:loss3.9436:, (2.579 sec/step)

stored done

stored done

0.11833595800524935 accurate

val set:loss3.9302:, (2.578 sec/step)

stored done



0.12046194225721785 accurate

val set:loss3.9163:, (2.579 sec/step)

stored done

stored done

0.12045669291338583 accurate

val set:loss3.9243:, (2.577 sec/step)

stored done



0.1213753280839895 accurate

val set:loss3.9193:, (2.576 sec/step)

stored done

stored done

0.122 accurate

val set:loss3.9147:, (2.576 sec/step)

stored done



0.12246194225721785 accurate

val set:loss3.9221:, (2.576 sec/step)

stored done



stored done

0.12107611548556431 accurate

val set:loss3.9157:, (2.575 sec/step)

stored done

0.12250918635170603 accurate

val set:loss3.9185:, (2.573 sec/step)

stored done



stored done

0.1240734908136483 accurate

val set:loss3.9027:, (2.572 sec/step)

stored done

0.12385301837270342 accurate

val set:loss3.8945:, (2.572 sec/step)



stored done

stored done

0.1243989501312336 accurate

val set:loss3.8983:, (2.573 sec/step)

stored done



0.12669291338582678 accurate

val set:loss3.8834:, (2.575 sec/step)

stored done

stored done

0.12694488188976377 accurate

val set:loss3.8925:, (2.577 sec/step)

stored done



0.12637795275590552 accurate

val set:loss3.8893:, (2.580 sec/step)

stored done

stored done

0.12658792650918635 accurate

val set:loss3.8837:, (2.580 sec/step)

stored done



0.12604199475065617 accurate

val set:loss3.8880:, (2.577 sec/step)

stored done

stored done

0.1286771653543307 accurate

val set:loss3.8821:, (2.575 sec/step)

stored done



0.1290288713910761 accurate

val set:loss3.8729:, (2.572 sec/step)

stored done

stored done

0.12845669291338582 accurate

val set:loss3.8796:, (2.571 sec/step)

stored done



0.13063517060367455 accurate

val set:loss3.8717:, (2.575 sec/step)

stored done

stored done



0.13057742782152232 accurate

val set:loss3.8736:, (2.575 sec/step)

stored done

0.13075590551181102 accurate

val set:loss3.8659:, (2.576 sec/step)

stored done



stored done

0.1310236220472441 accurate

val set:loss3.8585:, (2.576 sec/step)

stored done

0.1309238845144357 accurate

val set:loss3.8654:, (2.575 sec/step)

stored done



stored done

0.13277165354330708 accurate

val set:loss3.8533:, (2.574 sec/step)

stored done



0.13230971128608923 accurate

val set:loss3.8570:, (2.576 sec/step)

stored done

stored done

0.1315433070866142 accurate

val set:loss3.8683:, (2.576 sec/step)

stored done



0.13464041994750656 accurate

val set:loss3.8521:, (2.575 sec/step)

stored done

stored done

0.13349081364829396 accurate

val set:loss3.8544:, (2.574 sec/step)

stored done



0.1336220472440945 accurate

val set:loss3.8506:, (2.573 sec/step)

stored done

stored done

0.13395800524934384 accurate

val set:loss3.8496:, (2.572 sec/step)

stored done



0.13474015748031495 accurate

val set:loss3.8459:, (2.573 sec/step)

stored done

stored done

0.13530708661417323 accurate

val set:loss3.8452:, (2.573 sec/step)

stored done



0.13425721784776903 accurate

val set:loss3.8439:, (2.575 sec/step)

stored done

stored done

0.13633070866141733 accurate

val set:loss3.8348:, (2.574 sec/step)

stored done



0.13594225721784778 accurate

val set:loss3.8350:, (2.575 sec/step)

stored done



stored done

0.13625721784776904 accurate

val set:loss3.8396:, (2.575 sec/step)

stored done

0.13697112860892388 accurate

val set:loss3.8366:, (2.576 sec/step)

stored done



stored done

0.13653543307086613 accurate

val set:loss3.8390:, (2.577 sec/step)

stored done



0.1346246719160105 accurate

val set:loss3.8501:, (2.575 sec/step)

stored done

stored done

0.13781102362204725 accurate

val set:loss3.8416:, (2.574 sec/step)

stored done



0.13903937007874015 accurate

val set:loss3.8236:, (2.574 sec/step)

stored done

stored done

0.13800524934383201 accurate

val set:loss3.8241:, (2.573 sec/step)

stored done



0.13930708661417324 accurate

val set:loss3.8159:, (2.573 sec/step)

stored done

stored done

0.13883464566929135 accurate

val set:loss3.8240:, (2.575 sec/step)

stored done



0.13880839895013122 accurate

val set:loss3.8308:, (2.577 sec/step)

stored done

stored done

0.13929658792650917 accurate

val set:loss3.8162:, (2.577 sec/step)

stored done



0.13992125984251969 accurate

val set:loss3.8172:, (2.576 sec/step)

stored done

stored done

0.14 accurate

val set:loss3.8177:, (2.575 sec/step)

stored done



0.13995275590551182 accurate

val set:loss3.8140:, (2.574 sec/step)

stored done



stored done

0.14068766404199476 accurate

val set:loss3.8167:, (2.574 sec/step)

stored done

0.14102887139107612 accurate

val set:loss3.8116:, (2.573 sec/step)

stored done



stored done

0.14058792650918636 accurate

val set:loss3.8183:, (2.573 sec/step)

stored done

0.1409238845144357 accurate

val set:loss3.8085:, (2.578 sec/step)

stored done



stored done

0.14299212598425196 accurate

val set:loss3.8016:, (2.578 sec/step)

stored done



0.14248818897637797 accurate

val set:loss3.8123:, (2.577 sec/step)

stored done

stored done

0.14300787401574802 accurate

val set:loss3.8014:, (2.576 sec/step)

stored done





In [None]:
#model = torch.load('torchmodel')

In [None]:
#torch.save(model, 'loss23*1.+4*3.+16*2.step_9072_repeat15*2**8resu32sample50000')