In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import tensorflow as tf
from tensorflow.python.client import device_lib
from scipy.io import wavfile
from torch.utils.data import Dataset
from tensorflow.keras import layers, Model
from tensorflow.keras.layers import Conv1D, Conv2D, MaxPooling2D, Flatten, Dense, Input, Dropout, BatchNormalization, Reshape, Conv1DTranspose, UpSampling1D, UpSampling2D, MaxPooling1D, UpSampling1D, Layer, Embedding, Add, Multiply
import torchaudio
from sklearn.model_selection import train_test_split
import IPython.display as ipd
import matplotlib.pyplot as plt
import wandb
from keras.models import Model
from keras.layers import Input, Dense, Flatten, Reshape, Concatenate
from keras.layers import Conv1D, Conv2D, MaxPooling2D, LeakyReLU
from keras.layers import BatchNormalization, Dropout
from keras.optimizers import Adam
from keras import backend as K
import gc
import time

EPOCHS = 50
BATCH_SIZE = 50
TRAINING_DATA_AMOUNT = 100000
SAMPLE_RATE = 24000
FRAME_LENGTH = int(SAMPLE_RATE * 0.02)
LEARNING_RATE = 0.001
WANDB_LOG = True

NUM_LATENTS = 16
LATENT_DIM = 16
COMMITMENT_COST = 0.8

In [None]:

import matplotlib.pyplot as plt
import IPython.display as ipd
from vector_quantize_pytorch import VectorQuantize


def plot_validation_data():
    for i in range(5):
        input, sr = torchaudio.load(val_dataset[i])
        input = torch.unsqueeze(input, 0)
        output = autoencoder(input)
        
        plt.figure(figsize=(10, 4))
        plt.plot(input.squeeze().detach().cpu().numpy())
        plt.plot(output.squeeze().detach().cpu().numpy())
        plt.show()

        ipd.display(ipd.Audio(input.squeeze().detach().cpu().numpy(), rate=sr))
        ipd.display(ipd.Audio(output.squeeze().detach().cpu().numpy(), rate=sr))


class DataSetLoader():
    def __init__(self,
                dev_data_set_path='dev-clean/',
                train_data_set_path='train-clean-360/',
                test_data_set_path='test-clean/'):
        self.dev_data_set_path = dev_data_set_path
        self.train_data_set_path = train_data_set_path
        self.test_data_set_path = test_data_set_path
        self.file_names = []

    def load_data_set(self, data_set_path):
        for subdir, dirs, files in os.walk(data_set_path):
            for file in files:
                file_path = os.path.join(subdir, file)
                if file.split(".")[1] == "wav":
                    self.file_names.append(file_path)

        print("Loaded in: \n- wave_files:", len(self.file_names))

        return self.file_names
    
    def __getitem__(self, index):
        return wavfile.read(self.file_names[index])


dataset_loader = DataSetLoader()    
dataset_files = dataset_loader.load_data_set(dataset_loader.train_data_set_path)
dataset_files = dataset_files[:TRAINING_DATA_AMOUNT]
train_dataset, val_dataset = torch.utils.data.random_split(dataset_files, [int(TRAINING_DATA_AMOUNT * 0.8), int(TRAINING_DATA_AMOUNT * 0.2)])
for PACKET_LOSS_PERCENTAGE in [0.0, 0.2, 0.4, 0.6, 0.8]:
    LATENT_DIM = NUM_LATENTS
    if WANDB_LOG:
        wandb.init(
            project="FINAL_VQ_SWEEPING_PACKET_LOSS_TEST",
            config={
            "architecture": "AutoEncoderVQ_PACKET_LOSS",
            "dataset": "LibriTTS Corpus",
            "epochs": EPOCHS,
            "batch_size": BATCH_SIZE,
            "training_data_amount": TRAINING_DATA_AMOUNT,
            "learning_rate": LEARNING_RATE,
            "model_complexity": "25percent latent VQ",
            "packet_loss_percentage": PACKET_LOSS_PERCENTAGE,
            "frame_length": FRAME_LENGTH,
            "num_latents": NUM_LATENTS,
            "latent_dim": LATENT_DIM,
            "commitment_cost": COMMITMENT_COST,
            }
        )


    # Define the autoencoder model with VQ
    class MyAutoEncoder(nn.Module):
        def __init__(self):
            super(MyAutoEncoder, self).__init__()
            self.encoder = nn.Sequential(
                nn.Conv1d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
                nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
                nn.Conv1d(in_channels=16, out_channels=8, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
                nn.Conv1d(in_channels=8, out_channels=LATENT_DIM, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
            )
            # Vector quantization layer defined with specified number of latents
            self.quantize = VectorQuantize(
                dim=LATENT_DIM, 
                codebook_size=NUM_LATENTS, 
                decay=0.8, 
                kmeans_init=True,
                kmeans_iters=10,
            )
            self.decoder = nn.Sequential(
                nn.ConvTranspose1d(in_channels=NUM_LATENTS, out_channels=8, kernel_size=2, stride=2),
                nn.ReLU(),
                nn.ConvTranspose1d(in_channels=8, out_channels=16, kernel_size=2, stride=2),
                nn.ReLU(),
                nn.ConvTranspose1d(in_channels=16, out_channels=32, kernel_size=2, stride=2),
                nn.ReLU(),
                nn.ConvTranspose1d(in_channels=32, out_channels=1, kernel_size=2, stride=2),
                nn.Tanh() # Mapping the output to [-1, 1] as the input is also in this range
            )

        def forward(self, x):
            # Pass input through encoder
            x = self.encoder(x)
            # Save the shape for reshaping after vector quantization
            original_shape = x.shape
            # Reshape for vector quantization
            x = x.view(-1, original_shape[1])
            # Apply vector quantization
            quantized, indices, commit_loss = self.quantize(x)

            # Introduce packet loss to each row independently based on PACKET_LOSS_PERCENTAGE
            for i in range(quantized.shape[0]):
                if torch.rand(1) < PACKET_LOSS_PERCENTAGE:
                    quantized[i] = torch.zeros(quantized[i].shape)

            # Reshape back to the original shape
            quantized = quantized.view(*original_shape)

            # Pass encoded representation through decoder
            x = self.decoder(quantized)
            return x, commit_loss

    # Set device (GPU if available, else CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create an instance of the autoencoder with VQ
    autoencoder = MyAutoEncoder().to(device)

    # Define loss function
    criterion = nn.MSELoss()

    # Define optimizer
    optimizer = optim.Adam(autoencoder.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)


    for epoch in range(EPOCHS):
        running_loss = 0.0
        for data in train_dataset:
            inputs, sr = torchaudio.load(data)

            # Split the input into frames
            inputs = torch.split(inputs, FRAME_LENGTH, dim=1)

            # Remove the last frame if it is not the correct size
            if inputs[-1].size(1) != FRAME_LENGTH:
                inputs = inputs[:-1]

            num_batches = len(inputs) // BATCH_SIZE

            for i in range(0, len(inputs), BATCH_SIZE):
                # Create a batch of inputs
                inputs_batch = torch.stack(inputs[i:i+BATCH_SIZE])

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass
                outputs, commit_loss = autoencoder(inputs_batch)


                # Calculate the loss
                reconstruction_loss = criterion(outputs, inputs_batch)
                loss = reconstruction_loss  + COMMITMENT_COST * commit_loss

                # Backward pass
                loss.backward()

                # Optimize
                optimizer.step()

                # Log the loss
                running_loss += loss.item() * inputs_batch.size(0)
        scheduler.step()

        validation_loss = 0.0
        for data in val_dataset:
            inputs, sr = torchaudio.load(data)
            inputs = torch.unsqueeze(inputs, 0).to(device)
            outputs, commit_loss = autoencoder(inputs)
            loss = criterion(outputs, inputs[:, :, :outputs.size(2)])
            validation_loss += loss.item() * inputs.size(0)

        
        running_loss /= len(train_dataset)
        print(f"Epoch {epoch}, Loss: {running_loss:.6f}", f"Validation Loss: {validation_loss:.6f}")
        print("Average Validation Loss: ", validation_loss / len(val_dataset))

        if WANDB_LOG:
            wandb.log({"epoch": epoch, "train_loss": running_loss, "val_loss": validation_loss})

    for i in range(5):
        # Log the inference time
        input, sr = torchaudio.load(val_dataset[i])
        start_time = time.time()
        output, vq_loss = autoencoder(input.unsqueeze(0))
        end_time = time.time()
        inference_time = end_time - start_time
        
        print(f"Packet Loss Percentage: {PACKET_LOSS_PERCENTAGE}")
        plt.figure(figsize=(10, 4))
        plt.plot(input.squeeze().detach().cpu().numpy())
        plt.plot(output.squeeze().detach().cpu().numpy())
        plt.show()
        

        ipd.display(ipd.Audio(input.squeeze().detach().cpu().numpy(), rate=sr))
        ipd.display(ipd.Audio(output.squeeze().detach().cpu().numpy(), rate=sr))
        print("--------------------")

    wandb.finish()
    print("Finished Training")