In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
pip install hickle

Collecting hickle
  Downloading hickle-5.0.3-py3-none-any.whl.metadata (22 kB)
Downloading hickle-5.0.3-py3-none-any.whl (107 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/108.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: hickle
Successfully installed hickle-5.0.3


In [3]:
from string import printable
import numpy as np
import pandas as pd
from csv import QUOTE_NONE
import hickle as hkl
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, ToTensor

In [4]:
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

**Chuẩn bị dữ liệu**

In [5]:
ds_metadata = pd.read_csv("/content/drive/MyDrive/project1/LJSpeech-1.1/LJSpeech/metadata.csv", sep = "|", quoting = QUOTE_NONE, names = ["WAV Name", "Transcripts", "Normalized Transcripts"])

In [6]:
ds_metadata.head()

Unnamed: 0,WAV Name,Transcripts,Normalized Transcripts
0,LJ001-0001,"Printing, in the only sense with which we are ...","Printing, in the only sense with which we are ..."
1,LJ001-0002,in being comparatively modern.,in being comparatively modern.
2,LJ001-0003,For although the Chinese took impressions from...,For although the Chinese took impressions from...
3,LJ001-0004,"produced the block books, which were the immed...","produced the block books, which were the immed..."
4,LJ001-0005,the invention of movable metal letters in the ...,the invention of movable metal letters in the ...


In [7]:
ds_metadata.drop('Transcripts', axis = 1, inplace = True)

In [8]:
ds_metadata['Normalized Transcripts'] = ds_metadata['Normalized Transcripts'].str.lower()

ds_metadata['Normalized Transcripts'] = ds_metadata['Normalized Transcripts'].str.replace(' +',' ', regex = True) \
                                                                                 .replace('ü','u', regex = True)  \
                                                                                 .replace('“','"', regex = True)  \
                                                                                 .replace('”', '"', regex = True) \
                                                                                 .replace('’', '\'', regex = True) \
                                                                                 .replace("i.e.", "i e ") \
                                                                                 .replace(";", "") \
                                                                                 .replace ("  ", " ")

In [9]:
#Từ
vocab = printable
for ch in ['\t', '\n', '\r', '\x0b', '\x0c']:
    vocab = vocab.replace(ch, '')

vocab = vocab[10:]
vocab = vocab.replace('ABCDEFGHIJKLMNOPQRSTUVWXYZ', '')
print(vocab)
print(len(vocab))

abcdefghijklmnopqrstuvwxyz!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ 
59


In [10]:
# mã hóa từ sang int

def ch_to_int(transcript):
    try:
        return [vocab.index(ch) for ch in transcript]
    except:
        return 'bad encoding'

ds_metadata['One-Hot Encoding'] = ds_metadata['Normalized Transcripts'].apply(ch_to_int)
ds_metadata = ds_metadata[ds_metadata['One-Hot Encoding'] != 'bad encoding']
ds_metadata = ds_metadata.reset_index()

In [13]:
import os
from tqdm import tqdm

# Kiểm tra GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Đang sử dụng:", device)

# Các thông số xử lý âm thanh
SAMPLE_RATE = 22050
N_FFT = 2048
FRAME_SIZE = 50
HOP_SIZE = 12.5
N_MELS = 80

# Load metadata
ds_metadata = pd.read_csv('/content/drive/MyDrive/project1/LJSpeech-1.1/LJSpeech/metadata.csv', sep='|', header=None)
ds_metadata.columns = ['WAV Name', 'Transcripts', 'Normalized Transcripts']
ds_metadata['WAV Name'] = ds_metadata['WAV Name'].str.strip()

def wav_to_melspec(root):
    data = []

    # MelSpectrogram to GPU
    w2ms_trans = torchaudio.transforms.MelSpectrogram(SAMPLE_RATE,
                                                      n_fft=N_FFT,
                                                      win_length= FRAME_SIZE * SAMPLE_RATE // 1000,
                                                      hop_length=int(HOP_SIZE * SAMPLE_RATE // 1000),
                                                      f_min=0,
                                                      f_max=8000,
                                                      n_mels=N_MELS,
                                                      window_fn=torch.hann_window,
                                                      power=1.0,
                                                      center=False,
                                                      norm='slaney',
                                                      mel_scale='slaney'
                                                      ).to(device)

    for i in tqdm(range(len(ds_metadata))):
        file_name = ds_metadata.loc[i, 'WAV Name']
        try:
            # Load và chuyển wave sang GPU
            wave, _ = torchaudio.load(os.path.join(root, file_name + ".wav"))
            wave = wave.to(device)

            # Chuyển sang MelSpec
            wav2melspec = w2ms_trans(wave)
            wav2melspec = torch.clip(wav2melspec, min=1e-5)
            amp2db = torch.log(wav2melspec)

            # Stop token
            stop_token = np.zeros(amp2db.shape[-1])
            stop_token[-1] = 1.0

            # Chuyển về CPU để lưu
            data.append((
                file_name,
                ds_metadata.loc[i, 'Normalized Transcripts'],
                amp2db.cpu().numpy()[0],
                stop_token
            ))

        except Exception as e:
            print(f"Lỗi {file_name}: {e}")

    hkl.dump(data, '/content/drive/MyDrive/project1/LJSpeech-1.1/LJSpeech/ljspeech_gpu.hkl', compression='gzip')

# Gọi hàm
root = "/content/drive/MyDrive/project1/LJSpeech-1.1/LJSpeech/wavs/"
wav_to_melspec(root)


Đang sử dụng: cuda


  2%|▏         | 222/13100 [01:34<1:31:17,  2.35it/s]


KeyboardInterrupt: 

In [None]:
!mv /content/ljspeech_gpu.hkl /content/drive/MyDrive/project1/LJSpeech-1.1/LJSpeech

**Tacotron2**

In [None]:
EPOCHS = 200
BATCH_SIZE = 64 # 16
VOCAB_SIZE = 59
EMBEDDING_SIZE = 512
NR_MELS = 80
LR = 0.001
EPS = 1e-08 #1e-06
WEIGHT_DECAY = 1e-06
THRESHOLD = 0.5
MAX_DEC_STEPS = 1000
CHKP_PATH = '/content/drive/MyDrive/project1/LJSpeech-1.1/LJSpeech/taco2_1000_ds_size.pth'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

**Dataset**

In [None]:
#@title Dataset { form-width: "22%" }
class LJSpeech(Dataset):
    def __init__(self, ds_path):
        self.ds = hkl.load(ds_path)
        self.sz = 12000 + 95 + 500 # train, valid, test
        self.ds = self.ds[:self.sz]

    def __len__(self):
        #return len(self.ds)
        return self.sz

    def __getitem__(self, index):
        transcript = torch.LongTensor(self.ds[index][2])
        melspec = torch.FloatTensor(self.ds[index][3])
        stop_token = torch.FloatTensor(self.ds[index][4])
        return transcript, melspec.transpose(1, 0), stop_token # transponse melspec to pad it in collate

    def min_max(self):
        min, max = 10000, -10000
        for (_, melspec, _, _, _) in train_loader:
            melspec_max = torch.max(melspec)
            if melspec_max > max:
                max = melspec_max

            melspec_min = torch.min(melspec)
            if melspec_min < min:
                min = melspec_min

        print(min, max)

def collate_pad(batch):
    transcripts, melspecs, stop_tokens = zip(*batch)
    mel_lens = torch.LongTensor([mel.shape[0] for mel in melspecs])
    trans_lens = torch.LongTensor([t.shape[0] for t in transcripts])

    transcripts = nn.utils.rnn.pad_sequence(transcripts, batch_first = True)
    melspecs = nn.utils.rnn.pad_sequence(melspecs, batch_first = True)
    stop_tokens = nn.utils.rnn.pad_sequence(stop_tokens, batch_first = True)
    for i, elem in enumerate(stop_tokens):
        elem[mel_lens[i] - 1:] = 1.0
    return transcripts, melspecs.permute(0, 2, 1), stop_tokens, mel_lens, trans_lens

**Encoder**

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        super().__init__()

        self.embedding_size = embedding_size

        self.embed = nn.Embedding(vocab_size, embedding_size)
        self.conv = nn.Sequential(
            nn.Conv1d(embedding_size, embedding_size, 5, padding = 2, bias = True),
            nn.BatchNorm1d(embedding_size),
            nn.ReLU(),
            nn.Dropout(),
            nn.Conv1d(embedding_size, embedding_size, 5, padding = 2, bias = True),
            nn.BatchNorm1d(embedding_size),
            nn.ReLU(),
            nn.Dropout(),
            nn.Conv1d(embedding_size, embedding_size, 5, padding = 2, bias = True),
            nn.BatchNorm1d(embedding_size),
            nn.ReLU(),
            nn.Dropout()
        )

        for layer in self.conv:
            if isinstance(layer, nn.Conv1d):
                torch.nn.init.xavier_uniform_(layer.weight, gain = torch.nn.init.calculate_gain('relu'))

        self.bi_lstm = nn.LSTM(embedding_size, embedding_size // 2,
                               batch_first = True, bidirectional = True)
        self.drop = nn.Dropout(0.1)

    def forward(self, x):
        x = self.embed(x)
        x = x.transpose(1, 2)
        #assert x.shape[:2] == (BATCH_SIZE, self.embedding_size), x.shape
        x = self.conv(x)
        x = x.transpose(1, 2)
        #assert x.shape[-1] == self.embedding_size, x.shape
        x, _ = self.bi_lstm(x)
        return x

**Aligner**
Căn chỉnh từ bằng cơ chế attention

In [None]:
class AlignNN(nn.Module):
    def __init__(self, enc_hidden_size, dec_hidden_size):
        '''
        Calculates the alignment of the arguments, representing the importance of each encoder output.
        Args:
            enc_hidden_size: output size of encoder at step j
            dec_hidden_size: size of previous hidden state of the decoder
        '''
        super().__init__()
        self.ll_dec_prev_hidden = nn.Linear(dec_hidden_size, 128)
        self.ll_enc_hidden = nn.Linear(enc_hidden_size, 128)
        self.ll_prev_step_att = nn.Linear(32, 128)
        self.ll_out = nn.Linear(128, 1)

        self.proj1 = None # saves encoder hidden LL output


    def forward(self, dec_prev_hidden, prev_step_att):
        proj2 = self.ll_dec_prev_hidden(dec_prev_hidden) # (BATCH_SIZE, seq_len = 1, hidden_size = 128)
        #assert len(proj2.shape) == 3
        #assert proj1.shape[2] == proj2.shape[2]
        proj3 = self.ll_prev_step_att(prev_step_att.transpose(2, 1)) # TODO: should it be transpose 1, 2?
        #assert len(proj3.shape) == 3
        #assert proj1.shape[2] == proj3.shape[2]
        return self.ll_out(torch.tanh(self.proj1 + proj2 + proj3)).squeeze(2)

class Attention(nn.Module):
    def __init__(self, enc_hidden_size, dec_hidden_size):
        '''
        Calculates attention (https://arxiv.org/abs/1409.0473)
        Args:
            enc_hidden_size: output size of encoder at step j
            dec_hidden_size: size of previous hidden state of the decoder
        '''
        super().__init__()
        # vanilla, Tacotron 2 paper version (my version) doesn't concat previous
        # step attention with the cumulative attention, so only 1 channel is needed
        # self.location_att = nn.Conv1d(1, 32, 31, padding = 15)

        self.location_att = nn.Conv1d(2, 32, 31, padding = 15)
        self.align_nn = AlignNN(enc_hidden_size, dec_hidden_size)
        self.mask = None


    def forward(self, enc_out, dec_prev_h, prev_step_att):
        # location features
        location_features = self.location_att(prev_step_att)

        # calculate the attention weights
        att_weights = self.align_nn(dec_prev_h, location_features) # (BATCH_SIZE, seq_len)

        if self.mask is not None:
            att_weights.data.masked_fill_(self.mask, -float('inf'))

        #assert att_weights.shape == (BATCH_SIZE, enc_out.shape[1])
        att_weights = torch.softmax(att_weights, dim = 1)
        # return the current step context vector (BATCH_SIZE, EMBEDDING_SIZE = 512)
        return att_weights, torch.bmm(att_weights.unsqueeze(1), enc_out).squeeze(1)

**Decoder**

In [None]:
class Decoder(nn.Module):
    def __init__(self, melspec_frame_shape):
        super().__init__()
        self.pre_net = nn.Sequential(
            nn.Linear(melspec_frame_shape, 256),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout()
        )

        self.lstm = nn.LSTM(EMBEDDING_SIZE + 256, 1024, 2, batch_first = True, dropout = 0.1)
        self.drop = nn.Dropout(0.1)
        self.linear = nn.Linear(EMBEDDING_SIZE + 1024, 80)
        self.stop_linear = nn.Sequential(
            nn.Linear(EMBEDDING_SIZE + 1024, 1),
        )
        self.post_net = nn.Sequential(
            nn.Conv1d(NR_MELS, 512, 5, padding = 2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(),
            nn.Conv1d(512, 512, 5, padding = 2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(),
            nn.Conv1d(512, 512, 5, padding = 2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(),
            nn.Conv1d(512, 512, 5, padding = 2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(),
            nn.Conv1d(512, NR_MELS, 5, padding = 2),
            nn.BatchNorm1d(NR_MELS),
            nn.Dropout()
        )


    def forward(self, context_vector, prev_frame, h_x):
        '''
        Args:
            context_vector: the attention context vector for current step (BATCH_SIZE, EMBEDDING_SIZE)
            prev_frame: previous step output (frame), using teacher forcing for training (NR_MELS, BATCH_SIZE)
            h_x: tuple containing last step's hidden and cell states
                 h_n shape: (num_layers = 2, BATCH_SIZE, hidden_size = 1024)
                 c_n shape: (num_layers = 2, BATCH_SIZE, hidden_size = 1024)
        '''
        prev_frame = prev_frame.transpose(1, 0)
        #assert prev_frame.shape == (BATCH_SIZE, 256), prev_frame.shape
        x = torch.concat([context_vector, prev_frame], dim = 1)
        #assert x.shape == (BATCH_SIZE, EMBEDDING_SIZE + 256), x.shape
        x = x.unsqueeze(1) # LSTM input shape: (batch_size, seq_len = 1, input_len)
        _, (h_n, c_n) = self.lstm(x, h_x)
        h_n = self.drop(h_n)
        #assert h_n.shape == (2, BATCH_SIZE, 1024), h_n.shape
        x = torch.concat([context_vector, h_n[1]], dim = 1)
        #assert x.shape == (BATCH_SIZE, EMBEDDING_SIZE + h_n.shape[2]), x.shape
        x_out = self.linear(x)
        x_stop = self.stop_linear(x)
        return (h_n, c_n), x_out, x_stop

Decoder theo kiểu của Nvidia (chèn thông tin attention trên vào cả 2 layer tầng đàu và tầng sau GRU/LSTM)

In [None]:
class Decoder(nn.Module):
    def __init__(self, melspec_frame_shape):
        super().__init__()
        self.pre_net = nn.Sequential(
            nn.Linear(melspec_frame_shape, 256),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout()
        )

        self.lstm_l1 = nn.LSTM(EMBEDDING_SIZE + 256, 1024, 1, batch_first = True)
        self.att_nn = None # the attention network
        self.lstm_l2 = nn.LSTM(EMBEDDING_SIZE + 1024, 1024, 1, batch_first = True)

        self.linear = nn.Linear(EMBEDDING_SIZE + 1024, 80)
        self.stop_linear = nn.Sequential(
            nn.Linear(EMBEDDING_SIZE + 1024, 1),
        )
        self.post_net = nn.Sequential(
            nn.Conv1d(NR_MELS, 512, 5, padding = 2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(),
            nn.Conv1d(512, 512, 5, padding = 2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(),
            nn.Conv1d(512, 512, 5, padding = 2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(),
            nn.Conv1d(512, 512, 5, padding = 2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(),
            nn.Conv1d(512, NR_MELS, 5, padding = 2),
            nn.BatchNorm1d(NR_MELS),
            nn.Dropout()
        )


    def get_mask_from_lengths(self, lengths):
        max_len = torch.max(lengths).item()
        ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
        mask = (ids < lengths.unsqueeze(1)).bool()
        return mask


    def init_states(self, enc_out, trans_lens, batch_size, seq_len, train_mode = True):
        self.h_n1 = torch.zeros((1, batch_size, 1024), requires_grad = train_mode).to(DEVICE) # (nr_layers, BATCH_SIZE, hidden_size)
        self.c_n1 = torch.zeros((1, batch_size, 1024), requires_grad = train_mode).to(DEVICE) # (nr_layers, BATCH_SIZE, hidden_size)
        self.h_n2 = torch.zeros((1, batch_size, 1024), requires_grad = train_mode).to(DEVICE) # (nr_layers, BATCH_SIZE, hidden_size)
        self.c_n2 = torch.zeros((1, batch_size, 1024), requires_grad = train_mode).to(DEVICE) # (nr_layers, BATCH_SIZE, hidden_size)

        self.prev_step_att = torch.zeros((batch_size, seq_len), requires_grad = train_mode).to(DEVICE)
        self.prev_steps_att = torch.zeros((batch_size, seq_len), requires_grad = train_mode).to(DEVICE)

        self.context_vec = torch.zeros((batch_size, EMBEDDING_SIZE), requires_grad = train_mode).to(DEVICE)

        self.enc_out = enc_out
        self.att_nn.align_nn.proj1 = self.att_nn.align_nn.ll_enc_hidden(enc_out)
        if trans_lens != None:
            self.att_nn.mask = ~self.get_mask_from_lengths(trans_lens)

    def forward(self, prev_frame):
        '''
        Args:
            context_vector: the attention context vector for current step (BATCH_SIZE, EMBEDDING_SIZE)
            prev_frame: previous step output (frame), using teacher forcing for training (NR_MELS, BATCH_SIZE)
            h_x: tuple containing last step's hidden and cell states
                 h_n shape: (num_layers = 2, BATCH_SIZE, hidden_size = 1024)
                 c_n shape: (num_layers = 2, BATCH_SIZE, hidden_size = 1024)
        '''
        prev_frame = prev_frame.transpose(1, 0)
        #assert prev_frame.shape == (BATCH_SIZE, 256), prev_frame.shape

        x = torch.concat([self.context_vec, prev_frame], dim = -1)
        #assert x.shape == (BATCH_SIZE, EMBEDDING_SIZE + 256), x.shape
        x = x.unsqueeze(1) # LSTM input shape: (batch_size, seq_len = 1, input_len)
        _, (self.h_n1, self.c_n1) = self.lstm_l1(x, (self.h_n1, self.c_n1))
        self.h_n1 = nn.functional.dropout(self.h_n1, 0.1, True)

        att = torch.cat((self.prev_step_att.unsqueeze(1), self.prev_steps_att.unsqueeze(1)), dim = 1)
        self.prev_step_att, self.context_vec = self.att_nn(self.enc_out, self.h_n1.squeeze(0).unsqueeze(1), att)
        self.prev_steps_att += self.prev_step_att

        x = torch.concat([self.h_n1.squeeze(0), self.context_vec], dim = -1).unsqueeze(1)
        _, (self.h_n2, self.c_n2) = self.lstm_l2(x, (self.h_n2, self.c_n2))
        self.h_n2 = nn.functional.dropout(self.h_n2, 0.1, True)

        #assert h_n.shape == (2, BATCH_SIZE, 1024), h_n.shape
        x = torch.concat([self.context_vec, self.h_n2.squeeze(0)], dim = 1)
        #assert x.shape == (BATCH_SIZE, EMBEDDING_SIZE + h_n.shape[2]), x.shape
        x_out = self.linear(x)
        x_stop = self.stop_linear(x)
        return  x_out, x_stop

mô hình tacotron2

In [None]:
class Tacotron2(nn.Module):
    def __init__(self, enc_bi_lstm_size = 256, dec_lstm_size = 1024, train_mode = True):
        super().__init__()

        self.train_mode = True
        self.prev_train_loss = float('inf')
        self.prev_valid_loss = float('inf')

        self.enc = Encoder(VOCAB_SIZE, EMBEDDING_SIZE)
        self.dec = Decoder(NR_MELS)
        self.dec.att_nn = Attention(enc_bi_lstm_size * 2, dec_lstm_size)


        self.optimizer = torch.optim.Adam(self.parameters(), lr = LR, eps = EPS, weight_decay = WEIGHT_DECAY)


    def init_states(self, batch_size, seq_len):
        self.frame = torch.zeros((batch_size, NR_MELS), requires_grad = self.train_mode).to(DEVICE)
        self.h_n = torch.zeros((2, batch_size, 1024), requires_grad = self.train_mode).to(DEVICE) # (nr_layers, BATCH_SIZE, hidden_size)
        self.c_n = torch.zeros((2, batch_size, 1024), requires_grad = self.train_mode).to(DEVICE) # (nr_layers, BATCH_SIZE, hidden_size)

        self.prev_step_att = torch.zeros((batch_size, seq_len), requires_grad = self.train_mode).to(DEVICE)
        self.prev_steps_att = torch.zeros((batch_size, seq_len), requires_grad = self.train_mode).to(DEVICE)


    def mask_outputs(self, melspecs, res_melspecs, stop_tokens, mel_lens):
        mel_mask = torch.zeros(melspecs.shape, dtype = torch.bool).to(DEVICE)
        stop_mask = torch.zeros(stop_tokens.shape, dtype = torch.bool).to(DEVICE)
        for i, len in enumerate(mel_lens):
            mel_mask[i, :, len:] = True
            stop_mask[i, len:] = True

        melspecs.data.masked_fill_(mel_mask, 0.0)
        res_melspecs.data.masked_fill_(mel_mask, 0.0)
        stop_tokens.data.masked_fill_(stop_mask, 1e3)

        return melspecs, res_melspecs, stop_tokens

    def _forward(self, transcript, trans_lens, melspec, stop_tokens, mel_lens):
        #assert melspec.shape == (BATCH_SIZE, NR_MELS, torch.max(mel_lens)), melspec.shape

        loss = 0.0
        x_outs, x_stops = [], []

        enc_out = self.enc(transcript)
        self.init_states(transcript.shape[0], transcript.shape[-1])
        self.dec.init_states(enc_out, trans_lens, transcript.shape[0], transcript.shape[-1], self.train_mode)
        #self.att_nn.align_nn.proj1 = self.att_nn.align_nn.ll_enc_hidden(enc_out)
        #self.att_nn.mask = ~self.get_mask_from_lengths(trans_lens)
        melspec = self.dec.pre_net(melspec.permute(0, 2, 1))
        self.frame = self.dec.pre_net(self.frame).transpose(1, 0) # initial frame
        melspec = melspec.permute(1, 2, 0) # permute in order to traverse the frames
        #assert melspec.shape == (torch.max(mel_lens), 256, BATCH_SIZE), melspec.shape
        for i, crt_frame in enumerate(melspec):
            #self.prev_steps_att = self.prev_steps_att + self.prev_step_att
            #att = torch.cat((self.prev_step_att.unsqueeze(1), self.prev_steps_att.unsqueeze(1)), dim = 1)
            #self.prev_step_att, context_vec = self.att_nn(enc_out, self.h_n[1].unsqueeze(1), att)
            #assert context_vec.shape == (BATCH_SIZE, EMBEDDING_SIZE), context_vec.shape
            #(self.h_n, self.c_n), x_out, x_stop = self.dec(context_vec,
            #                                               self.frame,
            #                                               (self.h_n, self.c_n))
            x_out, x_stop = self.dec(self.frame)
            x_outs.append(x_out)
            x_stops.append(x_stop)

            self.frame = crt_frame

        x_out = torch.stack(x_outs, dim = 2)
        out = x_out + self.dec.post_net(x_out)
        x_stops = torch.cat(x_stops, dim = 1)
        return x_out, out, x_stops

    def _inference(self, transcript):
        self.eval()
        with torch.no_grad():
            enc_out = self.enc(transcript)
            self.init_states(transcript.shape[0], transcript.shape[-1])
            self.dec.init_states(enc_out, None, transcript.shape[0], transcript.shape[-1])
            x_stop = torch.FloatTensor([-1.0])
            crt_steps = 0
            melspec = []
            frame = torch.zeros((transcript.shape[0], NR_MELS)).to(DEVICE)
            #self.att_nn.align_nn.proj1 = self.att_nn.align_nn.ll_enc_hidden(enc_out)
            while torch.sigmoid(x_stop).item() < THRESHOLD and crt_steps < MAX_DEC_STEPS:
                #self.prev_steps_att = self.prev_steps_att + self.prev_step_att
                #att = torch.cat((self.prev_step_att.unsqueeze(1), self.prev_steps_att.unsqueeze(1)), dim = 1)
                #self.prev_step_att, context_vec = self.att_nn(enc_out, self.h_n[1].unsqueeze(1), att)
                #assert context_vec.shape == (BATCH_SIZE, EMBEDDING_SIZE), context_vec.shape
                frame = self.dec.pre_net(frame).transpose(1, 0)
                x_out, x_stop = self.dec(frame)
                melspec.append(x_out)
                frame = x_out
                crt_steps += 1
            melspec = torch.stack(melspec, dim = 2)
            melspec = melspec + self.dec.post_net(melspec)
        return melspec.cpu().numpy()

    def forward(self, transcript, trans_lens = None, melspec = None, stop_tokens = None, mel_lens = None):
        if self.train_mode:
            return self._forward(transcript, trans_lens, melspec, stop_tokens, mel_lens)
        else:
            return self._inference(transcript)

    def criterion(self, y_pred, y):
        # taken from: https://github.com/NVIDIA/tacotron2
        x_out, out, x_stop = y_pred[0], y_pred[1], y_pred[2]
        melspec, stop_tokens = y
        loss = nn.MSELoss()(x_out, melspec) + \
               nn.MSELoss()(out, melspec) + \
               nn.BCEWithLogitsLoss()(x_stop, stop_tokens)
        return loss

    def get_mask_from_lengths(self, lengths):
        # taken from: https://github.com/NVIDIA/tacotron2
        max_len = torch.max(lengths).item()
        ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
        mask = (ids < lengths.unsqueeze(1)).bool()
        return mask

    def parse_outputs(self, melspecs, res_melspecs, stop_tokens, mel_lens):
        # taken from: https://github.com/NVIDIA/tacotron2
        if mel_lens is not None:
            mask = ~self.get_mask_from_lengths(mel_lens)
            mask = mask.expand(NR_MELS, mask.size(0), mask.size(1))
            mask = mask.permute(1, 0, 2)

            melspecs.data.masked_fill_(mask, 0.0)
            res_melspecs.data.masked_fill_(mask, 0.0)
            stop_tokens.data.masked_fill_(mask[:, 0, :], 1e3)  #1e3 # gate energies
        return (melspecs, res_melspecs, stop_tokens)

    def fit(self, train_loader, valid_loader):
        #writer = SummaryWriter('runs/taco2')
        train_losses, valid_losses = [], []
        self.train_mode = True
        self.train()
        for e in range(EPOCHS):
            total_loss = 0
            for (trans, melspec, stop_tokens, mel_lens, trans_lens) in train_loader:
                trans = trans.to(DEVICE)
                melspec = melspec.to(DEVICE)
                stop_tokens = stop_tokens.to(DEVICE)
                mel_lens = mel_lens.to(DEVICE)
                trans_lens = trans_lens.to(DEVICE)

                out = self(trans, trans_lens, melspec, stop_tokens)
                out = self.parse_outputs(out[0], out[1], out[2], mel_lens)
                #assert x_out.shape == out.shape == melspec.shape
                loss = self.criterion(out, (melspec, stop_tokens))
                total_loss += loss.item()

                self.zero_grad(set_to_none = True)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
                self.optimizer.step()
            total_loss /= len(train_loader)
            train_losses.append(total_loss)
            print("#{} TRAIN LOSS: {:.4f}".format(e, total_loss), end = "\n")
            #val_loss = self.validate(valid_loader)
            #valid_losses.append(val_loss)
            #if val_loss < self.prev_val_loss: # use this if training on the entire dataset
            #if total_loss < self.prev_train_loss:
            #print("Saving model...")
            self.save(total_loss, valid_loss = None)
            #self.prev_val_loss = val_loss
            #self.prev_train_loss = total_loss
            plt.plot(train_losses)
            #plt.plot(valid_losses)
            #plt.savefig('train_loss.png')

    def validate(self, loader):
        self.eval()
        with torch.no_grad():
            total_loss = 0.0
            for (trans, melspec, stop_tokens, mel_lens) in loader:
                trans = trans.to(DEVICE)
                melspec = melspec.to(DEVICE)
                stop_tokens = stop_tokens.to(DEVICE)
                out = self(trans, melspec, stop_tokens, mel_lens)
                loss = self.criterion(out, (melspec, stop_tokens))
                total_loss += loss.item()
            total_loss /= len(loader)
        self.train()
        print("EVAL LOSS: {:.4f}".format(total_loss))
        return total_loss

    def save(self, train_loss, valid_loss = None):
        torch.save({
            'train_loss': train_loss,
            'valid_loss': valid_loss,
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            }, CHKP_PATH)

    def load(self):
        checkpoint = torch.load(CHKP_PATH, map_location=torch.device(DEVICE))
        self.prev_train_loss = checkpoint['train_loss']
        self.prev_valid_loss = checkpoint['valid_loss']
        print("Current checkpoint train loss: {:.4f}".format(checkpoint['train_loss']))
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.load_state_dict(checkpoint['model_state_dict'])

In [None]:
ds = LJSpeech("/content/drive/MyDrive/project1/LJSpeech-1.1/LJSpeech/ljspeech_gpu.hkl")


print("Dataset size:", len(ds))
print("Transcript shape:", ds[0][0].shape)
print("Melspectrogram shape:", ds[0][1].shape)
print("Stop token shape:", ds[0][2].shape)

plt.figure(figsize = (8, 6), dpi = 100)
p = plt.imshow(ds[0][1][:, :].transpose(1, 0))

train_ds, valid_ds, test_ds = torch.utils.data.random_split(ds, [12000, 95, 500], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_ds, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_pad)
valid_loader = DataLoader(valid_ds, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_pad)

# sample size: ((BATCH_SIZE, NR_WORDS), (BATCH_SIZE, NR_MELS, NR_FRAMES), (BATCH_SIZE, NR_FRAMES), (BATCH_SIZE, NR_FRAMES))
sample = next(iter(train_loader))

In [None]:
taco = Tacotron2().to(DEVICE)

In [None]:
taco.load()

In [None]:
taco.fit(train_loader, valid_loader)

**Tinh chỉnh audio bằng HiFi-Gan**

In [None]:
import IPython.display as ipd

In [None]:
ipd.Audio('LJSpeech-1.1/wavs/LJ001-0001.wav')

In [None]:
# torch.Size([1, 80, 831])
test_loader = DataLoader(train_ds, batch_size = 1, shuffle = True, collate_fn = collate_pad)
sample = next(iter(test_loader))
sample = next(iter(test_loader))
transcript = sample[0].to(DEVICE)
orig_melspec = sample[1]
taco.eval()
taco.train_mode = False
melspec = taco(transcript) # the input is the transcript
print(orig_melspec.shape)
print(melspec.shape)
np.save("test_mel_files/original.npy", orig_melspec[0].numpy())
np.save("test_mel_files/taco2_output.npy", melspec[0])

In [None]:
!python inference_e2e.py --checkpoint_file generator_v3

In [None]:
ipd.Audio('generated_files_from_mel/original_generated_e2e.wav')

In [None]:
ipd.Audio('generated_files_from_mel/taco2_output_generated_e2e.wav')