In [1]:
import pytorch_lightning as pl
import numpy as np
import torchaudio
import torch
import torch.optim
import pandas as pd
import matplotlib.pyplot as plt
import IPython.display as ipd

from torch.utils.data import DataLoader
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoderLayer
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
from utils import WaveNet, AudioMELSpectogramDataset



In [2]:
class StyleAutoEncoder(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, dilation, n_head=8, dropout=0.2):
        super(StyleAutoEncoder, self).__init__()
        self.sharedEncoder = WaveNet(in_channels, out_channels//2, kernel_size, dilation, dropout)
        self.styleEncoder = WaveNet(out_channels//2, out_channels, kernel_size, dilation, dropout)
        self.speechEncoder = WaveNet(out_channels//2, out_channels, kernel_size, dilation, dropout)
        self.style_encoder_attention = TransformerEncoderLayer(out_channels, nhead=2)
        self.speech_encoder_attention = TransformerEncoderLayer(out_channels, nhead=2)
        self.decoder_attention = TransformerEncoderLayer(out_channels*2, nhead=2)
        self.decoder = WaveNet(out_channels*2, in_channels, kernel_size, dilation, dropout)

    def forward(self, speech, style):

        style = self.sharedEncoder(style)
        speech = self.sharedEncoder(speech) 

        style = self.styleEncoder(style)
        speech = self.speechEncoder(speech)
        
        style = style.permute(2, 0, 1)  # Reordenamos las dimensiones para que sea compatible con la capa de atención
        style = self.style_encoder_attention(style)
        style = style.permute(1, 2, 0)  # Restauramos el orden de las dimensiones
        
        speech = speech.permute(2, 0, 1)  # Reordenamos las dimensiones para que sea compatible con la capa de atención
        speech = self.speech_encoder_attention(speech)
        speech = speech.permute(1, 2, 0)  # Restauramos el orden de las dimensiones

        x = torch.cat((style, speech), dim=1)
        x = x.permute(2, 0, 1)  # Reordenamos las dimensiones para que sea compatible con la capa de atención
        x = self.decoder_attention(x)
        x = x.permute(1, 2, 0)  # Restauramos el orden de las dimensiones

        x = self.decoder(x)
        
        return x.squeeze(0) #, transcriptions
    
class WaveNetDiscriminator(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation, sample_rate, dropout=0.2):
        super(WaveNetDiscriminator, self).__init__()
        self.wavenet = WaveNet(in_channels, in_channels, kernel_size, dilation, dropout)
        self.linear1 = nn.Linear(in_channels * sample_rate, out_channels)
        self.linear2 = nn.Linear(out_channels, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        output = self.wavenet(x)
        output = output.view(x.shape[0], -1)
        output = self.linear1(output)
        output = self.linear2(output)
        output = self.sigmoid(output)
        return output

class MelSpecVCAutoencoderModule(pl.LightningModule):

    def __init__(self, in_channels, out_channels, kernel_size, dilation, device, sample_rate, dropout=0.2):
        super(MelSpecVCAutoencoderModule, self).__init__()
        self.encoder_generator = StyleAutoEncoder(in_channels, out_channels, kernel_size, dilation, dropout)
        self.encoder_discriminator = WaveNetDiscriminator(in_channels, out_channels, kernel_size, dilation, sample_rate, dropout)
        self.decoder_generator = StyleAutoEncoder(in_channels, out_channels, kernel_size, dilation, dropout)
        self.decoder_discriminator = WaveNetDiscriminator(in_channels, out_channels, kernel_size, dilation, sample_rate, dropout)
        self.loss_speech_loss_fn = nn.MSELoss()
        self.loss_style_loss_fn = nn.MSELoss()
        self.disc_loss_fn = nn.BCEWithLogitsLoss()
        self.todevice = device

    def forward(self, speech, style):
        y = self.encoder_generator(speech, style)
        y1 = self.decoder_generator(y, style)
        enc_disc_out = self.encoder_discriminator(y)
        dec_disc_out = self.decoder_discriminator(y1)
        labels = torch.zeros(y.shape[0], 1, device=self.device)
        return y, y1, enc_disc_out, dec_disc_out, labels

    def training_step(self, batch, batch_idx):
        speech, style = batch
        if speech.ndim == 2: speech = speech.unsqueeze(1)
        if style.ndim == 2: style = style.unsqueeze(1)
        y, y1, enc_disc_out, dec_disc_out, labels = self(speech, style)
        loss_speech = self.loss_speech_loss_fn(y, speech)
        loss_style = self.loss_style_loss_fn(y1, style)
        loss_disc_speech = self.disc_loss_fn(enc_disc_out,labels) + self.disc_loss_fn(dec_disc_out,labels)
        loss = loss_speech + loss_style + loss_disc_speech
        self.log('train_loss', loss.item())
        return loss 
    
    def validation_step(self, batch, batch_idx):
        speech, style = batch
        if speech.ndim == 2: speech = speech.unsqueeze(1)
        if style.ndim == 2: style = style.unsqueeze(1)
        y, y1, enc_disc_out, dec_disc_out, labels = self(speech, style)
        loss_speech = self.loss_speech_loss_fn(y, speech)
        loss_style = self.loss_style_loss_fn(y1, style)
        loss_disc_speech = self.disc_loss_fn(enc_disc_out,labels) + self.disc_loss_fn(dec_disc_out,labels)
        loss = loss_speech + loss_style + loss_disc_speech
        self.log('val_loss', loss.item())
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)
        return {
            'optimizer': optimizer,
            'lr_scheduler': scheduler,
            'monitor': 'val_loss'
        }

In [3]:
torch.set_float32_matmul_precision('medium') #  | 'high'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
batch_size = 4
sample_rate = 16000
div_ratio = 0.8
n_fft = 800
metadata = pd.read_csv('metadata.csv')
dataset = AudioMELSpectogramDataset(metadata, device, sample_rate=sample_rate, audio_length=2, n_fft=n_fft)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * div_ratio), len(dataset) - int(len(dataset) * div_ratio)])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True)
model = MelSpecVCAutoencoderModule(1, 64, 3, 12, device=device, sample_rate=sample_rate)
model = model.to(device)



In [5]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath='./checkpoints',
    filename='{epoch}-{val_loss:.2f}-{val_r2:.2f}',
    save_top_k=1,
    monitor='val_loss',
    every_n_epochs =1,
    mode='min'
)

trainer = pl.Trainer(
    accelerator='gpu',
    max_epochs=10,
    benchmark=True,
    # deterministic=True,
    precision=16,
    accumulate_grad_batches=6,
    callbacks=[
        checkpoint_callback,
        pl.callbacks.LearningRateMonitor(logging_interval='step')
    ]
)

Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
# trainer.fit(model, train_loader, val_loader)

In [29]:
random_speech = train_dataset[120]
mel_spectrogram = random_speech[0]
style_spec = random_speech[1]

# Convert mel spectrogram to linear scale
mel_to_linear = torchaudio.transforms.InverseMelScale()
linear_spectrogram = mel_to_linear(mel_spectrogram)

# Convert linear spectrogram to waveform
griffin_lim = torchaudio.transforms.GriffinLim()
waveform = griffin_lim(linear_spectrogram)

# Plot mel spectrogram
plt.figure(figsize=(10, 4))
plt.imshow(mel_spectrogram.log2(), cmap='inferno', origin='lower', aspect='auto')

# Set frequency axis ticks
num_ticks = 10
freq_bins = np.linspace(0, mel_spectrogram.shape[0], num_ticks)
hz_ticks = torchaudio.transforms.MelScale().mel_to_hz(torch.tensor(freq_bins))
plt.yticks(freq_bins, ["{:.0f}".format(hz) for hz in hz_ticks])

plt.colorbar(format='%+2.0f dB')
plt.xlabel('Time')
plt.ylabel('Frequency (Hz)')
plt.title('Mel Spectrogram')
plt.show()

# display(ipd.Audio(waveform.cpu().numpy(), rate=sample_rate))

plt.show()

In [8]:
model.eval()
model = model.to(device)

with torch.no_grad():
    predicted_audio = model.encoder_generator(random_audio.unsqueeze(0), style_audio.unsqueeze(0))

print(random_audio.shape, style_audio.shape, predicted_audio.shape)

display(ipd.Audio(random_audio.cpu().numpy(), rate=sample_rate))
display(ipd.Audio(style_audio.cpu().numpy(), rate=sample_rate))
predicted_audio = predicted_audio.squeeze(0)
display(ipd.Audio(predicted_audio.cpu().numpy(), rate=sample_rate))


In [9]:
plt.plot(random_audio.cpu().numpy())
plt.plot(style_audio.cpu().numpy())
plt.plot(predicted_audio.cpu().numpy())
plt.show()