In [1]:
from pytorch_regression_data_prep import get_diabetes_data

In [2]:
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from numba import jit

In [3]:
train_dataloader, test_dataloader, val_dataloader = get_diabetes_data(min_max=True)

In [5]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc_mu = nn.Linear(5, 2)
        self.fc_logvar = nn.Linear(5, 2)
        self.fc3 = nn.Linear(2, 5)
        self.fc4 = nn.Linear(5, 10)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc_mu(h1), self.fc_logvar(h1)

    def reparameterise(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 10))
        z = self.reparameterise(mu, logvar)
        return self.decode(z), mu, logvar
    

def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 10), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

def train():
    model = VAE()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    history = []

    num_epochs = 10000
    for epoch in range(num_epochs):
        for i, (x, y) in enumerate(train_dataloader):
            recon_x, mu, logvar = model(x)
            loss = loss_function(recon_x, x, mu, logvar)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        history.append((epoch+1,loss.item()))
    
    return model, history

In [6]:
trained_model, history = train()

In [21]:
history

[(1, 175.55772399902344),
 (2, 173.3131561279297),
 (3, 171.38291931152344),
 (4, 172.4010467529297),
 (5, 172.21463012695312),
 (6, 171.1937255859375),
 (7, 171.3336944580078),
 (8, 171.3549346923828),
 (9, 170.6791534423828),
 (10, 171.61392211914062),
 (11, 171.34349060058594),
 (12, 170.67666625976562),
 (13, 169.8026123046875),
 (14, 171.24159240722656),
 (15, 170.62081909179688),
 (16, 168.72872924804688),
 (17, 168.5198974609375),
 (18, 168.61062622070312),
 (19, 168.51956176757812),
 (20, 168.93606567382812),
 (21, 168.50160217285156),
 (22, 167.17041015625),
 (23, 167.173583984375),
 (24, 167.83367919921875),
 (25, 167.06570434570312),
 (26, 167.01373291015625),
 (27, 165.90283203125),
 (28, 167.0020751953125),
 (29, 166.7766876220703),
 (30, 166.21458435058594),
 (31, 167.06085205078125),
 (32, 165.94984436035156),
 (33, 166.36083984375),
 (34, 166.57318115234375),
 (35, 165.64393615722656),
 (36, 166.78250122070312),
 (37, 164.82566833496094),
 (38, 164.87339782714844),
 (39

In [22]:
trained_model

VAE(
  (fc1): Linear(in_features=10, out_features=5, bias=True)
  (fc_mu): Linear(in_features=5, out_features=2, bias=True)
  (fc_logvar): Linear(in_features=5, out_features=2, bias=True)
  (fc3): Linear(in_features=2, out_features=5, bias=True)
  (fc4): Linear(in_features=5, out_features=10, bias=True)
)