In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import torch.utils.data as utils
import librosa
import soundfile as sf
import time
import os

In [2]:
sampleSize=50000
sample_rate=16000
quantization_channels=256
dilations=[2**i for i in range(8)]*15
skipDim=256
residualDim=48
filterSize=3

In [3]:
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [4]:
use_cuda = torch.cuda.is_available()
torch.manual_seed(1)
device = torch.device("cuda" if use_cuda else "cpu")
#device = 'cpu'
torch.set_default_tensor_type('torch.cuda.FloatTensor')
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [5]:
def mu_law_encode(audio, quantization_channels=quantization_channels):
    '''Quantizes waveform amplitudes.'''
    mu = (quantization_channels - 1)*1.0
    # Perform mu-law companding transformation (ITU-T, 1988).
    # Minimum operation is here to deal with rare large amplitudes caused
    # by resampling.
    safe_audio_abs = np.minimum(np.abs(audio), 1.0)
    magnitude = np.log1p(mu * safe_audio_abs) / np.log1p(mu)
    signal = np.sign(audio) * magnitude
    # Quantize signal to the specified number of levels.
    #if(forX):return signal
    return ((signal + 1) / 2 * mu + 0.5).astype(int)
def mu_law_decode(output, quantization_channels=quantization_channels):
    '''Recovers waveform from quantized values.'''
    mu = quantization_channels - 1
    # Map values back to [-1, 1].
    signal = 2 * ((output*1.0) / mu) - 1
    # Perform inverse of mu-law transformation.
    magnitude = (1 / mu) * ((1 + mu)**np.abs(signal) - 1)
    return np.sign(signal) * magnitude

In [6]:
def readAudio(name):
    audio0, samplerate = sf.read(name, dtype='float32')
    return librosa.resample(audio0.T, samplerate, sample_rate).reshape(-1)
p=['./vsCorpus/origin_mix.wav','./vsCorpus/origin_vocal.wav',
   './vsCorpus/origin_mix.wav','./vsCorpus/origin_vocal.wav','./vsCorpus/pred_mix.wav']
xtrain,ytrain,xval,yval,xtest=readAudio(p[0]),readAudio(p[1]),readAudio(p[2]),readAudio(p[3]),readAudio(p[4])
assert((xtrain==xval).all())
assert((ytrain==yval).all())
assert((xtrain != ytrain).any())
#xtrain,ytrain,xval,yval=xtrain[:-sampleSize],ytrain[:-sampleSize],xval[-sampleSize:],yval[-sampleSize:]
xtrain,ytrain,xval,yval=xtrain[:-sampleSize],ytrain[:-sampleSize],xval[:sampleSize],yval[:sampleSize]
xtrain,xval,xtest=xtrain.reshape(1,1,-1),xval.reshape(1,1,-1),xtest.reshape(1,1,-1)
ytrain,yval=ytrain.reshape(1,-1),yval.reshape(1,-1)

In [7]:
xmean,xstd = xtrain.mean(),xtrain.std()
xtrain=(xtrain-xmean)/xstd
xval=(xval-xmean)/xstd
xtest=(xtest-xmean)/xstd
ytrain,yval=mu_law_encode(ytrain),mu_law_encode(yval)

In [8]:
xtrain,ytrain,xval,yval,xtest = torch.from_numpy(xtrain).type(torch.float32),\
                                torch.from_numpy(ytrain).type(torch.LongTensor),\
                                torch.from_numpy(xval).type(torch.float32),\
                                torch.from_numpy(yval).type(torch.LongTensor),\
                                torch.from_numpy(xtest).type(torch.float32)

In [9]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        sd,qd,rd = skipDim,quantization_channels,residualDim
        self.causal = nn.Conv1d(in_channels=1,out_channels=rd,kernel_size=3,padding=1)
        self.layer=dict()
        for i, d in enumerate(dilations):
            self.layer['tanh'+str(i)] = nn.Conv1d(in_channels=rd,out_channels=rd,kernel_size=3,padding=d,dilation=d)
            self.layer['sigmoid'+str(i)] = nn.Conv1d(in_channels=rd,out_channels=rd,kernel_size=3,padding=d,dilation=d)
            self.layer['skip'+str(i)] = nn.Conv1d(in_channels=rd,out_channels=sd,kernel_size=1,padding=0)
            self.layer['dense'+str(i)] = nn.Conv1d(in_channels=rd,out_channels=rd,kernel_size=1,padding=0)
        self.post1 = nn.Conv1d(in_channels=sd,out_channels=sd,kernel_size=1,padding=0)
        self.post2 = nn.Conv1d(in_channels=sd,out_channels=qd,kernel_size=1,padding=0)
        self.tanh,self.sigmoid = nn.Tanh(),nn.Sigmoid()

    def forward(self, x):
        x = self.causal(x)
        #print('x.shape',x.shape)
        skip_connections = torch.zeros([1,skipDim,x.shape[2]],dtype=torch.float32,device=device)
        #skip_connections = []
        #print('skip_connections',skip_connections.shape)
        for i, dilation in enumerate(dilations):
            xinput = x.clone()
            x1 = self.tanh(self.layer['tanh'+str(i)](x))
            #print('tanh.shape',x1.shape)
            x2 = self.sigmoid(self.layer['sigmoid'+str(i)](x))
            #print('sigmoid.shape',x2.shape)
            x = x1*x2
            #print('multi',x3.shape)
            skip_connections += self.layer['skip'+str(i)](x)
            #print(skip_connections.cpu()[:,:,:10])
            #cur =self.layer['skip'+str(i)](x3)
            #skip_connections.append(cur)
            #print('skip.shape',self.layer['skip'+str(i)](x).shape)
            x = self.layer['dense'+str(i)](x)
            #print('dense.shape',x.shape)
            x += xinput
        #skip_connections = torch.cat(skip_connections,dim=0)
        #skip_connections = torch.sum(skip_connections,dim=0,keepdim=True)
        x = self.post2(F.relu(self.post1(F.relu(skip_connections))))
        return x

model = Net().cuda()
criterion = nn.CrossEntropyLoss().cuda()
#optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
optimizer = optim.Adam(model.parameters(),weight_decay=1e-5)

In [10]:
#torch.sum(c,dim=0,keepdim=True)

In [11]:
#model = torch.load('torchmodel0.681600')

In [14]:
def val():
    model.eval()
    startval_time = time.time()
    with torch.no_grad():
        #data, target = xval.to(device), yval.to(device)
        data, target = xtrain[:,:,:176000].to(device), ytrain[:,:176000].to(device)
        output = model(data)
        #print(output.shape)
        #print(output[:,:,:10])
        val_loss = criterion(output, target).item()
    print('\nval set:  {:.4f}, ({:.3f} sec/step)\n'.format(val_loss,time.time()-startval_time))

def test():
    model.eval()
    startval_time = time.time()
    with torch.no_grad():
        output = model(xtest.to(device))
        pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
        ans = mu_law_decode(pred)
        sf.write('./vsCorpus/resultxte.wav', ans, sample_rate)
        
        #output = model(xtrain[:,:,:sampleSize].to(device))
        output = model(xtrain[:,:,:176000].to(device))
        pred = output.max(1, keepdim=True)[1].cpu().numpy().reshape(-1)
        ans = mu_law_decode(pred)
        sf.write('./vsCorpus/resultxtr.wav', ans, sample_rate)
        print('stored done\n')
    
def train(epoch):
    model.train()
    #idx = np.arange(xtrain.shape[-1] - 2 * sampleSize,1000)
    #176000
    idx = np.arange(0,126000,3000)
    np.random.shuffle(idx)
    for i, ind in enumerate(idx):
        start_time = time.time()
        data, target = xtrain[:,:,ind:ind+sampleSize].to(device), ytrain[:,ind:ind+sampleSize].to(device)
        output = model(data)
        #print(output.shape)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f} , ({:.3f} sec/step)'.format(
                epoch, i, len(idx),100. * i / len(idx), loss.item(),time.time() - start_time))
        if i % 100 == 0:
            val()
            test()
            torch.save(model, '15*2**8')

In [15]:
for epoch in range(100000):
    train(epoch)
    test()


val set:  3.9157, (4.142 sec/step)

stored done



  "type " + obj.__name__ + ". It won't be checked "


stored done


val set:  3.8946, (4.144 sec/step)

stored done

stored done


val set:  3.8724, (4.145 sec/step)

stored done

stored done


val set:  3.8595, (4.144 sec/step)

stored done



stored done


val set:  3.8273, (4.144 sec/step)

stored done

stored done


val set:  3.7979, (4.145 sec/step)

stored done

stored done


val set:  3.7834, (4.143 sec/step)

stored done



stored done


val set:  3.7554, (4.145 sec/step)

stored done

stored done


val set:  3.7353, (4.144 sec/step)

stored done

stored done


val set:  3.7091, (4.143 sec/step)

stored done



stored done


val set:  3.7119, (4.145 sec/step)

stored done

stored done


val set:  3.6737, (4.142 sec/step)

stored done

stored done


val set:  3.6549, (4.146 sec/step)

stored done



stored done


val set:  3.6679, (4.140 sec/step)

stored done

stored done


val set:  3.6372, (4.146 sec/step)

stored done

stored done


val set:  3.6189, (4.143 sec/step)

stored done



stored done


val set:  3.6048, (4.142 sec/step)

stored done

stored done


val set:  3.5956, (4.141 sec/step)

stored done

stored done


val set:  3.5689, (4.142 sec/step)

stored done



stored done


val set:  3.5468, (4.147 sec/step)

stored done

stored done


val set:  3.5400, (4.141 sec/step)

stored done

stored done


val set:  3.5354, (4.141 sec/step)

stored done



stored done


val set:  3.5219, (4.143 sec/step)

stored done

stored done


val set:  3.5040, (4.144 sec/step)

stored done

stored done


val set:  3.4971, (4.143 sec/step)

stored done



stored done


val set:  3.4799, (4.143 sec/step)

stored done

stored done


val set:  3.4697, (4.144 sec/step)

stored done

stored done


val set:  3.4813, (4.144 sec/step)

stored done



stored done


val set:  3.4585, (4.143 sec/step)

stored done

stored done


val set:  3.4422, (4.144 sec/step)

stored done

stored done


val set:  3.4336, (4.143 sec/step)

stored done



stored done


val set:  3.4255, (4.144 sec/step)

stored done

stored done


val set:  3.4271, (4.145 sec/step)

stored done

stored done


val set:  3.4124, (4.143 sec/step)

stored done



stored done


val set:  3.3958, (4.143 sec/step)

stored done

stored done


val set:  3.3946, (4.144 sec/step)

stored done

stored done


val set:  3.3777, (4.145 sec/step)

stored done



stored done


val set:  3.3594, (4.145 sec/step)

stored done

stored done


val set:  3.3567, (4.145 sec/step)

stored done

stored done


val set:  3.3456, (4.145 sec/step)

stored done



stored done


val set:  3.3399, (4.143 sec/step)

stored done

stored done


val set:  3.3399, (4.145 sec/step)

stored done

stored done


val set:  3.3285, (4.145 sec/step)

stored done



stored done


val set:  3.3142, (4.144 sec/step)

stored done

stored done


val set:  3.3210, (4.142 sec/step)

stored done

stored done


val set:  3.3109, (4.143 sec/step)

stored done



stored done


val set:  3.2958, (4.144 sec/step)

stored done

stored done


val set:  3.2872, (4.142 sec/step)

stored done

stored done


val set:  3.2801, (4.143 sec/step)

stored done



stored done


val set:  3.2829, (4.144 sec/step)

stored done

stored done


val set:  3.2779, (4.146 sec/step)

stored done

stored done


val set:  3.2819, (4.143 sec/step)

stored done



stored done


val set:  3.2542, (4.143 sec/step)

stored done

stored done


val set:  3.2466, (4.143 sec/step)

stored done

stored done


val set:  3.2342, (4.143 sec/step)

stored done



stored done


val set:  3.2351, (4.144 sec/step)

stored done

stored done


val set:  3.2359, (4.143 sec/step)

stored done

stored done


val set:  3.2271, (4.146 sec/step)

stored done



stored done


val set:  3.2200, (4.145 sec/step)

stored done

stored done


val set:  3.2262, (4.145 sec/step)

stored done

stored done


val set:  3.2028, (4.147 sec/step)

stored done



stored done


val set:  3.2072, (4.147 sec/step)

stored done

stored done


val set:  3.1985, (4.146 sec/step)

stored done

stored done


val set:  3.1955, (4.144 sec/step)

stored done



stored done


val set:  3.1837, (4.145 sec/step)

stored done

stored done


val set:  3.1783, (4.143 sec/step)

stored done

stored done


val set:  3.1757, (4.143 sec/step)

stored done



stored done


val set:  3.1745, (4.145 sec/step)

stored done

stored done


val set:  3.1636, (4.144 sec/step)

stored done

stored done


val set:  3.1864, (4.146 sec/step)

stored done



stored done


val set:  3.1713, (4.144 sec/step)

stored done

stored done


val set:  3.1506, (4.146 sec/step)

stored done

stored done


val set:  3.1629, (4.144 sec/step)

stored done



stored done


val set:  3.1433, (4.147 sec/step)

stored done

stored done


val set:  3.1393, (4.144 sec/step)

stored done

stored done


val set:  3.1340, (4.143 sec/step)

stored done



stored done


val set:  3.1471, (4.140 sec/step)

stored done

stored done


val set:  3.1494, (4.143 sec/step)

stored done

stored done


val set:  3.1251, (4.142 sec/step)

stored done



stored done


val set:  3.1190, (4.141 sec/step)

stored done

stored done


val set:  3.1171, (4.142 sec/step)

stored done

stored done


val set:  3.1111, (4.144 sec/step)

stored done



stored done


val set:  3.1072, (4.145 sec/step)

stored done

stored done


val set:  3.1076, (4.143 sec/step)

stored done

stored done


val set:  3.0941, (4.143 sec/step)

stored done



stored done


val set:  3.0974, (4.143 sec/step)

stored done

stored done


val set:  3.0924, (4.141 sec/step)

stored done

stored done


val set:  3.0995, (4.140 sec/step)

stored done



stored done


val set:  3.0826, (4.142 sec/step)

stored done

stored done


val set:  3.0965, (4.140 sec/step)

stored done

stored done


val set:  3.0886, (4.143 sec/step)

stored done



stored done


val set:  3.0896, (4.142 sec/step)

stored done

stored done


val set:  3.0871, (4.142 sec/step)

stored done

stored done


val set:  3.0780, (4.144 sec/step)

stored done



stored done


val set:  3.0671, (4.141 sec/step)

stored done

stored done


val set:  3.0590, (4.142 sec/step)

stored done

stored done


val set:  3.0644, (4.141 sec/step)

stored done



stored done


val set:  3.0651, (4.143 sec/step)

stored done

stored done


val set:  3.0549, (4.142 sec/step)

stored done

stored done


val set:  3.0620, (4.142 sec/step)

stored done



stored done


val set:  3.0518, (4.143 sec/step)

stored done

stored done


val set:  3.0478, (4.145 sec/step)

stored done

stored done


val set:  3.0606, (4.142 sec/step)

stored done



stored done


val set:  3.0419, (4.146 sec/step)

stored done

stored done


val set:  3.0441, (4.141 sec/step)

stored done

stored done


val set:  3.0500, (4.141 sec/step)

stored done



stored done


val set:  3.0355, (4.144 sec/step)

stored done

stored done


val set:  3.0471, (4.144 sec/step)

stored done

stored done


val set:  3.0412, (4.147 sec/step)

stored done



stored done


val set:  3.0305, (4.146 sec/step)

stored done

stored done


val set:  3.0396, (4.146 sec/step)

stored done

stored done


val set:  3.0559, (4.144 sec/step)



stored done

stored done


val set:  3.0318, (4.146 sec/step)

stored done

stored done


val set:  3.0248, (4.142 sec/step)

stored done



stored done


val set:  3.0228, (4.145 sec/step)

stored done

stored done


val set:  3.0218, (4.144 sec/step)

stored done

stored done


val set:  3.0249, (4.144 sec/step)

stored done



stored done


val set:  3.0214, (4.145 sec/step)

stored done

stored done


val set:  3.0209, (4.146 sec/step)

stored done

stored done


val set:  3.0095, (4.145 sec/step)

stored done



stored done


val set:  3.0069, (4.146 sec/step)

stored done

stored done


val set:  2.9951, (4.146 sec/step)

stored done

stored done


val set:  3.0053, (4.146 sec/step)

stored done



stored done


val set:  2.9960, (4.146 sec/step)

stored done

stored done


val set:  3.0015, (4.144 sec/step)

stored done

stored done


val set:  2.9912, (4.146 sec/step)

stored done



stored done


val set:  2.9891, (4.143 sec/step)

stored done

stored done


val set:  2.9807, (4.146 sec/step)

stored done

stored done


val set:  2.9824, (4.145 sec/step)

stored done



stored done


val set:  3.0061, (4.146 sec/step)

stored done

stored done


val set:  2.9875, (4.148 sec/step)

stored done

stored done


val set:  2.9866, (4.145 sec/step)

stored done



stored done


val set:  2.9713, (4.145 sec/step)

stored done

stored done


val set:  2.9856, (4.143 sec/step)

stored done

stored done


val set:  2.9908, (4.147 sec/step)

stored done



stored done


val set:  3.0334, (4.145 sec/step)

stored done

stored done


val set:  2.9782, (4.146 sec/step)

stored done

stored done


val set:  2.9672, (4.145 sec/step)

stored done



stored done


val set:  2.9699, (4.147 sec/step)

stored done

stored done


val set:  2.9688, (4.143 sec/step)

stored done

stored done


val set:  2.9689, (4.144 sec/step)

stored done



stored done


val set:  2.9643, (4.145 sec/step)

stored done

stored done


val set:  2.9729, (4.146 sec/step)

stored done

stored done


val set:  2.9672, (4.147 sec/step)

stored done



stored done


val set:  2.9666, (4.145 sec/step)

stored done

stored done


val set:  2.9528, (4.146 sec/step)

stored done

stored done


val set:  2.9565, (4.144 sec/step)

stored done



stored done


val set:  2.9543, (4.146 sec/step)

stored done

stored done


val set:  2.9602, (4.143 sec/step)

stored done

stored done


val set:  2.9542, (4.147 sec/step)

stored done



stored done


val set:  2.9524, (4.142 sec/step)

stored done

stored done


val set:  2.9472, (4.145 sec/step)

stored done

stored done


val set:  2.9671, (4.142 sec/step)

stored done



stored done


val set:  2.9397, (4.144 sec/step)

stored done

stored done


val set:  2.9548, (4.143 sec/step)

stored done

stored done


val set:  2.9362, (4.142 sec/step)

stored done



stored done


val set:  2.9395, (4.139 sec/step)

stored done

stored done


val set:  2.9431, (4.142 sec/step)

stored done

stored done


val set:  2.9350, (4.143 sec/step)

stored done



stored done


val set:  2.9436, (4.142 sec/step)

stored done

stored done


val set:  2.9896, (4.144 sec/step)

stored done

stored done


val set:  2.9352, (4.143 sec/step)

stored done



stored done


val set:  2.9421, (4.145 sec/step)

stored done

stored done


val set:  2.9309, (4.142 sec/step)

stored done

stored done


val set:  2.9355, (4.144 sec/step)

stored done



stored done


val set:  2.9325, (4.142 sec/step)

stored done

stored done


val set:  2.9344, (4.142 sec/step)

stored done

stored done


val set:  2.9380, (4.143 sec/step)

stored done



stored done


val set:  2.9188, (4.144 sec/step)

stored done

stored done


val set:  2.9298, (4.141 sec/step)

stored done

stored done


val set:  2.9307, (4.142 sec/step)

stored done



stored done


val set:  2.9318, (4.141 sec/step)

stored done

stored done


val set:  2.9433, (4.144 sec/step)

stored done

stored done


val set:  2.9276, (4.143 sec/step)

stored done



stored done


val set:  2.9090, (4.143 sec/step)

stored done

stored done


val set:  2.9264, (4.141 sec/step)

stored done

stored done




val set:  2.9171, (4.144 sec/step)

stored done

stored done


val set:  2.9172, (4.143 sec/step)

stored done

stored done


val set:  2.9180, (4.142 sec/step)

stored done



stored done


val set:  2.9181, (4.144 sec/step)

stored done

stored done


val set:  2.9157, (4.144 sec/step)

stored done

stored done


val set:  2.9168, (4.144 sec/step)

stored done



stored done


val set:  2.9157, (4.143 sec/step)

stored done

stored done


val set:  2.9114, (4.142 sec/step)

stored done

stored done


val set:  2.9272, (4.141 sec/step)

stored done



stored done


val set:  2.9056, (4.140 sec/step)

stored done

stored done


val set:  2.9094, (4.140 sec/step)

stored done

stored done


val set:  2.9116, (4.141 sec/step)

stored done



stored done


val set:  2.9092, (4.142 sec/step)

stored done

stored done


val set:  2.9110, (4.141 sec/step)

stored done

stored done


val set:  2.9175, (4.144 sec/step)

stored done



stored done


val set:  2.9144, (4.142 sec/step)

stored done

stored done


val set:  2.8995, (4.144 sec/step)

stored done

stored done


val set:  2.8950, (4.145 sec/step)

stored done



stored done


val set:  2.9034, (4.145 sec/step)

stored done

stored done


val set:  2.8951, (4.145 sec/step)

stored done

stored done


val set:  2.8907, (4.145 sec/step)

stored done



stored done


val set:  2.9016, (4.144 sec/step)

stored done

stored done


val set:  2.8954, (4.142 sec/step)

stored done

stored done


val set:  2.9104, (4.143 sec/step)

stored done



stored done


val set:  2.8831, (4.142 sec/step)

stored done

stored done


val set:  2.8868, (4.143 sec/step)

stored done

stored done


val set:  2.8862, (4.139 sec/step)

stored done



stored done


val set:  2.8943, (4.144 sec/step)

stored done

stored done


val set:  2.8850, (4.144 sec/step)

stored done

stored done


val set:  2.9019, (4.142 sec/step)

stored done



stored done


val set:  2.8857, (4.146 sec/step)

stored done

stored done


val set:  2.8887, (4.145 sec/step)

stored done

stored done


val set:  2.9055, (4.144 sec/step)

stored done



stored done


val set:  2.9305, (4.143 sec/step)

stored done

stored done


val set:  2.8815, (4.143 sec/step)

stored done

stored done


val set:  2.8998, (4.143 sec/step)

stored done



stored done


val set:  2.8757, (4.145 sec/step)

stored done

stored done


val set:  2.8816, (4.143 sec/step)

stored done

stored done


val set:  2.8773, (4.146 sec/step)

stored done



stored done


val set:  2.8770, (4.143 sec/step)

stored done

stored done


val set:  2.8833, (4.145 sec/step)

stored done

stored done


val set:  2.8719, (4.145 sec/step)

stored done





KeyboardInterrupt: 

In [None]:
#model = torch.load('torchmodel')

In [17]:
torch.save(model, 'loss23*1.+4*3.+16*2.step_9072_repeat15*2**8resu32sample50000')

  "type " + obj.__name__ + ". It won't be checked "


In [None]:
torch.

In [16]:
model

Net(
  (causal): Conv1d(1, 32, kernel_size=(3,), stride=(1,), padding=(1,))
  (post1): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
  (post2): Conv1d(512, 256, kernel_size=(1,), stride=(1,))
  (tanh): Tanh()
  (sigmoid): Sigmoid()
)