In [195]:
import pytorch_lightning as pl
import numpy as np
import torchaudio
import torch
import torch.optim
import pandas as pd
import matplotlib.pyplot as plt
import IPython.display as ipd

from torch.utils.data import DataLoader
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoderLayer
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
from utils import WaveNet, AudioMELSpectogramDataset, ContentLoss, StyleLoss

In [196]:
class StyleAutoEncoder(nn.Module):

    def __init__(self, in_channels, shared_channels, out_channels, kernel_size, dilation, nhead=8, dropout=0.2, dim_feedforward=1024):
        super(StyleAutoEncoder, self).__init__()
        self.sharedEncoder = WaveNet(in_channels, shared_channels, kernel_size, dilation, dropout)
        self.speechEncoder = WaveNet(shared_channels, out_channels, kernel_size, dilation, dropout)
        self.styleEncoder = WaveNet(shared_channels, out_channels, kernel_size, dilation, dropout)
        self.style_encoder_attention = TransformerEncoderLayer(
            d_model=out_channels,
            nhead=8,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='relu'
        )
        self.speech_encoder_attention = TransformerEncoderLayer(
            d_model=out_channels,
            nhead=8,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='relu'
        )
        self.decoder_attention = TransformerDecoderLayer(
            d_model=out_channels*2,
            nhead=8,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='relu'
        )
        self.decoder = WaveNet(out_channels*2, in_channels, kernel_size, dilation, dropout)

    def forward(self, speech, style):

        # print('speech', speech.shape)
        # print('style', style.shape)

        style = self.sharedEncoder(style)
        speech = self.sharedEncoder(speech) 

        style = self.styleEncoder(style)
        speech = self.speechEncoder(speech)
        
        style = style.permute(2, 0, 1)  # Reordenamos las dimensiones para que sea compatible con la capa de atención
        style_encoder = self.style_encoder_attention(style)
        style = style.permute(1, 2, 0)
        style_encoder = speech.permute(0, 1, 2)

        speech = speech.permute(2, 0, 1)
        speech_encoder = self.speech_encoder_attention(speech)
        speech = speech.permute(1, 2, 0)
        speech_encoder = speech.permute(0, 1, 2)
        
        # print('speech', speech.shape)
        # print('style', style.shape)
        # print('speech_encoder', speech_encoder.shape)
        # print('style_encoder', style_encoder.shape)

        out = torch.cat((speech, style), dim=1)
        out_encoder = torch.cat((speech_encoder, style_encoder), dim=1)

        out = out.permute(2, 0, 1) 
        out_encoder = out.permute(0, 1, 2) 

        # print('out', out.shape)
        # print('out_encoder', out_encoder.shape)

        x = self.decoder_attention(out, out_encoder)
        x = x.permute(1, 2, 0) 

        # print('x', x.shape)

        x = self.decoder(x)
        x = x.unsqueeze(1)
        # print('x', x.shape)

        return x #.squeeze(0)
 
class MelSpecVCAutoencoderModule(pl.LightningModule):

    def __init__(self, in_channels, shared_channels, out_channels, kernel_size, dilation, device, sample_rate, dropout=0.2):
        super(MelSpecVCAutoencoderModule, self).__init__()
        self.encoder_generator = StyleAutoEncoder(in_channels, shared_channels, out_channels, kernel_size, dilation, dropout)
        self.todevice = device

    def forward(self, speech, style):
        y = self.encoder_generator(speech, style)
        return y

    def training_step(self, batch, batch_idx):
        speech, style = batch
        if speech.ndim == 2: speech = speech.unsqueeze(1)
        if style.ndim == 2: style = style.unsqueeze(1)
        y = self(speech, style)
        print('y', y.shape)
        print('speech', speech.shape)
        print('style', style.shape)
        loss_speech_loss_fn = ContentLoss(speech)
        loss_style_loss_fn = StyleLoss(style)
        loss_speech = loss_speech_loss_fn(y)
        loss_style = loss_style_loss_fn(y)
        loss = loss_speech + loss_style
        self.log('train_loss', loss.item())
        return loss 
    
    def validation_step(self, batch, batch_idx):
        speech, style = batch
        if speech.ndim == 2: speech = speech.unsqueeze(1)
        if style.ndim == 2: style = style.unsqueeze(1)
        y = self(speech, style)
        loss_speech_loss_fn = ContentLoss(speech)
        loss_style_loss_fn = StyleLoss(style)
        loss_speech = loss_speech_loss_fn(y)
        loss_style = loss_style_loss_fn(y)
        loss = loss_speech + loss_style
        self.log('val_loss', loss.item())
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)
        return {
            'optimizer': optimizer,
            'lr_scheduler': scheduler,
            'monitor': 'val_loss'
        }

In [197]:
torch.set_float32_matmul_precision('medium')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [198]:
batch_size = 4
sample_rate = 16000
div_ratio = 0.8
n_fft = 800
audio_length = 2
in_channels, shared_channels, out_channels, kernel_size, dilation = 128, 64, 32, 3, 12
metadata = pd.read_csv('metadata.csv')
dataset = AudioMELSpectogramDataset(metadata, device, sample_rate=sample_rate, audio_length=audio_length, n_fft=n_fft)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * div_ratio), len(dataset) - int(len(dataset) * div_ratio)])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True)
model = MelSpecVCAutoencoderModule(in_channels, shared_channels, out_channels, kernel_size, dilation, device=device, sample_rate=sample_rate)
model = model.to(device)



In [199]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath='./checkpoints',
    filename='{epoch}-{val_loss:.2f}-{val_r2:.2f}',
    save_top_k=1,
    monitor='val_loss',
    every_n_epochs =1,
    mode='min'
)

trainer = pl.Trainer(
    accelerator='gpu',
    max_epochs=10,
    benchmark=True,
    # deterministic=True,
    precision=16,
    accumulate_grad_batches=6,
    callbacks=[
        checkpoint_callback,
        pl.callbacks.LearningRateMonitor(logging_interval='step')
    ]
)

Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [200]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type             | Params
-------------------------------------------------------
0 | encoder_generator | StyleAutoEncoder | 2.3 M 
-------------------------------------------------------
2.3 M     Trainable params
0         Non-trainable params
2.3 M     Total params
4.506     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


In [201]:
random_speech = train_dataset[120]
mel_spectrogram = random_speech[0]
style_spec = random_speech[1]

# Convert mel spectrogram to linear scale
mel_to_linear = torchaudio.transforms.InverseMelScale(n_stft=800)
linear_spectrogram = mel_to_linear(mel_spectrogram)

# Convert linear spectrogram to waveform
griffin_lim = torchaudio.transforms.GriffinLim()
waveform = griffin_lim(linear_spectrogram)

# Plot mel spectrogram
plt.figure(figsize=(10, 4))
plt.imshow(mel_spectrogram.log2(), cmap='inferno', origin='lower', aspect='auto')

# Set frequency axis ticks
num_ticks = 10
freq_bins = np.linspace(0, mel_spectrogram.shape[0], num_ticks)
hz_ticks = torchaudio.transforms.MelScale().mel_to_hz(torch.tensor(freq_bins))
plt.yticks(freq_bins, ["{:.0f}".format(hz) for hz in hz_ticks])

plt.colorbar(format='%+2.0f dB')
plt.xlabel('Time')
plt.ylabel('Frequency (Hz)')
plt.title('Mel Spectrogram')
plt.show()

# display(ipd.Audio(waveform.cpu().numpy(), rate=sample_rate))

plt.show()

In [None]:
model.eval()
model = model.to(device)

with torch.no_grad():
    predicted_audio = model.encoder_generator(random_audio.unsqueeze(0), style_audio.unsqueeze(0))

print(random_audio.shape, style_audio.shape, predicted_audio.shape)

display(ipd.Audio(random_audio.cpu().numpy(), rate=sample_rate))
display(ipd.Audio(style_audio.cpu().numpy(), rate=sample_rate))
predicted_audio = predicted_audio.squeeze(0)
display(ipd.Audio(predicted_audio.cpu().numpy(), rate=sample_rate))


In [None]:
random_speech = train_dataset[1]
mel_spectrogram = random_speech[0]
style_spec = random_speech[1]

mel_spectrogram = mel_spectrogram.unsqueeze(0)
style_spec = style_spec.unsqueeze(0)

print('MEL', mel_spectrogram.shape)
print('style_spec', style_spec.shape)

sharedEncoder = WaveNet(128, out_channels, kernel_size, dilation)

mel_spectrogram = sharedEncoder(mel_spectrogram)
style_spec = sharedEncoder(style_spec)

print('MEL', mel_spectrogram.shape)

# spec image to encoder transformer layer
encoder = TransformerEncoderLayer(
    d_model=64,
    nhead=8,
    dim_feedforward=2048,
    dropout=0.2,
    activation='relu'
)

mel_spectrogram = mel_spectrogram.permute(2,0,1)
print('MEL', mel_spectrogram.shape)

x = encoder(mel_spectrogram)
x = x.permute(1,2,0)
mel_spectrogram = mel_spectrogram.permute(1,2,0)

mel_spectrogram += x

print('x', x.shape)
print('MEL', mel_spectrogram.shape)
print('style_spec', style_spec.shape)

mel_spectrogram = mel_spectrogram.permute(2,0,1)
style_spec = style_spec.permute(2,0,1)
decoder_attention = TransformerDecoderLayer(
    d_model=64,
    nhead=8,
    dim_feedforward=2048,
    dropout=0.2,
    activation='relu'
)

x = decoder_attention(style_spec, mel_spectrogram)

print('x', x.shape)
print('MEL', mel_spectrogram.shape)
print('style_spec', style_spec.shape)

MEL torch.Size([1, 128, 161])
style_spec torch.Size([1, 128, 161])
MEL torch.Size([1, 32, 161])
MEL torch.Size([161, 1, 32])
