# Imports

In [70]:
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
from tqdm import trange
import itertools
import os

from utility import load_pickle, save_pickle, show_mel

# Resources

Model design and training made following Albadawi 2020, and some open source github repos and tutorials:
- https://github.com/liusongxiang/StarGAN-Voice-Conversion/blob/master/model.py
- https://github.com/pritishyuvraj/Voice-Conversion-GAN/blob/master/model.py
- https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
- https://github.com/Lornatang/CycleGAN-PyTorch/blob/master/cyclegan_pytorch/models.py

# Models

Creating the residual block for the encoder. Instance norm 2d normalises samples with respect to them self rather than all neighbouring samples in a batch (as is done in batch normalization).

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=4, bias=False),
            nn.InstanceNorm2d(dim_out),
            nn.LeakyReLU(0.2),
            nn.Conv2d(dim_out, dim_out, kernel_size=4, bias=False),
            nn.InstanceNorm2d(dim_out))
        
    def forward(self, x):
        return x + self.main(x)

The rational for the encoder is to rid of all the non-timbral related information in its dimensionality reduction to the latent space. The generator then builds up the new voice from the rich latent space.

In [4]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
         
        l = []
        
        # Initial linear convolutional mapping
        l.append(nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=7, bias=False), 
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)))
        
        # Non-linear mapping convolutional layers
        l.append(nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, bias=False, stride=2), 
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)))
        
        l.append(nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, bias=False, stride=2),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)))
                 
        # Residual blocks with skip connections for bottleneck
        l.append(ResidualBlock(512, 1024))
         
        # OLD: Starting with less residual blocks due to memory
        # l.append(ResidualBlock(512, 1024))
        # l.append(ResidualBlock(1024, 1156))
        # l.append(ResidualBlock(1156, 1280)) 
        
        self.main = nn.Sequential(*l)
        
        
    def forward(self, x):
        return self.main(x)  

Following Albadawi 2020, making the generator the reverse of the encoder for upsampling in a decoder-like manner.

In [5]:
class Generator(nn.Module):
    def __init__(self, sharedResBlock):
        super(Generator, self).__init__()
        
        l = []
                
        # Residual blocks for expanding from latent space       
        l.append(sharedResBlock) 
        
        # OLD: Starting with less residual blocks due to memory
        # l.append(sharedResBlock)  # First res block shared across G
        # l.append(ResidualBlock(1156, 1024))
        # l.append(ResidualBlock(1024, 512)) 
        
        # Non-linear mapping convolutional layers
        l.append(nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, bias=False, stride=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)))
        
        l.append(nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, bias=False, stride=2), 
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)))
        
        # Final linear convolutional mapping 
        l.append(nn.Sequential(
            nn.ConvTranspose2d(128, 128, kernel_size=7, bias=False), 
            nn.BatchNorm2d(128),
            nn.Tanh()))  # wrt DCGAN
        
        self.main = nn.Sequential(*l)
        
        
    def forward(self, x):
        return self.main(x)             

Discriminator is downsampling to a feature space for classification. Followed DCGAN intuition. Strided convolutions with ReLU based activations, and non strided convolution for the last fifth layer with a Sigmoid activation.

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        l = []
        
        l.append(nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=4, bias=False, stride=2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)))
        
        l.append(nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=4, bias=False, stride=2),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)))
        
        l.append(nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=4, bias=False, stride=2),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2)))
        
        l.append(nn.Sequential(
            nn.Conv2d(1024, 1280, kernel_size=4, bias=False, stride=2),
            nn.BatchNorm2d(1280),
            nn.LeakyReLU(0.2)))
                
        l.append(nn.Sequential(
            nn.Conv2d(1280, 1, kernel_size=4, bias=False),
            nn.BatchNorm2d(1280),
            nn.Sigmoid()))  # wrt DCGAN
        
        self.main = nn.Sequential(*l)
         
    def forward(self, x):
        return self.main(x)         

# Training

### Hyperparameters

In [9]:
# Training control
max_epochs = 2
max_duplets = 2 
batch_size = 2
learning_rate = 0.0001

# Regularisation
lambda_cycle = 10.0

### Data

Loading data and shuffling it and converting to torch.

In [7]:
# Loading
melset_7_128 = load_pickle('pool/melset_7_128.pickle')
melset_4_128 = load_pickle('pool/melset_4_128.pickle')

In [8]:
rng = np.random.default_rng()

# Shuffling melspectrograms
melset_7_128 = rng.permutation(np.array(melset_7_128))
melset_4_128 = rng.permutation(np.array(melset_4_128))

# Torch conversion
melset_7_128 = torch.from_numpy(melset_7_128)
melset_4_128 = torch.from_numpy(melset_4_128)

### Model Instantiation

Following Albadawi 2020, unlike the other models, there is only one universal encoder in attempt to train it to rid of frequency and loudness information across speakers.

The first residual block of each generator is shared because of the shared latent space assumption coming from the universal encoder. This is to further encourage a general latent space mapping focused on timbral aspect.

In [10]:
device = 'cuda' # torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(0)

# Shared models
E = Encoder().to(device)
R = ResidualBlock(1024, 512)  # Shared first residual blocks across all Gs

# OLD: Starting with less residual blocks due to memory
# R = ResidualBlock(1280, 1156)

# Generator and Discriminator for Speaker A to B
G_A2B = Generator(R).to(device)
D_B = Discriminator().to(device)

# Generator and Discriminator for Speaker B to A
G_B2A = Generator(R).to(device)
D_A = Discriminator().to(device)

### Training Utilities

Initialise weights from dist with mu=0, s=0.02

In [17]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [19]:
E.apply(weights_init)
R.apply(weights_init)
G_A2B.apply(weights_init)
G_B2A.apply(weights_init)
D_A.apply(weights_init)
D_B.apply(weights_init)
print()




Keeping an experience roleplay buffer for training discriminator on previous iterations. 

In [24]:
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert (max_size > 0), "Empty buffer or trying to create a black hole. Be careful."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.cat(to_return)

In [25]:
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

### Objective Initialisation

MSE instead of BCE for adversarial loss making it a Least Squares GAN. Aims to minimize vanishing gradients by making it less discrete, done following CycleGAN implementation.

In [15]:
# Initialising optimizers
optim_E = torch.optim.Adam(E.parameters(), lr=learning_rate)
optim_R = torch.optim.Adam(R.parameters(), lr=learning_rate)
optim_G = torch.optim.Adam(itertools.chain(G_A2B.parameters(), G_B2A.parameters()),lr=learning_rate)
optim_D_A = torch.optim.Adam(D_A.parameters(), lr=learning_rate)
optim_D_B = torch.optim.Adam(D_B.parameters(), lr=learning_rate)

# Loss functions
loss_adversarial = torch.nn.MSELoss().to(device)
loss_cycle = torch.nn.L1Loss().to(device)

### Training loop

Training loop implemented based on Albadawi 2020 for theory and CycleGAN for implementation

In [74]:
for i in trange(max_epochs):
    for j in range(0, max_duplets, batch_size):
        
        # Loading real samples from each speaker in batches
        real_mel_A = melset_7_128[j:j+batch_size].to(device)
        real_mel_B = melset_4_128[j:j+batch_size].to(device)

        # Real data labelled 1, fake data labelled 0
        batch_size = real_mel_A.size(0)
        real_label = torch.full((batch_size, 1), 1, device=device, dtype=torch.float32)
        fake_label = torch.full((batch_size, 1), 0, device=device, dtype=torch.float32)
        
        # =====================================================
        #            Encoder and Decoding Generators update
        # =====================================================

        # Resetting gradients
        optim_E.zero_grad()
        optim_R.zero_grad()
        optim_G.zero_grad()   

        # Forward pass for B to A
        latent_mel_A = E(real_mel_A)
        post_latent_mel_A= R(latent_mel_A)
        fake_mel_A = G_B2A(post_latent_mel_B)
        fake_output_A = D_A(fake_mel_A)
        recon_mel_A = G_B2A(fake_mel_B)  # reconstuction
        
        # Forward pass for A to B
        latent_mel_B = E(real_mel_B)
        post_latent_mel_B = R(latent_mel_B)
        fake_mel_B = G_A2B(post_latent_mel_B)
        fake_output_B = D_B(fake_mel_B)
        recon_mel_B = G_A2B(fake_mel_A)  # reconstuction
        
        # VAE loss TODO
        
        # GAN loss
        loss_GAN_B2A = loss_adversarial(fake_mel_A, real_label)
        loss_GAN_A2B = loss_adversarial(fake_mel_B, real_label)
        
        # Cyclic loss
        loss_cycle_ABA = cycle_loss(recon_mel_A, real_mel_A) * lambda_cycle
        loss_cycle_BAB = cycle_loss(recon_mel_B, real_mel_B) * lambda_cycle
        
        # Update all  
        loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        loss_G.backward()
        optim_G.step()
        
        
        # =====================================================
        #                   Discriminators update
        # =====================================================
        
        # Resetting gradients
        optim_D_A.zero_grad()
        optim_D_B.zero_grad()
        
        # ==== Updating D_A ====
        # Real loss
        real_out_A = D_A(real_mel_A)
        loss_D_real_A = adversarial_loss(real_out_A, real_label)
        
        # Fake loss
        fake_mel_A = fake_A_buffer.push_and_pop(fake_mel_A)
        fake_out_A = D_A(fake_mel_A.detach())
        loss_D_fake_A = adversarial_loss(fake_out_A, fake_label)
        
        # Combine and update
        loss_D_A = (loss_D_real_A + loss_D_fake_A) / 2
        D_A.backward()
        optim_D_A.step()
        
        # ==== Updating D_B ====
        # Real loss
        real_out_B = D_B(real_mel_B)
        loss_D_real_B = adversarial_loss(real_out_B, real_label)
        
        # Fake loss
        fake_mel_B = fake_B_buffer.push_and_pop(fake_mel_B)
        fake_out_B = D_B(fake_mel_B.detach())
        loss_D_fake_B = adversarial_loss(fake_out_B, fake_label)
        
        # Combine and update
        loss_D_B = (loss_D_real_B + loss_D_fake_B) / 2
        D_B.backward()
        optim_D_B.step() 
        
        # TODO: Image saving
        
    # TODO: Checkpoint after each epoch

        
# TODO: Save last checkpoint       

  0%|          | 0/2 [00:00<?, ?it/s]


RuntimeError: Expected 4-dimensional input for 4-dimensional weight 512 1024 4 4, but got 3-dimensional input of size [2, 128, 128] instead