In [None]:
import torch
from torch import nn

In [None]:
tensor_size = 200 # all dims mul, just placeholder here
class VariationalAutoEncoder(nn.module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(tensor_size, (3 * tensor_size) / 4),
            nn.LeakyReLU(),
            nn.Linear((3 * tensor_size) / 4, (3 * tensor_size) / 6),
            nn.LeakyReLU(), 
            nn.Linear((3 * tensor_size) / 6, (3 * tensor_size) / 8),
            nn.LeakyReLU(),
            nn.Linear((3 * tensor_size) / 8, (3 * tensor_size) / 10)
        )
        self.reparamMean = nn.Linear((3 * tensor_size) / 10), (3 * tensor_size) / 12))
        self.reparamLogVarm = nn.Linear((3 * tensor_size) / 10), (3 * tensor_size) / 12))
        self.decoder = nn.Sequential(
            nn.Linear((3 * tensor_size) / 12, (3 * tensor_size) / 10),
            nn.LeakyReLU,
            nn.Linear((3 * tensor_size) / 10, (3 * tensor_size) / 8),
            nn.LeakyReLU(),
            nn.Linear((3 * tensor_size) / 8, (3 * tensor_size) / 6),
            nn.LeakyReLU(),
            nn.Linear((3 * tensor_size) / 6, (3 * tensor_size) / 4),
            nn.LeakyReLU(),
            nn.Linear((3 * tensor_size) / 4, tensor_size), 
            nn.Sigmoid()
        )
    def reparam(mean, logvarm):
        """Reparametrize for the Variational part of VAE (when training)"""
        if self.training: 
            stdev = torch.exp(torch.mul(0.5, logvarm)) 
            eps = torch.normal(mean, stdev)
            return torch.add(torch.mul(eps, stdev), mean)
        else: 
            return mean
    def forward(self, x):
        l1 = self.encode(x)
        mean, logvarm = self.reparamMean(l1), self.reparamLogVarm(l1)
        reparam = self.reparam(mean, logvarm)
        return self.decode(reparam), mean, logvarm

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
resuming = False
epochs = 400
learning_rate = 1e-2
loss_func = nn.MSELoss()
net = VariationalAutoEncoder().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

In [None]:
torch.manual_seed(42)
if os.path.isfile('models/interim_model.tar'):
    resuming = True
    checkpoint = torch.load('models/interim_model.tar')
    net.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss1 = checkpoint['loss']
for t in range(epochs):
    if resuming:
        it = epoch + t
        loss = loss1
    else: 
        it = t
    prediction, mean, logparm = net(x)   
    ae_loss = loss_func(prediction, x)
    kl_loss = nn.functional.kl_div(logparm, mean)
    # NOTE TO SELF: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
    loss = ae_loss + kl_loss
    optimizer.zero_grad()   
    loss.backward()         
    optimizer.step()        
    if it % 100 == 99:
        print("Epoch: ", t, "Loss: ", loss.item())
        torch.save({
            'epoch': it,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
        }, "models/interim_model.tar")
copyfile("models/interim_model.tar", "models/final_model.tar")
print('Done!')