In [None]:
import math, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
import numpy as np, random, os, time, tarfile, gc, heapq
import copy
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device", device)
torch.manual_seed(1); random.seed(1); np.random.seed(1)

In [None]:
class PermutedMNIST:
    def __init__(self, num_tasks=10, seed=123):
        shuffler = torch.Generator().manual_seed(seed)
        self.perms = [torch.randperm(784, generator=shuffler) for _ in range(num_tasks)]
        tfm = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
        tr = datasets.MNIST("./data", True,  download=True, transform=tfm)
        te = datasets.MNIST("./data", False, download=True, transform=tfm)
        x_tr = tr.data.float().view(-1,784)/255.
        x_te = te.data.float().view(-1,784)/255.
        y_tr = F.one_hot(tr.targets, 10).float()
        y_te = F.one_hot(te.targets, 10).float()
        self.tasks = [(TensorDataset(x_tr[:, p], y_tr), TensorDataset(x_te[:, p], y_te))for p in self.perms]
        self.input_dim = 784
        self.n_classes = 10
        self.num_tasks = num_tasks
    def get_task(self, tid):
        return self.tasks[tid]

class SplitMNIST:
    pairs = [(0,1), (2,3), (4,5), (6,7), (8,9)]
    def __init__(self):
        tfm = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
        tr = datasets.MNIST("./data", True,  download=True, transform=tfm)
        te = datasets.MNIST("./data", False, download=True, transform=tfm)
        x_tr = tr.data.float().view(-1,784)/255.
        x_te = te.data.float().view(-1,784)/255.
        self.tasks=[]
        for a, b in self.pairs:
            msk_tr = (tr.targets==a)|(tr.targets==b)
            msk_te = (te.targets==a)|(te.targets==b)
            y_tr = F.one_hot((tr.targets[msk_tr]==b).long(), 2).float()
            y_te = F.one_hot((te.targets[msk_te]==b).long(), 2).float()
            self.tasks.append((TensorDataset(x_tr[msk_tr], y_tr), TensorDataset(x_te[msk_te], y_te)))
        self.input_dim = 784
        self.n_classes = 2
        self.num_tasks = 5
    def get_task(self, tid):
        return self.tasks[tid]

In [None]:
class PlainMLP(nn.Module):
    def __init__(self, dims):
        super().__init__()
        layers=[]
        for din,dout in zip(dims[:-1],dims[1:]):
            layers.append(nn.Linear(din,dout))
            layers.append(nn.ReLU())
        layers.pop() # last one no relu
        self.net = nn.Sequential(*layers)
    def forward(self,x): return self.net(x)

class BayesianLinear(nn.Module):
    def __init__(self, in_f, out_f, prior_var=1.0):
        super().__init__()
        self.w_mu     = nn.Parameter(torch.empty(out_f,in_f))
        self.w_logvar = nn.Parameter(torch.full((out_f,in_f), -6.0))
        self.b_mu     = nn.Parameter(torch.empty(out_f))
        self.b_logvar = nn.Parameter(torch.full((out_f,), -6.0))
        nn.init.normal_(self.w_mu,0,0.1); nn.init.normal_(self.b_mu,0,0.1)
        self.register_buffer("pw_mu", torch.zeros_like(self.w_mu))
        self.register_buffer("pw_logvar", torch.full_like(self.w_mu, math.log(prior_var)))
        self.register_buffer("pb_mu", torch.zeros_like(self.b_mu))
        self.register_buffer("pb_logvar", torch.full_like(self.b_mu, math.log(prior_var)))
    def sample(self):
        ew = torch.randn_like(self.w_mu)
        eb = torch.randn_like(self.b_mu)
        w = self.w_mu + (0.5*self.w_logvar).exp()*ew
        b = self.b_mu + (0.5*self.b_logvar).exp()*eb
        return w,b
    def forward(self,x,sample=True):
        w,b = self.sample() if sample else (self.w_mu, self.b_mu)
        return F.linear(x,w,b)
    def helper_kl(self, m, lv, m0, lv0):
        v, v0 = lv.exp(), lv0.exp()
        return 0.5*((lv0-lv) + (v+(m-m0).pow(2))/v0 -1).sum()
    def kl(self):
        return self.helper_kl(self.w_mu,self.w_logvar,self.pw_mu,self.pw_logvar) + self.helper_kl(self.b_mu,self.b_logvar,self.pb_mu,self.pb_logvar)
    def update_prior(self):
        self.pw_mu.data.copy_(self.w_mu.data)
        self.pw_logvar.data.copy_(self.w_logvar.data)
        self.pb_mu.data.copy_(self.b_mu.data)
        self.pb_logvar.data.copy_(self.b_logvar.data)

class BayesianMLP(nn.Module):
    def __init__(self, in_dim, hidden, out_dim, heads=1, prior_var=1.0):
        super().__init__()
        self.hidden = nn.ModuleList()
        last = in_dim
        for h in hidden:
            self.hidden.append(BayesianLinear(last,h,prior_var))
            last = h
        self.heads = nn.ModuleList([BayesianLinear(last,out_dim,prior_var) for _ in range(heads)])
        self.out_dim = out_dim
    def add_head(self, out_dim):
        head = BayesianLinear(self.hidden[-1].w_mu.size(0), out_dim)
        head.to(next(self.parameters()).device); self.heads.append(head)
    def forward(self,x,head_id=0,sample=True):
        for l in self.hidden: x = torch.relu(l(x,sample))
        return self.heads[head_id](x,sample)
    def kl(self):
        return sum(l.kl() for l in self.hidden)+sum(h.kl() for h in self.heads)
    def update_prior(self):
        for l in self.hidden: l.update_prior()
        for h in self.heads:  h.update_prior()

In [None]:
class VCLReservoir:
    def __init__(self, in_dim, hidden, n_classes,
                 single_head=True, lr=1e-3, mc=10,
                 prior_var=1.0, coreset_size=100,
                 coreset_epochs=20,
                 weighting_scheme="uniform"):


        self.in_dim, self.hidden_sizes = in_dim, list(hidden)
        self.n_classes = n_classes
        self.single_head = single_head
        self.lr, self.mc = lr, mc
        self.prior_var = prior_var
        self.coreset_size = coreset_size
        self.coreset_epochs = coreset_epochs
        self.weighting_scheme = weighting_scheme

        # reservoir
        feature_dim = in_dim
        self.reservoir_x = torch.zeros((coreset_size, feature_dim), device=device, dtype=torch.float)
        self.reservoir_y = torch.zeros(coreset_size, device=device, dtype=torch.long)
        self.reservoir_task_ids = torch.zeros(coreset_size, device=device, dtype=torch.long)
        self.reservoir_heap = []
        self.current_fill = 0
        self.min_key_in_heap = float('inf')

        init_heads = 1
        initial_out_dim = n_classes
        self.model = BayesianMLP(in_dim=self.in_dim, hidden=self.hidden_sizes, out_dim=initial_out_dim, heads=init_heads, prior_var=self.prior_var).to(device)

        self.acc_hist = []

        print(f"Coreset Size {self.coreset_size}, w_function '{self.weighting_scheme}', coreset_epochs {self.coreset_epochs}")

    def calculate_weights(self, task_id, num_samples):
        if self.weighting_scheme == "uniform":
            return torch.ones(num_samples, device=device)
        elif self.weighting_scheme == "task_power_2": #debug
            weight_val = 2.0**task_id
            return torch.full((num_samples,), fill_value=weight_val, device=device)
        elif self.weighting_scheme == "geom_09":
            weight_val = 0.9**task_id
            return torch.full((num_samples,), fill_value=weight_val, device=device)
        else:
            print(f"not implemented yet")
            return torch.ones(num_samples, device=device)

    # give on device
    def update_reservoir(self, x_batch_gpu, y_batch_gpu, task_id):
        n_batch = x_batch_gpu.size(0)
        if n_batch == 0: return

        weights = self.calculate_weights(task_id, n_batch)

        if y_batch_gpu.ndim > 1 and y_batch_gpu.shape[1] > 1:
             y_batch_gpu = y_batch_gpu.argmax(dim=1)

        for i in range(n_batch):
            w_i = weights[i].item()
            rand_i = torch.rand(1).item()
            if rand_i < 1e-9: rand_i = 1e-9
            key_i = rand_i ** (1.0 / w_i)

            if self.current_fill < self.coreset_size:
                slot_idx = self.current_fill
                self.reservoir_x[slot_idx] = x_batch_gpu[i]
                self.reservoir_y[slot_idx] = y_batch_gpu[i]
                self.reservoir_task_ids[slot_idx] = task_id
                heapq.heappush(self.reservoir_heap, (key_i, slot_idx))
                self.current_fill += 1
                self.min_key_in_heap = self.reservoir_heap[0][0]

            elif key_i > self.min_key_in_heap:
                popped_key, slot_to_replace = heapq.heappop(self.reservoir_heap)
                heapq.heappush(self.reservoir_heap, (key_i, slot_to_replace))
                self.reservoir_x[slot_to_replace] = x_batch_gpu[i]
                self.reservoir_y[slot_to_replace] = y_batch_gpu[i]
                self.reservoir_task_ids[slot_to_replace] = task_id
                self.min_key_in_heap = self.reservoir_heap[0][0]

    # debug
    def print_reservoir_proportions(self):
        valid_task_ids = self.reservoir_task_ids[:self.current_fill]
        unique_tasks, counts = torch.unique(valid_task_ids, sorted=True, return_counts=True)
        props = []
        total_selected = self.current_fill
        for tid, count in zip(unique_tasks.tolist(), counts.tolist()):
            prop = count / total_selected * 100
            props.append(f"T{tid}={prop}% ({count})")

        print(f"Reservoir Composition {total_selected}/{self.coreset_size} points: {', '.join(props)}")


    @torch.no_grad()
    def evaluate(self, dset, t, model_to_eval):
        model_to_eval.eval() # eval time
        accs = []
        num_test_tasks = t + 1

        for task_idx in range(num_test_tasks):
            try:
                test_ds_tuple = dset.get_task(task_idx)
                test_ds = test_ds_tuple[1]

                xt, yt = test_ds.tensors[0].to(device), test_ds.tensors[1].to(device)
                loader_test = DataLoader(TensorDataset(xt, yt), batch_size=1024) # can adjust
                task_correct, task_total = 0, 0
                head_id_eval = 0 if self.single_head else task_idx

                for xb_test, yb_test in loader_test:
                    logits = model_to_eval(xb_test, head_id_eval, sample=False)
                    preds = logits.argmax(1)
                    targets = yb_test.argmax(1) if yb_test.ndim > 1 and yb_test.shape[1] > 1 else yb_test # Handle one-hot/scalar labels
                    task_correct += (preds == targets).sum().item()
                    task_total += xb_test.size(0)

                accs.append(task_correct / task_total)
            except Exception as e:
                 print(f"Error 111 task {task_idx}: {e}")
                 accs.append(float('nan'))

        acc_str = ", ".join([f"T{i}={acc}" for i, acc in enumerate(accs)])
        avg_acc = sum(accs) / len(accs) if accs else -233213.0
        print(f"After Task {t} Eval: Avg={avg_acc} | [{acc_str}]")
        self.acc_hist.append(accs)


    def fit(self, dset, epochs=120, batch_size=None):
        opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)

        for t in range(dset.num_tasks):
            task_start_time = time.time()

            if not self.single_head and t > 0:
                current_task_out_dim = dset.get_task(t)[0].tensors[1].shape[1]
                self.model.add_head(out_dim=current_task_out_dim)
                print(f"New head {t} for task {t} it's multihead")
                opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)



            train_ds_t_tuple = dset.get_task(t)
            train_ds_t = train_ds_t_tuple[0]

            x_task_full, y_task_full = train_ds_t.tensors[0].to(device), train_ds_t.tensors[1].to(device)

            if self.reservoir_y.dtype == torch.long and y_task_full.ndim > 1 and y_task_full.shape[1] > 1:
                y_task_full = y_task_full.argmax(dim=1)

            self.update_reservoir(x_task_full, y_task_full, t)
            self.print_reservoir_proportions()



            if x_task_full.size(0) > 0:
                current_batch_size_prop = min(batch_size, x_task_full.size(0))
                loader_prop = DataLoader(TensorDataset(x_task_full, y_task_full), batch_size=current_batch_size_prop, shuffle=True)
                num_train_samples = x_task_full.size(0)

                self.model.train()
                for epoch in tqdm(range(epochs), desc=f"Task {t} main loop", leave=False):
                    epoch_loss_prop = 0.0
                    for xb, yb in loader_prop:
                         opt.zero_grad()
                         nll = 0.0
                         head_id = 0 if self.single_head else t
                         for _ in range(self.mc):
                              logits = self.model(xb, head_id, sample=True)
                              targets = yb.argmax(1) if yb.ndim > 1 and yb.shape[1] > 1 else yb
                              nll += F.cross_entropy(logits, targets, reduction='sum')
                         kl = self.model.kl()
                         loss = (nll / xb.size(0) / self.mc) + (kl / num_train_samples)
                         epoch_loss_prop += loss.item() * xb.size(0)
                         loss.backward()
                         opt.step()
            else:
                print(f"Task {t}: no noncorset points check out")


            state_after_propagation = copy.deepcopy(self.model.state_dict())
            model_state_before_eval = state_after_propagation

            if self.coreset_size > 0 and self.current_fill > 0 and self.coreset_epochs > 0:
                x_core_re = self.reservoir_x[:self.current_fill]
                y_core_re = self.reservoir_y[:self.current_fill]
                head_id_re = 0

                self.model.update_prior()

                num_train_samples_coreset = x_core_re.size(0)
                current_batch_size_re = min(batch_size, num_train_samples_coreset)
                loader_core = DataLoader(TensorDataset(x_core_re, y_core_re), batch_size=current_batch_size_re, shuffle=True)

                self.model.train()

                for epoch in tqdm(range(self.coreset_epochs), desc=f"Coreset {t} loop", leave=False):
                      epoch_loss_re = 0.0
                      for xb, yb in loader_core:
                          opt.zero_grad()
                          nll = 0.0

                          for _ in range(self.mc):
                                logits = self.model(xb, head_id_re, sample=True)
                                targets = yb.argmax(1) if yb.ndim > 1 and yb.shape[1] > 1 else yb
                                nll += F.cross_entropy(logits, targets, reduction='sum')
                          kl = self.model.kl()
                          loss = (nll / xb.size(0) / self.mc) + (kl / num_train_samples_coreset)
                          epoch_loss_re += loss.item() * xb.size(0)
                          loss.backward()
                          opt.step()
                model_state_before_eval = copy.deepcopy(self.model.state_dict())
            else:
                 print(f"Task {t} no coreset")

            eval_model = BayesianMLP(in_dim=self.in_dim, hidden=self.hidden_sizes, out_dim=self.n_classes, heads=len(self.model.heads), prior_var=self.prior_var).to(device)
            eval_model.load_state_dict(model_state_before_eval)

            self.evaluate(dset, t, eval_model)
            del eval_model

            self.model.load_state_dict(state_after_propagation)
            self.model.update_prior()

            task_end_time = time.time()
            print(f"Task {t} Complete took: {task_end_time - task_start_time}s")


In [None]:
vcl_trainer = VCLReservoir(in_dim=784, hidden=[100, 100], n_classes=10,
single_head=True, lr=1e-3, mc=3, prior_var=1.0,
coreset_size=2000, coreset_epochs=30, weighting_scheme="uniform"
)

permuted_mnist_data = PermutedMNIST(num_tasks=20, seed = 100)
vcl_trainer.fit(permuted_mnist_data, epochs=100, batch_size=256)

In [None]:
# the run above wasn't exactly the run I used in the paper, so it's slightly different but it's clearly consistent with the results from the paper