In [None]:
import os
import math
import numpy as np
import pandas as pd

import torch
import torchaudio
import torch.nn as nn
from torch.nn import functional as F
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d


from scipy.io.wavfile import read
from IPython.display import Audio
from torch.utils.data import DataLoader

In [None]:
import sys
sys.path.append('/kaggle/input/vits-model')
import  monotonic_align, attentions
from common import slice_segments, sequence_mask, TextAudioCollate
from modules import WN, Flip, ResidualCouplingLayer, DurationPredictor
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch, spectrogram_torch

In [None]:
data_path = '/kaggle/input/the-lj-speech-dataset/LJSpeech-1.1'
wavs_path = data_path + "/wavs/"
metadata_path = data_path + "/metadata.csv"

# Read metadata file and parse it
metadata_df = pd.read_csv(metadata_path, sep="|", header=None, quoting=3)
metadata_df.columns = ["file_name", "transcription", "normalized_transcription"]
metadata_df = metadata_df[["file_name", "normalized_transcription"]]
metadata_df.head()

In [None]:
# Text Encoding
_pad        = '_'
_punctuation = ';:,.!?¡¿—…"«»“”ü ()-' + "'" + "[]" 
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
symbols = [_pad] + list(_punctuation) + list(_letters)
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}

def text_to_sequence(text):
    sequence = []
    for symbol in text:
        symbol_id = _symbol_to_id.get(symbol, None)
        if symbol_id is not None:
            sequence.append(symbol_id)
    return sequence

In [None]:
class TextAudioLoader(torch.utils.data.Dataset):
    def __init__(self, audiopaths_and_text):
        self.audio_path = audiopaths_and_text['file_name'].values
        self.text = audiopaths_and_text['normalized_transcription'].values
        self.sampling_rate = 22050

        self.max_wav_value = 32768.0
        self.filter_length = 1024
        self.hop_length = 256
        self.win_length = 1024
        self._filter()
        
    def _filter(self):
        audios = []
        texts = []
        
        for audiopath, text in zip(self.audio_path, self.text):
            if 1 <= len(text) and len(text) <= 100:
                audios.append(audiopath)
                texts.append(text) 
        self.audio_path =  audios
        self.text = texts
        
    def get_audio(self, filename):
        full_path = '/kaggle/input/the-lj-speech-dataset/LJSpeech-1.1/wavs/'+ filename + '.wav'
        sampling_rate, audio = read(full_path)
        audio = torch.FloatTensor(audio.astype(np.float32))
        
        if sampling_rate != self.sampling_rate:
            raise ValueError("{} {} SR doesn't match target {} SR".format(
                sampling_rate, self.sampling_rate))
            
        audio_norm = audio / self.max_wav_value
        audio_norm = audio_norm.unsqueeze(0)
        spec = spectrogram_torch(audio_norm, self.filter_length,
            self.sampling_rate, self.hop_length, self.win_length,
            center=False)
        spec = torch.squeeze(spec, 0)
        return spec, audio_norm
            
    
    def get_text(self, text):
        text_norm = text_to_sequence(text)
        text_norm = torch.LongTensor(text_norm)
        return text_norm

    def __getitem__(self, index):
        text = self.get_text(self.text[index])
        spec, audio_norm = self.get_audio(self.audio_path[index])
        return [text, spec, audio_norm]
    
    def play_audio(self, index):
        audio_path = self.audio_path[index]
        text = self.text[index]
        audio_path = '/kaggle/input/the-lj-speech-dataset/LJSpeech-1.1/wavs/'+ audio_path + '.wav'
        audio, sr = torchaudio.load(audio_path)
        return Audio(audio.numpy(), rate=sr)
    
    def __len__(self):
        return len(self.audio_path)
    

train_dataset = TextAudioLoader(metadata_df)
collate_fn = TextAudioCollate()
train_loader = DataLoader(train_dataset, batch_size = 16, shuffle=True,\
                           pin_memory=True, collate_fn=collate_fn)

In [None]:
class TextEncoder(nn.Module):
    def __init__(self,
        n_vocab,
        out_channels,
        hidden_channels,
        filter_channels,
        n_heads,
        n_layers,
        kernel_size,
        p_dropout):
        super().__init__()
        
        self.n_vocab = n_vocab
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels
        self.filter_channels = filter_channels
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.kernel_size = kernel_size
        self.p_dropout = p_dropout
        self.emb = nn.Embedding(n_vocab, hidden_channels)

        self.encoder = attentions.Encoder(
            hidden_channels,
            filter_channels,
            n_heads,
            n_layers,
            kernel_size,
            p_dropout)             
        self.proj= Conv1d(hidden_channels, out_channels * 2, 1)

    def forward(self, x, x_lengths):
        x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
        x = torch.transpose(x, 1, -1) # [b, h, t]
        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
        x = self.encoder(x * x_mask, x_mask)
        
        stats = self.proj(x) * x_mask
        m, logs = torch.split(stats, self.out_channels, dim=1)
        return x, m, logs, x_mask

In [None]:
class PosteriorEncoder(nn.Module):
    def __init__(self,
        in_channels,
        out_channels,
        hidden_channels,
        kernel_size = 5,
        dilation_rate = 1,
        n_layers = 16,
        gin_channels=0):
        
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.dilation_rate = dilation_rate
        self.n_layers = n_layers
        self.gin_channels = gin_channels

        self.pre = Conv1d(in_channels, hidden_channels, 1)
        self.proj = Conv1d(hidden_channels, out_channels * 2, 1)
        self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)

    def forward(self, x, x_lengths, g=None):
        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
        x = self.pre(x) * x_mask 
        x = self.enc(x, x_mask, g=g) # [b, h, spec_size]
        stats = self.proj(x) * x_mask # [b, h * 2, spec_size]
        m, logs = torch.split(stats, self.out_channels, dim=1)
        
        # z sampling (reparameterization trick)
        z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
        return z, m, logs, x_mask

In [None]:
def get_padding(kernel_size, dilation):
    return (kernel_size - 1) * dilation // 2

def init_weights(m):
    if isinstance(m, nn.Conv1d):
        nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')

class ResBlock(nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
        super(ResBlock, self).__init__()
        self.convs1 = nn.ModuleList([
            Conv1d(channels, channels, kernel_size, dilation=dilation[0],
                    padding=get_padding(kernel_size, dilation[0])),
            Conv1d(channels, channels, kernel_size, dilation=dilation[1],
                    padding=get_padding(kernel_size, dilation[1])),
            Conv1d(channels, channels, kernel_size, dilation=dilation[2],
                    padding=get_padding(kernel_size, dilation[2]))
        ])
        self.convs1.apply(init_weights)
        self.convs2 = nn.ModuleList([
            Conv1d(channels, channels, kernel_size, dilation=1,
                    padding=get_padding(kernel_size, 1)),
            Conv1d(channels, channels, kernel_size, dilation=1,
                    padding=get_padding(kernel_size, 1)),
            Conv1d(channels, channels, kernel_size, dilation=1,
                    padding=get_padding(kernel_size, 1))
        ])
        self.convs2.apply(init_weights)

    def forward(self, x, x_mask=None):
        for c1, c2 in zip(self.convs1, self.convs2):
            xt = F.leaky_relu(x, 0.2)
            if x_mask is not None:
                xt = xt * x_mask
            xt = c1(xt)
            xt = F.leaky_relu(xt, 0.2)
            if x_mask is not None:
                xt = xt * x_mask
            xt = c2(xt)
            x = xt + x
        if x_mask is not None:
            x = x * x_mask
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self, initial_channel, resblock_kernel_sizes, upsample_rates, 
                 upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
        super(Decoder, self).__init__()
        self.num_kernels = len(resblock_kernel_sizes)
        self.num_upsamples = len(upsample_rates)
        self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)

        self.ups = nn.ModuleList()
        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
            self.ups.append(ConvTranspose1d(upsample_initial_channel // (2 ** i), \
                                            upsample_initial_channel // (2 ** (i + 1)), \
                                            k, u, padding=(k - u) // 2))

        self.resblocks = nn.ModuleList()
        for i in range(len(self.ups)):
            ch = upsample_initial_channel // (2 ** (i + 1))
            for k in resblock_kernel_sizes:
                self.resblocks.append(ResBlock(ch, k, [1, 3, 5]))

        self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
        if gin_channels != 0:
            self.cond = Conv1d(gin_channels, upsample_initial_channel, 1)

    def forward(self, x, g=None):
        x = self.conv_pre(x)
        if g is not None:
            x = x + self.cond(g)

        for i in range(self.num_upsamples):
            x = F.leaky_relu(x, 0.2)
            x = self.ups[i](x)
            xs = None
            for j in range(self.num_kernels):
                if xs is None:
                    xs = self.resblocks[i * self.num_kernels + j](x)
                else:
                    xs += self.resblocks[i * self.num_kernels + j](x)
            x = xs / self.num_kernels
        x = F.leaky_relu(x)
        x = self.conv_post(x)
        x = torch.tanh(x)
        return x

In [None]:
class ResidualCouplingBlock(nn.Module):
    def __init__(self,
        channels,
        hidden_channels,
        kernel_size = 5,
        dilation_rate = 1,
        n_layers = 4,
        n_flows=4,
        gin_channels=0):
        super().__init__()
        self.channels = channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.dilation_rate = dilation_rate
        self.n_layers = n_layers
        self.n_flows = n_flows
        self.gin_channels = gin_channels

        self.flows = nn.ModuleList()
        for i in range(n_flows):
            self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, \
                                                    dilation_rate, n_layers, gin_channels=gin_channels))
            self.flows.append(Flip())

    def forward(self, x, x_mask, g=None, reverse=False):
        if not reverse:
            for flow in self.flows:
                x, logdet = flow(x, x_mask, g=g, reverse=reverse)
        else:
            for flow in reversed(self.flows):
                x = flow(x, x_mask, g=g, reverse=reverse)
        return x

In [None]:
class VITGenerator(nn.Module):
    def __init__(self, n_vocab, inter_channels, hidden_channels, filter_channels, spec_channels):
        super().__init__()
        self.enc_p = TextEncoder(n_vocab,
            inter_channels,
            hidden_channels,
            filter_channels,
            n_heads = 2,
            n_layers = 6,
            kernel_size = 3,
            p_dropout = 0.1)
        
        self.seg_size = 8192 // 256
        self.dec = Decoder(inter_channels, [3,7,11], [8,8,2,2], 512, [16,16,4,4])
        self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels)
        self.flow = ResidualCouplingBlock(inter_channels, hidden_channels)
        self.dp = DurationPredictor(hidden_channels)
    
    def forward(self, x, x_lengths, y, y_lengths):
        g = None
        x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
        z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
        z_p = self.flow(z, y_mask, g=g)
        
        with torch.no_grad():
            # negative cross-entropy
            s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
            neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True)
            neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) 
            neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r))
            neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) 
            neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
            
            attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
            attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()

        w = attn.sum(2)
        logw_ = torch.log(w + 1e-6) * x_mask
        logw = self.dp(x, x_mask, g=g)
        l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging 
        
        # expand prior
        m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
        logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
        z_slice, ids_slice = slice_segments(z, x_lengths = y_lengths, segment_size = self.seg_size)
        o = self.dec(z_slice, g=g)
        return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
    
    def infer(self, x, x_lengths, max_len=None):
        g = None
        x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
        logw = self.dp(x, x_mask, g=g)
        
        w = torch.exp(logw) * x_mask 
        w_ceil = torch.ceil(w)
        y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
        y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
        attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
        attn = monotonic_align.generate_path(w_ceil, attn_mask)
        m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
        logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)

        z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p)
        z = self.flow(z_p, y_mask, g=g, reverse=True)
        o = self.dec((z * y_mask)[:,:,:max_len], g=g)
        return o, attn, y_mask, (z, z_p, m_p, logs_p)

In [None]:
class HiFiGANDiscriminator(nn.Module):
    def __init__(self, period):
        super(HiFiGANDiscriminator, self).__init__()
        self.period = period
        self.convs = nn.ModuleList([
            nn.Conv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
            nn.Conv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
            nn.Conv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
            nn.Conv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
            nn.Conv2d(1024, 1024, (5, 1), 1, padding=(2, 0))
        ])
        self.conv_post = nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))
        self.leaky_relu = nn.LeakyReLU(0.1)

    def forward(self, x):
        fmap = []
        b, c, t = x.shape
        if t % self.period != 0:
            n_pad = self.period - (t % self.period)
            x = F.pad(x, (0, n_pad), 'reflect')
            t = t + n_pad
 
        x = x.view(b, c, t // self.period, self.period)
        for l in self.convs:
            x = self.leaky_relu(l(x))
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        x = torch.flatten(x, 1, -1)
        return x, fmap

class MultiScaleDiscriminator(nn.Module):
    def __init__(self):
        super(MultiScaleDiscriminator, self).__init__()
        self.discriminators = nn.ModuleList([
            HiFiGANDiscriminator(2),
            HiFiGANDiscriminator(3),
            HiFiGANDiscriminator(5)
        ])

    def forward(self, y, y_hat):
        y_d_rs = []
        y_d_gs = []
        fmap_rs = []
        fmap_gs = []
        
        for d in self.discriminators:
            y_d_r, fmap_r = d(y)
            y_d_g, fmap_g = d(y_hat)
            y_d_rs.append(y_d_r)
            y_d_gs.append(y_d_g)
            fmap_rs.append(fmap_r)
            fmap_gs.append(fmap_g)
        return y_d_rs, y_d_gs, fmap_rs, fmap_gs

In [None]:
def feature_loss(fmap_r, fmap_g):
    # Computes the Mean Absolute Error (MAE) \
    # between the feature maps of real and generated data.
    
    loss = 0
    for dr, dg in zip(fmap_r, fmap_g):
        for rl, gl in zip(dr, dg):
            rl = rl.float().detach()
            gl = gl.float()
            loss += torch.mean(torch.abs(rl - gl))
    return loss * 2 

def discriminator_loss(disc_real_outputs, disc_generated_outputs):
    loss = 0
    r_losses = []
    g_losses = []
    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
        dr = dr.float()
        dg = dg.float()
        r_loss = torch.mean((1-dr)**2)
        g_loss = torch.mean(dg**2)
        loss += (r_loss + g_loss)
        r_losses.append(r_loss.item())
        g_losses.append(g_loss.item())

    return loss, r_losses, g_losses

def generator_loss(disc_outputs):
    loss = 0
    gen_losses = []
    for dg in disc_outputs:
        dg = dg.float()
        l = torch.mean((1-dg)**2)
        gen_losses.append(l)
        loss += l
        
    return loss, gen_losses

def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
    z_p = z_p.float()
    logs_q = logs_q.float()
    m_p = m_p.float()
    logs_p = logs_p.float()
    z_mask = z_mask.float()

    kl = logs_p - logs_q - 0.5
    kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
    kl = torch.sum(kl * z_mask)
    l = kl / torch.sum(z_mask)
    return l

In [None]:
import time
import soundfile as sf
from tqdm import *
from torch.cuda.amp import autocast, GradScaler

def save_audio(y_hat, epochs, sample_rate=22050):
    if not isinstance(y_hat, np.ndarray):
        y_hat = y_hat.detach().cpu().numpy()

    for i in range(y_hat.shape[0]):
        audio_segment = y_hat[i, 0, :] 
        file_name = f"audio_segment_{epochs}_{i+1}.wav"
        sf.write(file_name, audio_segment, sample_rate)
        print(f"Saved {file_name}")


In [None]:
class VitTrainer():
    def __init__(self, g, d, data_loader, total_steps, save_n_steps, \
                 ckpt_dir, LR, device='cuda', load_path = None):
        super().__init__()
        self.g = g.to(device)
        self.d = d.to(device)
        
        self.data_loader = data_loader
        self.optim_g = torch.optim.AdamW(self.g.parameters(), LR, betas=[0.8, 0.99], eps=1e-9)                 
        self.optim_d = torch.optim.AdamW(self.d.parameters(), LR, betas=[0.8, 0.99], eps=1e-9)
        self.scaler = GradScaler(enabled=True)
        
        self.step = 1
        self.ckpt_dir = ckpt_dir
        self.device = device
        self.total_steps = total_steps
        self.save_n_steps = save_n_steps
        
        if load_path is not None:
            self.load_state_dict(load_path)
            print("sucessful load state dict !!!!!!")
            print(f"start from step {self.step}")
        
    def state_dict(self, step):
        return {
            "step": step,
            'g': self.g.state_dict(),
            'd': self.d.state_dict()
        }
    
    def load_state_dict(self, path):
        state_dict = torch.load(path)
        self.g.load_state_dict(state_dict['g'])
        self.d.load_state_dict(state_dict['d'])
        self.step = state_dict['step']
        
    def train(self):
        start = time.time()
        print(f'Start of step {self.step}')
    
        for step in tqdm(range(self.step, self.total_steps+1), desc=f"Training progress"):
            x, x_lengths, spec, spec_lengths, y, y_lengths = next(iter(self.data_loader))
            x, x_lengths = x.to(self.device), x_lengths.to(self.device)
            spec, spec_lengths = spec.to(self.device), spec_lengths.to(self.device)
            y, y_lengths = y.to(self.device), y_lengths.to(self.device)

            with autocast(enabled=True):
                y_hat, l_length, attn, ids_slice, x_mask, z_mask,\
                (z, z_p, m_p, logs_p, m_q, logs_q) = self.g(x, x_lengths, spec, spec_lengths)
                mel = spec_to_mel_torch(spec, 1024, 80, 22050, 0.0, None)
                y_mel, _ = slice_segments(mel, ids_str = ids_slice, segment_size = 8192 // 256)
                y_hat_mel = mel_spectrogram_torch(y_hat.squeeze(1), 1024, 80, 22050, 256, 1024, 0.0, None)
                y, _ = slice_segments(y, ids_str = ids_slice * 256, segment_size = 8192) 

                # Discriminator
                y_d_hat_r, y_d_hat_g, _, _ = self.d(y, y_hat.detach())
                with autocast(enabled=False):
                    loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
                    loss_disc_all = loss_disc
        
            self.optim_d.zero_grad()
            self.scaler.scale(loss_disc_all).backward()
            self.scaler.unscale_(self.optim_d)
            self.grad_norm = nn.utils.clip_grad_norm_(self.d.parameters(), 1e9)
            self.scaler.step(self.optim_d)

            with autocast(enabled=True):
                y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.d(y, y_hat)
                with autocast(enabled=False):
                    loss_dur = torch.sum(l_length.float())
                    loss_mel = F.l1_loss(y_mel, y_hat_mel) * 45
                    loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask)

                    loss_fm = feature_loss(fmap_r, fmap_g)
                    loss_gen, losses_gen = generator_loss(y_d_hat_g)
                    loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
                    
            self.optim_g.zero_grad()
            self.scaler.scale(loss_gen_all).backward()
            self.scaler.unscale_(self.optim_g)
            self.grad_norm = nn.utils.clip_grad_norm_(self.g.parameters(), 1e9)
            self.scaler.step(self.optim_g)
            self.scaler.update()
            
            if step % self.save_n_steps == 0:
                epoch = step // self.save_n_steps
                time_minutes = (time.time() - start) / 60
                torch.save(self.state_dict(step), f"{self.ckpt_dir}/weight_tts_epoch{epoch}.pt")
                
                print(f"epoch: {epoch}, d_loss: {loss_disc_all.data} ~~~~~~")
                print(f"epoch: {epoch}, g_loss: {loss_gen_all.data} ~~~~~~")
                print (f'Time taken for epoch {epoch} is {time_minutes:.3f} min\n') 
                print(f"sucessful saving epoch {epoch} state dict !!!!!!!")
                start = time.time()
                self.generate(epoch)
              
    def get_input(self, text):
        text = text_to_sequence(text)
        text_len = torch.tensor(len(text)).to(self.device).unsqueeze(0)
        text = torch.LongTensor(text).to(self.device).unsqueeze(0)
        return text, text_len

    def generate(self, epochs, text=None):
        self.g.eval()
        x, x_lengths, spec, spec_lengths, y, y_lengths = next(iter(self.data_loader))
        x, x_lengths = x.cuda(0), x_lengths.cuda(0)
        x = x[:4]
        x_lengths = x_lengths[:4]
        
        if text is not None:
            x, x_lengths = self.get_input(text)

        y_hat, attn, mask, *_ = self.g.infer(x, x_lengths, max_len=1000)
        y_hat_lengths = mask.sum([1,2]).long() * 256
        save_audio(y_hat, epochs)
        self.g.train()

In [None]:
n_vocab = len(symbols)
spec_channels = 1024 // 2 + 1
spec_size = 8192 // 256
hop_length = 256

inter_channels = 192
hidden_channels = 192
filter_channels = 768

total_iters = 15000
save_n_iters = 1500
LR = 2e-4
load_path = None
ckpt_dir = './model_weight/'
os.makedirs(ckpt_dir, exist_ok=True)

d_model = MultiScaleDiscriminator()
g_model = VITGenerator(n_vocab, inter_channels, hidden_channels, filter_channels, spec_channels)       
trainer = VitTrainer(g_model, d_model, train_loader, total_iters,\
                     save_n_iters, ckpt_dir, LR, load_path=load_path)

In [None]:
trainer.train()

In [None]:
inp_text= None
trainer.generate(-1, inp_text)