In [None]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import torch.utils.data as utils
import librosa
import soundfile as sf
import time
import os
from torch.optim.lr_scheduler import StepLR,MultiStepLR

In [None]:
sampleSize=16000#the length of the sample size
sample_rate=16000#the length of audio for one second
quantization_channels=256 #discretize the value to 256 numbers
dilations=[2**i for i in range(9)]*1  #idea from wavenet, have more receptive field
residualDim=128 #
skipDim=512
shapeoftest = 190500
filterSize=3
resumefile='array2model' # name of checkpoint
lossname='array2loss.txt' # name of loss file
continueTrain=False # whether use checkpoint
pad = np.sum(dilations) # padding for dilate convolutional layers
lossrecord=[]  #list for record loss
receptive_field=np.sum(dilations)+1
#pad=0

In [None]:
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
use_cuda = torch.cuda.is_available()
torch.manual_seed(1)
device = torch.device("cuda" if use_cuda else "cpu")
#device = 'cpu'
torch.set_default_tensor_type('torch.cuda.FloatTensor')
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [None]:
def mu_law_encode(audio, quantization_channels=quantization_channels,forX=False):
    '''Quantizes waveform amplitudes.'''
    mu = (quantization_channels - 1)*1.0
    # Perform mu-law companding transformation (ITU-T, 1988).
    # Minimum operation is here to deal with rare large amplitudes caused
    # by resampling.
    safe_audio_abs = np.minimum(np.abs(audio), 1.0)
    magnitude = np.log1p(mu * safe_audio_abs) / np.log1p(mu)
    signal = np.sign(audio) * magnitude
    # Quantize signal to the specified number of levels.
    if(forX):return signal
    return ((signal + 1) / 2 * mu + 0.5).astype(int)
def mu_law_decode(output, quantization_channels=quantization_channels):
    '''Recovers waveform from quantized values.'''
    mu = quantization_channels - 1
    # Map values back to [-1, 1].
    signal = 2 * ((output*1.0) / mu) - 1
    # Perform inverse of mu-law transformation.
    magnitude = (1 / mu) * ((1 + mu)**np.abs(signal) - 1)
    return np.sign(signal) * magnitude
def cateToSignal(output, quantization_channels=quantization_channels,stage=0):
    mu = quantization_channels - 1
    if stage == 0:
        # Map values back to [-1, 1].
        signal = 2 * ((output*1.0) / mu) - 1
        return signal
    else:
        magnitude = (1 / mu) * ((1 + mu)**np.abs(output) - 1)
        return np.sign(output) * magnitude

In [None]:
def readAudio(name):
    audio0, samplerate = sf.read(name, dtype='float32')
    return librosa.resample(audio0.T, samplerate, sample_rate).reshape(-1)
p=['./vsCorpus/origin_mix.wav','./vsCorpus/origin_vocal.wav',
   './vsCorpus/origin_mix.wav','./vsCorpus/origin_vocal.wav','./vsCorpus/pred_mix.wav']
xtrain,ytrain,xval,yval,xtest=readAudio(p[0]),readAudio(p[1]),readAudio(p[2]),readAudio(p[3]),readAudio(p[4])
assert((xtrain==xval).all())
assert((ytrain==yval).all())
assert((xtrain != ytrain).any())

In [None]:
ytrain,yval=mu_law_encode(ytrain),mu_law_encode(yval)
xytrain,xyval=mu_law_encode(ytrain,forX=True),mu_law_encode(yval,forX=True)
xtrain,xval,xtest=mu_law_encode(xtrain,forX=True),mu_law_encode(xval,forX=True),mu_law_encode(xtest,forX=True)

In [None]:
#xmean,xstd = xtrain.mean(),xtrain.std()
#xtrain=(xtrain-xmean)/xstd
#xval=(xval-xmean)/xstd
#xtest=(xtest-xmean)/xstd

In [None]:
xtrain=np.pad(xtrain, (receptive_field, receptive_field), 'constant')
xval=np.pad(xval, (receptive_field, receptive_field), 'constant')
xtest=np.pad(xtest, (receptive_field, receptive_field), 'constant')
yval=np.pad(yval, (receptive_field, receptive_field), 'constant')
ytrain=np.pad(ytrain, (receptive_field, receptive_field), 'constant')
xyval=np.pad(xyval, (receptive_field, receptive_field), 'constant')
xytrain=np.pad(xytrain, (receptive_field, receptive_field), 'constant')

In [None]:
xtrain,xval,xtest=xtrain.reshape(1,1,-1),xval.reshape(1,1,-1),xtest.reshape(1,1,-1)
ytrain,yval=ytrain.reshape(1,-1),yval.reshape(1,-1)
xytrain,xyval=ytrain.reshape(1,1,-1),yval.reshape(1,1,-1)

In [None]:
#xtrain = np.concatenate((xtrain,xytrain.reshape(xtrain.shape)),axis=1)
#xval = np.concatenate((xval,xyval.reshape(xval.shape)),axis=1)
#xtest = np.concatenate((xtest,np.zeros_like(xtest)),axis=1)

In [None]:
xtrain,ytrain,xval,yval,xtest = torch.from_numpy(xtrain).type(torch.float32),\
                                torch.from_numpy(ytrain).type(torch.LongTensor),\
                                torch.from_numpy(xval).type(torch.float32),\
                                torch.from_numpy(yval).type(torch.LongTensor),\
                                torch.from_numpy(xtest).type(torch.float32)
xytrain,xyval=torch.from_numpy(xytrain).type(torch.float32),torch.from_numpy(xyval).type(torch.float32),

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        sd,qd,rd = skipDim,quantization_channels,residualDim
        self.xcausal = nn.Conv1d(in_channels=1,out_channels=rd,kernel_size=1,padding=0)
        self.ycausal = nn.Conv1d(in_channels=1,out_channels=rd,kernel_size=1,padding=0)
        self.xtanh = nn.ModuleList()
        self.xsig = nn.ModuleList()
        self.xskip = nn.ModuleList()
        self.xdense = nn.ModuleList()
        self.ytanh = nn.ModuleList()
        self.ysig = nn.ModuleList()
        self.yskip = nn.ModuleList()
        self.ydense = nn.ModuleList()
        for i, d in enumerate(dilations):
            self.xtanh.append(nn.Conv1d(in_channels=rd,out_channels=rd,kernel_size=3,dilation=d))
            self.xsig.append(nn.Conv1d(in_channels=rd,out_channels=rd,kernel_size=3,dilation=d))
            self.xskip.append(nn.Conv1d(in_channels=rd,out_channels=sd,kernel_size=1))
            self.xdense.append(nn.Conv1d(in_channels=rd,out_channels=rd,kernel_size=1))
            
            self.ytanh.append(nn.Conv1d(in_channels=rd,out_channels=rd,kernel_size=2,dilation=d))
            self.ysig.append(nn.Conv1d(in_channels=rd,out_channels=rd,kernel_size=2,dilation=d))
            self.yskip.append(nn.Conv1d(in_channels=rd,out_channels=sd,kernel_size=1))
            self.ydense.append(nn.Conv1d(in_channels=rd,out_channels=rd,kernel_size=1))  
        self.post1 = nn.Conv1d(in_channels=sd*2,out_channels=sd*2,kernel_size=1,padding=0)
        self.post2 = nn.Conv1d(in_channels=sd*2,out_channels=qd,kernel_size=1,padding=0)
        self.tanh,self.sigmoid = nn.Tanh(),nn.Sigmoid()

    def forward(self, x, y):
        finallen = x.shape[-1]-2*receptive_field
        x = self.xcausal(x)
        y = self.ycausal(y)
        skipx = torch.zeros([1,skipDim,finallen],dtype=torch.float32,device=device)
        skipy = torch.zeros([1,skipDim,finallen],dtype=torch.float32,device=device)
        accumulate=1
        for i, dilation in enumerate(dilations):
            accumulate+=dilation
            yinput = y.clone()[:,:,:-dilation]
            y1 = self.tanh(self.ytanh[i](y))
            y2 = self.sigmoid(self.ysig[i](y))
            y=y1*y2
            skipy += (self.yskip[i](y)).\
                narrow(2,int(receptive_field-accumulate),int(finallen))
            
            xinput = x.clone()[:,:,dilation:-dilation]
            x1 = self.tanh(self.xtanh[i](x))
            x2 = self.sigmoid(self.xsig[i](x))
            x = x1*x2
            cutlen = (x.shape[-1] - finallen)//2
            skipx += (self.xskip[i](x)).narrow(2,int(cutlen),int(finallen))
            x = self.xdense[i](x)
            x += xinput
        x = self.post2(F.relu(self.post1(F.relu(torch.cat((skipx,skipy), dim=1)))))
        return x

model = Net()
criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
#optimizer = optim.SGD(model.parameters(), lr = 0.1, momentum=0.9, weight_decay=1e-4)
#scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
#scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)

In [None]:
if continueTrain:
    if os.path.isfile(resumefile):
        print("=> loading checkpoint '{}'".format(resumefile))
        checkpoint = torch.load(resumefile)
        start_epoch = checkpoint['epoch']
        #best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        #optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(resumefile, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(resumefile))

In [None]:
def val():
    model.eval()
    startval_time = time.time()
    with torch.no_grad():
        # data, target = xval.to(device), yval.to(device)
        data, target = xtrain[:, :, 0:2 * receptive_field + shapeoftest].to(device),\
            ytrain[:, receptive_field:shapeoftest + receptive_field].to(device)
        output = model(data)
        # print(output.shape)
        # print(output[:,:,:10])
        pred = output.max(1, keepdim=True)[1]
        correct = pred.eq(target.view_as(pred)).sum().item() / pred.shape[-1]
        val_loss = criterion(output, target).item()
    print(correct, 'accurate')
    print('\nval set:loss{:.4f}:, ({:.3f} sec/step)\n'.format(val_loss, time.time() - startval_time))


def test():
    model.eval()
    startval_time = time.time()
    with torch.no_grad():
        xytest=torch.zeros_like(xtest)
        for ind in range(receptive_field,xtest.shape[-1] - receptive_field - 1):
            datax = xtrain[:,:,ind - receptive_field:ind + 1 + receptive_field].to(device)
            datay = xytest[:,:,ind - receptive_field:ind + 1 + receptive_field].to(device)
            output = model(datax,datay)
            pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
            xytest[:,:,ind]=cateToSignal(pred.item(),stage=0)
        ans = cateToSignal(xytest.reshape(-1)[receptive_field+1:-receptive_field],stage=1)
        sf.write('./vsCorpus/2arrayxtr.wav', ans, sample_rate)
        print('stored xtr done\n')
        
        xytest=torch.zeros_like(xtest)
        for ind in range(receptive_field,xtest.shape[-1] - receptive_field - 1):
            datax = xtest[:,:,ind - receptive_field:ind + 1 + receptive_field].to(device)
            datay = xytest[:,:,ind - receptive_field:ind + 1 + receptive_field].to(device)
            output = model(datax,datay)
            pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
            xytest[:,:,ind]=cateToSignal(pred.item(),stage=0)
        ans = cateToSignal(xytest.reshape(-1)[receptive_field+1:-receptive_field],stage=1)
        sf.write('./vsCorpus/2arrayxte.wav', ans, sample_rate)
        print('stored xte done\n')


def train(epoch):
    model.train()
    # idx = np.arange(xtrain.shape[-1] - 2 * sampleSize,1000)
    # 176000
    idx = np.arange(receptive_field, xtrain.shape[-1] - receptive_field - sampleSize - 1, 16000)
    np.random.shuffle(idx)
    for i, ind in enumerate(idx):
        start_time = time.time()
        datax = xtrain[:,:,ind - receptive_field:ind + sampleSize + receptive_field].to(device)
        datay = xytrain[:,:,ind - receptive_field:ind + sampleSize + receptive_field].to(device)
        target = ytrain[:,ind:ind + sampleSize].to(device)
            
        output = model(datax,datay)
        # print(output.shape)
        loss = criterion(output, target)
        lossrecord.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss:{:.6f}: , ({:.3f} sec/step)'.format(
            epoch, i, len(idx), 100. * i / len(idx), loss.item(), time.time() - start_time))
        if i % 100 == 0:
            with open("./lossRecord/"+lossname, "w") as f:
                for s in lossrecord:
                    f.write(str(s) +"\n")
            print('write finish')
            state={'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()}
            torch.save(state, './model/'+resumefile)



In [None]:
test()

In [None]:
'''Train Epoch: 335 [207/235 (88%)] Loss:2.003607: , (0.430 sec/step)
Train Epoch: 335 [208/235 (89%)] Loss:1.883396: , (0.430 sec/step)
Train Epoch: 335 [209/235 (89%)] Loss:1.929835: , (0.430 sec/step)
Train Epoch: 335 [210/235 (89%)] Loss:1.877405: , (0.430 sec/step)
Train Epoch: 335 [211/235 (90%)] Loss:2.432457: , (0.430 sec/step)
Train Epoch: 335 [212/235 (90%)] Loss:1.965565: , (0.430 sec/step)
Train Epoch: 335 [213/235 (91%)] Loss:2.112944: , (0.430 sec/step)
Train Epoch: 335 [214/235 (91%)] Loss:2.015351: , (0.430 sec/step)
Train Epoch: 335 [215/235 (91%)] Loss:2.040849: , (0.430 sec/step)
Train Epoch: 335 [216/235 (92%)] Loss:1.861006: , (0.430 sec/step)
Train Epoch: 335 [217/235 (92%)] Loss:2.068533: , (0.431 sec/step)
Train Epoch: 335 [218/235 (93%)] Loss:2.016664: , (0.430 sec/step)
Train Epoch: 335 [219/235 (93%)] Loss:1.942126: , (0.430 sec/step)
Train Epoch: 335 [220/235 (94%)] Loss:2.165619: , (0.429 sec/step)
Train Epoch: 335 [221/235 (94%)] Loss:1.912416: , (0.430 sec/step)
Train Epoch: 335 [222/235 (94%)] Loss:1.897818: , (0.429 sec/step)
Train Epoch: 335 [223/235 (95%)] Loss:2.100859: , (0.430 sec/step)
Train Epoch: 335 [224/235 (95%)] Loss:2.141154: , (0.429 sec/step)
Train Epoch: 335 [225/235 (96%)] Loss:1.984180: , (0.430 sec/step)
Train Epoch: 335 [226/235 (96%)] Loss:2.040324: , (0.429 sec/step)
Train Epoch: 335 [227/235 (97%)] Loss:1.824804: , (0.430 sec/step)
Train Epoch: 335 [228/235 (97%)] Loss:2.198944: , (0.430 sec/step)
Train Epoch: 335 [229/235 (97%)] Loss:1.941066: , (0.429 sec/step)
Train Epoch: 335 [230/235 (98%)] Loss:2.017070: , (0.429 sec/step)
Train Epoch: 335 [231/235 (98%)] Loss:2.079705: , (0.430 sec/step)
Train Epoch: 335 [232/235 (99%)] Loss:1.867700: , (0.430 sec/step)
Train Epoch: 335 [233/235 (99%)] Loss:1.923172: , (0.430 sec/step)
Train Epoch: 335 [234/235 (100%)] Loss:1.887689: , (0.429 sec/step)
Train Epoch: 336 [0/235 (0%)] Loss:2.155779: , (0.429 sec/step)'''

In [None]:
for epoch in range(100000):
    train(epoch)
    #test()