In [1]:
import numpy as np
import torch
import torch.autograd
import torch.nn as nn
import torch.nn.init
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from copy import deepcopy
import math
import os
from datetime import datetime

import torch.optim as optim
from torch.utils.data import ConcatDataset
from torch.utils.tensorboard import SummaryWriter
import torch.utils.data as torchdata
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import MNIST,CIFAR10
from torchvision.transforms import Compose, Lambda,ToTensor

import matplotlib.pyplot as plt

from tqdm import tqdm
from copy import deepcopy
from random import shuffle

import json

from PIL import Image

# Utils


In [2]:
def class_accuracy(pred: torch.Tensor, true: torch.Tensor) -> float:
    pred = pred.to('cuda')
    true = true.to('cuda')
    return 100 * (pred.int() == true.int()).sum().item() / len(true)


def kl_divergence(z_posterior_means, z_posterior_log_std, z_prior_mean=0.0, z_prior_log_std=0.0):
    z_prior_means = torch.full_like(z_posterior_means, z_prior_mean)
    z_prior_log_stds = torch.full_like(z_posterior_log_std, z_prior_log_std)

    prior_precision = torch.exp(torch.mul(z_prior_log_stds, -2))
    kl = 0.5 * ((z_posterior_means - z_prior_means) ** 2) * prior_precision - 0.5
    kl += z_prior_log_stds - z_posterior_log_std
    kl += 0.5 * torch.exp(2 * z_posterior_log_std - 2 * z_prior_log_stds)
    return torch.sum(kl, dim=(1,))


def bernoulli_log_likelihood(x_observed, x_reconstructed, epsilon=1e-8) -> torch.Tensor:
    prob = torch.mul(torch.log(x_reconstructed + epsilon), x_observed)
    inv_prob = torch.mul(torch.log(1 - x_reconstructed + epsilon), 1 - x_observed)
    inv_prob[inv_prob != inv_prob] = epsilon

    return torch.sum(torch.add(prob, inv_prob), dim=(1,))


def normal_with_reparameterization(means: torch.Tensor, log_stds: torch.Tensor, device='cuda') -> torch.Tensor:
    return torch.add(means, torch.mul(torch.exp(log_stds), torch.randn_like(means)))


def concatenate_flattened(tensor_list) -> torch.Tensor:
    return torch.cat([torch.reshape(t, (-1,)) for t in tensor_list])


def task_subset(data: torchdata.Dataset, task_ids: torch.Tensor, task: int, ) -> torch.Tensor:
    idx_list = torch.arange(0, len(task_ids))[task_ids == task]
    return torchdata.Subset(data, idx_list)

class Flatten(object):
    def __init__(self):
        pass

    def __call__(self, sample):
        return np.array(sample, dtype=np.float32).flatten()


class Scale(object):
    def __init__(self, max_value=255):
        self.max_value = max_value

    def __call__(self, sample):
        return sample / self.max_value


class Permute(object):
    def __init__(self, permutation):
        self.permutation = permutation

    def __call__(self, sample):
        return sample[self.permutation]


OUT_DIR = 'out/experiments/'
MODEL_DIR = 'out/models/'
IMAGE_DIR = 'out/images/'


def write_as_json(filename, data):
    if not os.path.exists(os.path.dirname(OUT_DIR + filename)):
        print('creating ...')
        os.makedirs(os.path.dirname(OUT_DIR + filename))

    with open(OUT_DIR + filename, "w") as f:
        json.dump(data, f)


def save_model(model, filename):
    if not os.path.exists(os.path.dirname(MODEL_DIR)):
        print('creating ...')
        os.makedirs(os.path.dirname(MODEL_DIR))

    torch.save(model, MODEL_DIR + filename)


def load_model(filename):
    if not os.path.exists(os.path.dirname(MODEL_DIR)):
        raise FileNotFoundError()
    return torch.load(MODEL_DIR + filename)


def save_generated_image(data: np.ndarray, filename: str):
    if not os.path.exists(os.path.dirname(IMAGE_DIR)):
        print('creating ...')
        os.makedirs(os.path.dirname(IMAGE_DIR))

    data = data * 255
    image = Image.fromarray(data)
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image.save(IMAGE_DIR + filename)
    np.save(IMAGE_DIR + filename + str('.npy'), data)

# Experiment Utils

In [3]:

def run_point_estimate_initialisation(model, data, epochs, task_ids, batch_size,
                                      device, lr, task_idx=0, y_transform=None,
                                      multiheaded=True):

    print("Obtaining point estimate for posterior initialisation")

    head = task_idx if multiheaded else 0

    optimizer = optim.Adam(model.parameters(), lr=lr)

    task_data = task_subset(data, task_ids, task_idx)
    loader = DataLoader(task_data, batch_size)

    for _ in tqdm(range(epochs), 'Epochs: '):
        for batch in loader:
            optimizer.zero_grad()
            x, y_true = batch
            x = x.to(device)
            y_true = y_true.to(device)

            if y_transform is not None:
                y_true = y_transform(y_true, task_idx)

            loss = model.point_estimate_loss(x, y_true, head_idx=head)
            loss.backward()
            optimizer.step()


def run_task(model, task_idx,
             train_data, train_task_ids, test_data, test_task_ids,
             coreset, epochs, batch_size, save_as, device, lr,num_tasks,
             y_transform=None, multiheaded=True, train_full_coreset=True,
             summary_writer=None):


    print('TASK ', task_idx)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    head = task_idx if multiheaded else 0

    task_data = task_subset(train_data, train_task_ids, task_idx)
    non_coreset_data = coreset.select(task_data, task_id=task_idx)
    train_loader = DataLoader(non_coreset_data, batch_size)

    for epoch in tqdm(range(epochs), 'Epochs: '):
        epoch_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()
            x, y_true = batch
            x = x.to(device)
            y_true = y_true.to(device)

            if y_transform is not None:
                y_true = y_transform(y_true, task_idx)

            loss = model.vcl_loss(x, y_true, head, len(task_data))
            epoch_loss += len(x) * loss.item()

            loss.backward()
            optimizer.step()

        if summary_writer is not None:
            summary_writer.add_scalars("loss", {"TASK_" + str(task_idx): epoch_loss / len(task_data)}, epoch)

    model.reset_for_new_task(head)

    if train_full_coreset:
        model_cs_trained = coreset.coreset_train(
            model, optimizer, list(range(task_idx + 1)), epochs,
            device, y_transform=y_transform, multiheaded=multiheaded)

    np_task_accuracies = np.full(num_tasks, np.nan)
    task_accuracies = [-1]*num_tasks
    tot_right = 0
    tot_tested = 0
    for test_task_idx in range(task_idx+1):
        if not train_full_coreset:
            model_cs_trained = coreset.coreset_train(
                model, optimizer, test_task_idx, epochs,
                device, y_transform=y_transform, multiheaded=multiheaded)

        head = test_task_idx if multiheaded else 0

        task_data = task_subset(test_data, test_task_ids, test_task_idx)

        x = torch.Tensor(np.array([x for x, _ in task_data]))
        y_true = torch.Tensor([y for _, y in task_data])
        x = x.to(device)

        if y_transform is not None:
            y_true = y_transform(y_true, test_task_idx)

        y_pred = model_cs_trained.prediction(x, head)

        acc = class_accuracy(y_pred, y_true)
        print("After task {} perfomance on task {} is {}"
              .format(task_idx, test_task_idx, acc))
        tot_right += acc * len(task_data)
        tot_tested += len(task_data)
        np_task_accuracies[test_task_idx] = acc
        task_accuracies[test_task_idx] = acc

    mean_accuracy = tot_right / tot_tested
    print("Mean accuracy:", mean_accuracy)

    if summary_writer is not None:
        task_accuracies_dict = dict(zip(["TASK_" + str(i) for i in range(task_idx + 1)], task_accuracies))
        summary_writer.add_scalars("test_accuracy", task_accuracies_dict, task_idx + 1)
        summary_writer.add_scalar("mean_accuracy", mean_accuracy, task_idx + 1)

    write_as_json(save_as + '/accuracy.txt', task_accuracies)
    save_model(model, save_as + '_model_task_' + str(task_idx) + '.pth')
    return np_task_accuracies

# Coreset Class

In [37]:



class Coreset:
    def __init__(self, size=0, lr=0.001):
        self.size = size
        self.coreset = None
        self.coreset_task_ids = None
        self.lr = lr

    def select(self, d: torchdata.Dataset, task_id: int):
        return d

    def coreset_train(self, m, old_optimizer, tasks, epochs, device,
                      y_transform=None, multiheaded=True, batch_size=256):

        if self.coreset is None:
            return m

        model = deepcopy(m)

        optimizer = optim.Adam(model.parameters(), lr=self.lr)
        optimizer.load_state_dict(old_optimizer.state_dict())

        if isinstance(tasks, int):
            tasks = [tasks]

        train_loaders = {
            task_idx: torchdata.DataLoader(
                task_subset(self.coreset, self.coreset_task_ids, task_idx),
                batch_size
            )
            for task_idx in tasks
        }

        print('CORESET TRAIN')
        for _ in tqdm(range(epochs), 'Epochs: '):
            shuffle(tasks)
            for task_idx in tasks:
                head = task_idx if multiheaded else 0

                for batch in train_loaders[task_idx]:
                    optimizer.zero_grad()
                    x, y_true = batch
                    x = x.to(device)
                    y_true = y_true.to(device)

                    if y_transform is not None:
                        y_true = y_transform(y_true, task_idx)

                    loss = model.vcl_loss(x, y_true, head, len(self.coreset))
                    loss.backward()
                    optimizer.step()

        return model

class RandomCoreset(Coreset):
    def __init__(self, size):
        super().__init__(size)

    def select(self, d: torchdata.Dataset, task_id: int):

        new_cs_data, non_cs = torchdata.random_split(d, [self.size, max(0, len(d) - self.size)])


        new_cs_x = torch.tensor(np.array([x for x, _ in new_cs_data]))
        new_cs_y = torch.tensor(np.array([y for _, y in new_cs_data]))

        new_cs = torchdata.TensorDataset(new_cs_x, new_cs_y)
        new_task_ids = torch.full((len(new_cs_data),), task_id)

        if self.coreset is None:
            self.coreset = new_cs
            self.coreset_task_ids = new_task_ids
        else:
            self.coreset = torchdata.ConcatDataset((self.coreset, new_cs))
            self.coreset_task_ids = torch.cat((self.coreset_task_ids, new_task_ids))

        return non_cs
class kCenterCoreset(Coreset):
  def __init__(self,size):
    super().__init__(size)
  def select(self,d, task_id):
    x,y = zip(*[(data[0],data[1]) for data in d])
    x = torch.Tensor(np.array(x).flatten())
    y = torch.Tensor(np.array(y))
    dists = torch.full((len(x),), float('inf'))
    current_id = 0
    dists = self.update_distance(dists, x, current_id)
    idx = [current_id]
    for _ in range(1, self.size):
        current_id = torch.argmax(dists)
        dists = self.update_distance(dists, x, current_id)
        idx.append(current_id.item())
    idx = torch.tensor(idx)

    new_cs_x = x[idx]
    new_cs_y = y[idx]
    new_non_cs_x = torch.stack([x for i, x in enumerate(x) if i not in idx])
    new_non_cs_y = torch.stack([y for i, y in enumerate(y) if i not in idx])
    non_cs = torchdata.TensorDataset(new_non_cs_x, new_non_cs_y)

    new_cs = torchdata.TensorDataset(new_cs_x, new_cs_y)
    new_task_ids = torch.full((self.size,), task_id)

    if self.coreset is None:
        self.coreset = new_cs
        self.coreset_task_ids = new_task_ids
    else:
        self.coreset = torchdata.ConcatDataset((self.coreset, new_cs))
        self.coreset_task_ids = torch.cat((self.coreset_task_ids, new_task_ids))

    return non_cs

  def update_distance(self,dists, X, current_id):
    current_point = X[current_id]
    new_dists = torch.norm(X - current_point, dim=1)
    dists = torch.min(dists, new_dists)
    return dists


# VCL Layer

In [5]:


class MeanFieldGaussianLinear(torch.nn.Module):
    def __init__(self, in_features, out_features, initial_posterior_variance=1e-3):
        super().__init__()
        self.epsilon = 1e-8
        self.in_features = in_features
        self.out_features = out_features
        self.ipv = initial_posterior_variance
        self.register_buffer('prior_W_means', torch.zeros(out_features, in_features))
        self.register_buffer('prior_W_log_vars', torch.zeros(out_features, in_features))
        self.register_buffer('prior_b_means', torch.zeros(out_features))
        self.register_buffer('prior_b_log_vars', torch.zeros(out_features))

        self.posterior_W_means = torch.nn.Parameter(torch.empty_like(self._buffers['prior_W_means'], requires_grad=True))
        self.posterior_b_means = torch.nn.Parameter(torch.empty_like(self._buffers['prior_b_means'], requires_grad=True))
        self.posterior_W_log_vars = torch.nn.Parameter(torch.empty_like(self._buffers['prior_W_log_vars'], requires_grad=True))
        self.posterior_b_log_vars = torch.nn.Parameter(torch.empty_like(self._buffers['prior_b_log_vars'], requires_grad=True))

        self._initialize_posteriors()

    def forward(self, x, sample_parameters=True):
        if sample_parameters:
            w, b = self._sample_parameters()
            return F.linear(x, w, b)
        else:
            return F.linear(x, self.posterior_W_means, self.posterior_b_means)

    def reset_for_next_task(self):
        self._buffers['prior_W_means'].data.copy_(self.posterior_W_means.data)
        self._buffers['prior_W_log_vars'].data.copy_(self.posterior_W_log_vars.data)
        self._buffers['prior_b_means'].data.copy_(self.posterior_b_means.data)
        self._buffers['prior_b_log_vars'].data.copy_(self.posterior_b_log_vars.data)

    def kl_divergence(self) -> torch.Tensor:
        prior_means = torch.cat((self._buffers['prior_W_means'].view(-1),
                                self._buffers['prior_b_means'].view(-1)))
        prior_log_vars = torch.cat((self._buffers['prior_W_log_vars'].view(-1),
                                    self._buffers['prior_b_log_vars'].view(-1)))
        prior_vars = torch.exp(prior_log_vars)

        posterior_means = torch.cat((self.posterior_W_means.view(-1),
                                    self.posterior_b_means.view(-1)))
        posterior_log_vars = torch.cat((self.posterior_W_log_vars.view(-1),
                                        self.posterior_b_log_vars.view(-1)))
        posterior_vars = torch.exp(posterior_log_vars)

        kl_elementwise = (posterior_vars / (prior_vars + self.epsilon) +
                          ((prior_means - posterior_means) ** 2) / (prior_vars + self.epsilon) -
                          1 + prior_log_vars - posterior_log_vars)

        return 0.5 * kl_elementwise.sum()

    def count_parameters(self):
      return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def get_statistics(self) -> dict:
        statistics = {
            'total_params': self.count_parameters(),
            'average_w_mean': torch.mean(self.posterior_W_means),
            'average_b_mean': torch.mean(self.posterior_b_means),
            'average_w_var': torch.mean(torch.exp(self.posterior_W_log_vars)),
            'average_b_var': torch.mean(torch.exp(self.posterior_b_log_vars))
        }

        return statistics

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

    def _sample_parameters(self):
        w_epsilons = torch.randn_like(self.posterior_W_means)
        b_epsilons = torch.randn_like(self.posterior_b_means)

        w_std_dev = torch.exp(0.5 * self.posterior_W_log_vars)
        b_std_dev = torch.exp(0.5 * self.posterior_b_log_vars)

        w = self.posterior_W_means + w_epsilons * w_std_dev
        b = self.posterior_b_means + b_epsilons * b_std_dev

        return w, b

    def _initialize_posteriors(self):
        torch.nn.init.normal_(self.posterior_W_means, mean=0, std=0.1)
        torch.nn.init.uniform_(self.posterior_b_means, -0.1, 0.1)
        torch.nn.init.constant_(self.posterior_W_log_vars, math.log(self.ipv))
        torch.nn.init.constant_(self.posterior_b_log_vars, math.log(self.ipv))


# VCL Model

In [6]:
class DiscriminativeVCL(nn.Module):


    def __init__(self, x_dim, h_dim, y_dim, n_heads=1, shared_h_dims=(100, 100),
                 initial_posterior_variance=1e-6, mc_sampling_n=10, device='cuda'):
        super().__init__()
        if n_heads < 1:
            raise ValueError('Network requires at least one head.')

        self.x_dim = x_dim
        self.h_dim = h_dim
        self.y_dim = y_dim
        self.n_heads = n_heads
        self.ipv = initial_posterior_variance
        self.mc_sampling_n = mc_sampling_n
        self.device = device

        shared_dims = [x_dim] + list(shared_h_dims) + [h_dim]

        self.shared_layers = nn.ModuleList([
            MeanFieldGaussianLinear(shared_dims[i], shared_dims[i + 1], self.ipv) for i in
            range(len(shared_dims) - 1)
        ])
        self.heads = nn.ModuleList([
            MeanFieldGaussianLinear(self.h_dim, self.y_dim, self.ipv) for _ in range(n_heads)
        ])

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x, head_idx, sample_parameters=True):
        y_out = torch.zeros(size=(x.size()[0], self.y_dim)).to(self.device)

        for _ in range(self.mc_sampling_n if sample_parameters else 1):
            h = x
            for layer in self.shared_layers:
                h = F.relu(layer(h, sample_parameters=sample_parameters))

            h = self.heads[head_idx](h, sample_parameters=sample_parameters)
            h = self.softmax(h)

            y_out.add_(h)

        y_out.div_(self.mc_sampling_n)

        return y_out

    def vcl_loss(self, x, y, head_idx, task_size):
        return self._kl_divergence(head_idx) / task_size + torch.nn.NLLLoss()(self(x, head_idx), y.long())

    def point_estimate_loss(self, x, y, head_idx):
        return torch.nn.NLLLoss()(self(x, head_idx, sample_parameters=False), y)

    def prediction(self, x, task):
        return torch.argmax(self(x, task), dim=1)

    def reset_for_new_task(self, head_idx):
        for layer in self.shared_layers:
            layer.reset_for_next_task()

        self.heads[head_idx].reset_for_next_task()

    def get_statistics(self):
        layer_statistics = []
        model_statistics = {
            'total_params' : 0,
            'average_w_mean': 0,
            'average_b_mean': 0,
            'average_w_var': 0,
            'average_b_var': 0
        }

        for layer in self.shared_layers:
            stats = layer.get_statistics()
            layer_statistics.append(stats)
            layer_n_params = stats['total_params']
            model_statistics['total_params'] += layer_n_params
            model_statistics['average_w_mean'] += stats['average_w_mean']*layer_n_params
            model_statistics['average_b_mean'] += stats['average_b_mean']*layer_n_params
            model_statistics['average_w_var'] += stats['average_w_var']*layer_n_params
            model_statistics['average_b_var'] += stats['average_b_var']*layer_n_params

        for head in self.heads:
            stats = head.get_statistics()
            layer_statistics.append(stats)
            layer_n_params = stats['total_params']
            model_statistics['total_params'] += layer_n_params
            model_statistics['average_w_mean'] += stats['average_w_mean']*layer_n_params
            model_statistics['average_b_mean'] += stats['average_b_mean']*layer_n_params
            model_statistics['average_w_var'] += stats['average_w_var']*layer_n_params
            model_statistics['average_b_var'] += stats['average_b_var']*layer_n_params

        model_statistics['average_w_mean'] /= model_statistics['total_params']
        model_statistics['average_b_mean'] /= model_statistics['total_params']
        model_statistics['average_w_var'] /= model_statistics['total_params']
        model_statistics['average_b_var'] /= model_statistics['total_params']

        return layer_statistics, model_statistics

    def _kl_divergence(self, head_idx) -> torch.Tensor:
        kl_divergence = torch.zeros(1, requires_grad=False).to(self.device)

        for layer in self.shared_layers:
            kl_divergence = torch.add(kl_divergence, layer.kl_divergence())

        kl_divergence = torch.add(kl_divergence, self.heads[head_idx].kl_divergence())
        return kl_divergence

    def _mean_posterior_variance(self):
        ((_, posterior_w_log_vars), (_, posterior_b_log_vars)) = self.posterior
        posterior_log_vars = torch.cat([torch.reshape(t, (-1,)) for t in posterior_w_log_vars] + posterior_b_log_vars)
        posterior_vars     = torch.exp(posterior_log_vars)
        return torch.mean(posterior_vars).item()


# Discriminative Experiments

In [7]:
MNIST_FLATTENED_DIM = 28 * 28
LR = 0.001
INITIAL_POSTERIOR_VAR = 1e-3

device = torch.device("cuda:0")
print("Running on device", device)

def permuted_mnist(coresetAlgo,coreset_size=20):
    N_CLASSES = 10
    LAYER_WIDTH = 100
    N_HIDDEN_LAYERS = 2
    N_TASKS = 10
    MULTIHEADED = False
    CORESET_SIZE = coreset_size
    EPOCHS = 100
    BATCH_SIZE = 256
    TRAIN_FULL_CORESET = True
    coresetDict = {"Random":RandomCoreset, "K-Center":kCenterCoreset}

    transforms = [Compose([Flatten(), Scale(), Permute(torch.randperm(MNIST_FLATTENED_DIM))]) for _ in range(N_TASKS)]


    model = DiscriminativeVCL(
        x_dim=MNIST_FLATTENED_DIM, y_dim=N_CLASSES,
        h_dim=LAYER_WIDTH, shared_h_dims = tuple(LAYER_WIDTH for _ in range(N_HIDDEN_LAYERS)),
        n_heads=(N_TASKS if MULTIHEADED else 1),
        initial_posterior_variance=INITIAL_POSTERIOR_VAR
    ).to(device)

    coreset = coresetDict[coresetAlgo](size=CORESET_SIZE)

    mnist_train = ConcatDataset(
        [MNIST(root="data", train=True, download=True, transform=t) for t in transforms]
    )
    mnist_train = Subset(mnist_train,torch.randperm(len(mnist_train))[:len(mnist_train)//20])
    task_size = len(mnist_train) // N_TASKS
    train_task_ids = torch.cat(
        [torch.full((task_size,), id) for id in range(N_TASKS)]
    )

    mnist_test = ConcatDataset(
        [MNIST(root="data", train=False, download=True, transform=t) for t in transforms]
    )
    mnist_test = Subset(mnist_test,torch.randperm(len(mnist_test))[:len(mnist_test)//20])
    task_size = len(mnist_test) // N_TASKS
    test_task_ids = torch.cat(
        [torch.full((task_size,), id) for id in range(N_TASKS)]
    )



    summary_logdir = os.path.join("logs", "disc_p_mnist", datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)

    run_point_estimate_initialisation(model=model, data=mnist_train,
                                      epochs=EPOCHS, batch_size=BATCH_SIZE,
                                      device=device, lr=LR,
                                      multiheaded=MULTIHEADED,
                                      task_ids=train_task_ids)
    accuracies_list = []
    for task in range(N_TASKS):
        accuracies = run_task(
            model=model, train_data=mnist_train, train_task_ids=train_task_ids,
            test_data=mnist_test, test_task_ids=test_task_ids, task_idx=task,
            coreset=coreset, epochs=EPOCHS, batch_size=BATCH_SIZE,
            device=device, lr=LR,num_tasks=10, save_as="disc_p_mnist",
            multiheaded=MULTIHEADED, train_full_coreset=TRAIN_FULL_CORESET,
            summary_writer=writer
        )
        accuracies_list.append(accuracies)
    accuracies_stack = np.stack(accuracies_list)
    avgAcc = np.nanmean(accuracies_stack, axis=0)

    writer.close()

    return avgAcc,accuracies_list


def split_mnist(coresetAlgo,coreset_size=4):

    N_CLASSES = 2
    LAYER_WIDTH = 256
    N_HIDDEN_LAYERS = 2
    N_TASKS = 5
    MULTIHEADED = True
    CORESET_SIZE = coreset_size
    EPOCHS = 100
    BATCH_SIZE = 256
    TRAIN_FULL_CORESET = True
    coresetDict = {"Random":RandomCoreset, "K-Center":kCenterCoreset}

    transform = Compose([Flatten(), Scale()])

    mnist_train = MNIST(root="data", train=True, download=True, transform=transform)
    mnist_test = MNIST(root="data", train=False, download=True, transform=transform)
    mnist_train = Subset(mnist_train,torch.randperm(len(mnist_train))[:len(mnist_train)])
    task_size = len(mnist_train) // N_TASKS
    train_task_ids = torch.cat(
        [torch.full((task_size,), id) for id in range(N_TASKS)]
    )

    mnist_test = Subset(mnist_test,torch.randperm(len(mnist_test))[:len(mnist_test)])
    task_size = len(mnist_test) // N_TASKS
    test_task_ids = torch.cat(
        [torch.full((task_size,), id) for id in range(N_TASKS)]
    )
    model = DiscriminativeVCL(
        x_dim=MNIST_FLATTENED_DIM, y_dim=N_CLASSES,
        h_dim=LAYER_WIDTH, shared_h_dims = tuple(LAYER_WIDTH for _ in range(N_HIDDEN_LAYERS)),
        n_heads=(N_TASKS if MULTIHEADED else 1),
        initial_posterior_variance=INITIAL_POSTERIOR_VAR
    ).to(device)

    coreset = coresetDict[coresetAlgo](size=CORESET_SIZE)
    label_to_task_mapping = {
        0: 0, 1: 0,
        2: 1, 3: 1,
        4: 2, 5: 2,
        6: 3, 7: 3,
        8: 4, 9: 4,
    }

    if isinstance(mnist_train[0][1], int):
        train_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_train])
        test_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_test])
    elif isinstance(mnist_train[0][1], torch.Tensor):
        train_task_ids = torch.Tensor([label_to_task_mapping[y.item()] for _, y in mnist_train])
        test_task_ids = torch.Tensor([label_to_task_mapping[y.item()] for _, y in mnist_test])

    summary_logdir = os.path.join("logs", "disc_s_mnist", datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)

    binarize_y = lambda y, task: (y == (2 * task + 1)).long()

    run_point_estimate_initialisation(model=model, data=mnist_train,
                                      epochs=EPOCHS, batch_size=BATCH_SIZE,
                                      device=device, multiheaded=MULTIHEADED,
                                      lr=LR, task_ids=train_task_ids,
                                      y_transform=binarize_y)

    accuracies_list = []
    for task_idx in range(N_TASKS):
        accuracies = run_task(
            model=model, train_data=mnist_train, train_task_ids=train_task_ids,
            test_data=mnist_test, test_task_ids=test_task_ids, coreset=coreset,
            task_idx=task_idx, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR,num_tasks=5,
            save_as="disc_s_mnist", device=device, multiheaded=MULTIHEADED,
            y_transform=binarize_y, train_full_coreset=TRAIN_FULL_CORESET,
            summary_writer=writer
        )
        accuracies_list.append(accuracies)
    accuracies_stack = np.stack(accuracies_list)
    avgAcc = np.nanmean(accuracies_stack, axis=0)


    writer.close()
    return avgAcc,accuracies_list

Running on device cuda:0


# Conv Layer and Model

In [48]:
class MeanFieldGaussianConv2d(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, initial_posterior_variance=1e-3):
        super().__init__()
        self.epsilon = 1e-8
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.ipv = initial_posterior_variance
        self.kernel_size = kernel_size
        self.groups = 1

        self.register_buffer('prior_W_means', torch.zeros(out_channels, in_channels//self.groups, kernel_size,kernel_size))
        self.register_buffer('prior_W_log_vars', torch.zeros(out_channels, in_channels//self.groups, kernel_size,kernel_size))
        self.register_buffer('prior_b_means', torch.zeros(out_channels))
        self.register_buffer('prior_b_log_vars', torch.zeros(out_channels))

        self.posterior_W_means = torch.nn.Parameter(torch.empty_like(self._buffers['prior_W_means'], requires_grad=True))
        self.posterior_b_means = torch.nn.Parameter(torch.empty_like(self._buffers['prior_b_means'], requires_grad=True))
        self.posterior_W_log_vars = torch.nn.Parameter(torch.empty_like(self._buffers['prior_W_log_vars'], requires_grad=True))
        self.posterior_b_log_vars = torch.nn.Parameter(torch.empty_like(self._buffers['prior_b_log_vars'], requires_grad=True))

        self._initialize_posteriors()

    def forward(self, x, sample_parameters=True):
        if sample_parameters:
            w, b = self._sample_parameters()
            return F.conv2d(input=x, weight=w, bias=b,groups=self.groups,padding=1)
        else:
            return F.conv2d(input=x, weight=self.posterior_W_means, bias=self.posterior_b_means,groups=self.groups,padding=1)

    def reset_for_next_task(self):
        self._buffers['prior_W_means'].data.copy_(self.posterior_W_means.data)
        self._buffers['prior_W_log_vars'].data.copy_(self.posterior_W_log_vars.data)
        self._buffers['prior_b_means'].data.copy_(self.posterior_b_means.data)
        self._buffers['prior_b_log_vars'].data.copy_(self.posterior_b_log_vars.data)

    def kl_divergence(self) -> torch.Tensor:
        prior_means = torch.cat((self._buffers['prior_W_means'].view(-1),
                                self._buffers['prior_b_means'].view(-1)))
        prior_log_vars = torch.cat((self._buffers['prior_W_log_vars'].view(-1),
                                    self._buffers['prior_b_log_vars'].view(-1)))
        prior_vars = torch.exp(prior_log_vars)

        posterior_means = torch.cat((self.posterior_W_means.view(-1),
                                    self.posterior_b_means.view(-1)))
        posterior_log_vars = torch.cat((self.posterior_W_log_vars.view(-1),
                                        self.posterior_b_log_vars.view(-1)))
        posterior_vars = torch.exp(posterior_log_vars)

        kl_elementwise = (posterior_vars / (prior_vars + self.epsilon) +
                          ((prior_means - posterior_means) ** 2) / (prior_vars + self.epsilon) -
                          1 + prior_log_vars - posterior_log_vars)

        return 0.5 * kl_elementwise.sum()

    def count_parameters(self):
      return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def get_statistics(self) -> dict:
        statistics = {
            'total_params': self.count_parameters(),
            'average_w_mean': torch.mean(self.posterior_W_means),
            'average_b_mean': torch.mean(self.posterior_b_means),
            'average_w_var': torch.mean(torch.exp(self.posterior_W_log_vars)),
            'average_b_var': torch.mean(torch.exp(self.posterior_b_log_vars))
        }

        return statistics

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_channels, self.out_channels, self.bias is not None
        )

    def _sample_parameters(self):
        w_epsilons = torch.randn_like(self.posterior_W_means)
        b_epsilons = torch.randn_like(self.posterior_b_means)

        w_std_dev = torch.exp(0.5 * self.posterior_W_log_vars)
        b_std_dev = torch.exp(0.5 * self.posterior_b_log_vars)

        w = self.posterior_W_means + w_epsilons * w_std_dev
        b = self.posterior_b_means + b_epsilons * b_std_dev

        return w, b

    def _initialize_posteriors(self):
        torch.nn.init.normal_(self.posterior_W_means, mean=0, std=0.1)
        torch.nn.init.uniform_(self.posterior_b_means, -0.1, 0.1)
        torch.nn.init.constant_(self.posterior_W_log_vars, math.log(self.ipv))
        torch.nn.init.constant_(self.posterior_b_log_vars, math.log(self.ipv))


In [49]:
class ConvVCL(nn.Module):
    def __init__(self, x_dim,x_channels, y_dim, n_heads=1, shared_h_channels=(32,64,128),kernel_size=3,
                 initial_posterior_variance=1e-6, mc_sampling_n=10, device='cuda'):
        super().__init__()
        if n_heads < 1:
            raise ValueError('Network requires at least one head.')

        self.x_dim = x_dim
        self.x_channels = x_channels
        self.y_dim = y_dim
        self.n_heads = n_heads
        self.kernel_size = kernel_size
        self.ipv = initial_posterior_variance
        self.mc_sampling_n = mc_sampling_n
        self.device = device

        shared_channels = [x_channels]+list(shared_h_channels)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.shared_layers = nn.ModuleList([
            MeanFieldGaussianConv2d(shared_channels[i], shared_channels[i + 1], self.kernel_size, initial_posterior_variance = self.ipv) for i in
            range(len(shared_channels) - 1)
        ])
        self.output_flat_length = self.get_output_shape()
        self.heads = [
            [MeanFieldGaussianLinear(self.output_flat_length, 256, self.ipv).to("cuda"),MeanFieldGaussianLinear(256, self.y_dim, self.ipv).to("cuda")] for _ in range(n_heads)
        ]
        self.softmax = nn.Softmax(dim=1)
    def get_output_shape(self):
        x = torch.rand((1,self.x_channels,self.x_dim,self.x_dim))
        for layer in self.shared_layers:
            x = self.pool(F.relu(layer(x, sample_parameters=False)))

        return math.prod(x.shape)
    def forward(self, x, head_idx, sample_parameters=True):
        y_out = torch.zeros(size=(x.size()[0], self.y_dim)).to(self.device)

        for _ in range(self.mc_sampling_n if sample_parameters else 1):
            h = x
            for layer in self.shared_layers:
                h = self.pool(F.relu(layer(h, sample_parameters=sample_parameters)))
            h = h.view(-1,self.output_flat_length)
            h = F.relu(self.heads[head_idx][0](h, sample_parameters=sample_parameters))
            h = self.heads[head_idx][1](h,sample_parameters=sample_parameters)
            h = self.softmax(h)

            y_out.add_(h)

        y_out.div_(self.mc_sampling_n)

        return y_out

    def vcl_loss(self, x, y, head_idx, task_size):
        return self._kl_divergence(head_idx) / task_size + torch.nn.NLLLoss()(self(x, head_idx), y.long())

    def point_estimate_loss(self, x, y, head_idx):
        return torch.nn.NLLLoss()(self(x, head_idx, sample_parameters=False), y)

    def prediction(self, x, task):
        return torch.argmax(self(x, task), dim=1)

    def reset_for_new_task(self, head_idx):
        for layer in self.shared_layers:
            layer.reset_for_next_task()

        self.heads[head_idx][0].reset_for_next_task()
        self.heads[head_idx][1].reset_for_next_task()


    def get_statistics(self):
        layer_statistics = []
        model_statistics = {
            'total_params' : 0,
            'average_w_mean': 0,
            'average_b_mean': 0,
            'average_w_var': 0,
            'average_b_var': 0
        }

        for layer in self.shared_layers:
            stats = layer.get_statistics()
            layer_statistics.append(stats)
            layer_n_params = stats['total_params']
            model_statistics['total_params'] += layer_n_params
            model_statistics['average_w_mean'] += stats['average_w_mean']*layer_n_params
            model_statistics['average_b_mean'] += stats['average_b_mean']*layer_n_params
            model_statistics['average_w_var'] += stats['average_w_var']*layer_n_params
            model_statistics['average_b_var'] += stats['average_b_var']*layer_n_params

        for headL in self.heads:
          for head in headL:
            stats = head.get_statistics()
            layer_statistics.append(stats)
            layer_n_params = stats['total_params']
            model_statistics['total_params'] += layer_n_params
            model_statistics['average_w_mean'] += stats['average_w_mean']*layer_n_params
            model_statistics['average_b_mean'] += stats['average_b_mean']*layer_n_params
            model_statistics['average_w_var'] += stats['average_w_var']*layer_n_params
            model_statistics['average_b_var'] += stats['average_b_var']*layer_n_params

        model_statistics['average_w_mean'] /= model_statistics['total_params']
        model_statistics['average_b_mean'] /= model_statistics['total_params']
        model_statistics['average_w_var'] /= model_statistics['total_params']
        model_statistics['average_b_var'] /= model_statistics['total_params']

        return layer_statistics, model_statistics

    def _kl_divergence(self, head_idx) -> torch.Tensor:
        kl_divergence = torch.zeros(1, requires_grad=False).to(self.device)
        for layer in self.shared_layers:
            kl_divergence = torch.add(kl_divergence, layer.kl_divergence())

        kl_divergence = torch.add(kl_divergence, self.heads[head_idx][0].kl_divergence())
        kl_divergence = torch.add(kl_divergence, self.heads[head_idx][1].kl_divergence())
        return kl_divergence

    def _mean_posterior_variance(self):
        ((_, posterior_w_log_vars), (_, posterior_b_log_vars)) = self.posterior
        posterior_log_vars = torch.cat([torch.reshape(t, (-1,)) for t in posterior_w_log_vars] + posterior_b_log_vars)
        posterior_vars     = torch.exp(posterior_log_vars)
        return torch.mean(posterior_vars).item()

# CVCL tests MNIST and CIFAR10

In [50]:
def split_mnist_CVCL(coresetAlgo,coreset_size=4):
    N_CLASSES = 2
    N_TASKS = 5
    MULTIHEADED = True
    CORESET_SIZE = coreset_size
    EPOCHS = 100
    BATCH_SIZE = 64
    TRAIN_FULL_CORESET = True
    coresetDict = {"Random":RandomCoreset, "K-Center":kCenterCoreset}

    transform = Compose([ToTensor(),Scale()])

    mnist_train = MNIST(root="data", train=True, download=True, transform=transform)
    mnist_test = MNIST(root="data", train=False, download=True, transform=transform)
    mnist_train = Subset(mnist_train,torch.randperm(len(mnist_train))[:len(mnist_train)])
    task_size = len(mnist_train) // N_TASKS
    train_task_ids = torch.cat(
        [torch.full((task_size,), id) for id in range(N_TASKS)]
    )

    mnist_test = Subset(mnist_test,torch.randperm(len(mnist_test))[:len(mnist_test)])
    task_size = len(mnist_test) // N_TASKS
    test_task_ids = torch.cat(
        [torch.full((task_size,), id) for id in range(N_TASKS)]
    )
    model = ConvVCL(x_channels=1,
        x_dim=28, y_dim=N_CLASSES,
        n_heads=(N_TASKS if MULTIHEADED else 1),
        initial_posterior_variance=INITIAL_POSTERIOR_VAR
    ).to(device)

    coreset = coresetDict[coresetAlgo](size=CORESET_SIZE)
    label_to_task_mapping = {
        0: 0, 1: 0,
        2: 1, 3: 1,
        4: 2, 5: 2,
        6: 3, 7: 3,
        8: 4, 9: 4,
    }

    if isinstance(mnist_train[0][1], int):
        train_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_train])
        test_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_test])
    elif isinstance(mnist_train[0][1], torch.Tensor):
        train_task_ids = torch.Tensor([label_to_task_mapping[y.item()] for _, y in mnist_train])
        test_task_ids = torch.Tensor([label_to_task_mapping[y.item()] for _, y in mnist_test])

    summary_logdir = os.path.join("logs", "disc_s_mnist", datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)

    binarize_y = lambda y, task: (y == (2 * task + 1)).long()

    run_point_estimate_initialisation(model=model, data=mnist_train,
                                      epochs=2, batch_size=BATCH_SIZE,
                                      device=device, multiheaded=MULTIHEADED,
                                      lr=LR, task_ids=train_task_ids,
                                      y_transform=binarize_y)

    accuracies_list = []
    for task_idx in range(N_TASKS):
        accuracies = run_task(
            model=model, train_data=mnist_train, train_task_ids=train_task_ids,
            test_data=mnist_test, test_task_ids=test_task_ids, coreset=coreset,
            task_idx=task_idx, epochs=2, batch_size=BATCH_SIZE, lr=LR,num_tasks=5,
            save_as="disc_s_mnist", device=device, multiheaded=MULTIHEADED,
            y_transform=binarize_y, train_full_coreset=TRAIN_FULL_CORESET,
            summary_writer=writer
        )
        accuracies_list.append(accuracies)
    accuracies_stack = np.stack(accuracies_list)
    avgAcc = np.nanmean(accuracies_stack, axis=0)


    writer.close()
    return avgAcc,accuracies_list

In [51]:
def split_cifar_CVCL(coresetAlgo,coreset_size=4):
    N_CLASSES = 2
    N_TASKS = 5
    MULTIHEADED = True
    CORESET_SIZE = coreset_size
    EPOCHS = 100
    BATCH_SIZE = 64
    TRAIN_FULL_CORESET = True
    coresetDict = {"Random":RandomCoreset, "K-Center":kCenterCoreset}

    transform = Compose([ToTensor(),Scale()])

    cifar_train = CIFAR10(root="data", train=True, download=True, transform=transform)
    cifar_test = CIFAR10(root="data", train=False, download=True, transform=transform)
    cifar_train = Subset(cifar_train,torch.randperm(len(cifar_train))[:len(cifar_train)])
    task_size = len(cifar_train) // N_TASKS
    train_task_ids = torch.cat(
        [torch.full((task_size,), id) for id in range(N_TASKS)]
    )

    cifar_test = Subset(cifar_test,torch.randperm(len(cifar_test))[:len(cifar_test)])
    task_size = len(cifar_test) // N_TASKS
    test_task_ids = torch.cat(
        [torch.full((task_size,), id) for id in range(N_TASKS)]
    )
    model = ConvVCL(x_channels=3,
        x_dim=32, y_dim=N_CLASSES,
        n_heads=(N_TASKS if MULTIHEADED else 1),
        initial_posterior_variance=INITIAL_POSTERIOR_VAR
    ).to(device)

    coreset = coresetDict[coresetAlgo](size=CORESET_SIZE)
    label_to_task_mapping = {
        0: 0, 1: 0,
        2: 1, 3: 1,
        4: 2, 5: 2,
        6: 3, 7: 3,
        8: 4, 9: 4,
    }

    if isinstance(cifar_train[0][1], int):
        train_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in cifar_train])
        test_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in cifar_test])
    elif isinstance(cifar_train[0][1], torch.Tensor):
        train_task_ids = torch.Tensor([label_to_task_mapping[y.item()] for _, y in cifar_train])
        test_task_ids = torch.Tensor([label_to_task_mapping[y.item()] for _, y in cifar_test])

    summary_logdir = os.path.join("logs", "disc_s_mnist", datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)

    binarize_y = lambda y, task: (y == (2 * task + 1)).long()

    run_point_estimate_initialisation(model=model, data=cifar_train,
                                      epochs=1, batch_size=BATCH_SIZE,
                                      device=device, multiheaded=MULTIHEADED,
                                      lr=LR, task_ids=train_task_ids,
                                      y_transform=binarize_y)

    accuracies_list = []
    for task_idx in range(N_TASKS):
        accuracies = run_task(
            model=model, train_data=cifar_train, train_task_ids=train_task_ids,
            test_data=cifar_test, test_task_ids=test_task_ids, coreset=coreset,
            task_idx=task_idx, epochs=1, batch_size=BATCH_SIZE, lr=LR,num_tasks=5,
            save_as="disc_s_mnist", device=device, multiheaded=MULTIHEADED,
            y_transform=binarize_y, train_full_coreset=TRAIN_FULL_CORESET,
            summary_writer=writer
        )
        accuracies_list.append(accuracies)
    accuracies_stack = np.stack(accuracies_list)
    avgAcc = np.nanmean(accuracies_stack, axis=0)


    writer.close()
    return avgAcc,accuracies_list