In [2]:
# @title Imports & Data
import itertools
import math
import time

import numpy as np
import torch
import torchvision

import keras
from keras.datasets import mnist

from tensorflow.keras.utils import to_categorical

import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.utils.data
from torch.distributions import Normal

torch.manual_seed(123)

print(torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
torch.backends.cudnn.benchmark = True


def load_mnist(digits = None, conv = False):
    # Get MNIST test data
    #X_train, Y_train, X_test, Y_test = data_mnist()
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = X_train.reshape(60000, 784)
    X_test = X_test.reshape(10000, 784)
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    X_train /= 255
    X_test /= 255

    # convert class vectors to binary class matrices
    Y_train = to_categorical(y_train, 10)
    Y_test = to_categorical(y_test, 10)

    # collect the corresponding digits
    if digits is not None:
        ind_train = []
        ind_test = []
        for i in digits:
            ind_train = ind_train + list(np.where(Y_train[:, i] == 1)[0])
            ind_test = ind_test + list(np.where(Y_test[:, i] == 1)[0])
        X_train = X_train[ind_train]; Y_train = Y_train[ind_train]
        X_test = X_test[ind_test]; Y_test = Y_test[ind_test]

    if conv:
        X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
        X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)

    return X_train, X_test, Y_train, Y_test

def single_digit_loader(X, label, b_size=10):
    X.reshape(-1,dimX)
    N = X.shape[0]
    for i in range(N//b_size):
        yield (torch.from_numpy(X[i*b_size:(i+1)*b_size,:]), torch.from_numpy(np.ones(b_size,dtype=int)))
    if (N/b_size != 0.0):
        end = list(range((N//b_size)*b_size,X.shape[0]))
        n_missing = b_size - len(end)
        last_batch_ind = end + list(range(n_missing))
        yield (torch.from_numpy(X[last_batch_ind,:]), torch.from_numpy(np.ones(b_size,dtype=int)*label))



def create_mnist_single_digit_loaders(b_size=10, train_data=True):
    loaders = []
    for i in range(10):
        X_train, X_test, Y_train, Y_test = load_mnist(digits = [i])
        if degenerate_dataset:
            X_train = X_train[0,:].reshape(1,-1).repeat(X_train.shape[0], axis=0)
        if train_data:
            N_train = int(X_train.shape[0] * 0.9) if scale_down_090 else X_train.shape[0]
            X_train = X_train[:N_train]
            loaders.append(BatchWrapper(X_train,i,b_size))
        else:
            loaders.append(BatchWrapper(X_test, i, b_size))
    return loaders

True
Using device: cuda:0


In [3]:
# @title Evaluation Function (Test LL)

def IS_estimate(x, task_model, K):
    x = x.view(-1, 28 ** 2)
    x_rep = x.repeat([K, 1]).to(device=device)
    assert(x_rep.size()[0] < 6000)

    N = x.size()[0]
    Zs_params = task_model.enc(x_rep)
    mu_qz, log_sig_qz = Zs_to_mu_sig(Zs_params)
    z = task_model.sampler(Zs_params)
    mu_x = task_model.dec_shared(task_model.dec_head(z))
    logp = log_bernoulli(x_rep, mu_x)

    log_prior = log_gaussian_prob(z)
    logq = log_gaussian_prob(z, mu_qz, log_sig_qz)
    kl_z = logq - log_prior

    bound = torch.reshape(logp - kl_z, (K, N))
    bound_max = torch.max(bound, 0)[0]
    bound -= bound_max
    log_norm = torch.log(torch.clamp(torch.mean(torch.exp(bound), 0), 1e-9, np.inf))

    test_ll = log_norm + bound_max
    test_ll_mean = torch.mean(test_ll).item()
    test_ll_var = torch.mean((test_ll - test_ll_mean) ** 2).item()

    return test_ll_mean, test_ll_var


class Evaluation:

    def __init__(self, should_print=True, K=100):
        self.should_print = should_print
        self.K = K

    def __call__(self, task_id, task_model, loader):
        N = 0
        bound_tot = 0.0
        bound_var = 0.0
        begin = time.time()
        batches = len(loader)
        for j in range(len(loader)):
            inputs, labels = loader[j]
            N += len(inputs)
            logp_mean, logp_var = IS_estimate(inputs, task_model, self.K)
            bound_tot += logp_mean / batches
            bound_var += logp_var / batches
        end = time.time()
        if self.should_print:
            print("task %d test_ll=%.2f, ste=%.2f, time=%.2f" \
                  % (task_id, bound_tot, np.sqrt(bound_var / N), end - begin))
        return (bound_tot, np.sqrt(bound_var / N))


In [4]:
# @title Sampling Function

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def KL_div_gaussian(mu_p, log_sig_p, mu_q, log_sig_q):
    # compute KL[p||q]
    precision_q = torch.exp(-2 * log_sig_q)
    kl = 0.5 * (mu_p - mu_q) ** 2 * precision_q - 0.5
    kl += log_sig_q - log_sig_p
    kl += 0.5 * torch.exp(2 * log_sig_p - 2 * log_sig_q)
    return torch.sum(kl, dim=list(range(1, len(kl.shape))))

def KL_div_gaussian_from_standard_normal(mu_q, log_sig_q):
    # 0,0 corresponds to N(0,1) due to the log_sig representation, works for multidim normal as well.
    return KL_div_gaussian(mu_q, log_sig_q, torch.zeros(1, device=device), torch.zeros(1, device=device))

def Zs_to_mu_sig(Zs_params):
    dimZ = Zs_params.shape[1] // 2  # 1st is batch size 2nd is 2*dimZ
    mu_qz = Zs_params[:, :dimZ]
    log_sig_qz = Zs_params[:, dimZ:]
    return mu_qz, log_sig_qz

forced_interval = (1e-9, 1.0)

def log_bernoulli(X, Mu_Reconstructed_X):
    """
    Mu_Reconstructed_X is the output of the decoder. We accept fractions, and project them to the interval 'forced_interval' for numerical stability
    """
    logprob = X * torch.log(torch.clamp(Mu_Reconstructed_X, *forced_interval)) + (1 - X) * torch.log(torch.clamp((1.0 - Mu_Reconstructed_X), *forced_interval))

    return torch.sum(logprob.view(logprob.size()[0], -1), dim=1)  # sum all but first dim

def log_gaussian_prob(x, mu=torch.zeros(1, device=device), log_sig=torch.zeros(1, device=device)):
    logprob = -(0.5 * np.log(2 * np.pi) + log_sig) \
              - 0.5 * ((x - mu) / torch.exp(log_sig)) ** 2
    return torch.sum(logprob.view(logprob.size()[0], -1), dim=1)  # sum all but first dim

def log_P_y_GIVEN_x(Xs, enc, sample_and_decode, NumLogPSamples=100):
    """
    Returns logP(Y|X), KL(Z||Normal(0,1))
    """
    Zs_params = enc(Xs)
    mu_qz, log_sig_qz = Zs_to_mu_sig(Zs_params)
    kl_z = KL_div_gaussian_from_standard_normal(mu_qz, log_sig_qz)
    logp = 0.0
    for _ in range(NumLogPSamples):
        # The Zs_params are the deterministic result of enc(Xs) so we don't recalculate them
        Mu_Ys = sample_and_decode(Zs_params)
        logp += log_bernoulli(Xs, Mu_Ys) / NumLogPSamples
    return logp, kl_z


In [5]:
# @title Encoder
class mlp_layer(nn.Module):
    def __init__(self, d_in, d_out, activation):
        """
        Activation is a function (eg. torch.nn.functional.sigmoid/relu)
        """
        super().__init__()
        self.mu = nn.Linear(d_in, d_out).to(device=device)
        with torch.no_grad():
            self._init_weights(d_in, d_out)
        self.activation = activation

    def forward(self, x):
        if (weight_print):
            print("Weights of ENC ", self.mu.weight)
            print("bias of ENC ", self.mu.bias)
        return self.activation(self.mu(x))

    def _init_weights(self, input_size, output_size, constant=1.0):
        scale = constant * np.sqrt(6.0 / (input_size + output_size))
        assert (output_size > 0)
        nn.init.uniform_(self.mu.weight, -scale, scale)
        nn.init.zeros_(self.mu.bias)

    @property
    def d_out(self):
        return self.mu.weight.shape[0]

    @property
    def d_in(self):
        return self.mu.weight.shape[1]

In [6]:
# @title Bayesian Decoder
class bayesian_mlp_layer(mlp_layer):
    def __init__(self, d_in, d_out, activation):
        """
        Activation is a function (eg. torch.nn.functional.sigmoid/relu)
        """
        super().__init__(d_in, d_out, activation)
        self.log_sigma = nn.Linear(d_in, d_out).to(device=device)
        with torch.no_grad():
            self._init_log_sigma()
        # mu is initialized the same as non-Bayesian mlp

        self.w_standard_normal_sampler = Normal(torch.zeros(self.mu.weight.shape, device=device), torch.ones(self.mu.weight.shape, device=device))
        self.b_standard_normal_sampler = Normal(torch.zeros(self.mu.bias.shape, device=device), torch.ones(self.mu.bias.shape, device=device))

        self.sampling = True

    def forward(self, x):

        if (weight_print):
            print("Weights of mu DEC ", self.mu.weight)
            print("bias of mu DEC ", self.mu.bias)
            print("Weights of log_sig DEC ", self.log_sigma.weight)
            print("bias of log_sig DEC ", self.log_sigma.bias)

        if self.sampling:
            sampled_W = (self.mu.weight + torch.randn_like(self.mu.weight) * torch.exp(self.log_sigma.weight))
            sampled_b = (self.mu.bias + torch.randn_like(self.mu.bias) * torch.exp(self.log_sigma.bias))
            return self.activation(torch.einsum('ij,bj->bi',[sampled_W, x]) + sampled_b)
        else:
            return super().forward(x)

    def _init_log_sigma(self):
        nn.init.constant_(self.log_sigma.weight, -6.0)
        nn.init.constant_(self.log_sigma.bias, -6.0)

    def get_posterior(self):
        return [(self.mu.weight, self.log_sigma.weight), (self.mu.bias, self.log_sigma.bias)]

class NormalSamplingLayer(nn.Module):
    def __init__(self, d_out):
        super().__init__()
        self.d_out = d_out

    def forward(self, mu_log_sigma_vec):
        mu = mu_log_sigma_vec[:, :self.d_out]
        return mu + torch.randn_like(mu) * torch.exp(mu_log_sigma_vec[:, self.d_out:])

class SharedDecoder(nn.Module):
    def __init__(self, dims, activations):
        super().__init__()
        # Not sure if device does anything
        self.net = nn.Sequential(*[bayesian_mlp_layer(dims[i], dims[i + 1], activations[i]) \
                                   for i in range(len(activations))])
        self._init_prior()

    def forward(self, Xs):
        return self.net(Xs)

    def _get_posterior(self):
        return list(itertools.chain(*list(map(lambda f: f.get_posterior(), self.net.children()))))

    def _init_prior(self):
        """
        Initialize a constant tensor that corresponds to a prior distribution over all the weights
        which is standard normal
        """
        self.prior = [(torch.zeros(mu.shape, device=device), torch.zeros(log_sig.shape, device=device)) for mu, log_sig in self._get_posterior()]
        for mu, log_sig in self.prior:
            mu.requires_grad = False
            log_sig.requires_grad = False

    def update_prior(self):
        """
        Copy the current posterior to a constant tensor, which will be used as prior for the next task
        """
        posterior = self._get_posterior()
        self.prior = [(mu.clone().detach(), log_sig.clone().detach()) for mu, log_sig in posterior]
        #update the new posterior's log_sig to -6
        with torch.no_grad():
            for mu_sig in posterior:
                mu_sig[1].fill_(-6.0)


    def KL_from_prior(self):
        params = [(*post, *prior) for (post, prior) in zip(self._get_posterior(), self.prior)]
        KL = torch.zeros(1, device=device).squeeze() #don't know how to generate a zero scalar
        for param in params:
            unsqueezed_param = list(map(lambda x: x.unsqueeze(0), param))
            tmp = KL_div_gaussian(*unsqueezed_param)
            KL += tmp.squeeze()

        return KL

    @property
    def d_in(self):
        return self.net.d_in

    @property
    def d_out(self):
        return self.net.d_out

In [None]:
# @title Experiment 1 (Mean-Field Bayesian VAE)


DATASET = 'mnist'

Train = True
max_task = 10

weight_print = False
data_print = False
loss_print = False

scale_down_090 = True
degenerate_dataset = False

dimX = 28 * 28
dimH = 500
dimZ = 50
batch_size = 50
n_epochs = 400

# Encoder
enc_dims = [dimX, dimH, dimZ * 2, dimH, dimZ * 2]
enc_activations = [F.relu_, lambda x: x, F.relu_, lambda x: x]

# Sample here

# Private decoder (Head)
dec_head_dims = [dimZ, dimH, dimZ]
dec_head_activations = [F.relu_, lambda x: x]

# Sample here

# Shared decoder
dec_shared_dims = [dimZ, dimH, dimX]
dec_shared_activations = [F.relu_, torch.sigmoid]

# BayesianVAE
class TaskModel(nn.Module):
    def __init__(self, enc_dims_activations, dec_head_dims_activations, dec_shared, learning_rate=1e-4):
        super().__init__()
        # Define Encoder
        my_enc_dims, my_enc_activations = enc_dims_activations

        self.enc = nn.Sequential(*[mlp_layer(my_enc_dims[i], my_enc_dims[i + 1], my_enc_activations[i])
                                   for i in range(len(my_enc_activations))])

        # Define private decoder (head)
        my_dec_head_dims, my_dec_head_activations = dec_head_dims_activations

        self.dec_head = nn.Sequential(
            *[bayesian_mlp_layer(my_dec_head_dims[i], my_dec_head_dims[i + 1], my_dec_head_activations[i])
              for i in range(len(my_dec_head_activations))])

        # Define shared decoder
        self.dec_shared = dec_shared
        self.printer = PrintLayer()

        # Combine components
        self.sampler = NormalSamplingLayer(my_dec_head_dims[0])
        self.sample_and_decode = nn.Sequential(*[self.sampler, self.dec_head, self.dec_shared])
        self.decode = nn.Sequential(*[self.dec_head, self.dec_shared])

        # update just before training
        self.DatasetSize = None

        # Guards from retraining- train only once
        self.TrainGuard = True

        self.optimizer = self._create_optimizer(learning_rate)

    def set_sampling(self, sampling):
        for model in self.dec_head.children():
            model.sampling = sampling
        for model in self.dec_shared.net.children():
            model.sampling = sampling

    def save_model(self, path):
        torch.save(self.state_dict(), path)

    @classmethod
    def load_model(cls, path, uninitialized_instance):
        uninitialized_instance.load_state_dict(torch.load(path,map_location=device))
        uninitialized_instance.eval()
        return uninitialized_instance

    def forward(self, Xs):
        logp, kl_z = log_P_y_GIVEN_x(Xs, self.enc, self.sample_and_decode)  #POINTER
        kl_shared_dec_Qt_2_PREV_Qt = self.dec_shared.KL_from_prior()

        # We ignore the kl(private dec || Normal(0,1) ) like the authors did
        logp_mean = torch.mean(logp)
        kl_z_mean = torch.mean(kl_z)
        kl_Qt_normalized = (kl_shared_dec_Qt_2_PREV_Qt / self.DatasetSize)
        ELBO = logp_mean - kl_z_mean - kl_Qt_normalized
        #if loss_print:
        #    print("Log_like", "\tKL Z","\tKL Qt vs prev Qt")
        #    print('%.2f\t%.2f\t%.2f' % (logp_mean, kl_z_mean, kl_Qt_normalized))

        return -ELBO

    def _create_optimizer(self, learning_rate):
        return torch.optim.Adam(self.parameters(), lr=learning_rate)

    def _update_prior(self):
        self.dec_shared.update_prior()
        # no other priors should be updated. they are trained once.
        return

    def train_model(self, n_epochs, task_trainloader, DatasetSize):
        # We don't intend a TaskModel to be trained more than once
        assert (self.TrainGuard)
        self.TrainGuard = False

        self.DatasetSize = DatasetSize

        # loop over the dataset multiple times
        for epoch in range(n_epochs):
            print("starting epoch " + str(epoch))
            task_trainloader.shuffle()
            running_loss = 0.0
            for i in range(len(task_trainloader)):
                global loss_print
                loss_print = True #(i % 20 == 19)
                # get the inputs
                inputs, labels = task_trainloader[i]
                #Migrate to device (gpu if possible)
                inputs = inputs.to(device=device)
                # step
                self.optimizer.zero_grad()
                # loss = self(inputs.view(-1, self.enc.d_in))
                loss = self(inputs.view(-1, 28 ** 2))
                loss.backward()
                self.optimizer.step()

                # print statistics
                running_loss += loss.item()  # ?
                if i % 20 == 19:  # print every 2000 mini-batches
                    print('[%d, %5d] loss: %.3f' %
                          (epoch + 1, i + 1, running_loss/20))
                    running_loss = 0.0
        # This will set the prior to the current posterior, before we start to change it during training
        self._update_prior()
        self.DatasetSize = None


class PrintLayer(nn.Module):
    def __init__(self):
        super(PrintLayer, self).__init__()
        self.count = 0

    def forward(self, x):
        if (data_print):
            self.count += 1
            print("PRINTER LAYER")
            print(self.count)
            print(x[0].shape)
            print(x[0])
        return x


# In[135]:

class BatchWrapper:
    def __init__(self, X, label, b_size):
        self.X = X.reshape(-1, dimX).astype('float32')
        self.N = X.shape[0]
        self.label = label
        self.b_size = b_size

    def shuffle(self):
        np.random.shuffle(self.X)

    def flat_size(self):
        return self.X.shape[0]

    def __len__(self):
        if (self.N % self.b_size ==0):
            return self.N//self.b_size
        else:
            return (self.N//self.b_size) +1

    def __getitem__(self, i):
        assert(i>=0 and i < len(self))
        if i == len(self) - 1:
            end = list(range((self.N // self.b_size) * self.b_size, self.N))
            n_missing = self.b_size - len(end)
            last_batch_ind = end + list(range(n_missing))
            return (torch.from_numpy(self.X[last_batch_ind, :]), torch.from_numpy(np.ones(self.b_size, dtype=int) * self.label))
        else:
            return (torch.from_numpy(self.X[i*self.b_size:(i+1)*self.b_size,:]), torch.from_numpy(np.ones(self.b_size,dtype=int)))

def path(after, i):
    return './BestParamsBackup/after_task_' + str(after) + '_params_for_task_' + str(i) + '.pt'

def load_models(after):
    dec_shared = SharedDecoder(dec_shared_dims, dec_shared_activations)
    models = []
    for task_id in range(after+1): #range is 0 based
        # Not sure if device does anything
        task_model = TaskModel((enc_dims, enc_activations), (dec_head_dims, dec_head_activations), dec_shared)
        models.append(TaskModel.load_model(path(after,task_id), task_model))
    return models

def generate_pictures(task_models, n_pics=100):
    with torch.no_grad():
        for task_id, task_model in enumerate(task_models):
            task_model.set_sampling(False)
            pics = task_model.sample_and_decode(torch.zeros(n_pics, dimZ * 2, device=device))
            task_model.set_sampling(True)
            pics = pics.cpu()
            plot_images(pics, (28, 28), './figs/', 'after_task_'+str(len(task_models))+'_task_'+str(task_id))

def generate_all_picture_row(task_models, n_pics=10):
    row = np.zeros([n_pics, dimX])
    with torch.no_grad():
        for task_id, task_model in enumerate(task_models):
            task_model.set_sampling(False)
            pic = task_model.sample_and_decode(torch.zeros(1, dimZ * 2, device=device))
            task_model.set_sampling(True)
            pic = pic.cpu()
            row[task_id] = pic
    return row

def main():
    result_path = "./results/MNIST_Torch_cla_results"
    dec_shared = SharedDecoder(dec_shared_dims, dec_shared_activations)

    task_loaders = zip(create_mnist_single_digit_loaders(batch_size), create_mnist_single_digit_loaders(batch_size, train_data=False))

    test_loaders = []
    models = []

    evaluators = [Evaluation()] #classifier is loaded. assumes already trained

    results = np.zeros((len(evaluators), max_task, max_task))

    # A task corresponds to a digit
    for task_id,  (train_loader, test_loader) in enumerate(task_loaders):
        if task_id>max_task:
            break
        print("starting task " + str(task_id))
        if (Train):
            task_model = TaskModel((enc_dims, enc_activations), (dec_head_dims, dec_head_activations), dec_shared)
            models.append(task_model)
            task_model.train_model(n_epochs, train_loader, train_loader.flat_size())
        else:
            models = load_models(task_id)
            task_model = models[-1]
        test_loaders.append(test_loader)
        #Disable gradient calculation during evaluation
        with torch.no_grad():
            """
            if (Train):
                for i, model in enumerate(models):
                    model.save_model(path(task_id, i))
            """
            #generate_pictures(models)
            #tmp = generate_all_picture_row(models)
            #if task_id == 0:
            #    all_pic = tmp
            #else:
            #    all_pic = np.concatenate([all_pic, tmp], 0)


            for test_task_id, loader in enumerate(test_loaders):
                for eval_idx, evaluator in enumerate(evaluators):
                    results[eval_idx,task_id, test_task_id], _ = evaluator(test_task_id, models[test_task_id], loader)

        print(results)

    #np.save(result_path, results)
    #with torch.no_grad():
    #    plot_images(all_pic, (28, 28), './figs/', 'all_models_after_each_tasks')


# In[137]:

main()

# In[ ]: