#Set up environment

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

%cd /content/gdrive/MyDrive/Voice_coversion

import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

Mounted at /content/gdrive
/content/gdrive/MyDrive/Voice_coversion
True
Tesla P100-PCIE-16GB


In [2]:
from model import Generator
import torch
import torch.nn.functional as F
import time
import datetime
import data_loader

In [3]:
# hyper_parameters

# model 
dim_neck = 32
dim_emb = 256
dim_pre = 512
freq = 32

# training
lambda_cd = 1
learning_rate = 0.0001
batch_size = 2
num_iter = 100000
len_crop = 128

# log
log_step = 100

# load_rate
save_rate = 10000

# checkpoint_dir
checkpoint_dir = 'checkpoints/AutoVC_custom_16khz'

In [4]:
data_path = 'spmel_16khz'

loader = data_loader.get_loader(data_path, batch_size, len_crop)

Number of processes: 3
Finish loading spearker p225
Finish loading spearker p247
Finish loading spearker p236
Finish loading spearker p226
Finish loading spearker p237
Finish loading spearker p248
Finish loading spearker p227
Finish loading spearker p249
Finish loading spearker p238
Finish loading spearker p228
Finish loading spearker p250
Finish loading spearker p229
Finish loading spearker p239
Finish loading spearker p251
Finish loading spearker p230
Finish loading spearker p240
Finish loading spearker p252
Finish loading spearker p241
Finish loading spearker p231
Finish loading spearker p253
Finish loading spearker p243
Finish loading spearker p232
Finish loading spearker p254
Finish loading spearker p244
Finish loading spearker p233
Finish loading spearker p255
Finish loading spearker p245
Finish loading spearker p234
Done joining process
Finish loading spearker p256
Finish loading spearker p246
Done joining process
Done joining process
Finished loading the dataset...


In [5]:
from torch.backends import cudnn

# for fast training
cudnn.benchmark = True

In [6]:
generator = Generator(dim_neck, dim_emb, dim_pre, freq)
optimizer = torch.optim.Adam(generator.parameters(), learning_rate)

generator = generator.to(torch.device("cuda"))

In [7]:
# load the checkpoint
try:
  checkpoint = torch.load('checkpoints/AutoVC/autovc.ckpt')
  generator.load_state_dict(checkpoint['model'])
  #optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  #i = checkpoint['iter']
  #loss = checkpoint['loss']
except:
  i = 0

In [10]:
checkpoint = torch.load('checkpoints/AutoVC/autovc.ckpt', map_location='cuda:0')
generator.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [17]:
i = 0

In [18]:
# the main training loop
while i <= num_iter:
  # data preparation
  try:
    x_real, emb_org = next(data_iter)
  except:
    data_iter = iter(loader)
    x_real, emb_org = next(data_iter)
            
            
  x_real = x_real.to(torch.device("cuda")) 
  emb_org = emb_org.to(torch.device("cuda"))

  generator = generator.train()

  # forward
  x_identic, x_identic_psnt, code_real = generator(x_real, emb_org, emb_org)
  x_identic = torch.squeeze(x_identic, 1)
  x_identic_psnt = torch.squeeze(x_identic_psnt, 1)

  # calculate the loss
  #     Decoder loss
  g_loss_id = F.mse_loss(x_real, x_identic)   
  g_loss_id_psnt = F.mse_loss(x_real, x_identic_psnt)

  #     Encoder loss
  code_reconst = generator(x_identic_psnt, emb_org, None)
  g_loss_cd = F.l1_loss(code_real, code_reconst)

  #     Total loss
  loss = g_loss_id + g_loss_id_psnt + lambda_cd*g_loss_cd

  # back propagation
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  # print the log after a couple of iterations
  if i % log_step == 0: 
    print('Iteration: {}/{}, reconst_loss: {}, reconst_post_loss: {}, encoder_lost: {}'.format(i, num_iter, g_loss_id, g_loss_id_psnt, g_loss_cd))
  
  # sace the checkpoint after a couple of iterations
  if i % save_rate == 0:
    torch.save({
                'iter': i,
                'model_state_dict': generator.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
                }, 'checkpoints/AutoVC_custom_16khz/autovc_{}.ckpt'.format(i))
  i += 1

Iteration: 0/100000, reconst_loss: 0.004596559796482325, reconst_post_loss: 0.004610959440469742, encoder_lost: 0.0015303650870919228
Iteration: 100/100000, reconst_loss: 0.0030123249161988497, reconst_post_loss: 0.0029735432472079992, encoder_lost: 0.0016755943652242422
Iteration: 200/100000, reconst_loss: 0.002097021322697401, reconst_post_loss: 0.0021038358099758625, encoder_lost: 0.001157688908278942
Iteration: 300/100000, reconst_loss: 0.003510294947773218, reconst_post_loss: 0.003473800141364336, encoder_lost: 0.0017209912184625864
Iteration: 400/100000, reconst_loss: 0.004250830039381981, reconst_post_loss: 0.004243754781782627, encoder_lost: 0.0017797271721065044
Iteration: 500/100000, reconst_loss: 0.002108479617163539, reconst_post_loss: 0.0020903346594423056, encoder_lost: 0.0010746745392680168
Iteration: 600/100000, reconst_loss: 0.0034026135690510273, reconst_post_loss: 0.0033722773659974337, encoder_lost: 0.0011233342811465263
Iteration: 700/100000, reconst_loss: 0.002902

In [11]:
print('Reconst_loss: {}, reconst_post_loss: {}, encoder_lost: {}'.format(g_loss_id.item(), g_loss_id_psnt.item(), g_loss_cd.item()))

Reconst_loss: 0.02399316616356373, reconst_post_loss: 0.023965515196323395, encoder_lost: 0.00027673013391904533
