## Decision-Making-VAE Experiements 

In [8]:
import numpy as np
from scipy.linalg import sqrtm

import torch; torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
device = 'cuda' if torch.cuda.is_available() else 'cpu'

First experiment : pPCA
In the following code, latent_dim  will relate to the dimension of the latent space

In [9]:
#Generating data for pPCA

seed = 0
dim_z=10 # d
dim_x=100 # p
nu=0.5
n_samples=1000
np.random.seed(seed)
# Parameters

A = 1 / np.sqrt(dim_z) * np.random.normal(size=(dim_x, dim_z))
px_mean = np.zeros((dim_x,))

# conditional covariance

#Sigma_x_cond_z = nu * np.diag(np.random.normal(loc=1, scale=2, size=dim_x) ** 2)
gamma = nu * np.diag(np.random.normal(loc=1, scale=2, size=dim_x) ** 2)
#inv_Sigma_x_cond_z = np.linalg.inv(Sigma_x_cond_z)
inv_gamma = np.linalg.inv(gamma)
px_var = gamma + np.dot(A, A.T)
#posterior

inv_pz_condx_var = np.eye(dim_z) + np.dot(np.dot(A.T, inv_gamma),A)
pz_condx_var = np.linalg.inv(inv_pz_condx_var)
mz_cond_x_mean = np.dot(pz_condx_var, np.dot(A.T, inv_gamma))


covar_joint = np.block([[np.eye(dim_z), A.T], [A, px_var]])
pxz_log_det = np.log(np.linalg.det(covar_joint))
pxz_inv_sqrt = sqrtm(np.linalg.inv(covar_joint))

#generating data
data = np.random.multivariate_normal(px_mean, px_var, size=(n_samples,))


#generating data
#x = np.random.multivariate_normal(np.zeros((dim_x,)),Sigma_x_cond_z+np.dot(A,A.T),size=(n_samples,))
#z = np.random.multivariate_normal(np.zeros((dim_z)),np.eye((dim_z)),size=(n_samples,))


# posterior expression
#inv_Sigma_z_cond_x = np.eye(dim_z) + np.dot(np.dot(A.T, inv_Sigma_x_cond_z), A)
#print(inv_Sigma_z_cond_x.shape)
#Sigma_z_cond_x = np.linalg.inv(inv_Sigma_z_cond_x)

#Mz_cond_x = np.dot(Sigma_z_cond_x, np.dot(A.T, inv_Sigma_x_cond_z))
#z_cond_x = np.random.multivariate_normal(np.dot(Mz_cond_x,x),Sigma_z_cond_x,size=(n_samples,))

# other stuff we need 
#covar_joint = np.block([[np.eye(dim_z), A.T], [A, Sigma_x_cond_z+np.dot(A,A.T)]])
#pxz_log_det = np.log(np.linalg.det(covar_joint))
#pxz_inv_sqrt = sqrtm(np.linalg.inv(covar_joint))


In [10]:
# Variational AutoEncoder structure

latent_dims = dim_z

class VariationalEncoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalEncoder, self).__init__()

        self.linear1 = nn.Linear(dim_x, dim_x//2)
        self.linear2 = nn.Linear(dim_x//2, latent_dims)
        self.linear3 = nn.Linear(dim_x//2, latent_dims)
        self.mu =0
        self.var = 0

        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
        self.N.scale = self.N.scale.cuda()
        self.kl = 0

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.linear1(x))
        mu =  self.linear2(x)
        sigma = torch.exp(self.linear3(x))
        z = mu + sigma*self.N.sample(mu.shape)
        # N(0, I)  -> p(z)
        # KL(N(mu,sigma);N(0,I)) = sum(sigma²+mu²-log sigma - 1/2)
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum() # KL(q,p(z))
        self.mu = mu
        self.var =sigma
        return z

class Decoder(nn.Module):
    def __init__(self, latent_dims):
        super(Decoder, self).__init__()
        
        self.linear1 = nn.Linear(latent_dims, dim_x//2)
        self.linear2 = nn.Linear(dim_x//2, dim_x)

    def forward(self, z):
        z = F.relu(self.linear1(z))
        z = torch.sigmoid(self.linear2(z))
        return z.reshape((-1, 1, dim_x)) 
           
class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims):
        super(VariationalAutoencoder, self).__init__()

        self.encoder = VariationalEncoder(latent_dims)
        self.decoder = Decoder(latent_dims)

    def forward(self, x):
        """
        Encodes the input by passing through the encoder network and returns the latent codes.
        """
        z = self.encoder(x)
        return self.decoder(z)

    def get_distribution(self):
      return(self.encoder.mu,self.encoder.var)

In [11]:
def train(autoencoder, data, epochs=20):
    opt = torch.optim.Adam(autoencoder.parameters())
    for epoch in range(epochs):
        for x in data:
            x = x.float().to(device) # GPU
            opt.zero_grad()
            x_hat = autoencoder(x)
            loss = ((x - x_hat)**2).sum() + autoencoder.encoder.kl
            #(sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum() + ((x - x_hat)**2).sum()
            loss.backward()
            opt.step()
    return autoencoder

In [12]:
data_loader = torch.utils.data.DataLoader(data,
                                          batch_size=1,
                                          shuffle=True)

In [13]:
vae = VariationalAutoencoder(latent_dims).to(device) # GPU
vae = train(vae, data_loader)

In [14]:
q_mu,q_sigma = vae.get_distribution()

In [15]:
print(q_mu.cpu().detach().numpy(),q_sigma.cpu().detach().numpy())

[[ 0.88169205  0.8570416   0.63243055 -0.3322851   1.4059932   1.0859292
   0.33169502  0.64419377 -1.6628016  -0.33515942]] [[0.5551765  0.25311744 0.35380572 0.2844929  0.22901307 0.2849016
  0.3074084  0.27310145 0.27291787 0.1789326 ]]


Plugin estimator for the paper ELBO toy example

In [16]:
threshold = 0.5

q_mu_a = q_mu.cpu().detach().numpy()[0]
q_var_a = q_sigma.cpu().detach().numpy()[0]
f = lambda z: z[0] > threshold 

def Q_plug(n,f):
  z_array = np.random.multivariate_normal(q_mu_a,np.diag(q_var_a),size=(n,))
  print(z_array.shape)
  return(np.sum(f(z_array)/n))

Q_plug(1000,f)

(1000, 10)


0.005

Some tries on the SNIS Estimator for pPCA

In [17]:
class IWVariationalEncoder(nn.Module):
    def __init__(self, latent_dims):
        super(IWVariationalEncoder, self).__init__()

        self.linear1 = nn.Linear(dim_x, dim_x//2)
        self.linear2 = nn.Linear(dim_x//2, latent_dims)
        self.linear3 = nn.Linear(dim_x//2, latent_dims)
        self.mu =0
        self.std = 0

        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
        self.N.scale = self.N.scale.cuda()
        self.kl = 0
    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.linear1(x))
        mu =  self.linear2(x)
        sigma = torch.sigmoid(self.linear3(x))
        z = mu + sigma*self.N.sample(mu.shape)
        # N(0, I)  -> p(z)
        # KL(N(mu,sigma);N(0,I)) = sum(sigma²+mu²-log sigma - 1/2)
        self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum() # KL(q,p(z))
        self.mu = mu
        self.std =sigma
        return z


class IWVariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims):
        super(IWVariationalAutoencoder, self).__init__()
        
        self.encoder = IWVariationalEncoder(latent_dims)
        self.decoder = Decoder(latent_dims)

    def loss(self,x, k):
      """
        Computes the IWVAE loss functionn using the IWELBO
      """
      z= self.encoder(x)
      z_mu, z_std = self.encoder.mu, self.encoder.std
      x_hat = self.decoder(z)
      x_hat = torch.sigmoid(torch.flatten(x_hat, start_dim=1))

      log_p_z = torch.distributions.Normal(0, 1).log_prob(z).sum(dim=-1)
      log_p_xGz = torch.distributions.Bernoulli(x_hat).log_prob(x).sum(dim=-1)
      log_q_zGx = torch.distributions.Normal(z_mu, z_std).log_prob(z).sum(dim=-1)

      log_weights = log_p_z + log_p_xGz - log_q_zGx
      elbo = log_weights.logsumexp(dim=0) - np.log(k)

      return -elbo.mean()

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

In [18]:
def IWtrain(autoencoder, data, epochs=20):
    opt = torch.optim.Adam(autoencoder.parameters())
    for epoch in range(epochs):
        for x in data:
            x = x.float().to(device) # GPU
            opt.zero_grad()
            x_hat = autoencoder(x)
            loss = autoencoder.loss(x,autoencoder.encoder.kl)
            loss.backward()
            opt.step()
    return autoencoder

In [19]:
vae = IWVariationalAutoencoder(latent_dims).to(device) # GPU
vae = IWtrain(vae, data_loader)

ValueError: ignored

-----