# Import Functions 

In [1]:
from collections import defaultdict
import math 
import torch
from torch import nn, Tensor
from torch.nn.functional import softplus
from torch.distributions import Distribution
from torch.distributions import Normal
from torchvision.transforms import ToTensor
from functools import reduce
from typing import *
import numpy as np
import os
import pandas as pd


import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
from torch.nn.parameter import Parameter

import matplotlib.pyplot as plt

In [2]:
# import nbimporter
from generating_synthetic_data import generate_3Z_synthetic_data
from generating_synthetic_data import ReparameterizedDiagonalGaussian
from generating_synthetic_data import plot_latent_2d, plot_y_dist, p

from base_VAE import VariationalInference, run_training

# Generate Data 

In [3]:
E=[0.2, 2, 3, 5] # environmental factors
envs, Xs, train_loader, test_loader = generate_3Z_synthetic_data(E)
X = Xs[0] # call first X for testing

# Neural Network to Reconstruct X using 1 channel


## Model

In [4]:
class VariationalAutoencoder(nn.Module):
    """A Variational Autoencoder with
    * a Bernoulli observation model `p_\theta(x | z) = B(x | g_\theta(z))`
    * a Gaussian prior `p(z) = N(z | 0, I)`
    * a Gaussian posterior `q_\phi(z|x) = N(z | \mu(x), \sigma(x))`
    """
    
    def __init__(self, input_shape:torch.Size, latent_features:int) -> None:
        super(VariationalAutoencoder, self).__init__()
        
        self.input_shape = input_shape
        self.latent_features = latent_features
        self.observation_features = np.prod(input_shape)
        

        # Inference Network
        # Encode the observation `x` into the parameters of the posterior distribution
        # q_phi(z|x) = N(z | mu(x), sigma(x)),
        # mu(x),
        # log(sigma(x)) = h_phi(x)`
        
        # Step 1:
        # Define input dimensions -> self.input_shape
        # Step 2:
        # Define the rest of the encoding architecture
        self.encoder = nn.Sequential(
             nn.Linear(in_features=self.observation_features, out_features=10),
             nn.ReLU(),
            # A Gaussian is fully characterised by its mean \mu and variance \sigma**2
            nn.Linear(in_features=10, out_features=2*latent_features) # <- note the 2*latent_features
        )
        
        # Generative Model
        # Decode the latent sample `z` into the parameters of the observation model
        # `p_theta(x | z) = prod_i B(x_i | g_theta(x))`
        
        # Step 3:
        # Decode from latent space back to X
        self.decoder = nn.Sequential(
            nn.Linear(in_features=latent_features, out_features=10),
            nn.ReLU(),
            nn.Linear(in_features=10, out_features=self.observation_features)
            # 2*self.observation_features
            # Index which of the outputs are mu and which are sigma
        )
        
        # Step 4:
        # define the parameters of the prior, chosen as p(z) = N(0, I)
        self.register_buffer('prior_params', torch.zeros(torch.Size([1, 2*latent_features])))
        
    def posterior(self, x:Tensor) -> Distribution:
        """return the distribution `q(x|x) = N(z | mu(x), sigma(x))`"""
        
        # compute the parameters of the posterior
        h_x = self.encoder(x)
        mu, log_sigma =  h_x.chunk(2, dim=-1)
        
        # return a distribution `q(x|x) = N(z | mu(x), sigma(x))`
        return ReparameterizedDiagonalGaussian(mu, log_sigma)
    
    def prior(self, batch_size:int=1)-> Distribution:
        """return the distribution `p(z)`"""
        prior_params = self.prior_params.expand(batch_size, *self.prior_params.shape[-1:])
        mu, log_sigma = prior_params.chunk(2, dim=-1)
        
        # return the distribution `p(z)`
        return ReparameterizedDiagonalGaussian(mu, log_sigma)
    
    def observation_model(self, z:Tensor) -> Distribution:
        """return the distribution `p(x|z)`"""
        px_logits = self.decoder(z)
        #pdb.set_trace()
        px_logits = px_logits.view(-1, *self.input_shape) # reshape the output
        #pdb.set_trace()
        #return Bernoulli(logits=px_logits, validate_args=False)
        return Normal(px_logits, 0.1)
        

    def forward(self, x) -> Dict[str, Any]:
        """compute the posterior q(z|x) (encoder), sample z~q(z|x) and return the distribution p(x|z) (decoder)"""
        
        # flatten the input
        x = x.view(x.size(0), -1)
        
        # define the posterior q(z|x) / encode x into q(z|x)
        qz = self.posterior(x)
        
        # define the prior p(z) # Just assume a standard prior with mean 0 and var 1
        pz = self.prior(batch_size=x.size(0))
        
        # sample the posterior using the reparameterization trick: z ~ q(z | x)
        z = qz.rsample()
        #print(qz)
        
        # define the observation model p(x|z) = B(x | g(z))
        # Decode
        px = self.observation_model(z)
        
        return {'px': px, 'pz': pz, 'qz': qz, 'z': z}
    
    
    def sample_from_prior(self, batch_size:int=100):
        """sample z~p(z) and return p(x|z)"""
        
        # define the prior p(z)
        pz = self.prior(batch_size=batch_size)
        
        # sample the prior 
        z = pz.rsample()
        
        # define the observation model p(x|z) = B(x | g(z))
        px = self.observation_model(z)
        
        return {'px': px, 'pz': pz, 'z': z}


latent_features = 3
vae = VariationalAutoencoder(X[0].shape, latent_features)
print(vae)

VariationalAutoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=10, out_features=10, bias=True)
    (1): ReLU()
    (2): Linear(in_features=10, out_features=6, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=3, out_features=10, bias=True)
    (1): ReLU()
    (2): Linear(in_features=10, out_features=10, bias=True)
  )
)


## Training

In [5]:
# VAE
latent_features = 3
vae = VariationalAutoencoder(X[0].shape, latent_features)

# Evaluator: Variational Inference
beta = 1
vi = VariationalInference(beta=beta)

# The Adam optimizer works really well with VAEs.
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)


In [6]:
outputs, loss, diagnostics = run_training(vae, train_loader, test_loader, optimizer, vi, 50)

>> Using device: cuda:0


In [8]:
outputs['z']
px = vae.sample_from_prior(batch_size=X[0].size(0))['px']
x_samples = px.sample()
outputs['px'].scale

tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
         0.1000],
        [0

In [None]:
np.std(true_x[-1,:])

In [None]:
pred_x=Tensor.cpu(outputs['px'].loc).numpy()
true_x=X.detach().numpy()[910:,:]

plot_df = pd.concat([p(true_x[-1,:]), p(pred_x[-1,:]-true_x[-1,:])],axis=1)
plot_df.columns=['x', 'y']
plt.figure(figsize=(8,4), dpi=100)
sns.scatterplot(data=plot_df, x='x', y='y')
plt.hlines(0,-8,2)

In [None]:

plot_df.columns=['x', 'y', 'Env']
sns.scatterplot(data=plot_df, x='x', y='y', hue='Env')

In [None]:
import seaborn as sns
Za=Tensor.cpu(outputs['z']).numpy()
sns.kdeplot(Za[:,0], color='red') # Z's inferred from 'x' of test set
sns.kdeplot(Zs[910:,0,0]) # Synthetically generated Z's (test set)
plt.show()
sns.kdeplot(Za[:,1], color='red')
sns.kdeplot(Zs[910:,1,0])
plt.show()

In [None]:
p(Tensor.cpu(outputs['z']).numpy())

In [None]:
X.mean(axis=0)

In [None]:
X[910:,:].mean(axis=0)

In [None]:
x_samples.mean(axis=0)