In [None]:
import math, torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
import numpy as np, random, os, time, tarfile, gc
import copy
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device", device)
#torch.manual_seed(1); random.seed(1); np.random.seed(1)

device cuda


In [None]:
class PermutedMNIST:
    def __init__(self, num_tasks=10, seed=123):
        shuffler = torch.Generator().manual_seed(seed)
        self.perms = [torch.randperm(784, generator=shuffler) for _ in range(num_tasks)]
        tfm = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
        tr = datasets.MNIST("./data", True,  download=True, transform=tfm)
        te = datasets.MNIST("./data", False, download=True, transform=tfm)
        x_tr = tr.data.float().view(-1,784)/255.
        x_te = te.data.float().view(-1,784)/255.
        y_tr = F.one_hot(tr.targets, 10).float()
        y_te = F.one_hot(te.targets, 10).float()
        self.tasks = [(TensorDataset(x_tr[:, p], y_tr), TensorDataset(x_te[:, p], y_te))for p in self.perms]
        self.input_dim = 784
        self.n_classes = 10
        self.num_tasks = num_tasks
    def get_task(self, tid):
        return self.tasks[tid]

class SplitMNIST:
    pairs = [(0,1), (2,3), (4,5), (6,7), (8,9)]
    def __init__(self):
        tfm = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
        tr = datasets.MNIST("./data", True,  download=True, transform=tfm)
        te = datasets.MNIST("./data", False, download=True, transform=tfm)
        x_tr = tr.data.float().view(-1,784)/255.
        x_te = te.data.float().view(-1,784)/255.
        self.tasks=[]
        for a, b in self.pairs:
            msk_tr = (tr.targets==a)|(tr.targets==b)
            msk_te = (te.targets==a)|(te.targets==b)
            y_tr = F.one_hot((tr.targets[msk_tr]==b).long(), 2).float()
            y_te = F.one_hot((te.targets[msk_te]==b).long(), 2).float()
            self.tasks.append((TensorDataset(x_tr[msk_tr], y_tr), TensorDataset(x_te[msk_te], y_te)))
        self.input_dim = 784
        self.n_classes = 2
        self.num_tasks = 5
    def get_task(self, tid):
        return self.tasks[tid]

In [None]:
def accuracy(model, loader, head=None):
    model.eval()
    correct = n = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = model(x, head).argmax(1)
            correct += (preds == y.argmax(1)).sum().item()
            n += y.size(0)
    return correct / n

def gaussian_nll(pred, target, log_sig, reduction='sum'):
    inv_var = torch.exp(-2 * log_sig)
    sq_err  = (pred - target).pow(2).sum(dim=1)
    const = 0.5 * pred.size(1) * (math.log(2 * math.pi) + 2 * log_sig)
    nll   = 0.5 * inv_var * sq_err + const

    if reduction == 'mean':
        return nll.mean()
    elif reduction == 'sum':
        return nll.sum()

def rmse(model, loader, task_id=None):
    se, n = 0.0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x, task_id)
            se += (pred - y).pow(2).sum().item()
            n  += y.numel()
    return math.sqrt(se / n)

In [None]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_layers, out_dim,
                 multi_head=False, num_tasks=10):
        super().__init__()
        layers = []
        dim = input_dim
        for h in hidden_layers:
            layers += [nn.Linear(dim, h), nn.ReLU()]
            dim = h
        self.encoder = nn.Sequential(*layers)

        self.multi_head = multi_head
        if multi_head:
            self.heads = nn.ModuleList([nn.Linear(dim, out_dim) for _ in range(num_tasks)])
        else:
            self.head = nn.Linear(dim, out_dim)

    def forward(self, x, task_id=None):
        h = self.encoder(x)
        if self.multi_head:
            return self.heads[task_id](h)
        return self.head(h)

In [None]:
def train_task(model, reg, loader, task_id,
               epochs=20, lr=1e-3, print_every=5):
    if hasattr(reg, "begin_task"):          # SI needs a hook
        reg.begin_task()

    opt = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()

    for ep in range(epochs):
        pbar = tqdm(loader, desc=f"T{task_id} ep{ep+1}", leave=False)
        for step, (x, y) in enumerate(pbar, 1):
            x, y = x.to(device), y.to(device)

            log_sig = model.log_sigma

            opt.zero_grad()
            pred = model(x, task_id)
            nll  = gaussian_nll(pred, y, log_sig)
            rg   = reg.penalty()
            loss = nll + rg
            loss.backward()
            opt.step()
            if hasattr(reg, "accumulate_omega"):   # SI hook
                reg.accumulate_omega()

            if step % print_every == 0 or step == len(loader):
                pbar.set_postfix(ll=f"{nll.item()}", reg=f"{rg.item()}", sig2=f"{math.exp(2*log_sig.item())}")
        pbar.close()

In [None]:
def run_experiment(dataset, hidden=[100, 100], reg_type="EWC",
                   lam=100, epochs=20, lr=1e-3, batch=256, multi_head = False):

    model = MLP(dataset.input_dim, hidden, dataset.n_classes, multi_head=multi_head, num_tasks=dataset.num_tasks).to(device)

    model.register_parameter("log_sigma",torch.nn.Parameter(torch.tensor(-3.0, device=device)))

    BASELINES = {"EWC": EWC, "LP": LP, "SI": SI, "LP_fancy": LP_fancy}
    reg = BASELINES[reg_type](model, lam)

    rmse_mat = torch.zeros(dataset.num_tasks, dataset.num_tasks)

    for tid in range(dataset.num_tasks):
        train_ds, _ = dataset.get_task(tid)
        loader = DataLoader(train_ds, batch)

        train_task(model, reg, loader, tid if multi_head else None, epochs=epochs, lr=lr)

        reg.update_stats(loader, tid if multi_head else None)

        # eval
        row = []
        task_rmses = []
        for j in range(tid + 1):
            _, test_ds = dataset.get_task(j)
            test_loader = DataLoader(test_ds, batch)
            r = rmse(model, test_loader, j if multi_head else None)
            rmse_mat[tid, j] = r
            task_rmses.append(r)
            row.append(f"T{j}:{r}")


        valid_rmses_tensor = torch.tensor([r**2 for r in task_rmses], device=device) # note requires same nr for elements so only permuted mnist is valid
        overall_rmse = torch.sqrt(valid_rmses_tensor.mean()).item()
        print(f"{reg_type} Task {tid}  avg={overall_rmse}  |  " + " ".join(row))

    return rmse_mat

In [None]:
class EWC:
    def __init__(self, model: nn.Module, lam: float = 400.0):
        self.model = model.to(device)
        self.lam   = lam
        self.saved_means   = []
        self.saved_fishers = []

    def flat_params(self):
        return torch.cat([p.view(-1) for p in self.model.parameters()])

    # diag Fisher
    def diag_fisher(self, loader, task_id=None):
        fisher = torch.zeros_like(self.flat_params())
        for x, y in loader:
            x, y = x.to(device), y.to(device)

            self.model.zero_grad()
            loss = gaussian_nll(self.model(x, task_id), y, self.model.log_sigma, reduction='sum')
            loss.backward()

            grads2 = [(p.grad.detach() if p.grad is not None else torch.zeros_like(p)).view(-1) ** 2 for p in self.model.parameters()]
            fisher += torch.cat(grads2)

        return fisher / len(loader.dataset)

    def penalty(self):
        if not self.saved_means:
            return torch.tensor(0.0, device=device)

        theta = self.flat_params()
        total = 0.0
        for mu, F_diag in zip(self.saved_means, self.saved_fishers):
            total += (F_diag * (theta - mu) ** 2).sum()
        return 0.5 * self.lam * total

    def update_stats(self, loader, task_id=None):
        f_diag = self.diag_fisher(loader, task_id).detach()

        self.saved_means.append(self.flat_params().detach().clone())
        self.saved_fishers.append(f_diag)

In [None]:
class SI:
    def __init__(self, model, lam=0.1, damping=1e-1):
        self.model   = model
        self.lam     = lam
        self.damping = damping

        # state
        self.big_omega   = None
        self.star_params = None

        # task helps
        self.prev_params  = None
        self.last_params  = None
        self.omega_accum  = None

    def params(self):
        return [p for p in self.model.parameters() if p.requires_grad]

    def begin_task(self):
        #start of task run this
        self.prev_params  = [p.detach().clone() for p in self.params()]
        self.last_params  = [p.detach().clone() for p in self.params()]
        self.omega_accum  = [torch.zeros_like(p) for p in self.params()]

    def accumulate_omega(self):
        # after opt run this
        for w_acc, p, last in zip(self.omega_accum, self.params(), self.last_params):
            if p.grad is not None:
                delta = p.detach() - last
                w_acc += (-p.grad.detach()) * delta
                last.copy_(p.detach())

    def update_stats(self, loader, task_id=None):
        # debug
        if self.prev_params is None: print("begin task not working")

        if self.big_omega is None:
            self.big_omega = [torch.zeros_like(w) for w in self.omega_accum]

        for big_O, w_acc, p, p_start in zip(self.big_omega, self.omega_accum, self.params(), self.prev_params):
            delta_total = p.detach() - p_start
            update = w_acc / (delta_total.pow(2) + self.damping)
            big_O += torch.clamp(update, min=0.0) # let's only allow positive main impl does similar

        self.star_params = [p.detach().clone() for p in self.params()]

        # reset helpers
        self.prev_params  = None
        self.last_params  = None
        self.omega_accum  = None

    def penalty(self):
        if self.big_omega is None or self.star_params is None:
            return torch.tensor(0.0, device=self.params()[0].device)

        total = 0.0
        for big_O, p, p_star in zip(self.big_omega, self.params(), self.star_params):
            total += (big_O * (p - p_star).pow(2)).sum()
        return 0.5 * self.lam * total


In [None]:
class LP:
    def __init__(self, model, lam=1.0):
        self.model, self.lam = model, lam
        self.prev_mu = None
        self.precision = None          # diag prec here

    def flat_params(self):
        return torch.cat([p.view(-1) for p in self.model.parameters()])

    def diag_hessian(self, loader, task_id):
        h = torch.zeros_like(torch.cat([p.view(-1) for p in self.model.parameters()]))

        for x, y in loader:
            x, y = x.to(device), y.to(device)

            self.model.zero_grad()

            loss = gaussian_nll(self.model(x, task_id), y, self.model.log_sigma, reduction='sum')
            loss.backward()

            # squared grad
            h += torch.cat([(p.grad if p.grad is not None else torch.zeros_like(p)).view(-1).pow(2) for p in self.model.parameters()])

        return h / len(loader.dataset)

    def update_stats(self, loader, task_id=None):
        self.prev_mu = self.flat_params().detach().clone()
        hdiag = self.diag_hessian(loader, task_id).detach()
        if self.precision is None:
            self.precision = hdiag
        else:
            self.precision += hdiag

    def penalty(self):
        if self.prev_mu is None:
            return torch.tensor(0., device=device)
        delta = self.flat_params() - self.prev_mu
        return 0.5 * self.lam * (self.precision * delta.pow(2)).sum()

In [None]:
class LP_fancy:
    def __init__(self, model, lam=1.0, eps=1e-3):
        self.model  = model
        self.lam    = lam
        self.eps    = eps

        # state helpers
        self.mu_prev      = None
        self.prec_weights = None
        self.prec_biases  = None

    @staticmethod
    def linear_layers(model):
        return [m for m in model.modules() if isinstance(m, nn.Linear)]

    def collect_activs(self, x, task_id):
        activs = []
        hooks = [m.register_forward_hook(lambda m,i,o: activs.append(i[0].detach())) for m in self.linear_layers(self.model)]
        with torch.no_grad():
            self.model(x.to(device)[:1], task_id)
        for h in hooks: h.remove()
        return activs

    def kfac_blocks(self, loader, task_id):
        L = len(self.linear_layers(self.model))
        Q_sum  = [None]*L
        B_sum  = [None]*L
        bias_h = [torch.zeros_like(m.bias) for m in self.linear_layers(self.model)]

        for x, y in loader:
            x, y = x.to(device), y.to(device)
            self.model.zero_grad()
            loss = gaussian_nll(self.model(x, task_id), y, self.model.log_sigma, reduction='sum')
            loss.backward()

            activs = self.collect_activs(x, task_id)

            for idx, m in enumerate(self.linear_layers(self.model)):

                Q = activs[idx].t() @ activs[idx] / activs[idx].size(0)
                Q += self.eps * torch.eye(Q.size(0), device=device)   # shift eigenvalues, illcond
                Q_sum[idx] = Q if Q_sum[idx] is None else Q_sum[idx] + Q

                g = m.weight.grad.detach()
                B = (g @ g.t()) / x.size(0)
                B += self.eps * torch.eye(B.size(0), device=device)   # shift eigenvalues
                B_sum[idx] = B if B_sum[idx] is None else B_sum[idx] + B

                bias_h[idx] += (m.bias.grad.detach()**2) / x.size(0)

        num_batches = len(loader)
        kfac_blocks = [(Q / num_batches, B / num_batches) for Q, B in zip(Q_sum, B_sum)]
        bias_blocks = [b / num_batches + self.eps for b in bias_h]  # shift eigenvalues numerical stability

        return kfac_blocks, bias_blocks

    def penalty(self):
        if self.mu_prev is None:
            return torch.tensor(0.0, device=next(self.model.parameters()).device)

        total = 0.0
        for (Q, B), (W_star, b_star), m, bias_prec in zip(self.prec_weights, self.mu_prev, self.linear_layers(self.model), self.prec_biases):

            dW = m.weight - W_star
            total += (B @ dW @ Q * dW).sum()

            db = m.bias - b_star
            total += (bias_prec * db.pow(2)).sum()

        return 0.5 * self.lam * total

    def update_stats(self, loader, task_id=None):
        self.mu_prev = [(m.weight.detach().clone(), m.bias.detach().clone()) for m in self.linear_layers(self.model)]

        blocks, bias_blocks = self.kfac_blocks(loader, task_id)

        if self.prec_weights is None:
            self.prec_weights = blocks
            self.prec_biases  = bias_blocks
            return

        for (Q_old, B_old), (Q_new, B_new) in zip(self.prec_weights, blocks):
            Q_old += Q_new
            B_old += B_new

        for i in range(len(self.prec_biases)):
            self.prec_biases[i] += bias_blocks[i]


In [None]:
perm = PermutedMNIST(num_tasks=10)
acc_ewc = run_experiment(perm,
                         hidden=[100, 100],
                         reg_type="LP",
                         lam=100,
                         epochs=1,
                         batch=256,
                         lr = 1e-4)

TNone ep1:   0%|          | 0/235 [00:00<?, ?it/s]

LP Task 0  avg=0.18981243669986725  |  T0:0.1898124414317644


TNone ep1:   0%|          | 0/235 [00:00<?, ?it/s]

LP Task 1  avg=0.19708625972270966  |  T0:0.20504885126353486 T1:0.18878810323656595


TNone ep1:   0%|          | 0/235 [00:00<?, ?it/s]

LP Task 2  avg=0.20193132758140564  |  T0:0.21881347071136228 T1:0.19428038162889044 T2:0.19158436870364318


TNone ep1:   0%|          | 0/235 [00:00<?, ?it/s]

LP Task 3  avg=0.21214042603969574  |  T0:0.24429119167664495 T1:0.20569968740312472 T2:0.19584289180749584 T3:0.19917141425753623


TNone ep1:   0%|          | 0/235 [00:00<?, ?it/s]

LP Task 4  avg=0.217952698469162  |  T0:0.25866898935379595 T1:0.21601839531610373 T2:0.20534806174345058 T3:0.20487228197176555 T4:0.19950648390240974


TNone ep1:   0%|          | 0/235 [00:00<?, ?it/s]

LP Task 5  avg=0.21869894862174988  |  T0:0.26378811048361034 T1:0.22322139668674626 T2:0.20547656439364234 T3:0.20455596740669568 T4:0.20291455834747138 T5:0.20573127233530217


TNone ep1:   0%|          | 0/235 [00:00<?, ?it/s]

LP Task 6  avg=0.22130106389522552  |  T0:0.2705490180604564 T1:0.22959463566850313 T2:0.21054059911090225 T3:0.2074979964699593 T4:0.2098757427422818 T5:0.20768038530357705 T6:0.20578358956341594


TNone ep1:   0%|          | 0/235 [00:00<?, ?it/s]

LP Task 7  avg=0.22664503753185272  |  T0:0.28020909584442727 T1:0.23912127306158024 T2:0.21950294523935207 T3:0.21414700428125205 T4:0.21765622355727277 T5:0.21231309244175975 T6:0.21066804940588776 T7:0.2106539981700747


TNone ep1:   0%|          | 0/235 [00:00<?, ?it/s]

LP Task 8  avg=0.2305169701576233  |  T0:0.29584623076213856 T1:0.25176927703725743 T2:0.22290624401884554 T3:0.21658418279874847 T4:0.22116822614122392 T5:0.21173506895566255 T6:0.21439251917662647 T7:0.2162958273867431 T8:0.21033035016652227


TNone ep1:   0%|          | 0/235 [00:00<?, ?it/s]

LP Task 9  avg=0.23424573242664337  |  T0:0.29944598254357463 T1:0.26492422623801676 T2:0.233489629253786 T3:0.21754772511114553 T4:0.22444346959570674 T5:0.21608675076947342 T6:0.21613213725133637 T7:0.21805158795218307 T8:0.21461846705709972 T9:0.22276527898472206
