In [None]:
import math, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
import numpy as np, random, os, time, tarfile, gc
import copy
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device", device)
torch.manual_seed(1); random.seed(1); np.random.seed(1)

In [None]:
class PermutedMNIST:
    def __init__(self, num_tasks=10, seed=123):
        shuffler = torch.Generator().manual_seed(seed)
        self.perms = [torch.randperm(784, generator=shuffler) for _ in range(num_tasks)]
        tfm = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
        tr = datasets.MNIST("./data", True,  download=True, transform=tfm)
        te = datasets.MNIST("./data", False, download=True, transform=tfm)
        x_tr = tr.data.float().view(-1,784)/255.
        x_te = te.data.float().view(-1,784)/255.
        y_tr = F.one_hot(tr.targets, 10).float()
        y_te = F.one_hot(te.targets, 10).float()
        self.tasks = [(TensorDataset(x_tr[:, p], y_tr), TensorDataset(x_te[:, p], y_te))for p in self.perms]
        self.input_dim = 784
        self.n_classes = 10
        self.num_tasks = num_tasks
    def get_task(self, tid):
        return self.tasks[tid]

class SplitMNIST:
    pairs = [(0,1), (2,3), (4,5), (6,7), (8,9)]
    def __init__(self):
        tfm = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
        tr = datasets.MNIST("./data", True,  download=True, transform=tfm)
        te = datasets.MNIST("./data", False, download=True, transform=tfm)
        x_tr = tr.data.float().view(-1,784)/255.
        x_te = te.data.float().view(-1,784)/255.
        self.tasks=[]
        for a, b in self.pairs:
            msk_tr = (tr.targets==a)|(tr.targets==b)
            msk_te = (te.targets==a)|(te.targets==b)
            y_tr = F.one_hot((tr.targets[msk_tr]==b).long(), 2).float()
            y_te = F.one_hot((te.targets[msk_te]==b).long(), 2).float()
            self.tasks.append((TensorDataset(x_tr[msk_tr], y_tr), TensorDataset(x_te[msk_te], y_te)))
        self.input_dim = 784
        self.n_classes = 2
        self.num_tasks = 5
    def get_task(self, tid):
        return self.tasks[tid]

In [None]:
class PlainMLP(nn.Module):
    def __init__(self, dims):
        super().__init__()
        layers=[]
        for din,dout in zip(dims[:-1],dims[1:]):
            layers.append(nn.Linear(din,dout))
            layers.append(nn.ReLU())
        layers.pop() # last one no relu
        self.net = nn.Sequential(*layers)
    def forward(self,x): return self.net(x)

class BayesianLinear(nn.Module):
    def __init__(self, in_f, out_f, prior_var=1.0):
        super().__init__()
        self.w_mu     = nn.Parameter(torch.empty(out_f,in_f))
        self.w_logvar = nn.Parameter(torch.full((out_f,in_f), -6.0))
        self.b_mu     = nn.Parameter(torch.empty(out_f))
        self.b_logvar = nn.Parameter(torch.full((out_f,), -6.0))
        nn.init.normal_(self.w_mu,0,0.1); nn.init.normal_(self.b_mu,0,0.1)
        self.register_buffer("pw_mu", torch.zeros_like(self.w_mu))
        self.register_buffer("pw_logvar", torch.full_like(self.w_mu, math.log(prior_var)))
        self.register_buffer("pb_mu", torch.zeros_like(self.b_mu))
        self.register_buffer("pb_logvar", torch.full_like(self.b_mu, math.log(prior_var)))
    def _sample(self):
        ew = torch.randn_like(self.w_mu)
        eb = torch.randn_like(self.b_mu)
        w = self.w_mu + (0.5*self.w_logvar).exp()*ew
        b = self.b_mu + (0.5*self.b_logvar).exp()*eb
        return w,b
    def forward(self,x,sample=True):
        w,b = self._sample() if sample else (self.w_mu, self.b_mu)
        return F.linear(x,w,b)
    def helper_kl(self, m, lv, m0, lv0):
        v, v0 = lv.exp(), lv0.exp()
        return 0.5*((lv0-lv) + (v+(m-m0).pow(2))/v0 -1).sum()
    def kl(self):
        return self.helper_kl(self.w_mu,self.w_logvar,self.pw_mu,self.pw_logvar) + self.helper_kl(self.b_mu,self.b_logvar,self.pb_mu,self.pb_logvar)
    def update_prior(self):
        self.pw_mu.data.copy_(self.w_mu.data)
        self.pw_logvar.data.copy_(self.w_logvar.data)
        self.pb_mu.data.copy_(self.b_mu.data)
        self.pb_logvar.data.copy_(self.b_logvar.data)

class BayesianMLP(nn.Module):
    def __init__(self, in_dim, hidden, out_dim, heads=1, prior_var=1.0):
        super().__init__()
        self.hidden = nn.ModuleList()
        last = in_dim
        for h in hidden:
            self.hidden.append(BayesianLinear(last,h,prior_var))
            last = h
        self.heads = nn.ModuleList([BayesianLinear(last,out_dim,prior_var) for _ in range(heads)])
        self.out_dim = out_dim
    def add_head(self, out_dim):
        head = BayesianLinear(self.hidden[-1].w_mu.size(0), out_dim)
        head.to(next(self.parameters()).device); self.heads.append(head)
    def forward(self,x,head_id=0,sample=True):
        for l in self.hidden: x = torch.relu(l(x,sample))
        return self.heads[head_id](x,sample)
    def kl(self):
        return sum(l.kl() for l in self.hidden)+sum(h.kl() for h in self.heads)
    def update_prior(self):
        for l in self.hidden: l.update_prior()
        for h in self.heads:  h.update_prior()


In [None]:
class VCLGlobalSigma:
    def __init__(self, in_dim, hidden, n_classes,
                 single_head=True, lr=1e-4, mc=3,
                 init_log_sigma=-3.0,
                 beta=1, prior_var=1.0,
                 coreset_size=0, coreset_epochs=0):

        self.in_dim, self.hidden_sizes = in_dim, list(hidden)
        self.n_classes = n_classes
        self.single_head = single_head
        self.beta            = beta
        self.lr, self.mc = lr, mc
        self.prior_var = prior_var
        self.print_freq      = 5 # debug change
        self.coreset_size    = coreset_size
        self.coreset_epochs = coreset_epochs

        self.model = BayesianMLP(in_dim=self.in_dim, hidden=self.hidden_sizes, out_dim=self.n_classes, heads=1, prior_var=self.prior_var).to(device)

        self.log_sigma = nn.Parameter(torch.tensor(float(init_log_sigma)))
        self.model.register_parameter("log_sigma", self.log_sigma)

        self.rmse_hist = []
        self.core_x, self.core_y = [], []

        init_sig2 = torch.exp(2 * self.log_sigma).item()
        print(f"sigma init={init_sig2} " + f"beta = {self.beta}, lr={self.lr}, coreset size of {self.coreset_size}")

    # for LR just do random other method is pretty much identical in performance to permuted mnist
    def coreset_selection(self, x_full, y_full):
        n_samples = x_full.size(0)
        if self.coreset_size <= 0:
            return None, None, x_full, y_full # don't do coresets.
        if self.coreset_size >= n_samples:
             print("Coreset size > train size")
             return x_full, y_full, None, None

        perm = torch.randperm(n_samples, device=x_full.device)
        core_idx = perm[:self.coreset_size]
        non_core_mask = torch.ones(n_samples, dtype=torch.bool, device=x_full.device)
        non_core_mask[core_idx] = False

        core_x, core_y = x_full[core_idx], y_full[core_idx]
        non_core_x, non_core_y = x_full[non_core_mask], y_full[non_core_mask]

        return core_x, core_y, non_core_x, non_core_y

    @torch.no_grad()
    def evaluate(self, dset, t, model_to_eval):
        model_to_eval.eval()
        rmses = []
        overall_sq_sum = 0.0
        overall_n = 0
        num_eval_tasks = t + 1
        for task_idx in range(num_eval_tasks):
            _, test_ds = dset.get_task(task_idx)
            if len(test_ds) == 0:
                 rmses.append(float('nan'))
                 continue
            loader = DataLoader(test_ds, batch_size=1024, shuffle=False)
            sq_sum, n_elem = 0.0, 0
            head_id_eval = 0
            for xb_test, yb_test in loader:
                xb_test, yb_test = xb_test.to(device), yb_test.to(device)
                pred = model_to_eval(xb_test, head_id=head_id_eval, sample=False)
                sq_sum += (pred - yb_test).pow(2).sum().item()
                n_elem += yb_test.numel()
                overall_sq_sum += (pred - yb_test).pow(2).sum().item()
                overall_n += yb_test.numel()
            rmse = math.sqrt(sq_sum / n_elem) if n_elem > 0 else 0.0
            rmses.append(rmse)
        valid_rmses = [r for r in rmses if not math.isnan(r)]
        overall = math.sqrt(overall_sq_sum/overall_n) if valid_rmses else -1223334444
        self.rmse_hist.append((rmses, overall))
        tasks_str = ", ".join(f"T{i}={r}" for i, r in enumerate(rmses))
        print(f"After Task {t}: full RMSE={overall} | [{tasks_str}]")

    def fit(self, dset, epochs=50, batch_size=256):
        start_fit_time = time.time()
        head_id = 0

        opt = torch.optim.Adam(self.model.parameters(), lr=self.lr)

        for t in range(dset.num_tasks):
            task_start_time = time.time()
            tr_ds, _ = dset.get_task(t)

            x, y = tr_ds.tensors[0].to(device), tr_ds.tensors[1].to(device)

            cx, cy, x_nc, y_nc = self.coreset_selection(x, y)
            if cx is not None:
                self.core_x.append(cx.cpu())
                self.core_y.append(cy.cpu())

            if x_nc is not None and x_nc.size(0) > 0:
                num_train_samples_nc = x_nc.size(0)
                current_batch_size_nc = min(batch_size, num_train_samples_nc)
                loader_nc = DataLoader(TensorDataset(x_nc, y_nc), batch_size=current_batch_size_nc, shuffle=True)

                for ep in tqdm(range(epochs), desc=f"Task {t} learning", leave=False):
                    self.model.train()
                    for batch_idx, (xb, yb) in enumerate(loader_nc):
                        xb, yb = xb.to(device), yb.to(device)
                        opt.zero_grad()

                        batch_nll = 0.0; batch_mse = 0.0
                        inv_var = torch.exp(-2 * self.log_sigma)
                        out_dim = yb.size(1) if yb.ndim > 1 else 1
                        log_2pi = math.log(2 * math.pi)

                        for _ in range(self.mc):
                            out = self.model(xb, head_id=head_id, sample=True)
                            se_sum = (out - yb).pow(2).sum()
                            nll_mc_sample = (0.5 * xb.size(0) * out_dim * log_2pi) + (xb.size(0) * out_dim * self.log_sigma) + (0.5 * inv_var * se_sum)
                            batch_nll += nll_mc_sample
                            batch_mse += se_sum.item()
                        nll = batch_nll / self.mc
                        mse = batch_mse / (self.mc * xb.size(0) * out_dim)
                        kl = self.model.kl()

                        loss = nll / xb.size(0) + self.beta * kl / xb.size(0)

                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                        opt.step()

                        # debug
                        if (batch_idx == 0 and ep % self.print_freq == 0):
                            nll_s = nll.item() / xb.size(0)
                            kl_s = kl.item() / xb.size(0)
                            ratio = (self.beta * kl_s) / nll_s
                            sig2 = torch.exp(2 * self.log_sigma).item()
                            print(f"Task {t} Ep {ep} " + f"L={loss.item()} NLL={nll_s} + b*KL={self.beta * kl_s}) "
                                  + f" KL/NLL={ratio}  RMSE={math.sqrt(mse)}  var={sig2}")
                print(f"Task {t} non coreset done")
            else:
                print(f"Task {t}: error 123")

            # store so coreset doesn't override
            state_after_propagation = copy.deepcopy(self.model.state_dict())

            model_state_before_eval = state_after_propagation

            if self.coreset_epochs > 0 and self.core_x:
                cx_all = torch.cat(self.core_x).to(device); cy_all = torch.cat(self.core_y).to(device)
                num_train_samples_core = cx_all.size(0)
                if num_train_samples_core > 0:
                    current_batch_size_core = min(batch_size, num_train_samples_core)
                    loader_core = DataLoader(TensorDataset(cx_all, cy_all), batch_size=current_batch_size_core, shuffle=True)

                    self.model.update_prior()

                    for ep in tqdm(range(self.coreset_epochs), desc=f"Task {t} rem", leave=False):
                        self.model.train()
                        for batch_idx, (xb, yb) in enumerate(loader_core):
                            xb, yb = xb.to(device), yb.to(device)
                            opt.zero_grad()

                            batch_nll = 0.0; batch_mse = 0.0
                            inv_var = torch.exp(-2 * self.log_sigma)
                            out_dim = yb.size(1) if yb.ndim > 1 else 1
                            log_2pi = math.log(2 * math.pi)

                            for _ in range(self.mc):
                                out = self.model(xb, head_id=head_id, sample=True)
                                se_sum = (out - yb).pow(2).sum()
                                nll_mc_sample = (0.5 * xb.size(0) * out_dim * log_2pi) + (xb.size(0) * out_dim * self.log_sigma) + (0.5 * inv_var * se_sum)
                                batch_nll += nll_mc_sample
                                batch_mse += se_sum.item()

                            nll = batch_nll / self.mc
                            mse = batch_mse / (self.mc * xb.size(0) * out_dim)
                            kl = self.model.kl()
                            loss = nll / xb.size(0) + self.beta * kl / xb.size(0)

                            loss.backward()
                            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                            opt.step()

                    print(f"Task {t} coreset done.")
                    model_state_before_eval = copy.deepcopy(self.model.state_dict())
                else:
                     print(f"Task {t} no coreset error tr43")
            else:
                print(f"Task {t} no coreset")

            eval_model = BayesianMLP(in_dim=self.in_dim, hidden=self.hidden_sizes, out_dim=self.n_classes, heads=1, prior_var=self.prior_var).to(device)
            eval_model.register_parameter("log_sigma", self.log_sigma)
            eval_model.load_state_dict(model_state_before_eval)
            self.evaluate(dset, t, eval_model)

            self.model.load_state_dict(state_after_propagation)

            del eval_model

            self.model.update_prior()

            task_end_time = time.time()

        end_fit_time = time.time()
        print(f"took {end_fit_time - start_fit_time} s")

In [None]:
trainer = VCLGlobalSigma(in_dim=784, hidden=[100, 100],
n_classes=10, single_head=True, init_log_sigma=-1.5, beta=1.0, lr=1e-4, mc=3,
prior_var=1.0, coreset_size=200, coreset_epochs=0)

permuted_mnist_data = PermutedMNIST(num_tasks=10)
trainer.fit(permuted_mnist_data, epochs=50, batch_size=256)