# imports

In [None]:
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, Subset

import numpy as np
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from copy import deepcopy

# utils

In [None]:
def class_accuracy(pred: torch.Tensor, true: torch.Tensor) -> float:
    """
    Computes the percentage class accuracy of the predictions, given the correct
    class labels.

    Args:
        pred: the class predictions made by a model
        true: the ground truth classes of the sample
    Returns:
        Classification accuracy of the predictions w.r.t. the ground truth labels
    """
    return 100 * (pred.int() == true.int()).sum().item() / len(true)


def kl_divergence(z_posterior_means, z_posterior_log_std, z_prior_mean=0.0, z_prior_log_std=0.0):
    z_prior_means = torch.full_like(z_posterior_means, z_prior_mean)
    z_prior_log_stds = torch.full_like(z_posterior_log_std, z_prior_log_std)

    prior_precision = torch.exp(torch.mul(z_prior_log_stds, -2))
    kl = 0.5 * ((z_posterior_means - z_prior_means) ** 2) * prior_precision - 0.5
    kl += z_prior_log_stds - z_posterior_log_std
    kl += 0.5 * torch.exp(2 * z_posterior_log_std - 2 * z_prior_log_stds)
    return torch.sum(kl, dim=(1,))


def bernoulli_log_likelihood(x_observed, x_reconstructed, epsilon=1e-8) -> torch.Tensor:
    """
    For observed batch of data x, and reconstructed data p (we view p as a
    probability of a pixel being on), computes a tensor of dimensions
    [batch_size] representing the log likelihood of each data point in the batch.
    """
    prob = torch.mul(torch.log(x_reconstructed + epsilon), x_observed)
    inv_prob = torch.mul(torch.log(1 - x_reconstructed + epsilon), 1 - x_observed)
    inv_prob[inv_prob != inv_prob] = epsilon

    return torch.sum(torch.add(prob, inv_prob), dim=(1,))


def normal_with_reparameterization(means: torch.Tensor, log_stds: torch.Tensor, device='cpu') -> torch.Tensor:
    return torch.add(means, torch.mul(torch.exp(log_stds), torch.randn_like(means)))


def concatenate_flattened(tensor_list) -> torch.Tensor:
    """
    Given list of tensors, flattens each and concatenates their values.
    """
    return torch.cat([torch.reshape(t, (-1,)) for t in tensor_list])


def task_subset(data: Dataset, task_ids: torch.Tensor, task: int,) -> torch.Tensor:
    idx_list = torch.arange(0, len(task_ids))[task_ids == task]
    return Subset(data, idx_list)

In [None]:
import json
import os
import torch
import numpy as np
from PIL import Image

OUT_DIR = 'experiments/'
MODEL_DIR = 'models/'
IMAGE_DIR = 'images/'


def write_as_json(filename, data):
    """
    Dumps the given data into the specified file using JSON formatting. The file
    is created if it does not exist.

    Args:
        filename: path to file to dump JSON into
        data: numeric data to dump
    """
    if not os.path.exists(os.path.dirname(OUT_DIR + filename)):
        print('creating ...')
        os.makedirs(os.path.dirname(OUT_DIR + filename))

    with open(OUT_DIR + filename, "w") as f:
        json.dump(data, f)


def save_model(model, filename):
    if not os.path.exists(os.path.dirname(MODEL_DIR)):
        print('creating ...')
        os.makedirs(os.path.dirname(MODEL_DIR))

    torch.save(model, MODEL_DIR + filename)


def load_model(filename):
    if not os.path.exists(os.path.dirname(MODEL_DIR)):
        raise FileNotFoundError()
    return torch.load(MODEL_DIR + filename)


def save_generated_image(data: np.ndarray, filename: str):
    if not os.path.exists(os.path.dirname(IMAGE_DIR)):
        print('creating ...')
        os.makedirs(os.path.dirname(IMAGE_DIR))

    data = data * 255
    image = Image.fromarray(data)
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image.save(IMAGE_DIR + filename)
    np.save(IMAGE_DIR + filename + str('.npy'), data)

In [None]:
import numpy as np


class Flatten(object):
    """ Transforms a PIL image to a flat numpy array. """
    def __init__(self):
        pass

    def __call__(self, sample):
        return np.array(sample, dtype=np.float32).flatten()


class Scale(object):
    """Scale images down to have [0,1] float pixel values"""
    def __init__(self, max_value=255):
        self.max_value = max_value

    def __call__(self, sample):
        return sample / self.max_value


class Permute(object):
    """ Apply a fixed permutation to the pixels in the image. """
    def __init__(self, permutation):
        self.permutation = permutation

    def __call__(self, sample):
        return sample[self.permutation]

# run_task

In [None]:
"""
Utilities that abstract the low-level details of experiments, such as standard train-and-eval loops.
"""

import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm


def run_point_estimate_initialisation(model, data, epochs, task_ids, batch_size,
                                      device, lr, task_idx=0, y_transform=None,
                                      multiheaded=True):
    print("Obtaining point estimate for posterior initialisation")

    head = task_idx if multiheaded else 0

    # each task has its own optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # obtain the appropriate data subset depending on which task we are running
    task_data = task_subset(data, task_ids, task_idx)
    loader = DataLoader(task_data, batch_size)

    # train
    for _ in tqdm(range(epochs), 'Epochs: '):
        for batch in loader:
            optimizer.zero_grad()
            x, y_true = batch
            x = x.to(device)
            y_true = y_true.to(device)

            if y_transform is not None:
                y_true = y_transform(y_true, task_idx).to(device)

            loss = model.point_estimate_loss(x, y_true, head=head)
            loss.backward()
            optimizer.step()


def run_task(model, train_data, train_task_ids, test_data, test_task_ids,
             task_idx, coreset, epochs, batch_size, save_as, device, lr,
             y_transform=None, multiheaded=True, train_full_coreset=True,
             summary_writer=None):


    print('TASK ', task_idx)

    # separate optimizer for each task
    optimizer = optim.Adam(model.parameters(), lr=lr)

    head = task_idx if multiheaded else 0

    # obtain correct subset of data for training, and set some aside for the coreset
    task_data = task_subset(train_data, train_task_ids, task_idx)
    non_coreset_data = coreset.select(task_data, task_id=task_idx)
    train_loader = DataLoader(non_coreset_data, batch_size)

    # train
    for epoch in tqdm(range(epochs), 'Epochs: '):
        epoch_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()
            x, y_true = batch
            x = x.to(device)
            y_true = y_true.to(device)

            if y_transform is not None:
                y_true = y_transform(y_true, task_idx)

            loss = model.vcl_loss(x, y_true, head, len(task_data))
            epoch_loss += len(x) * loss.item()

            loss.backward()
            optimizer.step()

        if summary_writer is not None:
            summary_writer.add_scalars("loss", {"TASK_" + str(task_idx): epoch_loss / len(task_data)}, epoch)

    # after training, prepare for new task by copying posteriors into priors
    model.reset_for_new_task(head)

    # train using full coreset
    if train_full_coreset:
        model_cs_trained = coreset.coreset_train(
            model, optimizer, list(range(task_idx + 1)), epochs,
            device, y_transform=y_transform, multiheaded=multiheaded)

    # test
    task_accuracies = []
    tot_right = 0
    tot_tested = 0

    for test_task_idx in range(task_idx + 1):
        if not train_full_coreset:
            model_cs_trained = coreset.coreset_train(
                model, optimizer, test_task_idx, epochs,
                device, y_transform=y_transform, multiheaded=multiheaded)

        head = test_task_idx if multiheaded else 0

        task_data = task_subset(test_data, test_task_ids, test_task_idx)

        x = torch.Tensor([x for x, _ in task_data])
        y_true = torch.Tensor([y for _, y in task_data])
        x = x.to(device)
        y_true = y_true.to(device)

        if y_transform is not None:
            y_true = y_transform(y_true, test_task_idx)

        y_pred = model_cs_trained.prediction(x, head).to(device)

        acc = class_accuracy(y_pred, y_true)
        print("After task {} perfomance on task {} is {}"
              .format(task_idx, test_task_idx, acc))

        tot_right += acc * len(task_data)
        tot_tested += len(task_data)
        task_accuracies.append(acc)

    mean_accuracy = tot_right / tot_tested
    print("Mean accuracy:", mean_accuracy)

    return task_accuracies, mean_accuracy

def run_task_scale(model, train_data, train_task_ids, test_data, test_task_ids,
             task_idx, coreset, epochs, batch_size, save_as, device, lr,
             y_transform=None, multiheaded=True, train_full_coreset=True,
             summary_writer=None):


    print('TASK ', task_idx)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    head = task_idx if multiheaded else 0

    task_data = task_subset(train_data, train_task_ids, task_idx)
    non_coreset_data = coreset.select(task_data, task_id=task_idx)
    train_loader = DataLoader(non_coreset_data, batch_size)

    for epoch in tqdm(range(epochs), 'Epochs: '):
        epoch_loss = 0
        for batch in train_loader:
            optimizer.zero_grad()
            x, y_true = batch
            x = x.to(device)
            y_true = y_true.to(device)

            if y_transform is not None:
                y_true = y_transform(y_true, task_idx)

            loss = model.vcl_loss_factor(x, y_true, head, len(task_data))
            epoch_loss += len(x) * loss.item()

            loss.backward()
            optimizer.step()

        if summary_writer is not None:
            summary_writer.add_scalars("loss", {"TASK_" + str(task_idx): epoch_loss / len(task_data)}, epoch)

    model.reset_for_new_task(head)

    if train_full_coreset:
        model_cs_trained = coreset.coreset_train(
            model, optimizer, list(range(task_idx + 1)), epochs,
            device, y_transform=y_transform, multiheaded=multiheaded)

    task_accuracies = []
    tot_right = 0
    tot_tested = 0

    for test_task_idx in range(task_idx + 1):
        if not train_full_coreset:
            model_cs_trained = coreset.coreset_train(
                model, optimizer, test_task_idx, epochs,
                device, y_transform=y_transform, multiheaded=multiheaded)

        head = test_task_idx if multiheaded else 0

        task_data = task_subset(test_data, test_task_ids, test_task_idx)

        x = torch.Tensor([x for x, _ in task_data])
        y_true = torch.Tensor([y for _, y in task_data])
        x = x.to(device)
        y_true = y_true.to(device)

        if y_transform is not None:
            y_true = y_transform(y_true, test_task_idx)

        y_pred = model_cs_trained.prediction(x, head).to(device)

        acc = class_accuracy(y_pred, y_true)
        print("After task {} perfomance on task {} is {}"
              .format(task_idx, test_task_idx, acc))

        tot_right += acc * len(task_data)
        tot_tested += len(task_data)
        task_accuracies.append(acc)

    mean_accuracy = tot_right / tot_tested
    print("Mean accuracy:", mean_accuracy)

    return task_accuracies, mean_accuracy






In [None]:
import torch
import torch.utils.data as data
from copy import deepcopy
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from random import shuffle


# DiscriminativeVCL

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


TRAIN_NUM_SAMPLES = 50
TEST_NUM_SAMPLES = 50
EPSILON = 1e-8

class DiscriminativeVCL(nn.Module):
    """
    A Bayesian multi-head neural network which updates its parameters using
    variational inference.
    """

    def __init__(self, in_size: int, out_size: int, layer_width: int,
                 n_hidden_layers: int, n_heads: int, initial_posterior_var: int):
        super().__init__()
        self.input_size = in_size
        self.out_size = out_size
        self.n_hidden_layers = n_hidden_layers
        self.layer_width = layer_width
        self.n_heads = n_heads

        print("Number of heads:", n_heads)

        self.prior, self.posterior = None, None
        self.head_prior, self.head_posterior = None, None

        self._init_variables(initial_posterior_var)
        self.initial_alpha = 1.0
        self.initial_beta = 1.0
        # Initialize current alpha and beta to the initial values
        self.alpha = self.initial_alpha
        self.beta = self.initial_beta

    def adjust_scaling_factors(self, mean_accuracy, new_task_accuracy):
        """
        Adjusts the scaling factors based on the comparison between
        previous task accuracy and current task accuracy.
        """
        perform_diff = new_task_accuracy- mean_accuracy
        alpha_adjustment_rate = 1.0 - perform_diff/10
        beta_adjustment_rate = 1.0 + perform_diff/10

        # Ensure the rates are within sensible bounds
        alpha_adjustment_rate = max(0.5, min(alpha_adjustment_rate, 1.5))
        beta_adjustment_rate = max(0.5, min(beta_adjustment_rate, 1.5))

        self.alpha = self.initial_alpha * alpha_adjustment_rate
        self.beta = self.initial_beta * beta_adjustment_rate
        print('alpha', self.alpha)
        print('beta', self.beta)

    def adjust_scaling_factors_2(self, new_task_idx, mean_accuracy, new_task_accuracy, task_similarity_matrix):
        """
        Adjusts the scaling factors based on the comparison between
        previous task accuracy and current task accuracy.
        """
        perform_diff = new_task_accuracy- mean_accuracy
        alpha_adjustment_rate = 1.0 - perform_diff/10
        beta_adjustment_rate = 1.0 + perform_diff/10

        # Ensure the rates are within sensible bounds
        alpha_adjustment_rate = max(0.5, min(alpha_adjustment_rate, 1.5))
        beta_adjustment_rate = max(0.5, min(beta_adjustment_rate, 1.5))

        self.alpha = self.initial_alpha * alpha_adjustment_rate
        self.beta = self.initial_beta * beta_adjustment_rate
        print('alpha', self.alpha)
        print('beta', self.beta)

    def to(self, *args, **kwargs):
        """
        Our prior tensors are registered as buffers but the way we access them
        indirectly (through tuple attributes on the model) is causing problems
        because when we use `.to()` to move the model to a new device, the prior
        tensors get moved (because they're registered as buffers) but the
        references in the tuples don't get updated to point to the new moved
        tensors. This has no effect when running just on a cpu but breaks the
        model when trying to run on a gpu. There are a million nicer ways of
        working around this problem, but for now the easiest thing is to do
        this: override the `.to()` method and manually update our references to
        prior tensors.
        """
        self = super().to(*args, **kwargs)
        (prior_w_means, prior_w_log_vars), (prior_b_means, prior_b_log_vars) = self.prior
        prior_w_means = [t.to(*args, **kwargs) for t in prior_w_means]
        prior_w_log_vars = [t.to(*args, **kwargs) for t in prior_w_log_vars]
        prior_b_means = [t.to(*args, **kwargs) for t in prior_b_means]
        prior_b_log_vars = [t.to(*args, **kwargs) for t in prior_b_log_vars]
        self.prior = (prior_w_means, prior_w_log_vars), (prior_b_means, prior_b_log_vars)
        (head_prior_w_means, head_prior_w_log_vars), (head_prior_b_means, head_prior_b_log_vars) = self.head_prior
        head_prior_w_means = [t.to(*args, **kwargs) for t in head_prior_w_means]
        head_prior_w_log_vars = [t.to(*args, **kwargs) for t in head_prior_w_log_vars]
        head_prior_b_means = [t.to(*args, **kwargs) for t in head_prior_b_means]
        head_prior_b_log_vars = [t.to(*args, **kwargs) for t in head_prior_b_log_vars]
        self.head_prior = (head_prior_w_means, head_prior_w_log_vars), (head_prior_b_means, head_prior_b_log_vars)
        return self

    def forward(self, x, head):
        """ Forward pass of the model on an input. """
        # sample layer parameters from posterior distribution
        (w_means, w_log_vars), (b_means, b_log_vars) = self.posterior
        (head_w_means, head_w_log_vars), (head_b_means, head_b_log_vars) = self.head_posterior
        sampled_layers = self._sample_parameters(w_means, b_means, w_log_vars, b_log_vars)
        sampled_head_layers = self._sample_parameters(head_w_means, head_b_means, head_w_log_vars, head_b_log_vars)

        # Apply each layer with its sampled weights and biases
        for weight, bias in sampled_layers:
            x = F.relu(x @ weight + bias)

        head_weight, head_bias = list(sampled_head_layers)[head]
        x = x @ head_weight + head_bias

        return x

    def vcl_loss(self, x, y, head, task_size, num_samples=TRAIN_NUM_SAMPLES) -> torch.Tensor:

        return self._calculate_kl_term(head).cpu() / task_size - self._log_prob(x, y, head, num_samples)

    def vcl_loss2(self, x, y, head, task_size, num_samples=TRAIN_NUM_SAMPLES, similarity_score=1.0) -> torch.Tensor:
        kl_term = self._calculate_kl_term(head) / task_size
        log_prob = self._log_prob(x, y, head, num_samples)
        dynamic_kl_weight = 1.0 + (1.0 - similarity_score) * adaptability_factor
        return dynamic_kl_weight * kl_term - log_prob

    def vcl_loss_factor(self, x, y, head, task_size, num_samples=TRAIN_NUM_SAMPLES):
        kl_term = self._calculate_kl_term(head) / task_size
        log_prob = self._log_prob(x, y, head, num_samples)
        return self.beta * kl_term - self.alpha * log_prob

    def point_estimate_loss(self, x, y, head=0):
        """
        Returns a loss defined in terms of a simplified forward pass that
        doesn't use sampling, and so uses the posterior means but not the
        variances. Used as part of model initialisation to optimise the
        posterior means to point-estimates for the first head.
        """
        (w_means, _), (b_means, _) = self.posterior
        (head_w_means, _), (head_b_means, _) = self.head_posterior

        for weight, bias in zip(w_means, b_means):
            x = F.relu(x @ weight + bias)

        x = x @ head_w_means[head] + head_b_means[head]

        return nn.CrossEntropyLoss()(x, y)

    def prediction(self, x, head, num_samples=TEST_NUM_SAMPLES):
        """Returns an integer between 0 and self.out_size"""
        outputs = torch.empty(num_samples, len(x), self.out_size)
        for i in range(num_samples):
            outputs[i] = nn.Softmax(dim=1)(self.forward(x, head))

        predictions = outputs.mean(dim=0)

        return torch.argmax(predictions, dim=1)

    def reset_for_new_task(self, head):
        """
        Called after completion of a task, to reset state for the next task
        """
        # Set the value of the prior to be the current value of the posterior
        (prior_w_means, prior_w_log_vars), (prior_b_means, prior_b_log_vars) = self.prior
        (post_w_means, post_w_log_vars), (post_b_means, post_b_log_vars) = self.posterior
        for i in range(self.n_hidden_layers):
            prior_w_means[i].data.copy_(post_w_means[i].data)
            prior_w_log_vars[i].data.copy_(post_w_log_vars[i].data)
            prior_b_means[i].data.copy_(post_b_means[i].data)
            prior_b_log_vars[i].data.copy_(post_b_log_vars[i].data)

        # set the value of the head prior to be the current value of the posterior
        (head_prior_w_means, head_prior_w_log_vars), (head_prior_b_means, head_prior_b_log_vars) = self.head_prior
        (head_posterior_w_means, head_posterior_w_log_vars), (head_posterior_b_means, head_posterior_b_log_vars) = self.head_posterior
        head_prior_w_means[head].data.copy_(head_posterior_w_means[head].data)
        head_prior_w_log_vars[head].data.copy_(head_posterior_w_log_vars[head].data)
        head_prior_b_means[head].data.copy_(head_posterior_b_means[head].data)
        head_prior_b_log_vars[head].data.copy_(head_posterior_b_log_vars[head].data)

    def _calculate_kl_term(self, head):
        """
        Calculates and returns the KL divergence of the new posterior and the previous
        iteration's posterior. See equation L3, slide 14.
        """
        # Prior
        ((prior_w_means, prior_w_log_vars), (prior_b_means, prior_b_log_vars)) = self.prior
        ((head_prior_w_means, head_prior_w_log_vars),
         (head_prior_b_means, head_prior_b_log_vars)) = self.head_prior

        prior_means = concatenate_flattened(
            prior_w_means + head_prior_w_means[head:head+1] +
            prior_b_means + head_prior_b_means[head:head+1])
        prior_log_vars = concatenate_flattened(
            prior_w_log_vars + head_prior_w_log_vars[head:head+1] +
            prior_b_log_vars + head_prior_b_log_vars[head:head+1])
        prior_vars = torch.exp(prior_log_vars)

        # Posterior
        ((post_w_means, post_w_log_vars), (post_b_means, post_b_log_vars)) = self.posterior
        ((head_post_w_means, head_post_w_log_vars),
         (head_post_b_means, head_post_b_log_vars)) = self.head_posterior

        post_means = concatenate_flattened(
            post_w_means + head_post_w_means[head:head+1] +
            post_b_means + head_post_b_means[head:head+1])
        post_log_vars = concatenate_flattened(
            post_w_log_vars + head_post_w_log_vars[head:head+1] +
            post_b_log_vars + head_post_b_log_vars[head:head+1])
        post_vars = torch.exp(post_log_vars)

        # Calculate KL for individual normal distributions over parameters
        kl_elementwise = \
            post_vars / (prior_vars + EPSILON) + \
            torch.pow(prior_means - post_means, 2) / (prior_vars + EPSILON) \
            - 1 + prior_log_vars - post_log_vars

        # Sum KL over all parameters
        return 0.5 * kl_elementwise.sum()

    def _log_prob(self, x, y, head, num_samples):
        outputs = []
        for i in range(num_samples):
            outputs.append(self.forward(x, head))

        return - nn.CrossEntropyLoss()(torch.cat(outputs), y.repeat(num_samples).view(-1))

    def _sample_parameters(self, w_means, b_means, w_log_vars, b_log_vars):
        # sample weights and biases from normal distributions
        sampled_weights, sampled_bias = [], []
        for layer_n in range(len(w_means)):
            w_epsilons = torch.randn_like(w_means[layer_n])
            b_epsilons = torch.randn_like(b_means[layer_n])
            sampled_weights.append(w_means[layer_n] + w_epsilons * torch.exp(0.5 * w_log_vars[layer_n]))
            sampled_bias.append(b_means[layer_n] + b_epsilons * torch.exp(0.5 * b_log_vars[layer_n]))
        return zip(sampled_weights, sampled_bias)

    def _init_variables(self, initial_posterior_var):
        """
        Initializes the model's prior and posterior weights / biases to their initial
        values. This method is called once on model creation. The model prior is registered
        as a persistent part of the model state which should not be modified, while the
        initial posterior is registered as a model parameter to be optimized.

        To avoid negative variances, we do not store parameter variances directly; instead
        we store the logarithm of each variance, and apply the exponential as needed in the
        forward pass.
        """
        # The initial prior over the parameters has zero mean, unit variance (i.e. log variance 0)
        prior_w_means = [torch.zeros(self.input_size, self.layer_width)] + \
                        [torch.zeros(self.layer_width, self.layer_width) for _ in range(self.n_hidden_layers - 1)]
        prior_w_log_vars = [torch.zeros_like(t) for t in prior_w_means]
        prior_b_means = [torch.zeros(self.layer_width) for _ in range(self.n_hidden_layers)]
        prior_b_log_vars = [torch.zeros_like(t) for t in prior_b_means]

        self.prior = ((prior_w_means, prior_w_log_vars), (prior_b_means, prior_b_log_vars))

        head_prior_w_means = [torch.zeros(self.layer_width, self.out_size) for t in range(self.n_heads)]
        head_prior_w_log_vars = [torch.zeros_like(t) for t in head_prior_w_means]
        head_prior_b_means = [torch.zeros(self.out_size) for t in range(self.n_heads)]
        head_prior_b_log_vars = [torch.zeros_like(t) for t in head_prior_b_means]

        self.head_prior = ((head_prior_w_means, head_prior_w_log_vars), (head_prior_b_means, head_prior_b_log_vars))

        empty_parameter_like = lambda t: nn.Parameter(torch.empty_like(t, requires_grad=True))

        posterior_w_means = [empty_parameter_like(t) for t in prior_w_means]
        posterior_w_log_vars = [empty_parameter_like(t) for t in prior_w_log_vars]
        posterior_b_means = [empty_parameter_like(t) for t in prior_b_means]
        posterior_b_log_vars = [empty_parameter_like(t) for t in prior_b_log_vars]

        self.posterior = ((posterior_w_means, posterior_w_log_vars), (posterior_b_means, posterior_b_log_vars))

        head_posterior_w_means = [empty_parameter_like(t) for t in head_prior_w_means]
        head_posterior_w_log_vars = [empty_parameter_like(t) for t in head_prior_w_log_vars]
        head_posterior_b_means = [empty_parameter_like(t) for t in head_prior_b_means]
        head_posterior_b_log_vars = [empty_parameter_like(t) for t in head_prior_b_log_vars]

        self.head_posterior = \
            ((head_posterior_w_means, head_posterior_w_log_vars),
             (head_posterior_b_means, head_posterior_b_log_vars))

        for t in posterior_w_means + posterior_b_means + head_posterior_w_means + head_posterior_b_means:
            torch.nn.init.normal_(t, mean=0, std=0.1)

        for t in posterior_w_log_vars + posterior_b_log_vars + head_posterior_w_log_vars + head_posterior_b_log_vars:
            torch.nn.init.constant_(t, math.log(initial_posterior_var))

        for i in range(self.n_hidden_layers):
            self.register_buffer("prior_w_means_" + str(i), prior_w_means[i])
            self.register_buffer("prior_w_log_vars_" + str(i), prior_w_log_vars[i])
            self.register_buffer("prior_b_means_" + str(i), prior_b_means[i])
            self.register_buffer("prior_b_log_vars_" + str(i), prior_b_log_vars[i])

        for i in range(self.n_heads):
            self.register_buffer("head_prior_w_means_" + str(i), head_prior_w_means[i])
            self.register_buffer("head_prior_w_log_vars_" + str(i), head_prior_w_log_vars[i])
            self.register_buffer("head_prior_b_means_" + str(i), head_prior_b_means[i])
            self.register_buffer("head_prior_b_log_vars_" + str(i), head_prior_b_log_vars[i])

        for i in range(self.n_hidden_layers):
            self.register_parameter("posterior_w_means_" + str(i), posterior_w_means[i])
            self.register_parameter("posterior_w_log_vars_" + str(i), posterior_w_log_vars[i])
            self.register_parameter("posterior_b_means_" + str(i), posterior_b_means[i])
            self.register_parameter("posterior_b_log_vars_" + str(i), posterior_b_log_vars[i])

        for i in range(self.n_heads):
            self.register_parameter("head_posterior_w_means_" + str(i), head_posterior_w_means[i])
            self.register_parameter("head_posterior_w_log_vars_" + str(i), head_posterior_w_log_vars[i])
            self.register_parameter("head_posterior_b_means_" + str(i), head_posterior_b_means[i])
            self.register_parameter("head_posterior_b_log_vars_" + str(i), head_posterior_b_log_vars[i])

    def _mean_posterior_variance(self):
        """
        Return the mean posterior variance for logging purposes.
        Excludes the head layer.
        """
        ((_, posterior_w_log_vars), (_, posterior_b_log_vars)) = self.posterior
        posterior_log_vars = torch.cat([torch.reshape(t, (-1,)) for t in posterior_w_log_vars] + posterior_b_log_vars)
        posterior_vars     = torch.exp(posterior_log_vars)
        return torch.mean(posterior_vars).item()




# coreset

In [None]:
import torch
import torch.utils.data as data
from copy import deepcopy
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from random import shuffle
import torch
from torch.utils.data import Dataset, Subset, ConcatDataset, TensorDataset
import random

class Coreset():
    """
    Base class for the the coreset.  This version of the class has no
    coreset but subclasses will replace the select method.
    """

    def __init__(self, size=0, lr=0.001):
        self.size = size
        self.coreset = None
        self.coreset_task_ids = None
        self.lr = lr

    def select(self, d: data.Dataset, task_id: int):
        """
        Given a torch dataset, will choose k datapoints.  Will then update
        the coreset with these datapoints.
        Returns: the subset that was not selected as a torch dataset.
        """

        return d

    def coreset_train(self, m, old_optimizer, tasks, epochs, device,
                      y_transform=None, multiheaded=True, batch_size=256):
        """
        Returns a new model, trained on the coreset.  The returned model will
        be a deep copy, except when coreset is empty (when it will be identical)

        tasks can be either a list, in which case the coreset will be trained
        on all tasks in the list, or an integer, in which case it will be
        trained on only that task.
        """

        if self.coreset is None:
            return m

        model = deepcopy(m)

        optimizer = optim.Adam(model.parameters(), lr=self.lr)
        optimizer.load_state_dict(old_optimizer.state_dict())

        # if tasks is an integer, turn it into a singleton.
        if isinstance(tasks, int):
            tasks = [tasks]

        # create dict of train_loaders
        train_loaders = {
            task_idx :  data.DataLoader(
                            task_subset(self.coreset, self.coreset_task_ids, task_idx),
                            batch_size
                        )
            for task_idx in tasks
        }

        print('CORESET TRAIN')
        for _ in tqdm(range(epochs), 'Epochs: '):
            # Randomize order of training tasks
            shuffle(tasks)
            for task_idx in tasks:
                head = task_idx if multiheaded else 0

                for batch in train_loaders[task_idx]:
                    optimizer.zero_grad()
                    x, y_true = batch
                    x = x.to(device)
                    y_true = y_true.to(device)

                    if y_transform is not None:
                        y_true = y_transform(y_true, task_idx)

                    loss = model.vcl_loss(x, y_true, head, len(self.coreset))
                    loss.backward()
                    optimizer.step()

        return model



class RandomCoreset(Coreset):

    def __init__(self, size):
        super().__init__(size)

    def select(self, d : data.Dataset, task_id : int):
#         print([_ for x, _ in d])

        new_cs_data, non_cs = data.random_split(d, [self.size, max(0,len(d)-self.size)])

        # Need to split the x from the y values to also include the task values.
        # I don't like this way of doing it, but I couldn't find something better.
        new_cs_x = torch.tensor([x for x, _ in new_cs_data])
        new_cs_y = torch.tensor([y for _, y in new_cs_data])

        new_cs = data.TensorDataset(new_cs_x, new_cs_y)
        new_task_ids = torch.full((len(new_cs_data),), task_id)

        if self.coreset is None:
            self.coreset = new_cs
            self.coreset_task_ids = new_task_ids
        else:
            self.coreset = data.ConcatDataset((self.coreset, new_cs))
            self.coreset_task_ids = torch.cat((self.coreset_task_ids, new_task_ids))

        return non_cs

# ---------------------------------------------------------------------------------------------------------

import torch
from torch.utils.data import Dataset, ConcatDataset, DataLoader, random_split


class RandomCoreset_resample(Coreset):
    def __init__(self, size, dataset, lr=0.001):
        super().__init__(size)
        self.size = size
        self.coreset_task_ids = torch.tensor([], dtype=torch.long)  # Store (dataset, task_id) for all tasks
        self.all_data_x = [x for x, _ in dataset]
        self.all_data_y = [y for _, y in dataset]
        self.dataset = dataset
        self.lr = lr


    def select(self, d: Dataset, task_id):
        label_to_task_mapping = {
        0: 0, 1: 0,
        2: 1, 3: 1,
        4: 2, 5: 2,
        6: 3, 7: 3,
        8: 4, 9: 4,
    }
        task_to_labels = {task: [] for task in set(label_to_task_mapping.values())}
        for label, task in label_to_task_mapping.items():
            task_to_labels[task].append(label)

        current_classes = [label for label, task in label_to_task_mapping.items() if task <= task_id]

        print('Current classes are:', current_classes)

        num_classes = 10
        samples_per_class = self.size // num_classes

        new_coreset_data = []

        selected_indices = []
        all_current_class_indices = []
        for class_id in current_classes:
            # Find indices of this class
            class_indices = [i for i, y in enumerate(self.all_data_y) if y == class_id]
            selected_for_class = random.sample(class_indices, min(samples_per_class, len(class_indices)))
            all_current_class_indices.extend(class_indices)
            selected_indices.extend(selected_for_class) # the selected indices for current class
        print(len(selected_indices))
        # Extract selected samples for coreset
        self.all_data_x = [torch.tensor(x) if not isinstance(x, torch.Tensor) else x for x in self.all_data_x]
        self.all_data_y = [torch.tensor(y, dtype=torch.long) if not isinstance(y, torch.Tensor) else y for y in self.all_data_y]

        self.new_x = [torch.tensor(x) if not isinstance(x, torch.Tensor) else x for x in self.all_data_x]
        self.new_y = [torch.tensor(y, dtype=torch.long) if not isinstance(y, torch.Tensor) else y for y in self.all_data_y]

        coreset_x = torch.stack([self.new_x[i] for i in selected_indices])
        coreset_y = torch.tensor([self.new_y[i] for i in selected_indices], dtype=torch.long)

        # Update coreset and task ids
        new_coreset = TensorDataset(coreset_x, coreset_y)
        new_task_ids = torch.full((len(new_coreset),), task_id, dtype=torch.long)

        # if self.coreset is None:
        self.coreset = new_coreset
        self.coreset_task_ids = new_task_ids
        # else:
        #     self.coreset = ConcatDataset([self.coreset, new_coreset])
        #     self.coreset_task_ids = torch.cat((self.coreset_task_ids, new_task_ids))
        print("len(self.coreset)", len(self.coreset))

#         ye_corset = self.coreset

#         non_selected_indices = list(set(range(len(self.all_data_y))) - set(selected_indices))
        non_selected_indices = set(all_current_class_indices) - set(selected_indices)
#         non_coreset_x = [self.all_data_x[i] for i in non_selected_indices]
#         non_coreset_y = [self.all_data_y[i] for i in non_selected_indices]
        non_coreset_x = torch.stack([self.all_data_x[i] for i in non_selected_indices])
        non_coreset_y = torch.tensor([self.all_data_y[i] for i in non_selected_indices], dtype=torch.long)
#         non_coreset = TensorDataset(torch.stack(non_coreset_x), torch.tensor(non_coreset_y))
        non_coreset = TensorDataset(non_coreset_x, non_coreset_y)
        print('non_coreset', len(non_coreset))


        return non_coreset

# experiment

In [None]:
import torch
import torch.optim as optim
import numpy as np
from torchvision.datasets import MNIST, CIFAR100
from torchvision.transforms import Compose
from torch.utils.data import ConcatDataset
from tensorboardX import SummaryWriter
import os
from datetime import datetime

## method1: split_mnist

In [None]:


MNIST_FLATTENED_DIM = 28 * 28
LR = 0.001
INITIAL_POSTERIOR_VAR = 1e-3

device = torch.device("cuda")
print("Running on device", device)


def split_mnist_original():
    """
    Runs the 'Split MNIST' experiment from the VCL paper, in which each task is
    a binary classification task carried out on a subset of the MNIST dataset.
    """

    N_CLASSES = 2
    LAYER_WIDTH = 256
    N_HIDDEN_LAYERS = 2
    N_TASKS = 5
    MULTIHEADED = True
    CORESET_SIZE = 200
    EPOCHS = 10
    BATCH_SIZE = 50000
    TRAIN_FULL_CORESET = True

    transform = Compose([Flatten(), Scale()])

    all_accuracies = [[] for _ in range(N_TASKS)]
    mean_accuracies = []

    # download dataset
    mnist_train = MNIST(root="data", train=True, download=True, transform=transform)
    mnist_test = MNIST(root="data", train=False, download=True, transform=transform)

    model = DiscriminativeVCL(
        in_size=MNIST_FLATTENED_DIM, out_size=N_CLASSES,
        layer_width=LAYER_WIDTH, n_hidden_layers=N_HIDDEN_LAYERS,
        n_heads=(N_TASKS if MULTIHEADED else 1),
        initial_posterior_var=INITIAL_POSTERIOR_VAR
    ).to(device)

    coreset = RandomCoreset(size=CORESET_SIZE)


    label_to_task_mapping = {
        0: 0, 1: 0,
        2: 1, 3: 1,
        4: 2, 5: 2,
        6: 3, 7: 3,
        8: 4, 9: 4,
    }

    if isinstance(mnist_train[0][1], int):
        train_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_train])
        test_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_test])
    elif isinstance(mnist_train[0][1], torch.Tensor):
        train_task_ids = torch.Tensor([label_to_task_mapping[y.item()] for _, y in mnist_train])
        test_task_ids = torch.Tensor([label_to_task_mapping[y.item()] for _, y in mnist_test])

    summary_logdir = os.path.join("logs", "disc_s_mnist", datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)

    # each task is a binary classification task for a different pair of digits
    binarize_y = lambda y, task: (y == (2 * task + 1)).long()

    run_point_estimate_initialisation(model=model, data=mnist_train,
                                      epochs=EPOCHS, batch_size=BATCH_SIZE,
                                      device=device, multiheaded=MULTIHEADED,
                                      lr=LR, task_ids=train_task_ids,
                                      y_transform=binarize_y)

    for task_idx in range(N_TASKS):

        current_task_accuracies, current_mean = run_task(
            model=model, train_data=mnist_train, train_task_ids=train_task_ids,
            test_data=mnist_test, test_task_ids=test_task_ids, coreset=coreset,
            task_idx=task_idx, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR,
            save_as="disc_s_mnist", device=device, multiheaded=MULTIHEADED,
            y_transform=binarize_y, train_full_coreset=TRAIN_FULL_CORESET,
            summary_writer=writer
        )
        mean_accuracies.append(current_mean)
        for i, acc in enumerate(current_task_accuracies):
            all_accuracies[i].append(acc)
    all_accuracies.append(mean_accuracies)
    print("All Task Accuracies:", all_accuracies)
    writer.close()



Running on device cuda


In [None]:
# random sample corset
EXP_OPTIONS = {
    'disc_s_mnist': split_mnist_original,
    'disc_s_mnist2': split_mnist_corset2,
    'disc_s_mnist_order_similar': split_mnist_order_similar,
    'disc_s_mnist_order_dissimilar': split_mnist_order_dissimilar
}

# Set the experiment you want to run here
experiment = 'disc_s_mnist'  # Options: 'disc_p_mnist', 'disc_s_mnist', 'disc_s_n_mnist', or 'all' to run all experiments

# Run the selected experiment(s)
if experiment in EXP_OPTIONS:
    print(f"Running {experiment}")
    EXP_OPTIONS[experiment]()
else:
    print(f"Experiment '{experiment}' not found. Available options are: {list(EXP_OPTIONS.keys())}")


Running disc_s_mnist
Number of heads: 5
Obtaining point estimate for posterior initialisation


Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00,  1.88it/s]


TASK  0


Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.65it/s]


CORESET TRAIN


Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 17.46it/s]


After task 0 perfomance on task 0 is 99.90543735224587
Mean accuracy: 99.90543735224587
TASK  1


Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00,  1.75it/s]


CORESET TRAIN


Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  8.42it/s]


After task 1 perfomance on task 0 is 99.90543735224587
After task 1 perfomance on task 1 is 95.5435847208619
Mean accuracy: 97.76280971854703
TASK  2


Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00,  1.89it/s]


CORESET TRAIN


Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  5.49it/s]


After task 2 perfomance on task 0 is 99.90543735224587
After task 2 perfomance on task 1 is 95.34769833496571
After task 2 perfomance on task 2 is 97.2785485592316
Mean accuracy: 97.54601226993866
TASK  3


Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00,  1.74it/s]


CORESET TRAIN


Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.05it/s]


After task 3 perfomance on task 0 is 99.90543735224587
After task 3 perfomance on task 1 is 95.64152791381
After task 3 perfomance on task 2 is 97.2785485592316
After task 3 perfomance on task 3 is 99.49647532729104
Mean accuracy: 98.10402893850568
TASK  4


Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00,  1.78it/s]


CORESET TRAIN


Epochs: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  3.18it/s]


After task 4 perfomance on task 0 is 99.90543735224587
After task 4 perfomance on task 1 is 95.5435847208619
After task 4 perfomance on task 2 is 97.86552828175027
After task 4 perfomance on task 3 is 99.29506545820745
After task 4 perfomance on task 4 is 95.86485123550176
Mean accuracy: 97.71
All Task Accuracies: [[99.90543735224587, 99.90543735224587, 99.90543735224587, 99.90543735224587, 99.90543735224587], [95.5435847208619, 95.34769833496571, 95.64152791381, 95.5435847208619], [97.2785485592316, 97.2785485592316, 97.86552828175027], [99.49647532729104, 99.29506545820745], [95.86485123550176], [99.90543735224587, 97.76280971854703, 97.54601226993866, 98.10402893850568, 97.71]]


## method2: random sample corset

In [None]:
def split_mnist_corset2():
    """
    Runs the 'Split MNIST' experiment from the VCL paper, in which each task is
    a binary classification task carried out on a subset of the MNIST dataset.
    """

    N_CLASSES = 2
    LAYER_WIDTH = 256
    N_HIDDEN_LAYERS = 2
    N_TASKS = 5
    MULTIHEADED = True
    CORESET_SIZE = 200
    EPOCHS = 120
    BATCH_SIZE = 50000
    TRAIN_FULL_CORESET = True

    transform = Compose([Flatten(), Scale()])

    all_accuracies = [[] for _ in range(N_TASKS)]
    mean_accuracies = []

    # download dataset
    mnist_train = MNIST(root="data", train=True, download=True, transform=transform)
    mnist_test = MNIST(root="data", train=False, download=True, transform=transform)

    model = DiscriminativeVCL(
        in_size=MNIST_FLATTENED_DIM, out_size=N_CLASSES,
        layer_width=LAYER_WIDTH, n_hidden_layers=N_HIDDEN_LAYERS,
        n_heads=(N_TASKS if MULTIHEADED else 1),
        initial_posterior_var=INITIAL_POSTERIOR_VAR
    ).to(device)

    coreset = RandomCoreset_resample(size=CORESET_SIZE, dataset = mnist_train)


    label_to_task_mapping = {
        0: 0, 1: 0,
        2: 1, 3: 1,
        4: 2, 5: 2,
        6: 3, 7: 3,
        8: 4, 9: 4,
    }

    if isinstance(mnist_train[0][1], int):
        train_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_train])
        test_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_test])
    elif isinstance(mnist_train[0][1], torch.Tensor):
        train_task_ids = torch.Tensor([label_to_task_mapping[y.item()] for _, y in mnist_train])
        test_task_ids = torch.Tensor([label_to_task_mapping[y.item()] for _, y in mnist_test])

    summary_logdir = os.path.join("logs", "disc_s_mnist_2", datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)

    # each task is a binary classification task for a different pair of digits
    binarize_y = lambda y, task: (y == (2 * task + 1)).long()

    run_point_estimate_initialisation(model=model, data=mnist_train,
                                      epochs=EPOCHS, batch_size=BATCH_SIZE,
                                      device=device, multiheaded=MULTIHEADED,
                                      lr=LR, task_ids=train_task_ids,
                                      y_transform=binarize_y)

    for task_idx in range(N_TASKS):

        current_task_accuracies, current_mean =run_task(
            model=model,
            train_data=mnist_train,
            train_task_ids=train_task_ids,
            test_data=mnist_test,
            test_task_ids=test_task_ids,
            coreset=coreset,
            task_idx=task_idx,
            epochs=EPOCHS,
            batch_size=BATCH_SIZE, lr=LR,
            save_as="disc_s_mnist_2", device=device, multiheaded=MULTIHEADED,
            y_transform=binarize_y, train_full_coreset=TRAIN_FULL_CORESET,
            summary_writer=writer
        )
        mean_accuracies.append(current_mean)
        for i, acc in enumerate(current_task_accuracies):
            all_accuracies[i].append(acc)
    all_accuracies.append(mean_accuracies)
    print("All Task Accuracies:", all_accuracies)
    writer.close()

In [None]:
EXP_OPTIONS = {
    'disc_s_mnist': split_mnist,
    'disc_s_mnist2': split_mnist_corset2,
    'disc_s_mnist_order_similar': disc_s_mnist_order_similar,
    'disc_s_mnist_order_dissimilar': split_mnist_order_dissimilar
}

# Set the experiment you want to run here
experiment = 'disc_s_mnist2'  # Options: 'disc_p_mnist', 'disc_s_mnist', 'disc_s_n_mnist', or 'all' to run all experiments

# Run the selected experiment(s)
if experiment in EXP_OPTIONS:
    print(f"Running {experiment}")
    EXP_OPTIONS[experiment]()
else:
    print(f"Experiment '{experiment}' not found. Available options are: {list(EXP_OPTIONS.keys())}")


Running disc_s_mnist2
Number of heads: 5
Obtaining point estimate for posterior initialisation


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:59<00:00,  2.03it/s]


TASK  0
Current classes are: [0, 1]
40
len(self.coreset) 40
non_coreset 12625


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:21<00:00,  5.50it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:06<00:00, 17.29it/s]


After task 0 perfomance on task 0 is 99.90543735224587
Mean accuracy: 99.90543735224587
TASK  1
Current classes are: [0, 1, 2, 3]
80
len(self.coreset) 80
non_coreset 24674


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:41<00:00,  2.93it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:07<00:00, 17.14it/s]


After task 1 perfomance on task 0 is 99.95271867612293
After task 1 perfomance on task 1 is 96.91478942213516
Mean accuracy: 98.4604281934087
TASK  2
Current classes are: [0, 1, 2, 3, 4, 5]
120
len(self.coreset) 120
non_coreset 35897


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:58<00:00,  2.04it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:07<00:00, 16.95it/s]


After task 2 perfomance on task 0 is 99.95271867612293
After task 2 perfomance on task 1 is 96.32713026444662
After task 2 perfomance on task 2 is 93.64994663820704
Mean accuracy: 96.76670535566241
TASK  3
Current classes are: [0, 1, 2, 3, 4, 5, 6, 7]
160
len(self.coreset) 160
non_coreset 48040


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:17<00:00,  1.55it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:07<00:00, 17.05it/s]


After task 3 perfomance on task 0 is 99.95271867612293
After task 3 perfomance on task 1 is 96.13124387855044
After task 3 perfomance on task 2 is 94.61045891141943
After task 3 perfomance on task 3 is 93.95770392749245
Mean accuracy: 96.24547835848821
TASK  4
Current classes are: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
200
len(self.coreset) 200
non_coreset 59800


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:39<00:00,  1.21it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:07<00:00, 16.96it/s]


After task 4 perfomance on task 0 is 99.90543735224587
After task 4 perfomance on task 1 is 94.22135161606268
After task 4 perfomance on task 2 is 96.05122732123799
After task 4 perfomance on task 3 is 95.36757301107754
After task 4 perfomance on task 4 is 87.49369641956632
Mean accuracy: 94.66
All Task Accuracies: [[99.90543735224587, 99.95271867612293, 99.95271867612293, 99.95271867612293, 99.90543735224587], [96.91478942213516, 96.32713026444662, 96.13124387855044, 94.22135161606268], [93.64994663820704, 94.61045891141943, 96.05122732123799], [93.95770392749245, 95.36757301107754], [87.49369641956632], [99.90543735224587, 98.4604281934087, 96.76670535566241, 96.24547835848821, 94.66]]


In [None]:
# reorder
def binarize_y(y, task, task_to_labels):
    class_0, class_1 = task_to_labels[task]
    return (y == class_1).long()

# reorder
def split_mnist_order_similar():

    N_CLASSES = 2 # TODO does it make sense to do binary classification with out_size=2 ?
    LAYER_WIDTH = 256
    N_HIDDEN_LAYERS = 2
    N_TASKS = 5
    MULTIHEADED = True
    CORESET_SIZE = 200
    EPOCHS = 120
#     BATCH_SIZE = 50000
    BATCH_SIZE = 50000

    TRAIN_FULL_CORESET = True

    transform = Compose([Flatten(), Scale()])

    all_accuracies = [[] for _ in range(N_TASKS)]
    mean_accuracies = []

    # download dataset
    mnist_train = MNIST(root="data", train=True, download=True, transform=transform)
    mnist_test = MNIST(root="data", train=False, download=True, transform=transform)

    model = DiscriminativeVCL(
        in_size=MNIST_FLATTENED_DIM, out_size=N_CLASSES,
        layer_width=LAYER_WIDTH, n_hidden_layers=N_HIDDEN_LAYERS,
        n_heads=(N_TASKS if MULTIHEADED else 1),
        initial_posterior_var=INITIAL_POSTERIOR_VAR
    ).to(device)

    coreset = RandomCoreset(size=CORESET_SIZE)
# similar
    label_to_task_mapping = {
        2: 0, 3: 0,
        5: 1, 6: 1,
        8: 2, 9: 2,
        0: 3, 4: 3,
        7: 4, 1: 4,
    }

    task_to_labels = {task: [] for task in range(N_TASKS)}
    for label, task in label_to_task_mapping.items():
        task_to_labels[task].append(label)
    for task, labels in task_to_labels.items():
        task_to_labels[task] = sorted(labels)

    train_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_train])
    test_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_test])

    summary_logdir = os.path.join("logs", "disc_s_mnist", datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)

    # each task is a binary classification task for a different pair of digits
#     binarize_y = lambda y, task: (y == (2 * task + 1)).long()
    bin_y = lambda y, task: binarize_y(y, task, task_to_labels)

    run_point_estimate_initialisation(model=model, data=mnist_train,
                                      epochs=EPOCHS, batch_size=BATCH_SIZE,
                                      device=device, multiheaded=MULTIHEADED,
                                      lr=LR, task_ids=train_task_ids,
                                      y_transform=bin_y)

    for task_idx in range(N_TASKS):
        current_task_accuracies, current_mean = run_task(
            model=model, train_data=mnist_train, train_task_ids=train_task_ids,
            test_data=mnist_test, test_task_ids=test_task_ids, coreset=coreset,
            task_idx=task_idx, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR,
            save_as="disc_s_mnist", device=device, multiheaded=MULTIHEADED,
            y_transform=bin_y, train_full_coreset=TRAIN_FULL_CORESET,
            summary_writer=writer
        )
        mean_accuracies.append(current_mean)
        for i, acc in enumerate(current_task_accuracies):
            all_accuracies[i].append(acc)
    all_accuracies.append(mean_accuracies)
    print("All Task Accuracies:", all_accuracies)

    writer.close()

## method3: similar within task

In [None]:


EXP_OPTIONS = {
    'disc_s_mnist': split_mnist_original,
    'disc_s_mnist2': split_mnist_corset2,
    'disc_s_mnist_order_similar': split_mnist_order_similar,
    'disc_s_mnist_order_dissimilar': split_mnist_order_dissimilar
}

# Set the experiment you want to run here
experiment = 'disc_s_mnist_order_similar'  # Options: 'disc_p_mnist', 'disc_s_mnist', 'disc_s_n_mnist', or 'all' to run all experiments

# Run the selected experiment(s)
if experiment in EXP_OPTIONS:
    print(f"Running {experiment}")
    EXP_OPTIONS[experiment]()
else:
    print(f"Experiment '{experiment}' not found. Available options are: {list(EXP_OPTIONS.keys())}")


Running disc_s_mnist_order
Number of heads: 5
Obtaining point estimate for posterior initialisation


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:54<00:00,  2.19it/s]


TASK  0


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:04<00:00,  1.86it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:06<00:00, 17.92it/s]


After task 0 perfomance on task 0 is 99.75514201762978
Mean accuracy: 99.75514201762978
TASK  1


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:00<00:00,  1.97it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:13<00:00,  8.58it/s]


After task 1 perfomance on task 0 is 99.51028403525955
After task 1 perfomance on task 1 is 98.21621621621621
Mean accuracy: 98.89516957862281
TASK  2


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:03<00:00,  1.88it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:21<00:00,  5.50it/s]


After task 2 perfomance on task 0 is 98.87365328109696
After task 2 perfomance on task 1 is 98.27027027027027
After task 2 perfomance on task 2 is 97.68028240040343
Mean accuracy: 98.28085106382979
TASK  3


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:03<00:00,  1.89it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:29<00:00,  4.05it/s]


After task 3 perfomance on task 0 is 99.31439764936337
After task 3 perfomance on task 1 is 98.16216216216216
After task 3 perfomance on task 2 is 98.23499747856782
After task 3 perfomance on task 3 is 99.84709480122324
Mean accuracy: 98.90264131683043
TASK  4


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:10<00:00,  1.71it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:37<00:00,  3.19it/s]


After task 4 perfomance on task 0 is 99.41234084231147
After task 4 perfomance on task 1 is 98.16216216216216
After task 4 perfomance on task 2 is 95.96570852244075
After task 4 perfomance on task 3 is 99.84709480122324
After task 4 perfomance on task 4 is 99.4914470642626
Mean accuracy: 98.6
All Task Accuracies: [[99.75514201762978, 99.51028403525955, 98.87365328109696, 99.31439764936337, 99.41234084231147], [98.21621621621621, 98.27027027027027, 98.16216216216216, 98.16216216216216], [97.68028240040343, 98.23499747856782, 95.96570852244075], [99.84709480122324, 99.84709480122324], [99.4914470642626], [99.75514201762978, 98.89516957862281, 98.28085106382979, 98.90264131683043, 98.6]]


## method4: dissimilar within task

In [None]:
# reorder
def split_mnist_order_dissimilar():

    N_CLASSES = 2 # TODO does it make sense to do binary classification with out_size=2 ?
    LAYER_WIDTH = 256
    N_HIDDEN_LAYERS = 2
    N_TASKS = 5
    MULTIHEADED = True
    CORESET_SIZE = 200
    EPOCHS = 120
#     BATCH_SIZE = 50000
    BATCH_SIZE = 50000

    TRAIN_FULL_CORESET = True

    transform = Compose([Flatten(), Scale()])

    all_accuracies = [[] for _ in range(N_TASKS)]
    mean_accuracies = []

    # download dataset
    mnist_train = MNIST(root="data", train=True, download=True, transform=transform)
    mnist_test = MNIST(root="data", train=False, download=True, transform=transform)

    model = DiscriminativeVCL(
        in_size=MNIST_FLATTENED_DIM, out_size=N_CLASSES,
        layer_width=LAYER_WIDTH, n_hidden_layers=N_HIDDEN_LAYERS,
        n_heads=(N_TASKS if MULTIHEADED else 1),
        initial_posterior_var=INITIAL_POSTERIOR_VAR
    ).to(device)

    coreset = RandomCoreset(size=CORESET_SIZE)
# # dissimilar
    label_to_task_mapping = {
        2: 0, 8: 0,
        5: 1, 9: 1,
        6: 2, 7: 2,
        0: 3, 4: 3,
        3: 4, 1: 4,
    }

    task_to_labels = {task: [] for task in range(N_TASKS)}
    for label, task in label_to_task_mapping.items():
        task_to_labels[task].append(label)
    for task, labels in task_to_labels.items():
        task_to_labels[task] = sorted(labels)

    train_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_train])
    test_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_test])

    summary_logdir = os.path.join("logs", "disc_s_mnist", datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)


    bin_y = lambda y, task: binarize_y(y, task, task_to_labels)

    run_point_estimate_initialisation(model=model, data=mnist_train,
                                      epochs=EPOCHS, batch_size=BATCH_SIZE,
                                      device=device, multiheaded=MULTIHEADED,
                                      lr=LR, task_ids=train_task_ids,
                                      y_transform=bin_y)

    for task_idx in range(N_TASKS):
        current_task_accuracies, current_mean = run_task(
            model=model, train_data=mnist_train, train_task_ids=train_task_ids,
            test_data=mnist_test, test_task_ids=test_task_ids, coreset=coreset,
            task_idx=task_idx, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR,
            save_as="disc_s_mnist", device=device, multiheaded=MULTIHEADED,
            y_transform=bin_y, train_full_coreset=TRAIN_FULL_CORESET,
            summary_writer=writer
        )
        mean_accuracies.append(current_mean)
        for i, acc in enumerate(current_task_accuracies):
            all_accuracies[i].append(acc)
    all_accuracies.append(mean_accuracies)
    print("All Task Accuracies:", all_accuracies)

    writer.close()

In [None]:
EXP_OPTIONS = {
    'disc_s_mnist': split_mnist_original,
    'disc_s_mnist2': split_mnist_corset2,
    'disc_s_mnist_order_similar': split_mnist_order_similar,
    'disc_s_mnist_order_dissimilar': split_mnist_order_dissimilar
}

# Set the experiment you want to run here
experiment = 'disc_s_mnist_order_dissimilar'  # Options: 'disc_p_mnist', 'disc_s_mnist', 'disc_s_n_mnist', or 'all' to run all experiments

# Run the selected experiment(s)

if experiment in EXP_OPTIONS:
    print(f"Running {experiment}")
    EXP_OPTIONS[experiment]()
else:
    print(f"Experiment '{experiment}' not found. Available options are: {list(EXP_OPTIONS.keys())}")


Running disc_s_mnist_order
Number of heads: 5
Obtaining point estimate for posterior initialisation


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:53<00:00,  2.22it/s]


TASK  0


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:03<00:00,  1.88it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:06<00:00, 19.23it/s]


After task 0 perfomance on task 0 is 99.2023928215354
Mean accuracy: 99.2023928215354
TASK  1


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:01<00:00,  1.96it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:13<00:00,  8.67it/s]


After task 1 perfomance on task 0 is 99.10269192422732
After task 1 perfomance on task 1 is 98.73750657548659
Mean accuracy: 98.92500639877143
TASK  2


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:05<00:00,  1.82it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:21<00:00,  5.60it/s]


After task 2 perfomance on task 0 is 99.05284147557327
After task 2 perfomance on task 1 is 98.79011046817465
After task 2 perfomance on task 2 is 99.49647532729104
Mean accuracy: 99.11759714916002
TASK  3


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:03<00:00,  1.88it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:28<00:00,  4.23it/s]


After task 3 perfomance on task 0 is 99.15254237288136
After task 3 perfomance on task 1 is 98.52709100473434
After task 3 perfomance on task 2 is 99.29506545820745
After task 3 perfomance on task 3 is 99.64322120285424
Mean accuracy: 99.15977084659453
TASK  4


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:10<00:00,  1.70it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:34<00:00,  3.48it/s]


After task 4 perfomance on task 0 is 98.70388833499501
After task 4 perfomance on task 1 is 98.47448711204629
After task 4 perfomance on task 2 is 99.29506545820745
After task 4 perfomance on task 3 is 99.59225280326197
After task 4 perfomance on task 4 is 99.44055944055944
Mean accuracy: 99.11
All Task Accuracies: [[99.2023928215354, 99.10269192422732, 99.05284147557327, 99.15254237288136, 98.70388833499501], [98.73750657548659, 98.79011046817465, 98.52709100473434, 98.47448711204629], [99.49647532729104, 99.29506545820745, 99.29506545820745], [99.64322120285424, 99.59225280326197], [99.44055944055944], [99.2023928215354, 98.92500639877143, 99.11759714916002, 99.15977084659453, 99.11]]


## method5: scaling factor

In [None]:

MNIST_FLATTENED_DIM = 28 * 28
LR = 0.001
INITIAL_POSTERIOR_VAR = 1e-3

device = torch.device("cuda")
print("Running on device", device)

def split_mnist_new():
    """
    Runs the 'Split MNIST' experiment from the VCL paper, in which each task is
    a binary classification task carried out on a subset of the MNIST dataset.
    """

    N_CLASSES = 2
    LAYER_WIDTH = 256
    N_HIDDEN_LAYERS = 2
    N_TASKS = 5
    MULTIHEADED = True
    CORESET_SIZE = 200
    EPOCHS = 120
    BATCH_SIZE = 50000
    TRAIN_FULL_CORESET = True

    transform = Compose([Flatten(), Scale()])

    all_accuracies = [[] for _ in range(N_TASKS)]
    mean_accuracies = []
    prev_mean_accuracy = None

    # download dataset
    mnist_train = MNIST(root="data", train=True, download=True, transform=transform)
    mnist_test = MNIST(root="data", train=False, download=True, transform=transform)

    model = DiscriminativeVCL(
        in_size=MNIST_FLATTENED_DIM, out_size=N_CLASSES,
        layer_width=LAYER_WIDTH, n_hidden_layers=N_HIDDEN_LAYERS,
        n_heads=(N_TASKS if MULTIHEADED else 1),
        initial_posterior_var=INITIAL_POSTERIOR_VAR
    ).to(device)

    coreset = RandomCoreset(size=CORESET_SIZE)


    label_to_task_mapping = {
        0: 0, 1: 0,
        2: 1, 3: 1,
        4: 2, 5: 2,
        6: 3, 7: 3,
        8: 4, 9: 4,
    }

    if isinstance(mnist_train[0][1], int):
        train_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_train])
        test_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_test])
    elif isinstance(mnist_train[0][1], torch.Tensor):
        train_task_ids = torch.Tensor([label_to_task_mapping[y.item()] for _, y in mnist_train])
        test_task_ids = torch.Tensor([label_to_task_mapping[y.item()] for _, y in mnist_test])

    summary_logdir = os.path.join("logs", "disc_s_mnist", datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)

    # each task is a binary classification task for a different pair of digits
    binarize_y = lambda y, task: (y == (2 * task + 1)).long()

    run_point_estimate_initialisation(model=model, data=mnist_train,
                                      epochs=EPOCHS, batch_size=BATCH_SIZE,
                                      device=device, multiheaded=MULTIHEADED,
                                      lr=LR, task_ids=train_task_ids,
                                      y_transform=binarize_y)

    for task_idx in range(N_TASKS):

        current_task_accuracies, current_mean = run_task_scale(
            model=model, train_data=mnist_train, train_task_ids=train_task_ids,
            test_data=mnist_test, test_task_ids=test_task_ids, coreset=coreset,
            task_idx=task_idx, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR,
            save_as="disc_s_mnist", device=device, multiheaded=MULTIHEADED,
            y_transform=binarize_y, train_full_coreset=TRAIN_FULL_CORESET,
            summary_writer=writer
        )
        current_last_accuracy = current_task_accuracies[-1]
        model.adjust_scaling_factors(current_mean, current_last_accuracy)

        mean_accuracies.append(current_mean)
        for i, acc in enumerate(current_task_accuracies):
            all_accuracies[i].append(acc)
    all_accuracies.append(mean_accuracies)
    print("All Task Accuracies:", all_accuracies)
    writer.close()

Running on device cuda


In [None]:
# random sample corset
EXP_OPTIONS = {
    'disc_s_mnist': split_mnist_original,
    'disc_s_mnist2': split_mnist_corset2,
    'disc_s_mnist_order_similar': split_mnist_order_similar,
    'disc_s_mnist_order_dissimilar': split_mnist_order_dissimilar,
    'disc_s_mnist_order_scale': split_mnist_new
}

# Set the experiment you want to run here
experiment = 'disc_s_mnist_order_scale'  # Options: 'disc_p_mnist', 'disc_s_mnist', 'disc_s_n_mnist', or 'all' to run all experiments

# Run the selected experiment(s)
if experiment == 'all':
    for exp_name, exp_func in EXP_OPTIONS.items():
        print(f"Running {exp_name}")
        exp_func()
else:
    if experiment in EXP_OPTIONS:
        print(f"Running {experiment}")
        EXP_OPTIONS[experiment]()
    else:
        print(f"Experiment '{experiment}' not found. Available options are: {list(EXP_OPTIONS.keys())}")


Running disc_s_mnist_order_scale
Number of heads: 5
Obtaining point estimate for posterior initialisation


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:58<00:00,  2.06it/s]


TASK  0


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:09<00:00,  1.73it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:06<00:00, 17.66it/s]


After task 0 perfomance on task 0 is 99.95271867612293
Mean accuracy: 99.95271867612293
alpha 1.0
beta 1.0
TASK  1


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:05<00:00,  1.83it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:14<00:00,  8.49it/s]


After task 1 perfomance on task 0 is 99.95271867612293
After task 1 perfomance on task 1 is 97.94319294809011
Mean accuracy: 98.96560019244647
alpha 1.1022407244356358
beta 0.8977592755643642
TASK  2


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:01<00:00,  1.95it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:21<00:00,  5.55it/s]


After task 2 perfomance on task 0 is 99.95271867612293
After task 2 perfomance on task 1 is 97.35553379040157
After task 2 perfomance on task 2 is 99.51974386339381
Mean accuracy: 98.93881611673022
alpha 0.9419072253336409
beta 1.058092774666359
TASK  3


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:06<00:00,  1.80it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:28<00:00,  4.20it/s]


After task 3 perfomance on task 0 is 99.90543735224587
After task 3 perfomance on task 1 is 97.25759059745347
After task 3 perfomance on task 2 is 99.41302027748132
After task 3 perfomance on task 3 is 99.69788519637463
Mean accuracy: 99.06448796307846
alpha 0.9366602766703835
beta 1.0633397233296165
TASK  4


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:04<00:00,  1.86it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:34<00:00,  3.47it/s]


After task 4 perfomance on task 0 is 99.90543735224587
After task 4 perfomance on task 1 is 96.42507345739472
After task 4 perfomance on task 2 is 99.09284951974387
After task 4 perfomance on task 3 is 98.99295065458207
After task 4 perfomance on task 4 is 97.52899646999495
Mean accuracy: 98.39
alpha 1.086100353000505
beta 0.913899646999495
All Task Accuracies: [[99.95271867612293, 99.95271867612293, 99.95271867612293, 99.90543735224587, 99.90543735224587], [97.94319294809011, 97.35553379040157, 97.25759059745347, 96.42507345739472], [99.51974386339381, 99.41302027748132, 99.09284951974387], [99.69788519637463, 98.99295065458207], [97.52899646999495], [99.95271867612293, 98.96560019244647, 98.93881611673022, 99.06448796307846, 98.39]]


## method6: scaling factor+order

In [None]:

MNIST_FLATTENED_DIM = 28 * 28
LR = 0.001
INITIAL_POSTERIOR_VAR = 1e-3

device = torch.device("cuda")
print("Running on device", device)

def binarize_y(y, task, task_to_labels):
    class_0, class_1 = task_to_labels[task]
    return (y == class_1).long()

# reorder

def split_mnist_factor_dissimilar():
    """
    Runs the 'Split MNIST' experiment from the VCL paper, in which each task is
    a binary classification task carried out on a subset of the MNIST dataset.
    """

    N_CLASSES = 2
    LAYER_WIDTH = 256
    N_HIDDEN_LAYERS = 2
    N_TASKS = 5
    MULTIHEADED = True
    CORESET_SIZE = 200
    EPOCHS = 120
    BATCH_SIZE = 50000
    TRAIN_FULL_CORESET = True

    transform = Compose([Flatten(), Scale()])

    all_accuracies = [[] for _ in range(N_TASKS)]
    mean_accuracies = []
    prev_mean_accuracy = None

    # download dataset
    mnist_train = MNIST(root="data", train=True, download=True, transform=transform)
    mnist_test = MNIST(root="data", train=False, download=True, transform=transform)

    model = DiscriminativeVCL(
        in_size=MNIST_FLATTENED_DIM, out_size=N_CLASSES,
        layer_width=LAYER_WIDTH, n_hidden_layers=N_HIDDEN_LAYERS,
        n_heads=(N_TASKS if MULTIHEADED else 1),
        initial_posterior_var=INITIAL_POSTERIOR_VAR
    ).to(device)

    coreset = RandomCoreset(size=CORESET_SIZE)


    label_to_task_mapping = {
        2: 0, 8: 0,
        5: 1, 9: 1,
        6: 2, 7: 2,
        0: 3, 4: 3,
        3: 4, 1: 4,
    }

    task_to_labels = {task: [] for task in range(N_TASKS)}
    for label, task in label_to_task_mapping.items():
        task_to_labels[task].append(label)
    for task, labels in task_to_labels.items():
        task_to_labels[task] = sorted(labels)

    train_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_train])
    test_task_ids = torch.Tensor([label_to_task_mapping[y] for _, y in mnist_test])

    summary_logdir = os.path.join("logs", "disc_s_mnist", datetime.now().strftime('%b%d_%H-%M-%S'))
    writer = SummaryWriter(summary_logdir)

    # each task is a binary classification task for a different pair of digits
    bin_y = lambda y, task: binarize_y(y, task, task_to_labels)

    run_point_estimate_initialisation(model=model, data=mnist_train,
                                      epochs=EPOCHS, batch_size=BATCH_SIZE,
                                      device=device, multiheaded=MULTIHEADED,
                                      lr=LR, task_ids=train_task_ids,
                                      y_transform=bin_y)

    for task_idx in range(N_TASKS):

        current_task_accuracies, current_mean = run_task_scale(
            model=model, train_data=mnist_train, train_task_ids=train_task_ids,
            test_data=mnist_test, test_task_ids=test_task_ids, coreset=coreset,
            task_idx=task_idx, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR,
            save_as="disc_s_mnist", device=device, multiheaded=MULTIHEADED,
            y_transform=bin_y, train_full_coreset=TRAIN_FULL_CORESET,
            summary_writer=writer
        )
        current_last_accuracy = current_task_accuracies[-1]
        model.adjust_scaling_factors(current_mean, current_last_accuracy)

        mean_accuracies.append(current_mean)
        for i, acc in enumerate(current_task_accuracies):
            all_accuracies[i].append(acc)
    all_accuracies.append(mean_accuracies)
    print("All Task Accuracies:", all_accuracies)
    writer.close()

Running on device cuda


In [None]:
split_mnist_factor_dissimilar
# random sample corset
EXP_OPTIONS = {
#     'disc_s_mnist': split_mnist_original,
#     'disc_s_mnist2': split_mnist_corset2,
#     'disc_s_mnist_order_similar': split_mnist_order_similar,
#     'disc_s_mnist_order_dissimilar': split_mnist_order_dissimilar,
#     'disc_s_mnist_order_scale': split_mnist_new,
    'disc_s_mnist_order_scale_dis': split_mnist_factor_dissimilar,
}

# Set the experiment you want to run here
experiment = 'disc_s_mnist_order_scale_dis'  # Options: 'disc_p_mnist', 'disc_s_mnist', 'disc_s_n_mnist', or 'all' to run all experiments

# Run the selected experiment(s)
if experiment == 'all':
    for exp_name, exp_func in EXP_OPTIONS.items():
        print(f"Running {exp_name}")
        exp_func()
else:
    if experiment in EXP_OPTIONS:
        print(f"Running {experiment}")
        EXP_OPTIONS[experiment]()
    else:
        print(f"Experiment '{experiment}' not found. Available options are: {list(EXP_OPTIONS.keys())}")


Running disc_s_mnist_order_scale_dis
Number of heads: 5
Obtaining point estimate for posterior initialisation


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:56<00:00,  2.12it/s]


TASK  0


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:04<00:00,  1.85it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:06<00:00, 17.59it/s]


After task 0 perfomance on task 0 is 99.45164506480559
Mean accuracy: 99.45164506480559
alpha 1.0
beta 1.0
TASK  1


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:03<00:00,  1.88it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:13<00:00,  8.77it/s]


After task 1 perfomance on task 0 is 99.3519441674975
After task 1 perfomance on task 1 is 98.15886375591793
Mean accuracy: 98.77143588431021
alpha 1.0612572128392272
beta 0.9387427871607728
TASK  2


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:07<00:00,  1.78it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:20<00:00,  5.81it/s]


After task 2 perfomance on task 0 is 98.85343968095712
After task 2 perfomance on task 1 is 98.36927932667017
After task 2 perfomance on task 2 is 99.34541792547834
Mean accuracy: 98.86305786526387
alpha 0.9517639939785525
beta 1.0482360060214475
TASK  3


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:04<00:00,  1.85it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:28<00:00,  4.22it/s]


After task 3 perfomance on task 0 is 98.90329012961116
After task 3 perfomance on task 1 is 98.26407154129406
After task 3 perfomance on task 2 is 99.39577039274924
After task 3 perfomance on task 3 is 99.59225280326197
Mean accuracy: 99.04519414385742
alpha 0.9452941340595444
beta 1.0547058659404556
TASK  4


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [01:11<00:00,  1.68it/s]


CORESET TRAIN


Epochs: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 120/120 [00:35<00:00,  3.37it/s]


After task 4 perfomance on task 0 is 99.40179461615155
After task 4 perfomance on task 1 is 98.26407154129406
After task 4 perfomance on task 2 is 99.29506545820745
After task 4 perfomance on task 3 is 99.64322120285424
After task 4 perfomance on task 4 is 99.39393939393939
Mean accuracy: 99.21
alpha 0.9816060606060603
beta 1.0183939393939396
All Task Accuracies: [[99.45164506480559, 99.3519441674975, 98.85343968095712, 98.90329012961116, 99.40179461615155], [98.15886375591793, 98.36927932667017, 98.26407154129406, 98.26407154129406], [99.34541792547834, 99.39577039274924, 99.29506545820745], [99.59225280326197, 99.64322120285424], [99.39393939393939], [99.45164506480559, 98.77143588431021, 98.86305786526387, 99.04519414385742, 99.21]]
