In [1]:
from torch.utils import data
import torch
import numpy as np
import os, pdb, pickle, random
       
from multiprocessing import Process, Manager   


class Utterances(data.Dataset):
    """Dataset class for the Utterances dataset."""

    # this object will contain both melspecs and speaker embeddings taken from the train.pkl
    def __init__(self, config):
        """Initialize and preprocess the Utterances dataset."""
        self.root_dir = config.data_dir
        self.len_crop = config.len_crop
        self.step = 10
        self.file_name = config.file_name
        self.one_hot = config.one_hot

        # metaname = os.path.join(self.root_dir, "all_meta_data.pkl")
        meta_all_data = pickle.load(open('./all_meta_data.pkl', "rb"))
        # split into training data
        num_training_speakers=config.train_size
        random.seed(1)
        training_indices =  random.sample(range(0, len(meta_all_data)), num_training_speakers)
        training_set = []

        meta_training_speaker_all_uttrs = []
        # make list of training speakers
        for idx in training_indices:
            meta_training_speaker_all_uttrs.append(meta_all_data[idx])
        # get training files
        for speaker_info in meta_training_speaker_all_uttrs:
            speaker_id_emb = speaker_info[:2]
            speaker_uttrs = speaker_info[2:]
            num_files = len(speaker_uttrs) # first 2 entries are speaker ID and speaker_emb)
            training_file_num = round(num_files*0.9)
            training_file_indices = random.sample(range(0, num_files), training_file_num)

            training_file_names = []
            for index in training_file_indices:
                fileName = speaker_uttrs[index]
                training_file_names.append(fileName)
            training_set.append(speaker_id_emb+training_file_names)
            # training_file_names_array = np.asarray(training_file_names)
            # training_file_indices_array = np.asarray(training_file_indices)
            # test_file_indices = np.setdiff1d(np.arange(num_files_in_subdir), training_file_indices_array)
        meta = training_set
        # pdb.set_trace()
        with open('././my_data/my_autovc/model_data/' +self.file_name +'/training_meta_data.pkl', 'wb') as train_pack:
            pickle.dump(training_set, train_pack)
        # pdb.set_trace()

        training_info = pickle.load(open('././my_data/my_autovc/model_data/' +self.file_name +'/training_meta_data.pkl', 'rb'))
        num_speakers_seq = np.arange(len(training_info))
        self.one_hot_array = np.eye(len(training_info))[num_speakers_seq]
        self.spkr_id_list = [spkr[0] for spkr in training_info]

        """Load data using multiprocessing"""
        manager = Manager()
        meta = manager.list(meta)
        dataset = manager.list(len(meta)*[None])  
        processes = []
        # uses a different process thread for every self.steps of the meta content
        for i in range(0, len(meta), self.step):
            p = Process(target=self.load_data, 
                        args=(meta[i:i+self.step],dataset,i))  
            p.start()
            processes.append(p)
        for p in processes:
            p.join()
        
        # pdb.set_trace()    
        self.train_dataset = list(dataset)
        self.num_tokens = len(self.train_dataset)
        
        print('Finished loading the dataset...')
        
    # this function is called within the class init (after self.data_loader its the arguments) 
    def load_data(self, submeta, dataset, idx_offset):  
        for k, sbmt in enumerate(submeta):    
            uttrs = len(sbmt)*[None]
            # pdb.set_trace()
            for j, tmp in enumerate(sbmt):
                if j < 2:  # fill in speaker id and embedding
                    uttrs[j] = tmp
                else: # load the mel-spectrograms
                    uttrs[j] = np.load(os.path.join(self.root_dir, tmp))
            dataset[idx_offset+k] = uttrs
                   
    """__getitem__ selects a speaker and chooses a random subset of data (in this case
    an utterance) and randomly crops that data. It also selects the corresponding speaker
    embedding and loads that up. It will now also get corresponding pitch contour for such a file""" 
    def __getitem__(self, index):
        # pick a random speaker
        dataset = self.train_dataset 
        # list_uttrs is literally a list of utterance from a single speaker
        list_uttrs = dataset[index]
        # pdb.set_trace()
        emb_org = list_uttrs[1]
        speaker_name = list_uttrs[0]
        # pick random uttr with random crop
        a = np.random.randint(2, len(list_uttrs))
        uttr_info = list_uttrs[a]
        
        spmel_tmp = uttr_info
        #spmel_tmp = uttr_info[0]
        #pitch_tmp = uttr_info[1]
        if spmel_tmp.shape[0] < self.len_crop:
            len_pad = self.len_crop - spmel_tmp.shape[0]
            uttr = np.pad(spmel_tmp, ((0,len_pad),(0,0)), 'constant')
        #    pitch = np.pad(pitch_tmp, ((0,len_pad),(0,0)), 'constant')
        elif spmel_tmp.shape[0] > self.len_crop:
            left = np.random.randint(spmel_tmp.shape[0]-self.len_crop)
            uttr = spmel_tmp[left:left+self.len_crop, :]
        #    pitch = pitch_tmp[left:left+self.len_crop, :]
        else:
            uttr = spmel_tmp
        #    pitch = pitch_tmp    

        # find out where speaker is in the order of the training list for one-hot
        for i, spkr_id in enumerate(self.spkr_id_list):
            if speaker_name == spkr_id:
                spkr_label = i
                break
        one_hot_spkr_label = self.one_hot_array[spkr_label]
        if self.one_hot==False:
            return uttr, emb_org, speaker_name # pitch
        else:
            return uttr, one_hot_spkr_label, speaker_name

    def __len__(self):
        """Return the number of spkrs."""
        return self.num_tokens
    
    
    

def get_loader(config, num_workers=0):
    """Build and return a data loader."""
    
    dataset = Utterances(config)
    
    worker_init_fn = lambda x: np.random.seed((torch.initial_seed()) % (2**32))
    data_loader = data.DataLoader(dataset=dataset,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  drop_last=True,
                                  worker_init_fn=worker_init_fn)
    return data_loader

In [2]:
import torch, pdb
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class LinearNorm(torch.nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
        super(LinearNorm, self).__init__()
        self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)

        torch.nn.init.xavier_uniform_(
            self.linear_layer.weight,
            gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, x):
        return self.linear_layer(x)


class ConvNorm(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=None, dilation=1, bias=True, w_init_gain='linear'):
        super(ConvNorm, self).__init__()
        if padding is None:
            assert(kernel_size % 2 == 1)
            padding = int(dilation * (kernel_size - 1) / 2)

        self.conv = torch.nn.Conv1d(in_channels, out_channels,
                                    kernel_size=kernel_size, stride=stride,
                                    padding=padding, dilation=dilation,
                                    bias=bias)

        torch.nn.init.xavier_uniform_(
            self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

    def forward(self, signal):
        conv_signal = self.conv(signal)
        return conv_signal

# "4.2. The Content Encoder"
class Encoder(nn.Module):
    """Encoder module:
    """
    def __init__(self, dim_neck, dim_emb, freq):
        super(Encoder, self).__init__()
        self.dim_neck = dim_neck
        self.freq = freq
        convolutions = []
        for i in range(3):
        # "the input to the content encoder is the 80-dimensional mel-spectrogram of X1 concatenated with the speaker embedding" - I think the embeddings are copy pasted from a dataset, as the Speaker Decoder is pretrained and may not actually appear in this implementation?
            conv_layer = nn.Sequential(
        # "the input to the content encoder is the 80-dimensional mel-spectrogram of X1 concatenated with the speaker embedding. The concatenated features are fed into three 5 Ã— 1 convolutional layers, each followed by batch normalization and ReLU activation. The number of channels is 512"
                ConvNorm(80+dim_emb if i==0 else 512,
                         512,
                         kernel_size=5, stride=1,
                         padding=2,
                         dilation=1, w_init_gain='relu'),
                nn.BatchNorm1d(512))
            convolutions.append(conv_layer)
        self.convolutions = nn.ModuleList(convolutions)
        
        # "Both the forward and backward cell dimensions are 32, so their (LSTMs) combined dimension is 64."
        self.lstm = nn.LSTM(512, dim_neck, 2, batch_first=True, bidirectional=True)

        # c_org is speaker embedding
    def forward(self, x, c_org):
        x = x.squeeze(1).transpose(2,1)
        # broadcasts c_org to a compatible shape to merge with x
        c_org = c_org.unsqueeze(-1).expand(-1, -1, x.size(-1))
        x = torch.cat((x, c_org), dim=1)
        for conv in self.convolutions:
            x = F.relu(conv(x))
        x = x.transpose(1, 2)
        
        self.lstm.flatten_parameters()
        # lstms output 64 dim
        outputs, _ = self.lstm(x)
        # backward is the first half of dimensions, forward is the second half
        # pdb.set_trace()
        out_forward = outputs[:, :, :self.dim_neck]
        out_backward = outputs[:, :, self.dim_neck:]

        # pdb.set_trace()
        codes = []
        
        # for each timestep, skipping self.freq frames
        for i in range(0, outputs.size(1), self.freq):
            # remeber that i is self.freq, not increments of 1)
            codes.append(torch.cat((out_forward[:,i+self.freq-1,:],out_backward[:,i,:]), dim=-1))
        
        # if self.freq is 32, then codes is a list of 4 tensors of size 64
        return codes
      
        
class Decoder(nn.Module):
    """Decoder module:
    """
    def __init__(self, dim_neck, dim_emb, dim_pre):
        super(Decoder, self).__init__()
        
        self.lstm1 = nn.LSTM(dim_neck*2+dim_emb, dim_pre, 1, batch_first=True)
        
        convolutions = []
        for i in range(3):
            conv_layer = nn.Sequential(
                ConvNorm(dim_pre,
                         dim_pre,
                         kernel_size=5, stride=1,
                         padding=2,
                         dilation=1, w_init_gain='relu'),
                nn.BatchNorm1d(dim_pre))
            convolutions.append(conv_layer)
        self.convolutions = nn.ModuleList(convolutions)
        
        self.lstm2 = nn.LSTM(dim_pre, 1024, 2, batch_first=True)
        self.linear_projection = LinearNorm(1024, 80)

    def forward(self, x):
        
        #self.lstm1.flatten_parameters()
        x, _ = self.lstm1(x)
        x = x.transpose(1, 2)
        
        for conv in self.convolutions:
            x = F.relu(conv(x))
        x = x.transpose(1, 2)
        
        outputs, _ = self.lstm2(x)
        
        decoder_output = self.linear_projection(outputs)

        return decoder_output   
    
# Still part of Decoder as indicated in paper Fig. 3 (c) - last two blocks 
class Postnet(nn.Module):
    """Postnet
        - Five 1-d convolution with 512 channels and kernel size 5
    """

    def __init__(self):
        super(Postnet, self).__init__()
        self.convolutions = nn.ModuleList()

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(80, 512,
                         kernel_size=5, stride=1,
                         padding=2,
                         dilation=1, w_init_gain='tanh'),
                nn.BatchNorm1d(512))
        )

        for i in range(1, 5 - 1):
            self.convolutions.append(
                nn.Sequential(
                    ConvNorm(512,
                             512,
                             kernel_size=5, stride=1,
                             padding=2,
                             dilation=1, w_init_gain='tanh'),
                    nn.BatchNorm1d(512))
            )

        self.convolutions.append(
            nn.Sequential(
                ConvNorm(512, 80,
                         kernel_size=5, stride=1,
                         padding=2,
                         dilation=1, w_init_gain='linear'),
                nn.BatchNorm1d(80))
            )

    def forward(self, x):
        for i in range(len(self.convolutions) - 1):
            x = torch.tanh(self.convolutions[i](x))

        x = self.convolutions[-1](x)

        return x    
    

class Generator(nn.Module):
    """Generator network."""
    def __init__(self, dim_neck, dim_emb, dim_pre, freq):
        super(Generator, self).__init__()
        
        self.encoder = Encoder(dim_neck, dim_emb, freq)
        self.decoder = Decoder(dim_neck, dim_emb, dim_pre)
        self.postnet = Postnet()

    def forward(self, x, c_org, c_trg):

        # codes is a LIST of tensors                
        codes = self.encoder(x, c_org)
        # if no c_trg given, then just return the formatted encoder codes
        if c_trg is None:
            # concatenates the by stacking over the last (in 2D this would be vertical) dimensio by stacking over the last (in 2D this would be vertical) dimension. For lists it means the same
            return torch.cat(codes, dim=-1)

        # list of reformatted codes        
        tmp = []
        for code in codes:
            # reformatting tmp from list to tensor, and giving it new dim of 128 (x.size(1))
            tmp.append(code.unsqueeze(1).expand(-1,int(x.size(1)/len(codes)),-1))
        code_exp = torch.cat(tmp, dim=1)
        
        # concat reformated encoder output with target speaker embedding
        encoder_outputs = torch.cat((code_exp, c_trg.unsqueeze(1).expand(-1,x.size(1),-1)), dim=-1)
        mel_outputs = self.decoder(encoder_outputs)
        # then put mel_ouputs through remaining postnet section of NN
        # the postnet process produces the RESIDUAL information that gets added to the mel output
        mel_outputs_postnet = self.postnet(mel_outputs.transpose(2,1))
        #pdb.set_trace() 
        # add together, as done in Fig. 3 (c) ensuring the mel_out_psnt is same shape (2,128,80). new mel_out_psnt will be the same
        mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2,1)
       
        #insert channel dimension into tensors to become (2,1,128,80)
        mel_outputs = mel_outputs.unsqueeze(1)
        mel_outputs_postnet = mel_outputs_postnet.unsqueeze(1)
        
        return mel_outputs, mel_outputs_postnet, torch.cat(codes, dim=-1)



In [3]:
import torch
import math
import utils
from scipy.signal import medfilt
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
import time, pdb
import datetime

# SOLVER IS THE MAIN SETUP FOR THE NN ARCHITECTURE. INSIDE SOLVER IS THE GENERATOR (G)
class Solver(object):

    def __init__(self, vcc_loader, config):
        """Initialize configurations."""
    
        
        # Data loader.
        self.vcc_loader = vcc_loader

        # Model configurations.
        self.lambda_cd = config.lambda_cd
        self.dim_neck = config.dim_neck
        self.dim_emb = config.dim_emb
        self.dim_pre = config.dim_pre
        self.freq = config.freq
        self.shape_adapt = config.shape_adapt
        self.which_cuda = config.which_cuda

        # Training configurations.
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.load_ckpts = config.load_ckpts
        self.file_name = config.file_name
        self.one_hot = config.one_hot
        self.psnt_loss_weight = config.psnt_loss_weight 
        self.prnt_loss_weight = config.prnt_loss_weight 
        self.adam_init = config.adam_init


        # Miscellaneous.
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device(f'cuda:{self.which_cuda}' if self.use_cuda else 'cpu')
        self.log_step = config.log_step
        self.shape_adapt = config.shape_adapt
        self.ckpt_freq = config.ckpt_freq
        self.spec_freq = config.spec_freq

         # Build the model and tensorboard.
        self.build_model()

    def build_model(self):
        
        self.G = Generator(self.dim_neck, self.dim_emb, self.dim_pre, self.freq)        
        
        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.adam_init)
        tester=1
        if self.load_ckpts!='':
            g_checkpoint = torch.load(self.load_ckpts)
            self.G.load_state_dict(g_checkpoint['model_state_dict'])
            self.g_optimizer.load_state_dict(g_checkpoint['optimizer_state_dict'])
            # fixes tensors on different devices error
            # https://github.com/pytorch/pytorch/issues/2830
            for state in self.g_optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.cuda()

            self.previous_ckpt_iters = g_checkpoint['iteration']
            tester=2
        else:
            self.previous_ckpt_iters = 0
        self.G.to(self.device)

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
      
    
    #=====================================================================================================================================#
   
        

    def train(self):
        # Set data loader.
        data_loader = self.vcc_loader
        hist_arr = np.array([0,0,0])
        # Print logs in specified order
        keys = ['G/loss_id','G/loss_id_psnt','G/loss_cd']
            
        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(self.previous_ckpt_iters, self.num_iters):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch data.
            # THE NEXT(DATA_ITER) FUNCTION USES THE DATALOADERS __GETITEM__ FUNCTION, SEEMINGLY
            # ITER AND NEXT FUNCTIONS WORK TOGETHER TO PRODUCE A COLLATED BATCH OF EXAMPLES 
            try:
                x_real, emb_org, speaker_name = next(data_iter)
            except:
                data_iter = iter(data_loader)
                x_real, emb_org, speaker_name = next(data_iter)
            
            
        
            x_real = x_real.to(self.device) 
            emb_org = emb_org.to(self.device).float() 
                        
       
            # =================================================================================== #
            #                               2. Train the generator                                #
            # =================================================================================== #
            # informs generator to be in train mode 
            self.G = self.G.train()
                        
            # Identity mapping loss
            # x_identic_psnt consists of the original mel + the residual definiton added ontop
            x_identic, x_identic_psnt, code_real = self.G(x_real, emb_org, emb_org)
            # SHAPES OF X_REAL AND X_INDETIC/PSNT ARE NOT THE SAME AND MAY GIVE INCORRECT LOSS VALUES
            residual_from_psnt = x_identic_psnt - x_identic
            # pdb.set_trace()
            if self.shape_adapt == True:
                x_identic = x_identic.squeeze(1)
                x_identic_psnt = x_identic_psnt.squeeze(1)
                residual_from_psnt = residual_from_psnt.squeeze(1)
            g_loss_id = F.l1_loss(x_real, x_identic)   
            g_loss_id_psnt = F.l1_loss(x_real, x_identic_psnt)   
            
            # Code semantic loss. For calculating this, there is no target embedding
            code_reconst = self.G(x_identic_psnt, emb_org, None)
            # gets the l1 loss between original encoder output and reconstructed encoder output
            g_loss_cd = F.l1_loss(code_real, code_reconst)


            # Backward and optimize.
            # interesting - the loss is a sum of the decoder loss and the melspec loss
            g_loss = (self.prnt_loss_weight * g_loss_id) + (self.psnt_loss_weight * g_loss_id_psnt) + (self.lambda_cd * g_loss_cd)
            self.reset_grad()
            g_loss.backward()
            #pdb.set_trace()
            self.g_optimizer.step()

            # Logging.
            loss = {}
            loss['G/loss_id'] = g_loss_id.item()
            loss['G/loss_id_psnt'] = g_loss_id_psnt.item()
            loss['G/loss_cd'] = g_loss_cd.item()
            
            if i==0:
                hist_arr = np.array([g_loss_id.item(), g_loss_id_psnt.item(), g_loss_cd.item()])
            else:
                temp_arr = np.array([g_loss_id.item(), g_loss_id_psnt.item(), g_loss_cd.item()])
                hist_arr = np.vstack((hist_arr, temp_arr))
            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #
            #pdb.set_trace()

            # Print out training information.
            if (i+1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
                for tag in keys:
                    log += ", {}: {:.4f}".format(tag, loss[tag])
                print(log)

            if (i+1) % self.spec_freq == 0:
                # save x and x_hat images
                x_real = x_real.cpu().data.numpy()
                if self.shape_adapt == True:
                    x_identic = x_identic.cpu().data.numpy()
                    x_identic_psnt = x_identic_psnt.cpu().data.numpy()
                    residual_from_psnt = residual_from_psnt.cpu().data.numpy()
                else:
                    x_identic = x_identic.squeeze(1).cpu().data.numpy()
                    x_identic_psnt = x_identic_psnt.squeeze(1).cpu().data.numpy()
                    residual_from_psnt = residual_from_psnt.squeeze(1).cpu().data.numpy()
                specs_list = []
                for arr in x_real:
                    specs_list.append(arr)
                for arr in x_identic:
                    specs_list.append(arr)
                for arr in residual_from_psnt:
                    specs_list.append(arr)
                for arr in x_identic_psnt:
                    specs_list.append(arr)
                columns = 2
                rows = 4
                fig, axs = plt.subplots(4,2)
                fig.tight_layout()
                for j in range(0, columns*rows):
                    spec = np.rot90(specs_list[j])
                    fig.add_subplot(rows, columns, j+1)
                    if j == 5 or j == 6:
                        #pdb.set_trace()
                        spec = spec - np.min(spec)
                        plt.clim(0,1)
                    plt.imshow(spec)
                    name = speaker_name[j%2]
                    plt.title(name)
                    plt.colorbar()
                plt.savefig('././my_data/my_autovc/model_data/' +self.file_name +'/image_comparison/' +str(i+1) +'iterations')
                plt.close(name)
                # save_recon_image(x_real, x_identic_psnt, speaker_name)    
                
            if (i+1) % self.ckpt_freq == 0:
                print('Saving model...')
                checkpoint = {'model_state_dict' : self.G.state_dict(),
                    'optimizer_state_dict': self.g_optimizer.state_dict(),
                    'iteration': i+1,
                    'loss': loss}
                torch.save(checkpoint, '././my_data/my_autovc/model_data/' +self.file_name +'/ckpts/' +'ckpt_' +str(i+1) +'.pth.tar')
                # plotting history since last checkpoint downsampled by 100
                print('Saving loss visuals...')
                num_cols=1
                num_graph_vals = 200
                down_samp_size = math.ceil(self.ckpt_freq/num_graph_vals)
                modified_array = hist_arr[-self.ckpt_freq::down_samp_size,:]
                file_path = '././my_data/my_autovc/model_data/' +self.file_name +'/ckpts/' +'ckpt_' +str(i+1) +'_loss.png'
                labels = ['iter_steps','loss','loss_id','loss_id_psnt','loss_cd']
                utils.saveContourPlots(modified_array, file_path, labels, num_cols) 
                if (i+1) % (self.ckpt_freq*2) == 0:
                    print('saving loss visuals of all history...')
                    down_samp_size = math.ceil(i/num_graph_vals)
                    modified_array = hist_arr[::down_samp_size,:]
                    file_path = '././my_data/my_autovc/model_data/' +self.file_name +'/ckpts/' +'ckpt_' +str(i+1) +'_loss_all_history.png'
                    utils.saveContourPlots(modified_array, file_path, labels, num_cols) 


In [4]:
config_path = '/homes/bdoc3/my_data/autovc_data/my_autovc/./my_data/my_autovc/model_data/1Hot16FreqL1Loss/config.pkl'
config = pickle.load(open(config_path, 'rb'))
config.data_dir='/homes/bdoc3/my_data/autovc_data/spmel'
config.file_name='testAutoVcWithoutPitchDataUsing1Hot16FreqConfigFile'
config.num_iters = 100
config.log_step = 10

In [5]:
import os, pdb, pickle, argparse, shutil
from torch.backends import cudnn


def str2bool(v):
    return v.lower() in ('true')

def overwrite_dir(directory):
    if os.path.exists(directory):
        shutil.rmtree(directory)
    os.makedirs(directory)
        

# if __name__ == '__main__':
#     parser = argparse.ArgumentParser()

#     # Model configuration.
#     parser.add_argument('--lambda_cd', type=float, default=1, help='weight for hidden code loss')
#     parser.add_argument('--dim_neck', type=int, default=32)
#     parser.add_argument('--dim_emb', type=int, default=256)
#     parser.add_argument('--dim_pre', type=int, default=512)
#     parser.add_argument('--freq', type=int, default=32)
#     parser.add_argument('--one_hot', type=str2bool, default=False, help='Toggle 1-hot mode')
#     parser.add_argument('--shape_adapt', type=str2bool, default=True, help='adjust shapes of tensors to match automatically')
#     parser.add_argument('--which_cuda', type=int, default=0, help='Determine which cuda to use')
    
#     # Training configuration.
#     parser.add_argument('--file_name', type=str, default='defaultSetup')
#     parser.add_argument('--data_dir', type=str, default='./spmel')
#     parser.add_argument('--batch_size', type=int, default=2, help='mini-batch size')
#     parser.add_argument('--num_iters', type=int, default=1000000, help='number of total iterations')
#     parser.add_argument('--adam_init', type=float, default=0.0001, help='Define initial Adam optimizer learning rate')
#     parser.add_argument('--train_size', type=int, default=20, help='Define how many speakers are used in the training set')
#     parser.add_argument('--len_crop', type=int, default=128, help='dataloader output sequence length')
#     parser.add_argument('--psnt_loss_weight', type=float, default=1.0, help='Determine weight applied to postnet reconstruction loss')
#     parser.add_argument('--prnt_loss_weight', type=float, default=1.0, help='Determine weight applied to pre-net reconstruction loss')
 
#     # Miscellaneous.
#     parser.add_argument('--load_ckpts', type=str, default='', help='toggle checkpoint load function')
#     parser.add_argument('--ckpt_freq', type=int, default=50000, help='frequency in steps to mark checkpoints')
#     parser.add_argument('--spec_freq', type=int, default=10000, help='frequency in steps to print reconstruction illustrations')
#     parser.add_argument('--log_step', type=int, default=10)
#     config = parser.parse_args()

if config.one_hot==True:
    config.dim_emb=config.train_size

print(config)
# pdb.set_trace()
overwrite_dir('././my_data/my_autovc/model_data/' +config.file_name)
os.makedirs('././my_data/my_autovc/model_data/' +config.file_name +'/ckpts')
os.makedirs('././my_data/my_autovc/model_data/' +config.file_name +'/generated_wavs')
os.makedirs('././my_data/my_autovc/model_data/' +config.file_name +'/image_comparison')
with open('././my_data/my_autovc/model_data/' +config.file_name +'/config.pkl', 'wb') as config_file:
    pickle.dump(config, config_file)

# For fast training.
cudnn.benchmark = True

# Data loader.
vcc_loader = get_loader(config)
# pass dataloader and configuration params to Solver NN
solver = Solver(vcc_loader, config)
solver.train()

Namespace(adam_init=0.0001, batch_size=2, ckpt_freq=50000, data_dir='/homes/bdoc3/my_data/autovc_data/spmel', dim_emb=20, dim_neck=32, dim_pre=512, file_name='testAutoVcWithoutPitchDataUsing1Hot16FreqConfigFile', freq=16, lambda_cd=1, len_crop=128, load_ckpts='', log_step=10, num_iters=100, one_hot=True, prnt_loss_weight=1.0, psnt_loss_weight=1.0, shape_adapt=True, spec_freq=10000, train_size=20, which_cuda=1)
Finished loading the dataset...
Start training...
Elapsed [0:00:03], Iteration [10/100], G/loss_id: 0.2987, G/loss_id_psnt: 0.8301, G/loss_cd: 0.1072
Elapsed [0:00:06], Iteration [20/100], G/loss_id: 0.1679, G/loss_id_psnt: 0.7784, G/loss_cd: 0.0886
Elapsed [0:00:10], Iteration [30/100], G/loss_id: 0.1658, G/loss_id_psnt: 0.7672, G/loss_cd: 0.0836
Elapsed [0:00:13], Iteration [40/100], G/loss_id: 0.1555, G/loss_id_psnt: 0.7418, G/loss_cd: 0.0759
Elapsed [0:00:16], Iteration [50/100], G/loss_id: 0.1478, G/loss_id_psnt: 0.7397, G/loss_cd: 0.0772
Elapsed [0:00:20], Iteration [60/100