In [2]:
import pytorch_lightning as pl
import numpy as np
import torchaudio
import torch
import torch.optim
import pandas as pd
import matplotlib.pyplot as plt
import IPython.display as ipd
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch import nn
from torchaudio import transforms
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoderLayer
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
from utils import WaveNet, AudioMELSpectogramDataset, CalcContentLoss, CalcStyleLoss, plot_mel_spectrogram, mel_to_wav

In [3]:
torch.set_float32_matmul_precision('medium')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
class StyleAutoEncoder(nn.Module):

    def __init__(self, in_channels, shared_channels, out_channels, kernel_size, dilation, nhead=8, dropout=0.2, dim_feedforward=1024):
        super(StyleAutoEncoder, self).__init__()
        self.sharedEncoder = WaveNet(in_channels, shared_channels, kernel_size, dilation, dropout)
        self.speechEncoder = WaveNet(shared_channels, out_channels, kernel_size, dilation, dropout)
        self.styleEncoder = WaveNet(shared_channels, out_channels, kernel_size, dilation, dropout)
        self.style_encoder_attention = TransformerEncoderLayer(
            d_model=out_channels,
            nhead=8,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='relu'
        )
        self.speech_encoder_attention = TransformerEncoderLayer(
            d_model=out_channels,
            nhead=8,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='relu'
        )
        self.decoder_attention = TransformerDecoderLayer(
            d_model=out_channels*2,
            nhead=8,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='relu'
        )
        self.decoder = WaveNet(out_channels*2, in_channels, kernel_size, dilation, dropout)

    def forward(self, speech, style):
        # print('speech', speech.shape)
        # print('style', style.shape)

        style = self.sharedEncoder(style)
        speech = self.sharedEncoder(speech) 

        style = self.styleEncoder(style)
        speech = self.speechEncoder(speech)
        
        style = style.permute(2, 0, 1)  # Reordenamos las dimensiones para que sea compatible con la capa de atención
        style_encoder = self.style_encoder_attention(style)
        style = style.permute(1, 2, 0)
        style_encoder = speech.permute(0, 1, 2)

        speech = speech.permute(2, 0, 1)
        speech_encoder = self.speech_encoder_attention(speech)
        speech = speech.permute(1, 2, 0)
        speech_encoder = speech.permute(0, 1, 2)
        
        # print('speech', speech.shape)
        # print('style', style.shape)
        out = torch.cat((speech, style), dim=1)
        # print('speech_encoder', speech_encoder.shape)
        # print('style_encoder', style_encoder.shape)
        out_encoder = torch.cat((speech_encoder, style_encoder), dim=1)
        
        out = out.permute(2, 0, 1) 
        out_encoder = out.permute(0, 1, 2) 

        # print('out', out.shape)
        # print('out_encoder', out_encoder.shape)

        x = self.decoder_attention(out, out_encoder)
        x = x.permute(1, 2, 0) 

        # print('x', x.shape)

        x = self.decoder(x)
        # x = x.unsqueeze(1)
        # print('x', x.shape)

        return x #.squeeze(0)
 
class MelSpecVCAutoencoderModule(pl.LightningModule):

    def __init__(self, in_channels, shared_channels, out_channels, kernel_size, dilation, device, sample_rate, dropout=0.2):
        super(MelSpecVCAutoencoderModule, self).__init__()
        self.encoder_generator = StyleAutoEncoder(in_channels, shared_channels, out_channels, kernel_size, dilation, dropout)
        self.todevice = device
        self.speech_weight = 1
        self.style_weight = 100

    def forward(self, speech, style):
        y = self.encoder_generator(speech, style)
        return y

    def debug(self, loss_speech, loss_style, loss):
        # print('loss_speech', loss_speech.item())
        # print('loss_style', loss_style.item())
        # print('shape', loss.shape)
        # print('loss', loss.item())
        pass

    def apply_loss(self, batch):
        speech, style = batch
        if speech.ndim == 2: speech = speech.unsqueeze(1)
        if style.ndim == 2: style = style.unsqueeze(1)
        # print('speech', speech.shape)
        # print('style', style.shape)
        y = self(speech, style)
        # print('y', y.shape)
        loss_speech = CalcContentLoss(y, speech) * self.speech_weight
        loss_style = CalcStyleLoss(y, style) * self.style_weight
        loss = loss_speech + loss_style
        self.debug(loss_speech, loss_style, loss)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.apply_loss(batch)
        self.log('train_loss', loss.item())
        return loss 
    
    def validation_step(self, batch, batch_idx):
        loss = self.apply_loss(batch)
        self.log('val_loss', loss.item())
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)
        return {
            'optimizer': optimizer,
            'lr_scheduler': scheduler,
            'monitor': 'val_loss'
        }

In [5]:
batch_size = 32
sample_rate = 16000
div_ratio = 0.8
n_fft = 800
audio_length = 3
in_channels, shared_channels, out_channels, kernel_size, dilation = 128, 64, 32, 3, 12
metadata = pd.read_csv('metadata.csv')
dataset = AudioMELSpectogramDataset(metadata, device, sample_rate=sample_rate, audio_length=audio_length, n_fft=n_fft)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * div_ratio), len(dataset) - int(len(dataset) * div_ratio)])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True)
model = MelSpecVCAutoencoderModule(in_channels, shared_channels, out_channels, kernel_size, dilation, device=device, sample_rate=sample_rate)
model = model.to(device)
model = model.to(torch.float64)



In [6]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath='./checkpoints',
    filename='{epoch}-{val_loss:.2f}-{val_r2:.2f}',
    save_top_k=1,
    monitor='val_loss',
    every_n_epochs =1,
    mode='min'
)

trainer = pl.Trainer(
    accelerator='gpu',
    max_epochs=10,
    benchmark=True,
    # deterministic=True,
    precision=16,
    accumulate_grad_batches=6,
    callbacks=[
        checkpoint_callback,
        pl.callbacks.LearningRateMonitor(logging_interval='step')
    ]
)

Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(model, train_loader, val_loader)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type             | Params
-------------------------------------------------------
0 | encoder_generator | StyleAutoEncoder | 2.3 M 
-------------------------------------------------------
2.3 M     Trainable params
0         Non-trainable params
2.3 M     Total params
4.506     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [25]:
def mel_to_wav(mel_spectrogram):
    print('estaaa')
    sample_rate = 16000  # Frecuencia de muestreo del audio original
    n_fft = 256  # Número de puntos de la STFT
    hop_length = 128 # Tamaño del salto entre ventanas
    n_mels = 128  # Número de bandas de frecuencia en el espectrograma de Mel
    # Crear una instancia de la transformación inversa de escala Mel
    inverse_mel_scale = transforms.InverseMelScale(n_stft=(n_fft // 2) + 1, n_mels=n_mels).to(device)
    # Convertir el espectrograma de Mel a un espectrograma lineal
    mel_spectrogram = mel_spectrogram.float()
    linear_spectrogram = inverse_mel_scale(mel_spectrogram.detach())
    # Restaurar el espectrograma lineal a su escala original (opcional)
    linear_spectrogram = linear_spectrogram * linear_spectrogram.max()
    griffin_lim_transform = transforms.GriffinLim(n_fft=n_fft, hop_length=hop_length).to(device)
    waveform = griffin_lim_transform(linear_spectrogram)
    return waveform.cpu().numpy()

random_speech = train_dataset[120]
speech_spec = random_speech[0].to(device)
style_spec = random_speech[1].to(device)
display(ipd.Audio(mel_to_wav(speech_spec), rate=sample_rate))

model.eval()
# with torch.no_grad():
speech_spec = speech_spec.unsqueeze(0)
style_spec = style_spec.unsqueeze(0)
y = model(speech_spec, style_spec)
y = y.detach()
display(ipd.Audio(mel_to_wav(y), rate=sample_rate))

estaaa


estaaa
