In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

#%config InlineBackend.figure_format = 'svg'
#%config InlineBackend.figure_format = 'pdf'

In [None]:
import kmod
import kmod.glo as glo
import kmod.plot as plot
import kmod.util as util
import kmod.kernel as kernel
import kmod.ex.exutil as exutil

import matplotlib
import matplotlib.pyplot as plt
import os
import numpy as np
import scipy.stats as stats
import torch
import torch.autograd
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from kmod.gan_ume_opt import ume_power_criterion

In [None]:

# font options
font = {
    #'family' : 'normal',
    #'weight' : 'bold',
    'size'   : 18
}

plt.rc('font', **font)
plt.rc('lines', linewidth=2)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

In [None]:
# set this to False to avoid using a GPU
use_cuda = True and torch.cuda.is_available()
# load option depends on whether GPU is used
load_options = {} if use_cuda else {'map_location': lambda storage, loc: storage} 
device = torch.device("cuda" if use_cuda else "cpu")
default_type = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
torch.set_default_tensor_type(default_type)

In [None]:
# load a model from the shared folder
shared_resource_path = glo.shared_resource_folder()
model_folder = glo.shared_resource_folder('prob_models', 'mnist_cnn')
epochs = 20
seed = 1
model_fname = 'mnist_cnn_ep{}_s{}.pt'.format(epochs, seed)
model_fpath = os.path.join(model_folder, model_fname)

print('Shared resource path at: {}'.format(shared_resource_path))
print('Model folder: {}'.format(model_folder))
print('Model file: ', model_fname)

In [None]:
from kmod.mnist.classify import MnistClassifier

classifier = MnistClassifier.load(model_fpath, **load_options)
# evaluation mode
classifier = classifier.eval().to(device)
# classifier is a torch.nn.Module
display(classifier)

In [None]:
def norm(x, loc, scale):
    return (x - loc) / scale


def mnist_norm(x):
    return norm(x, 0.1307, 0.3081)


def trans_gan(x):
    y = norm(x, -1.0, 2.0)
    return mnist_norm(y)

def trans_vae(x):
    return mnist_norm(x).view(-1, 1, 28, 28)

def get_trans(model_type):
    name = model_type.lower()
    if name not in exutil.mnist_model_names:
        raise ValueError('Model name has be one of '
                          '{} and was'.format(key_list, name))
    print('Model: {}'.format(name))
    if 'gan' in name:
        return trans_gan
    elif name == 'vae':
        return mnist_norm

In [None]:
# these two lines are for loading DCGAN 
from kmod.mnist.dcgan import Generator
from kmod.mnist.dcgan import DCGAN

import kmod.mnist.dcgan as mnist_dcgan
import kmod.net as net
import kmod.gen as gen

def vae_sample(vae, n):
    sample = torch.randn(n, 20).to(device)
    gen_imgs = vae.decode(sample)
    #bern = torch.distributions.Bernoulli(probs=gen_imgs)
    #return bern.sample().view(n, -1, 28, 28)
    return gen_imgs.detach().view(n, -1, 28, 28)

## Load models and generate samples

In [None]:
from kmod.ex import exutil
model_type_p = 'lsgan'
epoch_p = 30
gen_p = exutil.load_mnist_gen(model_type_p, epoch_p, default_type, **load_options)
model_name_p = '{}-{}'.format(model_type_p.upper(), epoch_p)

In [None]:
model_type_q = 'dcgan'
epoch_q = 50
gen_q = exutil.load_mnist_gen(model_type_q, epoch_q, default_type, **load_options)
model_name_q = '{}-{}'.format(model_type_q.upper(), epoch_q)

In [None]:
n_gen = 4000
gen_imgs_p = gen_p.sample(n_gen)

In [None]:
import kmod.mnist.util as mnist_util
digit_mapper = lambda Xte: torch.argmax(classifier(Xte), dim=1)
print('Sample from p =', model_type_p)
mnist_util.show_sorted_digits(gen_imgs_p, digit_mapper, n_per_row=8, figsize=(8, 8))
# n_show = 12*5
# print('Sample from p =', model_type_p)
# plt.figure()
# plot.show_torch_imgs(gen_imgs_p[:n_show], nrow=12, figsize=(12, 5), normalize=True)

In [None]:
gen_imgs_q = gen_q.sample(n_gen)

In [None]:
print('Sample from q = ', model_type_q)
mnist_util.show_sorted_digits(gen_imgs_q, digit_mapper, n_per_row=8, figsize=(8, 8))

# plt.figure()
# plot.show_torch_imgs(gen_imgs_q[:n_show], nrow=12, figsize=(12, 5), normalize=True)

## Classify generated samples

In [None]:
trans_p = get_trans(model_type_p)
trans_q = get_trans(model_type_q)

In [None]:
pred_results = []
batch_size = 100

for i in range(0, n_gen, batch_size):
    x = gen_imgs_p[i:i+batch_size]
    x = trans_p(x)
    pred = torch.argmax(classifier(x), dim=1)
    pred_results.append(pred)
pred_results_p = torch.cat(pred_results)
pred_num_p = []
for i in range(10):
    pred_num_p.append(torch.sum(pred_results_p==i).item())

In [None]:
pred_results = []
batch_size = 100
for i in range(0, n_gen, batch_size):
    x = gen_imgs_q[i:i+batch_size]
    x = trans_q(x)
    pred = torch.argmax(classifier(x), dim=1)
    pred_results.append(pred)
pred_results_q = torch.cat(pred_results)
pred_num_q = []
for i in range(10):
    pred_num_q.append(torch.sum(pred_results_q==i).item())

In [None]:
plt.xlabel('class')
plt.ylabel('pred[%]')
plt.bar(np.arange(10), pred_num_p/np.sum(pred_num_p)*100, alpha=0.5);
print('Class proportions of generated samples from p')

In [None]:
plt.xlabel('class')
plt.ylabel('pred[%]')
print('Class proportions of generated samples from q')
plt.bar(np.arange(10), pred_num_q/np.sum(pred_num_q)*100, alpha=0.5);

In [None]:
plt.figure(figsize=(10, 6))
fig, ax = plt.subplots()
index = np.arange(10)
bar_width = 0.35
opacity = 0.7
pred_per_p = pred_num_p / np.sum(pred_num_p) * 100
pred_per_q = pred_num_q / np.sum(pred_num_q) * 100

rects1 = plt.bar(index, pred_per_p, bar_width, alpha=opacity,
                color='r', label=model_name_p, hatch='.')
rects2 = plt.bar(index+bar_width, pred_per_q, bar_width, alpha=opacity,
                color='b', label=model_name_q, hatch='')
plt.xlabel('Digit')
plt.ylabel('Proportion [%]')
plt.xticks(np.arange(10))
ax.legend(loc='upper right', ncol=2, bbox_to_anchor=(1.0, 1.25));


## Power criterion bar chart

In [None]:
def extractor(imgs):
    """
    Feature extractor
    """
    self = classifier
    x = imgs
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = x.view(-1, 320)
    return x

def extractor_cls(imgs):
    self = classifier
    x = imgs
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2(x), 2))
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    return x


In [None]:
# featurizer = extractor_cls
featurizer = classifier
# featurizer = extractor
# load MNIST data
mnist_folder = glo.data_file('mnist')
mnist_dataset = torchvision.datasets.MNIST(mnist_folder, train=False, 
                        transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))

In [None]:
num_sample = 2000
num_classes = 10
J = 40
reg = 1e-4
n_sample_per_class = num_sample // num_classes
len_data = len(mnist_dataset)
input_Z = []
mnist_Y = torch.stack([mnist_dataset[i][1] for i in range(len_data)])
mnist_X = torch.stack([mnist_dataset[i][0] for i in range(len_data)])

In [None]:
def slice_array(arr, sizes):
    if not sizes or sum(sizes) == 0 or len(sizes) == 0:
        raise ValueError('sizes cannot be empty. Was {}'.format(sizes))
    sub_arrs = []
    idx = 0
    for i in range(0, len(sizes)):
        sub_arrs.append(arr[idx: idx+sizes[i]])
        idx += sizes[i]
    return sub_arrs

In [None]:
num_trials = 100
results = np.empty([num_trials, num_classes])

for i in range(num_trials):
    seed = i
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    X = featurizer(trans_p(gen_p.sample(num_sample)))
    Y = featurizer(trans_q(gen_q.sample(num_sample)))
    #Y = featurizer(mnist_norm(vae_sample(vae, num_sample)))
    V_list = []
    Z_list = []
    for j in range(num_classes):
        idx = (mnist_Y == j)
        rand_idx = util.subsample_ind(len(mnist_Y[idx]), len(mnist_Y[idx]), seed=seed)
        Z, V = slice_array(mnist_X[idx][rand_idx], [n_sample_per_class, J]) 
        Z_list.append(Z)
        V_list.append(V)
    Z = torch.cat(Z_list).to(device)
    Z = featurizer(Z)
    
    XYZ = np.vstack((X.cpu().data.numpy(), Y.cpu().data.numpy(), Z.cpu().data.numpy()))
    med = util.meddistance(XYZ, subsample=1000)
    gwidth2 = torch.tensor(med**2, requires_grad=True, device=device)
    k = kernel.PTKGauss(gwidth2)
    for j in range(num_classes):
        V = V_list[j]
        V = featurizer(V.to(device))
        results[i, j] = ume_power_criterion(X, Y, Z, V, V, k, reg)

In [None]:
plt.xlabel('Digit')
plt.ylabel('Power Criterion')
plt.xticks(np.arange(10))
plt.bar(np.arange(10), np.mean(results, 0), alpha=0.5)

In [None]:
print(np.std(results, 0))

Plot the results as a stack of histograms of power criteria (one histogram for each digit).

In [None]:
# https://matplotlib.org/examples/mplot3d/bars3d_demo.html
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
for di in range(10):
    # make one histogram for each digit
    counts, edges = np.histogram(results[:, di], bins=10, density=True)
    bin_centers = (edges[:-1]+edges[1:])/2.0
    assert(len(counts)==len(bin_centers))
    bin_width = edges[1] - edges[0]
    ax.bar(bin_centers, counts, width=bin_width, zs=di, zdir='y', alpha=0.6)

ax.set_xlabel('Power Criterion')
ax.set_ylabel('Digit')
ax.set_yticks(np.arange(10))
ax.set_zlabel('Density')
fig.savefig('mnist_3d_powcri_hists.pdf', bbox_inches='tight')
fig.savefig('mnist_3d_powcri_hists.png', bbox_inches='tight')

Plot the results as a stack of violin plots.

In [None]:
list_results = [results[:, i] for i in range(10)]
plt.violinplot(list_results, np.arange(10));
# plt.boxplot(list_results);
plt.xticks(np.arange(10));
plt.xlabel('Digit')
plt.ylabel('Power Criterion');

Box plot

In [None]:
plt.figure(figsize=(8, 4))
medianprops = dict(linestyle='-', linewidth=3, color='firebrick')
plt.boxplot(list_results, notch=True, medianprops=medianprops);
plt.xticks(np.arange(1, 11), np.arange(0, 10));
plt.xlabel('Digit')
plt.ylabel('Power Criterion');