In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import torch.utils.data as utils
import librosa
import soundfile as sf
import time
import os

In [2]:
sampleSize=50000
sample_rate=16000
quantization_channels=256
#dilations=[2**i for i in range(8)]*20
#"residualDim=32
dilations=[2**i for i in range(9)]*5
residualDim=32
skipDim=512
filterSize=3
pad = np.sum(dilations)
shapeoftest=190500
lossrecord=[]
pad

2555

In [3]:
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [4]:
use_cuda = torch.cuda.is_available()
torch.manual_seed(1)
device = torch.device("cuda" if use_cuda else "cpu")
#device = 'cpu'
torch.set_default_tensor_type('torch.cuda.FloatTensor')
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [5]:
def mu_law_encode(audio, quantization_channels=quantization_channels,forX=False):
    '''Quantizes waveform amplitudes.'''
    mu = (quantization_channels - 1)*1.0
    # Perform mu-law companding transformation (ITU-T, 1988).
    # Minimum operation is here to deal with rare large amplitudes caused
    # by resampling.
    safe_audio_abs = np.minimum(np.abs(audio), 1.0)
    magnitude = np.log1p(mu * safe_audio_abs) / np.log1p(mu)
    signal = np.sign(audio) * magnitude
    # Quantize signal to the specified number of levels.
    if(forX):return signal
    return ((signal + 1) / 2 * mu + 0.5).astype(int)
def mu_law_decode(output, quantization_channels=quantization_channels):
    '''Recovers waveform from quantized values.'''
    mu = quantization_channels - 1
    # Map values back to [-1, 1].
    signal = 2 * ((output*1.0) / mu) - 1
    # Perform inverse of mu-law transformation.
    magnitude = (1 / mu) * ((1 + mu)**np.abs(signal) - 1)
    return np.sign(signal) * magnitude

In [8]:
def readAudio(name):
    audio0, samplerate = sf.read(name, dtype='float32')
    return librosa.resample(audio0.T, samplerate, sample_rate).reshape(-1)
p=['./vsCorpus/origin_mix.wav','./vsCorpus/origin_vocal.wav',
   './vsCorpus/origin_mix.wav','./vsCorpus/origin_vocal.wav','./vsCorpus/pred_mix.wav']
xtrain,ytrain,xval,yval,xtest=readAudio(p[0]),readAudio(p[1]),readAudio(p[2]),readAudio(p[3]),readAudio(p[4])
assert((xtrain==xval).all())
assert((ytrain==yval).all())
assert((xtrain != ytrain).any())

In [9]:
ytrain,yval=mu_law_encode(ytrain),mu_law_encode(yval)
xtrain,xval,xtest=mu_law_encode(xtrain,forX=True),mu_law_encode(xval,forX=True),mu_law_encode(xtest,forX=True)

In [7]:
'''xmean,xstd = xtrain.mean(),xtrain.std()
xtrain=(xtrain-xmean)/xstd
xval=(xval-xmean)/xstd
xtest=(xtest-xmean)/xstd'''

In [12]:
#xtrain,ytrain=xtrain[:xtest.shape[0]],ytrain[:xtest.shape[0]]
#xval,yval=xval[:xtest.shape[0]],yval[:xtest.shape[0]]
xtrain=np.pad(xtrain, (pad, pad), 'constant')
xval=np.pad(xval, (pad, pad), 'constant')
xtest=np.pad(xtest, (pad, pad), 'constant')
yval=np.pad(yval, (pad, pad), 'constant')
ytrain=np.pad(ytrain, (pad, pad), 'constant')

In [13]:
#xtrain,ytrain,xval,yval=xtrain[:-sampleSize],ytrain[:-sampleSize],xval[-sampleSize:],yval[-sampleSize:]
#xtrain,ytrain,xval,yval=xtrain[:-sampleSize],ytrain[:-sampleSize],xval[:sampleSize],yval[:sampleSize]
xtrain,xval,xtest=xtrain.reshape(1,1,-1),xval.reshape(1,1,-1),xtest.reshape(1,1,-1)
ytrain,yval=ytrain.reshape(1,-1),yval.reshape(1,-1)

In [14]:
xtrain,ytrain,xval,yval,xtest = torch.from_numpy(xtrain).type(torch.float32),\
                                torch.from_numpy(ytrain).type(torch.LongTensor),\
                                torch.from_numpy(xval).type(torch.float32),\
                                torch.from_numpy(yval).type(torch.LongTensor),\
                                torch.from_numpy(xtest).type(torch.float32)

In [15]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        sd,qd,rd = skipDim,quantization_channels,residualDim
        self.causal = nn.Conv1d(in_channels=1,out_channels=rd,kernel_size=3,padding=1)
        self.layer=dict()
        for i, d in enumerate(dilations):
            self.layer['tanh'+str(i)] = nn.Conv1d(in_channels=rd,out_channels=rd,kernel_size=3,padding=0,dilation=d)
            self.layer['sigmoid'+str(i)] = nn.Conv1d(in_channels=rd,out_channels=rd,kernel_size=3,padding=0,dilation=d)
            self.layer['skip'+str(i)] = nn.Conv1d(in_channels=rd,out_channels=sd,kernel_size=1,padding=0)
            self.layer['dense'+str(i)] = nn.Conv1d(in_channels=rd,out_channels=rd,kernel_size=1,padding=0)
        self.post1 = nn.Conv1d(in_channels=sd,out_channels=sd,kernel_size=1,padding=0)
        self.post2 = nn.Conv1d(in_channels=sd,out_channels=qd,kernel_size=1,padding=0)
        self.tanh,self.sigmoid = nn.Tanh(),nn.Sigmoid()

    def forward(self, x):
        finallen = x.shape[-1]-2*pad
        x = self.causal(x)
        #print('x.shape',x.shape)
        skip_connections = torch.zeros([1,skipDim,finallen],dtype=torch.float32,device=device)
        for i, dilation in enumerate(dilations):
            xinput = x.clone()[:,:,dilation:-dilation]
            x1 = self.tanh(self.layer['tanh'+str(i)](x))
            #print('tanh.shape',x1.shape)
            x2 = self.sigmoid(self.layer['sigmoid'+str(i)](x))
            #print('sigmoid.shape',x2.shape)
            x = x1*x2
            #print('multi',x3.shape)
            cutlen = (x.shape[-1] - finallen)//2
            skip_connections += (self.layer['skip'+str(i)](x)).narrow(2,int(cutlen),int(finallen))
            #cur =self.layer['skip'+str(i)](x3)
            x = self.layer['dense'+str(i)](x)
            #print('dense.shape',x.shape)
            x += xinput
        x = self.post2(F.relu(self.post1(F.relu(skip_connections))))
        return x

model = Net().cuda()
criterion = nn.CrossEntropyLoss().cuda()
#optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
optimizer = optim.Adam(model.parameters(),weight_decay=1e-5)

In [17]:
#model = torch.load('torchmodel0.681600')

In [25]:
def val():
    model.eval()
    startval_time = time.time()
    with torch.no_grad():
        #data, target = xval.to(device), yval.to(device)
        data, target = xtrain[:,:,0:2*pad+shapeoftest].to(device), ytrain[:,pad:shapeoftest+pad].to(device)
        output = model(data)
        #print(output.shape)
        #print(output[:,:,:10])
        pred = output.max(1, keepdim=True)[1]
        correct = pred.eq(target.view_as(pred)).sum().item() / pred.shape[-1]
        val_loss = criterion(output, target).item()
    print(correct,'accurate')
    print('\nval set:loss{:.4f}:, ({:.3f} sec/step)\n'.format(val_loss,time.time()-startval_time))

def test():
    model.eval()
    startval_time = time.time()
    with torch.no_grad():
        output = model(xtest.to(device))
        pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
        ans = mu_law_decode(pred)
        sf.write('./vsCorpus/resultxte.wav', ans, sample_rate)

        # output = model(xtrain[:,:,:sampleSize].to(device))
        ind = pad
        listofpred=[]
        while ind < xtrain.shape[-1]-pad:
            output = model(xtrain[:, :, ind-pad:ind+sampleSize+pad].to(device))
            pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
            listofpred.append(pred)
            ind += sampleSize
        ans = mu_law_decode(np.concatenate(listofpred))
        sf.write('./vsCorpus/resultxtr.wav', ans, sample_rate)
        print('stored done\n')
    
def train(epoch):
    model.train()
    #idx = np.arange(xtrain.shape[-1] - 2 * sampleSize,1000)
    #176000
    idx = np.arange(pad,shapeoftest+pad-sampleSize,1000)
    np.random.shuffle(idx)
    for i, ind in enumerate(idx):
        start_time = time.time()
        data, target = xtrain[:,:,ind-pad:ind+sampleSize+pad].to(device), ytrain[:,ind:ind+sampleSize].to(device)
        output = model(data)
        #print(output.shape)
        loss = criterion(output, target)
        lossrecord.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss:{:.6f}: , ({:.3f} sec/step)'.format(
                epoch, i, len(idx),100. * i / len(idx), loss.item(),time.time() - start_time))
        if i % 100 == 0:
            val()
            test()
            state={'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()}
            torch.save(state, 'allmulawalldata')

In [26]:
 test()

stored done



In [None]:
for epoch in range(100000):
    train(epoch)
    test()
    with open("lossfile.txt", "w") as f:
        for s in lossrecord:
            f.write(str(s) +"\n")
    print('write finish')

0.004535433070866142 accurate

val set:loss5.4985:, (1.499 sec/step)

stored done



  "type " + obj.__name__ + ". It won't be checked "


0.026824146981627297 accurate

val set:loss4.6916:, (1.506 sec/step)

stored done



stored done

0.02790551181102362 accurate

val set:loss4.6499:, (1.505 sec/step)

stored done

0.0350498687664042 accurate

val set:loss4.5745:, (1.506 sec/step)

stored done



stored done

0.036965879265091865 accurate

val set:loss4.5383:, (1.506 sec/step)

stored done



0.041748031496062994 accurate

val set:loss4.4838:, (1.507 sec/step)

stored done

stored done

0.044241469816272964 accurate

val set:loss4.4722:, (1.507 sec/step)

stored done



0.04794225721784777 accurate

val set:loss4.4383:, (1.506 sec/step)

stored done

stored done

0.04959055118110236 accurate

val set:loss4.4252:, (1.507 sec/step)

stored done



0.053086614173228346 accurate

val set:loss4.3954:, (1.507 sec/step)

stored done

stored done

0.054645669291338586 accurate

val set:loss4.3819:, (1.507 sec/step)

stored done



0.05899212598425197 accurate

val set:loss4.3438:, (1.508 sec/step)

stored done

stored done

0.060241469816272965 accurate

val set:loss4.3330:, (1.507 sec/step)

stored done



0.06438845144356956 accurate

val set:loss4.3273:, (1.507 sec/step)

stored done

stored done

0.06555380577427822 accurate

val set:loss4.2961:, (1.506 sec/step)

stored done



0.06847769028871391 accurate

val set:loss4.2817:, (1.507 sec/step)

stored done

stored done

0.07033595800524935 accurate

val set:loss4.2705:, (1.508 sec/step)

stored done



0.07438320209973753 accurate

val set:loss4.2398:, (1.507 sec/step)

stored done



stored done

0.07483989501312335 accurate

val set:loss4.2333:, (1.506 sec/step)

stored done

0.07862992125984251 accurate

val set:loss4.2050:, (1.507 sec/step)

stored done



stored done

0.08073490813648294 accurate

val set:loss4.1970:, (1.507 sec/step)

stored done



0.08488713910761155 accurate

val set:loss4.1687:, (1.506 sec/step)

stored done

stored done

0.084750656167979 accurate

val set:loss4.1719:, (1.506 sec/step)

stored done



0.08763254593175852 accurate

val set:loss4.1601:, (1.508 sec/step)

stored done

stored done

0.08803149606299213 accurate

val set:loss4.1533:, (1.508 sec/step)

stored done



0.09373228346456693 accurate

val set:loss4.1193:, (1.507 sec/step)

stored done

stored done

0.09520734908136483 accurate

val set:loss4.1154:, (1.508 sec/step)

stored done



0.09886089238845144 accurate

val set:loss4.0915:, (1.508 sec/step)

stored done

stored done

0.09920734908136483 accurate

val set:loss4.0838:, (1.508 sec/step)

stored done



0.10126509186351706 accurate

val set:loss4.0789:, (1.507 sec/step)

stored done

stored done

0.10457742782152231 accurate

val set:loss4.0571:, (1.508 sec/step)

stored done



0.10679790026246719 accurate

val set:loss4.0476:, (1.509 sec/step)

stored done



stored done

0.1083989501312336 accurate

val set:loss4.0386:, (1.509 sec/step)

stored done

0.1114278215223097 accurate

val set:loss4.0276:, (1.509 sec/step)

stored done



stored done

0.11296587926509187 accurate

val set:loss4.0241:, (1.508 sec/step)

stored done

0.11438845144356956 accurate

val set:loss4.0086:, (1.508 sec/step)

stored done



stored done

0.1198740157480315 accurate

val set:loss3.9910:, (1.509 sec/step)

stored done



0.12136482939632547 accurate

val set:loss3.9799:, (1.508 sec/step)

stored done

stored done

0.12108136482939633 accurate

val set:loss3.9688:, (1.508 sec/step)

stored done



0.12269291338582677 accurate

val set:loss3.9599:, (1.508 sec/step)

stored done

stored done

0.1251233595800525 accurate

val set:loss3.9539:, (1.508 sec/step)

stored done



0.12826771653543306 accurate

val set:loss3.9411:, (1.509 sec/step)

stored done

stored done

0.12716010498687663 accurate

val set:loss3.9585:, (1.507 sec/step)

stored done



0.13181102362204725 accurate

val set:loss3.9264:, (1.509 sec/step)

stored done

stored done

0.13068241469816272 accurate

val set:loss3.9289:, (1.509 sec/step)

stored done



0.13406299212598424 accurate

val set:loss3.9128:, (1.508 sec/step)

stored done

stored done

0.13419422572178477 accurate

val set:loss3.9038:, (1.509 sec/step)

stored done



0.13723359580052494 accurate

val set:loss3.8929:, (1.509 sec/step)

stored done



stored done

0.13902887139107611 accurate

val set:loss3.8855:, (1.509 sec/step)

stored done

0.14012073490813648 accurate

val set:loss3.8866:, (1.508 sec/step)

stored done





In [23]:
#torch.save(model, 'loss2.5~3.5step10h_repeat5*2**9resu32sample50000')

  "type " + obj.__name__ + ". It won't be checked "
