In [9]:
import torch
import torch.nn as nn

from tqdm.notebook import tqdm
import os

from models import MLP
from helpers import model_to_list

### Variational Autoencoder to generate weights

In [4]:
def reparameterize(mean, var):
    std = torch.exp(0.5*var)
    eps = torch.randn_like(std)
    return mean + eps*std

In [7]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(151, 100),
            torch.nn.ReLU(),
            torch.nn.Linear(100, 50),
            torch.nn.ReLU()
        )
        self.fc_mean = torch.nn.Linear(50, 10)
        self.fc_var = torch.nn.Linear(50, 10)

    def forward(self, x):
        x = self.encoder(x)
        mean = self.fc_mean(x)
        var = self.fc_var(x)

        return mean, var
    

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(10, 50),
            torch.nn.ReLU(),
            torch.nn.Linear(50, 100),
            torch.nn.ReLU(),
            torch.nn.Linear(100, 151),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        return self.decoder(x)

class Autoencoder(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder):
        super(Autoencoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        mean, var  = self.encoder(x)
        x = reparameterize(mean, var)
        x = self.decoder(x)
        return x, mean, var
        


In [8]:
def total_loss_function(x, x_hat, mean, var):
    reconstruction_loss = nn.L1Loss(x, x_hat)
    # kl_divergence = -0.5 * torch.sum(1 + var - mean.pow(2) - var.exp())
    return reconstruction_loss

In [None]:
# Model 
encoder = Encoder()
decoder = Decoder()
autoencoder = Autoencoder(encoder, decoder)

# Hyperparameters
epochs = 100
learning_rate = 0.001
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=learning_rate)

# Training
for epoch in tqdm(range(epochs)):
    total_loss = 0
    for parameters in dataloader:
        parameters = parameters.float()
        pred, mean, var = autoencoder(parameters)
        loss = total_loss_function(parameters, pred, mean, var)
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}: Loss: {total_loss}")