# Experiments with Fisher

This whole script takes ~45mins to run with an H100

In [1]:
!uv pip install ipykernel jupyter

[2mUsing Python 3.12.12 environment at: /pvc/repos/open-r1_safety/openr1_v2[0m
[2mAudited [1m2 packages[0m [2min 17ms[0m[0m


## Setup 'Update' and 'Protect' texts

In [2]:
import math
import copy
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.notebook import tqdm # this makes tqdm.write() work with notebooks!
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW

from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
from datasets import load_dataset, load_from_disk

from trl.trainer.sft_trainer import DataCollatorForLanguageModeling
from copy import deepcopy

In [3]:
def add_generation_chat_template(tokenizer):
    if "qwen" in tokenizer.name_or_path.lower():
        # we have to use DataCollatorForLanguageModeling with completion_only_loss=True
        # however, for that tokenizer needs to have return_assistant_tokens_mask=True, and qwen decided against adding support for {% generation %} / {% endgeneration %} functionality
        # so we download a community qwen3 chat template that has it
        !wget -O all_assistant.jinja --no-check-certificate https://raw.githubusercontent.com/HarryMayne/qwen_3_chat_templates/refs/heads/main/all_assistant.jinja
        !mv all_assistant.jinja chat_templates/qwen_all_assistant.jinja
        with open('chat_templates/qwen_all_assistant.jinja', 'r') as f:
            tokenizer.chat_template = f.read()
    if "smollm2" in tokenizer.name_or_path.lower():
        with open('chat_templates/smollm2_all_assistant.jinja', 'r') as f:
            tokenizer.chat_template = f.read()
    if "llama" in tokenizer.name_or_path.lower():
        with open('chat_templates/llama3_all_assistant.jinja', 'r') as f:
            tokenizer.chat_template = f.read()
    return tokenizer

def load_or_preprocess_dataset(model_id, dataset_id, tokenizer, max_length=4096):
    local_ds_id = f"datasets/{model_id}/{dataset_id}"
    num_proc = 16
    if True:
        print(f"Dataset not found locally, processing and caching...")
        raw_dataset = load_dataset(dataset_id)["train"]
        def preprocess(example):
            tokenized = tokenizer.apply_chat_template(
                example["messages"],
                tokenize=True,
                return_assistant_tokens_mask=True,
                return_dict=True,
            )
            return {
                "input_ids": tokenized["input_ids"],
                "assistant_masks": tokenized["assistant_masks"],
            }
        
        tokenized_dataset = raw_dataset.map(preprocess, remove_columns=raw_dataset.column_names, num_proc=num_proc, desc="Tokenizing")
        def shorter_than(example):
            return len(example["input_ids"]) <= max_length
        final_dataset = tokenized_dataset.filter(shorter_than, num_proc=num_proc, desc=f"Filtering to max length {max_length}")
        print(f"Tokenized: {len(tokenized_dataset)}, After filtering: {len(final_dataset)}")
        final_dataset.save_to_disk(local_ds_id)
    return final_dataset

def create_dataloader(tokenized_dataset, batch_size):
    collator = DataCollatorForLanguageModeling(pad_token_id=tokenizer.pad_token_id, completion_only_loss=True)
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collator,
    )
    return dataloader

In [4]:
from transformers import set_seed
random_seed = 42
set_seed(42)

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# model_name = "HuggingFaceTB/SmolLM2-135M-Instruct"
model_name = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # for batching
tokenizer = add_generation_chat_template(tokenizer)
print(tokenizer.chat_template)

batch_size = 1 #8 for smollm2, 1 for llama

Device: cuda
{{- bos_token }}
{%- if custom_tools is defined %}
    {%- set tools = custom_tools %}
{%- endif %}
{%- if not tools_in_user_message is defined %}
    {%- set tools_in_user_message = true %}
{%- endif %}
{%- if not date_string is defined %}
    {%- set date_string = "26 Jul 2024" %}
{%- endif %}
{%- if not tools is defined %}
    {%- set tools = none %}
{%- endif %}

{#- This block extracts the system message, so we can slot it into the right place. #}
{%- if messages[0]['role'] == 'system' %}
    {%- set system_message = messages[0]['content']|trim %}
    {%- set messages = messages[1:] %}
{%- else %}
    {%- set system_message = "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>
...
</think>
<answer>
...
</answer>" %}
{%- endif %}

{#- System message + builtin tools #}
{{- "<|start_header

In [6]:
update_id = "Neelectric/OpenR1-Math-220k_extended_Llama3_4096toks"
update_ds = load_or_preprocess_dataset(model_name, update_id, tokenizer, 1024)

Dataset not found locally, processing and caching...
Tokenized: 86158, After filtering: 3608


Saving the dataset (0/1 shards):   0%|          | 0/3608 [00:00<?, ? examples/s]

In [7]:
retain_id = "Neelectric/wildguardmix_Llama-3.1-8B-Instruct_4096toks"
retain_ds = load_or_preprocess_dataset(model_name, retain_id, tokenizer, 1024)

Dataset not found locally, processing and caching...
Tokenized: 86745, After filtering: 74766


Saving the dataset (0/2 shards):   0%|          | 0/74766 [00:00<?, ? examples/s]

In [8]:
full_length = len(update_ds) //2
print(full_length)
update_ds = update_ds.select(range(full_length))
retain_ds = retain_ds.select(range(full_length))
print(len(retain_ds))
update_ds = update_ds.shuffle(seed=random_seed)
retain_ds = retain_ds.shuffle(seed=random_seed)

1804
1804


In [9]:
# Train / test splits
num_train = int(0.8 * full_length)
print(num_train)
retain_train_ds = retain_ds.select(range(num_train))
retain_test_ds = retain_ds.select(range(num_train, full_length))
print(len(retain_train_ds))
print(len(retain_test_ds))

update_train_ds = update_ds.select(range(num_train))
update_test_ds = update_ds.select(range(num_train, full_length))
print(len(update_train_ds))
print(len(update_test_ds))

1443
1443
361
1443
361


In [10]:
# block_size = 64
# batch_size = 8

# retain_train_ds  = LineByLineLMDataset(retain_train, tokenizer, block_size)
# retain_test_ds   = LineByLineLMDataset(retain_test,  tokenizer, block_size)
# update_train_ds = LineByLineLMDataset(update_train, tokenizer, block_size)
# update_test_ds  = LineByLineLMDataset(update_test,  tokenizer, block_size)

# update_loader = DataLoader(update_train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
# ba_loader  = DataLoader(retain_train_ds,  batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# len(update_train_ds), len(retain_train_ds)

update_loader = create_dataloader(update_train_ds, batch_size)
retain_loader = create_dataloader(update_test_ds, batch_size)

In [11]:
@torch.no_grad()
def eval_ppl(model, dataset, name, batch_size_eval=8, disable_tqdm=True):
    model.eval()
    loader = create_dataloader(dataset, batch_size_eval)
    total_loss = 0.0
    total_tokens = 0
    for batch in tqdm(loader, disable=disable_tqdm):
        batch = {k: v.to(device) for k, v in batch.items()}
        out = model(**batch)
        loss = out.loss * batch["attention_mask"].sum()
        total_loss += loss.item()
        total_tokens += batch["attention_mask"].sum().item()
    ppl = math.exp(total_loss / total_tokens)
    print(f"{name} perplexity: {ppl:.3f}")
    model.train()
    return ppl


## Before any optim: ppl on train and protect before fine-tuning

In [12]:
base_model = AutoModelForCausalLM.from_pretrained(model_name)
model = copy.deepcopy(base_model).to(device)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [13]:
eng_ppl = eval_ppl(model, update_test_ds, "Baseline update before optim", batch_size_eval=8, disable_tqdm=False)
ba_ppl = eval_ppl(model, retain_test_ds, "Baseline retain before optim", batch_size_eval=8, disable_tqdm=False)

  0%|          | 0/46 [00:00<?, ?it/s]

Baseline update before optim perplexity: 2.866


  0%|          | 0/46 [00:00<?, ?it/s]

Baseline retain before optim perplexity: 1.734


In [14]:
from torch.optim import AdamW

def run_baseline_adam(num_epochs=2, lr=1e-5):
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    model = copy.deepcopy(base_model).to(device)
    optimizer = AdamW(model.parameters(), lr=lr)

    print(f"Before opt:")
    update_ppl = eval_ppl(model, update_test_ds, "Update new")
    retain_ppl  = eval_ppl(model, retain_test_ds,  "Retain protected")

    print("=== Baseline Adam: train on Update only ===")
    for epoch in range(num_epochs):
        for step, batch in enumerate(update_loader):
            batch = {k: v.to(device) for k, v in batch.items()}
            model.zero_grad()
            out = model(**batch)
            loss = out.loss
            loss.backward()
            optimizer.step()

            if (step + 1) % 100 == 0:
                print(f"[Epoch {epoch} Step {step+1}] loss_new = {loss.item():.4f}")

        print(f"Epoch {epoch} evaluation:")
        update_ppl = eval_ppl(model, update_test_ds, "Update new")
        retain_ppl  = eval_ppl(model, retain_test_ds,  "Retain protected")
    return model
# SmolLM2 report mentions they use lr=3e-4 throughout SFT (page 8 section 5.2) https://arxiv.org/pdf/2502.02737
baseline_model = run_baseline_adam(num_epochs=4, lr=3e-4)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [11]:
def estimate_fisher_on_retain(model, num_batches=200):
    model.eval()
    params = [p for p in model.parameters() if p.requires_grad]
    fisher = {p: torch.zeros_like(p.data) for p in params}

    # loader = DataLoader(retain_train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    loader = create_dataloader(retain_train_ds, batch_size)
    it = iter(loader)
    for i in tqdm(range(num_batches)):
        try:
            batch = next(it)
        except StopIteration:
            it = iter(loader)
            batch = next(it)
        batch = {k: v.to(device) for k, v in batch.items()}
        model.zero_grad()
        out = model(**batch)
        loss = out.loss
        loss.backward()
        for p in params:
            if p.grad is None:
                continue
            fisher[p] += p.grad.data.pow(2)
    for p in params:
        fisher[p] /= num_batches
    model.train()
    return fisher

def run_ewc(num_epochs=2, lr=5e-5, ewc_lambda=50.0):
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    model = copy.deepcopy(base_model).to(device)

    print("Estimating Fisher on Retain (protected) ...")
    fisher = estimate_fisher_on_retain(model, num_batches=100)
    theta0 = copy.deepcopy(model).to(device)

    optimizer = AdamW(model.parameters(), lr=lr)

    params = [p for p in model.parameters() if p.requires_grad]

    print("=== EWC: train on Update with Retain EWC penalty ===")
    for epoch in range(num_epochs):
        for step, batch in enumerate(update_loader):
            batch = {k: v.to(device) for k, v in batch.items()}
            model.zero_grad()
            out = model(**batch)
            loss_new = out.loss

            ewc_loss = 0.0
            for p, p0 in zip(params, theta0.parameters()):
                ewc_loss = ewc_loss + (fisher[p] * (p - p0).pow(2)).sum()
            total_loss = loss_new + 0.5 * ewc_lambda * ewc_loss

            total_loss.backward()
            optimizer.step()

            if (step + 1) % 100 == 0:
                print(f"[Epoch {epoch} Step {step+1}] loss_new={loss_new.item():.4f}, ewc_loss={ewc_loss.item():.4f}")

        print(f"Epoch {epoch} evaluation:")
        update_ppl = eval_ppl(model, update_test_ds, "Update new (EWC)")
        ba_ppl  = eval_ppl(model, retain_test_ds,  "Retain protected (EWC)")
    return model

ewc_model = run_ewc(num_epochs=3, lr=3e-4, ewc_lambda=50.0)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Estimating Fisher on Retain (protected) ...


  0%|          | 0/100 [00:00<?, ?it/s]

=== EWC: train on Update with Retain EWC penalty ===


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.96 GiB. GPU 0 has a total capacity of 139.81 GiB of which 1.91 GiB is free. Process 453783 has 137.89 GiB memory in use. Of the allocated memory 135.57 GiB is allocated by PyTorch, and 1.60 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [16]:
def run_protected_adam(
    num_epochs=2,
    lr=5e-5,
    alpha_geom=1.0,
    beta_geom=10.0,
    gamma_exp=0.5,
    subset_update_every=5,
    rho_all=0.99,
    rho_sub=0.99,
):
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    model = copy.deepcopy(base_model).to(device)
    model.train()

    params = [p for p in model.parameters() if p.requires_grad]

    state = {}
    for p in params:
        state[p] = {
            "m": torch.zeros_like(p.data),
            "v_all": torch.zeros_like(p.data),
            "v_sub": torch.zeros_like(p.data),
        }

    beta1 = 0.9
    eps = 1e-6
    global_step = 0

    def protected_adam_step():
        nonlocal global_step
        global_step += 1
        for p in params:
            if p.grad is None:
                continue
            grad = p.grad.data
            s = state[p]

            # first moment
            s["m"].mul_(beta1).add_(grad, alpha=1 - beta1)

            # second moment on "all" (new Update) data
            s["v_all"].mul_(rho_all).addcmul_(grad, grad, value=1 - rho_all)

            v_all = s["v_all"]
            v_sub = s["v_sub"]
            v_protect = alpha_geom * v_all + beta_geom * v_sub

            m_hat = s["m"] / (1 - beta1**global_step)
            denom = (v_protect + eps).pow(gamma_exp)
            step = m_hat / denom
            p.data.add_(step, alpha=-lr)

    def update_subset_curvature():
        for p in params:
            if p.grad is None:
                continue
            grad = p.grad.data
            s = state[p]
            s["v_sub"].mul_(rho_sub).addcmul_(grad, grad, value=1 - rho_sub)

    retain_iter = iter(retain_loader)

    print("=== ProtectedAdam-γ: geometry shaped by Retain subset ===")
    print(f"alpha_geom={alpha_geom}, beta_geom={beta_geom}, gamma_exp={gamma_exp}")
    for epoch in range(num_epochs):
        for step, batch in enumerate(update_loader):
            # 1) Update batch: gradient for new task
            batch = {k: v.to(device) for k, v in batch.items()}
            model.zero_grad()
            out = model(**batch)
            loss = out.loss
            loss.backward()

            # 2) Take ProtectedAdam step (updates v_all + params)
            protected_adam_step()

            # 3) Occasionally update subset curvature using Retain
            if (step + 1) % subset_update_every == 0:
                try:
                    retain_batch = next(retain_iter)
                except StopIteration:
                    retain_iter = iter(retain_loader)
                    retain_batch = next(retain_iter)
                retain_batch = {k: v.to(device) for k, v in retain_batch.items()}
                model.zero_grad()
                retain_out = model(**retain_batch)
                retain_loss = retain_out.loss
                retain_loss.backward()
                update_subset_curvature()
                model.zero_grad()

            if (step + 1) % 100 == 0:
                print(f"[Epoch {epoch} Step {step+1}] loss_new = {loss.item():.4f}")

        print(f"Epoch {epoch} evaluation:")
        update_ppl = eval_ppl(model, update_test_ds, "Update new (ProtectedAdam-γ)")
        retain_ppl  = eval_ppl(model, retain_test_ds,  "Retain protected (ProtectedAdam-γ)")

    return model

protected_model = run_protected_adam(
    num_epochs=3,
    lr=3e-4,
    alpha_geom=1.0,
    beta_geom=10.0,   # strength of protected geometry
    gamma_exp=0.5,    # between 0.5 (Adam) and 1.0 (diag NGD)
    subset_update_every=5,
)


=== ProtectedAdam-γ: geometry shaped by Retain subset ===
alpha_geom=1.0, beta_geom=10.0, gamma_exp=0.5
[Epoch 0 Step 100] loss_new = 1.1522
Epoch 0 evaluation:
Update new (ProtectedAdam-γ) perplexity: 2.854
Retain protected (ProtectedAdam-γ) perplexity: 4.918
[Epoch 1 Step 100] loss_new = 0.9689
Epoch 1 evaluation:
Update new (ProtectedAdam-γ) perplexity: 2.734
Retain protected (ProtectedAdam-γ) perplexity: 5.008
[Epoch 2 Step 100] loss_new = 0.8218
Epoch 2 evaluation:
Update new (ProtectedAdam-γ) perplexity: 2.670
Retain protected (ProtectedAdam-γ) perplexity: 5.125


In [17]:
def run_protected_adam2(
    num_epochs=3,
    lr=5e-5,
    alpha_geom=1.0,
    beta_geom=10.0,
    gamma_exp=0.5,
    subset_update_every=5,
    rho_all=0.99,
    rho_sub=0.99,
):
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    model = copy.deepcopy(base_model).to(device)
    model.train()

    params = [p for p in model.parameters() if p.requires_grad]

    state = {}
    for p in params:
        state[p] = {
            "m": torch.zeros_like(p.data),
            "v_all": torch.zeros_like(p.data),
            "v_sub": torch.zeros_like(p.data),
        }

    beta1 = 0.9
    eps = 1e-6
    global_step = 0

    def protected_adam_step():
        nonlocal global_step
        global_step += 1

        # First pass: update moments, compute v_protect, and accumulate
        # the mean denominators for γ=0.5 (baseline) and γ=gamma_exp
        temp = {}
        sum_baseline = 0.0
        sum_gamma = 0.0
        count_tensors = 0

        for p in params:
            if p.grad is None:
                continue
            grad = p.grad.data
            s = state[p]

            # First moment
            s["m"].mul_(beta1).add_(grad, alpha=1 - beta1)

            # Second moment on "all" (new Update) data
            s["v_all"].mul_(rho_all).addcmul_(grad, grad, value=1 - rho_all)

            v_all = s["v_all"]
            v_sub = s["v_sub"]
            v_protect = alpha_geom * v_all + beta_geom * v_sub

            # Bias-corrected first moment (optional but keeps Adam-like behaviour)
            m_hat = s["m"] / (1 - beta1**global_step)

            denom_baseline = (v_protect + eps).pow(0.5)
            denom_gamma = (v_protect + eps).pow(gamma_exp)

            sum_baseline += denom_baseline.mean()
            sum_gamma += denom_gamma.mean()
            count_tensors += 1

            temp[p] = {
                "m_hat": m_hat,
                "v_protect": v_protect,
            }

        if count_tensors == 0:
            return

        # Renormalization factor so that average step size matches γ=0.5 case
        scale = (sum_baseline / sum_gamma).detach()

        # Second pass: apply update with renormalized step size
        for p in params:
            if p.grad is None or p not in temp:
                continue
            buf = temp[p]
            m_hat = buf["m_hat"]
            v_protect = buf["v_protect"]

            denom_gamma = (v_protect + eps).pow(gamma_exp)
            step = (m_hat / denom_gamma) * scale
            p.data.add_(step, alpha=-lr)

    def update_subset_curvature():
        for p in params:
            if p.grad is None:
                continue
            grad = p.grad.data
            s = state[p]
            s["v_sub"].mul_(rho_sub).addcmul_(grad, grad, value=1 - rho_sub)

    retain_iter = iter(retain_loader)

    print(f"Before opt:")
    update_ppl = eval_ppl(model, update_test_ds, "Update new (ProtectedAdam-γ)")
    retain_ppl  = eval_ppl(model, retain_test_ds,  "Retain protected (ProtectedAdam-γ)")


    print("=== ProtectedAdam-γ: geometry shaped by Retain subset ===")
    print(f"alpha_geom={alpha_geom}, beta_geom={beta_geom}, gamma_exp={gamma_exp}")
    for epoch in range(num_epochs):
        for step, batch in enumerate(update_loader):
            # 1) Update batch: gradient for new task
            batch = {k: v.to(device) for k, v in batch.items()}
            model.zero_grad()
            out = model(**batch)
            loss = out.loss
            loss.backward()

            # 2) Take ProtectedAdam step (updates v_all + params)
            protected_adam_step()

            # 3) Occasionally update subset curvature using Retain
            if (step + 1) % subset_update_every == 0:
                try:
                    retain_batch = next(retain_iter)
                except StopIteration:
                    retain_iter = iter(retain_loader)
                    retain_batch = next(retain_iter)
                retain_batch = {k: v.to(device) for k, v in retain_batch.items()}
                model.zero_grad()
                retain_out = model(**retain_batch)
                retain_loss = retain_out.loss
                retain_loss.backward()
                update_subset_curvature()
                model.zero_grad()

            if (step + 1) % 100 == 0:
                print(f"[Epoch {epoch} Step {step+1}] loss_new = {loss.item():.4f}")

        print(f"Epoch {epoch} evaluation:")
        update_ppl = eval_ppl(model, update_test_ds, "Update new (ProtectedAdam-γ)")
        retain_ppl  = eval_ppl(model, retain_test_ds,  "Retain protected (ProtectedAdam-γ)")

    return model


protected_model2 = run_protected_adam2(
    num_epochs=3,
    lr=3e-4,
    alpha_geom=1.0,
    beta_geom=10.0,   # strength of protected geometry
    gamma_exp=0.5,    # between 0.5 (Adam) and 1.0 (diag NGD)
    subset_update_every=5,
)


Before opt:
Update new (ProtectedAdam-γ) perplexity: 4.730
Retain protected (ProtectedAdam-γ) perplexity: 4.484
=== ProtectedAdam-γ: geometry shaped by Retain subset ===
alpha_geom=1.0, beta_geom=10.0, gamma_exp=0.5
[Epoch 0 Step 100] loss_new = 1.3416
Epoch 0 evaluation:
Update new (ProtectedAdam-γ) perplexity: 2.853
Retain protected (ProtectedAdam-γ) perplexity: 4.900
[Epoch 1 Step 100] loss_new = 1.0404
Epoch 1 evaluation:
Update new (ProtectedAdam-γ) perplexity: 2.734
Retain protected (ProtectedAdam-γ) perplexity: 4.990
[Epoch 2 Step 100] loss_new = 1.0139
Epoch 2 evaluation:
Update new (ProtectedAdam-γ) perplexity: 2.670
Retain protected (ProtectedAdam-γ) perplexity: 5.085


In [18]:
def run_replay(
    num_epochs=2,
    lr=5e-5,
    subset_update_every=5,
    replay_weight=1.0,   # λ: strength of Retain replay loss
):
    """
    Experience Replay baseline.

    - Optimizes Update CE loss every step.
    - Every `subset_update_every` steps, also optimizes Retain CE.
    - Total loss = CE_update + replay_weight * CE_Retain.
    - Uses plain AdamW.
    - No curvature, no shielding, no geometry.
    """

    # Load model
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    model = copy.deepcopy(base_model).to(device)
    model.train()

    optimizer = AdamW(model.parameters(), lr=lr)

    # Retain iterator for replay
    retain_iter = iter(retain_loader)

    print("=== Replay baseline: Update training + Retain replay ===")
    print(f"subset_update_every={subset_update_every}, replay_weight={replay_weight}")

    print("Before opt:")
    eval_ppl(model, update_test_ds, "Update new (replay)")
    eval_ppl(model, retain_test_ds,  "Retain protected (replay)")

    for epoch in range(num_epochs):
        for step, batch in enumerate(update_loader):
            batch = {k: v.to(device) for k, v in batch.items()}

            # Update forward/backward
            model.zero_grad()
            out = model(**batch)
            loss_new = out.loss
            total_loss = loss_new

            # Retain replay every N steps
            if (step + 1) % subset_update_every == 0:
                try:
                    ba_batch = next(retain_iter)
                except StopIteration:
                    retain_iter = iter(retain_loader)
                    ba_batch = next(retain_iter)
                ba_batch = {k: v.to(device) for k, v in ba_batch.items()}

                ba_out = model(**ba_batch)
                ba_loss = ba_out.loss

                total_loss = loss_new + replay_weight * ba_loss

            # Backprop + update
            total_loss.backward()
            optimizer.step()

            # Logging
            if (step + 1) % 100 == 0:
                if (step + 1) % subset_update_every == 0:
                    print(
                        f"[Epoch {epoch} Step {step+1}] "
                        f"loss_new={loss_new.item():.4f}, "
                        f"loss_replay={ba_loss.item():.4f}, "
                        f"total={total_loss.item():.4f}"
                    )
                else:
                    print(f"[Epoch {epoch} Step {step+1}] loss_new={loss_new.item():.4f}")

        # End epoch eval
        print(f"Epoch {epoch} evaluation:")
        eval_ppl(model, update_test_ds, "Update new (replay)")
        eval_ppl(model, retain_test_ds,  "Retain protected (replay)")

    return model


replay_model = run_replay(
    num_epochs=3,
    lr=3e-4,
    subset_update_every=5,
    replay_weight=1.0,
)


=== Replay baseline: Update training + Retain replay ===
subset_update_every=5, replay_weight=1.0
Before opt:
Update new (replay) perplexity: 4.728
Retain protected (replay) perplexity: 4.502
[Epoch 0 Step 100] loss_new=0.9579, loss_replay=0.9523, total=1.9102
Epoch 0 evaluation:
Update new (replay) perplexity: 2.200
Retain protected (replay) perplexity: 5.595
[Epoch 1 Step 100] loss_new=0.7064, loss_replay=0.7346, total=1.4410
Epoch 1 evaluation:
Update new (replay) perplexity: 1.863
Retain protected (replay) perplexity: 6.173
[Epoch 2 Step 100] loss_new=0.5434, loss_replay=0.5460, total=1.0894
Epoch 2 evaluation:
Update new (replay) perplexity: 1.609
Retain protected (replay) perplexity: 6.981


In [19]:
def estimate_fisher_retain(model, num_batches=200):
    model.eval()
    fisher = {
        name: torch.zeros_like(p.data)
        for name, p in model.named_parameters()
        if p.requires_grad
    }

    loader = create_dataloader(retain_train_ds, batch_size)
    it = iter(loader)

    for i in range(num_batches):
        try:
            batch = next(it)
        except StopIteration:
            it = iter(loader)
            batch = next(it)

        batch = {k: v.to(device) for k, v in batch.items()}
        model.zero_grad()
        out = model(**batch)
        loss = out.loss
        loss.backward()

        for name, p in model.named_parameters():
            if not p.requires_grad or p.grad is None:
                continue
            fisher[name] += p.grad.data.pow(2)

    for name in fisher:
        fisher[name] /= num_batches

    model.train()
    return fisher


In [20]:
def estimate_model_fisher_retain(model, num_batches=200, top_k=100):
    """
    Compute *model Fisher* diagonal using KL(p_ref || p_model),
    with optional top-K truncation of the reference distribution.

    top_k < 0  → use full distribution (no truncation)
    top_k > 0  → keep only top_k tokens in reference distribution
    """

    # Freeze reference model θ0
    ref_model = copy.deepcopy(model).eval().to(device)
    for p in ref_model.parameters():
        p.requires_grad = False

    model.eval()

    fisher = {
        name: torch.zeros_like(p.data)
        for name, p in model.named_parameters()
        if p.requires_grad
    }

    loader = create_dataloader(retain_train_ds, batch_size)
    it = iter(loader)

    for i in range(num_batches):
        try:
            batch = next(it)
        except StopIteration:
            it = iter(loader)
            batch = next(it)

        batch = {k: v.to(device) for k, v in batch.items()}

        # ---- 1. Reference distribution ----
        with torch.no_grad():
            ref_logits = ref_model(**batch).logits
            ref_probs_full = ref_logits.softmax(dim=-1)  # shape [B, T, V]

        # ---- 2. Possibly truncate to top-K ----
        if top_k is not None and top_k > 0:
            # Get top-K indices for each token
            top_vals, top_idx = torch.topk(ref_probs_full, k=top_k, dim=-1)
            # Renormalize probs over top-K
            ref_probs = top_vals / top_vals.sum(dim=-1, keepdim=True)
            # Make a tensor of zeros [B,T,V]
            ref_probs_k = torch.zeros_like(ref_probs_full)
            # Scatter top-K probabilities back into vocab dimension
            ref_probs_k.scatter_(-1, top_idx, ref_probs)
            ref_probs = ref_probs_k
        else:
            # use full distribution
            ref_probs = ref_probs_full

        # ---- 3. Model logits ----
        logits = model(**batch).logits
        log_probs = logits.log_softmax(dim=-1)

        # ---- 4. KL(p_ref || p_model) ----
        # KL per token: Σ_i q_i log(q_i/p_i)
        kl = (ref_probs * (ref_probs.log() - log_probs)).sum(dim=-1)
        loss = kl.mean()

        # ---- 5. Backprop = model Fisher at θ0 ----
        model.zero_grad()
        loss.backward()

        # ---- 6. Accumulate grad^2 ----
        for name, p in model.named_parameters():
            if not p.requires_grad or p.grad is None:
                continue
            fisher[name] += p.grad.data.pow(2)

    # Average
    for name in fisher:
        fisher[name] /= num_batches

    model.train()
    return fisher


In [21]:
def run_protected_adam_precomputed(
    num_epochs=2,
    lr=5e-5,
    alpha_geom=1.0,
    beta_geom=10.0,
    gamma_exp=0.5,
    rho_all=0.99,
    fisher_sub=None,   # dict[name -> tensor]
):
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    model = copy.deepcopy(base_model).to(device)
    model.train()

    # We'll work with named parameters for alignment
    named_params = [
        (name, p) for name, p in model.named_parameters()
        if p.requires_grad
    ]

    state = {}
    for name, p in named_params:
        if fisher_sub is not None and name in fisher_sub:
            v_sub_init = fisher_sub[name].clone().to(device)
        else:
            v_sub_init = torch.zeros_like(p.data)

        state[name] = {
            "m": torch.zeros_like(p.data),
            "v_all": torch.zeros_like(p.data),
            "v_sub": v_sub_init,
        }

    beta1 = 0.9
    eps = 1e-6
    global_step = 0

    def protected_adam_step():
        nonlocal global_step
        global_step += 1
        for name, p in named_params:
            if p.grad is None:
                continue
            grad = p.grad.data
            s = state[name]

            # first moment
            s["m"].mul_(beta1).add_(grad, alpha=1 - beta1)

            # second moment on "all" (new Update) data
            s["v_all"].mul_(rho_all).addcmul_(grad, grad, value=1 - rho_all)

            v_all = s["v_all"]
            v_sub = s["v_sub"]  # fixed precomputed Fisher
            v_protect = alpha_geom * v_all + beta_geom * v_sub

            m_hat = s["m"] / (1 - beta1**global_step)
            denom = (v_protect + eps).pow(gamma_exp)
            step = m_hat / denom
            p.data.add_(step, alpha=-lr)

    print("=== ProtectedAdam-γ with precomputed Retain Fisher ===")
    for epoch in range(num_epochs):
        for step, batch in enumerate(update_loader):
            batch = {k: v.to(device) for k, v in batch.items()}
            model.zero_grad()
            out = model(**batch)
            loss = out.loss
            loss.backward()
            protected_adam_step()

        print(f"Epoch {epoch} evaluation:")
        eval_ppl(model, update_test_ds, "Update new (precomputed-Fisher)")
        eval_ppl(model, retain_test_ds,  "Retain protected (precomputed-Fisher)")

    return model


# 1. Make a base model for Fisher estimation
base_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)


# 2. Estimate Fisher on Retain ONCE
mfisher_retain = estimate_model_fisher_retain(base_model, num_batches=200)

# 3. Run Update finetuning using precomputed Fisher, no Retain batches
protected_model_pre_mfisher = run_protected_adam_precomputed(
    num_epochs=3,
    lr=1e-5,
    alpha_geom=1.0,
    beta_geom=10.0,
    gamma_exp=0.5,
    rho_all=0.99,
    fisher_sub=mfisher_retain,   # <- pass the dict here
)


# 2. Estimate Fisher on Retain ONCE
fisher_retain = estimate_fisher_retain(base_model, num_batches=200)

# 3. Run Update finetuning using precomputed Fisher, no Retain batches
protected_model_pre = run_protected_adam_precomputed(
    num_epochs=3,
    lr=3e-4,
    alpha_geom=1.0,
    beta_geom=10.0,
    gamma_exp=0.5,
    rho_all=0.99,
    fisher_sub=fisher_retain,   # <- pass the dict here
)


=== ProtectedAdam-γ with precomputed Retain Fisher ===
Epoch 0 evaluation:
Update new (precomputed-Fisher) perplexity: 3.732
Retain protected (precomputed-Fisher) perplexity: 4.396
Epoch 1 evaluation:
Update new (precomputed-Fisher) perplexity: 3.462
Retain protected (precomputed-Fisher) perplexity: 4.436
Epoch 2 evaluation:
Update new (precomputed-Fisher) perplexity: 3.329
Retain protected (precomputed-Fisher) perplexity: 4.441
=== ProtectedAdam-γ with precomputed Retain Fisher ===
Epoch 0 evaluation:
Update new (precomputed-Fisher) perplexity: 2.894
Retain protected (precomputed-Fisher) perplexity: 4.497
Epoch 1 evaluation:
Update new (precomputed-Fisher) perplexity: 2.758
Retain protected (precomputed-Fisher) perplexity: 4.599
Epoch 2 evaluation:
Update new (precomputed-Fisher) perplexity: 2.690
Retain protected (precomputed-Fisher) perplexity: 4.679


In [22]:
# does not work, ignore for now
def run_protected_adam_precomputed2(
    num_epochs=2,
    lr=5e-5,
    alpha_geom=1.0,      # scale for v_all (Adam geometry)
    beta_geom=10.0,      # strength of protection retainom v_sub
    gamma_exp=0.5,       # exponent applied only to normalized v_sub
    rho_all=0.99,
    fisher_sub=None,     # dict[name -> tensor], precomputed Fisher on Retain
):
    """
    Protected Adam with precomputed Fisher (additive version).

    - v_all: EMA of grad^2 on Update (new task), like Adam.
    - v_sub: fixed Fisher retainom Retain (protected capability), precomputed.
    - v_sub is normalized globally once to be dimensionless.

    Update (per-parameter i):
        v_all_i ← EMA of g_i^2
        v_sub_i ≈ Fisher_i

        v_sub_scaled_i = v_sub_i / global_mean(v_sub)

        base_rms_i   = sqrt(alpha_geom * v_all_i)
        protect_i    = beta_geom * (v_sub_scaled_i ** gamma_exp)

        denom_i = base_rms_i + protect_i + eps

        Δθ_i = -lr * m_hat_i / denom_i

    Properties:
      - If fisher_sub is None or beta_geom = 0 -> exactly Adam.
      - If v_sub is small -> denom ≈ base_rms -> Adam-like.
      - If v_sub is large -> extra additive penalty in denom -> stronger protection.
    """

    # 1) Start retainom base GPT-2
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    model = copy.deepcopy(base_model).to(device)
    model.train()

    # 2) Collect named parameters to align with fisher_sub[name]
    named_params = [
        (name, p) for name, p in model.named_parameters()
        if p.requires_grad
    ]

    # 3) Initialize state (m, v_all, v_sub)
    state = {}
    for name, p in named_params:
        v_sub_init = torch.zeros_like(p.data)
        if fisher_sub is not None and name in fisher_sub:
            v_sub_init = fisher_sub[name].clone().to(p.data.device)
        state[name] = {
            "m": torch.zeros_like(p.data),
            "v_all": torch.zeros_like(p.data),
            "v_sub": v_sub_init,
        }

    # 4) Compute a global mean of v_sub for normalization (dimensionless)
    if fisher_sub is not None:
        total_sum = 0.0
        total_count = 0
        for name, p in named_params:
            v_sub = state[name]["v_sub"]
            if v_sub.numel() > 0:
                total_sum += v_sub.sum().item()
                total_count += v_sub.numel()
        if total_count > 0:
            global_vsub_mean = total_sum / total_count
        else:
            global_vsub_mean = 1.0
    else:
        global_vsub_mean = 1.0

    beta1 = 0.9
    eps = 1e-8
    global_step = 0

    print("=== ProtectedAdam-precomputed2 (additive): Adam base + Fisher protection ===")
    print(
        f"alpha_geom={alpha_geom}, beta_geom={beta_geom}, "
        f"gamma_exp={gamma_exp}, rho_all={rho_all}, "
        f"global_vsub_mean={global_vsub_mean:.3e}"
    )

    for epoch in range(num_epochs):
        for step, batch in enumerate(update_loader):
            batch = {k: v.to(device) for k, v in batch.items()}

            # ----- 1) Forward/backward on Update (new task) -----
            model.zero_grad()
            out = model(**batch)
            loss = out.loss
            loss.backward()

            global_step += 1

            # ----- 2) Protected Adam step (additive Fisher term) -----
            with torch.no_grad():
                for name, p in named_params:
                    if p.grad is None:
                        continue

                    g = p.grad.data
                    s = state[name]

                    # First moment (Adam)
                    s["m"].mul_(beta1).add_(g, alpha=1 - beta1)

                    # Second moment on "all" (new Update) data (Adam-style)
                    s["v_all"].mul_(rho_all).addcmul_(g, g, value=1 - rho_all)

                    v_all = s["v_all"]
                    v_sub = s["v_sub"]

                    # Base Adam geometry: sqrt of v_all (scaled)
                    base_rms = (alpha_geom * v_all).sqrt()

                    # Normalized protective curvature retainom v_sub (dimensionless)
                    if fisher_sub is not None and beta_geom != 0.0 and global_vsub_mean > 0.0:
                        v_sub_scaled = v_sub / (global_vsub_mean + 1e-12)
                        v_sub_scaled = torch.clamp(v_sub_scaled, min=0.0)  # safety
                        protect_term = beta_geom * v_sub_scaled.pow(gamma_exp)
                    else:
                        protect_term = 0.0

                    m_hat = s["m"] / (1 - beta1**global_step)

                    # ADDITIVE protection: denom = base_rms + protective term
                    denom = base_rms + protect_term + eps
                    step_dir = m_hat / denom

                    p.data.add_(step_dir, alpha=-lr)

            if (step + 1) % 100 == 0:
                print(f"[Epoch {epoch} Step {step+1}] loss_new = {loss.item():.4f}")

        # ----- 3) Epoch-end evaluation -----
        print(f"Epoch {epoch} evaluation:")
        update_ppl = eval_ppl(model, update_test_ds, "Update new (ProtAdam-pre2-add)")
        retain_ppl  = eval_ppl(model, retain_test_ds,  "Retain protected (ProtAdam-pre2-add)")

    return model


# 3. Run Update finetuning using precomputed Fisher, no Retain batches
protected_model_pre = run_protected_adam_precomputed2(
    num_epochs=3,
    lr=3e-4,
    alpha_geom=1.0,
    beta_geom=0.1,
    gamma_exp=0.5,
    rho_all=0.99,
    fisher_sub=fisher_retain,   # <- pass the dict here
)



=== ProtectedAdam-precomputed2 (additive): Adam base + Fisher protection ===
alpha_geom=1.0, beta_geom=0.1, gamma_exp=0.5, rho_all=0.99, global_vsub_mean=2.467e-08
[Epoch 0 Step 100] loss_new = 1.6104
Epoch 0 evaluation:
Update new (ProtAdam-pre2-add) perplexity: 4.041
Retain protected (ProtAdam-pre2-add) perplexity: 4.445
[Epoch 1 Step 100] loss_new = 1.2619
Epoch 1 evaluation:
Update new (ProtAdam-pre2-add) perplexity: 3.767
Retain protected (ProtAdam-pre2-add) perplexity: 4.421
[Epoch 2 Step 100] loss_new = 1.3105
Epoch 2 evaluation:
Update new (ProtAdam-pre2-add) perplexity: 3.594
Retain protected (ProtAdam-pre2-add) perplexity: 4.397


In [23]:
def run_adam_with_fisher_trust_region(
    num_epochs=2,
    lr=1e-5,
    beta1=0.9,
    beta2=0.999,
    eps=1e-8,
    fisher_sub=None,    # dict[name -> tensor] retrain om estimate_fishe_retain_named(...)
    delta_kl=1e-3,      # KL budget per step (approx)
):
    """
    Adam on Update, with a TRPO-style KL trust region on Retain capability:
      1) Compute standard Adam step Δθ.
      2) Estimate Retain KL ≈ 0.5 * Σ_i F_sub[i] * (Δθ_i)^2
      3) If KL > delta_kl: scale Δθ by sqrt(delta_kl / KL).
    """

    # Start retainom the same base model as elsewhere
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    base_model.resize_token_embeddings(len(tokenizer))
    model = copy.deepcopy(base_model).to(device)
    model.train()

    # Named params for alignment with fisher_sub
    named_params = [
        (name, p) for name, p in model.named_parameters()
        if p.requires_grad
    ]

    # Adam state
    state = {}
    for name, p in named_params:
        state[name] = {
            "m": torch.zeros_like(p.data),
            "v": torch.zeros_like(p.data),
        }

    global_step = 0

    print("=== Adam with Fisher KL trust region on Retain ===")
    print(f"lr={lr}, delta_kl={delta_kl}")
    for epoch in range(num_epochs):
        for step, batch in enumerate(update_loader):
            batch = {k: v.to(device) for k, v in batch.items()}

            # 1) Forward/backward on Update batch
            model.zero_grad()
            out = model(**batch)
            loss = out.loss
            loss.backward()

            global_step += 1

            # 2) Compute Adam proposal step Δθ for each param (WITHOUT applying yet)
            proposed_steps = {}  # name -> tensor (Δθ)
            for name, p in named_params:
                if p.grad is None:
                    proposed_steps[name] = torch.zeros_like(p.data)
                    continue

                g = p.grad.data
                s = state[name]

                # Adam moments
                s["m"].mul_(beta1).add_(g, alpha=1 - beta1)
                s["v"].mul_(beta2).addcmul_(g, g, value=1 - beta2)

                # Bias-corrected
                m_hat = s["m"] / (1 - beta1 ** global_step)
                v_hat = s["v"] / (1 - beta2 ** global_step)

                # Classic Adam step (note: step is *direction*, no lr yet)
                step_dir = m_hat / (v_hat.sqrt() + eps)

                # Proposed parameter change Δθ = -lr * step_dir
                delta_theta = -lr * step_dir
                proposed_steps[name] = delta_theta

            # 3) Estimate Retain KL for this joint step using precomputed Fisher
            kl_est = 0.0
            if fisher_sub is not None:
                for name, p in named_params:
                    if name not in fisher_sub:
                        continue
                    delta = proposed_steps[name]
                    if delta is None:
                        continue
                    F = fisher_sub[name].to(delta.device)
                    # 0.5 * sum_i F_i * (Δθ_i)^2
                    kl_est += 0.5 * (F * (delta ** 2)).sum().item()

            # 4) Compute scaling factor to enforce KL ≤ delta_kl
            if fisher_sub is None or kl_est <= 0.0:
                scale = 1.0
            elif kl_est <= delta_kl:
                scale = 1.0
            else:
                scale = (delta_kl / kl_est) ** 0.5

            # 5) Apply scaled step
            for name, p in named_params:
                delta = proposed_steps[name]
                if delta is None:
                    continue
                p.data.add_(delta * scale)

            if (step + 1) % 100 == 0:
                print(
                    f"[Epoch {epoch} Step {step+1}] "
                    f"loss_new = {loss.item():.4f}, KL_est = {kl_est:.3e}, scale = {scale:.3f}"
                )

        # 6) Evaluation at epoch end
        print(f"Epoch {epoch} evaluation:")
        update_ppl = eval_ppl(model, update_test_ds, "Update new (Adam+KL)")
        retain_ppl  = eval_ppl(model, retain_test_ds,  "Retain protected (Adam+KL)")

    return model


# Precompute model Fisher on Retain (once)
base_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
base_model.resize_token_embeddings(len(tokenizer))

fisher_retain = estimate_fisher_retain(base_model, num_batches=200)

# Now run Update finetuning with TRPO-style KL trust region on Retain
adam_trpo_model = run_adam_with_fisher_trust_region(
    num_epochs=3,
    lr=3e-4,
    fisher_sub=fisher_retain,
    delta_kl=1e-10,   # tune this up/down
)


=== Adam with Fisher KL trust region on Retain ===
lr=0.0003, delta_kl=1e-10
[Epoch 0 Step 100] loss_new = 1.0596, KL_est = 6.945e-09, scale = 0.120
Epoch 0 evaluation:
Update new (Adam+KL) perplexity: 2.775
Retain protected (Adam+KL) perplexity: 4.491
[Epoch 1 Step 100] loss_new = 1.0130, KL_est = 6.120e-09, scale = 0.128
Epoch 1 evaluation:
Update new (Adam+KL) perplexity: 2.618
Retain protected (Adam+KL) perplexity: 4.646
[Epoch 2 Step 100] loss_new = 0.8632, KL_est = 6.717e-09, scale = 0.122
Epoch 2 evaluation:
Update new (Adam+KL) perplexity: 2.552
Retain protected (Adam+KL) perplexity: 4.678


In [24]:
print("=== Final comparison ===")

print("Before training:")
base_model = AutoModelForCausalLM.from_pretrained(model_name)
base_model.resize_token_embeddings(len(tokenizer))
model = copy.deepcopy(base_model).to(device)

eval_ppl(model, update_test_ds, "Update new (before training)")
eval_ppl(model, retain_test_ds,  "Retain protected (before training)")

print("Baseline Adam:")
eval_ppl(baseline_model, update_test_ds, "Update new (baseline)")
eval_ppl(baseline_model, retain_test_ds,  "Retain protected (baseline)")

print("\nEWC:")
eval_ppl(ewc_model, update_test_ds, "Update new (EWC)")
eval_ppl(ewc_model, retain_test_ds,  "Retain protected (EWC)")

print("\nProtectedAdam-γ:")
eval_ppl(protected_model, update_test_ds, "Update new (ProtectedAdam-γ)")
eval_ppl(protected_model, retain_test_ds,  "Retain protected (ProtectedAdam-γ)")


print("\nProtectedAdam2-γ:")
eval_ppl(protected_model2, update_test_ds, "Update new (ProtectedAdam-γ)")
eval_ppl(protected_model2, retain_test_ds,  "Retain protected (ProtectedAdam-γ)")

print("\nReplay:")
eval_ppl(replay_model, update_test_ds, "Update new (ProtectedAdam-γ)")
eval_ppl(replay_model, retain_test_ds,  "Retain protected (ProtectedAdam-γ)")

=== Final comparison ===
Before training:
Update new (before training) perplexity: 4.731
Retain protected (before training) perplexity: 4.496
Baseline Adam:
Update new (baseline) perplexity: 2.890
Retain protected (baseline) perplexity: 8.626

EWC:
Update new (EWC) perplexity: 2.659


Retain protected (EWC) perplexity: 6.835

ProtectedAdam-γ:
Update new (ProtectedAdam-γ) perplexity: 2.669
Retain protected (ProtectedAdam-γ) perplexity: 5.102

ProtectedAdam2-γ:
Update new (ProtectedAdam-γ) perplexity: 2.670
Retain protected (ProtectedAdam-γ) perplexity: 5.074

Replay:
Update new (ProtectedAdam-γ) perplexity: 1.609
Retain protected (ProtectedAdam-γ) perplexity: 6.951


6.951274370651979