# Tacotron

- Tacotron은 2017년도에 arxiv에 발표된 논문으로, (음성, 텍스트) 쌍으로 부터 spectrogram을 생성해주는 모델이다.
- 이번 구현에서 사용할 데이터 셋은 LJSpeech 데이터셋으로 단일 화자가 녹음했으며 총 24시간 분량이다. (영어)
- 논문 링크: [Tacotron: Towards End-to-End Speech Synthesis](https://arxiv.org/abs/1703.10135)
- 블로그 링크: https://orca0917.github.io/posts/Tacotron/

In [None]:
# 필요한 모듈 설치
%pip install g2p_en

# 1. 데이터셋 다운로드

- LJSpeech dataset: https://keithito.com/LJ-Speech-Dataset/
- 크기: 약 2.6GB
- 데이터셋 형식: `tar.bz2` 확장자로 압축, 내부 음원은 `.wav` 확장자

In [None]:
import requests
from tqdm.notebook import tqdm
import tarfile
import os

def download_and_extract(dataset_url, extract_to='.'):
    # 파일명 추출
    filename = dataset_url.split('/')[-1]
    filepath = os.path.join(extract_to, filename)

    # 데이터셋 다운로드
    print(f"다운로드 중 {filename}...")
    response = requests.get(dataset_url, stream=True)
    total_size_in_bytes = 2750000000
    block_size = 1024 # 1 Kibibyte
    progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, leave=True)
    with open(filepath, 'wb') as file:
        for data in response.iter_content(block_size):
            progress_bar.update(len(data), )
            file.write(data)
    progress_bar.close()

    # 압축 해제
    print(f"압축 해제 중 {filename}...")
    if filepath.endswith("tar.bz2"):
        tar = tarfile.open(filepath, "r:bz2")
        tar.extractall(path=extract_to)
        tar.close()
    else:
        print("알 수 없는 압축 형식입니다.. 오직 .tar.bz2 파일만 압축 해제할 수 있습니다.")

    # 다운로드한 파일 삭제
    os.remove(filepath)
    print(f"{filename} 는 압축해제가 성공적으로 완료되었으며 다운로드 된 파일은 삭제하였습니다.")

# LJSpeech 데이터셋 URL (실제 URL을 사용하세요)
dataset_url = 'https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2'
# 압축을 풀 폴더 지정
extract_to = '/content/'

# 함수 실행
download_and_extract(dataset_url, extract_to)


# 2. 커스텀 데이터셋 정의

- 커스텀 데이터셋으로 부터 (텍스트, 정답 스펙트로그램) 입력쌍을 사용
- 각 입력마다 들어가는 텍스트의 길이, 음원의 길이 모두 다르기 때문에 `collate_fn` 함수 정의 필요

In [None]:
import os
import torch
import string
import librosa
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F

from g2p_en import G2p
# from google.colab import drive
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
# drive.mount('/content/drive')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'현재 사용중인 장치 = {device}')

In [None]:
args = {
    # 데이터셋 파라미터
    'batch_size' : 16,
    'n_mels': 80,
    'lin_dim': 1025,
    'wav_path': '/opt/ml/LJSpeech-1.1/wavs/',
    'metadata_path': '/opt/ml/LJSpeech-1.1/metadata.csv',
    'pre_emphasis': 0.97,
    'sr': 24000,
    'frame_length': 0.050,
    'frame_shift': 0.0125,
    'n_fft': 2048,

    # 인코더 파라미터
    'n_vocab' : len(string.punctuation) + len(G2p().phonemes) + 2,
    'char_embedding_dim' : 256,
    'prenet_dims' : [256, 128],
    'enc_n_kernel' : 16,
    'enc_in_dim' : 128,
    'enc_cbhg_projection_dims' : [128, 128],

    # 디코더 파라미터
    'r' : 1,
    'dec_hidden_dim' : 256,
    'dec_n_kernel' : 8,
    'dec_cbhg_projection_dims' : [256, 80]
}

In [None]:
class LJSpeechDataset(Dataset):
    """
    학습에 사용될 데이터를 준비해주는 클래스
    - LJSpeech 데이터셋을 사용하여 총 3개의 값을 반환
    - wav 파일의 이름, wav 파일의 대본(텍스트), wav 파일의 스펙트로그램
    - wav 파일의 이름은 확장자를 포함하지 않는다.
    """
    def __init__(self, args):
        super(LJSpeechDataset, self)
        self.metadata_path  = args['metadata_path']
        self.pre_emphasis   = args['pre_emphasis']
        self.frame_length   = args['frame_length']
        self.frame_shift    = args['frame_shift']
        self.wav_path       = args['wav_path']
        self.n_mels         = args['n_mels']
        self.n_fft          = args['n_fft']
        self.sr             = args['sr']
        self.g2p            = G2p()

        # 음소와 문장부호를 숫자로 매핑해주는 dict
        self.symbols = list(string.punctuation) + self.g2p.phonemes + [' ']
        self.symbol_to_id = {s: i + 1 for i, s in enumerate(self.symbols)}

        # 음성파일 이름과 대본들
        metadata_df         = pd.read_csv(self.metadata_path, delimiter='|', header=None)
        self.wav_file_names = metadata_df[0].values
        self.text_script    = metadata_df[2].values

    def __len__(self):
        return len(self.wav_file_names)

    def __getitem__(self, index):
        sequence        = self.get_sequence(self.text_script[index])
        spec, melspec   = self.get_spectrograms(self.wav_file_names[index])
        return sequence, spec, melspec

    # 텍스트를 음소로 변환하고, 음소를 숫자로 변환하여 반환하는 함수
    def get_sequence(self, text):
        phonemes = self.g2p(text)
        ret = []
        for p in phonemes:
            if p not in self.symbols: # 항상 phoneme 리스트에 존재하는 값들만 다룸
                continue
            ret.append(self.symbol_to_id[p])
        return np.array(ret)

    # 음성파일로 부터 선형 스펙트로그램을 구하는 함수
    def get_spectrograms(self, wav_file_name):
        y, sr = librosa.load(os.path.join(self.wav_path, wav_file_name + '.wav'), sr=self.sr)
        y_preemphasized = np.append(y[0], y[1:] - self.pre_emphasis * y[:-1])

        # 스펙트로그램을 생성하기 위한 STFT(Short-Time Fourier Transform) 계산
        hop_length = int(sr * self.frame_shift)
        win_length = int(sr * self.frame_length)

        D = librosa.stft(y=y_preemphasized,
                        n_fft=self.n_fft,
                        hop_length=hop_length,
                        win_length=win_length,
                        window='hann')
        S = librosa.amplitude_to_db(np.abs(D), ref=np.max)

        # 멜 스펙트로그램 계산
        mel_S = librosa.feature.melspectrogram(S=D,
                                               sr=sr,
                                               n_fft=self.n_fft,
                                               hop_length=hop_length,
                                               win_length=win_length,
                                               n_mels=self.n_mels)
        mel_S = librosa.amplitude_to_db(np.abs(mel_S), ref=np.max)

        return S, mel_S


# 배치마다 존재하는 음성의 길이가 다르므로 패딩 처리 해주기 (가장 긴 것에 맞추기)
def collate_fn(batch):
    # 텍스트
    max_seq_len = max([len(item[0]) for item in batch])
    pad_seq = np.array(
        [np.pad(item[0], (0, max_seq_len - len(item[0])), mode='constant', constant_values=0) for item in batch],
        dtype=np.int32)

    # 스펙트로그램
    max_spec_len = max([item[1].shape[1] for item in batch])
    pad_spec = np.array(
        [np.pad(item[1], [(0, 0), (0, max_spec_len - item[1].shape[1])], mode='constant', constant_values=0) for item in batch],
        dtype=np.float32
    )

    # 멜스펙트로그램
    max_melspec_len = max([item[2].shape[1] for item in batch])
    pad_melspec = np.array(
        [np.pad(item[2], [(0, 0), (0, max_melspec_len - item[2].shape[1])], mode='constant', constant_values=0) for item in batch],
        dtype=np.float32
    )

    pad_seq     = torch.LongTensor(pad_seq)
    pad_spec    = torch.FloatTensor(pad_spec).transpose(1, 2)
    pad_melspec = torch.FloatTensor(pad_melspec).transpose(1, 2)

    return pad_seq, pad_spec, pad_melspec


dataset = LJSpeechDataset(args)
dataloader = DataLoader(dataset,
                        batch_size=args['batch_size'],
                        shuffle=False,
                        num_workers=2,
                        collate_fn=collate_fn,)

# 3. Tacotron 모델 구현
- Encoder
- Decoder
- Griffin Lim 보코더로 구성

![](https://orca0917.github.io/assets/img/tacotron/tacotron-figure1.png)

## 3.1. Encoder 구현

- Text Embedding
- PreNet
- CBHG (Convolution Bank, Highway Network, GRU)

In [None]:
class HighwayNet(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.H = nn.Linear(dim, dim)
        self.T = nn.Linear(dim, dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()  # dim 확인해보기

    def forward(self, x):
        C = self.sigmoid(self.T(x))
        y = self.relu(self.H(x))
        return (1 - C) * x + C * y


class BNConv1D(nn.Module):
    """
    1D Convolution과 함께 Batch Normalization을 적용해주는 클래스
    """
    def __init__(self, in_dim, out_dim, kernel_size, stride, padding):
        super(BNConv1D, self).__init__()
        self.conv1d = nn.Conv1d(in_dim, out_dim, kernel_size, stride, padding)
        self.bn = nn.BatchNorm1d(out_dim)

    def forward(self, x):
        x = self.conv1d(x)
        x = self.bn(x)
        return x


class CBHG(nn.Module):
    """
    CBHG는 1D 컨볼루션 필터 여러개와 함께, highway network, GRU로 구성된 모듈이다.
    """
    def __init__(self, in_dim, K, proj_dims):
        super(CBHG, self).__init__()
        self.conv_bank = nn.ModuleList([
            BNConv1D(in_dim, in_dim, k + 1, 1, k // 2) for k in range(K)])
        self.max_pool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)

        p_ins  = [in_dim * K] + proj_dims[:-1]
        p_outs = proj_dims

        self.projections = nn.ModuleList([
            BNConv1D(p_in, p_out, 3, 1, 1) for p_in, p_out in zip(p_ins, p_outs)])

        self.activations = nn.ModuleList([
            nn.ReLU() for _ in range(len(p_outs) - 1)] + [None])

        self.proj_linear  = nn.Linear(p_outs[-1], p_outs[-1], bias=False)
        self.highway      = nn.ModuleList([HighwayNet(in_dim) for _ in range(4)])
        self.bi_gru       = nn.GRU(in_dim, in_dim, 1, batch_first=True, bidirectional=True)

    def forward(self, x):
        residual = x                        # (B, L, 128)
        x = x.transpose(1, 2)               # (B, 128, L)
        L = x.shape[-1]

        conv_result = []
        for conv in self.conv_bank:
            y = conv(x)
            y = F.pad(y, (0, L - y.shape[-1]))
            conv_result.append(y)

        x = torch.cat(conv_result, dim=1)   # (B, 128 * 16, L)
        x = self.max_pool(x)[:, :, :L]      # (B, 128 * 16, L)

        for proj, act in zip(self.projections, self.activations):
            x = proj(x)                     # (B, 128, L)
            if act is not None:
                x = act(x)
        x = x.transpose(1, 2)               # (B, L, 128)
        x = self.proj_linear(x)             # (B, L, 128)

        # residual connection
        x += residual                       # (B, L, 128)

        for highway in self.highway:
            x = highway(x)        # (B, L, 128)

        x, _ = self.bi_gru(x)               # (B, L, 128 * 2)
        return x


class PreNet(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        x = self.layers(x)
        return x


class Encoder(nn.Module):
    def __init__(self, args):
        super(Encoder, self).__init__()
        cbhg_proj   = args['enc_cbhg_projection_dims']
        c_emb_dim   = args['char_embedding_dim']
        cbhg_k      = args['enc_n_kernel']
        in_dim      = args['enc_in_dim']
        n_vocab     = args['n_vocab']

        self.char_embedding = nn.Embedding(n_vocab, c_emb_dim)
        self.enc_prenet     = PreNet(in_dim=c_emb_dim)
        self.cbhg           = CBHG(in_dim=in_dim, K=cbhg_k, proj_dims=cbhg_proj)

    def forward(self, x: torch.LongTensor):
                                    # (B, L)
        x = self.char_embedding(x)  # (B, L, 256)
        x = self.enc_prenet(x)      # (B, L, 128)
        x = self.cbhg(x)            # (B, L, 128 * 2)
        return x

## 3.2. Decoder 구현

In [None]:
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.enc_W  = nn.Linear(256, 256)
        self.dec_W  = nn.Linear(256, 256)
        self.v      = nn.Linear(256, 1, bias=False)
        self.tanh   = nn.Tanh()

    def forward(self, encoder_outputs, gru_output):
        enc_out = self.enc_W(encoder_outputs)               # (B, 138, 256)
        gru_out = self.dec_W(gru_output)                    # (B, 256)

        gru_out = gru_out.unsqueeze(1)                      # (B, 1, 256)
        alignment = self.v(self.tanh(enc_out + gru_out))    # (B, 138, 1)
        alignment = alignment.squeeze(-1)                   # (B, 138)
        alignment = F.softmax(alignment, dim=1)             # (B, 138)
        return alignment


class AttentionRNN(nn.Module):
    def __init__(self):
        super(AttentionRNN, self).__init__()
        self.gru_cell = nn.GRUCell(128 + 256, 256)
        self.attention = Attention()

    def forward(self, rnn_input, cell_hidden, attn_hidden, encoder_outputs):
        """
        - rnn_input : (B, 128)
        - cell_hidden : (B, 256)
        - attn_hidden : (B, 256)
        - encoder_outputs : (B, 138, 256)
        """

        cell_input = torch.cat((rnn_input, attn_hidden), dim=-1)            # (B, 128 + 256)
        cell_output = self.gru_cell(cell_input, cell_hidden)                # (B, 256)
        alignment = self.attention(encoder_outputs, cell_output)            # (B, 138)

        #-- 여기서부터는 논문에 언급된 내용이 없었음 (참고: r9y9)
        attn_hidden = torch.bmm(alignment.unsqueeze(1), encoder_outputs)    # (B, 1, 256)
        attn_hidden = attn_hidden.squeeze(1)                                # (B, 256)

        return cell_output, attn_hidden, alignment


class DecoderRNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(512, 256, bias=False)
        self.grus = nn.ModuleList([
            nn.GRUCell(256, 256) for _ in range(2)
        ])

    def forward(self, cell_hidden, attn_hidden, dec_rnn_hiddens):
        """
        - cell_hidden : (B, 256)
        - attn_hidden : (B, 256)
        - dec_rnn_hiddens : (2, B, 256)
        """
        decoder_in = torch.cat((cell_hidden, attn_hidden), dim=-1)  # (B, 512)
        decoder_in = self.linear(decoder_in)                        # (B, 256)
        for rnn_hidden, gru in zip(dec_rnn_hiddens, self.grus):
            rnn_hidden = gru(decoder_in, rnn_hidden)                # (B, 256)
            decoder_in = rnn_hidden + decoder_in                    # (B, 256)

        return decoder_in                                           # (B, 256)


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.r              = args['r']
        self.mel_dim        = args['n_mels']
        self.linear_dim     = args['lin_dim']
        self.dec_hidden_dim = args['dec_hidden_dim']
        self.dec_prenet     = PreNet(in_dim=self.mel_dim * self.r)
        self.attention_rnn  = AttentionRNN()
        self.decoder_rnn    = DecoderRNN()
        self.projection     = nn.Linear(self.dec_hidden_dim, self.mel_dim * self.r)
        self.dec_cbhg       = CBHG(in_dim=self.mel_dim,
                                   K=args['dec_n_kernel'],
                                   proj_dims=args['dec_cbhg_projection_dims'])
        self.to_linear      = nn.Linear(self.mel_dim * 2, self.linear_dim)

    def forward(self, enc_out, T):
        # enc_out  (B, T, 256)
        B = enc_out.shape[0]

        # 초기 hidden state 및 attention 들
        go_frame        =  torch.zeros(B, self.mel_dim * self.r, dtype=torch.float32).to(device)  # <go> frame
        cell_hidden     =  torch.zeros(B, self.dec_hidden_dim, dtype=torch.float32).to(device)
        attn_hidden     =  torch.zeros(B, self.dec_hidden_dim, dtype=torch.float32).to(device)
        dec_rnn_hiddens = [torch.zeros(B, self.dec_hidden_dim, dtype=torch.float32).to(device) for _ in range(2)]

        # 예측한 alignment와 mel_frame을 담아두는 곳
        alignments, mel_frames = [], []

        for step in range(T):
            if step > 0:
                go_frame = mel_frames[-1]
            prenet_out = self.dec_prenet(go_frame)                                                                      # (B, 128)
            cell_hidden, attn_hidden, alignment = self.attention_rnn(prenet_out, cell_hidden, attn_hidden, enc_out)
            dec_out = self.decoder_rnn(cell_hidden, attn_hidden, dec_rnn_hiddens)                                       # (B, 256)

            mel_frame = self.projection(dec_out)                                                                        # (B, 80 * 3)
            alignments += [alignment]
            mel_frames += [mel_frame]

        alignments = torch.stack(alignments).transpose(0, 1)                                                            # (B, T * r, )
        mel_frames = torch.stack(mel_frames).transpose(0, 1).contiguous()

        mel_frames = mel_frames.view(B, -1, self.mel_dim)                                                               # (B, T * r, 80)
        spectrogram = self.dec_cbhg(mel_frames)                                                                         # (B, T * r, 128 * 2)
        spectrogram = self.to_linear(spectrogram)                                                                       # (B, r * step, 1025)
        return alignments, mel_frames, spectrogram

class Tacotron(nn.Module):
    def __init__(self, args):
        super(Tacotron, self).__init__()
        self.encoder = Encoder(args)
        self.decoder = Decoder()

    def forward(self, x, T):
        # (B, L)
        x = self.encoder(x)  # (B, L, 128 * 2)
        alignments, melspectrogram, spectrogram = self.decoder(x, T)
        return alignments, melspectrogram, spectrogram

# 모델 실험


In [None]:
model = Tacotron(args)
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

model = model.to(device)
loss_history = []

for epoch in range(1):

    global_loss = []

    for i, (x, y_lin, y_mel) in tqdm(enumerate(dataloader), total=len(dataloader)):
        # print(f'데이터 타입: {x.dtype}, 크기: {x.shape}') # (B, L)
        # print(f'데이터 타입: {y_lin.dtype}, 크기: {y_lin.shape}') # (B, 1025, T)
        # print(f'데이터 타입: {y_mel.dtype}, 크기: {y_mel.shape}') # (B, 80, T)

        x       = x.to(device)
        y_lin   = y_lin.to(device)
        y_mel   = y_mel.to(device)

        T = y_lin.shape[1]
        alignment, pred_mel, pred_lin = model(x, T)

        # print(f'데이터 타입 <alignment>: {alignment.dtype}, 크기: {alignment.shape}')
        # print(f'데이터 타입 <melspectrogram>: {pred_mel.dtype}, 크기: {pred_mel.shape}')
        # print(f'데이터 타입 <spectrogram>: {pred_lin.dtype}, 크기: {pred_lin.shape}')

        mel_loss = criterion(pred_mel, y_mel)
        lin_loss = criterion(pred_lin, y_lin)
        loss = mel_loss + lin_loss
        global_loss += [loss.item()]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    loss_history += [np.mean(global_loss)]

# Loss 값 비교

In [None]:
import matplotlib.pyplot as plt

plt.plot(loss_history)
plt.title('Loss convergence', fontsize=20, fontweight='semibold')
plt.show()