In [4]:
import os
import numpy as np
import argparse
import torch
import time
import librosa
import pickle

import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import pdb

import preprocess
from trainingDataset import trainingDataset
from model_VC2 import Generator, Discriminator#, Encoder, Decoder

In [5]:
# Models
class GLU(nn.Module):
    def __init__(self):
        super(GLU, self).__init__()
        # Custom Implementation because the Voice Conversion Cycle GAN
        # paper assumes GLU won't reduce the dimension of tensor by 2.

    def forward(self, input):
        return input * torch.sigmoid(input)


class up_2Dsample(nn.Module):
    def __init__(self, upscale_factor=2):
        super(up_2Dsample, self).__init__()
        self.scale_factor = upscale_factor

    def forward(self, input):
        h = input.shape[2]
        w = input.shape[3]
        new_size = [h * self.scale_factor, w * self.scale_factor]
        return F.interpolate(input,new_size)
       

class PixelShuffle(nn.Module):
    def __init__(self, upscale_factor=2):
        super(PixelShuffle, self).__init__()
        # Custom Implementation because PyTorch PixelShuffle requires,
        # 4D input. Whereas, in this case we have have 3D array
        self.upscale_factor = upscale_factor

    def forward(self, input):
        n = input.shape[0]
        c_out = input.shape[1] // self.upscale_factor
        w_new = input.shape[2] * self.upscale_factor
        return input.view(n, c_out, w_new)


class ResidualLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ResidualLayer, self).__init__()
        self.conv1d_layer = nn.Sequential(nn.Conv1d(in_channels=in_channels,
                                                    out_channels=out_channels,
                                                    kernel_size=kernel_size,
                                                    stride=1,
                                                    padding=padding),
                                          nn.InstanceNorm1d(num_features=out_channels,
                                                            affine=True))

        self.conv_layer_gates = nn.Sequential(nn.Conv1d(in_channels=in_channels,
                                                        out_channels=out_channels,
                                                        kernel_size=kernel_size,
                                                        stride=1,
                                                        padding=padding),
                                              nn.InstanceNorm1d(num_features=out_channels,
                                                                affine=True))

        self.conv1d_out_layer = nn.Sequential(nn.Conv1d(in_channels=out_channels,
                                                        out_channels=in_channels,
                                                        kernel_size=kernel_size,
                                                        stride=1,
                                                        padding=padding),
                                              nn.InstanceNorm1d(num_features=in_channels,
                                                                affine=True))

    def forward(self, input):
        h1_norm = self.conv1d_layer(input)
        h1_gates_norm = self.conv_layer_gates(input)

        # GLU
        h1_glu = h1_norm * torch.sigmoid(h1_gates_norm)

        h2_norm = self.conv1d_out_layer(h1_glu)
        return input + h2_norm


class downSample_Generator(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(downSample_Generator, self).__init__()

        self.convLayer = nn.Sequential(nn.Conv2d(in_channels=in_channels,
                                                 out_channels=out_channels,
                                                 kernel_size=kernel_size,
                                                 stride=stride,
                                                 padding=padding),
                                       nn.InstanceNorm2d(num_features=out_channels,
                                                         affine=True))
        self.convLayer_gates = nn.Sequential(nn.Conv2d(in_channels=in_channels,
                                                       out_channels=out_channels,
                                                       kernel_size=kernel_size,
                                                       stride=stride,
                                                       padding=padding),
                                             nn.InstanceNorm2d(num_features=out_channels,
                                                               affine=True))

    def forward(self, input):
        a = self.convLayer(input)
        b = self.convLayer_gates(input)
        return self.convLayer(input) * torch.sigmoid(self.convLayer_gates(input))


class upSample_Generator(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(upSample_Generator, self).__init__()

        self.convLayer = nn.Sequential(nn.Conv2d(in_channels=in_channels,
                                                 out_channels=out_channels,
                                                 kernel_size=kernel_size,
                                                 stride=stride,
                                                 padding=padding),
                                       #PixelShuffle(upscale_factor=2),
                                       up_2Dsample(upscale_factor=2),
                                       nn.InstanceNorm2d(num_features=out_channels,
                                                         affine=True))
        self.convLayer_gates = nn.Sequential(nn.Conv2d(in_channels=in_channels,
                                                       out_channels=out_channels,
                                                       kernel_size=kernel_size,
                                                       stride=stride,
                                                       padding=padding),
                                             #PixelShuffle(upscale_factor=2),
                                             up_2Dsample(upscale_factor=2),
                                             nn.InstanceNorm2d(num_features=out_channels,
                                                               affine=True))
    def forward(self, input):        
        return self.convLayer(input) * torch.sigmoid(self.convLayer_gates(input))

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1,
                               out_channels=128,
                               kernel_size=[5,15],
                               stride=1,
                               padding=[2,7])

        self.conv1_gates = nn.Conv2d(in_channels=1,
                               out_channels=128,
                               kernel_size=[5,15],
                               stride=1,
                               padding=[2,7])

        # Downsample Layer
        self.downSample1 = downSample_Generator(in_channels=128,
                                                out_channels=256,
                                                kernel_size=5,
                                                stride=2,
                                                padding=2)

        self.downSample2 = downSample_Generator(in_channels=256,
                                                out_channels=512,
                                                kernel_size=5,
                                                stride=2,
                                                padding=2)
        #reshape
        self.conv2 = nn.Conv1d(in_channels=3072,
                               out_channels=512,
                               kernel_size=1,
                               stride=1)

        # Residual Blocks
        self.residualLayer1 = ResidualLayer(in_channels=512,
                                            out_channels=1024,
                                            kernel_size=3,
                                            stride=1,
                                            padding=1)
        self.residualLayer2 = ResidualLayer(in_channels=512,
                                            out_channels=1024,
                                            kernel_size=3,
                                            stride=1,
                                            padding=1)
        self.residualLayer3 = ResidualLayer(in_channels=512,
                                            out_channels=1024,
                                            kernel_size=3,
                                            stride=1,
                                            padding=1)
    def forward(self, input):
        # GLU
        input = input.unsqueeze(1)
        conv1 = self.conv1(input) * torch.sigmoid(self.conv1_gates(input))
        downsample1 = self.downSample1(conv1)
        downsample2 = self.downSample2(downsample1)
        downsample3 = downsample2.view([downsample2.shape[0],-1,downsample2.shape[3]])
        downsample3 = self.conv2(downsample3)
        
        residual_layer_1 = self.residualLayer1(downsample3)
        residual_layer_2 = self.residualLayer2(residual_layer_1)
        residual_layer_3 = self.residualLayer3(residual_layer_2)
        return residual_layer_3

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.residualLayer4 = ResidualLayer(in_channels=512,
                                            out_channels=1024,
                                            kernel_size=3,
                                            stride=1,
                                            padding=1)
        self.residualLayer5 = ResidualLayer(in_channels=512,
                                            out_channels=1024,
                                            kernel_size=3,
                                            stride=1,
                                            padding=1)
        self.residualLayer6 = ResidualLayer(in_channels=512,
                                            out_channels=1024,
                                            kernel_size=3,
                                            stride=1,
                                            padding=1)
        #reshape
        self.conv3 = nn.Conv1d(in_channels=512,
                               out_channels=3072,
                               kernel_size=1,
                               stride=1)


        # UpSample Layer
        self.upSample1 = upSample_Generator(in_channels=512,
                                            out_channels=1024,
                                            kernel_size=5,
                                            stride=1,
                                            padding=2)
        
        self.upSample2 = upSample_Generator(in_channels=1024,
                                            out_channels=512,
                                            kernel_size=5,
                                            stride=1,
                                            padding=2)

        self.lastConvLayer = nn.Conv2d(in_channels=512,
                                       out_channels=1,
                                       kernel_size=[5,15],
                                       stride=1,
                                       padding=[2,7])

    def forward(self, input):
        # GLU
        residual_layer_4 = self.residualLayer4(input)
        residual_layer_5 = self.residualLayer5(residual_layer_4)
        residual_layer_6 = self.residualLayer6(residual_layer_5)
        residual_layer_6 = self.conv3(residual_layer_6)
        residual_layer_6 = residual_layer_6.view([downsample2.shape[0],downsample2.shape[1],downsample2.shape[2],downsample2.shape[3]])
        
        upSample_layer_1 = self.upSample1(residual_layer_6)
        upSample_layer_2 = self.upSample2(upSample_layer_1)
        output = self.lastConvLayer(upSample_layer_2)
        output = output.view([output.shape[0],-1,output.shape[3]])
        return output


class DownSample_Discriminator(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(DownSample_Discriminator, self).__init__()

        self.convLayer = nn.Sequential(nn.Conv2d(in_channels=in_channels,
                                                 out_channels=out_channels,
                                                 kernel_size=kernel_size,
                                                 stride=stride,
                                                 padding=padding),
                                       nn.InstanceNorm2d(num_features=out_channels,
                                                         affine=True))
        self.convLayerGates = nn.Sequential(nn.Conv2d(in_channels=in_channels,
                                                      out_channels=out_channels,
                                                      kernel_size=kernel_size,
                                                      stride=stride,
                                                      padding=padding),
                                            nn.InstanceNorm2d(num_features=out_channels,
                                                              affine=True))

    def forward(self, input):
        # GLU
        return self.convLayer(input) * torch.sigmoid(self.convLayerGates(input))


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.convLayer1 = nn.Conv2d(in_channels=1,
                                    out_channels=128,
                                    kernel_size=[3, 3],
                                    stride=[1, 1])
        self.convLayer1_gates = nn.Conv2d(in_channels=1,
                                          out_channels=128,
                                          kernel_size=[3, 3],
                                          stride=[1, 1])

        # Note: Kernel Size have been modified in the PyTorch implementation
        # compared to the actual paper, as to retain dimensionality. Unlike,
        # TensorFlow, PyTorch doesn't have padding='same', hence, kernel sizes
        # were altered to retain the dimensionality after each layer

        # DownSample Layer
        self.downSample1 = DownSample_Discriminator(in_channels=128,
                                                    out_channels=256,
                                                    kernel_size=[3, 3],
                                                    stride=[2, 2],
                                                    padding=0)

        self.downSample2 = DownSample_Discriminator(in_channels=256,
                                                    out_channels=512,
                                                    kernel_size=[3, 3],
                                                    stride=[2, 2],
                                                    padding=0)

        self.downSample3 = DownSample_Discriminator(in_channels=512,
                                                    out_channels=1024,
                                                    kernel_size=[3, 3],
                                                    stride=[2, 2],
                                                    padding=0)

        self.downSample4 = DownSample_Discriminator(in_channels=1024,
                                                    out_channels=1024,
                                                    kernel_size=[1, 5],
                                                    stride=[1, 1],
                                                    padding=[0, 2])

        # Fully Connected Layer
        self.fc = nn.Linear(in_features=1024,
                            out_features=1)

        # output Layer
        self.output_layer = nn.Conv2d(in_channels=1024,
                                      out_channels=1,
                                      kernel_size=[1, 3],
                                      stride=[1, 1],
                                      padding=[0, 1])

    def forward(self, input):
        # input has shape [batch_size, num_features, time]
        # discriminator requires shape [batchSize, 1, num_features, time]
        input = input.unsqueeze(1)
        # GLU
        pad_input = nn.ZeroPad2d((1, 1, 1, 1))
        layer1 = self.convLayer1(
            pad_input(input)) * torch.sigmoid(self.convLayer1_gates(pad_input(input)))

        pad_input = nn.ZeroPad2d((1, 0, 1, 0))
        downSample1 = self.downSample1(pad_input(layer1))

        pad_input = nn.ZeroPad2d((1, 0, 1, 0))
        downSample2 = self.downSample2(pad_input(downSample1))

        pad_input = nn.ZeroPad2d((1, 0, 1, 0))
        downSample3 = self.downSample3(pad_input(downSample2))

        downSample4 = self.downSample4(downSample3)
        downSample4 = self.output_layer(downSample4)

        downSample4 = downSample4.contiguous().permute(0, 2, 3, 1).contiguous()
        # fc = torch.sigmoid(self.fc(downSample3))
        # Taking off sigmoid layer to avoid vanishing gradient problem
        #fc = self.fc(downSample4)
        fc = torch.sigmoid(downSample4)
        return fc

In [None]:
# set up
logf0s_normalization = "./cache/logf0s_normalization.npz"
mcep_normalization = "./cache/mcep_normalization.npz"
coded_sps_A_norm = "./cache/coded_sps_A_norm.pickle"
coded_sps_B_norm = "./cache/coded_sps_B_norm.pickle" 
resume_training_at = "./cache/model_checkpoint/_CycleGAN_CheckPoint"
validation_A_dir = "./data/vcc2016_training/evaluation_all/SF1/" 
output_A_dir = "./data/vcc2016_training/converted_sound/SF1"
validation_B_dir = "./data/vcc2016_training/evaluation_all/TF2/" 
output_B_dir = "./data/vcc2016_training/converted_sound/TF2/""
# =================================================
start_epoch = 0
num_epochs = 5000
mini_batch_size = 16
dataset_A = loadPickleFile(coded_sps_A_norm)
dataset_B = loadPickleFile(coded_sps_B_norm)
device = torch.device(
    'cuda' if torch.cuda.is_available() else 'cpu')

# Speech Parameters
logf0s_normalization = np.load(logf0s_normalization)
log_f0s_mean_A = logf0s_normalization['mean_A']
log_f0s_std_A = logf0s_normalization['std_A']
log_f0s_mean_B = logf0s_normalization['mean_B']
log_f0s_std_B = logf0s_normalization['std_B']

mcep_normalization = np.load(mcep_normalization)
coded_sps_A_mean = mcep_normalization['mean_A']
coded_sps_A_std = mcep_normalization['std_A']
coded_sps_B_mean = mcep_normalization['mean_B']
coded_sps_B_std = mcep_normalization['std_B']

# Encoder and Decoder
encoder_noCNN = Encoder().to(device)
encoder_noRNN = Encoder().to(device)
encoder_noCNNRNN = Encoder().to(device)
decoder_2A_noCNN = Decoder().to(device)
decoder_2A_noRNN = Decoder().to(device)
decoder_2A_noCNNRNN = Decoder().to(device)
decoder_2B_noCNN = Decoder().to(device)
decoder_2B_noRNN = Decoder().to(device)
decoder_2B_noCNNRNN = Decoder().to(device)

# Discriminator
cnn_discriminator = CNN_Discriminator().to(device)
rnn_discriminator = RNN_Discriminator().to(device)
realfake_discriminator = RealFake_Discriminator().to(device)

# Loss Functions
criterion_mse = torch.nn.MSELoss()

# Initial learning rates
generator_lr = 0.0002
discriminator_lr = 0.0001

# Learning rate decay
generator_lr_decay = generator_lr / 200000
discriminator_lr_decay = discriminator_lr / 200000

# Optimizers
def get_optimizer(one_module, learning_rate):
    return torch.optim.Adam(list(one_module.parameters()), lr=learning_rate, betas=(0.5, 0.999))

encoder_noCNN_optimizer = get_optimizer(encoder_noCNN, generator_lr)
encoder_noRNN_optimizer = get_optimizer(encoder_noRNN, generator_lr)
encoder_noCNNRNN_optimizer = get_optimizer(encoder_noCNNRNN, generator_lr)

decoder_2A_noCNN_optimizer = get_optimizer(decoder_2A_noCNN, generator_lr)
decoder_2A_noRNN_optimizer = get_optimizer(decoder_2A_noRNN, generator_lr)
decoder_2A_noCNNRNN_optimizer = get_optimizer(decoder_2A_noCNNRNN, generator_lr)
decoder_2B_noCNN_optimizer = get_optimizer(decoder_2B_noCNN, generator_lr)
decoder_2B_noRNN_optimizer = get_optimizer(decoder_2B_noRNN, generator_lr)
decoder_2B_noCNNRNN_optimizer = get_optimizer(decoder_2B_noCNNRNN, generator_lr)

realfake_discriminator_optimizer = get_optimizer(realfake_discriminator, discriminator_lr)
cnn_discriminator_optimizer = get_optimizer(cnn_discriminator, discriminator_lr)
rnn_discriminator_optimizer = get_optimizer(rnn_discriminator, discriminator_lr)


# To Load save previously saved models
modelCheckpoint = model_checkpoint

# Validation set Parameters
validation_A_dir = validation_A_dir
output_A_dir = output_A_dir
validation_B_dir = validation_B_dir
output_B_dir = output_B_dir

# Storing Discriminatior and Generator Loss
generator_loss_store = []
discriminator_loss_store = []
cnn_discriminator_loss_store = []
rnn_discriminator_loss_store = []

file_name = 'log_store_non_sigmoid.txt'

if restart_training_at is not None:
    # Training will resume from previous checkpoint
    start_epoch = loadModel(restart_training_at)
    print("Training resumed")

In [None]:
# Helper functions

def adjust_lr_rate(optimizer, name='generator'):
    if name == 'generator':
        generator_lr = max(
            0., generator_lr - generator_lr_decay)
        for param_groups in optimizer.param_groups:
            param_groups['lr'] = generator_lr
    else:
        discriminator_lr = max(
            0., discriminator_lr - discriminator_lr_decay)
        for param_groups in optimizer.param_groups:
            param_groups['lr'] = discriminator_lr

def reset_grad():
    encoder_noCNN_optimizer.zero_grad()
    encoder_noRNN_optimizer.zero_grad()
    encoder_noCNNRNN_optimizer.zero_grad()
    
    decoder_2A_noCNN_optimizer.zero_grad()
    decoder_2A_noRNN_optimizer.zero_grad()
    decoder_2A_noCNNRNN_optimizer.zero_grad()
    decoder_2B_noCNN_optimizer.zero_grad()
    decoder_2B_noRNN_optimizer.zero_grad()
    decoder_2B_noCNNRNN_optimizer.zero_grad()
    
    realfake_discriminator_optimizer.zero_grad()
    cnn_discriminator_optimizer.zero_grad()
    rnn_discriminator_optimizer.zero_grad()

def savePickle(variable, fileName):
    with open(fileName, 'wb') as f:
        pickle.dump(variable, f)

def loadPickleFile(fileName):
    with open(fileName, 'rb') as f:
        return pickle.load(f)

def store_to_file(doc):
    doc = doc + "\n"
    with open(file_name, "a") as myfile:
        myfile.write(doc)

# def saveModelCheckPoint(self, epoch, PATH):
#     torch.save({
#         'epoch': epoch,
#         'generator_loss_store': generator_loss_store,
#         'discriminator_loss_store': discriminator_loss_store,
#         'model_genA2B_state_dict': generator_A2B.state_dict(),
#         'model_genB2A_state_dict': generator_B2A.state_dict(),
#         'model_discriminatorA': discriminator_A.state_dict(),
#         'model_discriminatorB': discriminator_B.state_dict(),
#         'generator_optimizer': generator_optimizer.state_dict(),
#         'discriminator_optimizer': discriminator_optimizer.state_dict()
#     }, PATH)

# def loadModel(PATH):
#     checkPoint = torch.load(PATH)
#     generator_A2B.load_state_dict(
#         state_dict=checkPoint['model_genA2B_state_dict'])
#     generator_B2A.load_state_dict(
#         state_dict=checkPoint['model_genB2A_state_dict'])
#     discriminator_A.load_state_dict(
#         state_dict=checkPoint['model_discriminatorA'])
#     discriminator_B.load_state_dict(
#         state_dict=checkPoint['model_discriminatorB'])
#     generator_optimizer.load_state_dict(
#         state_dict=checkPoint['generator_optimizer'])
#     discriminator_optimizer.load_state_dict(
#         state_dict=checkPoint['discriminator_optimizer'])
#     epoch = int(checkPoint['epoch']) + 1
#     generator_loss_store = checkPoint['generator_loss_store']
#     discriminator_loss_store = checkPoint['discriminator_loss_store']
#     return epoch

In [None]:
# Training Begins...

# for realA, realB:
#   real -> embeddings -> fakeAB, fakeAA -> identity_loss
#   forback_AB (True) -> update AB_disc -> forback_RF (True) -> update RF_disc
#   zerograd encoder decoders
#   forback_AB (False) -> forback_RF (False) -> identity_loss.backward
#   fakeAB.Encoder.DecoderA -> fakeABA -> cycle_loss -> cycle_loss.backward
#   encoder & decoder update

# for realB, realA:
#   real -> embeddings -> fakeBA, fakeBB -> identity_loss
#   forback_AB (True) -> update AB_disc -> forback_RF (True) -> update RF_disc
#   zerograd encoder decoders
#   forback_AB (False) -> forback_RF (False) -> identity_loss.backward
#   fakeBA.Encoder.DecoderB -> fakeBAB -> cycle_loss -> cycle_loss.backward
#   encoder & decoder update

def emb_forward_backward_AB_discriminator(embeddings, isA=True, gradients_for_discriminator=True):
    cnn_discriminator_optimizer.zero_grad()
    rnn_discriminator_optimizer.zero_grad()
    
    emb_noCNN, emb_noRNN, emb_noCNNRNN = embeddings
    
    true_pred = 0.0 if isA else 1.0
    
    target_cnn_01 = true_pred if gradients_for_discriminator else 0.5
    target_cnn_10 = true_pred if gradients_for_discriminator else true_pred
    target_cnn_00 = true_pred if gradients_for_discriminator else 0.5
    
    target_rnn_01 = true_pred if gradients_for_discriminator else true_pred
    target_rnn_10 = true_pred if gradients_for_discriminator else 0.5
    target_rnn_00 = true_pred if gradients_for_discriminator else 0.5
    
    # forward CNN discriminator for all 3 embeddings
    pred_noCNN = cnn_discriminator(emb_noCNN)
    d_loss = torch.mean((target_cnn_01 - pred_noCNN) ** 2)
    d_loss.backward()
    
    pred_noRNN = cnn_discriminator(emb_noRNN)
    d_loss = torch.mean((target_cnn_10 - pred_noRNN) ** 2)
    d_loss.backward()
    
    pred_noCNNRNN = cnn_discriminator(emb_noCNNRNN)
    d_loss = torch.mean((target_cnn_00 - pred_noCNNRNN) ** 2)
    d_loss.backward()
    
    # forward RNN discriminator for all 3 embeddings
    pred_noCNN = rnn_discriminator(emb_noCNN)
    d_loss = torch.mean((target_rnn_01 - pred_noCNN) ** 2)
    d_loss.backward()
    
    pred_noRNN = rnn_discriminator(emb_noRNN)
    d_loss = torch.mean((target_rnn_10 - pred_noRNN) ** 2)
    d_loss.backward()
    
    pred_noCNNRNN = rnn_discriminator(emb_noCNNRNN)
    d_loss = torch.mean((target_rnn_00 - pred_noCNNRNN) ** 2)
    d_loss.backward()
    

def update_AB_discriminator():
    cnn_discriminator_optimizer.step()
    rnn_discriminator_optimizer.step()
    

def fakeVoice_AB_forward_backward_RealFake_discriminator(fakes, real, gradients_for_discriminator=True):
    realfake_discriminator_optimizer.zero_grad()
    
    fake_noCNN, fake_noRNN, fake_noCNNRNN = fakes
    
    pred = realfake_discriminator(fake_noCNN)
    d_loss = torch.mean((0.0 - pred) ** 2)
    d_loss.backward()
    
    pred = realfake_discriminator(fake_noRNN)
    d_loss = torch.mean((0.0 - pred) ** 2)
    d_loss.backward()
    
    pred = realfake_discriminator(fake_noCNNRNN)
    d_loss = torch.mean((0.0 - pred) ** 2)
    d_loss.backward()
    
    # if we are running this func for generator gradients, skip real input
    # because gradient doesn't flow to generators
    if gradients_for_discriminator: 
        pred = realfake_discriminator(real)
        d_loss = torch.mean(3 * (1.0 - pred) ** 2)
        d_loss.backward()
    

def update_RealFake_discriminator():
    realfake_discriminator_optimizer.step()


# for realA, realB:
#   real -> embeddings -> fakeAB, fakeAA -> identity_loss
#   forback_AB (True) -> update AB_disc -> forback_RF (True) -> update RF_disc
#   zerograd encoder decoders
#   forback_AB (False) -> forback_RF (False) -> identity_loss.backward
#   fakeAB.Encoder.DecoderA -> fakeABA -> cycle_loss -> cycle_loss.backward
#   encoder & decoder update

# for realB, realA:
#   real -> embeddings -> fakeBA, fakeBB -> identity_loss
#   forback_AB (True) -> update AB_disc -> forback_RF (True) -> update RF_disc
#   zerograd encoder decoders
#   forback_AB (False) -> forback_RF (False) -> identity_loss.backward
#   fakeBA.Encoder.DecoderB -> fakeBAB -> cycle_loss -> cycle_loss.backward
#   encoder & decoder update

for epoch in range(start_epoch, num_epochs):
    start_time_epoch = time.time()

    # Constants
    cycle_loss_lambda = 10
    identity_loss_lambda = 5
    #if epoch>20:#
        #cycle_loss_lambda = 15#
        #identity_loss_lambda = 0#

    # Preparing Dataset
    n_samples = len(dataset_A)

    dataset = trainingDataset(datasetA=dataset_A,
                              datasetB=dataset_B,
                              n_frames=128)
    train_loader = torch.utils.data.DataLoader(dataset=dataset,
                                               batch_size=mini_batch_size,
                                               shuffle=True,
                                               drop_last=False)

    for i, (real_A, real_B) in enumerate(train_loader):

        num_iterations = (
            n_samples // mini_batch_size) * epoch + i
        # print("iteration no: ", num_iterations, epoch)

        if num_iterations > 10000:
            identity_loss_lambda = 0
        if num_iterations > start_decay:
            adjust_lr_rate(
                generator_optimizer, name='generator')
            adjust_lr_rate(
                generator_optimizer, name='discriminator')

        real_A = real_A.to(device).float()
        real_B = real_B.to(device).float()
        
        # ========================================
        # for realA, realB:
        reset_grad();
        
        #   real -> embeddings -> fakeAB, fakeAA
        embedding_A_noCNN = encoder_noCNN(real_A)
        embedding_A_noRNN = encoder_noRNN(real_A)
        embedding_A_noCNNRNN = encoder_noCNNRNN(real_A)
        embeddings = (embedding_A_noCNN, embedding_A_noRNN, embedding_A_noCNNRNN)
        
        fakeAB_noCNN = decoder_2B_noCNN(embedding_A_noCNN)
        fakeAB_noRNN = decoder_2B_noRNN(embedding_A_noRNN)
        fakeAB_noCNNRNN = decoder_2B_noCNNRNN(embedding_A_noCNNRNN)
        fakesAB = (fakeAB_noCNN, fakeAB_noRNN, fakeAB_noCNNRNN)
        
        fakeAA_noCNN = decoder_2A_noCNN(embedding_A_noCNN)
        fakeAA_noRNN = decoder_2A_noRNN(embedding_A_noRNN)
        fakeAA_noCNNRNN = decoder_2A_noCNNRNN(embedding_A_noCNNRNN)
        fakesAA = (fakeAA_noCNN, fakeAA_noRNN, fakeAA_noCNNRNN)
        
        #   forback_AB (True) -> update AB_disc -> forback_RF (True) -> update RF_disc
        emb_forward_backward_AB_discriminator(embeddings, isA=True, gradients_for_discriminator=True)
        update_AB_discriminator()
        
        fakeVoice_AB_forward_backward_RealFake_discriminator(fakesAB, real_B, gradients_for_discriminator=True)
        update_RealFake_discriminator()
        
        #   zerograd encoder decoders (can also zero grad everything since discriminators have been updated)
        reset_grad()
        
        #   forback_AB (False) -> forback_RF (False) -> identity_loss.backward
        emb_forward_backward_AB_discriminator(embeddings, isA=True, gradients_for_discriminator=False)
        fakeVoice_AB_forward_backward_RealFake_discriminator(fakesAB, real_B, gradients_for_discriminator=False)
        
        identityLoss_noCNN = torch.mean(torch.abs(real_A - fakeAA_noCNN))
        identityLoss_noRNN = torch.mean(torch.abs(real_A - fakeAA_noRNN))
        identityLoss_noCNNRNN = torch.mean(torch.abs(real_A - fakeAA_noCNNRNN))
        identityLoss = (identityLoss_noCNN + identityLoss_noRNN + identityLoss_noCNNRNN) / 3.0
        identityLoss.backward()
        
        #   fakeAB.Encoder.DecoderA -> fakeABA -> cycle_loss -> cycle_loss.backward
        fakeABA_noCNN = decoder_2A_noCNN(encoder_noCNN(fakeAB_noCNN))
        fakeABA_noRNN = decoder_2A_noRNN(encoder_noRNN(fakeAB_noRNN))
        fakeABA_noCNNRNN = decoder_2A_noCNNRNN(encoder_noCNNRNN(fakeAB_noCNNRNN))
        
        cycleLoss_noCNN = torch.mean(torch.abs(real_A - fakeABA_noCNN))
        cycleLoss_noRNN = torch.mean(torch.abs(real_A - fakeABA_noRNN))
        cycleLoss_noCNNRNN = torch.mean(torch.abs(real_A - fakeABA_noCNNRNN))
        cycleLoss = (cycleLoss_noCNN + cycleLoss_noRNN + cycleLoss_noCNNRNN) / 3.0
        cycleLoss.backward()
        
        #   encoder & decoder update
        encoder_noCNN_optimizer.step()
        encoder_noRNN_optimizer.step()
        encoder_noCNNRNN_optimizer.step()
        decoder_2A_noCNN_optimizer.step()
        decoder_2A_noRNN_optimizer.step()
        decoder_2A_noCNNRNN_optimizer.step()
        decoder_2B_noCNN_optimizer.step()
        decoder_2B_noRNN_optimizer.step()
        decoder_2B_noCNNRNN_optimizer.step()
        
        
        # ========================================
        # for realB, realA:
        reset_grad();
        
        #   real -> embeddings -> fakeBA, fakeBB
        embedding_B_noCNN = encoder_noCNN(real_B)
        embedding_B_noRNN = encoder_noRNN(real_B)
        embedding_B_noCNNRNN = encoder_noCNNRNN(real_B)
        embeddings = (embedding_B_noCNN, embedding_B_noRNN, embedding_B_noCNNRNN)
        
        fakeBA_noCNN = decoder_2A_noCNN(embedding_B_noCNN)
        fakeBA_noRNN = decoder_2A_noRNN(embedding_B_noRNN)
        fakeBA_noCNNRNN = decoder_2A_noCNNRNN(embedding_B_noCNNRNN)
        fakesBA = (fakeBA_noCNN, fakeBA_noRNN, fakeBA_noCNNRNN)
        
        fakeBB_noCNN = decoder_2B_noCNN(embedding_B_noCNN)
        fakeBB_noRNN = decoder_2B_noRNN(embedding_B_noRNN)
        fakeBB_noCNNRNN = decoder_2B_noCNNRNN(embedding_B_noCNNRNN)
        fakesBB = (fakeBB_noCNN, fakeBB_noRNN, fakeBB_noCNNRNN)
        
        #   forback_AB (True) -> update AB_disc -> forback_RF (True) -> update RF_disc
        emb_forward_backward_AB_discriminator(embeddings, isA=False, gradients_for_discriminator=True)
        update_AB_discriminator()
        
        fakeVoice_AB_forward_backward_RealFake_discriminator(fakesBA, real_A, gradients_for_discriminator=True)
        update_RealFake_discriminator()
        
        #   zerograd encoder decoders (can also zero grad everything since discriminators have been updated)
        reset_grad()
        
        #   forback_AB (False) -> forback_RF (False) -> identity_loss.backward
        emb_forward_backward_AB_discriminator(embeddings, isA=False, gradients_for_discriminator=False)
        fakeVoice_AB_forward_backward_RealFake_discriminator(fakesBA, real_A, gradients_for_discriminator=False)
        
        identityLoss_noCNN = torch.mean(torch.abs(real_B - fakeBB_noCNN))
        identityLoss_noRNN = torch.mean(torch.abs(real_B - fakeBB_noRNN))
        identityLoss_noCNNRNN = torch.mean(torch.abs(real_B - fakeBB_noCNNRNN))
        identityLoss = (identityLoss_noCNN + identityLoss_noRNN + identityLoss_noCNNRNN) / 3.0
        identityLoss.backward()
        
        #   fakeBA.Encoder.DecoderB -> fakeBAB -> cycle_loss -> cycle_loss.backward
        fakeBAB_noCNN = decoder_2B_noCNN(encoder_noCNN(fakeBA_noCNN))
        fakeBAB_noRNN = decoder_2B_noRNN(encoder_noRNN(fakeBA_noRNN))
        fakeBAB_noCNNRNN = decoder_2B_noCNNRNN(encoder_noCNNRNN(fakeBA_noCNNRNN))
        
        cycleLoss_noCNN = torch.mean(torch.abs(real_B - fakeBAB_noCNN))
        cycleLoss_noRNN = torch.mean(torch.abs(real_B - fakeBAB_noRNN))
        cycleLoss_noCNNRNN = torch.mean(torch.abs(real_B - fakeBAB_noCNNRNN))
        cycleLoss = (cycleLoss_noCNN + cycleLoss_noRNN + cycleLoss_noCNNRNN) / 3.0
        cycleLoss.backward()
        
        #   encoder & decoder update
        encoder_noCNN_optimizer.step()
        encoder_noRNN_optimizer.step()
        encoder_noCNNRNN_optimizer.step()
        decoder_2A_noCNN_optimizer.step()
        decoder_2A_noRNN_optimizer.step()
        decoder_2A_noCNNRNN_optimizer.step()
        decoder_2B_noCNN_optimizer.step()
        decoder_2B_noRNN_optimizer.step()
        decoder_2B_noCNNRNN_optimizer.step()
        
        