In [2]:
import torch
from collections import OrderedDict
from torch.autograd import Variable
from torch.optim import lr_scheduler
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
from types import SimpleNamespace
import random
import librosa
import pescador
import os
import numpy as np
import itertools
import time

In [3]:
def tensor2audio(audio_tensor):
    audio_numpy = audio_tensor[0].cpu().float().numpy()
    return audio_numpy

In [4]:
class BaseModel():
    def name(self):
        return 'BaseModel'

    def initialize(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)

    def set_input(self, input):
        self.input = input

    def forward(self):
        pass

    # used in test time, no backprop
    def test(self):
        pass

    def get_audio_paths(self):
        pass

    def optimize_parameters(self):
        pass

    def get_current_audibles(self):
        return self.input

    def get_current_errors(self):
        return {}

    def save(self, label):
        pass

    # helper saving function that can be used by subclasses
    def save_network(self, network, network_label, epoch_label, gpu_ids):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if len(gpu_ids) and torch.cuda.is_available():
            network.cuda(gpu_ids[0])

    # helper loading function that can be used by subclasses
    def load_network(self, network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        network.load_state_dict(torch.load(save_path))

    # update learning rate (called once every epoch)
    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()
        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate = %.7f' % lr)

In [5]:
class AudioPool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.audio_examples = []

    def query(self, audio_examples):
        if self.pool_size == 0:
            return Variable(audio_examples)
        return_audio_examples = []
        for audio_data in audio_examples:
            audio_data = torch.unsqueeze(audio_data, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.audio_examples.append(audio_data)
                return_audio_examples.append(audio_data)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size - 1)
                    tmp = self.audio_examples[random_id].clone()
                    self.audio_examples[random_id] = audio_data
                    return_audio_examples.append(tmp)
                else:
                    return_audio_examples.append(audio_data)
        return_audio_examples = Variable(torch.cat(return_audio_examples, 0))
        return return_audio_examples

In [6]:

def get_scheduler(optimizer, opt):
    if opt.lr_policy == 'lambda':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
            return lr_l
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif opt.lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
    elif opt.lr_policy == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
    return scheduler

In [7]:
def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)

In [8]:
class CycleGANModel(BaseModel):
    def name(self):
        return 'CycleGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = WaveGANGenerator(model_size=opt.model_size, ngpus=opt.ngpus,
                                       num_channels=opt.num_channels,
                                       latent_dim=opt.latent_dim, alpha=opt.alpha,
                                       post_proc_filt_len=opt.post_proc_filt_len, verbose=opt.verbose)
        self.netG_B = WaveGANGenerator(model_size=opt.model_size, ngpus=opt.ngpus,
                                       num_channels=opt.num_channels,
                                       latent_dim=opt.latent_dim, alpha=opt.alpha,
                                       post_proc_filt_len=opt.post_proc_filt_len, verbose=opt.verbose)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = WaveGANDiscriminator(model_size=opt.model_size, ngpus=opt.ngpus,
                                               num_channels=opt.num_channels, shift_factor=opt.shift_factor,
                                               alpha=opt.alpha, batch_shuffle=opt.batch_shuffle,
                                               verbose=opt.verbose)
            self.netD_B = WaveGANDiscriminator(model_size=opt.model_size, ngpus=opt.ngpus,
                                               num_channels=opt.num_channels, shift_factor=opt.shift_factor,
                                               alpha=opt.alpha, batch_shuffle=opt.batch_shuffle,
                                               verbose=opt.verbose)
            

        if self.isTrain:
            self.fake_A_pool = AudioPool(opt.pool_size)
            self.fake_B_pool = AudioPool(opt.pool_size)
            # define loss functions
            self.criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, opt.beta2))
            self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
            self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            for optimizer in self.optimizers:
                self.schedulers.append(get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        print_network(self.netG_A)
        print_network(self.netG_B)
        if self.isTrain:
            print_network(self.netD_A)
            print_network(self.netD_B)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.input_B = input_B
        self.audio_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        real_A = Variable(self.input_A, volatile=True)
        fake_B = self.netG_A(real_A)
        self.rec_A = self.netG_B(fake_B).data
        self.fake_B = fake_B.data

        real_B = Variable(self.input_B, volatile=True)
        fake_A = self.netG_B(real_B)
        self.rec_B = self.netG_A(fake_A).data
        self.fake_A = fake_A.data

    # get audio paths
    def get_audio_paths(self):
        return self.audio_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
        self.loss_D_A = loss_D_A.data[0]

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
        self.loss_D_B = loss_D_B.data[0]

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            idt_A = self.netG_A(self.real_B)
            loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            idt_B = self.netG_B(self.real_A)
            loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt

            self.idt_A = idt_A.data
            self.idt_B = idt_B.data
            self.loss_idt_A = loss_idt_A.data[0]
            self.loss_idt_B = loss_idt_B.data[0]
        else:
            loss_idt_A = 0
            loss_idt_B = 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        fake_B = self.netG_A(self.real_A)
        pred_fake = self.netD_A(fake_B)
        loss_G_A = self.criterionGAN(pred_fake, True)

        # GAN loss D_B(G_B(B))
        fake_A = self.netG_B(self.real_B)
        pred_fake = self.netD_B(fake_A)
        loss_G_B = self.criterionGAN(pred_fake, True)

        # Forward cycle loss
        rec_A = self.netG_B(fake_B)
        loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A

        # Backward cycle loss
        rec_B = self.netG_A(fake_A)
        loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B
        # combined loss
        loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
        loss_G.backward()

        self.fake_B = fake_B.data
        self.fake_A = fake_A.data
        self.rec_A = rec_A.data
        self.rec_B = rec_B.data

        self.loss_G_A = loss_G_A.data[0]
        self.loss_G_B = loss_G_B.data[0]
        self.loss_cycle_A = loss_cycle_A.data[0]
        self.loss_cycle_B = loss_cycle_B.data[0]

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A
        self.optimizer_D_A.zero_grad()
        self.backward_D_A()
        self.optimizer_D_A.step()
        # D_B
        self.optimizer_D_B.zero_grad()
        self.backward_D_B()
        self.optimizer_D_B.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A),
                                  ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)])
        if self.opt.lambda_identity > 0.0:
            ret_errors['idt_A'] = self.loss_idt_A
            ret_errors['idt_B'] = self.loss_idt_B
        return ret_errors

    def get_current_audibles(self):
        real_A = tensor2audio(self.input_A)
        fake_B = tensor2audio(self.fake_B)
        rec_A = tensor2audio(self.rec_A)
        real_B = tensor2audio(self.input_B)
        fake_A = tensor2audio(self.fake_A)
        rec_B = tensor2audio(self.rec_B)
        ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
                                   ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
        if self.opt.isTrain and self.opt.lambda_identity > 0.0:
            ret_visuals['idt_A'] = tensor2audio(self.idt_A)
            ret_visuals['idt_B'] = tensor2audio(self.idt_B)
        return ret_visuals

    def save(self, label):
        self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
        self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
        self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
        self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)

In [9]:
# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
                 tensor=torch.FloatTensor):
        super(GANLoss, self).__init__()
        self.real_label = target_real_label
        self.fake_label = target_fake_label
        self.real_label_var = None
        self.fake_label_var = None
        self.Tensor = tensor
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        target_tensor = None
        if target_is_real:
            create_label = ((self.real_label_var is None) or
                            (self.real_label_var.numel() != input.numel()))
            if create_label:
                real_tensor = self.Tensor(input.size()).fill_(self.real_label)
                self.real_label_var = Variable(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else:
            create_label = ((self.fake_label_var is None) or
                            (self.fake_label_var.numel() != input.numel()))
            if create_label:
                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
                self.fake_label_var = Variable(fake_tensor, requires_grad=False)
            target_tensor = self.fake_label_var
        return target_tensor

    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor)

In [10]:
class WaveGANGenerator(nn.Module):
    def __init__(self, model_size=64, ngpus=1, num_channels=1, latent_dim=100, alpha=0.2,
                 post_proc_filt_len=512, verbose=False):
        super(WaveGANGenerator, self).__init__()
        self.ngpus = ngpus
        self.model_size = model_size # d
        self.num_channels = num_channels # c
        self.latent_dim = latent_dim
        self.post_proc_filt_len = post_proc_filt_len
        self.alpha = alpha
        self.verbose = verbose
        
        self.conv1 = nn.DataParallel(nn.Conv1d(num_channels, model_size, 25, stride=4, padding=11))
        self.conv2 = nn.DataParallel(
            nn.Conv1d(model_size, 2 * model_size, 25, stride=4, padding=11))
        self.conv3 = nn.DataParallel(
            nn.Conv1d(2 * model_size, 4 * model_size, 25, stride=4, padding=11))
        self.conv4 = nn.DataParallel(
            nn.Conv1d(4 * model_size, 8 * model_size, 25, stride=4, padding=11))
        self.conv5 = nn.DataParallel(
            nn.Conv1d(8 * model_size, 16 * model_size, 25, stride=4, padding=11))

        self.fc1 = nn.DataParallel(nn.Linear(256 * model_size, latent_dim))
        self.fc2 = nn.DataParallel(nn.Linear(latent_dim, 256 * model_size))

        self.tconv1 = nn.DataParallel(
            nn.ConvTranspose1d(16 * model_size, 8 * model_size, 25, stride=4, padding=11,
                               output_padding=1))
        self.tconv2 = nn.DataParallel(
            nn.ConvTranspose1d(8 * model_size, 4 * model_size, 25, stride=4, padding=11,
                               output_padding=1))
        self.tconv3 = nn.DataParallel(
            nn.ConvTranspose1d(4 * model_size, 2 * model_size, 25, stride=4, padding=11,
                               output_padding=1))
        self.tconv4 = nn.DataParallel(
            nn.ConvTranspose1d(2 * model_size, model_size, 25, stride=4, padding=11,
                               output_padding=1))
        self.tconv5 = nn.DataParallel(
            nn.ConvTranspose1d(model_size, num_channels, 25, stride=4, padding=11,
                               output_padding=1))

        if post_proc_filt_len:
            self.ppfilter1 = nn.DataParallel(nn.Conv1d(num_channels, num_channels, post_proc_filt_len))

        for m in self.modules():
            if isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal(m.weight.data)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)

        x = F.leaky_relu(self.conv2(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)

        x = F.leaky_relu(self.conv3(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)

        x = F.leaky_relu(self.conv4(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)

        x = F.leaky_relu(self.conv5(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)

        x = x.view(-1, 256 * self.model_size)
        if self.verbose:
            print(x.shape)

        x = F.relu(self.fc1(x))
        
        x = F.relu(self.fc2(x)).view(-1, 16 * self.model_size, 16)
        if self.verbose:
            print(x.shape)

        x = F.relu(self.tconv1(x))
        if self.verbose:
            print(x.shape)

        x = F.relu(self.tconv2(x))
        if self.verbose:
            print(x.shape)

        x = F.relu(self.tconv3(x))
        if self.verbose:
            print(x.shape)

        x = F.relu(self.tconv4(x))
        if self.verbose:
            print(x.shape)

        output = F.tanh(self.tconv5(x))
        if self.verbose:
            print(output.shape)

        if self.post_proc_filt_len:
            # Pad for "same" filtering
            if (self.post_proc_filt_len % 2) == 0:
                pad_left = self.post_proc_filt_len // 2
                pad_right = pad_left - 1
            else:
                pad_left = (self.post_proc_filt_len - 1) // 2
                pad_right = pad_left
            output = self.ppfilter1(F.pad(output, (pad_left, pad_right)))
            if self.verbose:
                print(output.shape)

        return output

In [11]:
class PhaseShuffle(nn.Module):
    """
    Performs phase shuffling, i.e. shifting feature axis of a 3D tensor
    by a random integer in {-n, n} and performing reflection padding where
    necessary
    """
    def __init__(self, n, batch_shuffle=True):
        super(PhaseShuffle, self).__init__()
        self.n = n
        self.batch_shuffle = batch_shuffle
        
    def forward(self, x):
        # Make sure to use PyTorch to generate number RNG state is all shared
        if self.batch_shuffle:                                                              
            # Make sure to use PyTorcTrueh to generate number RNG state is all shared       
            k = int(torch.Tensor(1).random_(0, 2*self.n + 1)) - self.n

            # Return if no phase shift                                                      
            if k == 0:                                                                      
                return x                                                                    

            # Slice feature dimension                                                       
            if k > 0:                                                                       
                x_trunc = x[:, :, :-k]                                                      
                pad = (k, 0)                                                                
            else:                                                                           
                x_trunc = x[:, :, -k:]                                                      
                pad = (0, -k)                                                               

            # Reflection padding                                                            
            x_shuffle = F.pad(x_trunc, pad, mode='reflect')                                 

        else:                                                                               
            k_list = torch.Tensor(x.shape[0]).random_(0, 2*self.n+1) - self.n                                                         
            k_list = k_list.numpy().astype(int)                                             
            x_shuffle = x.clone()                                                           

            for idx, k in enumerate(k_list):                                                
                k = int(k)                                                                  
                if k > 0:                                                                   
                    xi_trunc = x[idx:idx+1, :, :-k]                                         
                    pad = (k, 0)                                                            
                else:                                                                       
                    xi_trunc = x[idx:idx+1, :, -k:]                                         
                    pad = (0, -k)                                                           

                x_shuffle[idx:idx+1] = F.pad(xi_trunc, pad, mode='reflect')                 


        assert x_shuffle.shape == x.shape, "{}, {}".format(x_shuffle.shape, x.shape)
        return x_shuffle
        

In [12]:
class WaveGANDiscriminator(nn.Module):
    def __init__(self, model_size=64, ngpus=1, num_channels=1, shift_factor=2, alpha=0.2, batch_shuffle=False, verbose=False):
        super(WaveGANDiscriminator, self).__init__()
        self.model_size = model_size # d
        self.ngpus = ngpus
        self.num_channels = num_channels # c
        self.shift_factor = shift_factor # n
        self.alpha = alpha
        self.verbose = verbose
        # Conv2d(in_channels, out_channels, kernel_size, stride=1, etc.)
        self.conv1 = nn.DataParallel(nn.Conv1d(num_channels, model_size, 25, stride=4, padding=11))
        self.conv2 = nn.DataParallel(
            nn.Conv1d(model_size, 2 * model_size, 25, stride=4, padding=11))
        self.conv3 = nn.DataParallel(
            nn.Conv1d(2 * model_size, 4 * model_size, 25, stride=4, padding=11))
        self.conv4 = nn.DataParallel(
            nn.Conv1d(4 * model_size, 8 * model_size, 25, stride=4, padding=11))
        self.conv5 = nn.DataParallel(
            nn.Conv1d(8 * model_size, 16 * model_size, 25, stride=4, padding=11))
        self.ps1 = PhaseShuffle(shift_factor, batch_shuffle=batch_shuffle)
        self.ps2 = PhaseShuffle(shift_factor, batch_shuffle=batch_shuffle)
        self.ps3 = PhaseShuffle(shift_factor, batch_shuffle=batch_shuffle)
        self.ps4 = PhaseShuffle(shift_factor, batch_shuffle=batch_shuffle)
        self.fc1 = nn.DataParallel(nn.Linear(256 * model_size, 1))

        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal(m.weight.data)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)
        x = self.ps1(x)

        x = F.leaky_relu(self.conv2(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)
        x = self.ps2(x)

        x = F.leaky_relu(self.conv3(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)
        x = self.ps3(x)

        x = F.leaky_relu(self.conv4(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)
        x = self.ps4(x)

        x = F.leaky_relu(self.conv5(x), negative_slope=self.alpha)
        if self.verbose:
            print(x.shape)

        x = x.view(-1, 256 * self.model_size)
        if self.verbose:
            print(x.shape)

        return F.sigmoid(self.fc1(x))

In [13]:
def file_sample_generator(filepath, window_length=16384, fs=16000):
    """
    Audio sample generator
    """
    try:
        audio_data, _ = librosa.load(filepath, sr=fs)
    except Exception as e:
        raise StopIteration()
        
    audio_len = len(audio_data)
    
    # Pad audio to at least a single frame
    if audio_len < window_length:
        pad_length = window_length - audio_len
        left_pad = pad_length // 2
        right_pad = pad_length - left_pad
        
        audio_data = np.pad(audio_data, (left_pad, right_pad), mode='constant')
        audio_len = len(audio_data)
        
    while True:
        if audio_len == window_length:
            # If we only have a single frame's worth of audio, just yield the whole audio
            sample = audio_data
        else:
            # Sample a random window from the audio file
            start_idx = np.random.randint(0,audio_len - window_length)
            end_idx = start_idx + window_length
            sample = audio_data[start_idx:end_idx]
            
        sample = sample.astype('float32')
        assert not np.any(np.isnan(sample))
            
        yield {'X': sample}
    
def create_batch_generator(audio_filepath_list, batch_size):
    streamers = []
    for audio_filepath in audio_filepath_list:
        s = pescador.Streamer(file_sample_generator, audio_filepath)
        streamers.append(s)
        
    mux = pescador.ShuffledMux(streamers)
    batch_gen = pescador.buffer_stream(mux, batch_size)
    
    return batch_gen

def get_all_audio_filepaths(audio_dir):
    return [os.path.join(root, fname)
            for (root, dir_names, file_names) in os.walk(audio_dir)
            for fname in file_names
            if fname.lower().endswith('.wav')]

def create_data_split(audio_filepath_list, valid_ratio, test_ratio, train_batch_size, valid_size, test_size):
    num_files = len(audio_filepath_list)
    num_valid = int(np.ceil(num_files * valid_ratio))
    num_test = int(np.ceil(num_files * test_ratio))
    num_train = num_files - num_valid - num_test
    
    assert num_valid > 0
    assert num_test > 0
    assert num_train > 0
    
    valid_files = audio_filepath_list[:num_valid]
    test_files = audio_filepath_list[num_valid:num_valid+num_test]
    train_files = audio_filepath_list[num_valid+num_test:]
    
    train_gen = create_batch_generator(train_files, train_batch_size)
    valid_data = next(iter(create_batch_generator(valid_files, valid_size)))
    test_data = next(iter(create_batch_generator(train_files, test_size)))
    
    return train_gen, valid_data, test_data

In [14]:
def np_to_input_tensor(data, use_cuda):
    data = data[:,np.newaxis,:]
    data = torch.Tensor(data)
    if use_cuda:
        data = data.cuda()
    return data

In [15]:
batch_size = 64
audio_dir_A = "/beegfs/jtc440/aml/TheDrumClub-Kit004-THEMEGABUNDLE/Rhodes Polaris"
audio_filepaths_A = get_all_audio_filepaths(audio_dir_A)
genA, valid_data_A, test_data_A = create_data_split(audio_filepaths_A, 0.1, 0.1, batch_size, 64, 64)

audio_dir_B = "/beegfs/jtc440/aml/TheDrumClub-Kit004-THEMEGABUNDLE/Korg M1"
audio_filepaths_B = get_all_audio_filepaths(audio_dir_B)
genB, valid_data_B, test_data_B = create_data_split(audio_filepaths_B, 0.1, 0.1, batch_size, 64, 64)


In [24]:
def save_tidegan_samples(output_dir, current_audibles, step, fs=16000):
    samples_dir = os.path.join(output_dir, 'samples', str(step))
    if not os.path.exists(samples_dir):
        os.makedirs(samples_dir)
        
    for out_type, data in current_audibles.items():
        for idx, sample in enumerate(data):
            output_path = os.path.join(samples_dir, "{}_{}.wav".format(out_type, idx+1))
            librosa.output.write_wav(output_path, sample, sr=fs)

In [25]:
opt = SimpleNamespace(
    epoch_count=0,
    niter=50,
    niter_decay = 0,
    print_freq = 1,
    batchSize = batch_size,
    display_freq = 1,
    update_html_freq = 1,
    display_id = 4,
    save_latest_freq = 1,
    isTrain = True,
    gpu_ids = list(range(torch.cuda.device_count())),
    checkpoints_dir='/scratch/jtc440/cyclegan',
    name='tidegan',
    model_size=64,
    ngpus=1,
    num_channels=1,
    latent_dim=100,
    alpha=0.2,
    post_proc_filt_len=512,
    batch_shuffle=False,
    verbose=False,
    shift_factor=2,
    pool_size=50,
    lr=0.0002,
    beta1=0.5,
    beta2=0.999,
    lambda_identity=0.5,
    batches_per_epoch=10,
    lambda_A=10,
    lambda_B=10,
    no_lsgan=False,
    which_direction='AtoB',
    lr_policy='lambda',
    lr_decay_iters=50,
    use_cuda=True,
    save_epoch_freq=1,
)

model_dir = os.path.join(opt.checkpoints_dir, opt.name)
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

model = CycleGANModel()
model.initialize(opt)
#visualizer = Visualizer(opt)
total_steps = 0

for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
    epoch_start_time = time.time()
    iter_data_time = time.time()
    epoch_iter = 0

    for batch_idx in range(opt.batches_per_epoch):
        data_A = np_to_input_tensor(next(genA)['X'], use_cuda=opt.use_cuda)
        data_B = np_to_input_tensor(next(genB)['X'], use_cuda=opt.use_cuda)  
        data = {'A': data_A, 'B': data_B, 'A_paths': [], 'B_paths': []}
        
        iter_start_time = time.time()
        if total_steps % opt.print_freq == 0:
            t_data = iter_start_time - iter_data_time
        #visualizer.reset()
        total_steps += opt.batchSize
        epoch_iter += opt.batchSize
        model.set_input(data)
        model.optimize_parameters()

        if total_steps % opt.display_freq == 0:
            save_result = total_steps % opt.update_html_freq == 0
            save_tidegan_samples(opt.checkpoints_dir, model.get_current_audibles(), total_steps)
            #visualizer.display_current_results(, epoch, save_result)

        if total_steps % opt.print_freq == 0:
            errors = model.get_current_errors()
            t = (time.time() - iter_start_time) / opt.batchSize
            #visualizer.print_current_errors(epoch, epoch_iter, errors, t, t_data)
            #if opt.display_id > 0:
            #    visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors)

        if total_steps % opt.save_latest_freq == 0:
            print('saving the latest model (epoch %d, total_steps %d)' %
                  (epoch, total_steps))
            model.save('latest')

        iter_data_time = time.time()
    if epoch % opt.save_epoch_freq == 0:
        print('saving the model at the end of epoch %d, iters %d' %
              (epoch, total_steps))
        model.save('latest')
        model.save(epoch)

    print('End of epoch %d / %d \t Time Taken: %d sec' %
          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
    model.update_learning_rate()

---------- Networks initialized -------------
WaveGANGenerator(
  (conv1): DataParallel(
    (module): Conv1d(1, 64, kernel_size=(25,), stride=(4,), padding=(11,))
  )
  (conv2): DataParallel(
    (module): Conv1d(64, 128, kernel_size=(25,), stride=(4,), padding=(11,))
  )
  (conv3): DataParallel(
    (module): Conv1d(128, 256, kernel_size=(25,), stride=(4,), padding=(11,))
  )
  (conv4): DataParallel(
    (module): Conv1d(256, 512, kernel_size=(25,), stride=(4,), padding=(11,))
  )
  (conv5): DataParallel(
    (module): Conv1d(512, 1024, kernel_size=(25,), stride=(4,), padding=(11,))
  )
  (fc1): DataParallel(
    (module): Linear(in_features=16384, out_features=100, bias=True)
  )
  (fc2): DataParallel(
    (module): Linear(in_features=100, out_features=16384, bias=True)
  )
  (tconv1): DataParallel(
    (module): ConvTranspose1d(1024, 512, kernel_size=(25,), stride=(4,), padding=(11,), output_padding=(1,))
  )
  (tconv2): DataParallel(
    (module): ConvTranspose1d(512, 256, kernel_

saving the latest model (epoch 5, total_steps 3584)
saving the latest model (epoch 5, total_steps 3648)
saving the latest model (epoch 5, total_steps 3712)
saving the latest model (epoch 5, total_steps 3776)
saving the latest model (epoch 5, total_steps 3840)
saving the model at the end of epoch 5, iters 3840
End of epoch 5 / 50 	 Time Taken: 57 sec
learning rate = 0.0002000
saving the latest model (epoch 6, total_steps 3904)
saving the latest model (epoch 6, total_steps 3968)
saving the latest model (epoch 6, total_steps 4032)
saving the latest model (epoch 6, total_steps 4096)
saving the latest model (epoch 6, total_steps 4160)
saving the latest model (epoch 6, total_steps 4224)
saving the latest model (epoch 6, total_steps 4288)
saving the latest model (epoch 6, total_steps 4352)
saving the latest model (epoch 6, total_steps 4416)
saving the latest model (epoch 6, total_steps 4480)
saving the model at the end of epoch 6, iters 4480
End of epoch 6 / 50 	 Time Taken: 55 sec
learning r

saving the latest model (epoch 18, total_steps 11648)
saving the latest model (epoch 18, total_steps 11712)
saving the latest model (epoch 18, total_steps 11776)
saving the latest model (epoch 18, total_steps 11840)
saving the latest model (epoch 18, total_steps 11904)
saving the latest model (epoch 18, total_steps 11968)
saving the latest model (epoch 18, total_steps 12032)
saving the latest model (epoch 18, total_steps 12096)
saving the latest model (epoch 18, total_steps 12160)
saving the model at the end of epoch 18, iters 12160
End of epoch 18 / 50 	 Time Taken: 55 sec
learning rate = 0.0002000
saving the latest model (epoch 19, total_steps 12224)
saving the latest model (epoch 19, total_steps 12288)
saving the latest model (epoch 19, total_steps 12352)
saving the latest model (epoch 19, total_steps 12416)
saving the latest model (epoch 19, total_steps 12480)
saving the latest model (epoch 19, total_steps 12544)
saving the latest model (epoch 19, total_steps 12608)
saving the late

saving the latest model (epoch 30, total_steps 19648)
saving the latest model (epoch 30, total_steps 19712)
saving the latest model (epoch 30, total_steps 19776)
saving the latest model (epoch 30, total_steps 19840)
saving the model at the end of epoch 30, iters 19840
End of epoch 30 / 50 	 Time Taken: 55 sec
learning rate = 0.0002000
saving the latest model (epoch 31, total_steps 19904)
saving the latest model (epoch 31, total_steps 19968)
saving the latest model (epoch 31, total_steps 20032)
saving the latest model (epoch 31, total_steps 20096)
saving the latest model (epoch 31, total_steps 20160)
saving the latest model (epoch 31, total_steps 20224)
saving the latest model (epoch 31, total_steps 20288)
saving the latest model (epoch 31, total_steps 20352)
saving the latest model (epoch 31, total_steps 20416)
saving the latest model (epoch 31, total_steps 20480)
saving the model at the end of epoch 31, iters 20480
End of epoch 31 / 50 	 Time Taken: 55 sec
learning rate = 0.0002000
sa

End of epoch 42 / 50 	 Time Taken: 55 sec
learning rate = 0.0002000
saving the latest model (epoch 43, total_steps 27584)
saving the latest model (epoch 43, total_steps 27648)
saving the latest model (epoch 43, total_steps 27712)
saving the latest model (epoch 43, total_steps 27776)
saving the latest model (epoch 43, total_steps 27840)
saving the latest model (epoch 43, total_steps 27904)
saving the latest model (epoch 43, total_steps 27968)
saving the latest model (epoch 43, total_steps 28032)
saving the latest model (epoch 43, total_steps 28096)
saving the latest model (epoch 43, total_steps 28160)
saving the model at the end of epoch 43, iters 28160
End of epoch 43 / 50 	 Time Taken: 55 sec
learning rate = 0.0002000
saving the latest model (epoch 44, total_steps 28224)
saving the latest model (epoch 44, total_steps 28288)
saving the latest model (epoch 44, total_steps 28352)
saving the latest model (epoch 44, total_steps 28416)
saving the latest model (epoch 44, total_steps 28480)
s