# **drum**

## **load data**

In [0]:
import torch
from torch import autograd
from torch.utils.data import DataLoader
from torch import optim
import torch.nn as nn
import torch.nn.parallel
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import librosa
import librosa.display
import pickle as pk
class AudioDatasets(torch.utils.data.Dataset):
    def __init__(self):
      self.wave_num = 16384    
    def __getitem__(self, index):
        
        song = 1
        family = 1
        index %= 320 #352 file divide 176 176
        index += 1166
        loadpath = './drive/My Drive/drums/npy/{}.npy'.format(index)
        wave = np.load(loadpath)
        wave = wave.reshape(1,self.wave_num)
        wave = torch.from_numpy(wave)

        return wave, song, family
    def __len__(self):
        return 320
        
def loadData(batch_size):
   
    trainsets = AudioDatasets()
    trainloader = DataLoader(trainsets,batch_size=batch_size, shuffle=True, num_workers=10)
    return trainloader

## **phase shuffle**

In [0]:
#copy from https://github.com/chrisdonahue/wavegan
class PhaseShuffle(nn.Module):
    """
    Performs phase shuffling, i.e. shifting feature axis of a 3D tensor
    by a random integer in {-n, n} and performing reflection padding where
    necessary
    If batch shuffle is enabled, only a single shuffle is applied to the entire
    batch, rather than each sample in the batch.
    """

    def __init__(self, shift_factor, batch_shuffle=False):
        super(PhaseShuffle, self).__init__()
        self.shift_factor = shift_factor
        self.batch_shuffle = batch_shuffle

    def forward(self, x):
        # Return x if phase shift is disabled
        if self.shift_factor == 0:
            return x

        if self.batch_shuffle:
            # Make sure to use PyTorcTrueh to generate number RNG state is all shared
            k = int(torch.Tensor(1).random_(0, 2*self.shift_factor + 1)) - self.shift_factor

            # Return if no phase shift
            if k == 0:
                return x

            # Slice feature dimension
            if k > 0:
                x_trunc = x[:, :, :-k]
                pad = (k, 0)
            else:
                x_trunc = x[:, :, -k:]
                pad = (0, -k)

            # Reflection padding
            x_shuffle = F.pad(x_trunc, pad, mode='reflect')

        else:
            # Generate shifts for each sample in the batch
            k_list = torch.Tensor(x.shape[0]).random_(0, 2*self.shift_factor+1)\
                - self.shift_factor
            k_list = k_list.numpy().astype(int)

            # Combine sample indices into lists so that less shuffle operations
            # need to be performed
            k_map = {}
            for idx, k in enumerate(k_list):
                k = int(k)
                if k not in k_map:
                    k_map[k] = []
                k_map[k].append(idx)

            # Make a copy of x for our output
            x_shuffle = x.clone()

            # Apply shuffle to each sample
            for k, idxs in k_map.items():
                if k > 0:
                    x_shuffle[idxs] = F.pad(x[idxs][..., :-k], (k,0), mode='reflect')
                else:
                    x_shuffle[idxs] = F.pad(x[idxs][..., -k:], (0,-k), mode='reflect')

        assert x_shuffle.shape == x.shape, "{}, {}".format(x_shuffle.shape,
                                                           x.shape)
        return x_shuffle

## **dicriminator**

In [0]:
class Discriminator(nn.Module):
    def __init__(self, model_size=64, num_channels=1, shift_factor=2, alpha=0.2, batch_shuffle=False):
        super(Discriminator, self).__init__()
        self.model_size = model_size # d
        self.num_channels = num_channels # c
        self.shift_factor = shift_factor # n
        self.alpha = alpha

        # Conv2d(in_channels, out_channels, kernel_size, stride=1, etc.)
        self.conv1 = nn.DataParallel(nn.Conv1d(num_channels, model_size, 25, stride=4, padding=11))
        self.conv2 = nn.DataParallel(
            nn.Conv1d(model_size, 2 * model_size, 25, stride=4, padding=11))
        self.conv3 = nn.DataParallel(
            nn.Conv1d(2 * model_size, 4 * model_size, 25, stride=4, padding=11))
        self.conv4 = nn.DataParallel(
            nn.Conv1d(4 * model_size, 8 * model_size, 25, stride=4, padding=11))
        self.conv5 = nn.DataParallel(
            nn.Conv1d(8 * model_size, 16 * model_size, 25, stride=4, padding=11))
        self.ps1 = PhaseShuffle(shift_factor, batch_shuffle=batch_shuffle)
        self.ps2 = PhaseShuffle(shift_factor, batch_shuffle=batch_shuffle)
        self.ps3 = PhaseShuffle(shift_factor, batch_shuffle=batch_shuffle)
        self.ps4 = PhaseShuffle(shift_factor, batch_shuffle=batch_shuffle)
        self.fc1 = nn.DataParallel(nn.Linear(256 * model_size, 1))

        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight.data)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), negative_slope=self.alpha)
        x = self.ps1(x)

        x = F.leaky_relu(self.conv2(x), negative_slope=self.alpha)
        x = self.ps2(x)

        x = F.leaky_relu(self.conv3(x), negative_slope=self.alpha)
        x = self.ps3(x)

        x = F.leaky_relu(self.conv4(x), negative_slope=self.alpha)
        x = self.ps4(x)

        x = F.leaky_relu(self.conv5(x), negative_slope=self.alpha)

        x = x.view(-1, 256 * self.model_size)

        return torch.sigmoid(self.fc1(x))

## **generator**

In [0]:
class Generator(nn.Module):
    def __init__(self, model_size=64, num_channels=1, latent_dim=100,
                 post_proc_filt_len=512):
        super(Generator, self).__init__()
        self.model_size = model_size # d
        self.num_channels = num_channels # c
        self.latent_dim = latent_dim
        self.post_proc_filt_len = post_proc_filt_len
        
        self.fc1 = nn.DataParallel(nn.Linear(latent_dim, 256 * model_size))
        
        self.tconv1 = None
        self.tconv2 = None
        self.tconv3 = None
        self.tconv4 = None
        self.tconv5 = None
        

        self.tconv1 = nn.DataParallel(
                 nn.ConvTranspose1d(16 * model_size, 8 * model_size, 25, stride=4, padding=11,
                                    output_padding=1))
        self.tconv2 = nn.DataParallel(
                 nn.ConvTranspose1d(8 * model_size, 4 * model_size, 25, stride=4, padding=11,
                                    output_padding=1))
        self.tconv3 = nn.DataParallel(
                 nn.ConvTranspose1d(4 * model_size, 2 * model_size, 25, stride=4, padding=11,
                                    output_padding=1))
        self.tconv4 = nn.DataParallel(
                 nn.ConvTranspose1d(2 * model_size, model_size, 25, stride=4, padding=11,
                                    output_padding=1))
        self.tconv5 = nn.DataParallel(
                 nn.ConvTranspose1d(model_size, num_channels, 25, stride=4, padding=11,
                                    output_padding=1))

        
        if post_proc_filt_len:
            self.ppfilter1 = nn.DataParallel(nn.Conv1d(num_channels, num_channels, post_proc_filt_len))
        
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight.data)

    def forward(self, x):

        x = self.fc1(x).view(-1, 16 * self.model_size, 16)
        x = F.relu(x)
        output = None
        
        x = F.relu(self.tconv1(x))
        x = F.relu(self.tconv2(x))
        x = F.relu(self.tconv3(x))
        x = F.relu(self.tconv4(x))
        output = torch.tanh(self.tconv5(x))
                    
        if self.post_proc_filt_len:
            # Pad for "same" filtering
            if (self.post_proc_filt_len % 2) == 0:
                pad_left = self.post_proc_filt_len // 2
                pad_right = pad_left - 1
            else:
                pad_left = (self.post_proc_filt_len - 1) // 2
                pad_right = pad_left
            output = self.ppfilter1(F.pad(output, (pad_left, pad_right)))
        
        return output

## **loss compute**

In [0]:
#compute D and G loss
def cal_D_loss(D,G,real_wave,noisev):
  D.zero_grad()
  #convenient values for
  one = torch.FloatTensor([1])
  mone = one * -1
  one = one.cuda()
  mone = mone.cuda()
  
  # train with real
  real_out = D(real_wave)
  real_out = real_out.mean()
  real_out.backward(mone)

  # train fake wave
  #put the noise through the Generator
  fake_wave = autograd.Variable(G(noisev).data)
  #print(fake_wave.shape)
  inputv = fake_wave
  fake_out = D(inputv)
  fake_out = fake_out.mean()
  fake_out.backward(one)

  #train with gradient penalty
  gradient_penalty = calc_gradient_penalty(D, real_wave.data, fake_wave.data, batch_size)
  gradient_penalty.backward(one)

  D_lost = fake_out - real_out + gradient_penalty
  Wass_D = real_out - fake_out
  
  return D_lost, Wass_D

def cal_G_loss(G,D,batch_size, laten_dim):
  #train Generator
  G.zero_grad()
  
  #convenient values for
  one = torch.FloatTensor([1])
  mone = one * -1
  one = one.cuda()
  mone = mone.cuda()
  
  #generate nosie
  noise = torch.Tensor(batch_size, laten_dim).uniform_(-1, 1)
  noisev = autograd.Variable(noise.cuda(), requires_grad=False)
  #through the generator
  fake_wave = G(noisev)
  output = D(fake_wave)
  output = output.mean()
  output.backward(mone)
  G_lost = - output
  return G_lost


def calc_gradient_penalty(netD, real_data, fake_data, batch_size):
  alpha = torch.rand(batch_size, 1, 1)
  alpha = alpha.expand(real_data.size())
  alpha = alpha.cuda()

  interpolates = alpha * real_data + ((1 - alpha) * fake_data)
  interpolates = interpolates.cuda()
  
  interpolates = autograd.Variable(interpolates, requires_grad=True)

  disc_interpolates = netD(interpolates)

  gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                            grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
                            create_graph=True, retain_graph=True, only_inputs=True)[0]
  
  gradients = gradients.view(gradients.size(0), -1)

  gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 0.1
  return gradient_penalty

## **save model and save audio**

In [0]:

def save_sample(data,sample_size,epoch):
  for i in range(sample_size):
    sample = data[i].reshape(16384,1)    
    
    librosa.output.write_wav('./drive/My Drive/drums/sample/{}_{}.wav'.format(epoch,i),sample, 16000)

def showing_wave(data):
  sample = data.reshape(16384)
  print(sample)
  plt.figure(figsize=(25,8))
  librosa.display.waveplot(sample,16000)
  plt.show()
  plt.close()

def plot_loss(G,D):
  plt.cla()
  plt.plot(G,c='#4AD631',label='g_loss')
  plt.plot(D,label='d_loss')
  plt.xlabel('epoch')
  plt.ylabel('Loss')
  plt.title('learning rate 1e-6 drum')
  plt.legend()
  plt.show()
    

## **Training**

In [0]:
import time
if __name__=="__main__":
    #hyperparameter
  
    batch_size = 64
    batch_per_epoch = 1
    epoch = 500
    d_epoch = 5
    lr = 0.000001 #0.0001, 0.00005,  0.00001,  0.000005,  0.000001,  0.0000005
    beta_1 = 0.5
    beta_2 = 0.9
    laten_dim = 100
    sample_size = 2
    sample_per_epoch = 1
    model_per_epoch = 1
    load_model = True
    D = Discriminator(batch_shuffle=True)
    G = Generator()
    
    #---------training times-----------
    TRAIN_TIME = '1_'
    main_path = './drive/My Drive/drums/'
  
    G_model_path = main_path + 'model/G/'+ TRAIN_TIME
    D_model_path = main_path + 'model/D/'+ TRAIN_TIME
        
    #load the parameter of network 
    if not load_model:
      print('loading model...')
      
      G.load_state_dict(torch.load(main_path + 'model/G/drum_clap.pkl'))
      #D.load_state_dict(torch.load(main_path + 'model/D/2_1600.pkl'))  
    
    #save history
    history = []
    plot_history_D = []
    plot_history_G = []
    
    
    
    #optimizer
    d_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(beta_1,beta_2))
    g_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(beta_1,beta_2))

    '''
    #========= test generator ==============#
    #sample noise
    sample_noise = torch.randn(sample_size, laten_dim)
    sample_noisev = autograd.Variable(sample_noise.cuda())
    sample_output = G(sample_noisev)
    sample_output = sample_output.cpu()
    for j in range(sample_size):
      showing_wave(sample_output.data.numpy()[j])
    save_sample(sample_output.data.numpy(),sample_size,'drumclap')
    
    '''
    start_time = time.clock()
    for epoch_iter in range(epoch):
        tolerrent = 0
        trainloader = loadData(batch_size)
        iter_train = iter(trainloader)
        history_epoch = {
            'D':{
                'loss':[],
                'wass_loss':[],
                'valid_loss':[],
                'valid_wass':[]
            },
            'G':{
                'loss':[]
            }
        }
        
        epoch_time = time.clock()
        print("Epoch: {}/{}".format(epoch_iter + 1, epoch))
        #iteration the batch size
        for batch_iter in range(batch_per_epoch):
            print("Batch: {}/{} ".format(batch_iter + 1, batch_per_epoch),end='')
 
            
            #set model parameters to require gradients to be computer and stored
            for p in D.parameters():
                p.requires_grad = True
# =============================================================================
#           (1) training the D model
# =============================================================================
            for d_iter in range(d_epoch): #5 iters
                print('#',end='')
                #get next train data from dataloader (64,1,16384)
                try:
                  traindata = next(iter_train)
                except:
                  print('\n ---------------- data loader exauted ---------------')
                  
                  trainloader = loadData(batch_size)
                  iter_train = iter(trainloader)
                  traindata = next(iter_train)
               
                real_wave = autograd.Variable(traindata[0].cuda())
                #generate nosie
                noise = torch.Tensor(batch_size, laten_dim).uniform_(-1, 1)
                
                noisev = autograd.Variable(noise.cuda(), requires_grad=False)
                
                #compute loss
                D_loss, Wass_D = cal_D_loss( D, G, real_wave, noisev)                                               
                d_optimizer.step()
                
                #D_valid_loss, Wass_valid_loss = cal_D_loss( D, G, real_wave, noisev)
                
                D_loss = D_loss.cpu()
                Wass_D = Wass_D.cpu()
                #D_valid_loss = D_valid_loss.cpu()
                #Wass_valid_loss = Wass_valid_loss.cpu()
                
                #save history
                #_temp = history_epoch['D']
                history_epoch['D']['loss'].append(D_loss.data.numpy())
                history_epoch['D']['wass_loss'].append(Wass_D.data.numpy())
                #history_epoch['D']['valid_loss'].append(D_valid_loss.data.numpy())
                #history_epoch['D']['valid_wass'].append(Wass_valid_loss.data.numpy())
                

                
# =============================================================================
#             (2)updade G model
# =============================================================================
            print('#',end='')
            #fix the D parameters
            for p in D.parameters():
                p.requires_grad = False
                
            
            G_lost = cal_G_loss(G,D,batch_size, laten_dim)
            
            g_optimizer.step()
            G_lost = G_lost.cpu()


            history_epoch['G']['loss'].append(G_lost.data.numpy())

            

            #print accuracy
            print(' d_loss: {:.4f},'
                    'd_wass: {:.4f},'
                    'g_loss: {:.4f}'.format(
                    history_epoch['D']['loss'][-1], 
                    history_epoch['D']['wass_loss'][-1],                   
                    history_epoch['G']['loss'][-1]))
            plot_history_D.append(history_epoch['D']['loss'][-1])
            plot_history_G.append(history_epoch['G']['loss'][-1])
            #accelerate lr
            if history_epoch['G']['loss'][-1] > -0.0001:
              tolerrent += 1

            #print (print_temp)
# =============================================================================
#       save the sample    
# =============================================================================
        #change lr
        
        if tolerrent >= 7 and lr > 0.000001:
          lr = 0.000001
          d_optimizer.param_groups[0]['lr'] = lr
          g_optimizer.param_groups[0]['lr'] = lr
        
          print('------------------changing learning rate to {}!------------------'.format(lr))
        
        #print time
        time_a = int(time.clock()-start_time)/60
        time_b = int(time.clock()-epoch_time)/60
        print('epoch time: {:.2f} min || total time: {:.2f} min'.format(time_b, time_a))
        
        #append the history
        history.append(history_epoch)
        
        #save the model
        if (epoch_iter+1)%200 == 0:
          print('saving the model...')  
          try:
            torch.save(G.state_dict(), G_model_path + str(epoch_iter+1) + '.pkl', pickle_protocol=pk.HIGHEST_PROTOCOL)
            torch.save(D.state_dict(), D_model_path + str(epoch_iter+1) + '.pkl', pickle_protocol=pk.HIGHEST_PROTOCOL)
          except:
            print('failed to save model')

        #sample noise
        sample_noise = torch.Tensor(sample_size, laten_dim).uniform_(-1, 1)
        sample_noisev = autograd.Variable(sample_noise.cuda(), requires_grad=False)


        #plot the loss
        if (epoch_iter+1)%10 == 0:
          plot_loss(plot_history_G,plot_history_D)
        #showing the wave plt
        if (epoch_iter+1)%50 == 0:
          sample_output = G(sample_noisev)
          sample_output = sample_output.cpu()
          showing_wave(sample_output.data.numpy()[0])
          
        if (epoch_iter+1)%200 == 0:
          print('saving the sample...') 
          try:
            save_sample(sample_output.data.numpy(),sample_size,epoch_iter+1)
          except:
            print('failed to save sample')
        
        
          

Epoch: 1/500
Batch: 1/1 ##

KeyboardInterrupt: ignored