In [None]:
import math, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
import numpy as np, random, os, time, tarfile, gc
import copy
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device", device)
torch.manual_seed(1); random.seed(1); np.random.seed(1)

In [None]:
class PermutedMNIST:
    def __init__(self, num_tasks=10, seed=123):
        shuffler = torch.Generator().manual_seed(seed)
        self.perms = [torch.randperm(784, generator=shuffler) for _ in range(num_tasks)]
        tfm = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
        tr = datasets.MNIST("./data", True,  download=True, transform=tfm)
        te = datasets.MNIST("./data", False, download=True, transform=tfm)
        x_tr = tr.data.float().view(-1,784)/255.
        x_te = te.data.float().view(-1,784)/255.
        y_tr = F.one_hot(tr.targets, 10).float()
        y_te = F.one_hot(te.targets, 10).float()
        self.tasks = [(TensorDataset(x_tr[:, p], y_tr), TensorDataset(x_te[:, p], y_te))for p in self.perms]
        self.input_dim = 784
        self.n_classes = 10
        self.num_tasks = num_tasks
    def get_task(self, tid):
        return self.tasks[tid]

class SplitMNIST:
    pairs = [(0,1), (2,3), (4,5), (6,7), (8,9)]
    def __init__(self):
        tfm = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
        tr = datasets.MNIST("./data", True,  download=True, transform=tfm)
        te = datasets.MNIST("./data", False, download=True, transform=tfm)
        x_tr = tr.data.float().view(-1,784)/255.
        x_te = te.data.float().view(-1,784)/255.
        self.tasks=[]
        for a, b in self.pairs:
            msk_tr = (tr.targets==a)|(tr.targets==b)
            msk_te = (te.targets==a)|(te.targets==b)
            y_tr = F.one_hot((tr.targets[msk_tr]==b).long(), 2).float()
            y_te = F.one_hot((te.targets[msk_te]==b).long(), 2).float()
            self.tasks.append((TensorDataset(x_tr[msk_tr], y_tr), TensorDataset(x_te[msk_te], y_te)))
        self.input_dim = 784
        self.n_classes = 2
        self.num_tasks = 5
    def get_task(self, tid):
        return self.tasks[tid]

In [None]:
class SplitNotMNIST:
    # pairs are A/F , B/G , C/H , D/I , E/J: A, B, C, D, E, F, G, H, I, J
    _pairs = [(0,5), (1,6), (2,7), (3,8), (4,9)]
    def __init__(self, path="notMNIST_small.tar.gz", num_tasks=5, seed=0):
        self.num_tasks = num_tasks

        base_folder = os.path.splitext(os.path.splitext(path)[0])[0]
        if not os.path.isdir(base_folder):
            with tarfile.open(path, "r:gz") as tar:
                tar.extractall()

        xs, ys = [], []
        char_to_idx = {chr(ord('A')+i): i for i in range(10)}
        for char, idx in char_to_idx.items():
            folder = os.path.join(base_folder, char)
            for fname in os.listdir(folder):
                if fname.startswith('.'): continue
                path = os.path.join(folder, fname)
                try:
                    img = Image.open(path).convert("L")
                    arr = np.asarray(img, dtype=np.float32)
                    arr /= 255.0
                    if arr.shape == (28,28):
                        xs.append(arr.flatten())
                        ys.append(idx)
                except Exception:
                    pass
        xs = torch.tensor(np.stack(xs))
        ys = torch.tensor(ys)

        # do te/tr div
        N = xs.size(0)
        perm = torch.randperm(N, generator=torch.Generator().manual_seed(seed))
        split = int(0.9*N)
        x_tr, y_tr = xs[perm[:split]], ys[perm[:split]]
        x_te, y_te = xs[perm[split:]], ys[perm[split:]]
        del xs, ys; gc.collect()

        self.tasks = []
        for tid, pair in enumerate(self._pairs[:num_tasks]):
            a,b      = pair
            msk_tr   = (y_tr==a)|(y_tr==b)
            msk_te   = (y_te==a)|(y_te==b)
            map_helper = {a:0, b:1}
            ytr_m = torch.tensor([map_helper[l.item()] for l in y_tr[msk_tr]])
            yte_m = torch.tensor([map_helper[l.item()] for l in y_te[msk_te]])
            ytr_1h = F.one_hot(ytr_m, 2).float()
            yte_1h = F.one_hot(yte_m, 2).float()
            self.tasks.append((TensorDataset(x_tr[msk_tr], ytr_1h), TensorDataset(x_te[msk_te], yte_1h)))

        self.input_dim = 784
        self.n_classes = 2

    def get_task(self, tid):
        return self.tasks[tid]

In [None]:
class PlainMLP(nn.Module):
    def __init__(self, dims):
        super().__init__()
        layers=[]
        for din,dout in zip(dims[:-1],dims[1:]):
            layers.append(nn.Linear(din,dout))
            layers.append(nn.ReLU())
        layers.pop() # last one no relu
        self.net = nn.Sequential(*layers)
    def forward(self,x): return self.net(x)

class BayesianLinear(nn.Module):
    def __init__(self, in_f, out_f, prior_var=1.0):
        super().__init__()
        self.w_mu     = nn.Parameter(torch.empty(out_f,in_f))
        self.w_logvar = nn.Parameter(torch.full((out_f,in_f), -6.0))
        self.b_mu     = nn.Parameter(torch.empty(out_f))
        self.b_logvar = nn.Parameter(torch.full((out_f,), -6.0))
        nn.init.normal_(self.w_mu,0,0.1); nn.init.normal_(self.b_mu,0,0.1)
        self.register_buffer("pw_mu", torch.zeros_like(self.w_mu))
        self.register_buffer("pw_logvar", torch.full_like(self.w_mu, math.log(prior_var)))
        self.register_buffer("pb_mu", torch.zeros_like(self.b_mu))
        self.register_buffer("pb_logvar", torch.full_like(self.b_mu, math.log(prior_var)))
    def sample(self):
        ew = torch.randn_like(self.w_mu)
        eb = torch.randn_like(self.b_mu)
        w = self.w_mu + (0.5*self.w_logvar).exp()*ew
        b = self.b_mu + (0.5*self.b_logvar).exp()*eb
        return w,b
    def forward(self,x,sample=True):
        w,b = self.sample() if sample else (self.w_mu, self.b_mu)
        return F.linear(x,w,b)
    def helper_kl(self, m, lv, m0, lv0):
        v, v0 = lv.exp(), lv0.exp()
        return 0.5*((lv0-lv) + (v+(m-m0).pow(2))/v0 -1).sum()
    def kl(self):
        return self.helper_kl(self.w_mu,self.w_logvar,self.pw_mu,self.pw_logvar) + self.helper_kl(self.b_mu,self.b_logvar,self.pb_mu,self.pb_logvar)
    def update_prior(self):
        self.pw_mu.data.copy_(self.w_mu.data)
        self.pw_logvar.data.copy_(self.w_logvar.data)
        self.pb_mu.data.copy_(self.b_mu.data)
        self.pb_logvar.data.copy_(self.b_logvar.data)

class BayesianMLP(nn.Module):
    def __init__(self, in_dim, hidden, out_dim, heads=1, prior_var=1.0):
        super().__init__()
        self.hidden = nn.ModuleList()
        last = in_dim
        for h in hidden:
            self.hidden.append(BayesianLinear(last,h,prior_var))
            last = h
        self.heads = nn.ModuleList([BayesianLinear(last,out_dim,prior_var) for _ in range(heads)])
        self.out_dim = out_dim
    def add_head(self, out_dim):
        head = BayesianLinear(self.hidden[-1].w_mu.size(0), out_dim)
        head.to(next(self.parameters()).device); self.heads.append(head)
    def forward(self,x,head_id=0,sample=True):
        for l in self.hidden: x = torch.relu(l(x,sample))
        return self.heads[head_id](x,sample)
    def kl(self):
        return sum(l.kl() for l in self.hidden)+sum(h.kl() for h in self.heads)
    def update_prior(self):
        for l in self.hidden: l.update_prior()
        for h in self.heads:  h.update_prior()


In [None]:
class VCL:
    def __init__(self, in_dim, hidden, n_classes,
                 single_head=True, lr=1e-3, mc=10,
                 prior_var=1.0, coreset_size=0,
                 coreset_method="random",
                 coreset_epochs=20):

        self.in_dim, self.hidden_sizes = in_dim, list(hidden)
        self.n_classes = n_classes
        self.single_head = single_head
        self.lr, self.mc = lr, mc
        self.prior_var = prior_var
        self.coreset_size = coreset_size
        self.coreset_method = coreset_method
        self.coreset_epochs = coreset_epochs


        init_heads = 1
        self.model = BayesianMLP(in_dim=self.in_dim, hidden=self.hidden_sizes, out_dim=self.n_classes, heads=init_heads, prior_var=self.prior_var).to(device)
        self.acc_hist = []
        self.core_x, self.core_y = [], []

        print(f"VCLTrainer: Single Head: {self.single_head}, Coreset Size: {self.coreset_size}, coreset_method = {self.coreset_method}")

    def coreset_selection(self, x_full, y_full):
        n_samples = x_full.size(0)
        if self.coreset_size <= 0:
             return None, None, x_full, y_full
        if self.coreset_size >= n_samples:
             print("Coreset size > train size")
             return x_full, y_full, None, None

        if self.coreset_method == "random":
            perm = torch.randperm(n_samples, device=x_full.device)
            core_idx = perm[:self.coreset_size]
            non_core_mask = torch.ones(n_samples, dtype=torch.bool, device=x_full.device)
            non_core_mask[core_idx] = False

            core_x, core_y = x_full[core_idx], y_full[core_idx]
            non_core_x, non_core_y = x_full[non_core_mask], y_full[non_core_mask]

            return core_x, core_y, non_core_x, non_core_y

        elif self.coreset_method == "k_center":
            features = x_full
            selected_indices = []
            min_distances = torch.full((n_samples,), float('inf'), device=device, dtype=features.dtype)
            current_idx = torch.randint(0, n_samples, (1,), device=device).item()
            selected_indices.append(current_idx)
            min_distances[current_idx] = -1.0

            for i in range(1, self.coreset_size):
                last_center_features = features[current_idx].unsqueeze(0)
                dist_sq = torch.sum((features - last_center_features)**2, dim=1)
                min_distances = torch.minimum(min_distances, dist_sq)
                current_idx = torch.argmax(min_distances).item()
                selected_indices.append(current_idx)
                min_distances[current_idx] = -1.0
            core_idx = torch.tensor(selected_indices, dtype=torch.long, device=device)
            non_core_mask = torch.ones(n_samples, dtype=torch.bool, device=device)
            non_core_mask[core_idx] = False
            core_x, core_y = x_full[core_idx], y_full[core_idx]
            non_core_x, non_core_y = x_full[non_core_mask], y_full[non_core_mask]
            return core_x, core_y, non_core_x, non_core_y

    @torch.no_grad()
    def evaluate(self, dset, t, model_to_eval):
        model_to_eval.eval()
        accs = []
        num_test_tasks = t + 1

        for task_idx in range(num_test_tasks):
            test_ds_tuple = dset.get_task(task_idx)
            test_ds = test_ds_tuple[1]

            if len(test_ds) == 0:
                accs.append(float('nan'))
                continue

            xt, yt = test_ds.tensors[0].to(device), test_ds.tensors[1].to(device)
            loader_test = DataLoader(TensorDataset(xt, yt), batch_size=1024)
            task_correct, task_total = 0, 0
            head_id_eval = 0 if self.single_head else task_idx

            for xb_test, yb_test in loader_test:
                logits = model_to_eval(xb_test, head_id_eval, sample=False)
                preds = logits.argmax(1)
                targets = yb_test.argmax(1) if yb_test.ndim > 1 else yb_test
                task_correct += (preds == targets).sum().item()
                task_total += xb_test.size(0)

            accs.append(task_correct / task_total)

        acc_str = ", ".join([f"T{i}={acc}" for i, acc in enumerate(accs)])
        avg_acc = sum(accs) / len(accs) if accs else -233213.0
        print(f"After Task {t} Eval: Avg={avg_acc} | [{acc_str}]")
        self.acc_hist.append(accs)

    def fit(self, dset, epochs=120, batch_size=None):
        # pre-training here

        opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)

        for t in range(dset.num_tasks):
            task_start_time = time.time()

            if not self.single_head and t > 0:
                current_task_out_dim = dset.get_task(t)[0].tensors[1].shape[1]
                self.model.add_head(out_dim=current_task_out_dim)
                print(f"New head {t} for task {t} it's multihead")
                opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)


            train_ds_t = dset.get_task(t)[0]
            if len(train_ds_t) == 0:
                print(f"empty training set in task{t}")
                self.evaluate(dset, t, self.model)
                continue

            x_task_full, y_task_full = train_ds_t.tensors[0].to(device), train_ds_t.tensors[1].to(device)


            core_x_t, core_y_t, x_task_non_core, y_task_non_core = self.coreset_selection(x_task_full, y_task_full)

            if core_x_t is not None:
                self.core_x.append(core_x_t.cpu())
                self.core_y.append(core_y_t.cpu())

            # just non coreset points
            if x_task_non_core is not None and x_task_non_core.size(0) > 0:
                current_batch_size_prop = min(batch_size, x_task_non_core.size(0))
                loader_non_core = DataLoader(TensorDataset(x_task_non_core, y_task_non_core), batch_size=current_batch_size_prop)
                num_train_samples_non_core = x_task_non_core.size(0)
                self.model.train()
                for epoch in tqdm(range(epochs), desc=f"Task {t} learning", leave=False):
                    epoch_loss_prop = 0.0
                    for xb, yb in loader_non_core:
                        opt.zero_grad()
                        nll = 0.0
                        head_id = 0 if self.single_head else t

                        for _ in range(self.mc):
                            logits = self.model(xb, head_id, sample=True)
                            nll += F.cross_entropy(logits, yb.argmax(1) if yb.ndim > 1 else yb, reduction='sum')

                        kl = self.model.kl()
                        loss = (nll / xb.size(0) / self.mc) + (kl / num_train_samples_non_core)
                        epoch_loss_prop += loss.item() * xb.size(0)

                        loss.backward()

                        opt.step()
                avg_epoch_loss_prop = epoch_loss_prop / num_train_samples_non_core
                print(f"Task {t}: learning finished. Avg Loss: {avg_epoch_loss_prop}")
            else:
                print(f"Task {t}: no noncorset points check out")

            # store so coreset doesn't override
            state_after_propagation = copy.deepcopy(self.model.state_dict())

            model_state_before_eval = state_after_propagation

            if self.coreset_size > 0 and self.core_x and self.coreset_epochs > 0:

                self.model.update_prior() # coresets are just new bayesian update so like in paper set prior
                self.model.train()

                for epoch in tqdm(range(self.coreset_epochs), desc=f"Task {t} reminder", leave=False):
                    epoch_loss_re = 0.0
                    total_samples_processed_in_epoch = 0

                    if self.single_head:
                        x_core_re = torch.cat([c.to(device) for c in self.core_x])
                        y_core_re = torch.cat([c.to(device) for c in self.core_y])
                        head_id_re = 0
                        num_train_samples_core = x_core_re.size(0)

                        if num_train_samples_core > 0:
                            current_batch_size_re = min(batch_size, num_train_samples_core)
                            loader_core = DataLoader(TensorDataset(x_core_re, y_core_re), batch_size=current_batch_size_re)
                            for xb, yb in loader_core:
                                opt.zero_grad()
                                nll = 0.0
                                for _ in range(self.mc):
                                    logits = self.model(xb, head_id_re, sample=True)
                                    nll += F.cross_entropy(logits, yb.argmax(1) if yb.ndim > 1 else yb, reduction='sum')
                                kl = self.model.kl()
                                N_t_scaling_factor = num_train_samples_non_core
                                loss = (nll / xb.size(0) / self.mc) + (kl / N_t_scaling_factor)
                                epoch_loss_re += loss.item() * xb.size(0)
                                total_samples_processed_in_epoch += xb.size(0)
                                loss.backward()
                                opt.step()

                    else:
                        for task_idx_core in range(len(self.core_x)):
                            if self.core_x[task_idx_core] is not None and self.core_x[task_idx_core].size(0) > 0:
                                x_core_task = self.core_x[task_idx_core].to(device)
                                y_core_task = self.core_y[task_idx_core].to(device)
                                head_id_re = task_idx_core
                                num_train_samples_core_task = x_core_task.size(0)
                                current_batch_size_re = min(batch_size, num_train_samples_core_task)
                                loader_core_task = DataLoader(TensorDataset(x_core_task, y_core_task), batch_size=current_batch_size_re)

                                for xb, yb in loader_core_task:
                                    opt.zero_grad()
                                    nll = 0.0
                                    for _ in range(self.mc):
                                        logits = self.model(xb, head_id_re, sample=True)
                                        nll += F.cross_entropy(logits, yb.argmax(1) if yb.ndim > 1 else yb, reduction='sum')

                                    kl = self.model.kl()
                                    loss = (nll / xb.size(0) / self.mc) + (kl / num_train_samples_non_core)
                                    epoch_loss_re += loss.item() * xb.size(0)
                                    total_samples_processed_in_epoch += xb.size(0)
                                    loss.backward()
                                    opt.step()

                model_state_before_eval = copy.deepcopy(self.model.state_dict()) #
            else:
                print(f"Task {t}: No coreset data available")


            eval_model = BayesianMLP(in_dim=self.in_dim, hidden=self.hidden_sizes, out_dim=self.n_classes,heads=len(self.model.heads), prior_var=self.prior_var).to(device)


            eval_model.load_state_dict(model_state_before_eval)
            self.evaluate(dset, t, eval_model)
            del eval_model # delete never used again

            # don't propagate coreset info
            self.model.load_state_dict(state_after_propagation)

            self.model.update_prior()

            task_end_time = time.time()
            print(f"Task {t} Complete took: {task_end_time - task_start_time}s")


In [None]:
perm_trainer_paper = VCL(
    in_dim=784, hidden=[100,100], n_classes=10,
    single_head=True,
    lr=1e-3, mc=3,
    coreset_size=200,
    coreset_epochs=0,
    prior_var=1.0,
    coreset_method = "random"
)

permuted_mnist_data = PermutedMNIST(num_tasks=10, seed = 100)
perm_trainer_paper.fit(permuted_mnist_data, epochs=100, batch_size=256)

In [None]:
perm_trainer_paper = VCL(
    in_dim=784, hidden=[100,100], n_classes=10,
    single_head=True,
    lr=1e-3, mc=3,
    coreset_size=200,
    coreset_epochs=100,
    prior_var=1.0,
    coreset_method = "random"
)

permuted_mnist_data = PermutedMNIST(num_tasks=10, seed = 100)
perm_trainer_paper.fit(permuted_mnist_data, epochs=100, batch_size=256)

In [None]:
perm_trainer_paper = VCL(
    in_dim=784, hidden=[256,256], n_classes=2,
    single_head=False,
    lr=1e-3, mc=3,
    coreset_size=200,
    coreset_epochs=0,
    prior_var=1.0,
    coreset_method = "random"
)

split_mnist = SplitMNIST()
perm_trainer_paper.fit(split_mnist, epochs=100, batch_size=7000)