In [1]:
import pytorch_lightning as pl
import numpy
import torchaudio
import torch
import torch.optim
import os
import pandas as pd
import matplotlib.pyplot as plt
import IPython.display as ipd

from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import Trainer
from torch import nn
from utils import AudioDataset

In [6]:
class Wave_Block(nn.Module):
    def __init__(self, in_channels, out_channels, dilation_rates, kernel_size):
        super(Wave_Block, self).__init__()
        self.num_rates = dilation_rates
        self.convs = nn.ModuleList()
        self.filter_convs = nn.ModuleList()
        self.gate_convs = nn.ModuleList()

        self.convs.append(nn.Conv1d(in_channels, out_channels, kernel_size=1))
        dilation_rates = [2 ** i for i in range(dilation_rates)]
        for dilation_rate in dilation_rates:
            self.filter_convs.append(
                nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, padding=int((dilation_rate*(kernel_size-1))/2), dilation=dilation_rate))
            self.gate_convs.append(
                nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, padding=int((dilation_rate*(kernel_size-1))/2), dilation=dilation_rate))
            self.convs.append(nn.Conv1d(out_channels, out_channels, kernel_size=1))

    def forward(self, x):
        x = self.convs[0](x)
        res = x
        for i in range(self.num_rates):
            x = torch.tanh(self.filter_convs[i](x)) * torch.sigmoid(self.gate_convs[i](x))
            x = self.convs[i + 1](x)
            res = res + x
        return res
    
class WaveNet(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, dilation, dropout=0.2):
        super(WaveNet, self).__init__()
        self.wave_block = Wave_Block(in_channels, out_channels, dilation, kernel_size)
        self.dropout = nn.Dropout(dropout)
        self.conv1 = nn.Conv1d(out_channels, out_channels, kernel_size=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        if x.ndim == 2:
            x = x.unsqueeze(1)
        x = self.wave_block(x)
        x = self.relu(self.conv1(x))
        x = self.dropout(x)
        return x
    
class WavenetAutoencoder(nn.Module):
    
        def __init__(self, in_channels, out_channels, kernel_size, dilation, dropout=0.2):
            super(WavenetAutoencoder, self).__init__()
            self.encoder = WaveNet(in_channels, out_channels, kernel_size, dilation, dropout)
            self.decoder = WaveNet(out_channels, in_channels, kernel_size, dilation, dropout)
    
        def forward(self, x):
            x = self.encoder(x)
            x = self.decoder(x)
            return x
    
class WavenetAutoencoderModule(pl.LightningModule):

    def __init__(self, in_channels, out_channels, kernel_size, dilation, device, dropout=0.2):
        super(WavenetAutoencoderModule, self).__init__()
        self.model = WavenetAutoencoder(in_channels, out_channels, kernel_size, dilation, dropout)
        self.loss = nn.MSELoss()
        self.todevice = device

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x = batch
        y = self(x)
        loss = self.loss(x, y)
        self.log('train_loss', loss)
        return loss 
    
    def validation_step(self, batch, batch_idx):
        x = batch
        y = self(x)
        loss = self.loss(x, y)
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.5)
        return {
            'optimizer': optimizer,
            'lr_scheduler': scheduler,
            'monitor': 'val_loss'
        }

In [7]:
metadata = pd.read_csv('metadata.csv')
batch_size = 8
dataset = AudioDataset(metadata, device)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.8), len(dataset) - int(len(dataset) * 0.8)])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=True)

model = WavenetAutoencoderModule(1, 256, 3, 12, device)
model = model.to(device)

trainer = pl.Trainer(
    accelerator='gpu',
    max_epochs=10,
    benchmark=True,
    # deterministic=True,
    precision=16,
    callbacks=[
        pl.callbacks.ModelCheckpoint(monitor='val_loss', mode='min'),
        pl.callbacks.LearningRateMonitor(logging_interval='step')
    ]
)

Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [8]:
trainer.fit(model, train_loader, val_loader)

Missing logger folder: c:\Users\David Arcos\Documents\GitHub\Style-Transfer-VC-Model\lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type               | Params
---------------------------------------------
0 | model | WavenetAutoencoder | 5.6 M 
1 | loss  | MSELoss            | 0     
---------------------------------------------
5.6 M     Trainable params
0         Non-trainable params
5.6 M     Total params
11.162    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  return F.mse_loss(input, target, reduction=self.reduction)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
# model test
model.eval()
# create random audio from train data
random_audio = train_dataset[0]
# to device
random_audio = random_audio.to(device)
# model to device
model = model.to(device)

# predict
with torch.no_grad():
    predicted_audio = model(random_audio.unsqueeze(0))

print(random_audio.shape)
display(ipd.Audio(random_audio.cpu().numpy(), rate=sample_rate))
predicted_audio = predicted_audio.squeeze(0)
print(predicted_audio.shape)
display(ipd.Audio(predicted_audio.cpu().numpy(), rate=sample_rate))


# plot
# plt.plot(random_audio.cpu().numpy())
# plt.plot(predicted_audio.cpu().numpy())
# plt.show()


torch.Size([16000])


torch.Size([1, 16000])


  scaled = data / normalization_factor * 32767
  return scaled.astype("<h").tobytes(), nchan


In [None]:
!tensorboard --logdir=lightning_logs

^C
