# Implementation of Tacotron2

- Tacotron2는 학부생 수준에서 **GitHub를 그대로 따라쓰지 않고**, 스스로 구현할 수 있는 부분을 최대화하는 것을 목표로 구현해보려고 한다.
    1. GitHub의 기본적인 Module을 참고하면서 대략적인 모델 구조를 직접 그려본다.
    2. 전처리는 Tacotron1에서 구현했던 방식과 유사하게 재구현한다.
    3. GitHub를 참고하게 된다면, 다음 GitHub를 참고할 예정이다.
        - **1. chldkato**: `https://github.com/chldkato/Tacotron-pytorch`
        - **2. NVIDIA**: `https://github.com/NVIDIA/tacotron2`
        - **3. hccho2**: `https://github.com/hccho2/Tacotron2-Wavenet-Korean-TTS`

- 먼저, NVIDIA의 모델 구조는 다음과 같다.
<img src="NVIDIA Model Description.png" width=70% height=70%>

# 남은 과제

- 이 파트는 구현하다가 기록해둬야 하는 내용들을 적는 곳이며, **구현이 모두 마무리되면 지운다.**

[필수]  
1. **Inference**가 아직 구현되지 않았다. 학습하기 전, 반드시 모든 코드에 대하여 추론 과정을 구현한다!
2. **loss**를 구하도록 target을 만들지 않음. 학습하기 전, 반드시 구현해야 한다.

[선택]  
1. **Xavier 초기화**가 아직 구현되지 않았다. 구현을 마무리할 때 Xavier 초기화를 추가한다.
2. **hyperparameter의 soft coding화**가 필요하다. magic number를 이용한 hard coding(하이퍼파라미터를 따로 변수로 만들지 않고 숫자로 사용하는 것)은 코드의 유지보수가 어려워질 수 있다.
3. **SubPixelConvolution**레이어를 구현하지 않았다. Wavenet에서 local_condition을 upsampling할 때 transposed convolution layer가 제대로 작동하지 않을 경우 추가로 구현해서 대체할 필요가 있다. 

# 1. Preprocessing
- Tacotron 1에서의 구현을 가져오기로 한다.
- chldkato Github: `https://github.com/chldkato/Tacotron-pytorch`

### Import

In [1]:
import pandas as pd
import numpy as np
import os, librosa, re, glob, scipy
from tqdm import tqdm

from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np

from torch import nn
import torch.nn.functional as F

import time

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import numpy as np
import traceback

## 1.0. Hyperparameters

In [2]:
# 1.2. Audio Preprocessing
sample_rate = 22050
preemphasis = 0.97
n_fft = 1024
hop_length = 256
win_length = 1024
ref_db = 20
max_db = 100
mel_dim = 80

# Major Hyperparameters
batch_size = 64
checkpoint_step = 100

# 2.1. Encoder
symbol_length = 70 # len(symbols) = 70 (PAD + EOS + VALID_CHARS)
embedding_dim = 512

# 2.2. Decoder

# 2.3. WaveNet
dilations = [1, 2, 4, 8, 16, 32]*2
upsampling_factors = [16, 16] # 곱이 반드시 hop_length와 일치해야 함. 16 x 16 = 256

## 1.1. Text Preprocessing

In [3]:
!pip install jamo



In [4]:
from jamo import hangul_to_jamo

PAD = '_'
EOS = '~'
SPACE = ' '

JAMO_LEADS = "".join([chr(_) for _ in range(0x1100, 0x1113)])
JAMO_VOWELS = "".join([chr(_) for _ in range(0x1161, 0x1176)])
JAMO_TAILS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)])

VALID_CHARS = JAMO_LEADS + JAMO_VOWELS + JAMO_TAILS + SPACE
symbols = PAD + EOS + VALID_CHARS

_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}

# text를 초성, 중성, 종성으로 분리하여 id로 반환하는 함수
def text_to_sequence(text):
    sequence = []
    if not 0x1100 <= ord(text[0]) <= 0x1113:
        text = ''.join(list(hangul_to_jamo(text)))
    for s in text:
        sequence.append(_symbol_to_id[s])
    sequence.append(_symbol_to_id['~'])
    return sequence

def sequence_to_text(sequence):
    result = ''
    for symbol_id in sequence:
        if symbol_id in _id_to_symbol:
            s = _id_to_symbol[symbol_id]
            result += s
    return result.replace('}{', ' ')

In [7]:
import pandas as pd
import numpy as np
import os, librosa, re, glob, scipy
from tqdm import tqdm

text_dir = './archive/transcript.v.1.4.txt'
filters = '([.,!?])'

metadata = pd.read_csv(text_dir, dtype='object', sep='|', header=None)
text = metadata[3].values

out_dir = './data'
os.makedirs(out_dir, exist_ok=True)
os.makedirs(out_dir + '/text', exist_ok=True)
os.makedirs(out_dir + '/wav', exist_ok=True)
os.makedirs(out_dir + '/mel', exist_ok=True)
# os.makedirs(out_dir + '/dec', exist_ok=True) # Tacotron2 내부에서 이미 구현함.
# os.makedirs(out_dir + '/spec', exist_ok=True) # Tacotron2에서는 필요하지 않음.

# text
print('Load Text')
text_len = []
for idx, s in enumerate(tqdm(text)):
    # 문자열에서 정규표현식을 이용하여 특정 문자열을 필터링하고,
    # 이를 빈 문자열('')로 대체한다.
    sentence = re.sub(re.compile(filters), '', s)
    sentence = text_to_sequence(sentence)
    
    text_len.append(len(sentence))
    text_name = 'kss-text-%05d.npy' % idx
    np.save(os.path.join(out_dir + '/text', text_name), sentence, allow_pickle=False)
np.save(os.path.join(out_dir + '/text_len.npy'), np.array(text_len))
print('Text Done')

Load Text


100%|██████████████████████████████████████████████████████████████████████████| 12854/12854 [00:06<00:00, 1949.24it/s]

Text Done





## 1.2. Audio Preprocessing

In [18]:
# audio
wav_dir = metadata[0].values

print('Load Audio')
mel_len_list = []
for idx, fn in enumerate(tqdm(wav_dir)):
    file_dir = './archive/kss/' + fn
    wav, _ = librosa.load(file_dir, sr=sample_rate)
    wav, _ = librosa.effects.trim(wav) # 묵음 제거

    # y[n] = x[n] - preemphasis * x[n-1]
    # 아래 필터는 고주파 성분 강조를 통해 음성 인식 성능을 향상시킨다.
    filtered_wav = scipy.signal.lfilter([1, -preemphasis], [1], wav)
    # stft 결과값(복소수) (진폭 정보를 추출)
    stft = librosa.stft(filtered_wav, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
    stft = np.abs(stft)
    mel_filter = librosa.filters.mel(sample_rate, n_fft, mel_dim)
    mel_spec = np.dot(mel_filter, stft) # Mel-spec 생성
    
    # dB 스케일(dB scale)을 적용 (사람의 청각 특성에 맞게 설계된 척도)
    mel_spec = 20 * np.log10(np.maximum(1e-5, mel_spec))
    stft = 20 * np.log10(np.maximum(1e-5, stft))
    
    # 정규화 및 클리핑
    mel_spec = np.clip((mel_spec - ref_db + max_db) / max_db, 1e-8, 1)
    stft = np.clip((stft - ref_db + max_db) / max_db, 1e-8, 1)
    
    mel_spec = mel_spec.T.astype(np.float32) # (Frames, 80)
    stft = stft.T.astype(np.float32) # (Frames, 513)
    
    mel_len_list.append([mel_spec.shape[0], idx]) # Stack of Frames
    
    mel_name = 'kss-mel-%05d.npy' % idx
    np.save(os.path.join(out_dir + '/mel', mel_name), mel_spec, allow_pickle=False)
    
    # Wavenet vocoder에서 이용할 wav 파일을 추가로 저장함.
    wav_name = 'kss-wav-%05d.npy' % idx
    np.save(os.path.join(out_dir + '/wav', wav_name), wav, allow_pickle=False)
    
    """
    # Tacotron2에서는 아래 과정이 필요하지 않음.
    stft_name = 'kss-spec-%05d.npy' % idx
    np.save(os.path.join(out_dir + '/spec', stft_name), stft, allow_pickle=False)
    """
    
    """
    # Tacotron2에서는 아래 과정을 모델 내부에 구현함.
    
    # Decoder Input
    mel_spec = mel_spec.reshape((-1, mel_dim))
    # 맨 앞에 <GO> frame이 결합된 mel-spectrogram. 맨 뒤의 한 frame은 제거한다.
    dec_input = np.concatenate((np.zeros_like(mel_spec[:1, :]), mel_spec[:-1, :]), axis=0)
    dec_input = dec_input[:, -mel_dim:]
    dec_name = 'kss-dec-%05d.npy' % idx
    np.save(os.path.join(out_dir + '/dec', dec_name), dec_input, allow_pickle=False)
    """

mel_len = sorted(mel_len_list)
np.save(os.path.join(out_dir + '/mel_len.npy'), np.array(mel_len))
print('Audio Done')

Load Audio


100%|████████████████████████████████████████████████████████████████████████████| 12854/12854 [16:07<00:00, 13.28it/s]

Audio Done





## 1.3. Dataset and DataLoader

In [5]:
import os, glob
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np

class TextMelDataset(Dataset):
    def __init__(self, data_dir):
        self.text_list = sorted(glob.glob(os.path.join(data_dir + '/text', '*.npy')))
        self.mel_list = sorted(glob.glob(os.path.join(data_dir + '/mel', '*.npy')))
        self.wav_list = sorted(glob.glob(os.path.join(data_dir + '/wav', '*.npy')))

        self.text_len = np.load(os.path.join(data_dir + '/text_len.npy'))
        self.mel_len = np.load(os.path.join(data_dir + '/mel_len.npy'))
        
    def __len__(self):
        return len(self.text_list)
    
    def __getitem__(self, idx):
        text = torch.from_numpy(np.load(self.text_list[idx]))
        text_len = len(text)
        
        mel = torch.from_numpy(np.load(self.mel_list[idx]))
        mel_len = mel.shape[0]
        
        wav = torch.from_numpy(np.load(self.wav_list[idx]))
        wav_len = wav.shape[0]
        return (text, text_len, mel, mel_len, wav, wav_len)

def collate_fn(batch):
    text = []
    text_len = []
    mel = []
    mel_len = []
    wav = []
    wav_len = []
    
    for t, tl, m, ml, w, wl in batch:
        text.append(t)
        text_len.append(tl)
        mel.append(m)
        mel_len.append(ml)
        wav.append(w)
        wav_len.append(wl)
        
    max_text_len = max(text_len)
    max_mel_len = max(mel_len)
    max_wav_len = max(wav_len)
    
    # text zero_padding
    padded_text_batch = torch.zeros((len(batch), max_text_len), dtype=torch.int32)
    for i, x in enumerate(text):
        padded_text_batch[i, :len(x)] = torch.Tensor(x)
    
    # mel zero_padding
    padded_mel_batch = torch.zeros((len(batch), max_mel_len, mel_dim), dtype=torch.float32)
    for i, x in enumerate(mel):
        padded_mel_batch[i, :x.shape[0], :x.shape[1]] = torch.Tensor(x)
        
    # wav zero_padding
    padded_wav_batch = torch.zeros((len(batch), max_wav_len), dtype=torch.float32)
    for i, x in enumerate(wav):
        padded_wav_batch[i, :x.shape[0]] = torch.Tensor(x)
        
    return padded_text_batch, text_len, padded_mel_batch, mel_len, padded_wav_batch, wav_len

data_dir = './data'
dataset = TextMelDataset(data_dir)
dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size, collate_fn=collate_fn)

In [6]:
# DataLoader 객체를 반복자로 변환
dataiter = iter(dataloader)

# 데이터 한 번 추출
batch = next(dataiter)

print('batch[0](text) shape(B, Max_T):', batch[0].shape) # Max_T: The number of text length
print('batch[1](text_len) len(B):', len(batch[1]))
print('batch[2](mel) shape(B, Max_F, mel_dim):', batch[2].shape) # Max_F: The number of Mel-spectrogram frames
print('batch[3](mel_len) len(B):', len(batch[3]))
print('batch[4](wav) shape(B, Max_L, mel_dim):', batch[4].shape) # Max_L: length of time(seconds) * sampling_rate(22050)
print('batch[5](wav_len) len(B):', len(batch[5]))
print('Total:', dataset.__len__())
print('num_of_batches:', len(dataloader))

batch[0](text) shape(B, Max_T): torch.Size([64, 69])
batch[1](text_len) len(B): 64
batch[2](mel) shape(B, Max_F, mel_dim): torch.Size([64, 383, 80])
batch[3](mel_len) len(B): 64
batch[4](wav) shape(B, Max_L, mel_dim): torch.Size([64, 97792])
batch[5](wav_len) len(B): 64
Total: 12854
num_of_batches: 201


# 2. Model

- Tacotron2 Encoder와 Decoder 구현은 다음 Github를 참고한다.
- NVIDIA Github: `https://github.com/NVIDIA/tacotron2`

## 2.1. Tacotron2 Encoder

In [6]:
from torch import nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        # Character Embedding
        self.embedding = nn.Embedding(symbol_length, embedding_dim) # (70, 512) | (B, T) -> (B, T, 512)
        
        # 3 Conv Layers
        convolutions = []
        for i in range(3):
            conv = nn.Sequential(
                nn.Conv1d(embedding_dim, embedding_dim, kernel_size=5, stride=1, padding=3, dilation=1),
                nn.BatchNorm1d(embedding_dim), nn.ReLU()) # (B, 512, T) -> (B, 512, T)
            convolutions.append(conv)
        self.conv = nn.ModuleList(convolutions) # 3 layers of Conv1d-BatchNorm-ReLU
        
        # Bidirectional LSTM
        self.lstm = nn.LSTM(embedding_dim, embedding_dim // 2, batch_first=True, bidirectional=True) # (B, T, 512) -> (B, T, 256*2)
        
    def forward(self, text, text_len):
        """
        =====inputs=====
        text: (B, Max_T)
        text_len: (B)
        =====outputs=====
        outputs: (B, Max_T, 512) # memory
        """
        x = self.embedding(text) # (B, T) -> (B, T, 512)
        
        x = x.transpose(1, 2) # (B, T, 512) -> (B, 512, T)
        for conv in self.conv:
            x = conv(x) # (B, 512, T) -> (B, 512, T)
            
        x = x.transpose(1, 2) # (B, 512, T) -> (B, T, 512)
        # rnn이 실제 입력값이 있는 부분만을 고려하여 연산을 수행할 수 있도록 함.
        x = nn.utils.rnn.pack_padded_sequence(x, text_len, batch_first=True,
                                              enforce_sorted=False) # padding seq -> packing seq
        self.lstm.flatten_parameters() # CUDA ERROR FIX
        x, _ = self.lstm(x) # (B, T, 512) -> (B, T, 256*2)
        outputs, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) # packing seq -> padding seq
        
        return outputs

## 2.2. Tacotron2 Decoder

In [7]:
class PreNet(nn.Module):
    def __init__(self):
        super(PreNet, self).__init__()
        # 2 Layers Pre-Net
        self.fc1 = nn.Sequential(
            nn.Linear(80, 256, bias=False),
            nn.ReLU()) # (B, 80) -> (B, 256)
        self.fc2 = nn.Sequential(
            nn.Linear(256, 256, bias=False),
            nn.ReLU()) # (B, 256) -> (B, 256)
    def forward(self, uni_mel):
        """
        =====inputs=====
        uni_mel: (B, 80) # mel로부터 하나씩 추출해낸 것이다.
        =====outputs=====
        outputs: (B, 256)
        """
        x = F.dropout(self.fc1(uni_mel), p=0.5, training=True) # 항상 dropout
        outputs = F.dropout(self.fc2(x), p=0.5, training=True) # 항상 dropout
        return outputs
    
class LocationLayer(nn.Module):
    def __init__(self):
        super(LocationLayer, self).__init__()
        self.conv = nn.Conv1d(2, 32, kernel_size=31, stride=1, padding=15, dilation=1) # (B, 2, Max_T) -> (B, 32, Max_T)
        self.fc = nn.Linear(32, 128, bias=False) # (B, Max_T, 32) -> (B, Max_T, 128)
        
    def forward(self, attention_weights_cat):
        """
        =====inputs=====
        attention_weights_cat: (B, 2, Max_T) # 이전 time step의 attention weights과 attention_weights_cum의 concat
        =====outputs=====
        outputs: (B, Max_T, 128)
        """
        x = self.conv(attention_weights_cat)
        x = x.transpose(1, 2)
        outputs = self.fc(x)
        return outputs
        
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.fc_memory = nn.Linear(embedding_dim, 128, bias=False) # Embedding outputs
        self.fc_query = nn.Linear(1024, 128, bias=False) # Query
        self.fc_location = LocationLayer() # Previous Attention Weight
        
        self.v = nn.Linear(128, 1, bias=True)
        
    def get_attention_weights(self, memory, query, attention_weights_cat, text_len):
        """
        =====inputs=====
        memory: (B, Max_T, 512) # Encoder의 outputs
        query: (B, 1024) # Attention LSTM의 outputs
        attention_weights_cat: (B, 2, Max_T) # 이전 time step의 attention weights과 attention_weights_cum의 concat
        text_len: (B) # text_len이 담긴 list
        =====outputs=====
        attention_weights: (B, Max_T) # 현재 time step의 attention weight. attention weight가 모여 alignment가 형성된다.
        """
        h = self.fc_memory(memory) # (B, Max_T, 512) -> (B, Max_T, 128)
        d = self.fc_query(query.unsqueeze(1)) # (B, 1024) -> (B, 1, 128)
        f = self.fc_location(attention_weights_cat) # (B, 2, Max_T) -> (B, Max_T, 128)
        
        score = self.v(torch.tanh(h + d + f)) # (B, Max_T, 1)
        score = score.squeeze(dim=2) # (B, Max_T)
        
        for idx, length in enumerate(text_len):
            score[idx, length:] = -torch.inf
        
        attention_weights = F.softmax(score, dim=1)
        return attention_weights
    
    def forward(self, memory, query, attention_weights_cat, text_len):
        """
        =====inputs=====
        memory: (B, Max_T, 512) # Encoder의 outputs
        query: (B, 1024) # Attention LSTM의 outputs
        attention_weights_cat: (B, 2, Max_T) # 이전 time step의 attention weights과 attention_weights_cum의 concat
        text_len: (B)
        =====outputs=====
        context: (B, 512)
        attention_weights: (B, Max_T) # 현재 time step의 attention weight
        """
        attention_weights = self.get_attention_weights(memory, query, attention_weights_cat, text_len) # (B, Max_T)
        
        context = torch.bmm(attention_weights.unsqueeze(1), memory) # bmm: batch matrix-matrix product
        # (B, 1, Max_T)@(B, Max_T, 512) = (B, 1, 512)
        context = context.squeeze(1) # (B, 512)

        return context, attention_weights # (B, 512), (B, Max_T)

class PostNet(nn.Module):
    def __init__(self):
        super(PostNet, self).__init__()
        self.convolutions = nn.ModuleList()
        
        self.convolutions.append(nn.Sequential(
            nn.Conv1d(80, 512, kernel_size=5, stride=1, padding=2, dilation=1), # (B, 80, Max_F) -> (B, 512, Max_F)
            nn.BatchNorm1d(512), nn.Tanh()))
        
        for i in range(3):
            self.convolutions.append(nn.Sequential(
                nn.Conv1d(512, 512, kernel_size=5, stride=1, padding=2, dilation=1), # (B, 512, Max_F) -> (B, 512, Max_F)
                nn.BatchNorm1d(512), nn.Tanh())) # x3
            
        self.convolutions.append(nn.Sequential(
            nn.Conv1d(512, 80, kernel_size=5, stride=1, padding=2, dilation=1), # (B, 512, Max_F) -> (B, 80, Max_F)
            nn.BatchNorm1d(80)))
        
    def forward(self, mel_outputs):
        """
        =====inputs=====
        mel_outputs: (B, 80, Max_T) # Decoder outputs
        =====outputs=====
        outputs: (B, 80, Max_T)
        """
        x = mel_outputs
        for conv in self.convolutions:
            x = F.dropout(conv(x), p=0.5, training=self.training)
        outputs = x
        return outputs
    
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        ##### Layers #####
        # Pre-Net
        self.prenet = PreNet()
        # Attention LSTM Cell
        self.attention_lstm = nn.LSTMCell(256 + 512, 1024)
        # Attention
        self.attention = Attention()
        # Decoder LSTM Cell
        self.decoder_lstm = nn.LSTMCell(1024 + 512, 1024)
        # Linear Projection
        self.linear_projection = nn.Linear(1024 + 512, 80)
        # Gate Linear Projection
        self.gate = nn.Sequential(nn.Linear(1024 + 512, 1), nn.Sigmoid())
        
        # Post-Net
        self.postnet = PostNet()
        
    def init_weights(self, memory):
        """
        =====inputs=====
        memory: (B, Max_T, 512)
        """
        device = next(self.parameters()).device
        
        B = memory.size(0)
        Max_T = memory.size(1)
        ##### Weight #####
        self.attention_hidden = torch.zeros((B, 1024), dtype=torch.float32).to(device)
        self.attention_cell = torch.zeros((B, 1024), dtype=torch.float32).to(device)
        
        self.decoder_hidden = torch.zeros((B, 1024), dtype=torch.float32).to(device)
        self.decoder_cell = torch.zeros((B, 1024), dtype=torch.float32).to(device)
        
        self.attention_weights = torch.zeros((B, Max_T), dtype=torch.float32).to(device)
        self.attention_weights_cum = torch.zeros((B, Max_T), dtype=torch.float32).to(device)
        self.context = torch.zeros((B, 512), dtype=torch.float32).to(device)
        
    def forward(self, mel, memory, text_len):
        """
        =====inputs=====
        mel: (B, Max_F, mel_dim=80)
        memory: (B, Max_T, 512)
        text_len: (B)
        =====outputs=====
        mel_outputs: (B, 80, Max_F) # 중간 mel_spec
        postnet_mel_outputs: (B, 80, Max_F) # 최종 mel_spec | 두 mel_sepc의 loss의 합이 손실함수에 포함된다.
        gate_outputs: (B, Max_F)
        alignments: (B, Max_T, Max_F)
        """
        device = next(self.parameters()).device
        
        self.init_weights(memory) # 초기 가중치 초기화
        
        mel_outputs = []
        gate_outputs = []
        alignments = []
        
        B = mel.size(0)
        mel_dim = mel.size(2)
        GO = torch.zeros((B, 1, mel_dim), dtype=torch.float32).to(device) # GO frame
        mel_inputs = torch.cat((GO, mel), dim=1) # (B, 1 + Max_F, mel_dim)
        
        for idx in range(mel_inputs.size(1) - 1): # Max_F번 반복
            mel_input = mel_inputs[:, idx, :] # (B, mel_dim=80)
            x = self.prenet(mel_input) # (B, 80) -> (B, 256)

            # Attention LSTM Cell
            x = torch.cat((x, self.context), dim=1) # (B, 256 + 512)
            self.attention_hidden, self.attention_cell = self.attention_lstm(x, (self.attention_hidden, self.attention_cell))
            # (B, 256 + 512) -> (B, 1024)
            self.attention_hidden = F.dropout(self.attention_hidden, p=0.1, training=self.training)
            query = self.attention_hidden
            
            # Attention
            attention_weights_cat = torch.cat((self.attention_weights.unsqueeze(1),
                                              self.attention_weights_cum.unsqueeze(1)), dim=1) # (B, 2, Max_T)
            self.context, self.attention_weights = self.attention(memory, query, attention_weights_cat, text_len)
            # (B, 512), (B, Max_T)
            self.attention_weights_cum += self.attention_weights # Attention weights 누적
            
            # Decoder LSTM Cell
            x = torch.cat((query, self.context), dim=1) # (B, 1024 + 512)
            self.decoder_hidden, self_decoder_cell = self.decoder_lstm(x, (self.decoder_hidden, self.decoder_cell))
            # (B, 1024 + 512) -> (B, 1024)
            self.decoder_hidden = F.dropout(self.decoder_hidden, p=0.1, training=self.training)
            x = self.decoder_hidden
            
            x = torch.cat((x, self.context), dim=1) # (B, 1024 + 512)
            mel_output = self.linear_projection(x) # (B, 1024 + 512) -> (B, 80)
            gate_output = self.gate(x) # (B, 1024 + 512) -> (B, 1)
            
            mel_outputs.append(mel_output) # final: (B, 80) * Max_F
            gate_outputs.append(gate_output) # final: (B, 1) * Max_F
            alignments.append(self.attention_weights) # final: (B, Max_T) * Max_F
            
        mel_outputs = torch.stack(mel_outputs, dim=2) # (B, 80, Max_F)
        gate_outputs = torch.stack(gate_outputs, dim=1) # (B, Max_F, 1)
        gate_outputs = gate_outputs.squeeze(2) # (B, Max_F)
        alignments = torch.stack(alignments, dim=2) # (B, Max_T, Max_F)
        
        # Post-Net
        postnet_outputs = self.postnet(mel_outputs) # (B, 80, Max_F) -> (B, 80, Max_F)
        postnet_mel_outputs = mel_outputs + postnet_outputs # (B, 80, Max_F)
        
        return mel_outputs, postnet_mel_outputs, gate_outputs, alignments

In [8]:
class Tacotron2(nn.Module):
    def __init__(self):
        super(Tacotron2, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        
    def forward(self, text, text_len, mel):
        """
        =====inputs=====
        text: (B, Max_T)
        text_len: (B)
        mel: (B, 80, Max_F)
        =====outputs=====
        mel_outputs: (B, 80, Max_F)
        postnet_mel_outputs: (B, 80, Max_F)
        gate_outputs: (B, Max_F)
        alignments: (B, Max_T, Max_F)
        """
        memory = self.encoder(text, text_len)
        mel_outputs, postnet_mel_outputs, gate_outputs, alignments = self.decoder(mel, memory, text_len)
        return mel_outputs, postnet_mel_outputs, gate_outputs, alignments

# TEST

In [16]:
##### Attention Test #####
class BahdanauAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc_memory = nn.Linear(embedding_dim, 128, bias=False) # Embedding outputs
        self.fc_query = nn.Linear(1024, 128, bias=False) # Query
        
    def forward(self, memory, query):
        q = self.fc_query(query.unsqueeze(1)).unsqueeze(2) # (B, 1024) -> (B, 1, 1, 128)
        v = self.fc_memory(memory).unsqueeze(1) # (B, Max_T, 512) -> (B, 1, Max_T, 128)
        
        score = torch.sum(torch.tanh(q + v), dim=-1) # (B, 1, Max_T)

        attention_weights = F.softmax(score, dim=-1) # (B, 1, Max_T)
        
        context = torch.matmul(attention_weights, memory) # bmm: batch matrix-matrix product
        # (B, 1, Max_T)@(B, Max_T, 512) = (B, 1, 512)
        context = context.squeeze(1) # (B, 512)
        attention_weights = attention_weights.squeeze(1)#  (B, Max_T)
        
        return context, attention_weights
    
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        ##### Layers #####
        # Pre-Net
        self.prenet = PreNet()
        # Attention LSTM Cell
        self.attention_lstm = nn.LSTMCell(256 + 512, 1024)
        # Attention
        self.attention = BahdanauAttention()
        # Decoder LSTM Cell
        self.decoder_lstm = nn.LSTMCell(1024 + 512, 1024)
        # Linear Projection
        self.linear_projection = nn.Linear(1024 + 512, 80)
        # Gate Linear Projection
        self.gate = nn.Sequential(nn.Linear(1024 + 512, 1), nn.Sigmoid())
        
        # Post-Net
        self.postnet = PostNet()
        
    def init_weights(self, memory):
        """
        =====inputs=====
        memory: (B, Max_T, 512)
        """
        device = next(self.parameters()).device
        
        B = memory.size(0)
        Max_T = memory.size(1)
        ##### Weight #####
        self.attention_hidden = torch.zeros((B, 1024), dtype=torch.float32).to(device)
        self.attention_cell = torch.zeros((B, 1024), dtype=torch.float32).to(device)
        
        self.decoder_hidden = torch.zeros((B, 1024), dtype=torch.float32).to(device)
        self.decoder_cell = torch.zeros((B, 1024), dtype=torch.float32).to(device)
        
        self.attention_weights = torch.zeros((B, Max_T), dtype=torch.float32).to(device)
        self.context = torch.zeros((B, 512), dtype=torch.float32).to(device)
        
    def forward(self, mel, memory):
        """
        =====inputs=====
        mel: (B, Max_F, mel_dim=80)
        memory: (B, Max_T, 512)
        =====outputs=====
        mel_outputs: (B, 80, Max_F) # 중간 mel_spec
        postnet_mel_outputs: (B, 80, Max_F) # 최종 mel_spec | 두 mel_sepc의 loss의 합이 손실함수에 포함된다.
        gate_outputs: (B, Max_F)
        alignments: (B, Max_T, Max_F)
        """
        device = next(self.parameters()).device
        
        self.init_weights(memory) # 초기 가중치 초기화
        
        mel_outputs = []
        gate_outputs = []
        alignments = []
        
        B = mel.size(0)
        mel_dim = mel.size(2)
        GO = torch.zeros((B, 1, mel_dim), dtype=torch.float32).to(device) # GO frame
        mel_inputs = torch.cat((GO, mel), dim=1) # (B, 1 + Max_F, mel_dim)
        
        for idx in range(mel_inputs.size(1) - 1): # Max_F번 반복
            mel_input = mel_inputs[:, idx, :] # (B, mel_dim=80)
            x = self.prenet(mel_input) # (B, 80) -> (B, 256)

            # Attention LSTM Cell
            x = torch.cat((x, self.context), dim=1) # (B, 256 + 512)
            self.attention_hidden, self.attention_cell = self.attention_lstm(x, (self.attention_hidden, self.attention_cell))
            # (B, 256 + 512) -> (B, 1024)
            self.attention_hidden = F.dropout(self.attention_hidden, p=0.1, training=self.training)
            query = self.attention_hidden
            
            # Attention
            self.context, self.attention_weights = self.attention(memory, query)
            # (B, 512), (B, Max_T)
            
            x = torch.cat((query, self.context), dim=1) # (B, 1024 + 512)
            
            # Decoder LSTM Cell
            self.decoder_hidden, self_decoder_cell = self.decoder_lstm(x, (self.decoder_hidden, self.decoder_cell))
            # (B, 1024 + 512) -> (B, 1024)
            self.decoder_hidden = F.dropout(self.decoder_hidden, p=0.1, training=self.training)
            x = self.decoder_hidden
            
            x = torch.cat((x, self.context), dim=1) # (B, 1024 + 512)
            mel_output = self.linear_projection(x) # (B, 1024 + 512) -> (B, 80)
            gate_output = self.gate(x) # (B, 1024 + 512) -> (B, 1)
            
            mel_outputs.append(mel_output) # final: (B, 80) * Max_F
            gate_outputs.append(gate_output) # final: (B, 1) * Max_F
            alignments.append(self.attention_weights) # final: (B, Max_T) * Max_F
            
        mel_outputs = torch.stack(mel_outputs, dim=2) # (B, 80, Max_F)
        gate_outputs = torch.stack(gate_outputs, dim=1) # (B, Max_F, 1)
        gate_outputs = gate_outputs.squeeze(2) # (B, Max_F)
        alignments = torch.stack(alignments, dim=2) # (B, Max_T, Max_F)
        
        # Post-Net
        postnet_outputs = self.postnet(mel_outputs) # (B, 80, Max_F) -> (B, 80, Max_F)
        postnet_mel_outputs = mel_outputs + postnet_outputs # (B, 80, Max_F)
        
        return mel_outputs, postnet_mel_outputs, gate_outputs, alignments
    
class Tacotron2(nn.Module):
    def __init__(self):
        super(Tacotron2, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        
    def forward(self, text, text_len, mel):
        """
        =====inputs=====
        text: (B, Max_T)
        text_len: (B)
        mel: (B, 80, Max_F)
        =====outputs=====
        mel_outputs: (B, 80, Max_F)
        postnet_mel_outputs: (B, 80, Max_F)
        gate_outputs: (B, Max_F)
        alignments: (B, Max_T, Max_F)
        """
        memory = self.encoder(text, text_len)
        mel_outputs, postnet_mel_outputs, gate_outputs, alignments = self.decoder(mel, memory)
        return mel_outputs, postnet_mel_outputs, gate_outputs, alignments

### Debugging

In [13]:
import time
device = 'cpu'

# DataLoader 객체를 반복자로 변환
dataiter = iter(dataloader)

# 데이터 한 번 추출
batch = next(dataiter)

start = time.time()

# Tacotron에 대입
a, b, c, d = Tacotron2()(batch[0], batch[1], batch[2])
print(a.shape, b.shape, c.shape, d.shape)

end = time.time()
print('cpu: ', end-start)

torch.Size([64, 80, 659]) torch.Size([64, 80, 659]) torch.Size([64, 659]) torch.Size([64, 93, 659])
cpu:  30.782493591308594


In [26]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Tacotron2()
model = model.to(device)

# DataLoader 객체를 반복자로 변환
dataiter = iter(dataloader)

# 데이터 한 번 추출
batch = next(dataiter)

start = time.time()

# Tacotron에 대입
a, b, c, d = model(batch[0].to(device), batch[1], batch[2].to(device))
print(batch[0].size(), len(batch[1]), batch[2].size())
print(a.shape, b.shape, c.shape, d.shape)

end = time.time()
print('cuda: ', end-start)

torch.Size([64, 85]) 64 torch.Size([64, 479, 80])
torch.Size([64, 80, 479]) torch.Size([64, 80, 479]) torch.Size([64, 479]) torch.Size([64, 85, 479])
cuda:  1.3982977867126465


## 2.3. Wavenet Vocoder

- Wavenet Vocoder 구현은 다음 Github를 참고한다.
- hccho2 Github: `https://github.com/hccho2/Tacotron2-Wavenet-Korean-TTS`

In [6]:
class DilatedConv(nn.Module):
    def __init__(self, in_channel, out_channel, dilation):
        """
        in_channel: (causal) 1 (dilated) 128
        out_channel: (causal) 128 (dilated) 512
        dilation: (causal) 1 (dilated) 1, 2, 4, 8, 16, 32, 64, 128, 256, 512 중 하나
        """
        super(DilatedConv, self).__init__()
        self.padding = (3 - 1)*dilation
        self.dilated_conv = nn.Conv1d(in_channel, out_channel, kernel_size=3, stride=1, padding=0, dilation=dilation)
        # (B, in_channel, ?) -> (B, out_channel, ?) # ?: (training) Max_L-1, (inference) 1
        
    def forward(self, inputs):
        """
        =====inputs=====
        * ?: (training) Max_L-1, (inference) 1 (* 이후에도 동일함)
        inputs: (causal) (B, 1, ?) (dilation) (B, 128, ?)
        =====outputs=====
        output: (causal) (B, 128, ?) (dilation) (B, 512, ?)
        """
        B, C, L = inputs.size()
        zero_pad = torch.zeros(B, C, self.padding).to(device)
        padded_inputs = torch.cat([zero_pad, inputs], dim=-1) # receptive field에 맞도록 padding
        
        outputs = self.dilated_conv(padded_inputs) # (B, in_channel, ?+padding) -> (B, out_channel, ?)
        return outputs # (B, out_channel, ?)
    
class ResidualBlock(nn.Module):
    def __init__(self, dilation):
        """
        dilation: 1, 2, 4, 8, 16, 32, 64, 128, 256, 512 중 하나
        """
        super(ResidualBlock, self).__init__()
        self.dilated_conv = DilatedConv(128, 512, dilation) # (B, 128, ?) -> (B, 512, ?)
        """
        메모리 용량을 줄이기 위해 삭제
        self.filter_conv = nn.Conv1d(256, 256, kernel_size=1, stride=1, padding=0, dilation=1) # (B, 256, ?) -> (B, 256, ?)
        self.gate_conv = nn.Conv1d(256, 256, kernel_size=1, stride=1, padding=0, dilation=1) # (B, 256, ?) -> (B, 256, ?)
        """
        self.local_filter_conv = nn.Conv1d(80, 256, kernel_size=1, stride=1, padding=0, dilation=1) # (B, 80, ?) -> (B, 256, ?)
        self.local_gate_conv = nn.Conv1d(80, 256, kernel_size=1, stride=1, padding=0, dilation=1) # (B, 80, ?) -> (B, 256, ?)
        self.global_filter_conv = nn.Conv1d(32, 256, kernel_size=1, stride=1, padding=0, dilation=1) # (B, 32, ?) -> (B, 256, ?)
        self.global_gate_conv = nn.Conv1d(32, 256, kernel_size=1, stride=1, padding=0, dilation=1) # (B, 32, ?) -> (B, 256, ?)
        self.residual_conv = nn.Conv1d(256, 128, kernel_size=1, stride=1, padding=0, dilation=1) # (B, 256, ?) -> (B, 128, ?)
        self.skip_conv = nn.Conv1d(256, 128, kernel_size=1, stride=1, padding=0, dilation=1) # (B, 256, ?) -> (B, 128, ?)
        
    def forward(self, inputs, local_condition, global_condition):
        """
        =====inputs=====
        inputs: (B, 128, ?)
        local_condition: (B, 80, ?) # upscaled 된 tacotron2의 mel_outputs
        global_condition: (B, 32, 1) # embedded 된 speaker_id
        =====outputs=====
        residual_outputs: (B, 128, ?)
        skip_outputs: (B, 128, ?)
        """
        x = self.dilated_conv(inputs) # (B, 128, ?) -> (B, 512, ?)
        x_filter, x_gate = torch.chunk(x, chunks=2, dim=1) # x를 dim=1 기준으로 둘로 나눔: (B, 256, ?), (B, 256, ?)
        """
        메모리 용량을 줄이기 위해 삭제
        x_filter = self.filter_conv(x_filter) # (B, 256, ?)
        x_gate = self.gate_conv(x_gate) # (B, 256, ?)
        """
        
        if local_condition is not None:
            x_filter = x_filter + self.local_filter_conv(local_condition) # (B, 256, ?)
            x_gate = x_gate + self.local_gate_conv(local_condition) # (B, 256, ?)
        if global_condition is not None:
            x_filter = x_filter + self.global_filter_conv(global_condition) # (B, 256, ?)
            x_gate = x_gate + self.global_gate_conv(global_condition) # (B, 256, ?)
        x = torch.tanh(x_filter) * torch.sigmoid(x_gate) # (B, 256, ?)
        
        residual_outputs = self.residual_conv(x) # (B, 256, ?) -> (B, 128, ?)
        residual_outputs = inputs + residual_outputs # (B, 128, ?)
        skip_outputs = self.skip_conv(x) # (B, 256, ?) -> (B, 128, ?)
        
        return residual_outputs, skip_outputs
    
class StackOfResidualBlocks(nn.Module):
    def __init__(self, dilations):
        """
        dilations = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]*3
        """
        super().__init__()
        self.stack_of_residual_blocks = nn.ModuleList()
        for dilation in dilations:
            residual_block = ResidualBlock(dilation)
            self.stack_of_residual_blocks.append(residual_block)        
        
    def forward(self, inputs, local_condition, global_condition):
        """
        =====inputs=====
        inputs: (B, 128, ?)
        local_condition: (B, 80, ?) # upscaled 된 tacotron2의 mel_outputs
        global_condition: (B, 32, 1) # embedded 된 speaker_id
        =====outputs=====
        sum_of_skip_outputs: (B, 128, ?) # skip_outputs을 skip-connection한 결과
        """
        residual_outputs = inputs
        stack_of_skip_outputs = []
        for residual_block in self.stack_of_residual_blocks:
            residual_outputs, skip_outputs = residual_block(residual_outputs, local_condition, global_condition)
            # (B, 128, ?), (B, 128, ?)
            stack_of_skip_outputs.append(skip_outputs)
            
        sum_of_skip_outputs = torch.zeros_like(stack_of_skip_outputs[0])
        for skip_outputs in stack_of_skip_outputs:
            sum_of_skip_outputs += skip_outputs # (B, 128, ?)
        
        return sum_of_skip_outputs
    
class WaveNet(nn.Module):
    def __init__(self, upsampling_factors, dilations, num_of_speakers):
        """
        upsampling_factors = [16, 16]
        dilations = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]*3
        num_of_speakers: int # 화자의 수, global_condition의 원소의 종류
        """
        super().__init__()
        # Local condition Upsampling
        self.upsampling_convs = nn.ModuleList()
        for upscale in upsampling_factors:
            upsampling_conv = nn.ConvTranspose2d(in_channels=1, out_channels=1, kernel_size=(1, upscale), stride=(1, upscale),
                                                 padding=0, output_padding=0, bias=False, dilation=1)
            # (B, 1, 80, F) -> (B, 1, 80, upscale*F)
            self.upsampling_convs.append(upsampling_conv)
        # Global condition Embedding
        self.global_embedding = nn.Embedding(num_of_speakers, 32) # (B) -> (B, 32)
        
        # Casual Conv
        self.casual_conv = DilatedConv(in_channel=1, out_channel=128, dilation=1)
        # (B, 1, ?) -> (B, 128, ?)
        
        # Residual Blocks
        self.residual_blocks = StackOfResidualBlocks(dilations)
        
        # Post Layers
        self.post_layers = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(128, 128, kernel_size=1, stride=1, padding=0, dilation=1), # (B, 128, ?) -> (B, 128, ?)
            nn.ReLU(),
            nn.Conv1d(128, 30, kernel_size=1, stride=1, padding=0, dilation=1) # (B, 128, ?) -> (B, 30, ?)
        )
        
    def forward(self, input_waveform, local_condition, global_condition):
        """
        =====inputs=====
        * ?: (train) L-1, (inference) 1
        input_waveform: (B, ?)
        local_condition: (B, 80, Max_F) # Tacotron2의 mel_outputs
        global_condition: (B) # Speaker의 ID
        =====outputs=====
        outputs: (B, 30, Max_L) # MoL의 모수(가중치)에 대한 features
        """
        L = input_waveform.size(dim=1)
        # Preprocessing of input_waveform
        if self.training == True:
            input_waveform = input_waveform[:, :-1] # 마지막은 제외 | (B, L) -> (B, L-1)
        # input_waveform: (train) (B, L-1), (inference) (B, 1)
        # 아래부터는 ?: L-1 or 1
        input_waveform = input_waveform.unsqueeze(dim=1) # (B, ?) -> (B, 1, ?)
        
        # Preprocessing of local_condition
        local_condition = local_condition.unsqueeze(dim=1) # (B, 80, F) -> (B, 1, 80, F)
        for upsampling_conv in self.upsampling_convs:
            local_condition = upsampling_conv(local_condition)
            # (B, 1, 80, F) -> (B, 1, 80, 16*F) -> (B, 1, 80, 256*F)
        local_condition = local_condition[:, 0, :, :L-1] # (B, 80, L-1)
        
        # Preprocessing of global_condition
        global_condition = self.global_embedding(global_condition) # (B) -> (B, 32)
        global_condition = global_condition.unsqueeze(dim=-1) # (B, 32, 1)
        
        # WaveNet
        x = self.casual_conv(input_waveform) # (B, 1, ?)
        x = self.residual_blocks(x, local_condition, global_condition) # (B, 128, ?)
        outputs = self.post_layers(x) # (B, 128, ?) -> (B, 30, ?)
        
        return outputs

### Debugging

In [8]:
# DataLoader 객체를 반복자로 변환
dataiter = iter(dataloader)

# 데이터 한 번 추출
batch = next(dataiter)

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

start = time.time()

input_waveform = batch[4]
local_condition = batch[2].transpose(1, 2)
global_condition = torch.randint(low=0, high=2, size=(batch_size,))
num_of_speakers = 2

net = WaveNet(upsampling_factors, dilations, num_of_speakers)
net = net.to(device)

output = net(input_waveform.to(device), local_condition.to(device), global_condition.to(device))
print(output.size())

end = time.time()
print(device, ':', end-start)

RuntimeError: [enforce fail at C:\cb\pytorch_1000000000000\work\c10\core\impl\alloc_cpu.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 7985889280 bytes.

# 3. Training

## 3.0. Plot Alignment
- Plot Alignment는 Tacotron 1에서의 구현을 그대로 가져온다.
- 이후 main함수 구현도 Tacotron1을 참고하겠다.
- chldkato Github: `https://github.com/chldkato/Tacotron-pytorch`

In [9]:
# Matplotlib으로 그린 그래프를 파일로 저장할 때,
# 'Agg' 백엔드를 사용하여 비트맵 그래픽스로 렌더링하도록 설정한다.
# 즉, 수학적으로 표현하지 않고 이미지를 픽셀 단위로 분할하여 저장한다.
matplotlib.use('Agg')

# Matplotlib에서 한글 폰트를 설정
font_name = fm.FontProperties(fname="malgun.ttf").get_name()
matplotlib.rc('font', family=font_name, size=14)

def plot_alignment(alignment, path, text, step, loss):
    text = text.rstrip('_').rstrip('~')
    alignment = alignment[:len(text)]
    
    # 하나의 그림(fig) 객체와 하나의 축(ax) 객체를 생성
    _, ax = plt.subplots(figsize=(len(text)/3, 5))
    # 생성한 축(ax) 객체에 이미지를 출력
    im = ax.imshow(np.transpose(alignment), aspect='auto', origin='lower')
    
    plt.xlabel('Encoder timestep')
    plt.ylabel('Decoder timestep')
    # 공백 문자 ' '를 빈 문자열 ''로 변환
    text = [x if x != ' ' else '' for x in list(text)]
    # x축의 눈금과 레이블을 설정
    plt.xticks(range(len(text)), text)
    
    plt.title(f"step: {step}, loss: {loss:.5f}", loc="center", pad=10)
    
    # 그래프의 레이아웃을 조정
    plt.tight_layout()
    plt.colorbar(im, ax=ax) # colorbar
    plt.savefig(path, format='png')
    plt.close()

## 3.1. Tacotron2

In [10]:
class Tacotron2Loss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, mel_outputs, postnet_mel_outputs, gate_outputs,
                mel_targets, gate_targets):
        """
        =====inputs=====
        mel_outputs: (B, 80, Max_F)
        postnet_mel_outputs: (B, 80, Max_F)
        gate_outputs: (B, Max_F)
        mel_targets: (B, 80, Max_F)
        gate_targets: (B, Max_F)
        =====outputs=====
        
        """
        mel_targets.requires_grad = False
        gate_targets.requires_grad = False
        
        mel_loss =  nn.MSELoss()(mel_outputs, mel_targets)
        postnet_mel_loss = nn.MSELoss()(postnet_mel_outputs, mel_targets)
        gate_loss = nn.BCEWithLogitsLoss()(gate_outputs, gate_targets)
        return mel_loss + postnet_mel_loss + gate_loss

In [11]:
def train(model_name, check_step):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    train_loader = dataloader # 1.3에서 정의함.
    model = Tacotron2()
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters())
    
    size = len(dataloader.dataset)
    num_of_batches = len(dataloader)
    
    os.makedirs('ckpt/' + model_name + "/1", exist_ok=True)
    
    epoch, step = 1, 0
    if check_step is not None:
        check_point = "./ckpt/" + model_name + "/1/ckpt-" + str(check_step) + ".pt"
        ckpt = torch.load(check_point)
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
        step = ckpt['step']
        epoch = skpt['epoch']
        print(f'Load Status: Epoch {epoch}, Step {step}')

    # PyTorch에서 CUDA 연산을 더 빠르게 수행하기 위한 기능 중 하나
    torch.backends.cudnn.benchmark = True    
    
    start = time.time()
    while True:
        for i in range(num_of_batches):
            text, text_len, mel, mel_len, wav, _ = next(iter(train_loader))
            text = text.to(device)
            mel = mel.to(device) # (B, Max_F, 80)
            wav = wav.to(device)

            # gate_targets 생성
            B, Max_F, _ = mel.size()
            gate_targets = torch.ones((B, Max_F), dtype=torch.float32).to(device)
            for idx, length in enumerate(mel_len):
                gate_targets[idx, :length] = 0
                
            # mel_targets 생성
            mel_targets = mel.transpose(1, 2) # (B, 80, Max_F)

            mel_outputs, postnet_mel_outputs, gate_outputs, alignments = model(text, text_len, mel)
            loss = Tacotron2Loss()(mel_outputs, postnet_mel_outputs, gate_outputs,
                                   mel_targets, gate_targets)

            model.zero_grad()
            loss.backward()
            grad_norm = nn.utils.clip_grad_norm_(model.parameters(), 1.0) # 기울기를 1로 cliping 
            
            optimizer.step()

            step += 1
            if step % 10 == 0:
                print(f'| epoch: {epoch} | step: {step} | loss: {loss:.5f} | grad_norm: {grad_norm:.5f} | {time.time()-start:.3f} sec / 10 steps')
                start = time.time() # start time 초기화

            if step % checkpoint_step == 0:
                save_dir = './ckpt/' + model_name + '/1'
                input_seq = sequence_to_text(text[0].cpu().numpy())
                input_seq = input_seq[:text_len[0]]
                alignment_dir = os.path.join(save_dir, f'step-{step}-align.png')
                plot_alignment(alignments[0].detach().cpu().numpy(), alignment_dir, input_seq, step, loss)
                torch.save({
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'step': step,
                    'epoch': epoch
                }, os.path.join(save_dir, 'ckpt-{}.pt'.format(step)))
        epoch += 1

- model_01: Original
- model_02: Original + grad_cliping
- model_03: grad_cliping + Bahdanau Attention
- model_04: grad_cliping + Bahdanau Attention + LSTM 사이에 concat 추가
- model_05: grad_cliping + Original Bahdanau Attention + LSTM 사이에 concat 추가
- model_06: Original + grad_cliping + LSTM 사이에 concat 추가
- model_07: Original + grad_cliping + LSTM 사이에 concat 추가 + mask
- model_08: model_07 + attention_weight_cum 도입

In [12]:
##### INPUT #####
model_name = 'model_08'
check_step = None

In [13]:
if __name__ == "__main__":
    train(model_name, check_step)

| epoch: 1 | step: 10 | loss: 2.37144 | grad_norm: 0.46209 | 38.918 sec / 10 steps
| epoch: 1 | step: 20 | loss: 2.05989 | grad_norm: 0.39781 | 36.900 sec / 10 steps
| epoch: 1 | step: 30 | loss: 1.99175 | grad_norm: 0.34701 | 36.949 sec / 10 steps
| epoch: 1 | step: 40 | loss: 1.94190 | grad_norm: 0.45892 | 33.225 sec / 10 steps
| epoch: 1 | step: 50 | loss: 1.90032 | grad_norm: 0.32706 | 34.105 sec / 10 steps
| epoch: 1 | step: 60 | loss: 1.87740 | grad_norm: 0.33003 | 36.536 sec / 10 steps
| epoch: 1 | step: 70 | loss: 1.83515 | grad_norm: 0.38787 | 33.756 sec / 10 steps
| epoch: 1 | step: 80 | loss: 1.82902 | grad_norm: 0.40811 | 33.813 sec / 10 steps
| epoch: 1 | step: 90 | loss: 1.76325 | grad_norm: 0.34648 | 37.398 sec / 10 steps
| epoch: 1 | step: 100 | loss: 1.77798 | grad_norm: 0.31312 | 29.694 sec / 10 steps
| epoch: 1 | step: 110 | loss: 1.76574 | grad_norm: 0.32165 | 36.047 sec / 10 steps
| epoch: 1 | step: 120 | loss: 1.72661 | grad_norm: 0.30567 | 31.410 sec / 10 steps
|

| epoch: 5 | step: 990 | loss: 0.67535 | grad_norm: 0.10304 | 21.156 sec / 10 steps
| epoch: 5 | step: 1000 | loss: 0.61594 | grad_norm: 0.08954 | 24.244 sec / 10 steps
| epoch: 6 | step: 1010 | loss: 0.59434 | grad_norm: 0.09058 | 24.316 sec / 10 steps
| epoch: 6 | step: 1020 | loss: 0.61787 | grad_norm: 0.08979 | 23.579 sec / 10 steps
| epoch: 6 | step: 1030 | loss: 0.61094 | grad_norm: 0.08396 | 21.669 sec / 10 steps
| epoch: 6 | step: 1040 | loss: 0.65682 | grad_norm: 0.09645 | 22.515 sec / 10 steps
| epoch: 6 | step: 1050 | loss: 0.61551 | grad_norm: 0.08696 | 22.988 sec / 10 steps
| epoch: 6 | step: 1060 | loss: 0.61816 | grad_norm: 0.08677 | 24.668 sec / 10 steps
| epoch: 6 | step: 1070 | loss: 0.56728 | grad_norm: 0.07957 | 23.149 sec / 10 steps
| epoch: 6 | step: 1080 | loss: 0.62159 | grad_norm: 0.09206 | 22.409 sec / 10 steps
| epoch: 6 | step: 1090 | loss: 0.57316 | grad_norm: 0.08522 | 23.548 sec / 10 steps
| epoch: 6 | step: 1100 | loss: 0.58501 | grad_norm: 0.07838 | 23.

| epoch: 10 | step: 1960 | loss: 0.55536 | grad_norm: 0.01548 | 24.907 sec / 10 steps
| epoch: 10 | step: 1970 | loss: 0.55350 | grad_norm: 0.01594 | 21.057 sec / 10 steps
| epoch: 10 | step: 1980 | loss: 0.47277 | grad_norm: 0.01543 | 23.884 sec / 10 steps
| epoch: 10 | step: 1990 | loss: 0.48830 | grad_norm: 0.01181 | 23.804 sec / 10 steps
| epoch: 10 | step: 2000 | loss: 0.54263 | grad_norm: 0.01474 | 21.755 sec / 10 steps
| epoch: 10 | step: 2010 | loss: 0.55376 | grad_norm: 0.01358 | 24.113 sec / 10 steps
| epoch: 11 | step: 2020 | loss: 0.53745 | grad_norm: 0.01238 | 22.368 sec / 10 steps
| epoch: 11 | step: 2030 | loss: 0.49920 | grad_norm: 0.03926 | 22.567 sec / 10 steps
| epoch: 11 | step: 2040 | loss: 0.52008 | grad_norm: 0.02205 | 21.376 sec / 10 steps
| epoch: 11 | step: 2050 | loss: 0.53413 | grad_norm: 0.01367 | 22.656 sec / 10 steps
| epoch: 11 | step: 2060 | loss: 0.48786 | grad_norm: 0.02947 | 22.722 sec / 10 steps
| epoch: 11 | step: 2070 | loss: 0.51888 | grad_norm: 

| epoch: 15 | step: 2920 | loss: 0.52232 | grad_norm: 0.01565 | 23.692 sec / 10 steps
| epoch: 15 | step: 2930 | loss: 0.56413 | grad_norm: 0.01240 | 22.561 sec / 10 steps
| epoch: 15 | step: 2940 | loss: 0.55803 | grad_norm: 0.01205 | 22.502 sec / 10 steps
| epoch: 15 | step: 2950 | loss: 0.49209 | grad_norm: 0.02219 | 23.024 sec / 10 steps
| epoch: 15 | step: 2960 | loss: 0.54749 | grad_norm: 0.01575 | 21.424 sec / 10 steps
| epoch: 15 | step: 2970 | loss: 0.52542 | grad_norm: 0.01077 | 20.262 sec / 10 steps
| epoch: 15 | step: 2980 | loss: 0.50672 | grad_norm: 0.02829 | 22.948 sec / 10 steps
| epoch: 15 | step: 2990 | loss: 0.53617 | grad_norm: 0.01204 | 21.838 sec / 10 steps
| epoch: 15 | step: 3000 | loss: 0.53428 | grad_norm: 0.01450 | 22.195 sec / 10 steps
| epoch: 15 | step: 3010 | loss: 0.51497 | grad_norm: 0.01795 | 23.233 sec / 10 steps
| epoch: 16 | step: 3020 | loss: 0.54122 | grad_norm: 0.01519 | 23.392 sec / 10 steps
| epoch: 16 | step: 3030 | loss: 0.51039 | grad_norm: 

| epoch: 20 | step: 3880 | loss: 0.51438 | grad_norm: 0.01031 | 22.398 sec / 10 steps
| epoch: 20 | step: 3890 | loss: 0.47764 | grad_norm: 0.01018 | 22.558 sec / 10 steps
| epoch: 20 | step: 3900 | loss: 0.51928 | grad_norm: 0.00914 | 22.974 sec / 10 steps
| epoch: 20 | step: 3910 | loss: 0.53172 | grad_norm: 0.01366 | 22.425 sec / 10 steps
| epoch: 20 | step: 3920 | loss: 0.55355 | grad_norm: 0.01669 | 21.195 sec / 10 steps
| epoch: 20 | step: 3930 | loss: 0.52487 | grad_norm: 0.01051 | 23.140 sec / 10 steps
| epoch: 20 | step: 3940 | loss: 0.52032 | grad_norm: 0.00948 | 23.351 sec / 10 steps
| epoch: 20 | step: 3950 | loss: 0.51706 | grad_norm: 0.01832 | 23.490 sec / 10 steps
| epoch: 20 | step: 3960 | loss: 0.49939 | grad_norm: 0.01057 | 22.852 sec / 10 steps
| epoch: 20 | step: 3970 | loss: 0.52859 | grad_norm: 0.01370 | 22.538 sec / 10 steps
| epoch: 20 | step: 3980 | loss: 0.55959 | grad_norm: 0.01401 | 22.694 sec / 10 steps
| epoch: 20 | step: 3990 | loss: 0.57496 | grad_norm: 

| epoch: 25 | step: 4840 | loss: 0.56195 | grad_norm: 0.01332 | 23.688 sec / 10 steps
| epoch: 25 | step: 4850 | loss: 0.55750 | grad_norm: 0.01143 | 20.193 sec / 10 steps
| epoch: 25 | step: 4860 | loss: 0.52369 | grad_norm: 0.00880 | 20.959 sec / 10 steps
| epoch: 25 | step: 4870 | loss: 0.51193 | grad_norm: 0.00789 | 21.483 sec / 10 steps
| epoch: 25 | step: 4880 | loss: 0.48941 | grad_norm: 0.00639 | 23.509 sec / 10 steps
| epoch: 25 | step: 4890 | loss: 0.48568 | grad_norm: 0.04321 | 22.974 sec / 10 steps
| epoch: 25 | step: 4900 | loss: 0.52817 | grad_norm: 0.01644 | 22.565 sec / 10 steps
| epoch: 25 | step: 4910 | loss: 0.53542 | grad_norm: 0.01349 | 22.786 sec / 10 steps
| epoch: 25 | step: 4920 | loss: 0.50799 | grad_norm: 0.00925 | 21.134 sec / 10 steps
| epoch: 25 | step: 4930 | loss: 0.54353 | grad_norm: 0.01019 | 23.785 sec / 10 steps
| epoch: 25 | step: 4940 | loss: 0.56830 | grad_norm: 0.01974 | 22.677 sec / 10 steps
| epoch: 25 | step: 4950 | loss: 0.54913 | grad_norm: 

| epoch: 29 | step: 5800 | loss: 0.53664 | grad_norm: 0.00805 | 24.877 sec / 10 steps
| epoch: 29 | step: 5810 | loss: 0.52553 | grad_norm: 0.01134 | 22.743 sec / 10 steps
| epoch: 29 | step: 5820 | loss: 0.52190 | grad_norm: 0.00991 | 21.707 sec / 10 steps
| epoch: 30 | step: 5830 | loss: 0.52290 | grad_norm: 0.01574 | 23.221 sec / 10 steps
| epoch: 30 | step: 5840 | loss: 0.48436 | grad_norm: 0.01157 | 21.573 sec / 10 steps
| epoch: 30 | step: 5850 | loss: 0.51676 | grad_norm: 0.01667 | 23.108 sec / 10 steps
| epoch: 30 | step: 5860 | loss: 0.53110 | grad_norm: 0.00928 | 22.195 sec / 10 steps
| epoch: 30 | step: 5870 | loss: 0.52114 | grad_norm: 0.00906 | 23.905 sec / 10 steps
| epoch: 30 | step: 5880 | loss: 0.52654 | grad_norm: 0.00762 | 22.009 sec / 10 steps
| epoch: 30 | step: 5890 | loss: 0.51737 | grad_norm: 0.01019 | 24.567 sec / 10 steps
| epoch: 30 | step: 5900 | loss: 0.53797 | grad_norm: 0.01012 | 24.062 sec / 10 steps
| epoch: 30 | step: 5910 | loss: 0.47709 | grad_norm: 

| epoch: 34 | step: 6760 | loss: 0.50389 | grad_norm: 0.01761 | 21.905 sec / 10 steps
| epoch: 34 | step: 6770 | loss: 0.53783 | grad_norm: 0.01044 | 21.990 sec / 10 steps
| epoch: 34 | step: 6780 | loss: 0.52445 | grad_norm: 0.01181 | 22.826 sec / 10 steps
| epoch: 34 | step: 6790 | loss: 0.53315 | grad_norm: 0.04068 | 22.888 sec / 10 steps
| epoch: 34 | step: 6800 | loss: 0.51843 | grad_norm: 0.01003 | 23.190 sec / 10 steps
| epoch: 34 | step: 6810 | loss: 0.50592 | grad_norm: 0.01101 | 22.936 sec / 10 steps
| epoch: 34 | step: 6820 | loss: 0.55361 | grad_norm: 0.01347 | 21.959 sec / 10 steps
| epoch: 34 | step: 6830 | loss: 0.54149 | grad_norm: 0.01543 | 22.558 sec / 10 steps
| epoch: 35 | step: 6840 | loss: 0.54720 | grad_norm: 0.01185 | 22.378 sec / 10 steps
| epoch: 35 | step: 6850 | loss: 0.51714 | grad_norm: 0.00823 | 23.561 sec / 10 steps
| epoch: 35 | step: 6860 | loss: 0.51192 | grad_norm: 0.00775 | 24.606 sec / 10 steps
| epoch: 35 | step: 6870 | loss: 0.49848 | grad_norm: 

| epoch: 39 | step: 7720 | loss: 0.54526 | grad_norm: 0.00654 | 20.615 sec / 10 steps
| epoch: 39 | step: 7730 | loss: 0.50727 | grad_norm: 0.00675 | 23.024 sec / 10 steps
| epoch: 39 | step: 7740 | loss: 0.55211 | grad_norm: 0.00747 | 22.727 sec / 10 steps
| epoch: 39 | step: 7750 | loss: 0.53446 | grad_norm: 0.00839 | 21.319 sec / 10 steps
| epoch: 39 | step: 7760 | loss: 0.55758 | grad_norm: 0.00889 | 22.563 sec / 10 steps
| epoch: 39 | step: 7770 | loss: 0.53756 | grad_norm: 0.01492 | 22.446 sec / 10 steps
| epoch: 39 | step: 7780 | loss: 0.53298 | grad_norm: 0.00565 | 21.593 sec / 10 steps
| epoch: 39 | step: 7790 | loss: 0.52643 | grad_norm: 0.02102 | 21.810 sec / 10 steps
| epoch: 39 | step: 7800 | loss: 0.52782 | grad_norm: 0.00672 | 23.767 sec / 10 steps
| epoch: 39 | step: 7810 | loss: 0.52069 | grad_norm: 0.00538 | 22.925 sec / 10 steps
| epoch: 39 | step: 7820 | loss: 0.53332 | grad_norm: 0.01525 | 23.205 sec / 10 steps
| epoch: 39 | step: 7830 | loss: 0.48401 | grad_norm: 

| epoch: 44 | step: 8680 | loss: 0.48831 | grad_norm: 0.00618 | 22.542 sec / 10 steps
| epoch: 44 | step: 8690 | loss: 0.53513 | grad_norm: 0.01005 | 22.491 sec / 10 steps
| epoch: 44 | step: 8700 | loss: 0.53265 | grad_norm: 0.01211 | 22.267 sec / 10 steps
| epoch: 44 | step: 8710 | loss: 0.52196 | grad_norm: 0.00736 | 22.682 sec / 10 steps
| epoch: 44 | step: 8720 | loss: 0.49571 | grad_norm: 0.00752 | 21.619 sec / 10 steps
| epoch: 44 | step: 8730 | loss: 0.47846 | grad_norm: 0.01010 | 24.251 sec / 10 steps
| epoch: 44 | step: 8740 | loss: 0.53399 | grad_norm: 0.01132 | 22.525 sec / 10 steps
| epoch: 44 | step: 8750 | loss: 0.56100 | grad_norm: 0.01183 | 22.589 sec / 10 steps
| epoch: 44 | step: 8760 | loss: 0.50377 | grad_norm: 0.00803 | 23.860 sec / 10 steps
| epoch: 44 | step: 8770 | loss: 0.50662 | grad_norm: 0.00828 | 21.499 sec / 10 steps
| epoch: 44 | step: 8780 | loss: 0.50683 | grad_norm: 0.01144 | 23.084 sec / 10 steps
| epoch: 44 | step: 8790 | loss: 0.48668 | grad_norm: 

| epoch: 48 | step: 9640 | loss: 0.55695 | grad_norm: 0.01686 | 22.623 sec / 10 steps
| epoch: 49 | step: 9650 | loss: 0.53402 | grad_norm: 0.02004 | 21.341 sec / 10 steps
| epoch: 49 | step: 9660 | loss: 0.54497 | grad_norm: 0.01021 | 21.329 sec / 10 steps
| epoch: 49 | step: 9670 | loss: 0.52710 | grad_norm: 0.00737 | 22.805 sec / 10 steps
| epoch: 49 | step: 9680 | loss: 0.50448 | grad_norm: 0.00707 | 22.671 sec / 10 steps
| epoch: 49 | step: 9690 | loss: 0.53663 | grad_norm: 0.00784 | 20.754 sec / 10 steps
| epoch: 49 | step: 9700 | loss: 0.49623 | grad_norm: 0.01164 | 22.800 sec / 10 steps
| epoch: 49 | step: 9710 | loss: 0.56086 | grad_norm: 0.00898 | 21.892 sec / 10 steps
| epoch: 49 | step: 9720 | loss: 0.51194 | grad_norm: 0.00791 | 21.871 sec / 10 steps
| epoch: 49 | step: 9730 | loss: 0.53719 | grad_norm: 0.01338 | 22.655 sec / 10 steps
| epoch: 49 | step: 9740 | loss: 0.54031 | grad_norm: 0.00609 | 22.740 sec / 10 steps
| epoch: 49 | step: 9750 | loss: 0.54762 | grad_norm: 

| epoch: 53 | step: 10590 | loss: 0.53268 | grad_norm: 0.00677 | 21.246 sec / 10 steps
| epoch: 53 | step: 10600 | loss: 0.48694 | grad_norm: 0.00803 | 23.231 sec / 10 steps
| epoch: 53 | step: 10610 | loss: 0.54540 | grad_norm: 0.00620 | 24.258 sec / 10 steps
| epoch: 53 | step: 10620 | loss: 0.52807 | grad_norm: 0.01014 | 21.232 sec / 10 steps
| epoch: 53 | step: 10630 | loss: 0.54936 | grad_norm: 0.00461 | 21.534 sec / 10 steps
| epoch: 53 | step: 10640 | loss: 0.50171 | grad_norm: 0.00798 | 22.870 sec / 10 steps
| epoch: 53 | step: 10650 | loss: 0.53957 | grad_norm: 0.00885 | 23.117 sec / 10 steps
| epoch: 54 | step: 10660 | loss: 0.52709 | grad_norm: 0.00772 | 22.770 sec / 10 steps
| epoch: 54 | step: 10670 | loss: 0.51363 | grad_norm: 0.01071 | 23.020 sec / 10 steps
| epoch: 54 | step: 10680 | loss: 0.49120 | grad_norm: 0.00813 | 23.558 sec / 10 steps
| epoch: 54 | step: 10690 | loss: 0.50341 | grad_norm: 0.01053 | 23.222 sec / 10 steps
| epoch: 54 | step: 10700 | loss: 0.52927 |

| epoch: 58 | step: 11540 | loss: 0.50124 | grad_norm: 0.00617 | 22.763 sec / 10 steps
| epoch: 58 | step: 11550 | loss: 0.55292 | grad_norm: 0.00742 | 23.522 sec / 10 steps
| epoch: 58 | step: 11560 | loss: 0.52078 | grad_norm: 0.00545 | 22.455 sec / 10 steps
| epoch: 58 | step: 11570 | loss: 0.51732 | grad_norm: 0.00552 | 22.938 sec / 10 steps
| epoch: 58 | step: 11580 | loss: 0.55266 | grad_norm: 0.00945 | 23.773 sec / 10 steps
| epoch: 58 | step: 11590 | loss: 0.52649 | grad_norm: 0.01191 | 23.891 sec / 10 steps
| epoch: 58 | step: 11600 | loss: 0.55534 | grad_norm: 0.01040 | 22.519 sec / 10 steps
| epoch: 58 | step: 11610 | loss: 0.52314 | grad_norm: 0.00815 | 24.196 sec / 10 steps
| epoch: 58 | step: 11620 | loss: 0.51512 | grad_norm: 0.00794 | 20.853 sec / 10 steps
| epoch: 58 | step: 11630 | loss: 0.48904 | grad_norm: 0.00673 | 22.522 sec / 10 steps
| epoch: 58 | step: 11640 | loss: 0.53396 | grad_norm: 0.00673 | 23.255 sec / 10 steps
| epoch: 58 | step: 11650 | loss: 0.54070 |

| epoch: 63 | step: 12490 | loss: 0.49519 | grad_norm: 0.00507 | 25.887 sec / 10 steps
| epoch: 63 | step: 12500 | loss: 0.52451 | grad_norm: 0.00865 | 23.388 sec / 10 steps
| epoch: 63 | step: 12510 | loss: 0.55305 | grad_norm: 0.00644 | 22.947 sec / 10 steps
| epoch: 63 | step: 12520 | loss: 0.49716 | grad_norm: 0.00757 | 23.005 sec / 10 steps
| epoch: 63 | step: 12530 | loss: 0.53088 | grad_norm: 0.00814 | 23.423 sec / 10 steps
| epoch: 63 | step: 12540 | loss: 0.54298 | grad_norm: 0.00923 | 24.497 sec / 10 steps
| epoch: 63 | step: 12550 | loss: 0.48551 | grad_norm: 0.00970 | 21.525 sec / 10 steps
| epoch: 63 | step: 12560 | loss: 0.51341 | grad_norm: 0.00557 | 22.409 sec / 10 steps
| epoch: 63 | step: 12570 | loss: 0.52248 | grad_norm: 0.00759 | 21.502 sec / 10 steps
| epoch: 63 | step: 12580 | loss: 0.50263 | grad_norm: 0.00461 | 22.474 sec / 10 steps
| epoch: 63 | step: 12590 | loss: 0.56218 | grad_norm: 0.00547 | 21.212 sec / 10 steps
| epoch: 63 | step: 12600 | loss: 0.54133 |

| epoch: 67 | step: 13440 | loss: 0.48571 | grad_norm: 0.00871 | 24.731 sec / 10 steps
| epoch: 67 | step: 13450 | loss: 0.49259 | grad_norm: 0.00543 | 22.283 sec / 10 steps
| epoch: 67 | step: 13460 | loss: 0.52916 | grad_norm: 0.00466 | 23.664 sec / 10 steps
| epoch: 68 | step: 13470 | loss: 0.49020 | grad_norm: 0.00603 | 22.492 sec / 10 steps
| epoch: 68 | step: 13480 | loss: 0.52480 | grad_norm: 0.00516 | 22.796 sec / 10 steps
| epoch: 68 | step: 13490 | loss: 0.52658 | grad_norm: 0.00465 | 23.202 sec / 10 steps
| epoch: 68 | step: 13500 | loss: 0.52627 | grad_norm: 0.00616 | 22.122 sec / 10 steps
| epoch: 68 | step: 13510 | loss: 0.51590 | grad_norm: 0.00566 | 21.173 sec / 10 steps
| epoch: 68 | step: 13520 | loss: 0.49775 | grad_norm: 0.00721 | 24.057 sec / 10 steps
| epoch: 68 | step: 13530 | loss: 0.55826 | grad_norm: 0.00524 | 21.845 sec / 10 steps
| epoch: 68 | step: 13540 | loss: 0.55832 | grad_norm: 0.00564 | 22.464 sec / 10 steps
| epoch: 68 | step: 13550 | loss: 0.49088 |

| epoch: 72 | step: 14390 | loss: 0.54100 | grad_norm: 0.00939 | 23.836 sec / 10 steps
| epoch: 72 | step: 14400 | loss: 0.52438 | grad_norm: 0.00883 | 22.967 sec / 10 steps
| epoch: 72 | step: 14410 | loss: 0.49748 | grad_norm: 0.00510 | 23.050 sec / 10 steps
| epoch: 72 | step: 14420 | loss: 0.54175 | grad_norm: 0.01016 | 24.932 sec / 10 steps
| epoch: 72 | step: 14430 | loss: 0.56062 | grad_norm: 0.00733 | 21.444 sec / 10 steps
| epoch: 72 | step: 14440 | loss: 0.54934 | grad_norm: 0.02298 | 22.359 sec / 10 steps
| epoch: 72 | step: 14450 | loss: 0.55636 | grad_norm: 0.00630 | 21.889 sec / 10 steps
| epoch: 72 | step: 14460 | loss: 0.51883 | grad_norm: 0.00926 | 23.525 sec / 10 steps
| epoch: 72 | step: 14470 | loss: 0.49432 | grad_norm: 0.00683 | 21.985 sec / 10 steps
| epoch: 73 | step: 14480 | loss: 0.52013 | grad_norm: 0.00844 | 23.536 sec / 10 steps
| epoch: 73 | step: 14490 | loss: 0.50509 | grad_norm: 0.01139 | 22.762 sec / 10 steps
| epoch: 73 | step: 14500 | loss: 0.53574 |

| epoch: 77 | step: 15340 | loss: 0.48095 | grad_norm: 0.00755 | 23.557 sec / 10 steps
| epoch: 77 | step: 15350 | loss: 0.51662 | grad_norm: 0.00825 | 23.313 sec / 10 steps
| epoch: 77 | step: 15360 | loss: 0.51219 | grad_norm: 0.01481 | 25.027 sec / 10 steps
| epoch: 77 | step: 15370 | loss: 0.52371 | grad_norm: 0.00627 | 22.571 sec / 10 steps
| epoch: 77 | step: 15380 | loss: 0.54328 | grad_norm: 0.00928 | 22.246 sec / 10 steps
| epoch: 77 | step: 15390 | loss: 0.50028 | grad_norm: 0.00633 | 23.802 sec / 10 steps
| epoch: 77 | step: 15400 | loss: 0.53439 | grad_norm: 0.01560 | 22.084 sec / 10 steps
| epoch: 77 | step: 15410 | loss: 0.47864 | grad_norm: 0.00666 | 23.495 sec / 10 steps
| epoch: 77 | step: 15420 | loss: 0.53715 | grad_norm: 0.00507 | 23.976 sec / 10 steps
| epoch: 77 | step: 15430 | loss: 0.45727 | grad_norm: 0.00380 | 23.166 sec / 10 steps
| epoch: 77 | step: 15440 | loss: 0.53411 | grad_norm: 0.00495 | 22.485 sec / 10 steps
| epoch: 77 | step: 15450 | loss: 0.53922 |

| epoch: 82 | step: 16290 | loss: 0.57686 | grad_norm: 0.00786 | 21.490 sec / 10 steps
| epoch: 82 | step: 16300 | loss: 0.45852 | grad_norm: 0.00657 | 23.760 sec / 10 steps
| epoch: 82 | step: 16310 | loss: 0.47924 | grad_norm: 0.00681 | 24.918 sec / 10 steps
| epoch: 82 | step: 16320 | loss: 0.51596 | grad_norm: 0.00881 | 23.873 sec / 10 steps
| epoch: 82 | step: 16330 | loss: 0.49979 | grad_norm: 0.00513 | 21.855 sec / 10 steps
| epoch: 82 | step: 16340 | loss: 0.50387 | grad_norm: 0.00835 | 22.418 sec / 10 steps
| epoch: 82 | step: 16350 | loss: 0.51918 | grad_norm: 0.00665 | 21.644 sec / 10 steps
| epoch: 82 | step: 16360 | loss: 0.52677 | grad_norm: 0.00668 | 22.021 sec / 10 steps
| epoch: 82 | step: 16370 | loss: 0.55177 | grad_norm: 0.00548 | 21.712 sec / 10 steps
| epoch: 82 | step: 16380 | loss: 0.50254 | grad_norm: 0.00474 | 24.036 sec / 10 steps
| epoch: 82 | step: 16390 | loss: 0.55023 | grad_norm: 0.00974 | 21.511 sec / 10 steps
| epoch: 82 | step: 16400 | loss: 0.50938 |

| epoch: 86 | step: 17240 | loss: 0.47624 | grad_norm: 0.00526 | 23.873 sec / 10 steps
| epoch: 86 | step: 17250 | loss: 0.52139 | grad_norm: 0.00612 | 23.844 sec / 10 steps
| epoch: 86 | step: 17260 | loss: 0.47978 | grad_norm: 0.00543 | 24.115 sec / 10 steps
| epoch: 86 | step: 17270 | loss: 0.53190 | grad_norm: 0.00658 | 21.417 sec / 10 steps
| epoch: 86 | step: 17280 | loss: 0.51314 | grad_norm: 0.00455 | 22.816 sec / 10 steps
| epoch: 87 | step: 17290 | loss: 0.48545 | grad_norm: 0.00544 | 23.383 sec / 10 steps
| epoch: 87 | step: 17300 | loss: 0.52809 | grad_norm: 0.00992 | 23.713 sec / 10 steps
| epoch: 87 | step: 17310 | loss: 0.52069 | grad_norm: 0.00824 | 24.537 sec / 10 steps
| epoch: 87 | step: 17320 | loss: 0.50930 | grad_norm: 0.00662 | 21.571 sec / 10 steps
| epoch: 87 | step: 17330 | loss: 0.53250 | grad_norm: 0.00996 | 21.636 sec / 10 steps
| epoch: 87 | step: 17340 | loss: 0.48517 | grad_norm: 0.00590 | 24.079 sec / 10 steps
| epoch: 87 | step: 17350 | loss: 0.53296 |

KeyboardInterrupt: 