In [1]:
import numpy as np
import pandas as pd
import torch
import pydub
from torch.utils.data import DataLoader,Dataset
from torchaudio.functional import spectrogram
from torchaudio.transforms import Spectrogram, MelScale, InverseMelScale, GriffinLim




In [2]:
class Hyperparams:
  seed = 42

  csv_path = "LJSpeech-1.1/metadata.csv"
  wav_path = "LJSpeech-1.1/wavs"
  save_path = "params"  
  log_path = "train_logs"
  
  save_name = "SimpleTransfromerTTS.pt"

  # Text transformations params
  symbols = [
    'EOS', ' ', '!', ',', '-', '.', \
    ';', '?', 'a', 'b', 'c', 'd', 'e', 'f', \
    'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', \
    'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'à', \
    'â', 'è', 'é', 'ê', 'ü', '’', '“', '”' \
  ]
  
  # Sounds transformations params
  sr = 22050
  n_fft = 2048
  n_stft = int((n_fft//2) + 1)
  
  frame_shift = 0.0125 # seconds
  hop_length = int(n_fft/8.0)
  
  frame_length = 0.05 # seconds  
  win_length = int(n_fft/2.0)
  
  mel_freq = 128
  max_mel_time = 1024
  
  max_db = 100  
  scale_db = 10
  ref = 4.0
  power = 2.0
  norm_db = 10 
  ampl_multiplier = 10.0
  ampl_amin = 1e-10
  db_multiplier = 1.0
  ampl_ref = 1.0
  ampl_power = 1.0

  # Model params
  text_num_embeddings = 2*len(symbols)  
  embedding_size = 256
  encoder_embedding_size = 512 

  dim_feedforward = 1024
  postnet_embedding_size = 1024

  encoder_kernel_size = 3
  postnet_kernel_size = 5

  # Other
  batch_size = 32
  grad_clip = 1.0
  lr = 2.0 * 1e-4
  r_gate = 1.0

  step_print = 1000
  step_test = 8000
  step_save = 8000
  min_label_db = -100
hp = Hyperparams()


    

# Text To Sequence Conversion

In [3]:
symbol_to_id = {
    s: i for i,s in enumerate(hp.symbols)
}

def text_to_seq(text):
    text = text.lower()
    seq = []
    for s in text:
        _id = symbol_to_id.get(s, None)
        if id is not None:
            seq.append(_id)
            
    seq.append(symbol_to_id['EOS'])
    
    return torch.IntTensor(seq)

In [4]:
print(text_to_seq("Hello, World"))

tensor([15, 12, 19, 19, 22,  3,  1, 30, 22, 25, 19, 11,  0], dtype=torch.int32)


# Mask For Sequence Length

In [5]:
def mask_from_seq_length(
    sequence_lengths: torch.Tensor,
    max_length: int
) ->torch.BoolTensor:
    """
    our input was `[2, 2, 3]`, with a `max_length` of 4, we'd return
    `[[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0]]`.
    """
    # (batch size , max_length)
    ones = sequence_lengths.new_ones(sequence_lengths.size(0), max_length)
    range_tensor = ones.cumsum(dim = 1)
    return sequence_lengths.unsqueeze(1) >= range_tensor

# Mel Spectrogram Transformation

In [6]:
from torchaudio.functional import spectrogram
import torchaudio

In [7]:
import torchaudio
from torchaudio.functional import spectrogram


spec_transform = torchaudio.transforms.Spectrogram(
    n_fft=hp.n_fft, 
    win_length=hp.win_length,
    hop_length=hp.hop_length,
    power=hp.power
)


mel_scale_transform = torchaudio.transforms.MelScale(
  n_mels=hp.mel_freq, 
  sample_rate=hp.sr, 
  n_stft=hp.n_stft
)


mel_inverse_transform = torchaudio.transforms.InverseMelScale(
  n_mels=hp.mel_freq, 
  sample_rate=hp.sr, 
  n_stft=hp.n_stft
).cuda()


griffnlim_transform = torchaudio.transforms.GriffinLim(
    n_fft=hp.n_fft,
    win_length=hp.win_length,
    hop_length=hp.hop_length
).cuda()


def norm_mel_spec_db(mel_spec):  
  mel_spec = ((2.0*mel_spec - hp.min_level_db) / (hp.max_db/hp.norm_db)) - 1.0
  mel_spec = torch.clip(mel_spec, -hp.ref*hp.norm_db, hp.ref*hp.norm_db)
  return mel_spec


def denorm_mel_spec_db(mel_spec):
  mel_spec = (((1.0 + mel_spec) * (hp.max_db/hp.norm_db)) + hp.min_level_db) / 2.0 
  return mel_spec


def pow_to_db_mel_spec(mel_spec):
  mel_spec = torchaudio.functional.amplitude_to_DB(
    mel_spec,
    multiplier = hp.ampl_multiplier, 
    amin = hp.ampl_amin, 
    db_multiplier = hp.db_multiplier, 
    top_db = hp.max_db
  )
  mel_spec = mel_spec/hp.scale_db
  return mel_spec


def db_to_power_mel_spec(mel_spec):
  mel_spec = mel_spec*hp.scale_db
  mel_spec = torchaudio.functional.DB_to_amplitude(
    mel_spec,
    ref=hp.ampl_ref,
    power=hp.ampl_power
  )  
  return mel_spec


def convert_to_mel_spec(wav):
  spec = spec_transform(wav)
  mel_spec = mel_scale_transform(spec)
  db_mel_spec = pow_to_db_mel_spec(mel_spec)
  db_mel_spec = db_mel_spec.squeeze(0)
  return db_mel_spec


def inverse_mel_spec_to_wav(mel_spec):
  power_mel_spec = db_to_power_mel_spec(mel_spec)
  spectrogram = mel_inverse_transform(power_mel_spec)
  pseudo_wav = griffnlim_transform(spectrogram)
  return pseudo_wav

In [8]:
wav_path = "LJSpeech-1.1/wavs/LJ001-0001.wav" 
waveform, sample_rate = torchaudio.load(wav_path, normalize=True)
mel_spec = convert_to_mel_spec(waveform)
print("mel_spec:", mel_spec.shape)
pseudo_wav = inverse_mel_spec_to_wav(mel_spec.cuda())
print("pseudo_wav:", pseudo_wav.shape)

mel_spec: torch.Size([128, 832])
pseudo_wav: torch.Size([212736])


# Dataset And DataLoader

In [9]:
class TextMelDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df
        self.cache = {}

    def get_item(self, row):
      wav_id = row["wav"]                  
      wav_path = f"{hp.wav_path}/{wav_id}.wav"

      text = row["text_norm"]
      text = text_to_seq(text)

      waveform, sample_rate = torchaudio.load(wav_path, normalize=True)
      assert sample_rate == hp.sr

      mel = convert_to_mel_spec(waveform)

      return (text, mel)
    
    def __getitem__(self, index):
      row = self.df.iloc[index]
      wav_id = row["wav"]

      text_mel = self.cache.get(wav_id)

      if text_mel is None:
        text_mel = self.get_item(row)
        self.cache[wav_id] = text_mel
      
      return text_mel

    def __len__(self):
        return len(self.df)


def text_mel_collate_fn(batch):
  text_length_max = torch.tensor(
    [text.shape[-1] for text, _ in batch], 
    dtype=torch.int32
  ).max()

  mel_length_max = torch.tensor(
    [mel.shape[-1] for _, mel in batch],
    dtype=torch.int32
  ).max()

  
  text_lengths = []
  mel_lengths = []
  texts_padded = []
  mels_padded = []

  for text, mel in batch:
    text_length = text.shape[-1]      

    text_padded = torch.nn.functional.pad(
      text,
      pad=[0, text_length_max-text_length],
      value=0
    )

    mel_length = mel.shape[-1]
    mel_padded = torch.nn.functional.pad(
        mel,
        pad=[0, mel_length_max-mel_length],
        value=0
    )

    text_lengths.append(text_length)    
    mel_lengths.append(mel_length)    
    texts_padded.append(text_padded)    
    mels_padded.append(mel_padded)

  text_lengths = torch.tensor(text_lengths, dtype=torch.int32)
  mel_lengths = torch.tensor(mel_lengths, dtype=torch.int32)
  texts_padded = torch.stack(texts_padded, 0)
  mels_padded = torch.stack(mels_padded, 0).transpose(1, 2)

  stop_token_padded = mask_from_seq_length(
      mel_lengths,
      mel_length_max
  )
  stop_token_padded = (~stop_token_padded).float()
  stop_token_padded[:, -1] = 1.0
  
  return texts_padded, \
         text_lengths, \
         mels_padded, \
         mel_lengths, \
         stop_token_padded 

# Create Dataset and DataLoader

In [10]:
df = pd.read_csv(hp.csv_path)
dataset = TextMelDataset(df)
train_loader = DataLoader(
    dataset, 
    num_workers=2, 
    shuffle=True,
    batch_size=hp.batch_size,
    pin_memory=True, 
    drop_last=True, 
    collate_fn=text_mel_collate_fn
)

In [11]:
def names_shape(names, shape):  
    return "(" + ", ".join([f"{k}={v}" for k, v in list(zip(names, shape))]) + ")"

for i, batch in enumerate(train_loader):
    text_padded, text_lengths, mel_padded, mel_lengths, stop_token_padded = batch
    print(f"=========batch {i}=========")
    print("text_padded:", names_shape(["N", "S"], text_padded.shape))
    print("text_lengths:", names_shape(["N"], text_lengths.shape))
    print("mel_padded:", names_shape(["N", "TIME", "FREQ"], mel_padded.shape))
    print("mel_lengths:", names_shape(["N"], mel_lengths.shape))
    print("stop_token_padded:", names_shape(["N", "TIME"], stop_token_padded.shape))
    if i > 0:
        break

In [11]:
class TTSLoss(torch.nn.Module):
    def __init__(self):
        super(TTSLoss, self).__init__()
        self.mse_loss = torch.nn.MSELoss()
        self.bce_loss = torch.nn.BCEWithLogitsLoss()
        
    def forward(
        self,
        mel_postnet_out,
        mel_out,
        stop_token_out,
        mel_target,
        stop_token_target
    ):
        stop_token_target = stop_token_target.view(-1,1)
        stop_token_out = stop_token_out.view(-1,1)
        mel_loss = self.mse_loss(mel_out, mel_target) + \
            self.mse_loss(mel_postnet_out, mel_target)
            
        stop_token_loss = self.bce_loss(stop_token_out, stop_token_target) * hp.r_gate
        
        return mel_loss + stop_token_loss

In [12]:
def write_mp3(x, f="audio.mp3", sr=hp.sr, normalized=True):
    pydub.AudioSegment(
        x.detach().cpu().numpy(),
        frame_rate=sr,
        sample_width=2, 
        channels=1
    ).export(f, format="mp3")

In [18]:
class EncoderBlock(nn.Module):
  def __init__(self):
    super(EncoderBlock, self).__init__()
    self.norm_1 = nn.LayerNorm(
      normalized_shape=hp.embedding_size
    )
    self.attn = torch.nn.MultiheadAttention(
      embed_dim=hp.embedding_size,
      num_heads=4,
      dropout=0.1,
      batch_first=True
    )
    self.dropout_1 = torch.nn.Dropout(0.1)

    self.norm_2 = nn.LayerNorm(
      normalized_shape=hp.embedding_size
    )

    self.linear_1 = nn.Linear(
      hp.embedding_size, 
      hp.dim_feedforward
    )

    self.dropout_2 = torch.nn.Dropout(0.1)
    self.linear_2 = nn.Linear(
      hp.dim_feedforward, 
      hp.embedding_size
    )
    self.dropout_3 = torch.nn.Dropout(0.1)
    

  def forward(
    self, 
    x,
    attn_mask = None, 
    key_padding_mask = None
  ):
    x_out = self.norm_1(x)
    x_out, _ = self.attn(
      query=x_out, 
      key=x_out, 
      value=x_out,
      attn_mask=attn_mask,
      key_padding_mask=key_padding_mask
    )
    x_out = self.dropout_1(x_out)
    x = x + x_out    

    x_out = self.norm_2(x) 

    x_out = self.linear_1(x_out)
    x_out = F.relu(x_out)
    x_out = self.dropout_2(x_out)
    x_out = self.linear_2(x_out)
    x_out = self.dropout_3(x_out)

    x = x + x_out
    
    return x


class DecoderBlock(nn.Module):
  def __init__(self):
    super(DecoderBlock, self).__init__()
    self.norm_1 = nn.LayerNorm(
      normalized_shape=hp.embedding_size
    )
    self.self_attn = torch.nn.MultiheadAttention(
      embed_dim=hp.embedding_size,
      num_heads=4,
      dropout=0.1,
      batch_first=True
    )
    self.dropout_1 = torch.nn.Dropout(0.1)

    self.norm_2 = nn.LayerNorm(
      normalized_shape=hp.embedding_size
    )
    self.attn = torch.nn.MultiheadAttention(
      embed_dim=hp.embedding_size,
      num_heads=4,
      dropout=0.1,
      batch_first=True
    )    
    self.dropout_2 = torch.nn.Dropout(0.1)

    self.norm_3 = nn.LayerNorm(
      normalized_shape=hp.embedding_size
    )

    self.linear_1 = nn.Linear(
      hp.embedding_size, 
      hp.dim_feedforward
    )
    self.dropout_3 = torch.nn.Dropout(0.1)
    self.linear_2 = nn.Linear(
      hp.dim_feedforward, 
      hp.embedding_size
    )
    self.dropout_4 = torch.nn.Dropout(0.1)


  def forward(
    self,     
    x,
    memory,
    x_attn_mask = None, 
    x_key_padding_mask = None,
    memory_attn_mask = None,
    memory_key_padding_mask = None
  ):
    x_out, _ = self.self_attn(
      query=x, 
      key=x, 
      value=x,
      attn_mask=x_attn_mask,
      key_padding_mask=x_key_padding_mask
    )
    x_out = self.dropout_1(x_out)
    x = self.norm_1(x + x_out)
     
    x_out, _ = self.attn(
      query=x,
      key=memory,
      value=memory,
      attn_mask=memory_attn_mask,
      key_padding_mask=memory_key_padding_mask
    )
    x_out = self.dropout_2(x_out)
    x = self.norm_2(x + x_out)

    x_out = self.linear_1(x)
    x_out = F.relu(x_out)
    x_out = self.dropout_3(x_out)
    x_out = self.linear_2(x_out)
    x_out = self.dropout_4(x_out)
    x = self.norm_3(x + x_out)

    return x


class EncoderPreNet(nn.Module):
  def __init__(self):
    super(EncoderPreNet, self).__init__()
    
    self.embedding = nn.Embedding(
        num_embeddings=hp.text_num_embeddings,
        embedding_dim=hp.encoder_embedding_size
    )

    self.linear_1 = nn.Linear(
      hp.encoder_embedding_size, 
      hp.encoder_embedding_size
    )

    self.linear_2 = nn.Linear(
      hp.encoder_embedding_size, 
      hp.embedding_size
    )

    self.conv_1 = nn.Conv1d(
      hp.encoder_embedding_size, 
      hp.encoder_embedding_size,
      kernel_size=hp.encoder_kernel_size, 
      stride=1,
      padding=int((hp.encoder_kernel_size - 1) / 2), 
      dilation=1
    )
    self.bn_1 = nn.BatchNorm1d(
      hp.encoder_embedding_size
    )
    self.dropout_1 = torch.nn.Dropout(0.5)

    self.conv_2 = nn.Conv1d(
      hp.encoder_embedding_size, 
      hp.encoder_embedding_size,
      kernel_size=hp.encoder_kernel_size, 
      stride=1,
      padding=int((hp.encoder_kernel_size - 1) / 2), 
      dilation=1
    )
    self.bn_2 = nn.BatchNorm1d(
      hp.encoder_embedding_size
    )
    self.dropout_2 = torch.nn.Dropout(0.5)

    self.conv_3 = nn.Conv1d(
      hp.encoder_embedding_size, 
      hp.encoder_embedding_size,
      kernel_size=hp.encoder_kernel_size, 
      stride=1,
      padding=int((hp.encoder_kernel_size - 1) / 2), 
      dilation=1
    )
    self.bn_3 = nn.BatchNorm1d(
      hp.encoder_embedding_size
    )
    self.dropout_3 = torch.nn.Dropout(0.5)    

  def forward(self, text):
    x = self.embedding(text) # (N, S, E)
    x = self.linear_1(x)

    x = x.transpose(2, 1) # (N, E, S) 

    x = self.conv_1(x)
    x = self.bn_1(x)
    x = F.relu(x)
    x = self.dropout_1(x)

    x = self.conv_2(x)
    x = self.bn_2(x)
    x = F.relu(x)
    x = self.dropout_2(x)
    
    x = self.conv_3(x)
    x = self.bn_3(x)    
    x = F.relu(x)
    x = self.dropout_3(x)

    x = x.transpose(1, 2) # (N, S, E)
    x = self.linear_2(x)

    return x


class PostNet(nn.Module):
  def __init__(self):
    super(PostNet, self).__init__()  
    
    self.conv_1 = nn.Conv1d(
      hp.mel_freq, 
      hp.postnet_embedding_size,
      kernel_size=hp.postnet_kernel_size, 
      stride=1,
      padding=int((hp.postnet_kernel_size - 1) / 2), 
      dilation=1
    )
    self.bn_1 = nn.BatchNorm1d(
      hp.postnet_embedding_size
    )
    self.dropout_1 = torch.nn.Dropout(0.5)

    self.conv_2 = nn.Conv1d(
      hp.postnet_embedding_size, 
      hp.postnet_embedding_size,
      kernel_size=hp.postnet_kernel_size, 
      stride=1,
      padding=int((hp.postnet_kernel_size - 1) / 2), 
      dilation=1
    )
    self.bn_2 = nn.BatchNorm1d(
      hp.postnet_embedding_size
    )
    self.dropout_2 = torch.nn.Dropout(0.5)

    self.conv_3 = nn.Conv1d(
      hp.postnet_embedding_size, 
      hp.postnet_embedding_size,
      kernel_size=hp.postnet_kernel_size, 
      stride=1,
      padding=int((hp.postnet_kernel_size - 1) / 2), 
      dilation=1
    )
    self.bn_3 = nn.BatchNorm1d(
      hp.postnet_embedding_size
    )
    self.dropout_3 = torch.nn.Dropout(0.5)

    self.conv_4 = nn.Conv1d(
      hp.postnet_embedding_size, 
      hp.postnet_embedding_size,
      kernel_size=hp.postnet_kernel_size, 
      stride=1,
      padding=int((hp.postnet_kernel_size - 1) / 2), 
      dilation=1
    )
    self.bn_4 = nn.BatchNorm1d(
      hp.postnet_embedding_size
    )
    self.dropout_4 = torch.nn.Dropout(0.5)


    self.conv_5 = nn.Conv1d(
      hp.postnet_embedding_size, 
      hp.postnet_embedding_size,
      kernel_size=hp.postnet_kernel_size, 
      stride=1,
      padding=int((hp.postnet_kernel_size - 1) / 2), 
      dilation=1
    )
    self.bn_5 = nn.BatchNorm1d(
      hp.postnet_embedding_size
    )
    self.dropout_5 = torch.nn.Dropout(0.5)


    self.conv_6 = nn.Conv1d(
      hp.postnet_embedding_size, 
      hp.mel_freq,
      kernel_size=hp.postnet_kernel_size, 
      stride=1,
      padding=int((hp.postnet_kernel_size - 1) / 2), 
      dilation=1
    )
    self.bn_6 = nn.BatchNorm1d(hp.mel_freq)
    self.dropout_6 = torch.nn.Dropout(0.5)


  def forward(self, x):
    # x - (N, TIME, FREQ)

    x = x.transpose(2, 1) # (N, FREQ, TIME)

    x = self.conv_1(x)
    x = self.bn_1(x)
    x = torch.tanh(x)
    x = self.dropout_1(x) # (N, POSNET_DIM, TIME)

    x = self.conv_2(x)
    x = self.bn_2(x)
    x = torch.tanh(x)
    x = self.dropout_2(x) # (N, POSNET_DIM, TIME)

    x = self.conv_3(x)
    x = self.bn_3(x)
    x = torch.tanh(x)
    x = self.dropout_3(x) # (N, POSNET_DIM, TIME)    

    x = self.conv_4(x)
    x = self.bn_4(x)
    x = torch.tanh(x)
    x = self.dropout_4(x) # (N, POSNET_DIM, TIME)    

    x = self.conv_5(x)
    x = self.bn_5(x)
    x = torch.tanh(x)
    x = self.dropout_5(x) # (N, POSNET_DIM, TIME)

    x = self.conv_6(x)
    x = self.bn_6(x)
    x = self.dropout_6(x) # (N, FREQ, TIME)

    x = x.transpose(1, 2)

    return x


class DecoderPreNet(nn.Module):
  def __init__(self):
    super(DecoderPreNet, self).__init__()
    self.linear_1 = nn.Linear(
      hp.mel_freq, 
      hp.embedding_size
    )

    self.linear_2 = nn.Linear(
      hp.embedding_size, 
      hp.embedding_size
    )

  def forward(self, x):
    x = self.linear_1(x)
    x = F.relu(x)
    
    x = F.dropout(x, p=0.5, training=True)

    x = self.linear_2(x)
    x = F.relu(x)    
    x = F.dropout(x, p=0.5, training=True)

    return x    


class TransformerTTS(nn.Module):
  def __init__(self, device="cuda"):
    super(TransformerTTS, self).__init__()

    self.encoder_prenet = EncoderPreNet()
    self.decoder_prenet = DecoderPreNet()
    self.postnet = PostNet()

    self.pos_encoding = nn.Embedding(
        num_embeddings=hp.max_mel_time, 
        embedding_dim=hp.embedding_size
    )

    self.encoder_block_1 = EncoderBlock()
    self.encoder_block_2 = EncoderBlock()
    self.encoder_block_3 = EncoderBlock()

    self.decoder_block_1 = DecoderBlock()
    self.decoder_block_2 = DecoderBlock()
    self.decoder_block_3 = DecoderBlock()

    self.linear_1 = nn.Linear(hp.embedding_size, hp.mel_freq) 
    self.linear_2 = nn.Linear(hp.embedding_size, 1)

    self.norm_memory = nn.LayerNorm(
      normalized_shape=hp.embedding_size
    )


  def forward(
    self, 
    text, 
    text_len,
    mel, 
    mel_len
  ):  
    
    N = text.shape[0]
    S = text.shape[1]
    TIME = mel.shape[1]

    self.src_key_padding_mask = torch.zeros(
        (N, S),
        device=text.device
    ).masked_fill(
      ~mask_from_seq_lengths(
        text_len,
        max_length=S
      ),
      float("-inf")
    )
    
    self.src_mask = torch.zeros(
      (S, S),
      device=text.device
    ).masked_fill(
      torch.triu(
          torch.full(
              (S, S), 
              True,
              dtype=torch.bool
          ), 
          diagonal=1
      ).to(text.device),       
      float("-inf")
    )

    self.tgt_key_padding_mask = torch.zeros(
      (N, TIME),
      device=mel.device
    ).masked_fill(
      ~mask_from_seq_lengths(
        mel_len,
        max_length=TIME
      ),
      float("-inf")
    )

    self.tgt_mask = torch.zeros(
      (TIME, TIME),
      device=mel.device
    ).masked_fill(
      torch.triu(
          torch.full(
              (TIME, TIME), 
              True,
              device=mel.device,
              dtype=torch.bool
          ), 
          diagonal=1
      ),       
      float("-inf")
    )

    self.memory_mask = torch.zeros(
      (TIME, S),
      device=mel.device
    ).masked_fill(
      torch.triu(
          torch.full(
              (TIME, S), 
              True,
              device=mel.device,
              dtype=torch.bool
          ), 
          diagonal=1          
      ),       
      float("-inf")
    )    

    text_x = self.encoder_prenet(text) # (N, S, E)    
    
    pos_codes = self.pos_encoding(
      torch.arange(hp.max_mel_time).to(mel.device)
    ) # (MAX_S_TIME, E)

    S = text_x.shape[1]
    text_x = text_x + pos_codes[:S]
    # dropout after pos encoding?

    text_x = self.encoder_block_1(
      text_x, 
      attn_mask = self.src_mask, 
      key_padding_mask = self.src_key_padding_mask
    )
    text_x = self.encoder_block_2(
      text_x, 
      attn_mask = self.src_mask, 
      key_padding_mask = self.src_key_padding_mask
    )    
    text_x = self.encoder_block_3(
      text_x, 
      attn_mask = self.src_mask, 
      key_padding_mask = self.src_key_padding_mask
    ) # (N, S, E)

    text_x = self.norm_memory(text_x)
        
    mel_x = self.decoder_prenet(mel) # (N, TIME, E)    
    mel_x = mel_x + pos_codes[:TIME]
    # dropout after pos encoding?

    mel_x = self.decoder_block_1(
      x=mel_x,
      memory=text_x,
      x_attn_mask=self.tgt_mask, 
      x_key_padding_mask=self.tgt_key_padding_mask,
      memory_attn_mask=self.memory_mask,
      memory_key_padding_mask=self.src_key_padding_mask
    )

    mel_x = self.decoder_block_2(
      x=mel_x,
      memory=text_x,
      x_attn_mask=self.tgt_mask, 
      x_key_padding_mask=self.tgt_key_padding_mask,
      memory_attn_mask=self.memory_mask,
      memory_key_padding_mask=self.src_key_padding_mask
    )

    mel_x = self.decoder_block_3(
      x=mel_x,
      memory=text_x,
      x_attn_mask=self.tgt_mask, 
      x_key_padding_mask=self.tgt_key_padding_mask,
      memory_attn_mask=self.memory_mask,
      memory_key_padding_mask=self.src_key_padding_mask
    ) # (N, TIME, E)

    mel_linear = self.linear_1(mel_x) # (N, TIME, FREQ)
    mel_postnet = self.postnet(mel_linear) # (N, TIME, FREQ)
    mel_postnet = mel_linear + mel_postnet # (N, TIME, FREQ)
    stop_token = self.linear_2(mel_x) # (N, TIME, 1)

    bool_mel_mask = self.tgt_key_padding_mask.ne(0).unsqueeze(-1).repeat(
      1, 1, hp.mel_freq
    )

    mel_linear = mel_linear.masked_fill(
      bool_mel_mask,
      0
    )

    mel_postnet = mel_postnet.masked_fill(
      bool_mel_mask,
      0      
    )

    stop_token = stop_token.masked_fill(
      bool_mel_mask[:, :, 0].unsqueeze(-1),
      1e3
    ).squeeze(2)
    
    return mel_postnet, mel_linear, stop_token 

In [24]:
from sklearn.model_selection import train_test_split

train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)
train_dataset = TextMelDataset(train_df)
val_dataset = TextMelDataset(val_df)
train_loader = DataLoader(train_dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=text_mel_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=hp.batch_size, shuffle=False, collate_fn=text_mel_collate_fn)

In [16]:
val_loader

<torch.utils.data.dataloader.DataLoader at 0x1fed90df050>

In [25]:
model = TransformerTTS().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = TTSLoss().cuda()


In [26]:
print(model)

TransformerTTS(
  (encoder_prenet): EncoderPreNet(
    (embedding): Embedding(86, 512)
    (linear_1): Linear(in_features=512, out_features=512, bias=True)
    (linear_2): Linear(in_features=512, out_features=256, bias=True)
    (conv_1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
    (bn_1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout_1): Dropout(p=0.5, inplace=False)
    (conv_2): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
    (bn_2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout_2): Dropout(p=0.5, inplace=False)
    (conv_3): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
    (bn_3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout_3): Dropout(p=0.5, inplace=False)
  )
  (decoder_prenet): DecoderPreNet(
    (linear_1): Linear(in_features=128, out_features=256, bias=True)
    (linear_2): Lin

In [29]:
from torch.utils.tensorboard import SummaryWriter
import os
import time
import matplotlib.pyplot as plt

def batch_process(batch):
  text_padded, \
  text_lengths, \
  mel_padded, \
  mel_lengths, \
  stop_token_padded = batch

  text_padded = text_padded.cuda()
  text_lengths = text_lengths.cuda()
  mel_padded = mel_padded.cuda()
  stop_token_padded = stop_token_padded.cuda()
  mel_lengths = mel_lengths.cuda()

  N = mel_padded.shape[0]
  SOS = torch.zeros((N, 1, hp.mel_freq), device=mel_padded.device) # Start of sequence
  
  mel_input = torch.cat(
    [
      SOS, 
      mel_padded[:, :-1, :] # (N, L, FREQ)
    ],
    dim=1
  )  

  return text_padded, \
         text_lengths, \
         mel_padded, \
         mel_lengths, \
         mel_input, \
         stop_token_padded



def inference_utterance(model, text):
  sequences = text_to_seq(text).unsqueeze(0).cuda()
  postnet_mel, stop_token = model.inference(
    sequences, 
    stop_token_threshold=1e5, 
    with_tqdm = False
  )          
  audio = inverse_mel_spec_to_wav(postnet_mel.detach()[0].T)
            
  fig, (ax1) = plt.subplots(1, 1)
  ax1.imshow(
      postnet_mel[0, :, :].detach().cpu().numpy().T, 
  )
  
  return audio, fig 


def calculate_test_loss(model, test_loader):
  test_loss_mean = 0.0
  model.eval()

  with torch.no_grad():
    for test_i, test_batch in enumerate(test_loader):
      test_text_padded, \
      test_text_lengths, \
      test_mel_padded, \
      test_mel_lengths, \
      test_mel_input, \
      test_stop_token_padded = batch_process(batch)

      test_post_mel_out, test_mel_out, test_stop_token_out = model(
        test_text_padded, 
        test_text_lengths,
        test_mel_input, 
        test_mel_lengths
      )        
      test_loss = criterion(
        mel_postnet_out = test_post_mel_out, 
        mel_out = test_mel_out, 
        stop_token_out = test_stop_token_out, 
        mel_target = test_mel_padded, 
        stop_token_target = test_stop_token_padded
      )

      test_loss_mean += test_loss.item()

  test_loss_mean = test_loss_mean / (test_i + 1)  
  return test_loss_mean


if __name__ == "__main__":
  torch.manual_seed(hp.seed)

  df = pd.read_csv(hp.csv_path)  
  train_df, test_df = train_test_split(
    df, 
    test_size=64, 
    random_state=hp.seed
  )
  train_loader = torch.utils.data.DataLoader(
      TextMelDataset(train_df), 
      num_workers=2, 
      shuffle=True,
      sampler=None, 
      batch_size=hp.batch_size,
      pin_memory=True, 
      drop_last=True, 
      collate_fn=text_mel_collate_fn
  )
  test_loader = torch.utils.data.DataLoader(
      TextMelDataset(test_df), 
      num_workers=2, 
      shuffle=True,
      sampler=None, 
      batch_size=8,
      pin_memory=True, 
      drop_last=True, 
      collate_fn=text_mel_collate_fn
  )  
  
  train_saved_path = f"{hp.save_path}/train_{hp.save_name}"
  test_saved_path = f"{hp.save_path}/test_{hp.save_name}"
  
  print("train_saved_path:", train_saved_path)
  print("test_saved_path:", test_saved_path)

  logger = SummaryWriter(hp.log_path)  
  criterion = TTSLoss().cuda()
  model = TransformerTTS().cuda()
  optimizer = torch.optim.AdamW(model.parameters(), lr=hp.lr)
  scaler = torch.cuda.amp.GradScaler()  

  best_test_loss_mean = float("inf")
  best_train_loss_mean = float("inf")
  
  train_loss_mean = 0.0
  epoch = 0
  i = 0

  if os.path.isfile(train_saved_path):  
    state = torch.load(train_saved_path)
    state_model = state["model"]
    state_optimizer = state["optimizer"]
    
    i = state["i"] + 1
    best_test_loss_mean = state.get("test_loss", float("inf"))
    best_train_loss_mean = state.get("train_loss", float("inf"))

    model.load_state_dict(state_model)
    optimizer.load_state_dict(state_optimizer)

    print(f"Load: {i}; test_loss: {np.round(best_test_loss_mean, 5)}; train_loss: {np.round(best_train_loss_mean, 5)}")
  else:
    print("Start from zero!")


  start_time_sec = time.time()
  while True:
    for batch in train_loader:      
      text_padded, \
      text_lengths, \
      mel_padded, \
      mel_lengths, \
      mel_input, \
      stop_token_padded = batch_process(batch)

      model.train(True)
      model.zero_grad()

      with torch.autocast(device_type='cuda', dtype=torch.float16):
        post_mel_out, mel_out, stop_token_out = model(
          text_padded, 
          text_lengths,
          mel_input, 
          mel_lengths
        )        
        loss = criterion(
          mel_postnet_out = post_mel_out, 
          mel_out = mel_out, 
          stop_token_out = stop_token_out, 
          mel_target = mel_padded, 
          stop_token_target = stop_token_padded
        )

      scaler.scale(loss).backward()      
      scaler.unscale_(optimizer)
      torch.nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip)
      scaler.step(optimizer)
      scaler.update()

      train_loss_mean += loss.item()      

      if i !=0 and i % hp.step_print == 0:
        train_loss_mean = train_loss_mean / hp.step_print        
        logger.add_scalar("Loss/train_loss", train_loss_mean, global_step=i)
        
        if i % hp.step_test == 0:            
          test_loss_mean = calculate_test_loss(model, test_loader)
          audio, fig = inference_utterance(model, "Hello, World.")

          logger.add_scalar("Loss/test_loss", test_loss_mean, global_step=i)
          logger.add_figure(f"Img/img_{i}", fig, global_step=i) 
          logger.add_audio(f"Utterance/audio_{i}",audio, sample_rate=hp.sr, global_step=i)
          
          print(f"{epoch}-{i}) Test loss: {np.round(test_loss_mean, 5)}")

          if i % hp.step_save == 0:
            is_best_train = train_loss_mean < best_train_loss_mean
            is_best_test = test_loss_mean < best_test_loss_mean

            state = {
              "model": model.state_dict(),
              "optimizer": optimizer.state_dict(),
              "i": i,
              "test_loss": test_loss_mean,
              "train_loss": train_loss_mean
            }

            if is_best_train:
              print(f"{epoch}-{i}) Save best train")
              torch.save(state, train_saved_path)
              best_train_loss_mean = train_loss_mean

            if is_best_test:
              print(f"{epoch}-{i}) Save best test")
              torch.save(state, test_saved_path)
              best_test_loss_mean = test_loss_mean
              

        end_time_sec = time.time()
        time_sec = np.round(end_time_sec - start_time_sec, 3)
        start_time_sec = end_time_sec
        
        print(f"{epoch}-{i}) Train loss: {np.round(train_loss_mean, 5)}; Duration: {time_sec} sec.")
        train_loss_mean = 0.0

      i += 1
    epoch += 1   


train_saved_path: params/train_SimpleTransfromerTTS.pt
test_saved_path: params/test_SimpleTransfromerTTS.pt
Start from zero!
