In [3]:
import sys; sys.path.insert(0, '..')
from text import _clean_text
from text import _symbol_to_id

In [17]:
import yaml
import tgt
import pandas as pd
import os
import numpy as np  
from tqdm import tqdm

import torch 
import torch.nn as nn 
import torch.nn.functional as F
import torchaudio as TA
from torchaudio.transforms import MelSpectrogram

import librosa
from scipy.io import wavfile

import pytorch_lightning as pl

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
preprocessing_config = yaml.load(open('../config/preprocessing.yaml', 'r'), Loader=yaml.FullLoader)
model_config = yaml.load(open('../config/model.yaml', 'r'), Loader=yaml.FullLoader)

raw_path = preprocessing_config['path']['raw_path']
input_dir = preprocessing_config['path']['input_dir']

sampling_rate = preprocessing_config['audio']['sampling_rate']
hop_length = preprocessing_config['audio']['hop_length']
win_length = preprocessing_config['audio']['win_length']
n_fft = preprocessing_config['audio']['n_fft']
n_mels = preprocessing_config['audio']['n_mel_channels']
fmin = preprocessing_config['audio']['mel_fmin']
fmax = preprocessing_config['audio']['mel_fmax']
cleaners = preprocessing_config['text']['text_cleaners']

metadata = pd.read_csv(os.path.join(raw_path, 'metadata.csv'), sep='|', header=None)
metadata.columns = ['file', 'text', 'text_']
metadata.drop(['text_'], axis=1, inplace=True)

Create input dir and copy files over

In [23]:
input_dir

'./../data/raw_data'

In [28]:
temp_dir = './../kaggle_data/raw_data'

if not os.path.exists(input_dir):
    os.makedirs(input_dir)

if not os.path.exists(temp_dir):
    os.makedirs(temp_dir)

In [27]:
input_dir

'./../data/raw_data'

In [30]:
silent_phones = ["sil", "sp", "spn"]
mel_spec_transform = MelSpectrogram(sample_rate=sampling_rate, n_fft=n_fft, win_length=win_length, hop_length=hop_length, f_min=fmin, f_max=fmax, n_mels=n_mels)

for line in tqdm(metadata.iterrows()):
    file = line[1]['file']
    text = line[1]['text']
    clean_text = _clean_text(text, cleaner_names=cleaners)
    
    textgrid = tgt.io.read_textgrid(os.path.join(raw_path, 'TextGrid', 'LJSpeech', f'{file}.TextGrid'))
    textgrid = textgrid.get_tier_by_name('phones')

    phonemes = []
    durations = []

    start_time = 0.0
    end_time = 0.0
    end_idx = 0
    
    for tier in textgrid._objects:

        # trim the initial silent phonemes
        if tier.text in silent_phones and len(durations) == 0:
            continue

        # record the start time
        if len(phonemes) == 0:
            start_time = tier.start_time

        phonemes.append(tier.text)
        durations.append(round( tier.end_time * sampling_rate / hop_length ) - round( tier.start_time * sampling_rate / hop_length ))

        # trim the last silent phonemes
        if tier.text not in silent_phones:
            end_time = tier.end_time
            end_idx = len(phonemes)

       
    # compute the mel-spectogram
    phonemes = phonemes[:end_idx]
    durations = durations[:end_idx]
    phoneme_seq = "{" + " ".join(phonemes) + "}"

    wav_path = os.path.join(raw_path, 'wavs', f'{file}.wav')
    audio, _ = librosa.load(wav_path, sr=sampling_rate)
    audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
    mel = mel_spec_transform(audio).squeeze(0)[:, :sum(durations)]

    assert mel.shape[-1] == np.sum(durations), f"Number of frames not equal for file {line[0]}"

    sample_input_dir = os.path.join(input_dir, "LJSpeech", file)
    audio_file_name = os.path.join(sample_input_dir, f'{file}.wav')
    mel_file_name = os.path.join(sample_input_dir, f'{file}-mel.npy')
    duration_file_name = os.path.join(sample_input_dir, f'{file}-duration.npy')
    text_file_name = os.path.join(sample_input_dir, f'{file}.csv')

    if not os.path.exists(sample_input_dir):
        os.makedirs(sample_input_dir)

    wavfile.write(audio_file_name, sampling_rate, audio.reshape(-1, 1).numpy())
    np.save(mel_file_name, mel.T)
    np.save(duration_file_name, durations)
    np.save(mel_file_name, mel.T)

    with open(text_file_name, 'w') as f:
        f.write(f'{file[0]} | {clean_text} | {phoneme_seq}')

    sample_temp_dir = os.path.join(temp_dir, "LJSpeech", file)
    audio_file_name = os.path.join(sample_temp_dir, f'{file}.wav')
    mel_file_name = os.path.join(sample_temp_dir, f'{file}-mel.npy')
    duration_file_name = os.path.join(sample_temp_dir, f'{file}-duration.npy')
    text_file_name = os.path.join(sample_temp_dir, f'{file}.csv')

    if not os.path.exists(sample_temp_dir):
        os.makedirs(sample_temp_dir)

    #wavfile.write(audio_file_name, sampling_rate, audio.reshape(-1, 1).numpy())
    np.save(mel_file_name, mel.T)
    np.save(duration_file_name, durations)
    np.save(mel_file_name, mel.T)

    with open(text_file_name, 'w') as f:
        f.write(f'{file[0]} | {clean_text} | {phoneme_seq}')

13100it [07:58, 27.38it/s]


### Define dataset and dataloaders

In [5]:
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class LJSpeechDataset(Dataset):
    def __init__(self,  input_dir, input_file, max_src_len, max_trg_len, split='train', test_batch=False):
        metadata = pd.read_csv(input_file, sep='|', header=None)
        metadata.columns = ['file', 'text', 'text_']

        self.input_dir = input_dir
        self.max_src_len = max_src_len
        self.max_trg_len = max_trg_len
        
        file_names = metadata['file'].to_numpy()

        if test_batch:
            self.file_names = file_names[:16]
            return

        x_train, x_test = train_test_split(file_names, test_size=0.2, random_state=42, shuffle=True)
        
        if split == 'train':
            self.file_names = x_train
        else:
            self.file_names = x_test

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        mel = np.load(os.path.join(self.input_dir, file_name, f'{file_name}-mel.npy'))
        duration = np.load(os.path.join(self.input_dir, file_name, f'{file_name}-duration.npy'))
        phones = pd.read_csv(os.path.join(self.input_dir, file_name, f'{file_name}.csv'), sep='|', header=None).iloc[0][2].replace('{', '').replace('}', '').strip().split(' ')
        phone_mapping = torch.tensor([ _symbol_to_id[symbol] for symbol in phones ])
        

        src_len = torch.tensor(len(phones))
        trg_len = torch.tensor(mel.shape[0])

        phoneme_pad_length = self.max_src_len - src_len
        mel_pad_length = self.max_trg_len - trg_len
        
        phone_mapping = F.pad(phone_mapping, (0, phoneme_pad_length), mode='constant', value=0)
        duration = F.pad( torch.tensor(duration), (0, phoneme_pad_length), mode='constant', value=0)
        mel = F.pad(torch.tensor(mel), (0, 0, 0, mel_pad_length), mode='constant', value=0)
        
        return  src_len, trg_len, duration, phone_mapping, mel 

In [6]:
train_dataset = LJSpeechDataset(os.path.join(input_dir, "LJSpeech"), os.path.join(raw_path, 'metadata.csv'), max_src_len=200, max_trg_len=1000, split='train')
val_dataset = LJSpeechDataset(os.path.join(input_dir, "LJSpeech"), os.path.join(raw_path, 'metadata.csv'), max_src_len=200, max_trg_len=1000, split='test')

# set shuffle to false since we already shuffle when the data is split into train/test
train_loader = DataLoader(train_dataset, shuffle=False, batch_size=4)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=4)

### Building the Torch model

In [19]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads=8, embed_dim=256):
        super().__init__()
        self.n_heads = n_heads
        self.embed_dim = embed_dim 
        self.head_dim = embed_dim // n_heads

        assert self.head_dim * self.n_heads == self.embed_dim, "The embedding dimension must be divisible by the number of heads"

        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(self.head_dim * self.n_heads, self.embed_dim, bias=False)

    def forward(self, query, key, value, mask=None):
        # query_dim -> B, num_phonemes, embed_dim -> B, num_phonemes, n_heads, head_dim
        B = query.shape[0]
        query_len = query.shape[1]
        key_len = key.shape[1]
        value_len = value.shape[1] 

        query = query.reshape(B, query_len, self.n_heads, self.head_dim)
        value = value.reshape(B, value_len, self.n_heads, self.head_dim)
        key = key.reshape(B, key_len, self.n_heads, self.head_dim)

        query = self.queries(query)
        value = self.values(value)
        key = self.keys(key)

        # compute energy 
        # B, query_len, n_heads, head_dim * B, key_len, n_heads, head_dim -> B, n_heads, query_len, key_len

        energy = torch.einsum('bqnh,bknh->bnqk', [query, key])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float('1e-20'))

        attention = torch.softmax(energy / (self.embed_dim ** (0.5)), dim=3)

        # compute output 
        # attention dim -> (N, heads, query_len, key_len)
        # values dim -> (N, value_len, heads, head_dim)
        # output dim -> (N, query_len, heads, head_dim)
        
        output = torch.einsum('nhqk,nvhd->nqhd', [attention, value]).reshape(B, query_len, self.n_heads * self.head_dim)
        output = self.fc_out(output)

        return output, attention

#MultiHeadAttention()(torch.randn(5, 5, 256), torch.randn(5, 5, 256), torch.randn(5, 5, 256)).shape

class PositionWiseFeedForward(nn.Module):
    def __init__(self, kernel_size=9, embed_dim=256, forward_expansion=2, dropout=0.5):
        super().__init__()

        self.layers = nn.Sequential( 
            nn.Conv1d(embed_dim, embed_dim * forward_expansion, kernel_size, padding = (kernel_size - 1) // 2),
            nn.ReLU(),
            nn.Conv1d(embed_dim * forward_expansion, embed_dim, kernel_size, padding = (kernel_size - 1) // 2),
            nn.Dropout(dropout),
            nn.ReLU()
        )

    def forward(self, x):
        return self.layers(x)
    
#PositionWiseFeedForward()(torch.randn(5, 256, 100)).shape

def get_positional_encoding(seq_len, d_model):
    encoding = torch.zeros((seq_len, d_model))

    for k in range(seq_len):
        for i in range(d_model // 2):
            encoding[k, 2*i] = np.sin(k / (10000 ** ((2 * i) / d_model)))
            encoding[k, 2*i + 1] = np.cos(k / (10000 ** ((2 * i) / d_model)))

    return nn.Parameter(encoding)

class TransformerEncoderLayer(nn.Module):
    def __init__(self, n_heads=8, embed_dim=256, dropout=0.5, forward_expansion=2):
        super().__init__()

        self.attention = MultiHeadAttention(n_heads=n_heads, embed_dim=embed_dim)
        self.feed_forward = PositionWiseFeedForward(embed_dim=embed_dim, forward_expansion=forward_expansion, dropout=dropout)
        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        src_seq, src_mask = x
        out, _ = self.attention(src_seq, src_seq, src_seq, src_mask)
        out = self.dropout(self.layer_norm_1(out + src_seq))
        out_2 = self.feed_forward(out.permute(0, 2, 1))
        out_2 = self.dropout(self.layer_norm_2(out + out_2.permute(0, 2, 1)))
        return out_2, src_mask

class TransformerEncoder(nn.Module):
    def __init__(self, max_seq_len=200, src_vocab_size=64, n_layers=4, n_heads=8, embed_dim=256, dropout=0.5, forward_expansion=2):
        super().__init__()
        self.pos_embedding = get_positional_encoding(max_seq_len, embed_dim)
        self.phone_embedding = nn.Embedding(src_vocab_size, embed_dim, padding_idx=0)
        self.encoder = nn.Sequential( *[ TransformerEncoderLayer(n_heads=n_heads, embed_dim=embed_dim, dropout=dropout, forward_expansion=forward_expansion) for _ in range(n_layers) ])

    def forward(self, phonemes, mask=None):
        phoneme_embeddings = self.phone_embedding(phonemes) # max_seq_len, embed_dim
        model_in = self.pos_embedding + phoneme_embeddings
        out = self.encoder((model_in, mask))
        return out

class VariancePredictor(nn.Module):
    def __init__(self, embed_size=256, out_dim=1, kernel_size=3, dropout=0.2):
        super().__init__()

        self.conv1 = nn.Conv1d(embed_size, embed_size, kernel_size, padding=1)
        self.layer_norm_1 = nn.LayerNorm(embed_size)
        self.dropout_1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(embed_size, embed_size, kernel_size, padding=1)
        self.layer_norm_2 = nn.LayerNorm(embed_size)
        self.dropout_2 = nn.Dropout(dropout)

        self.fc_out = nn.Linear(embed_size, out_dim)

    def forward(self, x):
        out = F.relu(self.conv1(x.permute(0, 2, 1)))
        out = self.layer_norm_1(out.permute(0, 2, 1))
        out = self.dropout_1(out)
        out = F.relu(self.conv2(out.permute(0, 2, 1)))
        out = self.layer_norm_2(out.permute(0, 2, 1))
        out = self.dropout_2(out)
        out = F.relu(self.fc_out(out))

        return out
    
class LengthRegulator(nn.Module):
    def __init__(self, max_trg_len=1000):
        super().__init__()

        self.max_trg_len = max_trg_len

    def forward(self, encoder_output, variance):
        B = encoder_output.shape[0]
        mels = list()

        for b_idx in range(B):
            expanded_seq = torch.concat([ encoder_output[b_idx,i,:].expand(v, -1) for i, v in enumerate(variance[b_idx, :]) ], dim=0)
            seq_len = expanded_seq.shape[0]
            pad_len = self.max_trg_len - seq_len

            if pad_len < 0:
                padded_seq = expanded_seq[:self.max_trg_len, :]
            elif pad_len > 0:
                padded_seq = F.pad( expanded_seq, (0, 0, 0, pad_len), "constant", -1 )

            mels.append(padded_seq)


        expanded_batch = torch.stack(mels, dim=0)
        return expanded_batch
    
class TransformerDecoderLayer(nn.Module):
    def __init__(self, n_heads=8, embed_dim=256, dropout=0.5, forward_expansion=2):
        super().__init__()

        self.attention = MultiHeadAttention(n_heads=n_heads, embed_dim=embed_dim)
        self.feed_forward = PositionWiseFeedForward(embed_dim=embed_dim, forward_expansion=forward_expansion, dropout=dropout)
        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        trg_seq, trg_mask = x
        out, _ = self.attention(trg_seq, trg_seq, trg_seq, trg_mask)
        out = self.dropout(self.layer_norm_1(out + trg_seq))
        out_2 = self.feed_forward(out.permute(0, 2, 1))
        out_2 = self.dropout(self.layer_norm_2(out + out_2.permute(0, 2, 1)))
        return out_2, trg_mask

class TransformerDecoder(nn.Module):
    def __init__(self, max_seq_len=200, trg_vocab_size=64, n_layers=4, n_heads=8, embed_dim=256, dropout=0.5, forward_expansion=2):
        super().__init__()
        self.pos_embedding = get_positional_encoding(max_seq_len, embed_dim)
       # self.mel_embedding = nn.Embedding(trg_vocab_size, embed_dim)
        self.decoder = nn.Sequential( *[ TransformerDecoderLayer(n_heads=n_heads, embed_dim=embed_dim, dropout=dropout, forward_expansion=forward_expansion) for _ in range(n_layers) ])

    def forward(self, phonemes, trg_masks=None):
        #mel_embeddings = self.mel_embedding(phonemes) # max_seq_len, embed_dim
        model_in = self.pos_embedding + phonemes
        out = self.decoder((model_in, trg_masks))
        return out
    
class Hidden2Mel(nn.Module):
    def __init__(self, in_dim=256, out_dim=80):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(in_dim, in_dim),
            nn.ReLU(),
            nn.Linear(in_dim, out_dim),
            nn.ReLU()
         )
        
    def forward(self, x):
        return self.layers(x)
    
class PostNet(nn.Module):
    def __init__(self, in_dim=80, postnet_dim=512, n_layers=5, kernel_size=5):
        super().__init__()

        layers = list()
        
        layers.append(nn.Sequential(
            nn.Conv1d(in_dim, postnet_dim, kernel_size=5, padding=(kernel_size - 1) // 2, bias=True),
            nn.BatchNorm1d(postnet_dim)))
        
        for _ in range(n_layers):
            layers.append(
                nn.Sequential(
                    nn.Conv1d(postnet_dim, postnet_dim, kernel_size=5, padding=(kernel_size - 1) // 2, bias=True),
                    nn.BatchNorm1d(postnet_dim)
                )
            )

        layers.append(
            nn.Sequential(
                    nn.Conv1d(postnet_dim, in_dim, kernel_size=5, padding=(kernel_size - 1) // 2, bias=True),
                    nn.BatchNorm1d(in_dim)
            )
        )

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return x  + self.layers(x)
    
#PostNet()(torch.randn(5, 80, 100)).shape

class FastSpeechLoss(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.encoder_max_seq_len = config['model']['encoder']['max_seq_len']
        self.decoder_max_seq_len = config['model']['decoder']['max_seq_len']

        self.duration_loss = nn.MSELoss()
        self.mel_loss = nn.MSELoss()
        self.h2m_loss = nn.MSELoss()

    def _get_mask(self, seq_lens, max_seq_len):
    # durations will be B, seq_len
    # mels will be B, mel_channels, seq_len
    # seq_lens is B, seq_len
    # mask is B, 1, 1, max_seq_len
        B = seq_lens.shape[0]
        masks = torch.zeros(max_seq_len).repeat(B, 1, 1)

        for b_idx in range(B):
            masks[b_idx, :] = torch.tensor( [ [ seq_len > idx  for idx in torch.arange(0, max_seq_len) ]  for seq_len in seq_lens[b_idx, :] ])

        return masks.to(device)

    def forward(self, src_seq_len, trg_seq_len, pred_durations, trg_durations, h2m_pred_mels, pred_mels, trg_mels):

        duration_mask = self._get_mask(src_seq_len.reshape(-1, 1), self.encoder_max_seq_len)
        mel_mask = self._get_mask(trg_seq_len.reshape(-1, 1), self.decoder_max_seq_len)

        h2m_pred_mels = h2m_pred_mels.permute(0, 2, 1)
        pred_mels = pred_mels
        pred_durations = pred_durations.permute(0, 2, 1)
        trg_mels = trg_mels.permute(0, 2, 1)
        if duration_mask is not None:
            pred_durations = pred_durations.masked_fill(duration_mask == 0, float('0'))

        if mel_mask is not None:
            pred_mels = pred_mels.masked_fill(mel_mask == 0, float('0'))
            h2m_pred_mels = h2m_pred_mels.masked_fill(mel_mask == 0, float('0'))

        duration_loss = self.duration_loss(pred_durations.squeeze(1), trg_durations.float())
        h2m_loss = self.h2m_loss(h2m_pred_mels, trg_mels)
        mel_loss = self.mel_loss(pred_mels, trg_mels)

        return duration_loss, h2m_loss, mel_loss

class FastSpeech(nn.Module):
    def __init__(self, config):
        super().__init__() 

        self.encoder_max_seq_len = config['model']['encoder']['max_seq_len']
        self.decoder_max_seq_len = config['model']['decoder']['max_seq_len']

        encoder_max_seq_len = config['model']['encoder']['max_seq_len']
        encoder_src_vocab_size = config['model']['encoder']['src_vocab_size']
        encoder_n_layers = config['model']['encoder']['n_layers']
        encoder_n_heads = config['model']['encoder']['n_heads']
        encoder_embed_dim = config['model']['encoder']['embed_dim']
        encoder_dropout = config['model']['encoder']['dropout']
        encoder_forward_expansion = config['model']['encoder']['forward_expansion']

        decoder_max_seq_len = config['model']['decoder']['max_seq_len']
        decoder_trg_vocab_size = config['model']['decoder']['trg_vocab_size']
        decoder_n_layers = config['model']['decoder']['n_layers']
        decoder_n_heads = config['model']['decoder']['n_heads']
        decoder_embed_dim = config['model']['decoder']['embed_dim']
        decoder_dropout = config['model']['decoder']['dropout']
        decoder_forward_expansion = config['model']['decoder']['forward_expansion']

        variance_predictor_embed_dim = config['model']['variance_predictor']['embed_dim']
        variance_predictor_out_dim = config['model']['variance_predictor']['out_dim']
        variance_predictor_kernel_size = config['model']['variance_predictor']['kernel_size']
        variance_predictor_dropout = config['model']['variance_predictor']['dropout']

        length_regulator_max_trg_len = config['model']['length_regulator']['max_trg_len']

        hidden2mel_in_dim = config['model']['hidden_2_mel']['in_dim']
        hidden2mel_out_dim = config['model']['hidden_2_mel']['out_dim']

        postnet_in_dim = config['model']['postnet']['in_dim']
        postnet_postnet_dim = config['model']['postnet']['postnet_dim']
        postnet_n_layers = config['model']['postnet']['n_layers']
        postnet_kernel_size = config['model']['postnet']['kernel_size']
        
        # encoder, variance adaptor, length regulator, decoder, hidden2mel, postnet

        self.encoder = TransformerEncoder(max_seq_len=encoder_max_seq_len, src_vocab_size=encoder_src_vocab_size, n_layers=encoder_n_layers, n_heads=encoder_n_heads, embed_dim=encoder_embed_dim, dropout=encoder_dropout, forward_expansion=encoder_forward_expansion)
        self.decoder = TransformerDecoder(max_seq_len=decoder_max_seq_len, trg_vocab_size=decoder_trg_vocab_size, n_layers=decoder_n_layers, n_heads=decoder_n_heads, embed_dim=decoder_embed_dim, dropout=decoder_dropout, forward_expansion=decoder_forward_expansion)
        self.duration_predictor = VariancePredictor(embed_size=variance_predictor_embed_dim, out_dim=1, kernel_size=variance_predictor_kernel_size, dropout=variance_predictor_dropout)
        self.length_regulator = LengthRegulator(max_trg_len=length_regulator_max_trg_len)
        self.hidden2mel = Hidden2Mel(in_dim=hidden2mel_in_dim, out_dim=hidden2mel_out_dim)
        self.postnet = PostNet(in_dim=postnet_in_dim, postnet_dim=postnet_postnet_dim, n_layers=postnet_n_layers, kernel_size=postnet_kernel_size)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _get_mask(self, seq_lens, max_seq_len):
    # seq_lens is B, seq_len
    # mask is B, 1, 1, max_seq_len
        B = seq_lens.shape[0]
        masks = torch.zeros(max_seq_len).repeat(B, 1, 1, 1)

        for b_idx in range(B):
            masks[b_idx, :] = torch.tensor( [ [ seq_len > idx  for idx in torch.arange(0, max_seq_len) ]  for seq_len in seq_lens[b_idx, :] ])

        return masks.to(device)

    def forward(self, src_seq, src_seq_len, trg_seq, trg_seq_len, trg_durations):
        # the src seq is the list of phonemes, the trg_seq is the list of mel_specs
        #enc_out = self.encoder(torch.ones(5, 100, 512, dtype = torch.int32))
        
        src_masks = self._get_mask(src_seq_len.reshape(-1, 1), self.encoder_max_seq_len)
        trg_masks = self._get_mask(trg_seq_len.reshape(-1, 1), self.decoder_max_seq_len)
        
        enc_out, _ = self.encoder(src_seq, src_masks)
        pred_durations = self.duration_predictor(enc_out)
        adapted_enc_out = self.length_regulator(enc_out, trg_durations)
        dec_out, _ = self.decoder(adapted_enc_out, trg_masks)
        h2m_out = self.hidden2mel(dec_out)
        pred_mel = self.postnet(h2m_out.permute(0, 2, 1))

        return pred_mel, h2m_out, pred_durations 
    
src_len, trg_len, dur, phone, mel = next(iter(train_loader))
m  = FastSpeech(config=model_config)
# pred_mel, h2m_out, pred_durations = m(phone, src_len, mel, trg_len, dur)
# lo = FastSpeechLoss(config=model_config)
# lo(src_len, trg_len, pred_durations, dur, h2m_out, pred_mel, mel)

In [33]:
# class FastSpeechLoss(nn.Module):
#     def __init__(self, config):
#         super().__init__()

#         self.encoder_max_seq_len = config['model']['encoder']['max_seq_len']
#         self.decoder_max_seq_len = config['model']['decoder']['max_seq_len']

#         self.duration_loss = nn.MSELoss()
#         self.mel_loss = nn.MSELoss()
#         self.h2m_loss = nn.MSELoss()

#     def _get_mask(self, seq_lens, max_seq_len):
#     # durations will be B, seq_len
#     # mels will be B, mel_channels, seq_len
#     # seq_lens is B, seq_len
#     # mask is B, 1, 1, max_seq_len
#         B = seq_lens.shape[0]
#         masks = torch.zeros(max_seq_len).repeat(B, 1, 1)

#         for b_idx in range(B):
#             masks[b_idx, :] = torch.tensor( [ [ seq_len > idx  for idx in torch.arange(0, max_seq_len) ]  for seq_len in seq_lens[b_idx, :] ])

#         return masks

#     def forward(self, src_seq_len, trg_seq_len, pred_durations, trg_durations, h2m_pred_mels, pred_mels, trg_mels):

#         duration_mask = self._get_mask(src_seq_len.reshape(-1, 1), self.encoder_max_seq_len)
#         mel_mask = self._get_mask(trg_seq_len.reshape(-1, 1), self.decoder_max_seq_len)

#         h2m_pred_mels = h2m_pred_mels.permute(0, 2, 1)
#         pred_mels = pred_mels
#         pred_durations = pred_durations.permute(0, 2, 1)
#         trg_mels = trg_mels.permute(0, 2, 1)
#         if duration_mask is not None:
#             pred_durations = pred_durations.masked_fill(duration_mask == 0, float('0'))

#         if mel_mask is not None:
#             pred_mels = pred_mels.masked_fill(mel_mask == 0, float('0'))
#             h2m_pred_mels = h2m_pred_mels.masked_fill(mel_mask == 0, float('0'))

#         duration_loss = self.duration_loss(pred_durations.squeeze(1), trg_durations)
#         h2m_loss = self.h2m_loss(h2m_pred_mels, trg_mels)
#         mel_loss = self.mel_loss(pred_mels, trg_mels)

#         return duration_loss, h2m_loss, mel_loss

# # src_seq_len, trg_seq_len, pred_durations, trg_durations, h2m_pred_mels, pred_mels, trg_mels)
# lo = FastSpeechLoss(config=model_config)
# lo(src_len, trg_len, pred_durations, dur, h2m_out, pred_mel, mel)

torch.Size([4, 1, 200]) torch.Size([4, 200]) torch.Size([4, 200])


(tensor(29.8330, grad_fn=<MseLossBackward0>),
 tensor(4999.7119, grad_fn=<MseLossBackward0>),
 tensor(5000.2305, grad_fn=<MseLossBackward0>))

In [21]:
pred_durations.shape

torch.Size([4, 200, 1])

In [180]:
src_len, trg_len, dur, phone, mel = next(iter(train_loader))
m  = FastSpeech(config=model_config)
src_len

tensor([101,  66,  36,  61,  59,  63,  96,  73,  47,  86,  57,  36,  83,  71,
         63,  82,  42,  48,  86,  71,  95,  47,  51,  76,  60,  51,  73,  46,
         98,  34,  72,  67, 106,  96,  52,  97,  60,  30,  79,  99,  90,  98,
         62,  66,  69, 104,  80,  37, 108,  60,  71,  51,  98,  18,  62,  42,
         74,  67, 100,  39,  54,  34,  41,  91])

In [20]:
class FastSpeechModule(pl.LightningModule):
    def __init__(self, model_config_path):
        super().__init__()

        model_config = yaml.load(open(model_config_path, 'r'), Loader=yaml.FullLoader)

        self.save_hyperparameters()
        self.model = FastSpeech(config=model_config)
        self.loss_module = FastSpeechLoss(config=model_config)
        
    def forward(self, src_seq, src_seq_len, trg_seq, trg_seq_len, trg_durations):
        return self.model(src_seq, src_seq_len, trg_seq, trg_seq_len, trg_durations)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[100, 150], gamma=0.1)
        return [optimizer], [scheduler]
    
    def training_step(self,  batch, batch_idx):
        #print(batch)'
        #print('THIS IS A TRAINING STEP!!!!!!!!!!!!')
        #src_len, trg_len, duration, phone_mapping, mel 
        src_seq_len, trg_seq_len, trg_durations, src_seq, trg_seq = batch
        #src_seq, src_seq_len, trg_seq, trg_seq_len, trg_durations = batch 
        print(src_seq.shape, trg_seq.shape)
        pred_mel, pred_h2m, pred_durations = self.model(src_seq, src_seq_len, trg_seq, trg_seq_len, trg_durations)
        dur_loss, h2m_loss, mel_loss = self.loss_module(src_seq_len, trg_seq_len, pred_durations, trg_durations, pred_h2m, pred_mel, trg_seq)

        self.log('train_dur_error', dur_loss.item(), rank_zero_only=True)
        self.log('train_h2m_error', h2m_loss.item(), rank_zero_only=True)
        self.log('train_mel_error', mel_loss.item(), rank_zero_only=True)

        return dur_loss + h2m_loss + mel_loss
    
    def validation_step(self, batch, batch_idx):
        src_seq_len, trg_seq_len, trg_durations, src_seq, trg_seq = batch
        pred_mel, pred_h2m, pred_durations = self.model(src_seq, src_seq_len, trg_seq, trg_seq_len, trg_durations)
        dur_loss, h2m_loss, mel_loss = self.loss_module(src_seq_len, trg_seq_len, pred_durations, trg_durations, pred_h2m, pred_mel, trg_seq)

        self.log('val_dur_error', dur_loss.item(), rank_zero_only=True)
        self.log('val_h2m_error', h2m_loss.item(), rank_zero_only=True)
        self.log('val_mel_error', mel_loss.item(), rank_zero_only=True)
    
    def test_step(self, batch, batch_idx):
        src_seq_len, trg_seq_len, trg_durations, src_seq, trg_seq = batch
        pred_mel, pred_h2m, pred_durations = self.model(src_seq, src_seq_len, trg_seq, trg_seq_len, trg_durations)
        dur_loss, h2m_loss, mel_loss = self.loss_module(src_seq_len, trg_seq_len, pred_durations, trg_durations, pred_h2m, pred_mel, trg_seq)

        self.log('test_dur_error', dur_loss.item())
        self.log('test_h2m_error', h2m_loss.item())
        self.log('test_mel_error', mel_loss.item())

#FastSpeechModule(model_config_path='../config/model.yaml')

In [21]:
train_dataset = LJSpeechDataset(os.path.join(input_dir, "LJSpeech"), os.path.join(raw_path, 'metadata.csv'), max_src_len=200, max_trg_len=1000, test_batch=True)
val_dataset = LJSpeechDataset(os.path.join(input_dir, "LJSpeech"), os.path.join(raw_path, 'metadata.csv'), max_src_len=200, max_trg_len=1000, test_batch=True)

# set shuffle to false since we already shuffle when the data is split into train/test
train_loader = DataLoader(train_dataset, shuffle=False, batch_size=len(train_dataset))
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=len(val_dataset))

In [12]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [22]:
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

CHECKPOINT_PATH = 'B:\Masters FS2\checkpoints'

def train_model(save_name='FastSpeech', **kwargs):
    """
    Inputs:
        model_name - Name of the model you want to run. Is used to look up the class in "model_dict"
        save_name (optional) - If specified, this name will be used for creating the checkpoint and logging directory.
    """
    

    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, save_name),                          # Where to save models
                         accelerator="gpu" if torch.cuda.is_available() else "cpu",                     # We run on a GPU (if possible)
                         devices=1,                                                                          # How many GPUs/CPUs we want to use (1 is enough for the notebooks)
                         max_epochs=2,               
                                                                               num_sanity_val_steps=0,                                                      # How many epochs to train for if no patience is set
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_mel_error"),  # Save the best checkpoint based on the maximum val_acc recorded. Saves only weights and not optimizer
                                    LearningRateMonitor("epoch")],                                           # Log learning rate every epoch
                         enable_progress_bar=True,
                         )                                                           # Set to False if you do not want a progress bar
    trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, save_name + ".ckpt")
    if os.path.isfile(pretrained_filename):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        model = FastSpeechModule.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameters
    else:
        pl.seed_everything(42) # To be reproducable
        model = FastSpeechModule(model_config_path='../config/model.yaml')
        trainer.fit(model, train_loader, val_loader)
        #model = FastSpeechModule.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    # Test best model on validation and test set
    val_result = trainer.test(model, val_loader, verbose=False)
    return model, val_result

train_model()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 42
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type           | Params
-----------------------------------------------
0 | model       | FastSpeech     | 27.6 M
1 | loss_module | FastSpeechLoss | 0     
-----------------------------------------------
27.6 M    Trainable params
0         Non-trainable params
27.6 M    Total params
110.371   Total estimated model params size (MB)


Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s] 

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


torch.Size([16, 200]) torch.Size([16, 1000, 80])
Epoch 1:   0%|          | 0/2 [00:00<?, ?it/s, loss=5.82e+04, v_num=19]        torch.Size([16, 200]) torch.Size([16, 1000, 80])
Epoch 1: 100%|██████████| 2/2 [00:12<00:00,  6.30s/it, loss=5.81e+04, v_num=19]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 2/2 [00:13<00:00,  6.61s/it, loss=5.81e+04, v_num=19]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Testing DataLoader 0: 100%|██████████| 1/1 [00:04<00:00,  4.35s/it]


(FastSpeechModule(
   (model): FastSpeech(
     (encoder): TransformerEncoder(
       (phone_embedding): Embedding(99, 256, padding_idx=0)
       (encoder): Sequential(
         (0): TransformerEncoderLayer(
           (attention): MultiHeadAttention(
             (keys): Linear(in_features=128, out_features=128, bias=False)
             (queries): Linear(in_features=128, out_features=128, bias=False)
             (values): Linear(in_features=128, out_features=128, bias=False)
             (fc_out): Linear(in_features=256, out_features=256, bias=False)
           )
           (feed_forward): PositionWiseFeedForward(
             (layers): Sequential(
               (0): Conv1d(256, 512, kernel_size=(9,), stride=(1,), padding=(4,))
               (1): ReLU()
               (2): Conv1d(512, 256, kernel_size=(9,), stride=(1,), padding=(4,))
               (3): Dropout(p=0.4, inplace=False)
               (4): ReLU()
             )
           )
           (layer_norm_1): LayerNorm((256,), 

In [32]:
def get_mask(seq_lens, max_seq_len):
    # durations will be B, seq_len
    # mels will be B, mel_channels, seq_len
    # seq_lens is B, seq_len
    # mask is B, 1, 1, max_seq_len
    B = seq_lens.shape[0]
    masks = torch.zeros(max_seq_len).repeat(B, 1, 1)

    for b_idx in range(B):
        masks[b_idx, :] = torch.tensor( [ [ seq_len > idx  for idx in torch.arange(0, max_seq_len) ]  for seq_len in seq_lens[b_idx, :] ])

    return masks

mask = get_mask(src_len.reshape(-1, 1), 200)
mask.shape

torch.Size([4, 1, 200])

In [40]:
def get_mask(seq_lens, max_seq_len):
    # durations will be B, seq_len
    # mels will be B, mel_channels, seq_len
    # seq_lens is B, seq_len
    # mask is B, 1, 1, max_seq_len
    B = seq_lens.shape[0]
    masks = torch.zeros(max_seq_len).repeat(B, 1, 1)

    for b_idx in range(B):
        masks[b_idx, :] = torch.tensor( [ [ seq_len > idx  for idx in torch.arange(0, max_seq_len) ]  for seq_len in seq_lens[b_idx, :] ])

    return masks

mask = get_mask(trg_len.reshape(-1, 1), 1000)
mask.shape

torch.Size([4, 1, 1000])

In [68]:
mel.permute(0, 2, 1).masked_fill(mask == 0, float(-6))[0][:, 723:]

tensor([[ 4.0496e-05, -6.0000e+00, -6.0000e+00,  ..., -6.0000e+00,
         -6.0000e+00, -6.0000e+00],
        [ 5.2617e-04, -6.0000e+00, -6.0000e+00,  ..., -6.0000e+00,
         -6.0000e+00, -6.0000e+00],
        [ 9.6685e-04, -6.0000e+00, -6.0000e+00,  ..., -6.0000e+00,
         -6.0000e+00, -6.0000e+00],
        ...,
        [ 6.7222e-04, -6.0000e+00, -6.0000e+00,  ..., -6.0000e+00,
         -6.0000e+00, -6.0000e+00],
        [ 3.4299e-04, -6.0000e+00, -6.0000e+00,  ..., -6.0000e+00,
         -6.0000e+00, -6.0000e+00],
        [ 1.3991e-04, -6.0000e+00, -6.0000e+00,  ..., -6.0000e+00,
         -6.0000e+00, -6.0000e+00]])

In [57]:
mask.shape

torch.Size([4, 1, 1000])

In [38]:
dur.reshape(4, 1, 200).masked_fill(mask == 0, float(-6))[0]

tensor([[ 3,  6,  4,  3,  5,  6,  4,  7,  7,  8,  7,  4,  3,  4,  5,  8,  7,  5,
          5,  8,  9,  5,  4,  4,  7,  4, 10,  8,  5,  3,  7, 11,  6,  6, 11, 29,
          6,  2,  6,  3,  4, 10, 17,  8,  2,  3,  6, 10,  5, 13, 11,  4, 10,  8,
          9,  7, 13,  6,  9,  3,  6,  3,  8,  7,  8,  5,  7,  5,  4,  5,  7,  5,
          5,  7,  9, 10,  5,  5,  9,  5,  3,  2,  9, 11, 10,  4,  4, 10,  6, 11,
          8,  3, 10, 13,  4,  8,  9,  3, 24, 13, 21, -6, -6, -6, -6, -6, -6, -6,
         -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6,
         -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6,
         -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6,
         -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6,
         -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6, -6,
         -6, -6]], dtype=torch.int32)

In [201]:
class LengthRegulator(nn.Module):
    def __init__(self, max_trg_len=5):
        super().__init__()

        self.max_trg_len = max_trg_len

    def forward(self, encoder_output, variance):
        B = encoder_output.shape[0]
        mels = list()

        for b_idx in range(B):
            expanded_seq = torch.concat([ (i+1) * encoder_output[b_idx,i,:].expand(v, -1) for i, v in enumerate(variance[b_idx, :]) ], dim=0)
            seq_len = expanded_seq.shape[0]
            pad_len = self.max_trg_len - seq_len

            if pad_len < 0:
                padded_seq = expanded_seq[:self.max_trg_len, :]
            elif pad_len > 0:
                padded_seq = F.pad( expanded_seq, (0, 0, 0, pad_len), "constant", 0 )

            mels.append(padded_seq)


        expanded_batch = torch.stack(mels, dim=0)
        return expanded_batch
    
LengthRegulator()(torch.ones(2, 3, 256), torch.tensor([ [2, 1, 0], [2, 1, 1] ]))[0]

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [2., 2., 2.,  ..., 2., 2., 2.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [155]:
def get_mask(seq_lens, max_seq_len):
    # seq_lens is B, seq_len
    # mask is B, 1, 1, max_seq_len
    B = seq_lens.shape[0]
    masks = torch.zeros(max_seq_len).repeat(B, 1, 1, 1)

    for b_idx in range(B):
        masks[b_idx, :] = torch.tensor( [ [ seq_len > idx  for idx in torch.arange(0, max_seq_len) ]  for seq_len in seq_lens[b_idx, :] ])

    return masks

    

m = get_mask(torch.tensor([[3], [5]]), 10)
m

IndexError: too many indices for tensor of dimension 1

In [154]:
model_config = yaml.load(open('../config/model.yaml', 'r'), Loader=yaml.FullLoader)

torch.Size([2, 1])

In [636]:
#B, n_heads, query_len, key_len



tensor([[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],

         [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],

         [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],

         [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
      

In [166]:
x = torch.ones(2, 10, 256)
mask = get_mask(torch.tensor([[0], [1]]), 10)
out, attn = MultiHeadAttention()(x, x, x, mask)


In [627]:
for row in attn[0][0]:
    print(row.detach().numpy())

[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]


In [None]:
# class TransformerDecoderLayer(nn.Module):
#     def __init__(self, n_heads=8, embed_dim=256, dropout=0.5, forward_expansion=2):
#         super().__init__()

#         self.attention_1 = MultiHeadAttention(n_heads=n_heads, embed_dim=embed_dim)
#         self.attention_2 = MultiHeadAttention(n_heads=n_heads, embed_dim=embed_dim)
#         self.feed_forward = PositionWiseFeedForward(embed_dim=embed_dim, forward_expansion=forward_expansion, dropout=dropout)
#         self.layer_norm_1 = nn.LayerNorm(embed_dim)
#         self.layer_norm_2 = nn.LayerNorm(embed_dim)
#         self.layer_norm_3 = nn.LayerNorm(embed_dim)
#         self.dropout = nn.Dropout(dropout)

#     def forward(self, enc_seq, trg_seq, src_mask=None, trg_mask=None):
#         out, _ = self.attention_1(trg_seq, trg_seq, trg_seq, trg_mask)
#         out = self.dropout(self.layer_norm_1(out + trg_seq))
        
#         out_2, _ = self.attention_2(trg_seq, enc_seq, enc_seq, src_mask)
#         out_2 = self.dropout(self.layer_norm_2(out + out_2))

#         out_3 = self.feed_forward(out_2.permute(0, 2, 1))
#         out_3 = self.dropout(self.layer_norm_3(out_2 + out_3.permute(0, 2, 1)))
#         return out_3


# class TransformerDecoder(nn.Module):
#     def __init__(self, max_seq_len=1000, trg_vocab_size=256, n_layers=4, n_heads=8, embed_dim=256, dropout=0.5, forward_expansion=2):
#         self.pos_embedding = get_positional_encoding(max_seq_len, embed_dim)
#         self.mel_embedding = nn.Embedding(trg_vocab_size, embed_dim) # B, max_enc_seq_len, enc_embed_dim -> B, max_trg_seq_len 
#         self.decoder = nn.Sequential( *[ TransformerDecoderLayer(n_heads=n_heads, embed_dim=embed_dim, dropout=dropout, forward_expansion=forward_expansion) for _ in range(n_layers) ])

#     def forward(self, enc_out):
#         mel_embeddings = self.mel_embedding(phonemes) # max_seq_len, embed_dim
#         model_in = self.pos_embedding + phoneme_embeddings
#         out = self.encoder(model_in)
#         return out

#TransformerDecoderLayer()(torch.randn(5, 100, 256), torch.randn(5, 1000, 256)).shape
#LengthRegulator()(torch.randn(3, 5, 256), torch.tensor([ [1, 1, 2, 1, 1], [2, 1, 1, 1, 1], [1, 1, 1, 1, 6] ])).shape
#VariancePredictor()(torch.randn(5, 100, 256))
#TransformerEncoder()(torch.zeros(200, dtype=torch.int32).unsqueeze(0)).shape
#TransformerEncoderLayer()(torch.randn(5, 100, 256))


In [374]:
#torch.tensor([[1,1,1,1], [2,2,2,2], [3,3,3,3], [4,4,4,4]])
variance = torch.tensor([3, 1, 3, 1])
t = torch.tensor([[1,2,3,4], [1,2,3,4], [1,2,3,4], [1,2,3,4], [1,2,3,4], [1,2,3,4]])
torch.concat([ t[:,i].expand(v, -1) for i, v in enumerate(variance) ], dim=0).T

tensor([[1, 1, 1, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 3, 3, 3, 4]])

In [400]:
# B, seq_len, embed_dim 

b = torch.range(1, 5, dtype=torch.int32).unsqueeze(0).repeat(256, 1).T

  torch.ones(256, dtype=torch.int32).unsqueeze(0).repeat((5, 1)) * torch.range(1, 5, dtype=torch.int32).unsqueeze(0).repeat(256, 1).T


tensor([[1, 1, 1,  ..., 1, 1, 1],
        [2, 2, 2,  ..., 2, 2, 2],
        [3, 3, 3,  ..., 3, 3, 3],
        [4, 4, 4,  ..., 4, 4, 4],
        [5, 5, 5,  ..., 5, 5, 5]], dtype=torch.int32)

In [410]:
def tokens_gen(num_tokens, embed_dim=256):
    return torch.range(1, num_tokens, dtype=torch.int32).unsqueeze(0).repeat(embed_dim, 1).T.unsqueeze(0)

In [415]:
torch.concat([tokens_gen(5), tokens_gen(3), tokens_gen(8)], dim=1).shape

  return torch.range(1, num_tokens, dtype=torch.int32).unsqueeze(0).repeat(embed_dim, 1).T.unsqueeze(0)


torch.Size([1, 16, 256])

In [413]:
tokens_gen(5).shape

  return torch.range(1, num_tokens, dtype=torch.int32).unsqueeze(0).repeat(embed_dim, 1).T.unsqueeze(0)


torch.Size([1, 5, 256])

In [419]:
#torch.tensor([[1,1,1,1], [2,2,2,2], [3,3,3,3], [4,4,4,4]])
variance = torch.tensor([3, 1, 3, 1])
t = torch.tensor([[1,2,3,4], [1,2,3,4], [1,2,3,4], [1,2,3,4], [1,2,3,4], [1,2,3,4]])
torch.concat([ t[:,i].expand(v, -1) for i, v in enumerate(variance) ], dim=0).T

tensor([[1, 1, 1, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 3, 3, 3, 4],
        [1, 1, 1, 2, 3, 3, 3, 4]])

In [484]:
batch = torch.range(1, 5, dtype=torch.int32).unsqueeze(0).repeat(256, 1).T.unsqueeze(0).repeat(3, 1, 1)
v = torch.tensor([ [1, 1, 2, 1, 1], [2, 1, 1, 1, 1], [1, 1, 1, 1, 6] ])
max_trg_len = 15

def expand_seqs(encoder_output, variances):
    B = encoder_output.shape[0]
    mels = list()

    for b_idx in range(B):
        expanded_seq = torch.concat([ encoder_output[b_idx,i,:].expand(v, -1) for i, v in enumerate(variances[b_idx, :]) ], dim=0)
        seq_len = expanded_seq.shape[0]
        pad_len = max_trg_len - seq_len

        if pad_len < 0:
            padded_seq = expanded_seq[:max_trg_len, :]
        elif pad_len > 0:
            padded_seq = F.pad( expanded_seq, (0, 0, 0, pad_len), "constant", -1 )

        mels.append(padded_seq)


    expanded_batch = torch.stack(mels, dim=0)
    return expanded_batch


print(batch.shape)
print(expand_seqs(batch, v).shape)

torch.Size([3, 5, 256])
torch.Size([3, 15, 256])


  batch = torch.range(1, 5, dtype=torch.int32).unsqueeze(0).repeat(256, 1).T.unsqueeze(0).repeat(3, 1, 1)


In [460]:
batch[0]

tensor([[1, 1, 1,  ..., 1, 1, 1],
        [2, 2, 2,  ..., 2, 2, 2],
        [3, 3, 3,  ..., 3, 3, 3],
        [4, 4, 4,  ..., 4, 4, 4],
        [5, 5, 5,  ..., 5, 5, 5]], dtype=torch.int32)

In [471]:
F.pad( batch[0], (0, 0, 0, 2), "constant", -1 ).shape

torch.Size([7, 256])

In [445]:
batch[0][0].expand(2, -1).shape

torch.Size([2, 256])

In [447]:
batch[0].shape

torch.Size([5, 256])

In [450]:
torch.concat([batch[0], batch[0][0].expand(2, -1)]).shape

torch.Size([7, 256])