In [4]:
import torch

In [5]:
class VAE(torch.nn.Module):
    def __init__(self, d, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(d, 100),
            torch.nn.ReLU(),
            torch.nn.Linear(100, 2*latent_dim)
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, 100),
            torch.nn.ReLU(),
             torch.nn.Linear(100, d)
        )

    def samplize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.rand_like(std)
        sample = mu + (eps * std)
        return sample
    
    def forward(self, x):
        x = self.encoder(x).view(-1, 2, self.latent_dim)
        mu = x[:, 0, :]
        log_var = x[:, 1, :]
        y = self.samplize(mu, log_var)
        z = self.decoder(y)
        return z, mu, log_var

In [6]:
def VAE_loss(y, mu, log_var):
    KL = -0.5 * torch.sum(1 + log_var - mu**2 - torch.exp(log_var))
    MSE = torch.nn.MSELoss()
    return KL + MSE

In [9]:
a = torch.randn((1, 10))
vae = VAE(10, 5)
vae(a)[0]

tensor([[ 0.1864, -0.0016, -0.2248,  0.0772, -0.2441, -0.2137,  0.3631, -0.3021,
         -0.4672, -0.2449]], grad_fn=<AddmmBackward0>)