In [3]:
import pandas as pd
import numpy as np
import torch.nn
import torchvision
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms

### A Linear Gaussian Latent Variable Model
$z\in \mathbb{R}^n,x\in \mathbb{R}^D\\$
Prior: $p(z)=\mathcal{N}(y,\omega\mathbb{I})\\$
$x=Wz+b+\epsilon,\epsilon\sim \mathcal{N}(u,\sigma I)$


In [60]:
class Encoder(nn.Module):
    def __init__(self, encoder_net):
        super(self)
        self.encoder = encoder_net
    
    def reparametrization(mu,va):
        std = torch.exp(0.5*va)
        return mu + std*torch.randn_like(std)
    
    def forward(self, x):
        h_e =  self.encoder(x)
        mu, va = torch.chunk(h_e, 2)
        return mu, va

    def encode(self, x):
        return self.encoder(x)
    
    def sample(self,mu,va):
        z = self.reparametrization(mu,va)
        return z
    
    def log_p(self,mu,va,z):
        log_p = -0.5 * z.shape[0] * torch.log(2. * torch.PI) - 0.5 * va - 0.5 * torch.exp(-va) * (z - mu)**2
        return log_p

In [7]:
class Decoder(nn.Module):
    def __init__(self, decoder_net):
        super(self)
        self.decoder = decoder_net
    
    def forward(self, z):
        return self.decoder(z)
    
    def decode(self, z):
        return self.decoder(z)

    def log_prob(self,x,z):
        out = self.decode(z)
        log_p = 1
        return log_p 


In [59]:
class Prior(nn.Module):
    def __init__(self, n, y, ome):
        super(Prior,self).__init__()
        self.n = n
        self.y = y
        self.ome = ome

    def forward(self):
        pass

    def sample(self):
        z = self.ome*torch.randn(self.n)+self.y
        return z

    def log_prob(self,z):
        log_p = -0.5 * z.shape[0] * torch.log(2. * torch.PI) - 0.5 * z**2
        return log_p

In [57]:
class VAE(nn.Module):
    def __init__(self,encoder_net,decoder_net,n,y,ome):
        super(VAE,self).__init__()
        self.encoder = Encoder(encoder_net)
        self.decoder = Decoder(decoder_net)
        self.prior = Prior(n,y,ome)
    
    def forward(self, x):
        mu, va = self.encoder.encode(x)
        z = self.encoder.sample(mu, va)
        RE = self.decoder.log_prob(x,z)
        KL = (self.prior.log_prob(z) - self.encoder.log_prob(mu,va,z)).sum(-1)
        return -(RE+KL).mean()

In [50]:
D = 100
n = 4
y = torch.zeros(n); ome = 1
u = torch.zeros(n); sig = 1
W = torch.randn(size=(D,n)); b = torch.randn(D)
encoder = nn.Sequential(nn.Linear(D,2*n))
decoder = nn.Sequential(nn.Linear(n,D))

In [51]:
def sample_x(z,W,b,u,sig):
    eps = sig*torch.randn(D) + u
    x = W @ z + b + eps
    return x