Skip to content
Switch branches/tags
Go to file
Cannot retrieve contributors at this time
from __future__ import print_function
import torch
from torch import __init__
from torch.optim import Adam
from .dipvae_utils import VAE, DIPVAE
from .dipvae_utils import plot_reconstructions, plot_latent_traversal
from aix360.algorithms.die import DIExplainer
import os
import sys
import random
import time
import numpy as np
class DIPVAEExplainer(DIExplainer):
"""DIPVAEExplainer can be used to visualize the changes in the latent space of Disentangled Inferred Prior-VAE
or DIPVAE [#1]_. This model is a Variational Autoencoder [#2]_ variant that leads to a
disentangled latent space. This is achieved by matching the covariance of the prior distributions with the
inferred prior.
.. [#1] `Variational Inference of Disentangled Latent Concepts from Unlabeled Observations (DIP-VAE), ICLR 2018.
Kumar, Sattigeri, Balakrishnan. <>`_
.. [#2] `Diederik P Kingma and Max Welling. Auto-encoding variational Bayes. ICLR, 2014. <>`_
def __init__(self, model_args, dataset=None, net=None, cuda_available=None):
Initialize DIPVAEExplainer explainer.
model_args: This should contain all the parameter required for the generative model training and
inference. This includes model type (vae, dipvae-i, dipvae-ii, user-defined). The user-defined model can be
passed to the parameter net of the fit() function. Each of the model should have encoder and decode function
defined. See the notebook example for other model specific parameters.
dataset: The dataset object.
net: If not None this is the user specified generative model.
cuda_available: If True use GPU.
super(DIPVAEExplainer, self).__init__()
self.model_args = model_args
if net is None:
if self.model_args.model == "vae": = VAE(num_nodes=self.model_args.num_nodes, activation_type=self.model_args.activation_type,
latent_dim=self.model_args.latent_dim,, args=self.model_args, cuda_available=cuda_available)
elif self.model_args.model == "dipvae-i": = DIPVAE(num_nodes=self.model_args.num_nodes, activation_type=self.model_args.activation_type,
latent_dim=self.model_args.latent_dim,, args=self.model_args, cuda_available=cuda_available, mode='i',
output_activation_type=dataset.output_activation_type, likelihood_type=dataset.likelihood_type)
elif self.model_args.model == "dipvae-ii": = DIPVAE(num_nodes=self.model_args.num_nodes, activation_type=self.model_args.activation_type,
latent_dim=self.model_args.latent_dim,, args=self.model_args, cuda_available=cuda_available, mode='ii',
output_activation_type=dataset.output_activation_type, likelihood_type=dataset.likelihood_type)
else: = net
self.cuda_available = cuda_available
self.dataset = dataset
if self.cuda_available: =
def set_params(self, *argv, **kwargs):
Set parameters for the explainer.
print("TBD: Implement set params in DIPVAEExplainer")
def explain(self,
Edits the images in the latent space and returns the generated images.
input_images: The input images.
edit_dim_id: The latent dimension id that need to be edited.
edit_dim_value: The value that is assigned to the latent dimension with id edit_dim_id.
edit_z_sample: If True will use the sample from encoder instead of the mean.
Edited images.
reference_z, reference_mu, reference_std =
if edit_z_sample:
edited_z = reference_z
edited_z[:,edit_dim_id] = edit_dim_value
edited_z = reference_mu
edited_z[:, edit_dim_id] = edit_dim_value
edited_images =
return edited_images
def fit(self, visualize=False, save_dir="results"):
Train the underlying generative model.
visualize: Plot reconstructions during fit.
save_dir: directory where plots and model will be saved.
optimizer = Adam(, lr=self.model_args.step_size)
loss_epoch_list = []
for epoch in np.arange(self.model_args.num_epochs):
loss_epoch = 0.
batch_id = 0
for x, y in self.dataset.next_batch():
#x, y = torch.tensor(x), torch.tensor(y)
if self.cuda_available:
x = x.cuda()
y = y.cuda()
# forward
if "mnist" in
loss =,
loss =
# backward
loss_epoch += loss
batch_id += 1
if visualize and batch_id % 10 == 0:
if not os.path.isdir(save_dir):
plot_reconstructions(self.dataset,, x, image_id_to_plot=2, epoch=epoch,
batch_id = batch_id, save_dir=save_dir)
if batch_id % 100 == 0:
plot_latent_traversal(self, x, self.model_args, self.dataset, image_id_to_plot=2, epoch=epoch,
batch_id = batch_id, save_dir=save_dir), os.path.join(save_dir, 'net.p'))
loss_epoch_list.append(-loss_epoch / self.dataset.num_training_instances)
print("Epoch {0} | ELBO {1}".format(epoch, -loss_epoch / self.dataset.num_training_instances))
return loss_epoch_list