In [89]:
import pytorch_lightning as pl
import numpy
import torchaudio
import torch
import torch.optim
import os
import pandas as pd
import matplotlib.pyplot as plt
import IPython.display as ipd

from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import Trainer
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoderLayer
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
from utils import AudioDataset, Wave_Block, WaveNet

In [90]:
class StyleAutoEncoder(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, dilation, device, n_head=8, dropout=0.2):
        super(StyleAutoEncoder, self).__init__()
        self.styleEncoder = WaveNet(in_channels, out_channels, kernel_size, dilation, dropout)
        self.speechEncoder = WaveNet(in_channels, out_channels, kernel_size, dilation, dropout)
        self.style_encoder_attention = TransformerEncoderLayer(out_channels, nhead=n_head)
        self.speech_encoder_attention = TransformerEncoderLayer(out_channels, nhead=n_head)
        # self.decoder_attention = TransformerDecoderLayer(out_channels*2, nhead=n_head)
        self.decoder = WaveNet(out_channels*2, in_channels, kernel_size, dilation, dropout)
        # self.wav2vec = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
        # self.tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")

    def forward(self, speech, style):
        if x.ndim == 2: x = x.unsqueeze(1)

        style = self.styleEncoder(style)
        speech = self.speechEncoder(speech)
        
        print(style.shape)
        style = style.permute(2, 0, 1)  # Reordenamos las dimensiones para que sea compatible con la capa de atención
        print(style.shape)
        style = self.style_encoder_attention(style)
        style = style.permute(1, 2, 0)  # Restauramos el orden de las dimensiones
        
        speech = speech.permute(2, 0, 1)  # Reordenamos las dimensiones para que sea compatible con la capa de atención
        speech = self.speech_encoder_attention(speech)
        speech = speech.permute(1, 2, 0)  # Restauramos el orden de las dimensiones

        x = torch.cat((style, speech), dim=1)
        #x = x.permute(2, 0, 1)  # Reordenamos las dimensiones para que sea compatible con la capa de atención
        # x = self.decoder_attention(x)
        #x = x.permute(1, 2, 0)  # Restauramos el orden de las dimensiones

        x = self.decoder(x)
        # input_ids = self.tokenizer(x, return_tensors="pt").input_values
        #logits = self.wav2vec(input_ids).logits
        # transcriptions = self.tokenizer.batch_decode(logits.logits, skip_special_tokens=True)
        return x.squeeze(0) #, transcriptions


In [91]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [92]:
batch_size = 8
sample_rate = 16000
metadata = pd.read_csv('metadata.csv')
dataset = AudioDataset(metadata, device, sample_rate=sample_rate)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.8), len(dataset) - int(len(dataset) * 0.8)])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True)
model = StyleAutoEncoder(1, 8, 3, 12, device=device)
model = model.to(device)

In [93]:
trainer = pl.Trainer(
    accelerator='gpu',
    max_epochs=10,
    benchmark=True,
    # deterministic=True,
    precision=16,
    callbacks=[
        pl.callbacks.ModelCheckpoint(monitor='val_loss', mode='min'),
        pl.callbacks.LearningRateMonitor(logging_interval='step')
    ]
)

Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [94]:
# trainer.fit(model, train_loader, val_loader)

In [95]:
random_audio = train_dataset[0]
random_audio = random_audio.to(device)
model.eval()
model = model.to(device)
with torch.no_grad():
    predicted_audio = model(random_audio.unsqueeze(0))

print(random_audio.shape)
display(ipd.Audio(random_audio.cpu().numpy(), rate=sample_rate))
predicted_audio = predicted_audio.squeeze(0)
print(predicted_audio.shape)
display(ipd.Audio(predicted_audio.cpu().numpy(), rate=sample_rate))

# plt.plot(random_audio.cpu().numpy())
plt.plot(predicted_audio.cpu().numpy())
plt.show()

In [96]:
random_audio = train_dataset[0]
random_audio = random_audio.to(device)

audio_path = metadata['path'][0]
print(audio_path)

# Cargar el modelo pre-entrenado y el tokenizer
# wav2vec = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")#.to(device)
# tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")

# input_ids = tokenizer(audio_path, return_tensors="pt").input_values
# logits = wav2vec(input_ids).logits
# transcriptions = tokenizer.batch_decode(logits, skip_special_tokens=True)

VoxCelebTest/id10270/5r0dWxy17C8/00001.wav
