In [1]:
# Imports
import torch
cuda = torch.cuda.is_available()
import numpy as np
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
plt.switch_backend("agg")
import sys
sys.path.append("../../semi-supervised")

from models import AuxiliaryDeepGenerativeModel, DeepGenerativeModel, StackedDeepGenerativeModel, VariationalAutoencoder

In [2]:
features = VariationalAutoencoder(
    [784, 50, [500, 500]],
    batch_norm=False,
    activation_fn=torch.nn.Softplus).cuda()
features.load_state_dict(torch.load("./vae_mnist_new.ckpt"))

print("After VAE")
stacked = StackedDeepGenerativeModel(
    [784, 10, 50, [300]], features,
    batch_norm=False,
    activation_fn=torch.nn.Softplus)
stacked.dgm.load_state_dict(torch.load("./m1m2_mnist_new.ckpt"))
stacked.dgm = stacked.dgm.cuda()
print(stacked.features, stacked.dgm)
print("After stacked")

adgm = AuxiliaryDeepGenerativeModel([784, 10, 100, 100, [500, 500]], batch_norm=False)
adgm.load_state_dict(torch.load("./adgm_mnist_new.ckpt"))
adgm = adgm.cuda()
adgm




  init.xavier_normal(m.weight.data)


After VAE
50
VariationalAutoencoder(
  (encoder): Encoder(
    (first_dense): ModuleList(
      (0): Linear(in_features=784, out_features=500, bias=True)
    )
    (hidden): ModuleList(
      (0): Softplus(beta=1, threshold=20)
      (1): Linear(in_features=500, out_features=500, bias=True)
      (2): Softplus(beta=1, threshold=20)
    )
    (sample): GaussianSample(
      (mu): Linear(in_features=500, out_features=50, bias=True)
      (log_var): Linear(in_features=500, out_features=50, bias=True)
    )
  )
  (decoder): Decoder(
    (first_dense): ModuleList(
      (0): Linear(in_features=50, out_features=500, bias=True)
    )
    (hidden): ModuleList(
      (0): Softplus(beta=1, threshold=20)
      (1): Linear(in_features=500, out_features=500, bias=True)
      (2): Softplus(beta=1, threshold=20)
    )
    (reconstruction): Linear(in_features=500, out_features=784, bias=True)
  )
) DeepGenerativeModel(
  (encoder): Encoder(
    (first_dense): ModuleList(
      (0): Linear(in_features=

  init.xavier_normal(m.weight.data)


AuxiliaryDeepGenerativeModel(
  (encoder): Encoder(
    (first_dense): ModuleList(
      (0): Linear(in_features=784, out_features=500, bias=True)
      (1): Linear(in_features=10, out_features=500, bias=True)
      (2): Linear(in_features=100, out_features=500, bias=True)
    )
    (hidden): ModuleList(
      (0): ReLU()
      (1): Linear(in_features=500, out_features=500, bias=True)
      (2): ReLU()
    )
    (sample): GaussianSample(
      (mu): Linear(in_features=500, out_features=100, bias=True)
      (log_var): Linear(in_features=500, out_features=100, bias=True)
    )
  )
  (decoder): Decoder(
    (first_dense): ModuleList(
      (0): Linear(in_features=100, out_features=500, bias=True)
      (1): Linear(in_features=10, out_features=500, bias=True)
    )
    (hidden): ModuleList(
      (0): ReLU()
      (1): Linear(in_features=500, out_features=500, bias=True)
      (2): ReLU()
    )
    (reconstruction): Linear(in_features=500, out_features=784, bias=True)
  )
  (classifier): 

In [3]:
from datautils import get_mnist, get_svhn

labelled, unlabelled, validation, mnist_mean, mnist_std = get_mnist(location="./", batch_size=100, labels_per_class=10, preprocess=False)


In [4]:
adgm.eval()
stacked.dgm.eval()
z_dim = 100

z = torch.randn(100, z_dim).cuda()
y = np.zeros((100, 10))
y[np.arange(100), np.arange(100) // 10] = 1.
y = torch.tensor(y, dtype=torch.float).cuda()

x_mu = adgm.sample(z, y)




In [7]:
adgm.eval()

accuracy = 0.
for x, y in validation:

    if cuda:
        x, y = x.cuda(device=0), y.cuda(device=0)

    # x, _, _ = features.encoder(x)
    logits = adgm.classify(x.repeat(100, 1))
    logits = logits.reshape(100, -1, logits.shape[-1]).mean(0)
    _, pred_idx = torch.max(logits, 1)
    _, lab_idx = torch.max(y, 1)
    accuracy += torch.mean((torch.max(logits, 1)[1].data == torch.max(y, 1)[1].data).float())

print("ADGM test error", 100 - accuracy.item())

stacked.features.eval()
stacked.dgm.eval()

accuracy = 0.
for x, y in validation:

    if cuda:
        x, y = x.cuda(device=0), y.cuda(device=0)
        
    x, _, _ = stacked.features.encoder(x)
    logits = stacked.dgm.classify(x.repeat(100, 1))
    logits = logits.reshape(100, -1, logits.shape[-1]).mean(0)
    _, pred_idx = torch.max(logits, 1)
    _, lab_idx = torch.max(y, 1)
    accuracy += torch.mean((torch.max(logits, 1)[1].data == torch.max(y, 1)[1].data).float())
    
print("Stacked test error", 100 - accuracy.item())

ADGM test error 4.0399932861328125
Stacked test error 17.04998779296875


In [None]:
f, axarr = plt.subplots(10, 10, figsize=(10, 10))

samples = x_mu.cpu().data.view(-1, 28, 28).numpy()
# samples = x_mu.data.view(-1, 3, 32, 32).cpu().numpy().transpose(0, 2, 3, 1)


# mnist_means = np.tile(mnist_mean.reshape((1, -1)), (len(samples), 1))
# mnist_means[:, mnist_std > 0.1] = samples
# samples = mnist_means.reshape(-1, 28, 28)


for i, ax in enumerate(axarr.flat):
    ax.imshow(samples[i], cmap="gray")
    ax.axis("off")
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()

In [None]:
from metrics import sample_from_classes, interpolation, cyclic_interpolation, save_samples
stacked.features.eval()
stacked.dgm.eval()

im_shape = [28, 28, 1]
classes_num = 10
z_dim = 100
labels_names = [str(idx) for idx in range(10)]

cyclic_interpolation("adgm_mnist", adgm, validation.dataset, im_shape, classes_num, labels_names)
cyclic_interpolation("m1m2_mnist", stacked, validation.dataset, im_shape, classes_num, labels_names)

interpolation("adgm_mnist", adgm, validation.dataset, im_shape)
interpolation("m1m2_mnist", stacked, validation.dataset, im_shape)

sample_from_classes("adgm_mnist", adgm, im_shape, 100, classes_num)
sample_from_classes("m1m2_mnist", stacked, im_shape, 50, classes_num)




In [None]:
save_samples("adgm_mnist", adgm, im_shape, 10000, classes_num, 100)
save_samples("m1m2_mnist", stacked, im_shape, 10000, classes_num, 50)