In [1]:
import torch
cuda = torch.cuda.is_available()
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import sys
sys.path.append("../../semi-supervised")

In [2]:
from models import ProdLDADeepGenerativeModel

In [3]:
y_dim = 10
z_dim = 32
h_dim = [256, 128]

num_topics = y_dim
a = 1.0
prior_mean = np.log(a) - np.mean(np.log(a))
prior_var = (((1.0 / a) * (1 - (2.0 / num_topics))) + (1.0 / (num_topics * num_topics)) * np.sum((1.0 / a)))

In [4]:
model = ProdLDADeepGenerativeModel([784, y_dim, z_dim, h_dim], prior_mean, prior_var)

In [5]:
model

ProdLDADeepGenerativeModel(
  (encoder): Encoder(
    (hidden): ModuleList(
      (0): Linear(in_features=794, out_features=256, bias=True)
      (1): Linear(in_features=256, out_features=128, bias=True)
    )
    (sample): GaussianSample(
      (mu): Linear(in_features=128, out_features=32, bias=True)
      (log_var): Linear(in_features=128, out_features=32, bias=True)
    )
    (mean_norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (logvar_norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (decoder): Decoder(
    (activation): Softmax()
    (drop): Dropout(p=0.2)
    (hidden): ModuleList(
      (0): Linear(in_features=42, out_features=784, bias=True)
    )
    (norm): BatchNorm1d(784, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (output_activation): Softmax()
  )
  (classifier): Classifier(
    (dense): Linear(in_features=784, out_features=256, bias=True)
    (logits): Linear(in

In [6]:
from datautils import get_mnist

# Only use 10 labelled examples per class
# The rest of the data is unlabelled.
labelled, unlabelled, validation = get_mnist(location="./", batch_size=64, labels_per_class=10)
alpha = 0.1 * len(unlabelled) / len(labelled)

def binary_cross_entropy(r, x):
    return -torch.sum(x * torch.log(r + 1e-8) + (1 - x) * torch.log(1 - r + 1e-8), dim=-1)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.9, 0.999))

In [7]:
from itertools import cycle
from inference import SVI, ImportanceWeightedSampler

# You can use importance weighted samples [Burda, 2015] to get a better estimate
# on the log-likelihood.
sampler = ImportanceWeightedSampler(mc=1, iw=1)

if cuda: model = model.cuda()
elbo = SVI(model, likelihood=binary_cross_entropy, sampler=sampler)

In [249]:
from torch.autograd import Variable

def train_semi_supervised(labelled, unlabelled, validation, cuda, epochs=4):
    for epoch in range(epochs):
        model.train()
        total_loss, accuracy = (0, 0)
        for (x, y), (u, _) in zip(cycle(labelled), unlabelled):
            x, y, u = torch.from_numpy(x).float(), torch.from_numpy(y).float(), torch.from_numpy(u).float()
            # Wrap in variables
            x, y, u = Variable(x), Variable(y), Variable(u)

            if cuda:
                # They need to be on the same device and be synchronized.
                x, y = x.cuda(device=0), y.cuda(device=0)
                u = u.cuda(device=0)

            L = -elbo(x, y)
            U = -elbo(u)

            # Add auxiliary classification loss q(y|x)
            logits = model.classify(x)

            # Regular cross entropy
            classication_loss = torch.sum(y * torch.log(logits + 1e-8), dim=1).mean()

            J_alpha = L - alpha * classication_loss + U

            J_alpha.backward()
            optimizer.step()
            optimizer.zero_grad()

            total_loss += J_alpha.data.item()
            accuracy += torch.mean((torch.max(logits, 1)[1].data == torch.max(y, 1)[1].data).float())

        if epoch % 1 == 0:
            model.eval()
            m = len(unlabelled)
            print("Epoch: {}".format(epoch))
            print("[Train]\t\t J_a: {:.2f}, accuracy: {:.2f}".format(total_loss / m, accuracy / m))

            total_loss, accuracy = (0, 0)
            for x, y in validation:
                x, y = Variable(x), Variable(y)

                if cuda:
                    x, y = x.cuda(device=0), y.cuda(device=0)

                L = -elbo(x, y)
                U = -elbo(x)

                logits = model.classify(x)
                classication_loss = -torch.sum(y * torch.log(logits + 1e-8), dim=1).mean()

                J_alpha = L + alpha * classication_loss + U

                total_loss += J_alpha.data.item()

                _, pred_idx = torch.max(logits, 1)
                _, lab_idx = torch.max(y, 1)
                accuracy += torch.mean((torch.max(logits, 1)[1].data == torch.max(y, 1)[1].data).float())

            m = len(validation)
            print("[Validation]\t J_a: {:.2f}, accuracy: {:.2f}".format(total_loss / m, accuracy / m))

# Generate MC samples

In [274]:
def mc_samples(num_mc_samples, model, x_batch):
    x_batch = torch.from_numpy(x_batch).float()
    if cuda:
        x_batch = x_batch.cuda()
    model.train()
    mc_samples_ = [model.classify(x_batch).cpu().detach().numpy() for _ in range(num_mc_samples)]
    return np.array(mc_samples_)

In [275]:
def bald_acq(mc_samples):
    expected_entropy = - np.mean(np.sum(mc_samples * np.log(mc_samples + 1e-10), axis=-1), axis=0)  # [batch size]
    expected_p = np.mean(mc_samples, axis=0)
    entropy_expected_p = - np.sum(expected_p * np.log(expected_p + 1e-10), axis=-1)  # [batch size]
    BALD_acq = entropy_expected_p - expected_entropy
    
    return BALD_acq

In [276]:
def query_new_data(num_data, num_mc_samples, model, unlabelled_data):
    mc_samples_ = mc_samples(num_mc_samples, model, unlabelled_data)
    bald_acq_ = bald_acq(mc_samples_)
    
    return bald_acq_.argsort()[::-1][:num_data]

In [291]:
NUM_MC_SAMPLES = 100
NUM_QUERY = 60

In [278]:
train_labelled = list(labelled)
train_unlabelled = list(unlabelled)

# Active Learning

In [279]:
import torchvision.datasets as datasets

In [280]:
mnist_train = datasets.MNIST(root=".", train=True, download=True, transform=None)

In [281]:
mnist_train

Dataset MNIST
    Number of datapoints: 60000
    Split: train
    Root Location: .
    Transforms (if any): None
    Target Transforms (if any): None

In [282]:
mnist_train_data = mnist_train.train_data.reshape(60000, -1).numpy().astype(np.float64)
mnist_train_label = mnist_train.train_labels.numpy().astype(np.int64)
one_hot_label = np.zeros([mnist_train_label.shape[0], 10])
one_hot_label[np.arange(mnist_train_label.shape[0]), mnist_train_label] = 1
mnist_train_label = one_hot_label

In [283]:
mnist_train_data.shape, mnist_train_label.shape

((60000, 784), (60000, 10))

In [284]:
num_labelled = 300
labelled_idx = np.random.choice(mnist_train_data.shape[0], num_labelled, replace=False)

In [285]:
mask = np.zeros(mnist_train_data.shape[0])
mask[labelled_idx] = 1
mask = mask.astype(np.bool)

In [286]:
labelled = mnist_train_data[mask], mnist_train_label[mask]
unlabelled = mnist_train_data[~mask], mnist_train_label[~mask]

In [287]:
def create_batch(data, batch_size=64):
    x, y = data
    batch_idx = np.random.choice(x.shape[0], batch_size, replace=False)

    return x[batch_idx], y[batch_idx]

In [288]:
def create_data_sets(labelled, unlabelled, batch_size):
    num_labelled = labelled[0].shape[0]
    num_unlabelled = unlabelled[0].shape[0]

    train_labelled = [create_batch(labelled, batch_size) for _ in range(num_labelled // batch_size)]
    train_unlabelled = [create_batch(labelled, batch_size) for _ in range(num_unlabelled // batch_size)]
    
    return train_labelled, train_unlabelled

In [289]:
def rearange_datasets(labelled, unlabelled, new_data):
    labelled_x, labelled_y = labelled
    unlabelled_x, unlabelled_y = unlabelled
    
    new_data_x, new_data_y = unlabelled_x[new_data], unlabelled_y[new_data]
    
    new_labelled_x = np.append(labelled_x, new_data_x, axis=0)
    new_labelled_y = np.append(labelled_y, new_data_y, axis=0)
    
    new_unlabelled_x = np.delete(unlabelled_x, new_data, axis=0)
    new_unlabelled_y = np.delete(unlabelled_y, new_data, axis=0)
    
    return (new_labelled_x, new_labelled_y), (new_unlabelled_x, new_unlabelled_y)

In [292]:
batch_size=60
for i in range(2):
    train_labelled, train_unlabelled = create_data_sets(labelled, unlabelled, batch_size)
    print(labelled[0].shape[0], unlabelled[0].shape[0])
    train_semi_supervised(train_labelled, train_unlabelled, validation, cuda, epochs=4)

    new_data = query_new_data(NUM_QUERY, NUM_MC_SAMPLES, model, unlabelled[0])
    labelled, unlabelled = rearange_datasets(labelled, unlabelled, new_data)

310 59690
Epoch: 0
[Train]		 J_a: 302192.51, accuracy: 0.62
[Validation]	 J_a: 1422.20, accuracy: 0.60
Epoch: 1
[Train]		 J_a: 302083.98, accuracy: 0.62
[Validation]	 J_a: 1420.13, accuracy: 0.60
Epoch: 2
[Train]		 J_a: 301933.31, accuracy: 0.62
[Validation]	 J_a: 1419.01, accuracy: 0.60
Epoch: 3
[Train]		 J_a: 301906.13, accuracy: 0.62
[Validation]	 J_a: 1422.71, accuracy: 0.60
370 59630
Epoch: 0
[Train]		 J_a: 298963.00, accuracy: 0.65
[Validation]	 J_a: 1416.88, accuracy: 0.62
Epoch: 1
[Train]		 J_a: 298923.97, accuracy: 0.66
[Validation]	 J_a: 1416.59, accuracy: 0.62
Epoch: 2
[Train]		 J_a: 298921.29, accuracy: 0.66
[Validation]	 J_a: 1416.11, accuracy: 0.61
Epoch: 3
[Train]		 J_a: 298905.61, accuracy: 0.66
[Validation]	 J_a: 1415.63, accuracy: 0.62
