In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import random

from CookieTTS.utils.dataset.utils import load_wav_to_torch, load_filepaths_and_text

# 1 - Init Model

In [5]:
class ConditionalBatchNorm1d(nn.Module):
    """
    Conditional Batch Normalization
    https://github.com/yanggeng1995/GAN-TTS/blob/master/models/generator.py#L121-L144
    """
    def __init__(self, num_features, z_channels=128):
        super(ConditionalBatchNorm1d).__init__()
        self.num_features = num_features
        self.z_channels = z_channels
        self.batch_norm = nn.BatchNorm1d(num_features, affine=False)
        
        self.layer = nn.utils.spectral_norm(nn.Linear(z_channels, num_features * 2))
        self.layer.weight.data.normal_(1, 0.02)  # Initialise scale at N(1, 0.02)
        self.layer.bias.data.zero_()             # Initialise bias at 0
    
    def forward(self, inputs, noise):
        outputs = self.batch_norm(inputs)
        gamma, beta = self.layer(noise).chunk(2, dim=1)
        gamma = gamma.view(-1, self.num_features, 1)
        beta = beta.view(-1, self.num_features, 1)
        
        outputs = gamma * outputs + beta
        
        return outputs


class ResidualBlock(nn.Module):
    def __init__(self, in_dim, out_dim, z_dim=None, dilation=1, kernel_size=1, act_func=nn.LeakyReLU(negative_slope=0.1, inplace=True), bias=True, scale:int=1):
        super(ResidualBlock, self).__init__()
        self.z_dim = z_dim
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.kernel_size = kernel_size
        self.padding = (kernel_size - 1)//2
        self.dilation = dilation
        self.bias = bias
        if scale != 1:
            self.scale = scale
        
        if self.z_dim is not None:
            self.bn = ConditionalBatchNorm1d(self.in_dim, self.z_dim)
        self.act_func = act_func
        
        self.conv = nn.Conv1d(self.in_dim, self.out_dim, self.kernel_size, padding=self.padding, dilation=self.dilation, bias=bias)
    
    def forward(self, x, z=None):                                 # [B, in_dim, T]
        if hasattr(self, 'bn'):
            x = self.bn(x, z)
        x = self.act_func(x)
        if hasattr(self, 'scale'):
            x = F.interpolate(x, scale_factor=self.scale) # [B, in_dim, T]   -> [B, in_dim, x*T]
        x = self.conv(x)                                  # [B, in_dim, x*T] -> [B, out_dim, x*T]
        return x                                          # [B, out_dim, x*T]


class GBlock(nn.Module):
    def __init__(self, input_dim, output_dim, z_dim, kernel_size=3, dilations=[1,2,4,8], scale:int=1, upsample_block_id=0, res_block_id=1):
        super(GBlock, self).__init__()
        self.resblocks = nn.ModuleList()
        self.upsample_block_id = upsample_block_id
        self.res_block_id = res_block_id
        self.scale = scale
        
        for i, dilation in enumerate(dilations):
            in_dim = input_dim if i == 0 else output_dim
            dilation = dilations[i]
            scale_f = scale if i == self.upsample_block_id else int(1)
            resblock = ResidualBlock(in_dim, output_dim, z_dim=z_dim, dilation=dilation, kernel_size=kernel_size, scale=scale_f)
            self.resblocks.append(resblock)
        
        self.skip_conv = nn.Conv1d(input_dim, output_dim, 1)
    
    def forward(self, h, z):
        h = torch.cat((h, z), dim=1)# -> [B, input_dim, T]
        
        scaled_h = F.interpolate(h, scale_factor=self.scale) if self.scale != 1 else h# [B, input_dim, T] -> [B, input_dim, x*T]
        residual = self.skip_conv(h)# [B, input_dim, x*T] -> [B, output_dim, x*T]
        
        for i, resblock in enumerate(self.resblocks): # [B, input_dim, T] -> [B, output_dim, x*T]
            h = resblock(h)
            if i == self.res_block_id:
                h += residual
                residiual = h
        
        return h + residual # [B, output_dim, x*T]


class ConditionalDBlock(nn.Module):
    def __init__(self, hp):
        super(ConditionalDBlock, self).__init__()
        pass
    
    def forward(self, x):
        
        return x


class DBlock(nn.Module):
    def __init__(self, hp):
        super(DBlock, self).__init__()
        pass
    
    def forward(self, x):
        
        return x


class Decoder(nn.Module):
    def __init__(self, hp):
        super(Decoder, self).__init__()
        self.start = nn.Conv1d(hp.in_channels, hp.decoder_dims[0], kernel_size=3)
        
        self.Gblocks = nn.ModuleList([
            GBlock(hp.decoder_dims[0], hp.decoder_dims[0], hp.z_dim, kernel_size=hp.gblock_kernel_size, dilations=hp.dilations, scale=hp.decoder_scales[0]),
        ])
        for i, dim in enumerate(hp.decoder_dims[:-1]):
            in_dim = hp.decoder_dims[i]
            out_dim = hp.decoder_dims[i+1]
            scale = hp.decoder_scales[i+1]
            gblock = GBlock(in_dim, out_dim, hp.z_dim, kernel_size=hp.gblock_kernel_size, dilations=hp.dilations, scale=scale)
            self.Gblocks.append(gblock)
        
        self.end = nn.Conv1d(hp.decoder_dims[-1], 1, kernel_size=3)
    
    def forward(self, x, z):
        x = self.start(x)
        for gblock in self.Gblocks:
            x = gblock(x, z)
        x = self.end(x)
        return x


class GANTTS(nn.Module):
    def __init__(self, hp):
        super(GANTTS, self).__init__()
        #self.encoder = Encoder(hp) # Text -> Encoder Features
        #self.durpred = DurPred(hp) # Encoder Features -> Durations
        self.decoder = Decoder(hp) # Durations + Encoder Features -> Audio
        self.z_dim = hp.z_dim
    
    def parse_encoder_outputs(self, encoder_outputs, durations):
        """
        Acts as Monotonic Attention for Encoder Outputs.
        
        [B, enc_T, enc_dim] x [B, enc_T, durations] -> [B, dec_T, enc_dim]
        """
        return attention_contexts
    
    @torch.jit.script
    def generate_noise(x, z_dim:int):
        noise = torch.randn(x.shape[0], z_dim, 1, device=x.device, dtype=x.dtype)
        return noise
    
    def forward(self, inputs):
        encoder_outputs, durations, *_ = inputs
        
        noise = self.generate_noise(encoder_outputs, self.z_dim)
        
        attention_contexts = self.parse_encoder_outputs(encoder_outputs, durations)
        
        pred_audio = self.decoder(attention_contexts, noise)
        
        return pred_audio

In [6]:
"blah" is True

False

# 2 - Init Dataset

# 3 - Config