In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import tensorflow as tf
from tensorflow.python.client import device_lib
from scipy.io import wavfile
from torch.utils.data import Dataset
from tensorflow.keras import layers, Model
from tensorflow.keras.layers import Conv1D, Conv2D, MaxPooling2D, Flatten, Dense, Input, Dropout, BatchNormalization, Reshape, Conv1DTranspose, UpSampling1D, UpSampling2D, MaxPooling1D, UpSampling1D, Layer, Embedding, Add, Multiply
import torchaudio
from sklearn.model_selection import train_test_split
import IPython.display as ipd
import matplotlib.pyplot as plt
import wandb
from keras.models import Model
from keras.layers import Input, Dense, Flatten, Reshape, Concatenate
from keras.layers import Conv1D, Conv2D, MaxPooling2D, LeakyReLU
from keras.layers import BatchNormalization, Dropout
from keras.optimizers import Adam
from keras import backend as K
import gc
import time

EPOCHS = 50
BATCH_SIZE = 50
TRAINING_DATA_AMOUNT = 100000
SAMPLE_RATE = 24000
FRAME_LENGTH = int(SAMPLE_RATE * 0.02)
LEARNING_RATE = 0.001


class DataSetLoader():
    def __init__(self,
                dev_data_set_path='dev-clean/',
                train_data_set_path='train-clean-360/',
                test_data_set_path='test-clean/'):
        self.dev_data_set_path = dev_data_set_path
        self.train_data_set_path = train_data_set_path
        self.test_data_set_path = test_data_set_path
        self.file_names = []

    def load_data_set(self, data_set_path):
        for subdir, dirs, files in os.walk(data_set_path):
            for file in files:
                file_path = os.path.join(subdir, file)
                if file.split(".")[1] == "wav":
                    self.file_names.append(file_path)

        print("Loaded in: \n- wave_files:", len(self.file_names))

        return self.file_names

    @staticmethod
    def sample_to_frames(sample_data, frame_length=FRAME_LENGTH, frame_step=SAMPLE_RATE):
        return tf.signal.frame(sample_data, frame_length=frame_length, frame_step=frame_step, pad_end=True, pad_value=0)
    
    def __getitem__(self, index):
        return wavfile.read(self.file_names[index])

dataset_loader = DataSetLoader()
dataset_files = dataset_loader.load_data_set(dataset_loader.train_data_set_path)
dataset_files = dataset_files[:TRAINING_DATA_AMOUNT]

train_dataset, val_dataset = torch.utils.data.random_split(dataset_files, [int(TRAINING_DATA_AMOUNT * 0.8), int(TRAINING_DATA_AMOUNT * 0.2)])



In [None]:
import matplotlib.pyplot as plt
import IPython.display as ipd

# Initialize WandB and other configurations
for PACKET_LOSS_PERCENTAGE in [0.8, 0.6, 0.4, 0.2, 0.0]:
    wandb.init(
        project="FINAL_AE100_SWEEPING_PACKET_LOSS_ZEROS",
        config={
            "architecture": "AutoEncoder",
            "dataset": "LibriTTS Corpus",
            "epochs": EPOCHS,
            "batch_size": BATCH_SIZE,
            "training_data_amount": TRAINING_DATA_AMOUNT,
            "learning_rate": LEARNING_RATE,
            "packet_loss_percentage": PACKET_LOSS_PERCENTAGE,
            "model_complexity": "100percent latent",
            "frame_length": FRAME_LENGTH,
        }
    )

    # Define the autoencoder model
    class MyAutoEncoder(nn.Module):
        def __init__(self):
            super(MyAutoEncoder, self).__init__()
            # Define the layers of the autoencoder
            self.encoder = nn.Sequential(
                nn.Conv1d(in_channels=1, out_channels=32, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
                nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
                nn.Conv1d(in_channels=16, out_channels=8, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2)
            )
            self.decoder = nn.Sequential(
                nn.ConvTranspose1d(in_channels=8, out_channels=16, kernel_size=2, stride=2),
                nn.ReLU(),
                nn.ConvTranspose1d(in_channels=16, out_channels=32, kernel_size=2, stride=2),
                nn.ReLU(),
                nn.ConvTranspose1d(in_channels=32, out_channels=1, kernel_size=2, stride=2),
                nn.Tanh() # Mapping the output to [-1, 1] as the input is also in this range
            )

        def forward(self, x):
            # Pass input through encoder
            x = self.encoder(x)

            # Simulate packet loss at the defined packet loss percentage
            for batch in x:
                for packet in batch:
                    if torch.rand(1) < PACKET_LOSS_PERCENTAGE:
                        packet = torch.full_like(packet, -1)
            
            # Pass the output through the decoder
            x = self.decoder(x)
            return x

    # Set device (GPU if available, else CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create an instance of the autoencoder
    autoencoder = MyAutoEncoder().to(device)

    # Define loss function
    criterion = nn.MSELoss()

    # Define optimizer
    optimizer = optim.Adam(autoencoder.parameters(), lr=LEARNING_RATE)

    def plot_validation_data():
        for i in range(5):
            input, sr = torchaudio.load(val_dataset[i])
            input = torch.unsqueeze(input, 0).to(device)
            output = autoencoder(input)
            
            plt.figure(figsize=(10, 4))
            plt.plot(input.squeeze().detach().cpu().numpy(), label="Input")
            plt.plot(output.squeeze().detach().cpu().numpy(), label="Output")
            plt.legend()
            plt.show()

            ipd.display(ipd.Audio(input.squeeze().detach().cpu().numpy(), rate=sr))
            ipd.display(ipd.Audio(output.squeeze().detach().cpu().numpy(), rate=sr))

    # Training loop
    for epoch in range(EPOCHS):
        running_loss = 0.0
        for data in train_dataset:
            inputs, sr = torchaudio.load(data)
            inputs = torch.split(inputs, FRAME_LENGTH, dim=1)

            # If the last input is not the correct size, remove it
            if inputs[-1].size(1) != FRAME_LENGTH:
                inputs = inputs[:-1]

            for i in range(0, len(inputs), BATCH_SIZE):
                # Gather the inputs for the batch
                inputs_batch = torch.stack(inputs[i:i+BATCH_SIZE]).to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass
                outputs = autoencoder(inputs_batch)

                # Compute the loss
                loss = criterion(outputs, inputs_batch[:, :, :outputs.size(2)])

                # Backward pass and optimization
                loss.backward()
                optimizer.step()

                # Log the loss
                running_loss += loss.item() * inputs_batch.size(0)
        
        training_loss = running_loss / len(train_dataset)
        validation_loss = 0.0
        for data in val_dataset:
            inputs, sr = torchaudio.load(data)
            inputs = torch.unsqueeze(inputs, 0).to(device)
            outputs = autoencoder(inputs)
            loss = criterion(outputs, inputs[:, :, :outputs.size(2)])
            validation_loss += loss.item() * inputs.size(0)
        wandb.log({"train_loss": training_loss, "val_loss": validation_loss, "epoch": epoch})

    # Save 5 samples of the validation data
    for i in range(5):
        input, sr = torchaudio.load(val_dataset[i])
        output = autoencoder(input.unsqueeze(0))

        print(f"Packet Loss Percentage: {PACKET_LOSS_PERCENTAGE}")
        plt.figure(figsize=(10, 4))
        plt.plot(input.squeeze().detach().cpu().numpy())
        plt.plot(output.squeeze().detach().cpu().numpy())
        plt.show()

        ipd.display(ipd.Audio(input.squeeze().detach().cpu().numpy(), rate=sr))
        ipd.display(ipd.Audio(output.squeeze().detach().cpu().numpy(), rate=sr))
        print("Validation Loss: ", loss)
        print("--------------------")
    
    # Calculating the inference time for a frame
    input, sr = torchaudio.load(val_dataset[0])
    input = input[:, :FRAME_LENGTH]
    start_time = time.time()
    output = autoencoder(input.unsqueeze(0))
    end_time = time.time()
    inference_time = end_time - start_time
    print("Inference Time: ", inference_time)


    wandb.finish()
    print("Finished Training")