In [1]:
# Imports
import torch
cuda = torch.cuda.is_available()
import numpy as np
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
plt.switch_backend("agg")
import sys
sys.path.append("../../semi-supervised")

from models import AuxiliaryDeepGenerativeModel, DeepGenerativeModel, StackedDeepGenerativeModel, VariationalAutoencoder

In [2]:
features = VariationalAutoencoder([784, 50, [500, 500]]).cuda()
features.load_state_dict(torch.load("./vae_mnist.ckpt"))
stacked = StackedDeepGenerativeModel([784, 10, 50, [500]], features)
stacked.dgm.load_state_dict(torch.load("./m1m2_mnist.ckpt"))
stacked.dgm = stacked.dgm.cuda()

adgm = AuxiliaryDeepGenerativeModel([784, 10, 100, 100, [500, 500]])
adgm.load_state_dict(torch.load("./adgm_mnist.ckpt"))
adgm = adgm.cuda()




  init.xavier_normal(m.weight.data)


50
Linear layers! [ReLU(), BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)]
784
Linear layers! [ReLU(), BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), Linear(in_features=500, out_features=500, bias=True), ReLU(), BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)]


  init.xavier_normal(m.weight.data)


[784, 100]
Linear layers! [ReLU(), BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), Linear(in_features=500, out_features=500, bias=True), ReLU(), BatchNorm1d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)]


In [3]:
from datautils import get_mnist, get_svhn

labelled, unlabelled, validation, mnist_mean, mnist_std = get_mnist(location="./", batch_size=100, labels_per_class=10, preprocess=False)


In [4]:
adgm.eval()
stacked.dgm.eval()
z_dim = 100

z = torch.randn(100, z_dim).cuda()
y = np.zeros((100, 10))
y[np.arange(100), np.arange(100) // 10] = 1.
y = torch.tensor(y, dtype=torch.float).cuda()

x_mu = adgm.sample(z, y)




In [8]:
adgm.eval()

accuracy = 0.
for x, y in validation:

    if cuda:
        x, y = x.cuda(device=0), y.cuda(device=0)

    # x, _, _ = features.encoder(x)
    logits = adgm.classify(x)
    _, pred_idx = torch.max(logits, 1)
    _, lab_idx = torch.max(y, 1)
    accuracy += torch.mean((torch.max(logits, 1)[1].data == torch.max(y, 1)[1].data).float())

print(100 - accuracy)


tensor(5.1100, device='cuda:0')


In [6]:
f, axarr = plt.subplots(10, 10, figsize=(10, 10))

samples = x_mu.cpu().data.view(-1, 28, 28).numpy()
# samples = x_mu.data.view(-1, 3, 32, 32).cpu().numpy().transpose(0, 2, 3, 1)


# mnist_means = np.tile(mnist_mean.reshape((1, -1)), (len(samples), 1))
# mnist_means[:, mnist_std > 0.1] = samples
# samples = mnist_means.reshape(-1, 28, 28)


for i, ax in enumerate(axarr.flat):
    ax.imshow(samples[i], cmap="gray")
    ax.axis("off")
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()

In [9]:
from metrics import sample_from_classes, interpolation, cyclic_interpolation
stacked.features.eval()

im_shape = (28, 28, 1)
classes_num = 10
z_dim = 100
labels_names = [str(idx) for idx in range(10)]

cyclic_interpolation("adgm_mnist", adgm, validation.dataset, im_shape, classes_num, labels_names)
cyclic_interpolation("m1m2_mnist", stacked, validation.dataset, im_shape, classes_num, labels_names)

interpolation("adgm_mnist", adgm, validation.dataset, im_shape)
interpolation("m1m2_mnist", stacked, validation.dataset, im_shape)

sample_from_classes("adgm_mnist", adgm, im_shape, 100, classes_num)
sample_from_classes("m1m2_mnist", stacked, im_shape, 50, classes_num)

