In [22]:
!nvidia-smi
%env CUDA_VISIBLE_DEVICES=1
%env CUDA_VISIBLE_DEVICES

Fri Apr 30 12:44:49 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA Tesla V1...  On   | 00000000:15:00.0 Off |                    0 |
| N/A   48C    P0   274W / 300W |  12796MiB / 32510MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA Tesla V1...  On   | 00000000:16:00.0 Off |                    0 |
| N/A   38C    P0    56W / 300W |   6469MiB / 32510MiB |      0%      Default |
|       

'1'

In [3]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [4]:
import torchaudio
import torch
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm
from datetime import datetime
import numpy as np
from WaveNetTTS.model import WaveNet
import os
import random
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [5]:
random.seed(12)
os.environ['PYTHONHASHSEED'] = str(12)
np.random.seed(12)
torch.manual_seed(12)
torch.cuda.manual_seed(12)

In [6]:
sp_freq = 4000
seq_len = 4000
bins = 128
batch_size = 58
channels = 256
kernel_size = 2
dilation_depth = 9
blocks = 2
condition_size = 256

MuLawEncoding = torchaudio.transforms.MuLawEncoding(quantization_channels=bins)
Resample = torchaudio.transforms.Resample(22050, sp_freq)

hugging_face_model = 'bert-base-uncased'#'distilbert-base-uncased'#
tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', hugging_face_model)

Using cache found in /zhome/22/c/137477/.cache/torch/hub/huggingface_pytorch-transformers_master


In [7]:
from torch.cuda.amp.grad_scaler import GradScaler
from torch.cuda.amp.autocast_mode import autocast

scaler = GradScaler()
from transformers import AdamW

model = WaveNet(quantization_bins=bins, kernel_size=kernel_size, channels=channels, dilation_depth=dilation_depth, blocks=blocks, condition_size=condition_size, global_condition=True, local_condition=True)
model = model.to(device)

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [{
    'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and not 'bert' in n],
    'weight_decay': 0.01,
    'lr': 1e-4
}, {
    'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and not 'bert' in n],
    'weight_decay': 0.0,
    'lr': 1e-4
}, {
    'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and 'bert' in n],
    'weight_decay': 0.01,
    'lr': 5e-5
}, {
    'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and 'bert' in n],
    'weight_decay': 0.0,
    'lr': 5e-5
}]
optim = AdamW(optimizer_grouped_parameters, correct_bias=False, eps=1e-8)

criterion = torch.nn.CrossEntropyLoss()

losses = []

Using cache found in /zhome/22/c/137477/.cache/torch/hub/huggingface_pytorch-transformers_master


In [8]:
print("Trainable parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

Trainable parameters: 123869568


In [9]:
def collate_fn(batch_in):
    y_true_out = []
    gc_out = []
    lc_out = []
    for waveform, _, _, transcript in batch_in:
        #Mu Law Encoding of waveform
        y_true = MuLawEncoding(Resample(waveform[0])).to(device)
        #Trim random segment of the waveform with length seq_len
        random_idx = np.random.randint(len(y_true)-seq_len)
        y_true_trim = y_true[random_idx:random_idx+seq_len]

        y_true_out.append(y_true_trim)

        #Tokenize the transcript with the BERT tokenizer
        tokens = tokenizer(transcript, return_attention_mask=False, return_token_type_ids=False,return_tensors='pt')['input_ids'].to(device)

        #Feed into sentence embedding class
        gc_embed, lc_embed = model.sentence_embedding(tokens)

        gc_out.append(gc_embed)

        #Interpolate the locally conditioned signal from BERT so it fits with the waveform size and then trim the same portion of the signal as for the waveform.
        lc_embed = F.interpolate(lc_embed, size=waveform.size(1))[:,:,random_idx:random_idx+seq_len]
        lc_out.append(lc_embed)
    return torch.stack(y_true_out,0), torch.cat(gc_out, 0), torch.cat(lc_out, 0)

In [10]:
dataset = torchaudio.datasets.LJSPEECH('', download=False)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [None]:
model.train()
while True:
    with tqdm(iter(dataloader)) as t_bar:
        for y_true, gc, lc in t_bar:
            optim.zero_grad()
            
            y_true, gc, lc = y_true.to(device), gc.to(device), lc.to(device)
            # Model predictions
            y_preds = model(y_true, gc=gc, lc=lc)

            # Calculates loss. The whole indexation show is just to align predictions with the true values.
            loss = criterion(y_preds[:, :, :-1], y_true[:, -y_preds.size(2)+1:])
            
            
            #scaler.scale(loss).backward()
            #scaler.step(optim)
            #scaler.update()
            
            loss.backward()
            optim.step()
            
            # Updates
            losses.append(loss.item())
            
            
            t_bar.set_postfix_str(f'Loss: {loss.item()}, Receptive Field: {model.receptive_field}, Learned Size: {y_preds.size(2)}')
    torch.save({'model':model.state_dict(), 'optim':optim.state_dict()}, f'LJ_speech_WaveNet_{datetime.now().strftime("%d-%m-%Y")}-seq_L{seq_len}-bins{bins}-batch{batch_size}-C{channels}-k{kernel_size}-dil{dilation_depth}b{blocks}-cs{condition_size}-sp_freq{sp_freq}.pt')

100%|██████████| 226/226 [13:18<00:00,  3.53s/it, Loss: 2.6827635765075684, Receptive Field: 1022, Learned Size: 2978]
100%|██████████| 226/226 [13:17<00:00,  3.53s/it, Loss: 2.506441831588745, Receptive Field: 1022, Learned Size: 2978] 
100%|██████████| 226/226 [13:20<00:00,  3.54s/it, Loss: 2.4744057655334473, Receptive Field: 1022, Learned Size: 2978]
100%|██████████| 226/226 [13:40<00:00,  3.63s/it, Loss: 2.405155658721924, Receptive Field: 1022, Learned Size: 2978] 
 90%|████████▉ | 203/226 [11:56<01:21,  3.53s/it, Loss: 2.4261863231658936, Receptive Field: 1022, Learned Size: 2978]