#Set up environment

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

%cd /content/gdrive/MyDrive/Voice_coversion

import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

In [2]:
from model import Generator
import torch
import torch.nn.functional as F
import time
import datetime
import data_loader

In [3]:
# hyper_parameters

# model 
dim_neck = 32
dim_emb = 256
dim_pre = 512
freq = 32

# training
lambda_cd = 1
learning_rate = 0.0001
batch_size = 2
num_iter = 100000
len_crop = 128

# log
log_step = 100

# load_rate
save_rate = 10000

# checkpoint_dir
checkpoint_dir = 'checkpoints/AutoVC_custom_16khz'

In [None]:
data_path = 'spmel_16khz'

loader = data_loader.get_loader(data_path, batch_size, len_crop)

In [5]:
from torch.backends import cudnn

# for fast training
cudnn.benchmark = True

In [6]:
generator = Generator(dim_neck, dim_emb, dim_pre, freq)
optimizer = torch.optim.Adam(generator.parameters(), learning_rate)

generator = generator.to(torch.device("cuda"))

In [7]:
# load the checkpoint
try:
  checkpoint = torch.load('checkpoints/AutoVC_custom_16khz/autovc_100000.ckpt')
  generator.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  i = checkpoint['iter']
  loss = checkpoint['loss']
except:
  i = 0

In [10]:
"""
    Only run this cell if you want to check if your model is correct
    This cell will load the pre-trained model
"""
checkpoint = torch.load('checkpoints/AutoVC/autovc.ckpt', map_location='cuda:0')
generator.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [None]:
# the main training loop
while i <= num_iter:
  # data preparation
  try:
    x_real, emb_org = next(data_iter)
  except:
    data_iter = iter(loader)
    x_real, emb_org = next(data_iter)
            
            
  x_real = x_real.to(torch.device("cuda")) 
  emb_org = emb_org.to(torch.device("cuda"))

  generator = generator.train()

  # forward
  x_identic, x_identic_psnt, code_real = generator(x_real, emb_org, emb_org)
  x_identic = torch.squeeze(x_identic, 1)
  x_identic_psnt = torch.squeeze(x_identic_psnt, 1)

  # calculate the loss
  #     Decoder loss
  g_loss_id = F.mse_loss(x_real, x_identic)   
  g_loss_id_psnt = F.mse_loss(x_real, x_identic_psnt)

  #     Encoder loss
  code_reconst = generator(x_identic_psnt, emb_org, None)
  g_loss_cd = F.l1_loss(code_real, code_reconst)

  #     Total loss
  loss = g_loss_id + g_loss_id_psnt + lambda_cd*g_loss_cd

  # back propagation
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  # print the log after a couple of iterations
  if i % log_step == 0: 
    print('Iteration: {}/{}, reconst_loss: {}, reconst_post_loss: {}, encoder_lost: {}'.format(i, num_iter, g_loss_id, g_loss_id_psnt, g_loss_cd))
  
  # sace the checkpoint after a couple of iterations
  if i % save_rate == 0:
    torch.save({
                'iter': i,
                'model_state_dict': generator.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
                }, 'checkpoints/AutoVC_custom_16khz/autovc_{}.ckpt'.format(i))
  i += 1

In [None]:
print('Reconst_loss: {}, reconst_post_loss: {}, encoder_lost: {}'.format(g_loss_id.item(), g_loss_id_psnt.item(), g_loss_cd.item()))