In [2]:
from copy import deepcopy
import gzip
import pickle
from torch import nn, optim
from torch.nn import functional as F
import torch
import numpy as np
from tqdm import tqdm

In [3]:
device="cuda:0"

with gzip.open("mnist.pkl.gz", "rb") as f:
    train_set, val_set, test_set = pickle.load(f, encoding="latin1")

BATCH_SIZE = 512

train_x = np.append(train_set[0], val_set[0], axis = 0)
num_features = train_x.shape[1]
train_y = np.append(train_set[1], val_set[1], axis = 0)
train_set = torch.utils.data.TensorDataset(torch.Tensor(train_x), torch.LongTensor(train_y))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_set = torch.utils.data.TensorDataset(torch.Tensor(test_set[0]), torch.LongTensor(test_set[1]))
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

num_features = train_x.shape[1]
num_outputs = train_y.max() + 1

In [4]:
# Returns dataloaders for the training and test datasets with randomly permuted feature columns
def generate_permuted_datasets():
    with gzip.open("mnist.pkl.gz", "rb") as f:
        train_set, val_set, test_set = pickle.load(f, encoding="latin1")
    
    train_x = np.append(train_set[0], val_set[0], axis = 0)
    train_y = np.append(train_set[1], val_set[1], axis = 0)
    test_x = test_set[0]
    test_y = test_set[1]

    # Shuffle feature columns
    rand_idxs = np.arange(num_features)
    np.random.shuffle(rand_idxs)
    train_x = train_x[:,rand_idxs]
    test_x = test_x[:,rand_idxs]


    train_set = torch.utils.data.TensorDataset(torch.Tensor(train_x), torch.LongTensor(train_y))
    test_set = torch.utils.data.TensorDataset(torch.Tensor(test_x), torch.LongTensor(test_y))

    return train_set, test_set

In [5]:
class Parametrization:
    def __init__(self, w_mean, log_w_var, b_mean, log_b_var):
        self.w_mean = w_mean.to(device)
        self.log_w_var = log_w_var.to(device)
        self.b_mean = b_mean.to(device)
        self.log_b_var = log_b_var.to(device)

class VCLLayer(nn.Module):
    def __init__(self, input_dim, output_dim, prior: Parametrization):
        super(VCLLayer, self).__init__()
        self.prior = deepcopy(prior)
        self.w_mean = nn.Parameter(prior.w_mean)
        self.log_w_var = nn.Parameter(prior.log_w_var)
        self.b_mean = nn.Parameter(prior.b_mean)
        self.log_b_var = nn.Parameter(prior.log_b_var)

    def forward(self, x):
        w_eps = torch.randn_like(self.w_mean)
        w_std = (0.5*self.log_w_var).exp()
        b_eps = torch.randn_like(self.b_mean)
        b_std = (0.5*self.log_b_var).exp()
        weights = self.w_mean + w_eps * w_std
        bias = self.b_mean + b_eps * b_std
        return torch.matmul(x, weights) + bias

    def update_priors(self):
        self.prior = deepcopy(Parametrization(self.w_mean, self.log_w_var, self.b_mean, self.log_b_var))

    def restore_from_priors(self):
        self.w_mean = deepcopy(self.prior.w_mean)
        self.b_mean = deepcopy(self.prior.b_mean)
        self.log_w_var = deepcopy(self.prior.log_w_var)
        self.log_b_var = deepcopy(self.prior.log_b_var)

In [6]:
class VCLNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, layer_priors: list[Parametrization], num_samples, kl_strength):
        super(VCLNN, self).__init__()
        self.num_samples = num_samples
        self.kl_strength = kl_strength
        self.layers = nn.Sequential(
            VCLLayer(input_dim, hidden_dim, layer_priors[0]),
            nn.ReLU(),
            VCLLayer(hidden_dim, hidden_dim, layer_priors[1]),
            nn.ReLU(),
            VCLLayer(hidden_dim, output_dim, layer_priors[2])
        )

    def forward(self, x):
        x = x.repeat([self.num_samples, 1, 1]).flatten(0,1)
        return self.layers(x)

    def loss(self, predictions, targets):
        pred_loss = F.cross_entropy(predictions, targets)
        # Compute KL divergence
        kl_div = 0.0
        num_layers = len(self.layers)
        for l in range(3):
            cur_layer = self.layers[l*2]
            ############
            # Helper to compute KL divergence on just bias/just weights
            def _compute_elementary_kl(cur_means, cur_vars, prior_means, prior_vars):
                var_div = prior_vars - cur_vars
                mean_div = (cur_vars.exp() + (prior_means - cur_means).square()) / prior_vars.exp()
                return 0.5 * (mean_div + var_div - 1).sum()
            ############
            kl_div += _compute_elementary_kl(cur_layer.w_mean, cur_layer.log_w_var, cur_layer.prior.w_mean, cur_layer.prior.log_w_var)
            kl_div += _compute_elementary_kl(cur_layer.b_mean, cur_layer.log_b_var, cur_layer.prior.b_mean, cur_layer.prior.log_b_var)

        return pred_loss + self.kl_strength*kl_div/(len(targets))

    def update_priors(self):
        for l in [0,2,4]:
            self.layers[l].update_priors()

    def restore_from_priors(self):
        for l in [0,2,4]:
            self.layers[l].restore_from_priors()

In [7]:
class BaseLayer(nn.Module):
    def __init__(self, input_dim, output_dim, prior: Parametrization):
        super(BaseLayer, self).__init__()
        self.w = nn.Parameter(prior.w_mean)
        self.b = nn.Parameter(prior.b_mean)

    def forward(self, x):
        return torch.matmul(x, self.w) + self.b

class BaseNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, layer_priors: list[Parametrization]):
        super(BaseNN, self).__init__()
        self.layers = nn.Sequential(
            BaseLayer(input_dim, hidden_dim, layer_priors[0]),
            nn.ReLU(),
            BaseLayer(hidden_dim, hidden_dim, layer_priors[1]),
            nn.ReLU(),
            BaseLayer(hidden_dim, output_dim, layer_priors[2])
        )

    def forward(self, x):
        return self.layers(x)

    def loss(self, predictions, targets):
        return F.cross_entropy(predictions, targets)

    def get_weights(self, init_variance):
        params = []
        for l in [0,2,4]:
            cur_l = self.layers[l]
            w_m = cur_l.w
            b_m = cur_l.b
            w_v = torch.zeros(w_m.shape) + init_variance
            b_v = torch.zeros(b_m.shape) + init_variance
            params.append(Parametrization(w_m, w_v, b_m, b_v))
        return params

In [8]:
def generate_weights(input_dim, hidden_dim, output_dim, num_layers, init_variance):
    params = []
    for i in range(num_layers):
        if i == 0:
            in_dim = input_dim
        else:
            in_dim = hidden_dim
        if i == num_layers-1:
            out_dim = output_dim
        else:
            out_dim = hidden_dim

        weight_means = torch.zeros(in_dim, out_dim)
        weight_variances = torch.zeros(in_dim, out_dim) + init_variance
        bias_means = torch.zeros(out_dim)
        bias_variances = torch.zeros(out_dim) + init_variance

        weights = torch.normal(mean=weight_means, std=0.1)
        bias = torch.normal(mean=bias_means, std=0.1)

        params.append(Parametrization(weights, weight_variances, bias, bias_variances))
    return params

In [9]:
def train(model, data_loader, loss_fn, epochs, learning_rate=0.001, sampling=False):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    for _ in range(epochs):
        for inputs, targets in data_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            if sampling:
                targets = targets.repeat([model.num_samples, 1]).flatten()
            optimizer.zero_grad()
            preds = model(inputs)
            loss = loss_fn(preds, targets)
            loss.backward()
            optimizer.step()

def test(model, data_loader, sampling=False):
    model.eval()
    hits = 0
    num_samples = 1
    if sampling:
        num_samples = model.num_samples
    for inputs, targets in data_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        if sampling:
            targets = targets.repeat([model.num_samples, 1]).flatten()
        preds = model(inputs)
        preds = F.softmax(model(inputs), dim=1)
        class_preds = preds.argmax(dim=1)
        hits += (class_preds == targets).sum()
    return hits/(len(data_loader.dataset) * num_samples)

In [10]:
def split_coreset(dataset, coreset_size):
    perm_idxs = torch.randperm(len(dataset))
    coreset_idxs = perm_idxs[:coreset_size]
    remainder_idxs = perm_idxs[coreset_size:]
    remainder_set = torch.utils.data.TensorDataset(dataset[remainder_idxs][0], dataset[remainder_idxs][1])
    core_set = torch.utils.data.TensorDataset(dataset[coreset_idxs][0], dataset[coreset_idxs][1])
    return remainder_set, core_set

def perform_vcl(num_tasks, num_epochs, coreset_size=0, num_samples=100, init_variance=-6.0, pre_training=True, kl_strength=1.0):
    train_set, test_set = generate_permuted_datasets()
    test_sets = [test_set]

    priors = generate_weights(num_features, 100, num_outputs, 3, init_variance)
    if pre_training:
        base_model = BaseNN(num_features, 100, num_outputs, layer_priors=priors).to(device)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)
        train(base_model, train_loader, base_model.loss, epochs=num_epochs)
        acc = test(base_model, test_loader)
        print(f"Base model acc: {acc*100:.4}%")

        priors = base_model.get_weights(init_variance)

    vclm = VCLNN(num_features, 100, num_outputs, layer_priors=priors, num_samples=num_samples, kl_strength=kl_strength).to(device)

    task_results = torch.zeros((num_tasks, num_tasks+1))
    coresets = []
    for t in range(num_tasks):
        if coreset_size > 0:
            train_set, coreset = split_coreset(train_set, coreset_size)
            coresets.append(coreset)
        
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
        train(vclm, train_loader, vclm.loss, epochs=num_epochs, sampling=True)
        accs = torch.zeros((num_tasks+1,))

        # First, train model on coreset if used
        vclm.update_priors()
        if coreset_size > 0:
            coreset_dataset = torch.utils.data.ConcatDataset(coresets)
            coreset_loader = torch.utils.data.DataLoader(coreset_dataset, batch_size=BATCH_SIZE, shuffle=True)
            train(vclm, coreset_loader, vclm.loss, epochs=num_epochs, sampling=True)

        # Test all tasks together
        combined_test_set = torch.utils.data.ConcatDataset(test_sets)
        test_loader = torch.utils.data.DataLoader(combined_test_set, batch_size=BATCH_SIZE, shuffle=True)
        accs[0] = test(vclm, test_loader, sampling=True).item()*100.0
        # Test each task one by one
        for i, test_set in enumerate(test_sets):
            test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)
            accs[i+1] = test(vclm, test_loader, sampling=True).item()*100.0
        print(f"Task {t} accuracies: {accs}")
        task_results[t] = accs
        # After testing, restore model weights from the priors to those not trained via coresets
        vclm.restore_from_priors()

        train_set, test_set = generate_permuted_datasets()
        test_sets.append(test_set)

    return task_results


In [11]:
def average_runs(num_runs, num_tasks, num_epochs, coreset_size, num_samples=100, init_variance=-6.0, pre_training=True, kl_strength=1.0):
    vcl_results = None
    for _ in range(num_runs):
        res = perform_vcl(num_tasks, num_epochs, coreset_size, num_samples=num_samples, pre_training=pre_training, init_variance=init_variance, kl_strength=kl_strength)
        if vcl_results is None:
            vcl_results = res[None,:]
        else:
            vcl_results = torch.cat((vcl_results, res[None,:]))

    return vcl_results[:,:,0].mean(dim=0), vcl_results

In [13]:
# Runs a single benchmark with the given settings
acc, all_res = average_runs(10, 10, 30, 200, num_samples=100, pre_training=True, init_variance=-6.0, kl_strength=1.0)
print(acc)