This model will still be a variational autoencoder, but I will use a CNN within the architecture. 

In [172]:
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader

import numpy as np
from sklearn.model_selection import train_test_split

import csv

In [173]:
class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_channels=2, input_length=200, h_dim=128, z_dim=20):
        super().__init__()

        # encoder
        self.encoder = nn.Sequential(
            nn.Conv1d(input_channels, 16, kernel_size=3, stride=2, padding=1), # (N, 16, 100)
            nn.ReLU(),
            nn.Conv1d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
        )
        self.flatten = nn.Flatten()

        # code
        self.fc_mu = nn.Linear(64 * 25, z_dim)
        self.fc_sigma = nn.Linear(64 * 25, z_dim)
        self.fc_z = nn.Linear(z_dim, 64 * 25)

        # decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(16, input_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
        )

    def encode(self, x):
        h = self.encoder(x)
        h_flat = self.flatten(h)
        mu = self.fc_mu(h_flat)
        sigma = self.fc_sigma(h_flat)
        return mu, sigma

    def decode(self, z):
        h = self.fc_z(z)
        h = h.view(h.size(0), 64, 25)
        x_reconstructed = self.decoder(h)
        return x_reconstructed

    def forward(self, x):
        mu, sigma = self.encode(x)
        epsilon = torch.randn_like(sigma)
        z_new = mu + sigma * epsilon
        x_reconstructed = self.decode(z_new)
        return x_reconstructed, mu, sigma

In [174]:
# testing on a randon data
x = torch.randn(4, 2, 200)
vae = VariationalAutoEncoder()
x_reconstructed, mu, sigma = vae(x)
print(x_reconstructed.shape)
print(mu.shape)
print(sigma.shape)

torch.Size([4, 2, 200])
torch.Size([4, 20])
torch.Size([4, 20])


In [175]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
z_dim=40
num_epochs = 10
batch_size = 32
lr_rate = 3e-5
weight_decay = None
device

device(type='cuda')

In [176]:
# dataset loading
background = np.load('../data/background.npz')['data']
bbh = np.load('../data/bbh_for_challenge.npy')
sglf = np.load('../data/sglf_for_challenge.npy')

def normalize(data):
    stds = np.std(data, axis=-1, keepdims=True)
    return data / stds

background = normalize(background)
bbh = normalize(bbh)
sglf = normalize(sglf)

x_train, x_test = train_test_split(background, test_size=0.2, random_state=42)

x_train = torch.tensor(x_train, dtype=torch.float32)
x_test = torch.tensor(x_test, dtype=torch.float32)
bbh = torch.tensor(bbh, dtype=torch.float32)
sglf = torch.tensor(sglf, dtype=torch.float32)

train_loader = DataLoader(dataset=x_train, batch_size=batch_size)

In [177]:
# model setup
model = VariationalAutoEncoder(z_dim=z_dim).to(device)
if weight_decay is None:
    optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate)
else:
    optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate, weight_decay=weight_decay)
loss_fn = nn.MSELoss(reduction="sum")


In [178]:
for epoch in range(num_epochs):
    loop = tqdm(enumerate(train_loader))
    for i, x in loop:
        # forward pass
        x = x.to(device) 
        x_reconstructed, mu, sigma = model(x)

        # compute loss
        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

        # backprop
        loss = reconstruction_loss + kl_div # could add an alpha term to make it a disentangled autoencoder
        optimizer.zero_grad() # ensure no gradients before
        loss.backward() 
        optimizer.step() # compute new gradients
        loop.set_postfix(loss=loss.item())



2500it [00:08, 289.29it/s, loss=1.32e+4]
2500it [00:09, 274.31it/s, loss=1.28e+4]
2500it [00:08, 286.99it/s, loss=1.28e+4]
2500it [00:08, 282.05it/s, loss=1.28e+4]
2500it [00:09, 277.32it/s, loss=1.28e+4]
2500it [00:08, 286.99it/s, loss=1.28e+4]
2500it [00:09, 270.56it/s, loss=1.28e+4]
2500it [00:08, 292.03it/s, loss=1.28e+4]
2500it [00:08, 278.99it/s, loss=1.28e+4]
2500it [00:09, 270.99it/s, loss=1.28e+4]


In [179]:
def compute_reconstruction_errors(model, data):
    model.eval()
    data = data.to(device) 
    num_samples = data.size(0)
    with torch.no_grad():
        reconstructed, mu, sigma = model(data)
        reconstruction_loss = loss_fn(reconstructed, data) / num_samples
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2)) / num_samples
        total_loss = reconstruction_loss + kl_div
    return reconstruction_loss.item(), kl_div.item(), total_loss.item()

In [180]:
x_test = x_test.to(device)
sglf = sglf.to(device)
bbh = bbh.to(device)

In [181]:
test_rec_loss, test_kl_loss, test_total_loss = compute_reconstruction_errors(model, x_test)
sglf_rec_loss, sglf_kl_loss, sglf_total_loss = compute_reconstruction_errors(model, sglf)
bbh_rec_loss, bbh_kl_loss, bbh_total_loss = compute_reconstruction_errors(model, bbh)

print(f"Test Set - Reconstruction Loss: {test_rec_loss}, KL Divergence: {test_kl_loss}, Total Loss: {test_total_loss}")
print(f"SGLF Set - Reconstruction Loss: {sglf_rec_loss}, KL Divergence: {sglf_kl_loss}, Total Loss: {sglf_total_loss}")
print(f"BBH Set - Reconstruction Loss: {bbh_rec_loss}, KL Divergence: {bbh_kl_loss}, Total Loss: {bbh_total_loss}")

Test Set - Reconstruction Loss: 400.011962890625, KL Divergence: 0.1897244155406952, Total Loss: 400.2016906738281
SGLF Set - Reconstruction Loss: 399.98150634765625, KL Divergence: 0.17907898128032684, Total Loss: 400.16058349609375
BBH Set - Reconstruction Loss: 400.5328674316406, KL Divergence: 0.14685750007629395, Total Loss: 400.6797180175781


In [182]:
model_name = "VAE_increased_zdim_40_wd"

In [183]:
def log_model_results(model_name, z_dim, batch_size, epochs, lr_rate, weight_decay, test_rec_loss, test_kl_loss, test_total_loss, sglf_rec_loss, sglf_kl_loss, sglf_total_loss, bbh_rec_loss, bbh_kl_loss, bbh_total_loss, file_name="vae_model_results.csv"):
    log_entry = {
        "model_name": model_name,
        "z_dim": z_dim,
        "batch_size": batch_size,
        "epochs": epochs,
        "lr_rate": lr_rate,
        "weight_decay": weight_decay,
        "test_rec_loss": test_rec_loss,
        "test_kl_loss": test_kl_loss,
        "test_total_loss": test_total_loss,
        "sglf_rec_loss": sglf_rec_loss,
        "sglf_kl_loss": sglf_kl_loss, 
        "sglf_total_loss": sglf_total_loss, 
        "bbh_rec_loss": bbh_rec_loss, 
        "bbh_kl_loss": bbh_kl_loss, 
        "bbh_total_loss": bbh_total_loss,
    }
    
    file_exists = False
    try:
        with open(file_name, mode='r', newline='') as file:
            file_exists = True 
    except FileNotFoundError:
        file_exists = False 

    with open(file_name, mode='a', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=log_entry.keys())
        if not file_exists:
            writer.writeheader()
        writer.writerow(log_entry)

In [184]:
# log_model_results(model_name, z_dim, batch_size, num_epochs, lr_rate, weight_decay, test_rec_loss, test_kl_loss, test_total_loss, sglf_rec_loss, sglf_kl_loss, sglf_total_loss, bbh_rec_loss, bbh_kl_loss, bbh_total_loss)