In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from tqdm.auto import tqdm

sns.set_theme(style="whitegrid")

# Bayesian Neural Network

In [None]:
def kl_normal(mu1, sigma1, mu2, sigma2):
    return 0.5 * (sigma1 ** 2 / sigma2 ** 2 + (mu1 - mu2) ** 2 / sigma2 ** 2 - 1 + torch.log(sigma2 ** 2 / sigma1 ** 2))

def extrude(bs, x):
    return x.unsqueeze(0).expand(bs, *([-1] * len(x.shape)))

In [None]:
class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features, prior_weight_mean, prior_weight_logvar, prior_bias_mean=None, prior_bias_logvar=None, *, bias=True, init_weight_var=1e-4):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.has_bias = bias

        self.weight_mean = nn.Parameter(nn.init.kaiming_normal_(torch.Tensor(in_features+int(self.has_bias), out_features)))
        self.weight_logvar = nn.Parameter(torch.Tensor(in_features+int(self.has_bias), out_features).fill_(np.log(init_weight_var)))

        prior_mean = torch.Tensor(in_features, out_features).fill_(prior_weight_mean)
        prior_logvar = torch.Tensor(in_features, out_features).fill_(prior_weight_logvar)
        if bias:
            assert prior_bias_mean is not None and prior_bias_logvar is not None
            prior_mean = torch.cat((prior_mean, torch.Tensor(1, out_features).fill_(prior_bias_mean)))
            prior_logvar = torch.cat((prior_logvar, torch.Tensor(1, out_features).fill_(np.log(init_weight_var))))
        self.register_buffer("prior_mean", prior_mean.clone())
        self.register_buffer("prior_logvar", prior_logvar.clone())

        self.register_buffer("original_prior_mean", prior_mean.clone())
        self.register_buffer("original_prior_logvar", prior_logvar.clone())
        # self.original_prior_mean = prior_mean.clone()
        # self.original_prior_logvar = prior_logvar.clone()

    def strengthen(self, t):
        # calculates p(theta) p(D|theta)^t
        # given p(theta)
        print("strengthening")
        alpha = t
        beta = 1-t
        mu_1 = self.weight_mean.detach()
        sigma_1_sq = torch.exp(self.weight_logvar.detach())
        mu_2 = self.original_prior_mean
        sigma_2_sq = torch.exp(self.original_prior_logvar)
        sigma_1_sq = torch.min(sigma_1_sq, sigma_2_sq)
        mu = (alpha * mu_1 * sigma_2_sq + beta * mu_2 * sigma_1_sq) / (alpha * sigma_2_sq + beta * sigma_1_sq)
        sigma_sq = (sigma_1_sq * sigma_2_sq) / (alpha * sigma_2_sq + beta * sigma_1_sq)
        sigma_sq = torch.where(sigma_sq > 0, torch.clamp(sigma_sq, max=100.0), torch.ones_like(sigma_sq) * 100.0)
        # display(mu, mu_1)
        # display(sigma_sq, sigma_1_sq)
        self.weight_mean = nn.Parameter(mu)
        self.weight_logvar = nn.Parameter(torch.log(sigma_sq))

    def update_prior(self):
        # "locks in" the current posterior as the new "prior" $q_t$
        self.prior_mean = self.weight_mean.detach().clone()
        self.prior_logvar = self.weight_logvar.detach().clone()

    def forward(self, x, sample=True):
        if self.has_bias:
            ones = torch.ones(x.shape[0], 1, device=x.device)
            x = torch.cat((x, ones), dim=1)
        if sample:
            bs = x.shape[0]
            weight_eps = torch.randn((bs, self.in_features+int(self.has_bias), self.out_features), device=x.device)
            weight = extrude(bs, self.weight_mean) + weight_eps * torch.exp(0.5 * extrude(bs, self.weight_logvar))
            return torch.einsum("bij,bi->bj", weight, x)
        else:
            weight = self.weight_mean
            return torch.einsum("ij,bi->bj", weight, x)

    def kl_divergence(self):
        return kl_normal(self.weight_mean, torch.exp(0.5 * self.weight_logvar), self.prior_mean, torch.exp(0.5 * self.prior_logvar)).sum()


In [None]:
class BNN(nn.Module):
    def __init__(self, layers):
        super().__init__()
        self.layers = nn.ModuleList(layers)

    def forward(self, x, sample=True):
        for layer in self.layers:
            if isinstance(layer, BayesianLinear):
                x = layer(x, sample=sample)
            else:
                x = layer(x)
        return x

    def strengthen(self, t):
        for layer in self.layers:
            if isinstance(layer, BayesianLinear):
                layer.strengthen(t)

    def predict(self, x, sample=False):
        pass

    def loglik(self, output, y):
        pass

    def kl_divergence(self):
        kl = 0.0
        for layer in self.layers:
            if isinstance(layer, BayesianLinear):
                kl += layer.kl_divergence()
        # print(f"kl = {kl}")
        return kl

    def update_prior(self):
        for layer in self.layers:
            if isinstance(layer, BayesianLinear):
                layer.update_prior()


### Testing/Debugging

In [None]:
layer = BayesianLinear(2, 4, 0, 0, 1, 0)

layer.prior_mean

layer.kl_divergence()

layer.forward(torch.randn(2, 2), sample=True)

In [None]:
weight_mean = 0
weight_logvar = -2
bias_mean = 0
bias_logvar = -2

network = BNN([
    BayesianLinear(2, 100, weight_mean, weight_logvar, bias_mean, bias_logvar),
    nn.ReLU(),
    BayesianLinear(100, 100, weight_mean, weight_logvar, bias_mean, bias_logvar),
    nn.ReLU(),
    BayesianLinear(100, 100, weight_mean, weight_logvar, bias_mean, bias_logvar),
    nn.ReLU(),
    BayesianLinear(100, 1, weight_mean, weight_logvar, bias_mean, bias_logvar),
])

network.forward(torch.randn(10, 2), sample=False)

In [None]:
network.kl_divergence()

In [None]:
a = torch.Tensor([[1, 2], [3, 4]]).requires_grad_()
b = a.detach().clone()
b[0][0] = 3
a

# Load data

In [None]:
import torch
from torchvision import datasets, transforms

class MNISTData:
    def __init__(self):
        transform = transforms.Compose([
            transforms.ToTensor()
        ])

        train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

        train_images = torch.stack([img for img, _ in train_dataset])
        self.train_y = torch.tensor([label for _, label in train_dataset])

        test_images = torch.stack([img for img, _ in test_dataset])
        self.test_y = torch.tensor([label for _, label in test_dataset])

        self.train_X = train_images.view(train_images.size(0), -1)
        self.test_X = test_images.view(test_images.size(0), -1)

        self.dim = self.train_X.size(1)
        self.num_classes = 10

        self.device = "cpu"

    def generate_permutation(self, seed):
        torch.manual_seed(seed)
        return torch.randperm(self.dim)

    def permuted_train(self, seed):
        perm = self.generate_permutation(seed)
        return self.train_X[:, perm].to(self.device), self.train_y.to(self.device)

    def permuted_test(self, seed):
        perm = self.generate_permutation(seed)
        return self.test_X[:, perm].to(self.device), self.test_y.to(self.device)

    def to(self, device):
        self.device = device
        return self


class MNISTDataRegression(MNISTData):
    def __init__(self):
        super().__init__()
        # Convert integer labels to one-hot vectors
        self.train_y = self._to_one_hot(self.train_y)
        self.test_y = self._to_one_hot(self.test_y)

    def _to_one_hot(self, labels):
        one_hot = torch.zeros(labels.size(0), self.num_classes)
        one_hot.scatter_(1, labels.unsqueeze(1), 1)
        return one_hot

### Testing/Debugging

In [None]:
# mnist_data = MNISTData()

In [None]:
# mnist_data.train_X.shape

In [None]:
# mnist_data.permuted_train(0)

In [None]:
# plt.imshow(mnist_data.train_X[0].reshape(28, 28), cmap="gray")

In [None]:
# plt.imshow(mnist_data.permuted_train(0)[0][0].reshape(28, 28), cmap="gray")

In [None]:
# mnist_data_regression = MNISTDataRegression()

In [None]:
# mnist_data_regression.permuted_train(0)

# Training loop

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
d = MNISTData().to(device)

In [None]:
def loss_func(f):
    def _loss_func(*args, reduction="mean", **kwargs):
        loss = f(*args, **kwargs)
        if reduction == "mean":
            return torch.mean(loss)
        elif reduction == "sum":
            return torch.sum(loss)
        elif reduction == "none":
            return loss
        else:
            raise ValueError(f"invalid reduction {reduction}")
    return _loss_func

@loss_func
def bnn_loss_func(log_likelihood, bnn, nt):
    """
    log_likelihood: p(y_i | theta, x_i) (batch)
    bnn: the BNN
    nt: the number of datapoints in this dataset

    loss func = sum_{i=1}^nt exp loglik - kl
    so we must multiply kl by 1/nt per datapoint
    """
    return -log_likelihood + bnn.kl_divergence() / nt


In [None]:
# torch.Tensor([1.0, 2.0]) + 3.0

In [None]:
def train_bnn(bnn, optim, train_data, test_data, *, train_batch_size, test_batch_size, metric,
              n_epochs=10, verbose=True, train_variational=True):
    train_dataset = TensorDataset(*train_data)
    train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)

    test_dataset = TensorDataset(*test_data)
    test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

    nt = len(train_dataset)

    progress_bar = tqdm(total=len(train_dataloader))

    for epoch in range(n_epochs):
        # Training phase
        progress_bar.reset(total=len(train_dataloader))
        progress_bar.set_description(f"Training epoch {epoch}/{n_epochs}")

        bnn.train()
        train_batch_count = 0
        train_epoch_loss = 0
        for X_batch, y_batch in train_dataloader:
            model_output = bnn.forward(X_batch, sample=True if train_variational else False)
            labels = y_batch
            log_likelihood = bnn.loglik(model_output, labels) # torch.gather(model_output, dim=1, index=labels)

            if train_variational:
                loss = bnn_loss_func(log_likelihood, bnn, nt)
            else:
                loss = -torch.mean(log_likelihood)
            loss.backward()
            optim.step()
            optim.zero_grad()

            logprob = torch.mean(log_likelihood).item()
            # print(bnn)
            kl_penalty = bnn.kl_divergence().item() / nt

            train_epoch_loss += loss.item()
            train_batch_count += 1
            progress_bar.set_postfix(loss=loss.item(), avg_loss=train_epoch_loss/train_batch_count, logprob=logprob, kl_penalty=kl_penalty)
            progress_bar.update()
        train_loss = train_epoch_loss/train_batch_count

        # Testing phase
        progress_bar.reset(total=len(test_dataloader))
        progress_bar.set_description(f"Testing epoch {epoch}/{n_epochs}")

        bnn.eval()
        total_datapoints = 0
        total_metric = 0
        for X_batch, y_batch in test_dataloader:
            preds = bnn.predict(X_batch, sample=False)
            total_metric += metric(preds, y_batch)
            total_datapoints += len(y_batch)
            progress_bar.set_postfix(metric=total_metric/total_datapoints)
            progress_bar.update()
        test_acc = total_metric/total_datapoints
        print(f"Epoch {epoch}: training loss: {train_loss:.4f}, test metric: {test_acc:.4f}")

    progress_bar.close()
    return bnn

In [None]:
def eval_bnn(bnn, test_data, *, test_batch_size, metric):
    test_dataset = TensorDataset(*test_data)
    test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
    bnn.eval()
    total_datapoints = 0
    total_metric = 0
    for X_batch, y_batch in test_dataloader:
        preds = bnn.predict(X_batch, sample=False)
        total_metric += metric(preds, y_batch)
        total_datapoints += len(y_batch)
    test_acc = total_metric/total_datapoints
    return test_acc

In [None]:
def train_permuted_mnist(
    bnn, optim, mnist_data: MNISTData,
    coreset_size: int,
    *,
    train_batch_size=256, test_batch_size=1000,
    n_epochs=100, start_epochs=None, n_tasks=10,
    coreset_only=False,
    metric,
    use_strengthening=False,
    strengthen_ratio=1.0,
    telemetry=False,
):
    if start_epochs is None:
        start_epochs = n_epochs
    dim = mnist_data.dim

    test_X = torch.Tensor(0, dim).to(device)
    test_y = torch.Tensor(0).long().to(device)
    train_X = torch.Tensor(0, dim).to(device)
    train_y = torch.Tensor(0).long().to(device)
    coreset_X = torch.Tensor(0, dim).to(device)
    coreset_y = torch.Tensor(0).long().to(device)

    accuracies = pd.DataFrame(columns=["n_tasks", "acc"]).astype({"n_tasks": int, "acc": float})

    progress_bar = tqdm(range(n_tasks))
    for task in progress_bar:
        train_X_task, train_y_task = mnist_data.permuted_train(task)
        test_X_task, test_y_task = mnist_data.permuted_test(task)

        if task == 0:
            # initialise variational parameters (NOT priors) by training on means only
            train_bnn(bnn, optim(), (train_X_task, train_y_task), (test_X_task, test_y_task), train_batch_size=train_batch_size, test_batch_size=test_batch_size, n_epochs=n_epochs, verbose=False, train_variational=False, metric=metric)

        # handle coreset
        coreset_indices = np.random.choice(len(train_X_task), size=coreset_size, replace=False)
        non_coreset_indices = np.setdiff1d(np.arange(len(train_X_task)), coreset_indices)
        coreset_X_task = train_X_task[coreset_indices]
        coreset_y_task = train_y_task[coreset_indices]
        coreset_X = torch.cat((coreset_X, coreset_X_task))
        coreset_y = torch.cat((coreset_y, coreset_y_task))
        train_X_task = train_X_task[non_coreset_indices]
        train_y_task = train_y_task[non_coreset_indices]
        test_X = torch.cat((test_X, test_X_task))
        test_y = torch.cat((test_y, test_y_task))

        progress_bar.set_description("Training non-coreset")
        if not coreset_only and train_X_task.size(0) > 0:
            epochs = n_epochs if task != 0 else start_epochs
            train_bnn(bnn, optim(), (train_X_task, train_y_task), (test_X, test_y), train_batch_size=train_batch_size, test_batch_size=test_batch_size, n_epochs=epochs, verbose=False, metric=metric)

        state_dict = bnn.state_dict()
        # optim_state_dict = optim.state_dict()

        progress_bar.set_description("Training coreset")
        if coreset_X.size(0) > 0:
            bnn.update_prior()
            train_bnn(bnn, optim(), (coreset_X, coreset_y), (test_X, test_y), train_batch_size=train_batch_size, test_batch_size=test_batch_size, n_epochs=n_epochs, verbose=False, metric=metric)

        if telemetry:
            bnn.run_telemetry(task, train_X_task, train_y_task, len(train_X_task))

        all_task_acc = eval_bnn(bnn, (test_X, test_y), test_batch_size=test_batch_size, metric=metric)
        result_dict = {"n_tasks": [task+1], "acc": [all_task_acc]}
        for task in range(task+1):
            task_acc = eval_bnn(bnn, mnist_data.permuted_test(task), test_batch_size=test_batch_size, metric=metric)
            result_dict[f"task_{task}"] = task_acc

        accuracies = pd.concat([accuracies, pd.DataFrame(result_dict)])

        bnn.load_state_dict(state_dict)

        # for layer in bnn.layers:
        #     if isinstance(layer, BayesianLinear):
        #         sns.histplot(layer.weight_logvar.clone().detach().cpu().numpy().flatten(), bins=50)
        #         plt.show()

        if use_strengthening:
            bnn.strengthen(strengthen_ratio)
        bnn.update_prior()
        # optim.load_state_dict(optim_state_dict)

    return accuracies.reset_index(drop=True)

In [None]:
weight_mean = 0
weight_logvar = 0
bias_mean = 0
bias_logvar = 0

hidden_size = 100
class MnistBNN(BNN):
    def __init__(self):
        super().__init__(
            nn.ModuleList([
                BayesianLinear(d.dim, hidden_size, weight_mean, weight_logvar, bias_mean, bias_logvar),
                nn.ReLU(),
                BayesianLinear(hidden_size, hidden_size, weight_mean, weight_logvar, bias_mean, bias_logvar),
                nn.ReLU(),
                BayesianLinear(hidden_size, hidden_size, weight_mean, weight_logvar, bias_mean, bias_logvar),
                nn.ReLU(),
                BayesianLinear(hidden_size, d.num_classes, weight_mean, weight_logvar, bias_mean, bias_logvar),
                nn.LogSoftmax(dim=1)
            ])
        )

    def predict(self, x, sample=False):
        output = self.forward(x, sample=sample)
        return torch.argmax(output, dim=1)

    def loglik(self, output, y):
        return torch.gather(output, dim=1, index=y.unsqueeze(1))

# mnist_bnn = MnistBNN().to(device)

# metric = lambda preds, y: torch.sum(preds == y).item()

# optim = lambda: torch.optim.AdamW(mnist_bnn.parameters(), weight_decay=0, lr=5e-3)

# permuted_mnist_results = train_permuted_mnist(mnist_bnn, optim, d, coreset_size=0, n_epochs=2, start_epochs=2, n_tasks=10, metric=metric, use_strengthening=False, strengthen_ratio=1.0, telemetry=True)
# display(permuted_mnist_results)

# for ratio in tqdm([1]):
#     mnist_bnn = MnistBNN().to(device)

#     metric = lambda preds, y: torch.sum(preds == y).item()

#     optim = lambda: torch.optim.AdamW(mnist_bnn.parameters(), weight_decay=0, lr=5e-3)

#     permuted_mnist_results = train_permuted_mnist(mnist_bnn, optim, d, coreset_size=200, n_epochs=20, start_epochs=100, n_tasks=10, metric=metric, use_strengthening=True, strengthen_ratio=ratio)
#     display(permuted_mnist_results)

#     permuted_mnist_results.to_csv(f"./drive/MyDrive/permuted_mnist_coreset_200_epochs_20-100_lr_5e-3_init_var_1e-4_strengthen_{ratio}_fix_for_real.csv")

In [None]:
# permuted_mnist_results
# permuted_mnist_results.to_csv(f"./drive/MyDrive/permuted_mnist_coreset_0_epochs_20-100_lr_5e-3_init_var_1e-4_strengthen_1_fix_for_real.csv")

## Regression

In [None]:
d_reg = MNISTDataRegression().to(device)

In [None]:
weight_mean = 0
weight_logvar = 0
bias_mean = 0
bias_logvar = 0

hidden_size = 100
class MnistRegressionBNN(BNN):
    def __init__(self):
        self.output_dim = d_reg.num_classes

        super().__init__(
            nn.ModuleList([
                BayesianLinear(d.dim, hidden_size, weight_mean, weight_logvar, bias_mean, bias_logvar),
                nn.ReLU(),
                BayesianLinear(hidden_size, hidden_size, weight_mean, weight_logvar, bias_mean, bias_logvar),
                nn.ReLU(),
                BayesianLinear(hidden_size, hidden_size, weight_mean, weight_logvar, bias_mean, bias_logvar),
                nn.ReLU(),
                BayesianLinear(hidden_size, self.output_dim * 2, weight_mean, weight_logvar, bias_mean, bias_logvar),
            ])
        )
        self.telemetry = []

    def predict(self, x, sample=False):
        output = self.forward(x, sample=sample)
        means = output[:, :self.output_dim]
        logvars = output[:, self.output_dim:]
        return means

    def loglik(self, output, y):
        means = output[:, :self.output_dim]
        logvars = output[:, self.output_dim:]
        # print(means.shape, logvars.shape, y.shape)
        loglik = -(logvars * 0.5 + np.log(2 * np.pi)) - 0.5 * (y - means) ** 2 / (1e-8 + torch.exp(logvars))
        # print(loglik)
        return loglik.sum(dim=1)
        # return torch.gather(output, dim=1, index=y)

    def run_telemetry(self, task, X_train, y_train, nt):
        self.eval()
        batch_size = 1000
        num_samples = X_train.shape[0]
        num_batches = (num_samples + batch_size - 1) // batch_size

        all_means = []
        all_logvars = []
        all_means_sample = []
        all_logvars_sample = []

        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, num_samples)

            X_batch = X_train[start_idx:end_idx]

            outputs = self.forward(X_batch, sample=False)
            means = outputs[:, :self.output_dim].detach().cpu()
            logvars = outputs[:, self.output_dim:].detach().cpu()
            all_means.append(means)
            all_logvars.append(logvars)

            outputs_sample_true = self.forward(X_batch, sample=True)
            means_sample_true = outputs_sample_true[:, :self.output_dim].detach().cpu()
            logvars_sample_true = outputs_sample_true[:, self.output_dim:].detach().cpu()
            all_means_sample.append(means_sample_true)
            all_logvars_sample.append(logvars_sample_true)

        means = torch.cat(all_means, dim=0)
        logvars = torch.cat(all_logvars, dim=0)
        means_sample_true = torch.cat(all_means_sample, dim=0)
        logvars_sample_true = torch.cat(all_logvars_sample, dim=0)

        logvar_series = pd.Series(logvars.flatten().detach().cpu().numpy())
        percentiles = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
        logvar_percentiles = logvar_series.quantile(percentiles)

        logvar_series_sample_true = pd.Series(logvars_sample_true.flatten().detach().cpu().numpy())
        logvar_percentiles_sample_true = logvar_series_sample_true.quantile(percentiles)

        weight_logvars = []
        for layer in self.layers:
            if isinstance(layer, BayesianLinear):
                weight_logvars.append(layer.weight_logvar.detach().cpu().numpy().flatten())
        weight_logvar_series = pd.Series(np.concatenate(weight_logvars).flatten())
        weight_logvar_percentiles = weight_logvar_series.quantile(percentiles)

        kl_div = self.kl_divergence().item()

        self.telemetry.append({
            "kl_divergence": kl_div,
            "nt": nt,
            "logvar_percentiles": logvar_percentiles.to_dict(),
            "logvar_sample_percentiles": logvar_percentiles_sample_true.to_dict(),
            "weight_logvar_percentiles": weight_logvar_percentiles.to_dict(),
        })



# for ratio in tqdm([0.8, 1, 1.1, 1.2, 1.3, 1.5, 2]):
#     mnist_reg_bnn = MnistRegressionBNN().to(device)

#     metric = lambda preds, y: torch.mean((preds - y) ** 2, axis=1).sum().item()

#     reg_optim = lambda: torch.optim.AdamW(mnist_reg_bnn.parameters(), weight_decay=0, lr=3e-4)

#     permuted_mnist_reg_results = train_permuted_mnist(mnist_reg_bnn, reg_optim, d_reg, coreset_size=0, n_epochs=20, start_epochs=100, n_tasks=10, metric=metric, use_strengthening=True, strengthen_ratio=ratio)
#     display(permuted_mnist_reg_results)

#     permuted_mnist_reg_results.to_csv(f"./drive/MyDrive/no_coreset_strengthen_reg/permuted_mnist_reg_coreset_0_epochs_20-100_lr_3e-4_init_var_1e-4_strengthen_{ratio}_fix_for_real.csv")

# MSE
mnist_reg_bnn = MnistRegressionBNN().to(device)

metric = lambda preds, y: torch.mean((preds - y) ** 2, axis=1).sum().item()

reg_optim = lambda: torch.optim.AdamW(mnist_reg_bnn.parameters(), weight_decay=0, lr=3e-4)

permuted_mnist_reg_results = train_permuted_mnist(mnist_reg_bnn, reg_optim, d_reg, coreset_size=200, n_epochs=20, start_epochs=100, n_tasks=10, metric=metric, telemetry=True)
display(permuted_mnist_reg_results)

# permuted_mnist_reg_results.to_csv("./drive/MyDrive/permuted_mnist_reg_coreset_200_epochs_20-100_lr_1e-3_init_var_1e-4.csv")

In [None]:
pd.set_option("display.max_columns", 100)

df = pd.DataFrame(mnist_reg_bnn.telemetry)
for percentile in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]:
    df[f"logvar_{percentile}"] = df["logvar_percentiles"].apply(lambda x: x[percentile])
for percentile in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]:
    df[f"logvar_sample_{percentile}"] = df["logvar_sample_percentiles"].apply(lambda x: x[percentile])
for percentile in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]:
    df[f"weight_logvar_{percentile}"] = df["weight_logvar_percentiles"].apply(lambda x: x[percentile])
df.drop(columns=["logvar_percentiles"], inplace=True)
df.drop(columns=["logvar_sample_percentiles"], inplace=True)
df.drop(columns=["weight_logvar_percentiles"], inplace=True)
df

In [None]:
df.to_csv("./drive/MyDrive/permuted_mnist_reg_coreset_200_epochs_20-100_lr_1e-3_init_var_1e-4_telemetry.csv")

# Baselines

## Laplace Propagation

In [None]:
def train_laplace(model, optim, train_data, test_data, *, train_batch_size, test_batch_size, metric,
              lambda_reg=0.01, n_epochs=10, verbose=True, si=False):

    train_dataset = TensorDataset(*train_data)
    train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)

    test_dataset = TensorDataset(*test_data)
    test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

    nt = len(train_dataset)

    progress_bar = tqdm(total=len(train_dataloader))

    for epoch in range(n_epochs):
        # Training phase
        progress_bar.reset(total=len(train_dataloader))
        progress_bar.set_description(f"Training epoch {epoch}/{n_epochs}")

        model.train()
        train_batch_count = 0
        train_epoch_loss = 0

        for X_batch, y_batch in train_dataloader:
            model_output = model.forward(X_batch)
            log_likelihood = model.loglik(model_output, y_batch)

            reg_term = model.reg_term(lambda_reg, nt)

            loss = torch.mean(-log_likelihood) + reg_term * X_batch.size(0)

            loss.backward()
            optim.step()
            if si:
                model.update_omega_values()
            optim.zero_grad()

            logprob = torch.mean(log_likelihood).item()

            train_epoch_loss += loss.item()
            train_batch_count += 1

            progress_bar.set_postfix(loss=loss.item(),
                                    avg_loss=train_epoch_loss/train_batch_count,
                                    logprob=logprob,
                                    reg_penalty=reg_term.item() * X_batch.size(0))
            progress_bar.update()

        train_loss = train_epoch_loss/train_batch_count

        # Testing phase
        if verbose:
            progress_bar.reset(total=len(test_dataloader))
            progress_bar.set_description(f"Testing epoch {epoch}/{n_epochs}")

        model.eval()
        total_datapoints = 0
        total_metric = 0

        for X_batch, y_batch in test_dataloader:
            preds = model.predict(X_batch)
            total_metric += metric(preds, y_batch)
            total_datapoints += len(y_batch)

            if verbose:
                progress_bar.set_postfix(metric=total_metric/total_datapoints)
                progress_bar.update()

        test_metric_val = total_metric/total_datapoints
        if verbose:
            print(f"Epoch {epoch}: training loss: {train_loss:.4f}, test metric: {test_metric_val:.4f}")

    if verbose:
        print("Computing Hessian and updating precision matrix... ", end="")

    model.eval()

    model.finish_task(*train_data)

    if verbose:
        print("done")

    return model

In [None]:
def eval_laplace(model, test_data, *, test_batch_size, metric):
    test_dataset = TensorDataset(*test_data)
    test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
    model.eval()
    total_datapoints = 0
    total_metric = 0
    for X_batch, y_batch in test_dataloader:
        preds = model.predict(X_batch)
        total_metric += metric(preds, y_batch)
        total_datapoints += len(y_batch)
    test_acc = total_metric/total_datapoints
    return test_acc

In [None]:
def train_permuted_mnist_laplace(
    laplace, optim, mnist_data: MNISTData,
    *,
    train_batch_size=256, test_batch_size=1000,
    n_epochs=100, n_tasks=10,
    lambda_reg=1e-2,
    metric, si=False,
):
    dim = mnist_data.dim

    test_X = torch.Tensor(0, dim).to(device)
    test_y = torch.Tensor(0).long().to(device)
    # train_X = torch.Tensor(0, dim).to(device)
    # train_y = torch.Tensor(0).long().to(device)

    accuracies = pd.DataFrame(columns=["n_tasks", "acc"]).astype({"n_tasks": int, "acc": float})

    progress_bar = tqdm(range(n_tasks))
    for task in progress_bar:
        train_X_task, train_y_task = mnist_data.permuted_train(task)
        test_X_task, test_y_task = mnist_data.permuted_test(task)

        # handle coreset
        test_X = torch.cat((test_X, test_X_task))
        test_y = torch.cat((test_y, test_y_task))

        progress_bar.set_description("Training")
        task_lambda = lambda_reg
        train_laplace(laplace, optim(), (train_X_task, train_y_task), (test_X, test_y), lambda_reg=task_lambda, train_batch_size=train_batch_size, test_batch_size=test_batch_size, n_epochs=n_epochs, verbose=True, metric=metric, si=si)

        all_task_acc = eval_laplace(laplace, (test_X, test_y), test_batch_size=test_batch_size, metric=metric)
        result_dict = {"n_tasks": [task+1], "acc": [all_task_acc]}
        for task in range(task+1):
            task_acc = eval_laplace(laplace, mnist_data.permuted_test(task), test_batch_size=test_batch_size, metric=metric)
            result_dict[f"task_{task}"] = task_acc

        # all_task_acc = eval_laplace(laplace, (test_X, test_y), test_batch_size=test_batch_size, metric=metric)
        accuracies = pd.concat([accuracies, pd.DataFrame(result_dict)])

    return accuracies.reset_index(drop=True)

### Train LP

In [None]:
hidden_size = 100

class MnistLaplace(nn.Module):
    def __init__(self):
        super().__init__()

        self.network = nn.Sequential(
            nn.Linear(d.dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, d.num_classes),
            nn.LogSoftmax(dim=1)
        )

        self.Sigma_inv = {}
        for name, param in self.named_parameters():
            self.Sigma_inv[name] = torch.ones_like(param)  # Initialize to I

        self.theta_prev = {}
        for name, param in self.named_parameters():
            self.theta_prev[name] = torch.zeros_like(param)

        self.output_dim = d.num_classes

    def to(self, device):
        super().to(device)
        for t in self.theta_prev:
            self.theta_prev[t] = self.theta_prev[t].to(device)
        for t in self.Sigma_inv:
            self.Sigma_inv[t] = self.Sigma_inv[t].to(device)
        return self

    def forward(self, x):
        return self.network(x)

    def predict(self, x):
        output = self.forward(x)
        return torch.argmax(output, dim=1)

    def loglik(self, output, y):
        return torch.gather(output, dim=1, index=y.unsqueeze(1))

    def reg_term(self, lambda_reg, nt):
        reg_term = 0.0
        for name, param in self.named_parameters():
            diff = param - self.theta_prev[name]
            reg_term += 0.5 * lambda_reg * (diff * self.Sigma_inv[name] * diff).sum() / nt
        return reg_term

    def finish_task(self, X_train, y_train):
        self.eval()

        hessian_diag = self.hessian_approx(X_train, y_train, n_samples=2000)

        for (name, _), hess in zip(self.named_parameters(), hessian_diag):
            self.Sigma_inv[name] = self.Sigma_inv[name] + hess

        for name, param in self.named_parameters():
            self.theta_prev[name] = param.detach().clone()

    def hessian_approx(self, X, y, n_samples=200):
        nt = X.size(0)
        assert n_samples <= nt

        sampled_indices = np.random.choice(nt, size=n_samples, replace=False)
        sample_X = X[sampled_indices]
        sample_y = y[sampled_indices]

        params = list(self.parameters())

        loglik = self.loglik(self.forward(sample_X), sample_y)
        sq_grads = [torch.zeros_like(param) for param in params]
        for i in range(n_samples):
            retain_graph = i < n_samples - 1
            grads = torch.autograd.grad(loglik[i], self.parameters(), retain_graph=retain_graph)
            for param_idx, (param, grad) in enumerate(zip(params, grads)):
                sq_grads[param_idx] += grad ** 2

        sum_sq_grads = [sq_grad / n_samples * nt for sq_grad in sq_grads]

        return sum_sq_grads

    def hessian_exact(self, x, y):
        """ Calculates diagonal of hessian of negative log-likelihood wrt parameters """
        output = self.forward(x)

        log_likelihood = self.loglik(output, y)
        nll = -log_likelihood.sum()

        # Extract model parameters
        params = list(self.parameters())

        grads = torch.autograd.grad(nll, params, create_graph=True)

        diag_hessian = []

        for param_idx, (param, grad) in enumerate(zip(params, grads)):
            diag_h = torch.zeros_like(param)

            param_flat = param.view(-1)
            grad_flat = grad.view(-1)

            for i in range(param_flat.size(0)):
                retain_graph = (param_idx < len(params) - 1) or (i < param_flat.size(0) - 1)
                hess_elem = torch.autograd.grad(grad_flat[i], param, retain_graph=retain_graph)[0].view(-1)[i]
                diag_h.view(-1)[i] = hess_elem

            diag_hessian.append(diag_h)

        return diag_hessian


# mnist_laplace = MnistLaplace().to(device)

# metric = lambda preds, y: torch.sum(preds == y).item()

# optim = lambda: torch.optim.AdamW(mnist_laplace.parameters(), weight_decay=0, lr=5e-3)

# # train_laplace(mnist_laplace, optim(), d.permuted_train(0), d.permuted_test(0), train_batch_size=256, test_batch_size=1000, metric=metric, lambda_reg=0.1, n_epochs=20, verbose=True)

# permuted_mnist_results_lp = train_permuted_mnist_laplace(mnist_laplace, optim, d, lambda_reg=0.1, n_epochs=20, n_tasks=10, metric=metric)
# display(permuted_mnist_results_lp)

# permuted_mnist_results_lp.to_csv("./drive/MyDrive/permuted_mnist_LP_lambda_0.1_epochs_20_lr_5e-3_hidden_100_approx_2000.csv")

In [None]:
hidden_size = 100

class MnistLaplaceRegression(nn.Module):
    def __init__(self):
        super().__init__()

        self.network = nn.Sequential(
            nn.Linear(d.dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, d_reg.num_classes * 2)  # Output means and logvars
        )

        self.Sigma_inv = {}
        for name, param in self.named_parameters():
            self.Sigma_inv[name] = torch.ones_like(param)  # Initialize to I

        self.theta_prev = {}
        for name, param in self.named_parameters():
            self.theta_prev[name] = torch.zeros_like(param)

        self.output_dim = d_reg.num_classes

    def to(self, device):
        super().to(device)
        for t in self.theta_prev:
            self.theta_prev[t] = self.theta_prev[t].to(device)
        for t in self.Sigma_inv:
            self.Sigma_inv[t] = self.Sigma_inv[t].to(device)
        return self

    def forward(self, x):
        return self.network(x)

    def predict(self, x):
        output = self.forward(x)
        means = output[:, :self.output_dim]
        return means

    def loglik(self, output, y):
        means = output[:, :self.output_dim]
        logvars = output[:, self.output_dim:]
        # Gaussian log likelihood for regression
        loglik = -(logvars * 0.5 + np.log(2 * np.pi)) - 0.5 * (y - means) ** 2 / (1e-8 + torch.exp(logvars))
        return loglik.sum(dim=1)

    def reg_term(self, lambda_reg, nt):
        reg_term = 0.0
        for name, param in self.named_parameters():
            diff = param - self.theta_prev[name]
            reg_term += 0.5 * lambda_reg * (diff * self.Sigma_inv[name] * diff).sum() / nt
        return reg_term

    def finish_task(self, X_train, y_train):
        self.eval()

        hessian_diag = self.hessian_approx(X_train, y_train, n_samples=2000)

        for (name, _), hess in zip(self.named_parameters(), hessian_diag):
            self.Sigma_inv[name] = self.Sigma_inv[name] + hess

        for name, param in self.named_parameters():
            self.theta_prev[name] = param.detach().clone()

    def hessian_approx(self, X, y, n_samples=200):
        nt = X.size(0)
        assert n_samples <= nt

        sampled_indices = np.random.choice(nt, size=n_samples, replace=False)
        sample_X = X[sampled_indices]
        sample_y = y[sampled_indices]

        params = list(self.parameters())

        loglik = self.loglik(self.forward(sample_X), sample_y)
        sq_grads = [torch.zeros_like(param) for param in params]
        for i in range(n_samples):
            retain_graph = i < n_samples - 1
            grads = torch.autograd.grad(loglik[i], self.parameters(), retain_graph=retain_graph)
            for param_idx, (param, grad) in enumerate(zip(params, grads)):
                sq_grads[param_idx] += grad ** 2

        sum_sq_grads = [sq_grad / n_samples * nt for sq_grad in sq_grads]

        return sum_sq_grads

    def hessian_exact(self, x, y):
        """ Calculates diagonal of hessian of negative log-likelihood wrt parameters """
        output = self.forward(x)

        log_likelihood = self.loglik(output, y)
        nll = -log_likelihood.sum()

        # Extract model parameters
        params = list(self.parameters())

        grads = torch.autograd.grad(nll, params, create_graph=True)

        diag_hessian = []

        for param_idx, (param, grad) in enumerate(zip(params, grads)):
            diag_h = torch.zeros_like(param)

            param_flat = param.view(-1)
            grad_flat = grad.view(-1)

            for i in range(param_flat.size(0)):
                retain_graph = (param_idx < len(params) - 1) or (i < param_flat.size(0) - 1)
                hess_elem = torch.autograd.grad(grad_flat[i], param, retain_graph=retain_graph)[0].view(-1)[i]
                diag_h.view(-1)[i] = hess_elem

            diag_hessian.append(diag_h)

        return diag_hessian


# mnist_laplace_reg = MnistLaplaceRegression().to(device)

# metric = lambda preds, y: torch.mean((preds - y) ** 2, axis=1).sum().item()

# optim = lambda: torch.optim.AdamW(mnist_laplace_reg.parameters(), weight_decay=0, lr=1e-3)

# permuted_mnist_results_lp_reg = train_permuted_mnist_laplace(mnist_laplace_reg, optim, d_reg, lambda_reg=0, n_epochs=20, n_tasks=10, metric=metric)
# display(permuted_mnist_results_lp_reg)

# permuted_mnist_results_lp_reg.to_csv("./drive/MyDrive/permuted_mnist_naive_epochs_20_lr_1e-3_hidden_100.csv")

### Train EWC

In [None]:
hidden_size = 100

class MnistEWC(nn.Module):
    def __init__(self):
        super().__init__()

        self.network = nn.Sequential(
            nn.Linear(d.dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, d.num_classes),
            nn.LogSoftmax(dim=1)
        )

        self.theta_list = []
        self.fisher_list = []

        self.output_dim = d.num_classes
        self.task_count = 0

    def to(self, device):
        super().to(device)
        for task_idx in range(len(self.theta_list)):
            for name in self.theta_list[task_idx]:
                self.theta_list[task_idx][name] = self.theta_list[task_idx][name].to(device)

        for task_idx in range(len(self.fisher_list)):
            for name in self.fisher_list[task_idx]:
                self.fisher_list[task_idx][name] = self.fisher_list[task_idx][name].to(device)

        return self

    def forward(self, x):
        return self.network(x)

    def predict(self, x):
        output = self.forward(x)
        return torch.argmax(output, dim=1)

    def loglik(self, output, y):
        return torch.gather(output, dim=1, index=y.unsqueeze(1))

    def reg_term(self, lambda_reg, nt):
        reg_term = torch.tensor(0.0, device=device)

        # # ridge
        # ridge_term = 0.0
        # for name, param in self.named_parameters():
        #     ridge_term += torch.sum(param ** 2)
        # reg_term += ridge_term

        task_regs = []
        for i in range(self.task_count):
            theta_i = self.theta_list[i]

            task_reg = 0.0
            for name, param in self.named_parameters():
                if name in theta_i:
                    diff = param - theta_i[name]
                    fisher = self.fisher_list[i][name]
                    task_reg += (diff * fisher * diff).sum()

            task_regs.append(task_reg)
            reg_term += task_reg

        # print(task_regs)

        # Apply lambda and normalize by number of training examples
        return 0.5 * lambda_reg * reg_term / nt

    def finish_task(self, X_train, y_train):
        self.eval()

        current_theta = {}
        for name, param in self.named_parameters():
            current_theta[name] = param.detach().clone()

        self.theta_list.append(current_theta)

        fisher_diag = self.fisher_approx(X_train, y_train, n_samples=2000)

        current_fisher = {}
        for (name, _), fisher in zip(self.named_parameters(), fisher_diag):
            current_fisher[name] = fisher

        self.fisher_list.append(current_fisher)

        self.task_count += 1

    def fisher_approx(self, X, y, n_samples=200):
        nt = X.size(0)
        assert n_samples <= nt

        sampled_indices = np.random.choice(nt, size=n_samples, replace=False)
        sample_X = X[sampled_indices]
        sample_y = y[sampled_indices]

        params = list(self.parameters())

        loglik = self.loglik(self.forward(sample_X), sample_y)
        sq_grads = [torch.zeros_like(param) for param in params]
        for i in range(n_samples):
            retain_graph = i < n_samples - 1
            grads = torch.autograd.grad(loglik[i], self.parameters(), retain_graph=retain_graph)
            for param_idx, (param, grad) in enumerate(zip(params, grads)):
                sq_grads[param_idx] += grad ** 2

        sum_sq_grads = [sq_grad / n_samples * nt for sq_grad in sq_grads]

        return sum_sq_grads

# mnist_ewc = MnistEWC().to(device)

# metric = lambda preds, y: torch.sum(preds == y).item()

# optim = lambda: torch.optim.AdamW(mnist_ewc.parameters(), lr=2e-3, weight_decay=0)

# permuted_mnist_results_ewc = train_permuted_mnist_laplace(mnist_ewc, optim, d, lambda_reg=100, n_epochs=20, n_tasks=10, metric=metric)
# display(permuted_mnist_results_ewc)

# permuted_mnist_results_ewc.to_csv("./drive/MyDrive/permuted_mnist_EWC_lambda_100_epochs_20_lr_2e-3_hidden_100_approx_2000.csv")

In [None]:
hidden_size = 100

class MnistEWCRegression(nn.Module):
    def __init__(self):
        super().__init__()

        self.network = nn.Sequential(
            nn.Linear(d.dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, d_reg.num_classes * 2)
        )

        self.theta_list = []
        self.fisher_list = []

        self.output_dim = d_reg.num_classes
        self.task_count = 0  # Number of tasks completed

    def to(self, device):
        super().to(device)
        for task_idx in range(len(self.theta_list)):
            for name in self.theta_list[task_idx]:
                self.theta_list[task_idx][name] = self.theta_list[task_idx][name].to(device)

        for task_idx in range(len(self.fisher_list)):
            for name in self.fisher_list[task_idx]:
                self.fisher_list[task_idx][name] = self.fisher_list[task_idx][name].to(device)

        return self

    def forward(self, x):
        return self.network(x)

    def predict(self, x):
        output = self.forward(x)
        means = output[:, :self.output_dim]
        return means

    def loglik(self, output, y):
        means = output[:, :self.output_dim]
        logvars = output[:, self.output_dim:]
        loglik = -(logvars * 0.5 + np.log(2 * np.pi)) - 0.5 * (y - means) ** 2 / (1e-8 + torch.exp(logvars))
        return loglik.sum(dim=1)

    def reg_term(self, lambda_reg, nt):
        reg_term = torch.tensor(0.0, device=device)

        task_regs = []
        for i in range(self.task_count):
            theta_i = self.theta_list[i]

            task_reg = 0.0
            for name, param in self.named_parameters():
                if name in theta_i:
                    diff = param - theta_i[name]
                    fisher = self.fisher_list[i][name]
                    task_reg += (diff * fisher * diff).sum()

            task_regs.append(task_reg)
            reg_term += task_reg

        return 0.5 * lambda_reg * reg_term / nt

    def finish_task(self, X_train, y_train):
        self.eval()

        current_theta = {}
        for name, param in self.named_parameters():
            current_theta[name] = param.detach().clone()

        self.theta_list.append(current_theta)

        fisher_diag = self.fisher_approx(X_train, y_train, n_samples=2000)

        current_fisher = {}
        for (name, _), fisher in zip(self.named_parameters(), fisher_diag):
            current_fisher[name] = fisher

        self.fisher_list.append(current_fisher)

        self.task_count += 1

    def fisher_approx(self, X, y, n_samples=200):
        nt = X.size(0)
        assert n_samples <= nt

        sampled_indices = np.random.choice(nt, size=n_samples, replace=False)
        sample_X = X[sampled_indices]
        sample_y = y[sampled_indices]

        params = list(self.parameters())

        loglik = self.loglik(self.forward(sample_X), sample_y)
        sq_grads = [torch.zeros_like(param) for param in params]
        for i in range(n_samples):
            retain_graph = i < n_samples - 1
            grads = torch.autograd.grad(loglik[i], self.parameters(), retain_graph=retain_graph)
            for param_idx, (param, grad) in enumerate(zip(params, grads)):
                sq_grads[param_idx] += grad ** 2

        sum_sq_grads = [sq_grad / n_samples * nt for sq_grad in sq_grads]

        return sum_sq_grads

# mnist_ewc_reg = MnistEWCRegression().to(device)

# metric = lambda preds, y: torch.mean((preds - y) ** 2, axis=1).sum().item()

# optim = lambda: torch.optim.AdamW(mnist_ewc_reg.parameters(), lr=1e-3, weight_decay=0)

# permuted_mnist_results_ewc_reg = train_permuted_mnist_laplace(mnist_ewc_reg, optim, d_reg, lambda_reg=100, n_epochs=20, n_tasks=10, metric=metric)
# display(permuted_mnist_results_ewc_reg)
# permuted_mnist_results_ewc_reg.to_csv("./drive/MyDrive/permuted_mnist_reg_EWC_lambda_100_epochs_20_lr_1e-3_hidden_100_approx_2000.csv")

## Train SI

In [None]:
hidden_size = 100

class MnistSI(nn.Module):
    def __init__(self, xi: float):
        super().__init__()

        self.network = nn.Sequential(
            nn.Linear(d.dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, d.num_classes),
            nn.LogSoftmax(dim=1)
        )

        self.param_importance = {}
        self.omega = {}
        self.prev_params = {}
        self.prev_task_params = {}

        self._init_importance_params()

        self.output_dim = d.num_classes
        self.task_count = 0
        self.xi = xi

    def _init_importance_params(self):
        for name, param in self.named_parameters():
            self.omega[name] = torch.zeros_like(param.data)
            self.prev_params[name] = param.data.clone()
            self.prev_task_params[name] = param.data.clone()

            if name not in self.param_importance:
                self.param_importance[name] = torch.zeros_like(param.data)

    def to(self, device):
        super().to(device)
        for name in self.param_importance:
            self.param_importance[name] = self.param_importance[name].to(device)

        for name in self.omega:
            self.omega[name] = self.omega[name].to(device)
            self.prev_params[name] = self.prev_params[name].to(device)
            self.prev_task_params[name] = self.prev_task_params[name].to(device)

        return self

    def forward(self, x):
        return self.network(x)

    def predict(self, x):
        output = self.forward(x)
        return torch.argmax(output, dim=1)

    def loglik(self, output, y):
        return torch.gather(output, dim=1, index=y.unsqueeze(1))

    def update_omega_values(self):
        """
        Update omega values using current gradients and parameter changes.
        Call this during training at each optimization step.
        """
        for name, param in self.named_parameters():
            if param.grad is not None:
                delta = param.data.clone() - self.prev_params[name]

                self.omega[name] -= param.grad * delta

                self.prev_params[name] = param.data.clone()

    def reg_term(self, lambda_reg, nt):
        """
        Calculate regularization term for Synaptic Intelligence.
        lambda_reg: Regularization strength
        nt: Number of training examples
        """
        reg_term = 0.0

        for name, param in self.named_parameters():
            if name in self.prev_params and name in self.param_importance:
                param_change = param - self.prev_task_params[name]
                reg_term += (self.param_importance[name] * param_change**2).sum()

        return 0.5 * lambda_reg * reg_term / nt

    def finish_task(self, X_train, y_train):
        """
        Calculate final parameter importance after finishing a task.
        normalization: Factor to normalize importance (e.g., final loss value)
        """
        self.eval()

        for name, param in self.named_parameters():
            delta = param.data.clone() - self.prev_task_params[name]
            importance = self.omega[name] / (delta**2 + self.xi)
            self.param_importance[name] += importance
            self.prev_task_params[name] = param.data.clone()
            self.omega[name] = torch.zeros_like(param.data)

        self.task_count += 1

# for lambda_reg in [0.01, 0.1, 0.5]:
#     mnist_si = MnistSI(xi=1e-3).to(device)

#     metric = lambda preds, y: torch.sum(preds == y).item()

#     optim = lambda: torch.optim.AdamW(mnist_si.parameters(), lr=2e-3, weight_decay=0)

#     permuted_mnist_results_si = train_permuted_mnist_laplace(mnist_si, optim, d, lambda_reg=lambda_reg, n_epochs=20, n_tasks=10, metric=metric, si=True)
#     display(permuted_mnist_results_si)

#     permuted_mnist_results_si.to_csv(f"./drive/MyDrive/permuted_mnist_SI_epochs_20_lr_2e-3_lambda_{lambda_reg}_xi_1e-3.csv")

In [None]:
hidden_size = 100

class MnistSIRegression(nn.Module):
    def __init__(self, xi: float):
        super().__init__()

        self.network = nn.Sequential(
            nn.Linear(d.dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, d_reg.num_classes * 2)
        )

        self.param_importance = {}
        self.omega = {}
        self.prev_params = {}
        self.prev_task_params = {}

        self._init_importance_params()

        self.output_dim = d_reg.num_classes
        self.task_count = 0
        self.xi = xi

    def _init_importance_params(self):
        for name, param in self.named_parameters():
            self.omega[name] = torch.zeros_like(param.data)
            self.prev_params[name] = param.data.clone()
            self.prev_task_params[name] = param.data.clone()

            if name not in self.param_importance:
                self.param_importance[name] = torch.zeros_like(param.data)

    def to(self, device):
        super().to(device)
        for name in self.param_importance:
            self.param_importance[name] = self.param_importance[name].to(device)

        for name in self.omega:
            self.omega[name] = self.omega[name].to(device)
            self.prev_params[name] = self.prev_params[name].to(device)
            self.prev_task_params[name] = self.prev_task_params[name].to(device)

        return self

    def forward(self, x):
        return self.network(x)

    def predict(self, x):
        output = self.forward(x)
        means = output[:, :self.output_dim]
        return means

    def loglik(self, output, y):
        means = output[:, :self.output_dim]
        logvars = output[:, self.output_dim:]
        loglik = -(logvars * 0.5 + np.log(2 * np.pi)) - 0.5 * (y - means) ** 2 / (1e-8 + torch.exp(logvars))
        return loglik.sum(dim=1)

    def update_omega_values(self):
        """
        Update omega values using current gradients and parameter changes.
        Call this during training at each optimization step.
        """
        for name, param in self.named_parameters():
            if param.grad is not None:
                delta = param.data.clone() - self.prev_params[name]

                self.omega[name] -= param.grad * delta

                self.prev_params[name] = param.data.clone()

    def reg_term(self, lambda_reg, nt):
        """
        Calculate regularization term for Synaptic Intelligence.
        lambda_reg: Regularization strength
        nt: Number of training examples
        """
        reg_term = 0.0

        for name, param in self.named_parameters():
            if name in self.prev_params and name in self.param_importance:
                param_change = param - self.prev_task_params[name]
                reg_term += (self.param_importance[name] * param_change**2).sum()

        return 0.5 * lambda_reg * reg_term / nt

    def finish_task(self, X_train, y_train):
        """
        Calculate final parameter importance after finishing a task.
        """
        self.eval()

        for name, param in self.named_parameters():
            delta = param.data.clone() - self.prev_task_params[name]
            importance = self.omega[name] / (delta**2 + self.xi)
            self.param_importance[name] += importance
            self.prev_task_params[name] = param.data.clone()
            self.omega[name] = torch.zeros_like(param.data)

        self.task_count += 1

mnist_si_reg = MnistSIRegression(xi=1e-3).to(device)

metric = lambda preds, y: torch.mean((preds - y) ** 2, axis=1).sum().item()

optim = lambda: torch.optim.AdamW(mnist_si_reg.parameters(), lr=1e-3, weight_decay=0)

permuted_mnist_results_si_reg = train_permuted_mnist_laplace(mnist_si_reg, optim, d_reg, lambda_reg=0.1, n_epochs=20, n_tasks=10, metric=metric, si=True)
display(permuted_mnist_results_si_reg)

permuted_mnist_results_si_reg.to_csv(f"./drive/MyDrive/permuted_mnist_reg_SI_epochs_20_lr_1e-3_lambda_0.1_xi_1e-3.csv")


In [None]:

permuted_mnist_results_si_reg.to_csv(f"./drive/MyDrive/permuted_mnist_reg_SI_epochs_20_lr_1e-3_lambda_0.1_xi_1e-3.csv")

# Split MNIST

In [None]:
class SplitMNISTData:
    def __init__(self):
        transform = transforms.Compose([
            transforms.ToTensor()
        ])

        train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

        train_images = torch.stack([img for img, _ in train_dataset])
        train_labels = torch.tensor([label for _, label in train_dataset])

        test_images = torch.stack([img for img, _ in test_dataset])
        test_labels = torch.tensor([label for _, label in test_dataset])

        self.train_X = train_images.view(train_images.size(0), -1)
        self.train_y = train_labels

        self.test_X = test_images.view(test_images.size(0), -1)
        self.test_y = test_labels

        self.dim = self.train_X.size(1)
        self.num_classes = 2

        self.tasks = [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]

        self.device = "cpu"

    def _get_task_data(self, task_id, train=True):
        digit1, digit2 = self.tasks[task_id]

        X = self.train_X if train else self.test_X
        y = self.train_y if train else self.test_y

        mask = (y == digit1) | (y == digit2)
        task_X = X[mask]
        task_y = y[mask]

        binary_y = (task_y == digit2).long()

        return task_X.to(self.device), binary_y.to(self.device)

    def permuted_train(self, seed):
        return self._get_task_data(seed, train=True)

    def permuted_test(self, seed):
        return self._get_task_data(seed, train=False)

    def to(self, device):
        self.device = device
        return self

In [None]:
split_d = SplitMNISTData().to(device)

In [None]:
weight_mean = 0
weight_logvar = 0
bias_mean = 0
bias_logvar = 0
init_weight_var=1e-6

flag = True

hidden_size = 256
class MnistBinaryBNN(BNN):
    def __init__(self):
        super().__init__(
            nn.ModuleList([
                BayesianLinear(split_d.dim, hidden_size, weight_mean, weight_logvar, bias_mean, bias_logvar, init_weight_var=init_weight_var),
                nn.ReLU(),
                BayesianLinear(hidden_size, hidden_size, weight_mean, weight_logvar, bias_mean, bias_logvar, init_weight_var=init_weight_var),
                nn.ReLU(),
                BayesianLinear(hidden_size, hidden_size, weight_mean, weight_logvar, bias_mean, bias_logvar, init_weight_var=init_weight_var),
                nn.ReLU(),
                BayesianLinear(hidden_size, 1, weight_mean, weight_logvar, bias_mean, bias_logvar, init_weight_var=init_weight_var),
                # nn.Sigmoid()
            ])
        )

    def predict(self, x, sample=False):
        output = self.forward(x, sample=sample)
        return (output > 0).long().squeeze()

    def loglik(self, output, y):
        # return torch.gather(output, dim=1, index=y.unsqueeze(1))
        output = output.squeeze()
        return F.logsigmoid(torch.where(y == 1, output, -output))

# mnist_binary_bnn = MnistBinaryBNN().to(device)

# metric = lambda preds, y: torch.sum(preds == y).item()

# optim = lambda: torch.optim.Adam(mnist_binary_bnn.parameters(), lr=1e-3)

# split_mnist_results = train_permuted_mnist(mnist_binary_bnn, optim, split_d, coreset_size=200, coreset_only=False, n_epochs=50, n_tasks=5, metric=metric)
# display(split_mnist_results)

# split_mnist_results.to_csv("./drive/MyDrive/split_mnist_coreset_200_epochs_50_lr_1e-3_init_var_1e-4.csv")

In [None]:
hidden_size = 100

class SplitMnistLaplace(nn.Module):
    def __init__(self):
        super().__init__()

        self.network = nn.Sequential(
            nn.Linear(d.dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

        self.Sigma_inv = {}
        for name, param in self.named_parameters():
            self.Sigma_inv[name] = torch.ones_like(param)  # Initialize to I

        self.theta_prev = {}
        for name, param in self.named_parameters():
            self.theta_prev[name] = torch.zeros_like(param)

        self.output_dim = d.num_classes

    def to(self, device):
        super().to(device)
        for t in self.theta_prev:
            self.theta_prev[t] = self.theta_prev[t].to(device)
        for t in self.Sigma_inv:
            self.Sigma_inv[t] = self.Sigma_inv[t].to(device)
        return self

    def forward(self, x):
        return self.network(x)

    def predict(self, x):
        output = self.forward(x)
        return (output > 0).long().squeeze()

    def loglik(self, output, y):
        output = output.squeeze()
        return F.logsigmoid(torch.where(y == 1, output, -output))

    def reg_term(self, lambda_reg, nt):
        reg_term = 0.0
        for name, param in self.named_parameters():
            diff = param - self.theta_prev[name]
            reg_term += 0.5 * lambda_reg * (diff * self.Sigma_inv[name] * diff).sum() / nt
        return reg_term

    def finish_task(self, X_train, y_train):
        self.eval()

        hessian_diag = self.hessian_approx(X_train, y_train, n_samples=2000)

        for (name, _), hess in zip(self.named_parameters(), hessian_diag):
            self.Sigma_inv[name] = self.Sigma_inv[name] + hess

        for name, param in self.named_parameters():
            self.theta_prev[name] = param.detach().clone()

    def hessian_approx(self, X, y, n_samples=200):
        nt = X.size(0)
        assert n_samples <= nt

        sampled_indices = np.random.choice(nt, size=n_samples, replace=False)
        sample_X = X[sampled_indices]
        sample_y = y[sampled_indices]

        params = list(self.parameters())

        loglik = self.loglik(self.forward(sample_X), sample_y)
        sq_grads = [torch.zeros_like(param) for param in params]
        for i in range(n_samples):
            retain_graph = i < n_samples - 1
            grads = torch.autograd.grad(loglik[i], self.parameters(), retain_graph=retain_graph)
            for param_idx, (param, grad) in enumerate(zip(params, grads)):
                sq_grads[param_idx] += grad ** 2

        sum_sq_grads = [sq_grad / n_samples * nt for sq_grad in sq_grads]

        return sum_sq_grads


# split_mnist_laplace = SplitMnistLaplace().to(device)

# metric = lambda preds, y: torch.sum(preds == y).item()

# optim = lambda: torch.optim.AdamW(split_mnist_laplace.parameters(), weight_decay=0, lr=2e-3)

# # train_laplace(mnist_laplace, optim(), d.permuted_train(0), d.permuted_test(0), train_batch_size=256, test_batch_size=1000, metric=metric, lambda_reg=0.1, n_epochs=20, verbose=True)

# split_mnist_results_lp = train_permuted_mnist_laplace(split_mnist_laplace, optim, split_d, lambda_reg=0.1, n_epochs=50, n_tasks=5, metric=metric)
# display(split_mnist_results_lp)

# split_mnist_results_lp.to_csv("./drive/MyDrive/split_mnist_LP_lambda_0.1_epochs_50_lr_2e-3_hidden_100_approx_2000.csv")

In [None]:
hidden_size = 100

class SplitMnistEWC(nn.Module):
    def __init__(self):
        super().__init__()

        self.network = nn.Sequential(
            nn.Linear(d.dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

        self.theta_list = []
        self.fisher_list = []

        self.output_dim = d.num_classes
        self.task_count = 0

    def to(self, device):
        super().to(device)
        for task_idx in range(len(self.theta_list)):
            for name in self.theta_list[task_idx]:
                self.theta_list[task_idx][name] = self.theta_list[task_idx][name].to(device)

        for task_idx in range(len(self.fisher_list)):
            for name in self.fisher_list[task_idx]:
                self.fisher_list[task_idx][name] = self.fisher_list[task_idx][name].to(device)

        return self

    def forward(self, x):
        return self.network(x)

    def predict(self, x):
        output = self.forward(x)
        return (output > 0).long().squeeze()

    def loglik(self, output, y):
        output = output.squeeze()
        return F.logsigmoid(torch.where(y == 1, output, -output))

    def reg_term(self, lambda_reg, nt):
        reg_term = torch.tensor(0.0, device=device)

        task_regs = []
        for i in range(self.task_count):
            theta_i = self.theta_list[i]

            task_reg = 0.0
            for name, param in self.named_parameters():
                if name in theta_i:
                    diff = param - theta_i[name]
                    fisher = self.fisher_list[i][name]
                    task_reg += (diff * fisher * diff).sum()

            task_regs.append(task_reg)
            reg_term += task_reg

        return 0.5 * lambda_reg * reg_term / nt

    def finish_task(self, X_train, y_train):
        self.eval()

        current_theta = {}
        for name, param in self.named_parameters():
            current_theta[name] = param.detach().clone()

        self.theta_list.append(current_theta)

        fisher_diag = self.fisher_approx(X_train, y_train, n_samples=2000)

        current_fisher = {}
        for (name, _), fisher in zip(self.named_parameters(), fisher_diag):
            current_fisher[name] = fisher

        self.fisher_list.append(current_fisher)

        self.task_count += 1

    def fisher_approx(self, X, y, n_samples=200):
        nt = X.size(0)
        assert n_samples <= nt

        sampled_indices = np.random.choice(nt, size=n_samples, replace=False)
        sample_X = X[sampled_indices]
        sample_y = y[sampled_indices]

        params = list(self.parameters())

        loglik = self.loglik(self.forward(sample_X), sample_y)
        sq_grads = [torch.zeros_like(param) for param in params]
        for i in range(n_samples):
            retain_graph = i < n_samples - 1
            grads = torch.autograd.grad(loglik[i], self.parameters(), retain_graph=retain_graph)
            for param_idx, (param, grad) in enumerate(zip(params, grads)):
                sq_grads[param_idx] += grad ** 2

        sum_sq_grads = [sq_grad / n_samples * nt for sq_grad in sq_grads]

        return sum_sq_grads

# split_mnist_ewc = SplitMnistEWC().to(device)

# metric = lambda preds, y: torch.sum(preds == y).item()

# optim = lambda: torch.optim.AdamW(split_mnist_ewc.parameters(), lr=2e-3, weight_decay=0)

# split_mnist_results_ewc = train_permuted_mnist_laplace(split_mnist_ewc, optim, split_d, lambda_reg=100, n_epochs=50, n_tasks=5, metric=metric)
# display(split_mnist_results_ewc)

# split_mnist_results_ewc.to_csv("./drive/MyDrive/split_mnist_EWC_lambda_0.1_epochs_50_lr_2e-3_hidden_100_approx_2000.csv")

In [None]:
hidden_size = 100

class SplitMnistSI(nn.Module):
    def __init__(self, xi: float):
        super().__init__()

        self.network = nn.Sequential(
            nn.Linear(d.dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

        self.param_importance = {}
        self.omega = {}
        self.prev_params = {}
        self.prev_task_params = {}

        self._init_importance_params()

        self.output_dim = d.num_classes
        self.task_count = 0
        self.xi = xi

    def _init_importance_params(self):
        for name, param in self.named_parameters():
            self.omega[name] = torch.zeros_like(param.data)
            self.prev_params[name] = param.data.clone()
            self.prev_task_params[name] = param.data.clone()

            if name not in self.param_importance:
                self.param_importance[name] = torch.zeros_like(param.data)

    def to(self, device):
        super().to(device)
        for name in self.param_importance:
            self.param_importance[name] = self.param_importance[name].to(device)

        for name in self.omega:
            self.omega[name] = self.omega[name].to(device)
            self.prev_params[name] = self.prev_params[name].to(device)
            self.prev_task_params[name] = self.prev_task_params[name].to(device)

        return self

    def forward(self, x):
        return self.network(x)

    def predict(self, x):
        output = self.forward(x)
        return (output > 0).long().squeeze()

    def loglik(self, output, y):
        output = output.squeeze()
        return F.logsigmoid(torch.where(y == 1, output, -output))

    def update_omega_values(self):
        """
        Update omega values using current gradients and parameter changes.
        Call this during training at each optimization step.
        """
        for name, param in self.named_parameters():
            if param.grad is not None:
                delta = param.data.clone() - self.prev_params[name]

                self.omega[name] -= param.grad * delta

                self.prev_params[name] = param.data.clone()

    def reg_term(self, lambda_reg, nt):
        """
        Calculate regularization term for Synaptic Intelligence.
        lambda_reg: Regularization strength
        nt: Number of training examples
        """
        reg_term = 0.0

        for name, param in self.named_parameters():
            if name in self.prev_params and name in self.param_importance:
                param_change = param - self.prev_task_params[name]
                reg_term += (self.param_importance[name] * param_change**2).sum()

        return 0.5 * lambda_reg * reg_term / nt

    def finish_task(self, X_train, y_train):
        """
        Calculate final parameter importance after finishing a task.
        normalization: Factor to normalize importance (e.g., final loss value)
        """
        self.eval()

        for name, param in self.named_parameters():
            delta = param.data.clone() - self.prev_task_params[name]
            importance = self.omega[name] / (delta**2 + self.xi)
            self.param_importance[name] += importance
            self.prev_task_params[name] = param.data.clone()
            self.omega[name] = torch.zeros_like(param.data)

        self.task_count += 1

for lambda_reg in [0.1, 0.5, 1, 10]:
    split_mnist_si = SplitMnistSI(xi=1e-3).to(device)

    metric = lambda preds, y: torch.sum(preds == y).item()

    optim = lambda: torch.optim.AdamW(split_mnist_si.parameters(), lr=2e-3, weight_decay=0)

    permuted_mnist_results_si = train_permuted_mnist_laplace(split_mnist_si, optim, split_d, lambda_reg=lambda_reg, n_epochs=50, n_tasks=5, metric=metric, si=True)
    display(permuted_mnist_results_si)

    permuted_mnist_results_si.to_csv(f"./drive/MyDrive/permuted_mnist_SI_epochs_50_lr_2e-3_lambda_{lambda_reg}_xi_1e-3.csv")