In [56]:
import torch 
import torch.nn.functional as F 
from torch import nn
from tqdm import tqdm
from torch.utils.data import DataLoader

import numpy as np
from sklearn.model_selection import train_test_split


Coded from scratch using https://www.youtube.com/watch?v=VELQT1-hILo&t=1395s

In [57]:
# Input wave --> hideen dim --> mean, std --> parametrization trick -> decoder -> output wave
class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, h_dim=200, z_dim=20):
        super().__init__()
        #encoder
        self.wave_2hid = nn.Linear(input_dim, h_dim)
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)

        #decoder
        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.hid_2wave = nn.Linear(h_dim, input_dim)

        self.relu = nn.ReLU()

    def encode(self, x):
        h = self.relu(self.wave_2hid(x))
        mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
        return  mu, sigma

    def decode(self, z):
        h = self.relu(self.z_2hid(z))
        # no need to sigmoid here i think because we are not trying to limit the range from 0 to 1
        return self.hid_2wave(h)

    def forward(self, x):
        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma)
        z_new = mu + sigma * epsilon
        x_reconstructed = self.decode(z_new)
        return x_reconstructed, mu, epsilon


In [58]:
# testing
x = torch.randn(4, 200, 2)
# 4 represents the number of data segments, 200 is the number of data points in each segment, 2 represents 2 different channels
x_flat = x.view(x.size(0), -1) # reshape into (4, 200)
vae = VariationalAutoEncoder(input_dim=400)
x_reconstructed, mu, sigma = vae(x_flat)
print(x_reconstructed.shape)
print(mu.shape)
print(sigma.shape)

torch.Size([4, 400])
torch.Size([4, 20])
torch.Size([4, 20])


In [59]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_dim = 400
h_dim = 200
z_dim = 20 # how much compression
num_epochs = 10
batch_size = 32 # when we train the NN, we only put in this number of samples into the network each time, we could increase the batch size 
lr_rate = 3e-5   

In [60]:
torch.cuda.is_available()

True

In [61]:
# dataset loading
background = np.load('../data/background.npz')['data']
bbh = np.load('../data/bbh_for_challenge.npy')
sglf = np.load('../data/sglf_for_challenge.npy')

def normalize(data):
    stds = np.std(data, axis=-1, keepdims=True)
    return data / stds

background = normalize(background)
x_flat = x.view(x.size(0), -1)
bbh = normalize(bbh)
sglf = normalize(sglf)

x_train, x_test = train_test_split(background, test_size=0.2, random_state=42)

x_train = torch.tensor(x_train, dtype=torch.float32)
x_test = torch.tensor(x_test, dtype=torch.float32)
bbh = torch.tensor(bbh, dtype=torch.float32)
sglf = torch.tensor(sglf, dtype=torch.float32)

x_train = x_train.view(x_train.size(0), -1)
x_test = x_test.view(x_test.size(0), -1)
bbh = bbh.view(bbh.size(0), -1)
sglf = sglf.view(sglf.size(0), -1)

train_loader = DataLoader(dataset=x_train, batch_size=batch_size)

model = VariationalAutoEncoder(input_dim, h_dim, z_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate, weight_decay=1e-5)
loss_fn = nn.MSELoss(reduction="sum")

In [62]:
for epoch in range(num_epochs):
    loop = tqdm(enumerate(train_loader))
    for i, x in loop:
        # forward pass
        x = x.to(device).view(x.shape[0], input_dim) # keep the inital sample size of 100,000 but reshape into the input dim of 400
        x_reconstructed, mu, sigma = model(x)

        # compute loss
        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

        # backprop
        loss = reconstruction_loss + kl_div # could add an alpha term to make it a disentangled autoencoder
        optimizer.zero_grad() # ensure no gradients before
        loss.backward() 
        optimizer.step() # compute new gradients
        loop.set_postfix(loss=loss.item())



2500it [00:06, 368.39it/s, loss=1.32e+4]
2500it [00:07, 356.34it/s, loss=1.29e+4]
2500it [00:06, 363.50it/s, loss=1.27e+4]
2500it [00:06, 385.08it/s, loss=1.26e+4]
2500it [00:06, 380.73it/s, loss=1.25e+4]
2500it [00:06, 381.71it/s, loss=1.24e+4]
2500it [00:06, 398.50it/s, loss=1.25e+4]
2500it [00:06, 376.85it/s, loss=1.24e+4]
2500it [00:06, 377.87it/s, loss=1.23e+4]
2500it [00:06, 376.63it/s, loss=1.24e+4]


In [63]:
def compute_reconstruction_errors(model, data):
    model.eval()
    data = data.to(device) 
    num_samples = data.size(0)
    with torch.no_grad():
        data = data.view(data.size(0), -1)
        reconstructed, mu, sigma = model(data)
        reconstruction_loss = loss_fn(reconstructed, data) / num_samples
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2)) / num_samples
        total_loss = reconstruction_loss + kl_div
    return reconstruction_loss.item(), kl_div.item(), total_loss.item()


In [64]:
x_test = x_test.to(device)
sglf = sglf.to(device)
bbh = bbh.to(device)

In [65]:
test_rec_loss, test_kl_loss, test_total_loss = compute_reconstruction_errors(model, x_test)
sglf_rec_loss, sglf_kl_loss, sglf_total_loss = compute_reconstruction_errors(model, sglf)
bbh_rec_loss, bbh_kl_loss, bbh_total_loss = compute_reconstruction_errors(model, bbh)

print(f"Test Set - Reconstruction Loss: {test_rec_loss}, KL Divergence: {test_kl_loss}, Total Loss: {test_total_loss}")
print(f"SGLF Set - Reconstruction Loss: {sglf_rec_loss}, KL Divergence: {sglf_kl_loss}, Total Loss: {sglf_total_loss}")
print(f"BBH Set - Reconstruction Loss: {bbh_rec_loss}, KL Divergence: {bbh_kl_loss}, Total Loss: {bbh_total_loss}")

Test Set - Reconstruction Loss: 360.8668518066406, KL Divergence: 27.032485961914062, Total Loss: 387.89935302734375
SGLF Set - Reconstruction Loss: 355.3948059082031, KL Divergence: 27.199203491210938, Total Loss: 382.593994140625
BBH Set - Reconstruction Loss: 370.8006591796875, KL Divergence: 26.72834014892578, Total Loss: 397.52899169921875


Ideally the loss should be significantly higher for the test set compared to the sglf and bbh sets