In [51]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np
from vqvae import VQVAE
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn as nn 
import os
from vctk_dataset import VCTK_Dataset
import gc

with_gpu = torch.cuda.is_available()

if with_gpu:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
print('We are now using %s.' % device)

train_dataset_end_idx = 35121 # About 80% of the data with completely different Speakers in train and val

dataset_train = VCTK_Dataset(segment_length=int(2**16) - 256, n_speakers=128, start_idx=0, end_idx=train_dataset_end_idx, root=r"C:\Users\JadHa\Desktop\Uni\Audio SP\VoiceConversion", download=False) 
dataset_val = VCTK_Dataset(segment_length=int(2**16) - 256, n_speakers=128, start_idx=train_dataset_end_idx, end_idx=train_dataset_end_idx+512, root=r"C:\Users\JadHa\Desktop\Uni\Audio SP\VoiceConversion", download=False)

print("Segmented train dataset size : %d"%len(dataset_train))
print("Segmented validation dataset size : %d"%len(dataset_val))

vqvae = VQVAE(in_channel=1).to(device)

print(sum(p.numel() for p in vqvae.parameters() if p.requires_grad))

optimizer = optim.AdamW(params=vqvae.parameters(), lr=3e-4)

criterion = nn.MSELoss()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
We are now using cuda.
Segmented train dataset size : 42595
Segmented validation dataset size : 599
3027329


In [53]:
train_loader = DataLoader(dataset_train, batch_size=256, shuffle=True)
val_loader = DataLoader(dataset_val, batch_size=256, shuffle=True)

In [30]:
torch.cuda.empty_cache()
gc.collect()
sample = next(iter(train_loader))
print(sample[0][0].shape)
print(sample[1][0].shape)

torch.Size([1, 80, 128])
torch.Size([128])


In [7]:
from speechbrain.pretrained import HIFIGAN

# Load a pretrained HIFIGAN Vocoder
hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-libritts-16kHz", savedir="vocoder_16khz", run_opts={"device":"cuda"})

In [10]:
chkpoint = torch.load(os.path.join("saved_models", "vqvae_vctk_amp_clip1_2.pt"))
vqvae = VQVAE(in_channel=1).to(device)
vqvae.load_state_dict(chkpoint["model_state_dict"])
vqvae.eval();

In [117]:
print(np.unique(train_loader.dataset.speaker_onehot))

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
 73 74 75 76 77 78 79 80 81 82 83 84 85 86]


In [138]:
from IPython.display import Audio, display
torch.cuda.empty_cache()
gc.collect()
vqvae.eval()
input, speaker_id = val_loader.dataset[100]
print(speaker_id)

recon_input = vqvae(input.to(device).unsqueeze(0), speaker_id.unsqueeze(0).to(device))
speaker_id_new = torch.tensor(1).int()
recon_input_vc = vqvae(input.to(device).unsqueeze(0), speaker_id_new.unsqueeze(0).to(device))

wav = hifi_gan.decode_batch(input[0, :, :].cpu())
wav_recon = hifi_gan.decode_batch(recon_input[0, :, :].cpu())
wav_recon_vc = hifi_gan.decode_batch(recon_input_vc[0, :, :].cpu())

display(Audio(wav.squeeze().cpu().numpy(), rate=16000))
display(Audio(wav_recon.squeeze().cpu().numpy(), rate=16000))
display(Audio(wav_recon_vc.squeeze().cpu().numpy(), rate=16000))

tensor(87)


In [128]:
%load_ext autoreload
%autoreload 2

from vqvae import VQVAE

vqvae = VQVAE(in_channel=1, channel=64, embed_dim=64).to(device)
optimizer = optim.AdamW(params=vqvae.parameters(), lr=3e-4)
scaler = torch.cuda.amp.GradScaler()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [129]:
chkpoint = torch.load(os.path.join("saved_models", "vqvae_vctk_amp_clip1_2.pt"))
vqvae.load_state_dict(chkpoint["model_state_dict"])
optimizer.load_state_dict(chkpoint["optimizer"])
scaler.load_state_dict(chkpoint["scaler"])

In [130]:
epochs = 10000
for epoch in range(epochs):
    epoch_loss = 0
    vqvae.train()
    for batch_idx, (input, speaker_id) in enumerate(train_loader):
        torch.cuda.empty_cache()
        gc.collect()
        optimizer.zero_grad()
        input = input.to(device)
        speaker_id = speaker_id.to(device).int()
        with torch.autocast(device_type='cuda'):
            output = vqvae(input, speaker_id)
            loss = criterion(output, input)
        epoch_loss += loss

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(vqvae.parameters(), max_norm=1) # TODO Try different values and Scheduler
        scaler.step(optimizer)
        scaler.update()
    print("Epoch %d , Training Loss : %.2f " %(epoch + 1, epoch_loss.detach().cpu()))
    # Validation
    vqvae.eval()
    epoch_loss = 0
    with torch.no_grad():
        for batch_idx, (input, speaker_id) in enumerate(val_loader):
            input = input.to(device)
            speaker_id = speaker_id.to(device).int()
            with torch.autocast(device_type='cuda'):
                output = vqvae(input, speaker_id)
                loss = criterion(output, input)
            epoch_loss += loss
    print("Epoch %d , Validation Loss : %.2f " %(epoch + 1, epoch_loss.detach().cpu()))
    checkpoint = {"model_state_dict": vqvae.state_dict(),
              "optimizer": optimizer.state_dict(),
              "scaler": scaler.state_dict()}
    torch.save(checkpoint, os.path.join("saved_models", "vqvae_vctk_amp_clip1_3.pt"))
    print('VQ-VAE is stored at folder:{}'.format('saved_models/'+'vqvae_vctk_amp_clip1_3.pt'))

Epoch 1 , Training Loss : 36.60 
Epoch 1 , Validation Loss : 1.44 
VQ-VAE is stored at folder:saved_models/vqvae_vctk_amp_clip1_2.pt
Epoch 2 , Training Loss : 36.50 
Epoch 2 , Validation Loss : 1.49 
VQ-VAE is stored at folder:saved_models/vqvae_vctk_amp_clip1_2.pt
Epoch 3 , Training Loss : 36.53 
Epoch 3 , Validation Loss : 1.41 
VQ-VAE is stored at folder:saved_models/vqvae_vctk_amp_clip1_2.pt
Epoch 4 , Training Loss : 36.33 
Epoch 4 , Validation Loss : 1.45 
VQ-VAE is stored at folder:saved_models/vqvae_vctk_amp_clip1_2.pt
Epoch 5 , Training Loss : 36.29 
Epoch 5 , Validation Loss : 1.40 
VQ-VAE is stored at folder:saved_models/vqvae_vctk_amp_clip1_2.pt
Epoch 6 , Training Loss : 36.28 
Epoch 6 , Validation Loss : 1.43 
VQ-VAE is stored at folder:saved_models/vqvae_vctk_amp_clip1_2.pt
Epoch 7 , Training Loss : 36.33 
Epoch 7 , Validation Loss : 1.39 
VQ-VAE is stored at folder:saved_models/vqvae_vctk_amp_clip1_2.pt
Epoch 8 , Training Loss : 36.18 
Epoch 8 , Validation Loss : 1.38 
VQ

KeyboardInterrupt: 