In [57]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import numpy as np

from config import Config

In [58]:
class Config:
    # default
    base_path = '../LJSpeech-1.1'
    preprocess_base_path = base_path + '/preprocessed'

    # preprocess
    wav_paths = preprocess_base_path + '/paths'
    mel_paths = preprocess_base_path + '/mels'
    spec_paths = preprocess_base_path + '/specs'
    transcript_paths = preprocess_base_path + '/transcripts'
    phoneme_paths = preprocess_base_path + '/phonemes'

    # metadata
    data_path = base_path + '/wavs'
    metadata_path = base_path + '/metadata.csv'

    # train
    batch_size = 16
    

    # model
    num_phonemes = 70
    num_mels = 80
    embedding_dim = 512
    d_model = 512
    num_heads = 8
    num_encoder_layers = 6
    num_decoder_layers = 6

In [59]:
import librosa
import matplotlib.pyplot as plt

def visualize_specs(S, mel_S, sr):
    plt.figure(figsize=(12, 8))
    plt.subplot(2, 1, 1)
    librosa.display.specshow(S, sr=sr, x_axis='time', y_axis='log')
    plt.colorbar(format='%+2.0f dB')
    plt.title('Spectrogram')

    plt.subplot(2, 1, 2)
    librosa.display.specshow(mel_S, sr=sr, x_axis='time', y_axis='mel')
    plt.colorbar(format='%+2.0f dB')
    plt.title('Mel spectrogram')
    plt.tight_layout()
    plt.show()

In [60]:
class TransformerTTSDataset(Dataset):
    def __init__(self):
        self.phoneme_to_index = {}
        with open(Config.metadata_path, 'r') as f:
            lines = f.readlines()
            self.wav_names = [line.split('|')[0] for line in lines]

    def __len__(self):
        return len(self.wav_names)

    def __getitem__(self, idx):
        wav_name = self.wav_names[idx]

        phoneme = np.load(f'{Config.phoneme_paths}/{wav_name}.npy')
        spectrogram = np.load(f'{Config.spec_paths}/{wav_name}.npy')
        melspectrogram = np.load(f'{Config.mel_paths}/{wav_name}.npy')

        # phoneme to index
        for ph in phoneme:
            if ph not in self.phoneme_to_index:
                self.phoneme_to_index[ph] = len(self.phoneme_to_index) + 1

        phoneme_seq = [self.phoneme_to_index[ph] for ph in phoneme]

        phoneme_seq = torch.LongTensor(phoneme_seq)
        spectrogram = torch.FloatTensor(spectrogram)
        melspectrogram = torch.FloatTensor(melspectrogram)
        
        return (
            phoneme_seq,
            spectrogram,
            melspectrogram
        )

In [61]:
def pad_sequence1D(seq):
    return nn.utils.rnn.pad_sequence(sequences=seq, batch_first=True, padding_value=0)


def pad_sequence2D(seqs):
    B                   = len(seqs)
    T                   = len(seqs[0])
    max_len             = max([len(seq[0]) for seq in seqs])
    padded_mel          = torch.zeros(B, T, max_len, ) # 멜 스펙트로그램 차원 맞춰주기
    for i, seq in enumerate(seqs):
        # padded_mel[i, :, :] = -80.0
        padded_mel[i, :seq.size(0), :seq.size(1)] = seq
    return padded_mel


def pad_sequence1D_stops(seqs):
    for seq in seqs:
        seq[-1] = 1
    return nn.utils.rnn.pad_sequence(sequences=seqs, batch_first=True, padding_value=1)

def collate_fn(batch):
    seqs = [item[0] for item in batch]
    spectrograms = [item[1] for item in batch]
    melspectrograms = [item[2] for item in batch]
    stops = [torch.zeros_like(mel[0]) for mel in melspectrograms]

    seq_lengths = torch.LongTensor([len(seq) for seq in seqs])
    mel_lengths = torch.LongTensor([len(mel[0]) for mel in melspectrograms])

    padded_seqs = pad_sequence1D(seqs)
    padded_spectrograms = pad_sequence2D(spectrograms)
    padded_melspectrograms = pad_sequence2D(melspectrograms)
    padded_stop_seqs = pad_sequence1D_stops(stops)

    return (
        padded_seqs, 
        padded_spectrograms, 
        padded_melspectrograms, 
        padded_stop_seqs, 
        seq_lengths,
        mel_lengths
    )

In [62]:
class Conv1DBN(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size, padding, act=None):
        super(Conv1DBN, self).__init__()
        self.conv_1d = nn.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=padding)
        self.bn = nn.BatchNorm1d(out_dim)
        self.dropout = nn.Dropout(0.5)
        self.activation = act

    def forward(self, x):
        x = self.conv_1d(x)
        x = self.bn(x)
        x = self.dropout(x)

        if self.activation is not None:
            x = self.activation(x)

        return x

In [63]:
class Prenet(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(Prenet, self).__init__()
        self.in_dim = in_dim
        self.relu = nn.ReLU()
        self.layers = nn.Sequential(
            Conv1DBN(in_dim, hidden_dim, 5, 5//2, self.relu),
            Conv1DBN(hidden_dim, hidden_dim, 5, 5//2, self.relu),
            Conv1DBN(hidden_dim, out_dim, 5, 5//2, self.relu)
        )
        self.linear = nn.Linear(out_dim, out_dim)

    def forward(self, x):

        x = x.transpose(1, 2)
        x = self.layers(x)
        x = x.transpose(1, 2)
        return self.linear(x)

In [64]:
class ScaledPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(ScaledPositionalEncoding, self).__init__()
        self.d_model = d_model
        self.max_len = max_len

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

        self.alpha = nn.Parameter(torch.ones(1))

    def forward(self, x):
        return x + self.alpha * self.pe[:, :x.size(1)]

In [65]:
class PostNet(nn.Module):
    def __init__(self):
        super(PostNet, self).__init__()
        self.tanh = nn.Tanh()
        self.layers = nn.Sequential(
            Conv1DBN(80, 512, 5, 5//2, self.tanh),
            Conv1DBN(512, 512, 5, 5//2, self.tanh),
            Conv1DBN(512, 512, 5, 5//2, self.tanh),
            Conv1DBN(512, 512, 5, 5//2, self.tanh),
            Conv1DBN(512, 80, 5, 5//2)
        )
        self.linear = nn.Linear(80, 80)

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.layers(x)
        x = x.transpose(1, 2)
        return self.linear(x)

In [66]:
class TransformerTTS(nn.Module):
   
    def __init__(self):
        super(TransformerTTS, self).__init__()
        self.phoneme_embedding = nn.Embedding(num_embeddings=Config.num_phonemes, embedding_dim=Config.embedding_dim)
        self.encoder_prenet = Prenet(Config.embedding_dim, 512, 512)
        self.decoder_prenet = Prenet(Config.num_mels, 512, 512)
        
        self.scaled_positional_encoding = ScaledPositionalEncoding(Config.d_model)
        self.transformer = nn.Transformer(
            d_model=Config.d_model,
            num_encoder_layers=Config.num_encoder_layers,
            num_decoder_layers=Config.num_decoder_layers,
            nhead=Config.num_heads,
            batch_first=True,
        )
        
        self.mel_linear = nn.Linear(Config.d_model, Config.num_mels)
        self.stop_linear = nn.Linear(Config.d_model, 1)
        self.sigmoid = nn.Sigmoid()

        self.postnet = PostNet()


    def forward(self, phoneme_sequences, mel_spectrograms, phoneme_sequence_lengths=None, mel_spectrogram_lengths=None):
        self.initialize_masks(x_lengths=phoneme_sequence_lengths, y_lengths=mel_spectrogram_lengths)

        phoneme_embeddings = self.phoneme_embedding(phoneme_sequences)
        print('phoneme을 512차원 embedding vector로 변환', phoneme_embeddings.shape)
              
        prenet_phoneme_embeddings = self.encoder_prenet(phoneme_embeddings)
        print('encoder_prenet_out', prenet_phoneme_embeddings.shape)

        positional_phoneme_embeddings = self.scaled_positional_encoding(prenet_phoneme_embeddings)
        print('positional_encoded', positional_phoneme_embeddings.shape)

        decoder_input_mel_spectrograms = mel_spectrograms.transpose(1, 2)
        prenet_mel_spectrograms = self.decoder_prenet(decoder_input_mel_spectrograms)
        print('decoder_prenet_out', prenet_mel_spectrograms.shape)

        positional_mel_spectrograms = self.scaled_positional_encoding(prenet_mel_spectrograms)
        print('positional_encoded', positional_mel_spectrograms.shape)
        
        transformer_out = self.transformer(
            src=positional_phoneme_embeddings, 
            tgt=positional_mel_spectrograms,
            src_mask=self.src_mask,
            src_key_padding_mask=self.src_key_padding_mask,
            tgt_mask=self.tgt_mask,
            tgt_key_padding_mask=self.tgt_key_padding_mask,
            memory_mask=self.memory_mask
        )
        print('transformer_out', transformer_out.shape)

        mel_linear_out = self.mel_linear(transformer_out)
        print('mel_linear_out', mel_linear_out.shape)

        postnet_out = self.postnet(mel_linear_out)
        print('postnet_out', postnet_out.shape)

        mel_pred = mel_linear_out + postnet_out

        stop_token = self.sigmoid(self.stop_linear(transformer_out))
        print('stop_token', stop_token.shape)
        
        mel_pred = mel_pred.transpose(1, 2)
        stop_token = stop_token.squeeze(-1)

        # mel_pred에서 seq_len을 넘어가는 부분은 모두 -80.0 으로 채워준다.
        # for i, seq_len in enumerate(x_legnths):
        #     mel_pred[i, :, seq_len:] = -80.0

        mel_pred.data.masked_fill_(
            self.tgt_key_padding_mask.unsqueeze(-1).repeat(1, 1, mel_pred.size(-1)), 0.0)
        stop_token.data.masked_fill_(
            self.tgt_key_padding_mask.unsqueeze(-1), 1e3)

        return mel_pred, stop_token
    

    def generate_square_subsequent_mask(self, lsz, rsz):
        return torch.triu(torch.ones(lsz, rsz) * float('-inf'), diagonal=1)
    

    def generate_padding_mask(self, lengths, max_len=None):
        batch_size = lengths.size(0)
        if max_len is None:
            max_len = torch.max(lengths).item()
        ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(dtype=lengths.dtype, device=lengths.device)
        return ids >= lengths.unsqueeze(1).expand(-1, max_len)
    
    def initialize_masks(self, x_lengths=None, y_lengths=None):
        self.src_mask = None
        self.tgt_mask = None
        self.memory_mask = None
        self.src_key_padding_mask = None
        self.tgt_key_padding_mask = None
        if x_lengths is not None:
            S = x_lengths.max().item()
            self.src_mask = self.generate_square_subsequent_mask(S, S).to(device=x_lengths.device)         # text sequence self-attention mask
            self.src_key_padding_mask = self.generate_padding_mask(x_lengths).to(device=x_lengths.device)  # text sequence padding mask
        if y_lengths is not None:
            T = y_lengths.max().item()
            self.tgt_mask = self.generate_square_subsequent_mask(T, T).to(device=y_lengths.device)         # mel sequence self-attention mask
            self.tgt_key_padding_mask = self.generate_padding_mask(y_lengths).to(device=y_lengths.device)  # mel sequence padding mask
        if x_lengths is not None and y_lengths is not None:
            T = y_lengths.max().item()
            S = x_lengths.max().item()
            self.memory_mask = self.generate_square_subsequent_mask(T, S).to(device=y_lengths.device)      # text-mel cross attention mask
            

In [67]:
class TransformerTTSLoss():
    def __init__(self):
        self.mel_loss = nn.MSELoss()
        self.stop_loss = nn.BCELoss()
        self.alpha = 5.0

    def __call__(self, mel_pred, mel_target, stop_pred, stop_target):
        mel_loss = self.mel_loss(mel_pred, mel_target)
        stop_loss = self.stop_loss(stop_pred, stop_target)
        return mel_loss + stop_loss * self.alpha

In [68]:
transformer_tts_dataset = TransformerTTSDataset()
dataloader = DataLoader(transformer_tts_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

model = TransformerTTS()
criterion = TransformerTTSLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

print(len(dataloader))

# 데이터로더에서 데이터를 가져와서 확인하기
for i, data in enumerate(dataloader):
    phoneme, spec, mel, stop, phoneme_len, mel_len = data
    print(f'phoneme {phoneme.shape}, spec: {spec.shape}, mel: {mel.shape}, stop: {stop.shape}, phoneme_len: {phoneme_len}, mel_len: {mel_len}')
    print(phoneme_len)

    print("- model information")
    mel_pred, stop_pred = model(phoneme, mel, phoneme_len, mel_len)

    loss = criterion(mel_pred, mel, stop_pred, stop)
    print(f'loss: {loss.item()}')

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if i % 2 == 0:
        visualize_specs(mel_pred[0].detach().numpy(), mel[0].detach().numpy(), sr=22050)

3275
phoneme torch.Size([4, 116]), spec: torch.Size([4, 1025, 354]), mel: torch.Size([4, 80, 354]), stop: torch.Size([4, 354]), phoneme_len: tensor([116,  78, 109, 114]), mel_len: tensor([354, 287, 344, 339])
tensor([116,  78, 109, 114])
- model information
phoneme을 512차원 embedding vector로 변환 torch.Size([4, 116, 512])
encoder_prenet_out torch.Size([4, 116, 512])
positional_encoded torch.Size([4, 116, 512])
decoder_prenet_out torch.Size([4, 354, 512])
positional_encoded torch.Size([4, 354, 512])




transformer_out torch.Size([4, 354, 512])
mel_linear_out torch.Size([4, 354, 80])
postnet_out torch.Size([4, 354, 80])
stop_token torch.Size([4, 354, 1])


RuntimeError: The size of tensor a (80) must match the size of tensor b (354) at non-singleton dimension 1

In [None]:
def generate_square_subsequent_mask(self, size=200): # Generate mask covering the top right triangle of a matrix
        mask = torch.triu(torch.full((size, size), float('-inf')), diagonal=1)
        return mask

print(generate_square_subsequent_mask(5))

tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
        [0., 0., -inf,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer

# Sample Dataset
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.tokenizer(self.data[idx], return_tensors='pt', padding=True, truncation=True)

# Collate function to handle variable lengths and padding
def collate_fn(batch):
    input_ids = [item['input_ids'].squeeze(0) for item in batch]
    attention_masks = [item['attention_mask'].squeeze(0) for item in batch]

    input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks = nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=0)

    return input_ids, attention_masks

# Sample data
data = ["Hello, how are you?", "I am fine, thank you!", "What about you?"]

# DataLoader
dataset = CustomDataset(data)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)

# Transformer model with masked attention
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_length, d_model))
        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt, src_mask, tgt_mask):
        src = self.embedding(src) + self.positional_encoding[:, :src.size(1), :]
        tgt = self.embedding(tgt) + self.positional_encoding[:, :tgt.size(1), :]
        output = self.transformer(src, tgt, src_key_padding_mask=src_mask, tgt_key_padding_mask=tgt_mask)
        return self.fc(output)

# Function to create masks
def create_mask(src, tgt):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    src_mask = (src != 0).unsqueeze(1).unsqueeze(2)  # (batch_size, 1, 1, src_seq_len)
    tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)  # (batch_size, 1, tgt_seq_len, 1)

    return src_mask, tgt_mask

# Initialize model, loss, and optimizer
vocab_size = 30522  # Vocabulary size of BERT tokenizer
d_model = 512
nhead = 8
num_encoder_layers = 3
num_decoder_layers = 3
dim_feedforward = 2048
max_seq_length = 50

model = TransformerModel(vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

# Training loop
for epoch in range(10):
    model.train()
    for batch in dataloader:
        input_ids, attention_masks = batch
        tgt_input = input_ids[:, :-1]
        tgt_output = input_ids[:, 1:]

        src_mask, tgt_mask = create_mask(input_ids, tgt_input)

        optimizer.zero_grad()
        output = model(input_ids, tgt_input, src_mask, tgt_mask)

        loss = criterion(output.view(-1, vocab_size), tgt_output.view(-1))
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")


ModuleNotFoundError: No module named 'transformers'