In [1]:
!pip install pyworld inflect tgt einops

Collecting pyworld
  Downloading pyworld-0.3.4.tar.gz (251 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m252.0/252.0 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting inflect
  Downloading inflect-7.0.0-py3-none-any.whl (34 kB)
Collecting tgt
  Downloading tgt-1.4.4.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: pyworld, tgt
  Building wheel for pyworld (pyproject.toml) ... [?25ldone
[?25h  Created wheel for pyworld: filename=pyworld-0.3.4-cp310-cp310-linux_x86_64.whl size=202816 sha256=fce0ed9632444b3d9f016a4d9f2ecc532b868c240e013095c9

In [2]:
## Standard libraries
import os
import numpy as np
import pandas as pd 
import random
import math
import json
from functools import partial
from PIL import Image
import wandb
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange

## tqdm for loading bars
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
from pytorch_lightning.loggers import WandbLogger
wandb_api_key = '430e8c7ef92cf79a3d7c3d02e3d961257153181f'
os.environ["WANDB_API_KEY"] = wandb_api_key

import sys; sys.path.insert(0, '/..')
os.chdir(os.path.join(os.getcwd(), 'masters-final'))


from text import _clean_text
from text import _symbol_to_id

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


In [3]:
from audio.tools import inv_mel_spec
from audio.stft import TacotronSTFT

In [4]:
# # Run the following if it's the first time running in this kernel
# # !git clone https://github.com/SitholeDavid/masters-final
!git pull

Already up to date.


In [5]:
# import librosa.display
# mel_spec = torch.from_numpy(np.load('./preprocessed_data/LJSpeech/mel/LJSpeech-mel-LJ048-0114.npy'))
# # mel_spec = torch.from_numpy(np.load('/kaggle/working/masters-final/preprocessed_data/SpanishSingleSpeaker/mel/SpanishSingleSpeaker-mel-19demarzo_0022.npy'))
# librosa.display.specshow(mel_spec.T.numpy())
# stft = TacotronSTFT(filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, sampling_rate=22050, mel_fmin=0, mel_fmax=8000)
# inv_mel_spec(mel_spec.T, 'file.wav', stft)

### Prepare Spanish Dataset

In [6]:
# !python3 prepare_align.py config/SpanishSingleSpeaker/preprocess.yaml
# !python3 preprocess.py config/SpanishSingleSpeaker/preprocess.yaml

### Dataset and Loaders

In [6]:
class SpanishDataset(Dataset):
    def __init__(self,  input_dir, input_file, max_src_len, max_trg_len):
        self.input_dir = input_dir
        self.max_src_len = max_src_len
        self.max_trg_len = max_trg_len
        
        files = pd.read_csv(os.path.join(input_dir, input_file), sep='|', header=None)
        files.columns = ['file', 'speaker', 'phones', 'text'] 
        
        # only consider valid phoneme sequences
        for index, row in tqdm(files.iterrows(), total=len(files)):
            if not all(symbol in _symbol_to_id for symbol in self._process_phones(row['phones'])):
                files.drop(index, inplace=True)

        self.file_names = files

    def _process_phones(self, phones):
        phones = phones.replace('{', '').replace('}', '').strip().split(' ')
        mapped_phones = [f'@{phone}' for phone in phones]
        return mapped_phones
        
    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        file_name = self.file_names.iloc[idx]['file']
        
        mel = np.load(os.path.join(self.input_dir, 'mel', f'SpanishSingleSpeaker-mel-{file_name}.npy'))
        duration = np.load(os.path.join(self.input_dir, 'duration', f'SpanishSingleSpeaker-duration-{file_name}.npy'))
        energy = np.load(os.path.join(self.input_dir, 'energy', f'SpanishSingleSpeaker-energy-{file_name}.npy'))
        pitch = np.load(os.path.join(self.input_dir, 'pitch', f'SpanishSingleSpeaker-pitch-{file_name}.npy'))
        phones = self._process_phones(self.file_names.iloc[idx]['phones'])
        
        phone_mapping = torch.tensor([ _symbol_to_id[symbol] for symbol in phones ])

        src_len = torch.tensor(len(phones))
        trg_len = torch.tensor(mel.shape[0])

        phoneme_pad_length = self.max_src_len - src_len
        mel_pad_length = self.max_trg_len - trg_len
        
        phone_mapping = F.pad(phone_mapping, (0, phoneme_pad_length), mode='constant', value=0)
        duration = F.pad( torch.tensor(duration), (0, phoneme_pad_length), mode='constant', value=0)
        energy = F.pad( torch.tensor(energy), (0, phoneme_pad_length), mode='constant', value=0)
        pitch = F.pad( torch.tensor(pitch), (0, phoneme_pad_length), mode='constant', value=0)
        mel = F.pad(torch.tensor(mel), (0, 0, 0, mel_pad_length), mode='constant', value=0)
        return  phone_mapping, src_len, mel, trg_len, duration
    
train_dataset = SpanishDataset('./preprocessed_data/SpanishSingleSpeaker', 'train.txt', max_src_len=200, max_trg_len=1000)
val_dataset = SpanishDataset('./preprocessed_data/SpanishSingleSpeaker', 'val.txt', max_src_len=200, max_trg_len=1000)

# set shuffle to false since we already shuffle when the data is split into train/test
train_loader = DataLoader(train_dataset, shuffle=False, batch_size=4, pin_memory=True)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=4)

  0%|          | 0/10437 [00:00<?, ?it/s]

  0%|          | 0/512 [00:00<?, ?it/s]

In [46]:
class LJSpeechDataset(Dataset):
    def __init__(self,  input_dir, input_file, max_src_len, max_trg_len):
        self.input_dir = input_dir
        self.max_src_len = max_src_len
        self.max_trg_len = max_trg_len
        
        files = pd.read_csv(os.path.join(input_dir, input_file), sep='|', header=None)
        files.columns = ['file', 'speaker', 'phones', 'text'] 
        
        # only consider valid phoneme sequences
        for index, row in tqdm(files.iterrows(), total=len(files)):
            if not all(symbol in _symbol_to_id for symbol in self._process_phones(row['phones'])):
                files.drop(index, inplace=True)

        self.file_names = files

    def _process_phones(self, phones):
        phones = phones.replace('{', '').replace('}', '').strip().split(' ')
        mapped_phones = [f'@{phone}' if str.isalnum(phone) else phone for phone in phones]
        return mapped_phones
        
    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        file_name = self.file_names.iloc[idx]['file']
        
        mel = np.load(os.path.join(self.input_dir, 'mel', f'LJSpeech-mel-{file_name}.npy'))
        duration = np.load(os.path.join(self.input_dir, 'duration', f'LJSpeech-duration-{file_name}.npy'))
        energy = np.load(os.path.join(self.input_dir, 'energy', f'LJSpeech-energy-{file_name}.npy'))
        pitch = np.load(os.path.join(self.input_dir, 'pitch', f'LJSpeech-pitch-{file_name}.npy'))
        phones = self._process_phones(self.file_names.iloc[idx]['phones'])
        
        phone_mapping = torch.tensor([ _symbol_to_id[symbol] for symbol in phones ])

        src_len = torch.tensor(len(phones))
        trg_len = torch.tensor(mel.shape[0])

        phoneme_pad_length = self.max_src_len - src_len
        mel_pad_length = self.max_trg_len - trg_len
        
        phone_mapping = F.pad(phone_mapping, (0, phoneme_pad_length), mode='constant', value=0)
        duration = F.pad( torch.tensor(duration), (0, phoneme_pad_length), mode='constant', value=0)
        energy = F.pad( torch.tensor(energy), (0, phoneme_pad_length), mode='constant', value=0)
        pitch = F.pad( torch.tensor(pitch), (0, phoneme_pad_length), mode='constant', value=0)
        mel = F.pad(torch.tensor(mel), (0, 0, 0, mel_pad_length), mode='constant', value=0)
        return  phone_mapping, src_len, mel, trg_len, duration
    
train_dataset = LJSpeechDataset('./preprocessed_data/LJSpeech', 'train.txt', max_src_len=200, max_trg_len=1000)
val_dataset = LJSpeechDataset('./preprocessed_data/LJSpeech', 'val.txt', max_src_len=200, max_trg_len=1000)

# set shuffle to false since we already shuffle when the data is split into train/test
train_loader = DataLoader(train_dataset, shuffle=False, batch_size=2 pin_memory=True)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=2

SyntaxError: invalid syntax. Perhaps you forgot a comma? (4114873204.py, line 53)

#### Model Layers

In [7]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len):
        """
        Inputs
            d_model - Hidden dimensionality of the input.
            max_len - Maximum length of a sequence to expect.
        """
        super().__init__()

        # Create matrix of [SeqLen, HiddenDim] representing the positional encoding for max_len inputs
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
        # Used for tensors that need to be on the same device as the module.
        # persistent=False tells PyTorch to not add the buffer to the state dict (e.g. when we save the model)
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

In [8]:
stft = TacotronSTFT(filter_length=1024, hop_length=256, win_length=1024, n_mel_channels=80, sampling_rate=22050, mel_fmin=0, mel_fmax=8000)
#inv_mel_spec(mel_spec.T, 'file.wav', stft)

In [9]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout, hidden_dim, kernel_size):
        super().__init__()
        
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout, batch_first=True)
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        self.mlp = nn.Sequential(
            nn.LayerNorm(embed_dim),
            Rearrange('B S E -> B E S'),
            nn.Conv1d(embed_dim, hidden_dim, kernel_size, padding = (kernel_size - 1) // 2),
            #nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv1d(hidden_dim, embed_dim, kernel_size, padding = (kernel_size - 1) // 2),
            #nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout),
            Rearrange('B E S-> B S E')
        )
        
        self._init_weights()
        
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def forward(self, x, mask=None, fill_mask=None):
        x.masked_fill_(fill_mask, 0.0)
        attn_in = self.layer_norm(x)
        attn_out, _ = self.attention(attn_in, attn_in, attn_in, key_padding_mask=mask)
        x_out = attn_out + x
        
        x_out.masked_fill_(fill_mask, 0.0)
        mlp_out = self.mlp(x_out)
        out = mlp_out + x_out
        out.masked_fill_(fill_mask, 0.0)
        return out
    
    
class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, num_layers, kernel_size, max_src_seq_len, src_vocab_size, dropout=0.0):
        super().__init__()
        
        self.pos_embedding = PositionalEncoding(embed_dim, max_src_seq_len)
        self.phone_embedding = nn.Embedding(src_vocab_size, embed_dim, padding_idx=0)
        self.layers = nn.Sequential(*[TransformerBlock(embed_dim, num_heads, dropout, hidden_dim, kernel_size)  for _ in range(num_layers)])
        
    def forward(self, x, mask=None, fill_mask=None):
        x = self.phone_embedding(x)
        x = self.pos_embedding(x)
        
        for l in self.layers:
            x = l(x, mask, fill_mask)
        
        return x
        
class TransformerDecoder(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, num_layers, kernel_size, max_trg_seq_len, dropout=0.0):
        super().__init__()
        
        self.pos_embedding = PositionalEncoding(embed_dim, max_trg_seq_len)
        self.layers = nn.Sequential(*[TransformerBlock(embed_dim, num_heads, dropout, hidden_dim, kernel_size)  for _ in range(num_layers)])
        
    def forward(self, x, mask=None, fill_mask=None):
        x = self.pos_embedding(x)
        
        for l in self.layers:
            x = l(x, mask, fill_mask)
        
        return x
        
class Hidden2Mel(nn.Module):
    def __init__(self, embed_dim, n_mels):
        super().__init__()
        self.layers = nn.Linear(embed_dim, n_mels)
        
    def forward(self, x):
        return self.layers(x)
    
class LengthRegulator(nn.Module):
    def __init__(self, max_trg_len=1000):
        super().__init__()

        self.max_trg_len = max_trg_len

    def forward(self, encoder_output, variance):
        B = encoder_output.shape[0]
        mels = list()

        for b_idx in range(B):
            expanded_seq = torch.concat([ encoder_output[b_idx,i,:].expand(v, -1) for i, v in enumerate(variance[b_idx, :]) ], dim=0)
            seq_len = expanded_seq.shape[0]
            pad_len = self.max_trg_len - seq_len

            if pad_len < 0:
                padded_seq = expanded_seq[:self.max_trg_len, :]
            elif pad_len > 0:
                padded_seq = F.pad( expanded_seq, (0, 0, 0, pad_len), "constant", -1 )

            mels.append(padded_seq)
        expanded_batch = torch.stack(mels, dim=0)
        return expanded_batch
    
class FastSpeechLoss(nn.Module):
    def __init__(self, encoder_max_seq_len, decoder_max_seq_len):
        super().__init__()

        self.encoder_max_seq_len = encoder_max_seq_len
        self.decoder_max_seq_len = decoder_max_seq_len
        self.h2m_loss = nn.L1Loss()
   
    def forward(self,  h2m_pred_mels, trg_mels):
        h2m_loss = self.h2m_loss(h2m_pred_mels, trg_mels)
        return h2m_loss
          
class FastSpeech(nn.Module):
    def __init__(self, num_heads=2, num_layers=6, dropout=0.0, embed_dim=256, hidden_dim=512, kernel_size=3, max_trg_seq_len=1000, max_src_seq_len=200, src_vocab_size=360, n_mels=80):
        super().__init__()
        self.encoder_max_seq_len = max_src_seq_len
        self.decoder_max_seq_len = max_trg_seq_len
        self.embed_dim = embed_dim
        self.n_mels = n_mels
        
        self.encoder = TransformerEncoder(num_heads=num_heads,
                                          num_layers=num_layers,
                                          dropout=dropout,
                                          embed_dim=embed_dim,
                                          hidden_dim=hidden_dim,
                                          kernel_size=kernel_size,
                                          src_vocab_size=src_vocab_size,
                                          max_src_seq_len=max_src_seq_len)
        
        self.decoder = TransformerDecoder(num_heads=num_heads,
                                          num_layers=num_layers,
                                          dropout=dropout,
                                          embed_dim=embed_dim,
                                          hidden_dim=hidden_dim,
                                          kernel_size=kernel_size,
                                          max_trg_seq_len=max_trg_seq_len)
        
        self.hidden2mel = Hidden2Mel(n_mels=n_mels, embed_dim=embed_dim)
        
        self.length_regulator = LengthRegulator(max_trg_seq_len)
        
    def _get_masks(self, seq_lens, max_seq_len, embed_dim):
        B = seq_lens.shape[0]
        masks = [[ mask_idx >= seq_lens[seq_idx]  for mask_idx in torch.arange(max_seq_len)]  for seq_idx in torch.arange(0, B)]
        masks = torch.tensor(masks)
        fill_masks = repeat(masks.T, 'b s -> e b s', e=embed_dim).T.contiguous()
        
        return masks.to(device), fill_masks.to(device)
        
    def forward(self, src_seq, src_seq_len, trg_seq, trg_seq_len, trg_durations):
        src_masks, src_fill_masks = self._get_masks(src_seq_len, self.encoder_max_seq_len, self.embed_dim)
        trg_masks, trg_fill_masks = self._get_masks(trg_seq_len, self.decoder_max_seq_len, self.embed_dim)
        _, mel_fill_masks = self._get_masks(trg_seq_len, self.decoder_max_seq_len, self.n_mels)
        
        enc_out = self.encoder(src_seq, src_masks, src_fill_masks)
        adapted_enc_out = self.length_regulator(enc_out, trg_durations)
        dec_out = self.decoder(adapted_enc_out, trg_masks, trg_fill_masks)
        out = self.hidden2mel(dec_out).masked_fill(mel_fill_masks, 0.0)
        
        return out
    

# model = FastSpeech(
#     num_heads=2,
#     num_layers=6,
#     dropout=0.0, 
#     embed_dim=256,
#     hidden_dim=512,
#     kernel_size=3,
#     max_trg_seq_len=1000,
#     max_src_seq_len=200,
#     n_mels=80,
#     src_vocab_size=len(_symbol_to_id)
# )

# phonemes, phoneme_lens, mels, mel_lens, durations = next(iter(train_loader))

# o = model(phonemes, phoneme_lens, mels, mel_lens, durations)

In [10]:
class FastSpeechModule(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.model = FastSpeech(
                        num_heads=2,
                        num_layers=6,
                        dropout=0.0, 
                        embed_dim=256,
                        hidden_dim=512,
                        kernel_size=3,
                        max_trg_seq_len=1000,
                        max_src_seq_len=200,
                        n_mels=80,
                        src_vocab_size=len(_symbol_to_id)
                    )
        
        self.loss_module = FastSpeechLoss(encoder_max_seq_len=200, decoder_max_seq_len=1000)
        
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,150], gamma=0.1)
        return [optimizer], [lr_scheduler]
    
    def training_step(self, batch, batch_idx):
        phonemes, phoneme_lens, mels, mel_lens, durations = batch
        preds = self.model(phonemes, phoneme_lens, mels, mel_lens, durations)
        loss = self.loss_module(preds, mels)
        self.log(f'train_loss', loss, on_step=True, on_epoch=True)
        
        if self.current_epoch in [1, 60, 120, 198]:
            inv_mel_spec(preds[0].T, f'file_{self.current_epoch}.wav', stft)
            print('synth')
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        phonemes, phoneme_lens, mels, mel_lens, durations = batch
        preds = self.model(phonemes, phoneme_lens, mels, mel_lens, durations)
        loss = self.loss_module(preds, mels)
        self.log(f'val_loss', loss, on_epoch=True)
    
    def test_step(self, batch, batch_idx):
        phonemes, phoneme_lens, mels, mel_lens, durations = batch
        preds = self.model(phonemes, phoneme_lens, mels, mel_lens, durations)
        loss = self.loss_module(preds, mels)

In [11]:
import wandb
wandb.finish()
wandb_logger = WandbLogger(project='Fast Speech 2 - Improved', name='spanish cha cha cha 2')

[34m[1mwandb[0m: Currently logged in as: [33msitholedavid003[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
def train_model(**kwargs):
    trainer = pl.Trainer(default_root_dir=os.path.join('checkpoints', 'initial model x'),
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=200,
                         logger=wandb_logger,
                         overfit_batches=1,
                         callbacks=[
                             ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_loss"),
                                    LearningRateMonitor("epoch")]
                        )
    trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pl.seed_everything(42) # To be reproducable
    model = FastSpeechModule(**kwargs)
    trainer.fit(model, train_loader, val_loader)
    model = FastSpeechModule.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    # Test best model on validation and test set
    val_result = trainer.test(model, val_loader, verbose=False)
    result = {"val": val_result[0]["test_loss"]}
    return model, result

model, results = train_model(lr=3e-4)

Sanity Checking: 0it [00:00, ?it/s]

  fill_masks = repeat(masks.T, 'b s -> e b s', e=embed_dim).T.contiguous()
  return torch._native_multi_head_attention(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

synth


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

synth


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

synth


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

synth


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Testing: 0it [00:00, ?it/s]

KeyError: 'test_loss'

### ViT Stuff

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_channels, num_heads, num_layers, num_classes, patch_size, num_patches, kernel_size, dropout=0.0):
        super().__init__()
        
        input_dim = patch_size * patch_size * num_channels
        
        self.input_net = nn.Sequential(
            Rearrange('B C (h p1) (w p2) -> B (h w) (C p1 p2)', p1=patch_size, p2=patch_size),
            nn.Linear(input_dim, embed_dim)
        )
        
        self.positional_embedding = nn.Embedding(num_patches + 1, embed_dim) # +1 for the CLS token
        
        self.mlp = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes),
        )
        
        self.encoder = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads, dropout, hidden_dim, kernel_size) for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        
    def forward(self, x):
        x = self.input_net(x)
        B, T, _ = x.shape
        cls_token = repeat(self.cls_token, '1 1 N -> B 1 N', B=B)
        x = torch.cat([cls_token, x], dim=1)
        pos = torch.arange(0, T+1).to(device)
        x = x + self.positional_embedding(pos)
        x = self.dropout(x)
        x = self.encoder(x)
        out = self.mlp(x[:, 0, :])
        return out
    
t = VisionTransformer(**{ 'embed_dim': 256,
                                'hidden_dim': 512,
                                'num_heads': 8,
                                'num_layers': 6,
                                'patch_size': 4,
                                'num_channels': 3,
                                'num_patches': 64,
                                'num_classes': 10,
                                'dropout': 0.2,
                                'kernel_size': 3
                        })

# t(torch.randn(6, 3, 32, 32)).shape

class ViT(pl.LightningModule):

    def __init__(self, model_kwargs, lr):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisionTransformer(**model_kwargs)
        self.example_input_array = next(iter(train_loader))[0]

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,150], gamma=0.1)
        return [optimizer], [lr_scheduler]

    def _calculate_loss(self, batch, mode="train"):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log(f'{mode}_loss', loss, on_step=True, on_epoch=True)
        self.log(f'{mode}_acc', acc, on_step=True, on_epoch=True)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="test")

In [None]:
wandb.finish()
wandb_logger = WandbLogger(project='FastSpeech2', name='ViT CNN v2')

In [None]:
def train_model(**kwargs):
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT CNN"),
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=180,
                         logger=wandb_logger,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
                                    LearningRateMonitor("epoch")])
    trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pl.seed_everything(42) # To be reproducable
    model = ViT(**kwargs)
    trainer.fit(model, train_loader, val_loader)
    model = ViT.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    # Test best model on validation and test set
    val_result = trainer.test(model, val_loader, verbose=False)
    test_result = trainer.test(model, test_loader, verbose=False)
    result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}
    return model, result

model, results = train_model(model_kwargs={
                                'embed_dim': 256,
                                'hidden_dim': 512,
                                'num_heads': 8,
                                'num_layers': 6,
                                'patch_size': 4,
                                'num_channels': 3,
                                'num_patches': 64,
                                'num_classes': 10,
                                'dropout': 0.2,
                                'kernel_size': 3
                            },
                            lr=3e-4)

print("ViT results", results)