In [None]:
from architectures.gru_seq2seq_bidirectional_enc import Builder as AudioVAEBuilder
from architectures.gru_seq2seq_bidirectional_enc import Wrapper as AudioVAEWrapper

from readers import AudioReader

In [None]:
from utils.audio import (
    concat_FT,
    reverse_FT,
    get_waveform_from_spectrogram_tensor,
    play_audio
)

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torch
from scipy import signal
import sounddevice as sd
import numpy as np

# Model Definition

In [5]:
builder = AudioVAEBuilder()
audio_model = builder.build(
    embedding_dim=2050,
    latent_dim=128,
    context_length=944,
    num_layers=1
)

In [6]:
audio_wrapper = AudioVAEWrapper(audio_model)

# Data preparation

In [8]:
fourier_params = {
    'fs': 16000,
    'window_size': 2048,
    'window_shift': 1024,
    'type': "hamming"
}

In [9]:
dataset = AudioReader(fourier_params)

In [10]:
for audio in dataset:
    print(audio.shape)
    break

torch.Size([1, 944, 1025])


In [11]:
audio_model.cuda()

ImageVAE(
  (encoder): BidirectionalEncoder(
    (encoder): GRU(2050, 2050, batch_first=True, bidirectional=True)
    (mu_proj): Sequential(
      (0): Linear(in_features=4100, out_features=128, bias=True)
    )
    (sigma_proj): Sequential(
      (0): Linear(in_features=4100, out_features=128, bias=True)
    )
  )
  (decoder): AutoregressiveDecoder(
    (proj_h): Linear(in_features=128, out_features=2050, bias=True)
    (decoder): GRU(2050, 2050, batch_first=True)
  )
)

In [12]:
for audio in dataset:
    X = concat_FT(audio).cuda()
    output = audio_wrapper(X)
    print(X.shape, output[0].shape)
    break

torch.Size([1, 944, 2050]) torch.Size([1, 944, 2050])


# Training

In [13]:
dataloader = DataLoader(dataset, 32, shuffle=True)

In [14]:
optimizer = torch.optim.Adam(audio_wrapper.parameters(), lr=1e-3, betas=(0.5, 0.999), weight_decay=1e-5)
criterion = nn.MSELoss()

In [15]:
epochs = 100

In [16]:
audio_model.train()

ImageVAE(
  (encoder): BidirectionalEncoder(
    (encoder): GRU(2050, 2050, batch_first=True, bidirectional=True)
    (mu_proj): Sequential(
      (0): Linear(in_features=4100, out_features=128, bias=True)
    )
    (sigma_proj): Sequential(
      (0): Linear(in_features=4100, out_features=128, bias=True)
    )
  )
  (decoder): AutoregressiveDecoder(
    (proj_h): Linear(in_features=128, out_features=2050, bias=True)
    (decoder): GRU(2050, 2050, batch_first=True)
  )
)

In [None]:
for epoch in range(epochs):
    total_loss = 0
    for audio in dataset:
        optimizer.zero_grad()
        audio = concat_FT(audio).cuda()

        audio_pred, _, _ = audio_wrapper(audio, audio, teacher_forcing=0.5)
        loss = criterion(audio_pred, audio)
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    print(f'Epoch: {epoch}, Loss: {total_loss/len(dataset)}')
    total_loss = 0

Epoch: 0, Loss: 4.588586220734214e-05
Epoch: 1, Loss: 9.217205045075616e-07
Epoch: 2, Loss: 2.8083185218408244e-07
Epoch: 3, Loss: 2.7794322525753757e-07
Epoch: 4, Loss: 2.7914018762231763e-07
Epoch: 5, Loss: 2.8066664911507997e-07
Epoch: 6, Loss: 2.8208117858863346e-07
Epoch: 7, Loss: 2.8364625015875333e-07


# Eval

In [19]:
audio_model.eval()

ImageVAE(
  (encoder): BidirectionalEncoder(
    (encoder): GRU(2050, 2050, batch_first=True, bidirectional=True)
    (mu_proj): Sequential(
      (0): Linear(in_features=4100, out_features=128, bias=True)
    )
    (sigma_proj): Sequential(
      (0): Linear(in_features=4100, out_features=128, bias=True)
    )
  )
  (decoder): AutoregressiveDecoder(
    (proj_h): Linear(in_features=128, out_features=2050, bias=True)
    (decoder): GRU(2050, 2050, batch_first=True)
  )
)

In [20]:
for wave in dataset:
    break

In [23]:
audio = get_waveform_from_spectrogram_tensor(wave.cpu(), fourier_params)
play_audio(audio)

In [24]:
wave_recon = audio_model(concat_FT(wave).cuda())[0]
audio_recon = reverse_FT(wave_recon)
audio_recon = get_waveform_from_spectrogram_tensor(audio_recon.detach().cpu(), fourier_params)
play_audio(audio_recon)