In [1]:
# Imports
import torch
cuda = torch.cuda.is_available()
import numpy as np
import matplotlib
matplotlib.use("agg")
import matplotlib.pyplot as plt
plt.switch_backend("agg")
import sys
sys.path.append("../../semi-supervised")

torch.multiprocessing.set_sharing_strategy('file_system')

from models import AuxiliaryDeepGenerativeModel, DeepGenerativeModel, StackedDeepGenerativeModel, VariationalAutoencoder

In [2]:
features = VariationalAutoencoder(
    [32 * 32 * 3, 300, [600, 600]],
    activation_fn=torch.nn.Softplus,
    batch_norm=False).cuda()
features.load_state_dict(torch.load("./vae_svhn_new.ckpt"))

stacked = StackedDeepGenerativeModel(
    [32 * 32 * 3, 10, 100, [500]],
    features,
    activation_fn=torch.nn.Softplus,
    batch_norm=False
)

stacked.dgm.load_state_dict(torch.load("./m1m2_svhn_new.ckpt"))
stacked.dgm = stacked.dgm.cuda()

adgm = AuxiliaryDeepGenerativeModel([3072, 10, 300, 300, [1000, 1000]], batch_norm=False)
adgm.load_state_dict(torch.load("./adgm_svhn_new.ckpt"))
adgm = adgm.cuda()

adgm.eval()
stacked.dgm.eval()
stacked.features.eval()



  init.xavier_normal(m.weight.data)


300


  init.xavier_normal(m.weight.data)


3072
[3072, 300]


VariationalAutoencoder(
  (encoder): Encoder(
    (first_dense): ModuleList(
      (0): Linear(in_features=3072, out_features=600, bias=True)
    )
    (hidden): ModuleList(
      (0): Softplus(beta=1, threshold=20)
      (1): Linear(in_features=600, out_features=600, bias=True)
      (2): Softplus(beta=1, threshold=20)
    )
    (sample): GaussianSample(
      (mu): Linear(in_features=600, out_features=300, bias=True)
      (log_var): Linear(in_features=600, out_features=300, bias=True)
    )
  )
  (decoder): Decoder(
    (first_dense): ModuleList(
      (0): Linear(in_features=300, out_features=600, bias=True)
    )
    (hidden): ModuleList(
      (0): Softplus(beta=1, threshold=20)
      (1): Linear(in_features=600, out_features=600, bias=True)
      (2): Softplus(beta=1, threshold=20)
    )
    (reconstruction): Linear(in_features=600, out_features=3072, bias=True)
  )
)

In [3]:
from datautils import get_mnist, get_svhn

labelled, unlabelled, validation, _ = get_svhn(location="./", batch_size=1000, labels_per_class=100, extra=False)



Using downloaded and verified file: ./train_32x32.mat
Len of svhn train 73257
Using downloaded and verified file: ./test_32x32.mat


In [4]:
z_dim = 300

z = torch.randn(100, z_dim).cuda()
y = np.zeros((100, 10))
y[np.arange(100), np.arange(100) // 10] = 1
y = torch.tensor(y, dtype=torch.float).cuda()

x_mu = adgm.sample(z, y)




In [6]:
from tqdm import tqdm_notebook
adgm.eval()

correct_preds = 0
for x, y in tqdm_notebook(validation):

    if cuda:
        x, y = x.cuda(device=0), y.cuda(device=0)
        
    # print(y)

    # x, _, _ = features.encoder(x)
    x = x.repeat(10, 1)
    logits = adgm.classify(x)
    logits = logits.reshape(10, -1, logits.shape[-1]).mean(0)
    _, pred_idx = torch.max(logits, 1)
    _, lab_idx = torch.max(y, 1)
    correct_preds += torch.sum((torch.max(logits, 1)[1].data == torch.max(y, 1)[1].data).float())

accuracy = (correct_preds / len(validation.dataset)).item()
print("ADGM Accuracy {:.3f}\tError: {:.3f}".format(accuracy, 1 - accuracy))

stacked.features.eval()
stacked.dgm.eval()
correct_preds = 0
for x, y in tqdm_notebook(validation):

    if cuda:
        x, y = x.cuda(device=0), y.cuda(device=0)
        
    # print(y)

    x, _, _ = stacked.features.encoder(x)
    logits = stacked.dgm.classify(x)
    _, pred_idx = torch.max(logits, 1)
    _, lab_idx = torch.max(y, 1)
    correct_preds += torch.sum((torch.max(logits, 1)[1].data == torch.max(y, 1)[1].data).float())

accuracy = (correct_preds / len(validation.dataset)).item()
print("Stacked Accuracy {:.3f}\tError: {:.3f}".format(accuracy, 1 - accuracy))


HBox(children=(IntProgress(value=0, max=27), HTML(value='')))


ADGM Accuracy 0.589	Error: 0.411


HBox(children=(IntProgress(value=0, max=27), HTML(value='')))


Stacked Accuracy 0.346	Error: 0.654


In [13]:
from metrics import sample_from_classes, interpolation, cyclic_interpolation, save_samples

adgm.eval()
stacked.dgm.eval()
stacked.features.eval()

im_shape = [32, 32, 3]
classes_num = 10
z_dim = 100
labels_names = [str(idx) for idx in range(10)]



In [15]:
cyclic_interpolation("adgm_svhn", adgm, validation.dataset, im_shape, classes_num, labels_names)
cyclic_interpolation("m1m2_svhn", stacked, validation.dataset, im_shape, classes_num, labels_names)

interpolation("adgm_svhn", adgm, validation.dataset, im_shape)
interpolation("m1m2_svhn", stacked, validation.dataset, im_shape)

sample_from_classes("adgm_svhn", adgm, im_shape, 300, classes_num)
sample_from_classes("m1m2_svhn", stacked, im_shape, 100, classes_num)

In [17]:
save_samples("adgm_svhn", adgm, im_shape, 10000, classes_num, 300)
save_samples("m1m2_svhn", stacked, im_shape, 10000, classes_num, 100)

100%|██████████| 10/10 [00:03<00:00,  2.61it/s]
100%|██████████| 10/10 [00:04<00:00,  2.24it/s]
