In [1]:
import torch
import numpy as np
from vqvae import VQVAE
from torchaudio.datasets import VCTK_092
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn as nn 
import os

with_gpu = torch.cuda.is_available()

if with_gpu:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    
print('We are now using %s.' % device)

dataset = VCTK_092(root=r"C:\Users\JadHa\Desktop\Uni\VoiceConversion", download=False)

vqvae = VQVAE(in_channel=1).to(device)

print(sum(p.numel() for p in vqvae.parameters() if p.requires_grad))

optimizer = optim.Adam(params=vqvae.parameters(), lr=3e-4)

criterion = nn.CrossEntropyLoss()

speaker_id_list = dataset._speaker_ids


We are now using cuda.
496640


In [2]:
from wavenet_model import quantize_data
dataset_train = torch.utils.data.Subset(dataset, np.arange(35121))
dataset_val = torch.utils.data.Subset(dataset, np.arange(35121, len(dataset)))
segment_length = int(2**15)

def my_collate(batch):
    wav_data = []
    speaker_ids = []
    labels = []
    for wav, _, _, speaker_id, _ in batch :
        wav_extended_len = int(segment_length * np.ceil(wav.shape[1]/segment_length))
        wav_extended = torch.zeros((1,wav_extended_len))
        wav_extended[:, :wav.shape[1]] = wav / wav.abs().max()
        wav_extended = torch.stack(torch.split(wav_extended, split_size_or_sections=segment_length, dim=1), dim=0)
        wav_data.append(wav_extended)
        labels.append(torch.from_numpy(quantize_data(wav_extended, 256)).squeeze())
        speaker_id_onehot = torch.zeros((wav_extended.shape[0], 128))
        speaker_id_onehot[:, speaker_id_list.index(speaker_id)] = 1
        speaker_ids.append(speaker_id_onehot)
    return [torch.cat(wav_data, dim=0), torch.cat(labels, dim=0), torch.cat(speaker_ids, dim=0)]

train_loader = DataLoader(dataset_train, batch_size=2, shuffle=True, collate_fn=my_collate)
val_loader = DataLoader(dataset_val, batch_size=2, shuffle=True, collate_fn=my_collate)

In [3]:
sample = next(iter(train_loader))
print(sample[0].shape)
print(sample[1].shape)
print(sample[2].shape)

torch.Size([15, 1, 32768])
torch.Size([15, 32768])
torch.Size([15, 128])


In [5]:
import gc
epochs = 10
vq_loss_weight = 0.25
for epoch in range(epochs):
    epoch_loss = 0
    vqvae.train()
    for batch_idx, (input, quantized_input, speaker_id) in enumerate(train_loader):
        torch.cuda.empty_cache()
        gc.collect()
        optimizer.zero_grad()
        input = input.to(device)
        quantized_input = quantized_input.to(device)
        speaker_id = speaker_id.to(device)
        output, vq_loss = vqvae(input, speaker_id)
        recon_loss = criterion(output, quantized_input)
        vq_loss = vq_loss.mean()
        loss = recon_loss + vq_loss_weight * vq_loss
        epoch_loss += loss.detach().cpu()
        loss.backward()
        optimizer.step()
    print("Epoch %d , Training Loss : %.2f " %(epoch + 1, epoch_loss))
    # Validation
    vqvae.eval()
    epoch_loss = 0
    with torch.no_grad():
        for batch_idx, (input, quantized_input, speaker_id) in enumerate(val_loader):
            input = input.to(device)
            quantized_input = quantized_input.to(device)
            speaker_id = speaker_id.to(device)
            output, vq_loss = vqvae(input, speaker_id)
            recon_loss = criterion(output, quantized_input)
            vq_loss = vq_loss.mean()
            loss = recon_loss + vq_loss_weight * vq_loss
            epoch_loss += loss.detach().cpu()
    print("Epoch %d , Validation Loss : %.2f " %(epoch + 1, epoch_loss))
    chkpoint = {'model_state_dict': vqvae.state_dict()}
    torch.save(chkpoint, os.path.join("saved_models", "vqvae_vctk.pt"))
    print('VQ-VAE is stored at folder:{}'.format('saved_models/'+'vqvae_vctk.pt'))

OutOfMemoryError: CUDA out of memory. Tried to allocate 576.00 MiB. GPU 0 has a total capacty of 6.00 GiB of which 0 bytes is free. Of the allocated memory 4.52 GiB is allocated by PyTorch, and 323.00 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF