In [1]:
import copy
import cv2
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from matplotlib.offsetbox import OffsetImage, AnnotationBbox

from scipy.stats import chi2
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn import manifold
from sklearn.mixture import GaussianMixture
from sklearn.metrics.pairwise import cosine_distances
from sklearn.preprocessing import MinMaxScaler, StandardScaler

import statsmodels.api as sm

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
from torch.autograd import Variable, grad
from torch.nn import functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image

# VAE Model

In [2]:
class VAE(nn.Module):
    
    def __init__(self):
        super(VAE, self).__init__()

        self.dimensionality = 64
        
        # (W−F+2P)/S+1
        # After conv_1, we have a (28-2)/2 + 1 => (14x14) image with 64 channels
        self.conv_1 = nn.Conv2d(1, self.dimensionality, 
                                kernel_size=(2, 2),
                                stride=2)

        # After conv_2, we have a (14-2)/2 + 1 => (7x7) image with 128 channels
        self.conv_2 = nn.Conv2d(self.dimensionality, 2*self.dimensionality, 
                                kernel_size=(2, 2), 
                                stride=2)
 
        self.fc31 = nn.Linear((self.dimensionality*2)*7*7, 20)
        self.fc32 = nn.Linear((self.dimensionality*2)*7*7, 20)
        
        self.fc4 = nn.Linear(20, (2*self.dimensionality)*7*7)
        
        # S*(N-1) + F - 2P
        # After deconv_5, we have a 2*(7-1) + 2 => 14x14 image with 64 channels
        self.deconv_5 = nn.ConvTranspose2d(2*self.dimensionality, self.dimensionality,
                                           kernel_size=(2, 2),
                                           stride=2)
        # After deconv_6, we have a 2*(14-1) + 2 => 28x28 image with 1 channel
        self.deconv_6 = nn.ConvTranspose2d(self.dimensionality, 1, 
                                           kernel_size=(2, 2),
                                           stride=2)
                
        self.optimizer = optim.Adam(self.parameters(), lr=1e-3)

    def encode(self, x):
        x = x.view(-1, 1, 28, 28)
        x = F.relu(self.conv_1(x))
        x = F.relu(self.conv_2(x))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        return self.fc31(x), self.fc32(x)

    def _reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = Variable(torch.randn(std.shape))
        return eps.mul(std).add_(mu)

    def decode(self, z):
        h = F.relu(self.fc4(z))
        h = h.view(-1, 2*self.dimensionality, 7, 7)
        h = F.relu(self.deconv_5(h))
        h = torch.sigmoid(self.deconv_6(h))
        return h.view(-1, 28, 28)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self._reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    
    def loss(self, x, reconstructed_x, mu, log_var):

        # Maximize P(x)
        # => Since P(x,z) = P(z|x)P(x), P(x) = P(x,z) / P(z|x)
        # => log P(x) = log P(x,z) - log P(z|x)
        # => log P(x) = log P(x,z) - log Q(z) - log P(z|x) + log Q(z)
        # => log P(x) = log [P(x,z)/Q(z)] - log[P(z|x)/Q(z)]
        # => E_Q(z)[log(P(x))] = E_Q(z)[log [P(x,z)/Q(z)]] - E_Q(z)[log[P(z|x)/Q(z)]]
        # => log(P(x)) = E_Q(z)[log [P(x,z)/Q(z)]] - KL(P(z|x) || Q(z))
        # => log(P(x)) >= E_Q(z)[log [P(x,z)/Q(z)]]
        # => log(P(x)) >= E_Q(z)[log P(x,z)] - E_Q(z)[log Q(z)]
        # => log(P(x)) >= E_Q(z)[log P(x|z)] + E_Q(z)[log P(z)] - E_Q(z)[log Q(z)]
        # => log(P(x)) >= E_Q(z)[log P(x|z)] - KL(P(z) || Q(z))
        # So...
        # Minimizing - log(P(x)) can be accomplished by minimizing loss = KL(P(z) || Q(z)) - E_Q(z)[log P(x|z)]]
        
        # KL(P(z) || Q(z))
        # When P(z) and Q(z) are restricted to be gaussians with Q(z) ~ N(0, 1):
        # KL(P(z) || Q(z)) = -0.5*ln(sigma_i) + 0.5*sigma_i^2 + 0.5*mu_i^2 - 0.5
        # => KL(P(z) || Q(z)) = 0.5 * sum_i [sigma_i^2 + mu_i^2 - ln(sigma_i) - 1]
        kld = 0.5 * torch.sum(log_var.exp() + mu.pow(2) - log_var - 1.)
        
        # - E[log P(x|z)]        
        reconstruction = F.binary_cross_entropy(reconstructed_x.view(-1, 784), x.view(-1, 784))

        return reconstruction + kld
    
    def train(self, train_loader):
        sum_train_loss = 0
        sum_likelihood = 0
        
        for batch_idx, (data_batch, _) in enumerate(train_loader):
            data_batch = Variable(data_batch)
            
            self.zero_grad()
            reconstructed_batch, mu, log_var = self(data_batch)
            loss = self.loss(data_batch, reconstructed_batch, mu, log_var)
            loss.backward()
            sum_train_loss += loss.item()
            sum_likelihood += 0 # self.likelihood(data_batch, reconstructed_batch, mu, log_var)
            self.optimizer.step()
            
        return sum_train_loss, sum_likelihood

### VAE Train Loop

In [3]:
def train_vae(model, model_name, train_loader, num_epochs=1000, log_interval=5):
    for epoch in range(1, num_epochs+1):
        sum_train_loss, sum_likelihood = model.train(train_loader)
        print("EPOCH {0}".format(epoch))
        print("    SUM TRAIN LOSS : {0}".format(sum_train_loss))
        print("    MEAN LIKELIHOOD: {0}".format(sum_likelihood/len(train_loader)))
        
        if epoch % log_interval == 0:
            sample = Variable(torch.randn(64, 20))
            sample = model.decode(sample).data
            save_image(sample.view(64, 1, 28, 28), 'OUTPUT/reconstructed_samples/sample_{0}_{1}.png'.format(model_name, epoch))
            torch.save(model.state_dict(), 'OUTPUT/vae_models/{0}_epoch={1}.pt'.format(model_name, epoch))

## Train VAE on Full MNIST Dataset

In [4]:
# Hyperparameters
BATCH_SIZE = 128

In [None]:
# Load train and test data
mnist_train_dataset = datasets.MNIST('DATA/', train=True, download=True, transform=transforms.ToTensor())
mnist_test_dataset = datasets.MNIST('DATA/', train=False, download=True, transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(mnist_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
# Construct model
full_mnist_model = VAE()
# full_mnist_model.load_state_dict(torch.load('OUTPUT/vae_models/model.pt'))

In [None]:
# Train model
train_vae(full_mnist_model, 'full_mnist_model', train_loader)

ERROR:root:No traceback has been produced, nothing to debug.


EPOCH 1
    SUM TRAIN LOSS : 300.28398713469505
    MEAN LIKELIHOOD: 0.0
EPOCH 2
    SUM TRAIN LOSS : 124.61038218438625
    MEAN LIKELIHOOD: 0.0
EPOCH 3
    SUM TRAIN LOSS : 124.06747955083847
    MEAN LIKELIHOOD: 0.0


## Train VAE on Subset MNIST Dataset

In [None]:
def build_mnist_dataset_with_only_label_from_dataset(d, label, test=False):
    nd = copy.deepcopy(d)
    idxs_with_label = []
    
    for i in range(0, len(d)):  
        if d[i][1] == label:
            idxs_with_label.append(i)
    
    if test:
        nd.test_data = nd.test_data[idxs_with_label]
        nd.test_labels = nd.test_labels[idxs_with_label]
    else:
        nd.train_data = nd.train_data[idxs_with_label]
        nd.train_labels = nd.train_labels[idxs_with_label]
            
    return nd

In [None]:
mnist_train_dataset_8 = build_mnist_dataset_with_only_label_from_dataset(mnist_train_dataset, 8)
mnist_test_dataset_8 = build_mnist_dataset_with_only_label_from_dataset(mnist_test_dataset, 8, test=True)
mnist_test_dataset_6 = build_mnist_dataset_with_only_label_from_dataset(mnist_test_dataset, 6, test=True)
mnist_test_dataset_2 = build_mnist_dataset_with_only_label_from_dataset(mnist_test_dataset, 2, test=True)
mnist_test_dataset_1 = build_mnist_dataset_with_only_label_from_dataset(mnist_test_dataset, 1, test=True)

In [None]:
# Build train loader for dataset of 8's.
BATCH_SIZE = 128
mnist_8_train_loader = torch.utils.data.DataLoader(mnist_train_dataset_8, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
# Construct model
mnist_8_model = VAE()
mnist_8_model.load_state_dict(torch.load('OUTPUT/vae_models/mnist_8_model_epoch=230.pt'))

In [None]:
# Train VAE on dataset of only 8's.
#train_vae(mnist_8_model, 'mnist_8_model', mnist_8_train_loader)

# Experiment 1: Permute Latent Space Along Principal Components

In [None]:
# Encode training set into latent space
z_batches = []
data_batches = []

for batch_idx, (data_batch, _) in enumerate(train_loader):
    data_batches.append(data_batch)
    data_batch = Variable(data_batch)
    mu, logvar = full_mnist_model.encode(data_batch)
    z = full_mnist_model._reparameterize(mu, logvar)
    z_batches.append(z.data.numpy())
     
z = np.concatenate(z_batches, axis=0)
data = np.concatenate(data_batches, axis=0)

In [None]:
# Compute mean of latent space
z_mu = np.mean(z, axis=0)

In [None]:
# Apply PCA to latent feature vectors
pca = PCA(n_components=10)
pca.fit(z)

In [None]:
print(pca.components_.shape)

In [None]:
C = 3
i = 0
for principal_component in pca.components_:
    i += 1
    z_permuted = z_mu + C*principal_component
    z_permuted = Variable(torch.from_numpy(z_permuted))
    permuted_img = full_mnist_model.decode(z_permuted).data
    save_image(permuted_img.view(1, 1, 28, 28), 'OUTPUT/permuted_principal_component_samples/sample_' + str(i) + '.png')

# Experiment 2: Cluster Latent Encodings

In [None]:
kmeans = KMeans(n_clusters=10, random_state=0).fit(z)

1. Decode cluster center.
2. Find image in training set with closest latent space encoding to cluster_center.

In [None]:
i = 0
for cluster_center in kmeans.cluster_centers_:
    i += 1
    
    # Decode cluster center
    cluster_center_var = Variable(torch.from_numpy(cluster_center))
    decoded_cluster_center = full_mnist_model.decode(cluster_center_var)
    save_image(decoded_cluster_center.data.view(1, 1, 28, 28), 'OUTPUT/train_set_decoded_latent_space_cluster_centers/decoded_cluster_center_' + str(i) + '.png') 
    
    # Find image in training set with closest latent space encoding to cluster_center
    closest_distance = float("inf")
    closest_distance_data = None
    for (z_i, data_i) in zip(z, data):
        distance_i = cosine_distances([z_i], [cluster_center])
        if (distance_i < closest_distance):
            closest_distance = distance_i
            closest_distance_data = torch.from_numpy(data_i)
        
    save_image(closest_distance_data.view(1, 1, 28, 28), 'OUTPUT/train_set_imgs_closest_to_latent_space_cluster_centers/cluster_' + str(i) + '_representative.png') 

# Experiment 3: Contrived Outlier Problem (MNIST)

Plot examples of test dataset digits.

In [None]:
example_test_8 = mnist_test_dataset_8[0][0]
example_test_6 = mnist_test_dataset_6[1][0]
example_test_2 = mnist_test_dataset_2[1][0]
example_test_1 = mnist_test_dataset_1[1][0]

fig, axes = plt.subplots(2, 2)
axes[0, 0].imshow(example_test_8.view(28, 28))
axes[0, 1].imshow(example_test_6.view(28, 28))
axes[1, 0].imshow(example_test_2.view(28, 28))
axes[1, 1].imshow(example_test_1.view(28, 28))
plt.show()

### Experiment 3.1: Likelihood Statistic Approach

Reconstruct each of the example digits using the VAE trained on only 8s.

In [None]:
reconstructed_example_test_8, _, _ = mnist_8_model(Variable(example_test_8))
reconstructed_example_test_6, _, _ = mnist_8_model(Variable(example_test_6))
reconstructed_example_test_2, _, _ = mnist_8_model(Variable(example_test_2))
reconstructed_example_test_1, _, _ = mnist_8_model(Variable(example_test_1))

likelihood_statistic_example_test_8 = mnist_8_model.likelihood_statistic(example_test_8).data.item()
likelihood_statistic_example_test_6 = mnist_8_model.likelihood_statistic(example_test_6).data.item()
likelihood_statistic_example_test_2 = mnist_8_model.likelihood_statistic(example_test_2).data.item()
likelihood_statistic_example_test_1 = mnist_8_model.likelihood_statistic(example_test_1).data.item()

fig, axes = plt.subplots(2, 2)
axes[0, 0].imshow(reconstructed_example_test_8.view(28, 28).data)
axes[0, 1].imshow(reconstructed_example_test_6.view(28, 28).data)
axes[1, 0].imshow(reconstructed_example_test_2.view(28, 28).data)
axes[1, 1].imshow(reconstructed_example_test_1.view(28, 28).data)

axes[0, 0].set_title("LS: {0:.4f}".format(likelihood_statistic_example_test_8))
axes[0, 1].set_title("LS: {0:.4f}".format(likelihood_statistic_example_test_6))
axes[1, 0].set_title("LS: {0:.4f}".format(likelihood_statistic_example_test_2))
axes[1, 1].set_title("LS: {0:.4f}".format(likelihood_statistic_example_test_1))
plt.subplots_adjust(hspace=0.5)
plt.show()


Build null distribution of likelihood statistics from training set values.

In [None]:
mnist_train_dataset_8_likelihood_statistics = mnist_8_model.likelihood_statistics_for_dataset(mnist_train_dataset_8)

Build distributions of likelihood statistics from test set values.

In [None]:
mnist_test_dataset_8_likelihood_statistics = mnist_8_model.likelihood_statistics_for_dataset(mnist_test_dataset_8)
mnist_test_dataset_6_likelihood_statistics = mnist_8_model.likelihood_statistics_for_dataset(mnist_test_dataset_6)
mnist_test_dataset_2_likelihood_statistics = mnist_8_model.likelihood_statistics_for_dataset(mnist_test_dataset_2)
mnist_test_dataset_1_likelihood_statistics = mnist_8_model.likelihood_statistics_for_dataset(mnist_test_dataset_1)

Plot null and test distributions of likelihood statistics together!

In [None]:
plt.hist(mnist_train_dataset_8_likelihood_statistics, density=True, color=(0.5, 0.5, 0.5, 0.3))
plt.hist(mnist_test_dataset_8_likelihood_statistics, density=True, color=(1., 0, 0, 0.3))
plt.hist(mnist_test_dataset_6_likelihood_statistics, density=True, color=(0, 1., 0, 0.3))
plt.hist(mnist_test_dataset_2_likelihood_statistics, density=True, color=(0.1, 1., 1., 0.3))
plt.hist(mnist_test_dataset_1_likelihood_statistics, density=True, color=(0, 0, 1., 0.3))
plt.legend(['Training Set of 8s (Null)', 'Test Set of 8s', 'Test Set of 6s', 'Test Set of 2s', 'Test Set of 1s'])
plt.title("Distributions of Likelihood Statistics")
plt.xlabel("Likelihood Statistic Values")
plt.ylabel("Likelihood Statistic Density")
plt.show()

#### Experiment 3.1.1: Classify Outliers on Count of Null Statistics Below Test Pt Likelihood Statistic

In [None]:
def classify_outliers_report(null_pts, test_pts, alpha=0.05):
    
    num_outliers = 0
    num_non_outliers = 0
    N = len(null_pts)
    
    for test_pt in test_pts:
        p_value_left_tail = sum(null_pt < test_pt for null_pt in null_pts) / N
            
        if p_value_left_tail < alpha:
            num_outliers += 1
        else:
            num_non_outliers += 1
            
    print("%     OUTLIERS: {0}".format(100 * num_outliers / len(test_pts)))
    print("% NON OUTLIERS: {0}".format(100 * num_non_outliers / len(test_pts)))

Test inlier classification performance on training 8s

In [None]:
classify_outliers_report(mnist_train_dataset_8_likelihood_statistics, mnist_train_dataset_8_likelihood_statistics)

Test inlier classification performance on test 8s:

In [None]:
classify_outliers_report(mnist_train_dataset_8_likelihood_statistics, mnist_test_dataset_8_likelihood_statistics)

Test outlier classification performance on 6s, 2s, 1s:

In [None]:
classify_outliers_report(mnist_train_dataset_8_likelihood_statistics, mnist_test_dataset_6_likelihood_statistics)

In [None]:
classify_outliers_report(mnist_train_dataset_8_likelihood_statistics, mnist_test_dataset_2_likelihood_statistics)

In [None]:
classify_outliers_report(mnist_train_dataset_8_likelihood_statistics, mnist_test_dataset_1_likelihood_statistics)

### Experiment 3.2: Distance From Mean Statistic Approach

In [None]:
mnist_train_dataset_8_images = []
for elem in mnist_train_dataset_8:
    mnist_train_dataset_8_images.append(elem[0].view(28, 28).numpy())

mnist_train_dataset_8_images_mean = np.mean(mnist_train_dataset_8_images, axis=0)
plt.imshow(mnist_train_dataset_8_images_mean)
plt.show()

In [None]:
# IDEA: Distance in image space could be defined as distance in activations of a 
#  classifier trained on a similar problem.

def distance_between_images(img1, img2):
    return np.sum(pow(img1.flatten() - img2.flatten(), 2))

def compute_distances_from_image(dataset, image):
    distances = []
    for elem in dataset:
        distance = distance_between_images(elem[0].numpy(), image)
        distances.append(distance)
    return distances

In [None]:
mnist_train_dataset_8_distance_statistics = compute_distances_from_image(mnist_train_dataset_8, mnist_train_dataset_8_images_mean)
mnist_test_dataset_8_distance_statistics = compute_distances_from_image(mnist_test_dataset_8, mnist_train_dataset_8_images_mean)
mnist_test_dataset_6_distance_statistics = compute_distances_from_image(mnist_test_dataset_6, mnist_train_dataset_8_images_mean)
mnist_test_dataset_2_distance_statistics = compute_distances_from_image(mnist_test_dataset_2, mnist_train_dataset_8_images_mean)
mnist_test_dataset_1_distance_statistics = compute_distances_from_image(mnist_test_dataset_1, mnist_train_dataset_8_images_mean)

plt.hist(mnist_train_dataset_8_distance_statistics, density=True, color=(0.5, 0.5, 0.5, 0.3))
plt.hist(mnist_test_dataset_8_distance_statistics, density=True, color=(1., 0, 0, 0.3))
plt.hist(mnist_test_dataset_6_distance_statistics, density=True, color=(0, 1., 0, 0.3))
plt.hist(mnist_test_dataset_2_distance_statistics, density=True, color=(0.1, 1., 1., 0.3))
plt.hist(mnist_test_dataset_1_distance_statistics, density=True, color=(0, 0, 1., 0.3))
plt.legend(['Training Set of 8s (Null)', 'Test Set of 8s', 'Test Set of 6s', 'Test Set of 2s', 'Test Set of 1s'])
plt.title("Distributions of Distances to Mean")
plt.xlabel("(Distance to Mean) Statistic Values")
plt.ylabel("Statistic Density")
plt.show()

# Experiment 4: Move Through Image Space

In [None]:
def closest_image_in_dataset_to_image(dataset, query_img):

    closest_elem = None
    closest_elem_distance = float('inf')
    
    for elem in dataset:
        elem_distance = distance_between_images(elem[0].numpy(), query_img.numpy())
        if elem_distance < closest_elem_distance:
            closest_elem = elem[0]
            closest_elem_distance = elem_distance
    
    return closest_elem.view(28, 28).numpy()

In [None]:
closest_training_set_point_to_example_test_6 = closest_image_in_dataset_to_image(mnist_train_dataset_8, example_test_6)
plt.imshow(closest_training_set_point_to_example_test_6)
plt.show()

In [None]:
closest_training_set_point_to_example_test_1 = closest_image_in_dataset_to_image(mnist_train_dataset_8, example_test_1)
plt.imshow(closest_training_set_point_to_example_test_1)
plt.show()

In [None]:
# Interpolate pixel values individually between example_test_6 and closest_training_set_point_to_example_test_6
diff_6 = (closest_training_set_point_to_example_test_6 - example_test_6.numpy()).reshape(28,28)
diff_1 = (closest_training_set_point_to_example_test_1 - example_test_1.numpy()).reshape(28,28)

In [None]:
num_interpolation_steps = 5

f, axs = plt.subplots(1, 6)
for i in range(0, num_interpolation_steps+1):
    interpolation_i = example_test_6.view(28, 28).numpy() + (i/num_interpolation_steps)*diff_6
    axs[i].imshow(interpolation_i)
plt.show(f)
    
f2, axs2 = plt.subplots(1, 6)
for i in range(0, num_interpolation_steps+1):
    interpolation_i = example_test_1.view(28, 28).numpy() + (i/num_interpolation_steps)*diff_1
    axs2[i].imshow(interpolation_i)
plt.show(f2)

# Experiment 5: Move Through Latent Space

In [None]:
def distance_between_images_in_latent_space(img1, img2):
    
    img1 = Variable(img1)
    img2 = Variable(img2)
    
    mu1, logvar1 = mnist_8_model.encode(img1)
    mu2, logvar2 = mnist_8_model.encode(img2)
    
    # KL Divergence between univariate gaussians
    return torch.sum(logvar2 - logvar1 + (logvar1.exp() + pow(mu1 - mu2, 2))/(2*logvar2.exp()) - 0.5).item()

def closest_image_in_dataset_to_image(dataset, query_img):

    closest_elem = None
    closest_elem_distance = float('inf')
    
    for elem in dataset:
        elem_distance = distance_between_images_in_latent_space(elem[0], query_img)
        if elem_distance < closest_elem_distance:
            closest_elem = elem[0]
            closest_elem_distance = elem_distance
    
    return closest_elem.view(28, 28)

In [None]:
closest_training_set_point_in_latent_space_to_example_test_6 = closest_image_in_dataset_to_image(mnist_train_dataset_8, example_test_6)
plt.imshow(closest_training_set_point_in_latent_space_to_example_test_6)
plt.show()

In [None]:
closest_training_set_point_in_latent_space_to_example_test_2 = closest_image_in_dataset_to_image(mnist_train_dataset_8, example_test_2)
plt.imshow(closest_training_set_point_in_latent_space_to_example_test_2)
plt.show()

In [None]:
closest_training_set_point_in_latent_space_to_example_test_1 = closest_image_in_dataset_to_image(mnist_train_dataset_8, example_test_1)
plt.imshow(closest_training_set_point_in_latent_space_to_example_test_1)
plt.show()

In [None]:
def plot_samples_along_path_between_images(model, img1, img2, num_samples=10, parameter_space=True):
    
    mu1, varlog1 = model.encode(Variable(img1))
    mu2, varlog2 = model.encode(Variable(img2))
    
    mu_diff = mu2 - mu1
    varlog_diff = varlog2 - varlog1
    
    f, axs = plt.subplots(1, num_samples+1, figsize=(15,5))
    
    for i in range(0, num_samples+1):
        interpolation_factor = (i/num_samples)
        if parameter_space:
            z_i = model._reparameterize(mu1 + interpolation_factor*mu_diff, varlog1 + interpolation_factor*varlog_diff)
        else:
            z_i = Variable((mu1 + interpolation_factor*mu_diff).data)
        sample_i = model.decode(z_i).view(28, 28).data.numpy()
        
        axs[i].get_yaxis().set_visible(False)
        axs[i].get_xaxis().set_visible(False)
        axs[i].imshow(sample_i)

    return f

#### Latent Parameter Space

In [None]:
plot_samples_along_path_between_images(mnist_8_model, example_test_6, closest_training_set_point_in_latent_space_to_example_test_6)

In [None]:
plot_samples_along_path_between_images(mnist_8_model, example_test_1, closest_training_set_point_in_latent_space_to_example_test_1)

In [None]:
plot_samples_along_path_between_images(mnist_8_model, example_test_2, closest_training_set_point_in_latent_space_to_example_test_2)

In [None]:
plot_samples_along_path_between_images(mnist_8_model, example_test_2, example_test_6)

#### Latent Space

In [None]:
plot_samples_along_path_between_images(mnist_8_model, example_test_6, closest_training_set_point_in_latent_space_to_example_test_6, parameter_space=False)

In [None]:
plot_samples_along_path_between_images(mnist_8_model, example_test_1, closest_training_set_point_in_latent_space_to_example_test_1)

In [None]:
plot_samples_along_path_between_images(mnist_8_model, example_test_2, closest_training_set_point_in_latent_space_to_example_test_2)

In [None]:
plot_samples_along_path_between_images(mnist_8_model, example_test_2, example_test_6)

# Experiment 6: Latent Space Visualization

In [None]:
def imscatter(x, y, ax, imageData, zoom):
    images = []
    for i in range(len(x)):
        x0, y0 = x[i], y[i]
        
        # Convert to image
        img = imageData[i]*255.
        img = img.astype(np.uint8).reshape([28, 28])
        img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
        
        # Note: OpenCV uses BGR and plt uses RGB
        image = OffsetImage(img, zoom=zoom)
        ab = AnnotationBbox(image, (x0, y0), xycoords='data', frameon=False)
        images.append(ax.add_artist(ab))
    
    ax.update_datalim(np.column_stack([x, y]))
    ax.autoscale()

In [None]:
# Show dataset images with T-sne projection of latent space encoding
def plotTSNEProjectionOfLatentSpace(X, model):
    
    # Compute latent space representation
    mu, logvar = model.encode(Variable(X))
    Z = model._reparameterize(mu, logvar).data.numpy()

    # Compute t-SNE embedding of latent space
    tsne = manifold.TSNE(n_components=2, init='pca', random_state=0)
    Z_tsne = tsne.fit_transform(Z)

    # Plot images according to t-sne embedding
    fig, ax = plt.subplots(figsize=(30,30))
    imscatter(Z_tsne[:, 0], Z_tsne[:, 1], imageData=X.numpy(), ax=ax, zoom=0.6)
    return fig

In [None]:
# Pick first 1000 elements of train dataset
X = []
for i in range(0,1000):
    X.append(mnist_train_dataset[i][0])
X = torch.stack(X)

In [None]:
plotTSNEProjectionOfLatentSpace(X, mnist_8_model)

In [None]:
plotTSNEProjectionOfLatentSpace(X, full_mnist_model)

# Experiment 7: Saliency Maps

### Likelihood Statistic Saliency Map

In [None]:
def saliency_map(model, x):
    
    # Wrap the input tensor in a Variable
    x = Variable(x, requires_grad=True)

    # Forward pass
    reconstructed_x, mu, log_var = model(x)

    # Recall:
    # => log(P(x)) >= E_Q(z)[log P(x|z)] - KL(P(z) || Q(z))
        
    # E[log P(X|z)]
    reconstruction = - F.binary_cross_entropy(x.view(-1, 784), reconstructed_x.view(-1, 784).detach())
        
    # - KL(P(z) || Q(z))
    kld = -0.5 * torch.sum(log_var.exp() + mu.pow(2) - log_var - 1.)

    # E_Q(z)[log P(x|z)] - B*KL(P(z) || Q(z))
    beta = 0.01
    likelihood_statistic = reconstruction + beta*kld
    
    # Backward pass
    likelihood_statistic.backward(gradient=torch.ones_like(likelihood_statistic))    
    
    # Saliency map from gradients of input
    saliency = x.grad.data.numpy().reshape(28, 28)
    
    return saliency

def visualize_saliency_map(model, x):
    saliency = saliency_map(model, x)

    sns.heatmap(saliency, cmap='RdBu_r', alpha=1.0)
    plt.title("Saliency Map")
    plt.show()

In [None]:
visualize_saliency_map(full_mnist_model, example_test_8)

In [None]:
visualize_saliency_map(full_mnist_model, example_test_6)

In [None]:
visualize_saliency_map(full_mnist_model, example_test_2)

In [None]:
visualize_saliency_map(full_mnist_model, example_test_1)

### Encoder Mean and Var Saliency Map

In [None]:
def _encoder_saliency_map(model, x, target):
    
    # Wrap the input tensor in a Variable
    x = Variable(x, requires_grad=True)

    # Forward pass
    output = model.encode(x)[target]
    
    # Backward pass
    output.backward(gradient=torch.ones_like(output))
    
    # Saliency map from gradients of input
    saliency = x.grad.data.numpy().reshape(28, 28)
    
    return saliency
    
def encoder_mean_saliency_map(model, x):
    return _encoder_saliency_map(model, x, 0)

def encoder_var_saliency_map(model, x):
    return _encoder_saliency_map(model, x, 1)

def visualize_encoder_saliency_maps(model, x):
    mean_saliency = encoder_mean_saliency_map(model, x)
    var_saliency = encoder_var_saliency_map(model, x)

    sns.heatmap(mean_saliency, cmap='RdBu_r', alpha=0.8)
    plt.title("Encoder Mean Parameter Saliency Map")
    plt.show()
    
    sns.heatmap(var_saliency, cmap='RdBu_r', alpha=0.8)
    plt.title("Encoder Var Parameter Saliency Map")
    plt.show()

In [None]:
visualize_encoder_saliency_maps(full_mnist_model, example_test_8)

In [None]:
visualize_encoder_saliency_maps(full_mnist_model, example_test_6)

In [None]:
visualize_encoder_saliency_maps(full_mnist_model, example_test_2)

In [None]:
visualize_encoder_saliency_maps(full_mnist_model, example_test_1)