# Experiments with Fisher

## Setup English and 'Protected' texts

In [1]:
import math
import copy
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm.notebook import tqdm # this makes tqdm.write() work with notebooks!
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW

from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
from datasets import load_dataset, load_from_disk

from trl.trainer.sft_trainer import DataCollatorForLanguageModeling
from copy import deepcopy

In [11]:
def add_generation_chat_template(tokenizer):
    if "qwen" in tokenizer.name_or_path.lower():
        # we have to use DataCollatorForLanguageModeling with completion_only_loss=True
        # however, for that tokenizer needs to have return_assistant_tokens_mask=True, and qwen decided against adding support for {% generation %} / {% endgeneration %} functionality
        # so we download a community qwen3 chat template that has it
        !wget -O all_assistant.jinja --no-check-certificate https://raw.githubusercontent.com/HarryMayne/qwen_3_chat_templates/refs/heads/main/all_assistant.jinja
        !mv all_assistant.jinja chat_templates/qwen_all_assistant.jinja
        with open('chat_templates/qwen_all_assistant.jinja', 'r') as f:
            tokenizer.chat_template = f.read()
    if "smollm2" in tokenizer.name_or_path.lower():
        with open('chat_templates/smollm2_all_assistant.jinja', 'r') as f:
            tokenizer.chat_template = f.read()
    return tokenizer

def load_or_preprocess_dataset(model_id, dataset_id, tokenizer, max_length=4096):
    local_ds_id = f"datasets/{model_id}/{dataset_id}"
    num_proc = 16
    if True:
        print(f"Dataset not found locally, processing and caching...")
        raw_dataset = load_dataset(dataset_id)["train"]
        def preprocess(example):
            tokenized = tokenizer.apply_chat_template(
                example["messages"],
                tokenize=True,
                return_assistant_tokens_mask=True,
                return_dict=True,
            )
            return {
                "input_ids": tokenized["input_ids"],
                "assistant_masks": tokenized["assistant_masks"],
            }
        
        tokenized_dataset = raw_dataset.map(preprocess, remove_columns=raw_dataset.column_names, num_proc=num_proc, desc="Tokenizing")
        def shorter_than(example):
            return len(example["input_ids"]) <= max_length
        final_dataset = tokenized_dataset.filter(shorter_than, num_proc=num_proc, desc=f"Filtering to max length {max_length}")
        print(f"Tokenized: {len(tokenized_dataset)}, After filtering: {len(final_dataset)}")
        final_dataset.save_to_disk(local_ds_id)
    return final_dataset

def create_dataloader(tokenized_dataset, batch_size):
    collator = DataCollatorForLanguageModeling(pad_token_id=tokenizer.pad_token_id,)
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collator,
    )
    return dataloader

In [12]:
from transformers import set_seed
random_seed = 42
set_seed(42)

In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

model_name = "HuggingFaceTB/SmolLM2-135M-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # for batching
tokenizer = add_generation_chat_template(tokenizer)
print(tokenizer.chat_template)

batch_size = 8

Device: cuda
{%- for message in messages %}
    {%- if loop.first and messages[0]['role'] != 'system' %}
        {{- '<|im_start|>system\nYou are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>\n' }}
    {%- endif %}
    {%- if message['role'] == 'assistant' %}
        {{- '<|im_start|>' + message['role'] }}
        {% generation %}
        {{- '\n' + message['content'] + '<|im_end|>\n' }}
        {% endgeneration %}
    {%- else %}
        {{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
    {%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
    {{- '<|im_start|>assistant\n' }}
{%- endif %}


In [14]:
eng_id = "Neelectric/OpenR1-Math-220k_extended_Llama3_4096toks"
eng_ds = load_or_preprocess_dataset(model_name, eng_id, tokenizer, 1024)

Dataset not found locally, processing and caching...
Tokenized: 86158, After filtering: 3400


Saving the dataset (0/1 shards):   0%|          | 0/3400 [00:00<?, ? examples/s]

In [15]:
ba_id = "Neelectric/wildguardmix_Llama-3.1-8B-Instruct_4096toks"
ba_ds = load_or_preprocess_dataset(model_name, ba_id, tokenizer, 1024)

Dataset not found locally, processing and caching...
Tokenized: 86745, After filtering: 75491


Saving the dataset (0/1 shards):   0%|          | 0/75491 [00:00<?, ? examples/s]

In [16]:
full_length = len(eng_ds)
print(full_length)
ba_ds = ba_ds.select(range(full_length))
print(len(ba_ds))
eng_ds = eng_ds.shuffle(seed=random_seed)
ba_ds = ba_ds.shuffle(seed=random_seed)

3400
3400


In [17]:
# Train / test splits
num_train = int(0.8 * full_length)
print(num_train)
basque_train_ds = ba_ds.select(range(num_train))
basque_test_ds = ba_ds.select(range(num_train, full_length))
print(len(basque_train_ds))
print(len(basque_test_ds))

english_train_ds = eng_ds.select(range(num_train))
english_test_ds = eng_ds.select(range(num_train, full_length))
print(len(english_train_ds))
print(len(english_test_ds))

2720
2720
680
2720
680


In [18]:
# block_size = 64
# batch_size = 8

# basque_train_ds  = LineByLineLMDataset(basque_train, tokenizer, block_size)
# basque_test_ds   = LineByLineLMDataset(basque_test,  tokenizer, block_size)
# english_train_ds = LineByLineLMDataset(english_train, tokenizer, block_size)
# english_test_ds  = LineByLineLMDataset(english_test,  tokenizer, block_size)

# eng_loader = DataLoader(english_train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
# ba_loader  = DataLoader(basque_train_ds,  batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# len(english_train_ds), len(basque_train_ds)

eng_loader = create_dataloader(english_train_ds, batch_size)
ba_loader = create_dataloader(english_test_ds, batch_size)

In [25]:
@torch.no_grad()
def eval_ppl(model, dataset, name, batch_size_eval=8):
    model.eval()
    loader = create_dataloader(dataset, batch_size_eval)
    total_loss = 0.0
    total_tokens = 0
    for batch in tqdm(loader):
        batch = {k: v.to(device) for k, v in batch.items()}
        out = model(**batch)
        loss = out.loss * batch["attention_mask"].sum()
        total_loss += loss.item()
        total_tokens += batch["attention_mask"].sum().item()
    ppl = math.exp(total_loss / total_tokens)
    print(f"{name} perplexity: {ppl:.3f}")
    model.train()
    return ppl


## Before any optim: ppl on train and protect before fine-tuning

In [27]:
base_model = AutoModelForCausalLM.from_pretrained(model_name)
model = copy.deepcopy(base_model).to(device)

In [28]:
eng_ppl = eval_ppl(model, english_train_ds, "Baseline eng before optim", batch_size_eval=8)

  0%|          | 0/340 [00:00<?, ?it/s]

Baseline eng before optim perplexity: 4.794


In [12]:
from torch.optim import AdamW

def run_baseline_adam(num_epochs=2, lr=1e-5):
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    base_model.resize_token_embeddings(len(tokenizer))
    model = copy.deepcopy(base_model).to(device)
    optimizer = AdamW(model.parameters(), lr=lr)

    print(f"Before opt:")
    eng_ppl = eval_ppl(model, english_test_ds, "English new")
    fr_ppl  = eval_ppl(model, basque_test_ds,  "Basque protected")


    print("=== Baseline Adam: train on English only ===")
    for epoch in range(num_epochs):
        for step, batch in enumerate(eng_loader):
            batch = {k: v.to(device) for k, v in batch.items()}
            model.zero_grad()
            out = model(**batch)
            loss = out.loss
            loss.backward()
            optimizer.step()

            if (step + 1) % 100 == 0:
                print(f"[Epoch {epoch} Step {step+1}] loss_new = {loss.item():.4f}")

        print(f"Epoch {epoch} evaluation:")
        eng_ppl = eval_ppl(model, english_test_ds, "English new")
        fr_ppl  = eval_ppl(model, basque_test_ds,  "Basque protected")
    return model

baseline_model = run_baseline_adam(num_epochs=4, lr=1e-5)


Before opt:


`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


English new perplexity: 83.357
Basque protected perplexity: 1197.813
=== Baseline Adam: train on English only ===
[Epoch 0 Step 100] loss_new = 4.2586
[Epoch 0 Step 200] loss_new = 3.5406
Epoch 0 evaluation:
English new perplexity: 43.736
Basque protected perplexity: 1784.078
[Epoch 1 Step 100] loss_new = 3.7917
[Epoch 1 Step 200] loss_new = 2.3742
Epoch 1 evaluation:
English new perplexity: 43.166
Basque protected perplexity: 2169.199
[Epoch 2 Step 100] loss_new = 2.9043
[Epoch 2 Step 200] loss_new = 2.8937
Epoch 2 evaluation:
English new perplexity: 43.097
Basque protected perplexity: 2645.470
[Epoch 3 Step 100] loss_new = 2.8604
[Epoch 3 Step 200] loss_new = 2.9524
Epoch 3 evaluation:
English new perplexity: 41.908
Basque protected perplexity: 2986.717


In [13]:
def estimate_fisher_on_basque(model, num_batches=200):
    model.eval()
    params = [p for p in model.parameters() if p.requires_grad]
    fisher = {p: torch.zeros_like(p.data) for p in params}

    loader = DataLoader(basque_train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    it = iter(loader)
    for i in range(num_batches):
        try:
            batch = next(it)
        except StopIteration:
            it = iter(loader)
            batch = next(it)
        batch = {k: v.to(device) for k, v in batch.items()}
        model.zero_grad()
        out = model(**batch)
        loss = out.loss
        loss.backward()
        for p in params:
            if p.grad is None:
                continue
            fisher[p] += p.grad.data.pow(2)
    for p in params:
        fisher[p] /= num_batches
    model.train()
    return fisher

def run_ewc(num_epochs=2, lr=5e-5, ewc_lambda=50.0):
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    base_model.resize_token_embeddings(len(tokenizer))
    model = copy.deepcopy(base_model).to(device)

    print("Estimating Fisher on Basque (protected) ...")
    fisher = estimate_fisher_on_basque(model, num_batches=100)
    theta0 = copy.deepcopy(model).to(device)

    optimizer = AdamW(model.parameters(), lr=lr)

    params = [p for p in model.parameters() if p.requires_grad]

    print("=== EWC: train on English with Basque EWC penalty ===")
    for epoch in range(num_epochs):
        for step, batch in enumerate(eng_loader):
            batch = {k: v.to(device) for k, v in batch.items()}
            model.zero_grad()
            out = model(**batch)
            loss_new = out.loss

            ewc_loss = 0.0
            for p, p0 in zip(params, theta0.parameters()):
                ewc_loss = ewc_loss + (fisher[p] * (p - p0).pow(2)).sum()
            total_loss = loss_new + 0.5 * ewc_lambda * ewc_loss

            total_loss.backward()
            optimizer.step()

            if (step + 1) % 100 == 0:
                print(f"[Epoch {epoch} Step {step+1}] loss_new={loss_new.item():.4f}, ewc_loss={ewc_loss.item():.4f}")

        print(f"Epoch {epoch} evaluation:")
        eng_ppl = eval_ppl(model, english_test_ds, "English new (EWC)")
        ba_ppl  = eval_ppl(model, basque_test_ds,  "Basque protected (EWC)")
    return model

ewc_model = run_ewc(num_epochs=3, lr=5e-5, ewc_lambda=50.0)


Estimating Fisher on Basque (protected) ...
=== EWC: train on English with Basque EWC penalty ===
[Epoch 0 Step 100] loss_new=2.9424, ewc_loss=0.0001
[Epoch 0 Step 200] loss_new=3.7322, ewc_loss=0.0002
Epoch 0 evaluation:
English new (EWC) perplexity: 42.162
Basque protected (EWC) perplexity: 1230.778
[Epoch 1 Step 100] loss_new=2.4321, ewc_loss=0.0002
[Epoch 1 Step 200] loss_new=2.4304, ewc_loss=0.0002
Epoch 1 evaluation:
English new (EWC) perplexity: 43.163
Basque protected (EWC) perplexity: 1863.377
[Epoch 2 Step 100] loss_new=2.3077, ewc_loss=0.0002
[Epoch 2 Step 200] loss_new=2.1978, ewc_loss=0.0002
Epoch 2 evaluation:
English new (EWC) perplexity: 44.610
Basque protected (EWC) perplexity: 2244.481


In [14]:
def run_protected_adam(
    num_epochs=2,
    lr=5e-5,
    alpha_geom=1.0,
    beta_geom=10.0,
    gamma_exp=0.5,
    subset_update_every=5,
    rho_all=0.99,
    rho_sub=0.99,
):
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    base_model.resize_token_embeddings(len(tokenizer))
    model = copy.deepcopy(base_model).to(device)
    model.train()

    params = [p for p in model.parameters() if p.requires_grad]

    state = {}
    for p in params:
        state[p] = {
            "m": torch.zeros_like(p.data),
            "v_all": torch.zeros_like(p.data),
            "v_sub": torch.zeros_like(p.data),
        }

    beta1 = 0.9
    eps = 1e-6
    global_step = 0

    def protected_adam_step():
        nonlocal global_step
        global_step += 1
        for p in params:
            if p.grad is None:
                continue
            grad = p.grad.data
            s = state[p]

            # first moment
            s["m"].mul_(beta1).add_(grad, alpha=1 - beta1)

            # second moment on "all" (new English) data
            s["v_all"].mul_(rho_all).addcmul_(grad, grad, value=1 - rho_all)

            v_all = s["v_all"]
            v_sub = s["v_sub"]
            v_protect = alpha_geom * v_all + beta_geom * v_sub

            m_hat = s["m"] / (1 - beta1**global_step)
            denom = (v_protect + eps).pow(gamma_exp)
            step = m_hat / denom
            p.data.add_(step, alpha=-lr)

    def update_subset_curvature():
        for p in params:
            if p.grad is None:
                continue
            grad = p.grad.data
            s = state[p]
            s["v_sub"].mul_(rho_sub).addcmul_(grad, grad, value=1 - rho_sub)

    fr_iter = iter(ba_loader)

    print("=== ProtectedAdam-γ: geometry shaped by Basque subset ===")
    print(f"alpha_geom={alpha_geom}, beta_geom={beta_geom}, gamma_exp={gamma_exp}")
    for epoch in range(num_epochs):
        for step, batch in enumerate(eng_loader):
            # 1) English batch: gradient for new task
            batch = {k: v.to(device) for k, v in batch.items()}
            model.zero_grad()
            out = model(**batch)
            loss = out.loss
            loss.backward()

            # 2) Take ProtectedAdam step (updates v_all + params)
            protected_adam_step()

            # 3) Occasionally update subset curvature using Basque
            if (step + 1) % subset_update_every == 0:
                try:
                    fr_batch = next(fr_iter)
                except StopIteration:
                    fr_iter = iter(ba_loader)
                    fr_batch = next(fr_iter)
                fr_batch = {k: v.to(device) for k, v in fr_batch.items()}
                model.zero_grad()
                fr_out = model(**fr_batch)
                fr_loss = fr_out.loss
                fr_loss.backward()
                update_subset_curvature()
                model.zero_grad()

            if (step + 1) % 100 == 0:
                print(f"[Epoch {epoch} Step {step+1}] loss_new = {loss.item():.4f}")

        print(f"Epoch {epoch} evaluation:")
        eng_ppl = eval_ppl(model, english_test_ds, "English new (ProtectedAdam-γ)")
        fr_ppl  = eval_ppl(model, basque_test_ds,  "Basque protected (ProtectedAdam-γ)")

    return model

protected_model = run_protected_adam(
    num_epochs=3,
    lr=1e-5,
    alpha_geom=1.0,
    beta_geom=10.0,   # strength of protected geometry
    gamma_exp=0.5,    # between 0.5 (Adam) and 1.0 (diag NGD)
    subset_update_every=5,
)


=== ProtectedAdam-γ: geometry shaped by Basque subset ===
alpha_geom=1.0, beta_geom=10.0, gamma_exp=0.5
[Epoch 0 Step 100] loss_new = 3.3215
[Epoch 0 Step 200] loss_new = 3.4844
Epoch 0 evaluation:
English new (ProtectedAdam-γ) perplexity: 45.503
Basque protected (ProtectedAdam-γ) perplexity: 1335.100
[Epoch 1 Step 100] loss_new = 3.1518
[Epoch 1 Step 200] loss_new = 3.3106
Epoch 1 evaluation:
English new (ProtectedAdam-γ) perplexity: 44.303
Basque protected (ProtectedAdam-γ) perplexity: 1425.156
[Epoch 2 Step 100] loss_new = 2.6852
[Epoch 2 Step 200] loss_new = 3.4899
Epoch 2 evaluation:
English new (ProtectedAdam-γ) perplexity: 43.644
Basque protected (ProtectedAdam-γ) perplexity: 1467.923


In [15]:
def run_protected_adam2(
    num_epochs=3,
    lr=5e-5,
    alpha_geom=1.0,
    beta_geom=10.0,
    gamma_exp=0.5,
    subset_update_every=5,
    rho_all=0.99,
    rho_sub=0.99,
):
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    base_model.resize_token_embeddings(len(tokenizer))
    model = copy.deepcopy(base_model).to(device)
    model.train()

    params = [p for p in model.parameters() if p.requires_grad]

    state = {}
    for p in params:
        state[p] = {
            "m": torch.zeros_like(p.data),
            "v_all": torch.zeros_like(p.data),
            "v_sub": torch.zeros_like(p.data),
        }

    beta1 = 0.9
    eps = 1e-6
    global_step = 0

    def protected_adam_step():
        nonlocal global_step
        global_step += 1

        # First pass: update moments, compute v_protect, and accumulate
        # the mean denominators for γ=0.5 (baseline) and γ=gamma_exp
        temp = {}
        sum_baseline = 0.0
        sum_gamma = 0.0
        count_tensors = 0

        for p in params:
            if p.grad is None:
                continue
            grad = p.grad.data
            s = state[p]

            # First moment
            s["m"].mul_(beta1).add_(grad, alpha=1 - beta1)

            # Second moment on "all" (new English) data
            s["v_all"].mul_(rho_all).addcmul_(grad, grad, value=1 - rho_all)

            v_all = s["v_all"]
            v_sub = s["v_sub"]
            v_protect = alpha_geom * v_all + beta_geom * v_sub

            # Bias-corrected first moment (optional but keeps Adam-like behaviour)
            m_hat = s["m"] / (1 - beta1**global_step)

            denom_baseline = (v_protect + eps).pow(0.5)
            denom_gamma = (v_protect + eps).pow(gamma_exp)

            sum_baseline += denom_baseline.mean()
            sum_gamma += denom_gamma.mean()
            count_tensors += 1

            temp[p] = {
                "m_hat": m_hat,
                "v_protect": v_protect,
            }

        if count_tensors == 0:
            return

        # Renormalization factor so that average step size matches γ=0.5 case
        scale = (sum_baseline / sum_gamma).detach()

        # Second pass: apply update with renormalized step size
        for p in params:
            if p.grad is None or p not in temp:
                continue
            buf = temp[p]
            m_hat = buf["m_hat"]
            v_protect = buf["v_protect"]

            denom_gamma = (v_protect + eps).pow(gamma_exp)
            step = (m_hat / denom_gamma) * scale
            p.data.add_(step, alpha=-lr)

    def update_subset_curvature():
        for p in params:
            if p.grad is None:
                continue
            grad = p.grad.data
            s = state[p]
            s["v_sub"].mul_(rho_sub).addcmul_(grad, grad, value=1 - rho_sub)

    fr_iter = iter(ba_loader)

    print(f"Before opt:")
    eng_ppl = eval_ppl(model, english_test_ds, "English new (ProtectedAdam-γ)")
    fr_ppl  = eval_ppl(model, basque_test_ds,  "Basque protected (ProtectedAdam-γ)")


    print("=== ProtectedAdam-γ: geometry shaped by Basque subset ===")
    print(f"alpha_geom={alpha_geom}, beta_geom={beta_geom}, gamma_exp={gamma_exp}")
    for epoch in range(num_epochs):
        for step, batch in enumerate(eng_loader):
            # 1) English batch: gradient for new task
            batch = {k: v.to(device) for k, v in batch.items()}
            model.zero_grad()
            out = model(**batch)
            loss = out.loss
            loss.backward()

            # 2) Take ProtectedAdam step (updates v_all + params)
            protected_adam_step()

            # 3) Occasionally update subset curvature using Basque
            if (step + 1) % subset_update_every == 0:
                try:
                    fr_batch = next(fr_iter)
                except StopIteration:
                    fr_iter = iter(ba_loader)
                    fr_batch = next(fr_iter)
                fr_batch = {k: v.to(device) for k, v in fr_batch.items()}
                model.zero_grad()
                fr_out = model(**fr_batch)
                fr_loss = fr_out.loss
                fr_loss.backward()
                update_subset_curvature()
                model.zero_grad()

            if (step + 1) % 100 == 0:
                print(f"[Epoch {epoch} Step {step+1}] loss_new = {loss.item():.4f}")

        print(f"Epoch {epoch} evaluation:")
        eng_ppl = eval_ppl(model, english_test_ds, "English new (ProtectedAdam-γ)")
        fr_ppl  = eval_ppl(model, basque_test_ds,  "Basque protected (ProtectedAdam-γ)")

    return model


protected_model2 = run_protected_adam2(
    num_epochs=3,
    lr=5e-5,
    alpha_geom=1.0,
    beta_geom=10.0,   # strength of protected geometry
    gamma_exp=0.5,    # between 0.5 (Adam) and 1.0 (diag NGD)
    subset_update_every=5,
)


Before opt:
English new (ProtectedAdam-γ) perplexity: 83.357
Basque protected (ProtectedAdam-γ) perplexity: 1197.813
=== ProtectedAdam-γ: geometry shaped by Basque subset ===
alpha_geom=1.0, beta_geom=10.0, gamma_exp=0.5
[Epoch 0 Step 100] loss_new = 2.9968
[Epoch 0 Step 200] loss_new = 3.2017
Epoch 0 evaluation:
English new (ProtectedAdam-γ) perplexity: 44.726
Basque protected (ProtectedAdam-γ) perplexity: 1628.112
[Epoch 1 Step 100] loss_new = 3.0774
[Epoch 1 Step 200] loss_new = 2.2951
Epoch 1 evaluation:
English new (ProtectedAdam-γ) perplexity: 44.513
Basque protected (ProtectedAdam-γ) perplexity: 1757.193
[Epoch 2 Step 100] loss_new = 3.0296
[Epoch 2 Step 200] loss_new = 2.4865
Epoch 2 evaluation:
English new (ProtectedAdam-γ) perplexity: 45.276
Basque protected (ProtectedAdam-γ) perplexity: 1845.656


In [16]:
def run_replay(
    num_epochs=2,
    lr=5e-5,
    subset_update_every=5,
    replay_weight=1.0,   # λ: strength of Basque replay loss
):
    """
    Experience Replay baseline.

    - Optimizes English CE loss every step.
    - Every `subset_update_every` steps, also optimizes Basque CE.
    - Total loss = CE_english + replay_weight * CE_Basque.
    - Uses plain AdamW.
    - No curvature, no shielding, no geometry.
    """

    # Load model
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    base_model.resize_token_embeddings(len(tokenizer))
    model = copy.deepcopy(base_model).to(device)
    model.train()

    optimizer = AdamW(model.parameters(), lr=lr)

    # Basque iterator for replay
    fr_iter = iter(ba_loader)

    print("=== Replay baseline: English training + Basque replay ===")
    print(f"subset_update_every={subset_update_every}, replay_weight={replay_weight}")

    print("Before opt:")
    eval_ppl(model, english_test_ds, "English new (replay)")
    eval_ppl(model, basque_test_ds,  "Basque protected (replay)")

    for epoch in range(num_epochs):
        for step, batch in enumerate(eng_loader):
            batch = {k: v.to(device) for k, v in batch.items()}

            # English forward/backward
            model.zero_grad()
            out = model(**batch)
            loss_new = out.loss
            total_loss = loss_new

            # Basque replay every N steps
            if (step + 1) % subset_update_every == 0:
                try:
                    ba_batch = next(fr_iter)
                except StopIteration:
                    fr_iter = iter(ba_loader)
                    ba_batch = next(fr_iter)
                ba_batch = {k: v.to(device) for k, v in ba_batch.items()}

                ba_out = model(**ba_batch)
                ba_loss = ba_out.loss

                total_loss = loss_new + replay_weight * ba_loss

            # Backprop + update
            total_loss.backward()
            optimizer.step()

            # Logging
            if (step + 1) % 100 == 0:
                if (step + 1) % subset_update_every == 0:
                    print(
                        f"[Epoch {epoch} Step {step+1}] "
                        f"loss_new={loss_new.item():.4f}, "
                        f"loss_replay={ba_loss.item():.4f}, "
                        f"total={total_loss.item():.4f}"
                    )
                else:
                    print(f"[Epoch {epoch} Step {step+1}] loss_new={loss_new.item():.4f}")

        # End epoch eval
        print(f"Epoch {epoch} evaluation:")
        eval_ppl(model, english_test_ds, "English new (replay)")
        eval_ppl(model, basque_test_ds,  "Basque protected (replay)")

    return model


replay_model = run_replay(
    num_epochs=3,
    lr=5e-5,
    subset_update_every=5,
    replay_weight=1.0,
)


=== Replay baseline: English training + Basque replay ===
subset_update_every=5, replay_weight=1.0
Before opt:
English new (replay) perplexity: 83.357
Basque protected (replay) perplexity: 1197.813
[Epoch 0 Step 100] loss_new=3.7665, loss_replay=5.5810, total=9.3476
[Epoch 0 Step 200] loss_new=3.4632, loss_replay=5.2603, total=8.7235
Epoch 0 evaluation:
English new (replay) perplexity: 45.219
Basque protected (replay) perplexity: 183.202
[Epoch 1 Step 100] loss_new=2.7046, loss_replay=5.0360, total=7.7406
[Epoch 1 Step 200] loss_new=2.4994, loss_replay=4.9481, total=7.4475
Epoch 1 evaluation:
English new (replay) perplexity: 45.772
Basque protected (replay) perplexity: 126.821
[Epoch 2 Step 100] loss_new=2.4492, loss_replay=4.6953, total=7.1445
[Epoch 2 Step 200] loss_new=2.6651, loss_replay=4.3451, total=7.0102
Epoch 2 evaluation:
English new (replay) perplexity: 45.777
Basque protected (replay) perplexity: 104.797


In [17]:
def estimate_fisher_basque(model, num_batches=200):
    model.eval()
    fisher = {
        name: torch.zeros_like(p.data)
        for name, p in model.named_parameters()
        if p.requires_grad
    }

    loader = DataLoader(
        basque_train_ds,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )
    it = iter(loader)

    for i in range(num_batches):
        try:
            batch = next(it)
        except StopIteration:
            it = iter(loader)
            batch = next(it)

        batch = {k: v.to(device) for k, v in batch.items()}
        model.zero_grad()
        out = model(**batch)
        loss = out.loss
        loss.backward()

        for name, p in model.named_parameters():
            if not p.requires_grad or p.grad is None:
                continue
            fisher[name] += p.grad.data.pow(2)

    for name in fisher:
        fisher[name] /= num_batches

    model.train()
    return fisher


In [18]:
def estimate_model_fisher_basque(model, num_batches=200, top_k=100):
    """
    Compute *model Fisher* diagonal using KL(p_ref || p_model),
    with optional top-K truncation of the reference distribution.

    top_k < 0  → use full distribution (no truncation)
    top_k > 0  → keep only top_k tokens in reference distribution
    """

    # Freeze reference model θ0
    ref_model = copy.deepcopy(model).eval().to(device)
    for p in ref_model.parameters():
        p.requires_grad = False

    model.eval()

    fisher = {
        name: torch.zeros_like(p.data)
        for name, p in model.named_parameters()
        if p.requires_grad
    }

    loader = DataLoader(
        basque_train_ds,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )
    it = iter(loader)

    for i in range(num_batches):
        try:
            batch = next(it)
        except StopIteration:
            it = iter(loader)
            batch = next(it)

        batch = {k: v.to(device) for k, v in batch.items()}

        # ---- 1. Reference distribution ----
        with torch.no_grad():
            ref_logits = ref_model(**batch).logits
            ref_probs_full = ref_logits.softmax(dim=-1)  # shape [B, T, V]

        # ---- 2. Possibly truncate to top-K ----
        if top_k is not None and top_k > 0:
            # Get top-K indices for each token
            top_vals, top_idx = torch.topk(ref_probs_full, k=top_k, dim=-1)
            # Renormalize probs over top-K
            ref_probs = top_vals / top_vals.sum(dim=-1, keepdim=True)
            # Make a tensor of zeros [B,T,V]
            ref_probs_k = torch.zeros_like(ref_probs_full)
            # Scatter top-K probabilities back into vocab dimension
            ref_probs_k.scatter_(-1, top_idx, ref_probs)
            ref_probs = ref_probs_k
        else:
            # use full distribution
            ref_probs = ref_probs_full

        # ---- 3. Model logits ----
        logits = model(**batch).logits
        log_probs = logits.log_softmax(dim=-1)

        # ---- 4. KL(p_ref || p_model) ----
        # KL per token: Σ_i q_i log(q_i/p_i)
        kl = (ref_probs * (ref_probs.log() - log_probs)).sum(dim=-1)
        loss = kl.mean()

        # ---- 5. Backprop = model Fisher at θ0 ----
        model.zero_grad()
        loss.backward()

        # ---- 6. Accumulate grad^2 ----
        for name, p in model.named_parameters():
            if not p.requires_grad or p.grad is None:
                continue
            fisher[name] += p.grad.data.pow(2)

    # Average
    for name in fisher:
        fisher[name] /= num_batches

    model.train()
    return fisher


In [19]:
def run_protected_adam_precomputed(
    num_epochs=2,
    lr=5e-5,
    alpha_geom=1.0,
    beta_geom=10.0,
    gamma_exp=0.5,
    rho_all=0.99,
    fisher_sub=None,   # dict[name -> tensor]
):
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    model = copy.deepcopy(base_model).to(device)
    model.train()

    # We'll work with named parameters for alignment
    named_params = [
        (name, p) for name, p in model.named_parameters()
        if p.requires_grad
    ]

    state = {}
    for name, p in named_params:
        if fisher_sub is not None and name in fisher_sub:
            v_sub_init = fisher_sub[name].clone().to(device)
        else:
            v_sub_init = torch.zeros_like(p.data)

        state[name] = {
            "m": torch.zeros_like(p.data),
            "v_all": torch.zeros_like(p.data),
            "v_sub": v_sub_init,
        }

    beta1 = 0.9
    eps = 1e-6
    global_step = 0

    def protected_adam_step():
        nonlocal global_step
        global_step += 1
        for name, p in named_params:
            if p.grad is None:
                continue
            grad = p.grad.data
            s = state[name]

            # first moment
            s["m"].mul_(beta1).add_(grad, alpha=1 - beta1)

            # second moment on "all" (new English) data
            s["v_all"].mul_(rho_all).addcmul_(grad, grad, value=1 - rho_all)

            v_all = s["v_all"]
            v_sub = s["v_sub"]  # fixed precomputed Fisher
            v_protect = alpha_geom * v_all + beta_geom * v_sub

            m_hat = s["m"] / (1 - beta1**global_step)
            denom = (v_protect + eps).pow(gamma_exp)
            step = m_hat / denom
            p.data.add_(step, alpha=-lr)

    print("=== ProtectedAdam-γ with precomputed Basque Fisher ===")
    for epoch in range(num_epochs):
        for step, batch in enumerate(eng_loader):
            batch = {k: v.to(device) for k, v in batch.items()}
            model.zero_grad()
            out = model(**batch)
            loss = out.loss
            loss.backward()
            protected_adam_step()

        print(f"Epoch {epoch} evaluation:")
        eval_ppl(model, english_test_ds, "English new (precomputed-Fisher)")
        eval_ppl(model, basque_test_ds,  "Basque protected (precomputed-Fisher)")

    return model


# 1. Make a base model for Fisher estimation
base_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)


# 2. Estimate Fisher on Basque ONCE
mfisher_basque = estimate_model_fisher_basque(base_model, num_batches=200)

# 3. Run English finetuning using precomputed Fisher, no Basque batches
protected_model_pre_mfisher = run_protected_adam_precomputed(
    num_epochs=3,
    lr=1e-5,
    alpha_geom=1.0,
    beta_geom=10.0,
    gamma_exp=0.5,
    rho_all=0.99,
    fisher_sub=mfisher_basque,   # <- pass the dict here
)


# 2. Estimate Fisher on Basque ONCE
fisher_basque = estimate_fisher_basque(base_model, num_batches=200)

# 3. Run English finetuning using precomputed Fisher, no Basque batches
protected_model_pre = run_protected_adam_precomputed(
    num_epochs=3,
    lr=1e-5,
    alpha_geom=1.0,
    beta_geom=10.0,
    gamma_exp=0.5,
    rho_all=0.99,
    fisher_sub=fisher_basque,   # <- pass the dict here
)


=== ProtectedAdam-γ with precomputed Basque Fisher ===
Epoch 0 evaluation:
English new (precomputed-Fisher) perplexity: 44.793
Basque protected (precomputed-Fisher) perplexity: 1375.423
Epoch 1 evaluation:
English new (precomputed-Fisher) perplexity: 43.936
Basque protected (precomputed-Fisher) perplexity: 1540.240
Epoch 2 evaluation:
English new (precomputed-Fisher) perplexity: 43.401
Basque protected (precomputed-Fisher) perplexity: 1651.044
=== ProtectedAdam-γ with precomputed Basque Fisher ===
Epoch 0 evaluation:
English new (precomputed-Fisher) perplexity: 44.376
Basque protected (precomputed-Fisher) perplexity: 1282.951
Epoch 1 evaluation:
English new (precomputed-Fisher) perplexity: 42.969
Basque protected (precomputed-Fisher) perplexity: 1409.191
Epoch 2 evaluation:
English new (precomputed-Fisher) perplexity: 42.423
Basque protected (precomputed-Fisher) perplexity: 1484.037


In [20]:
# does not work, ignore for now
def run_protected_adam_precomputed2(
    num_epochs=2,
    lr=5e-5,
    alpha_geom=1.0,      # scale for v_all (Adam geometry)
    beta_geom=10.0,      # strength of protection from v_sub
    gamma_exp=0.5,       # exponent applied only to normalized v_sub
    rho_all=0.99,
    fisher_sub=None,     # dict[name -> tensor], precomputed Fisher on Basque
):
    """
    Protected Adam with precomputed Fisher (additive version).

    - v_all: EMA of grad^2 on English (new task), like Adam.
    - v_sub: fixed Fisher from Basque (protected capability), precomputed.
    - v_sub is normalized globally once to be dimensionless.

    Update (per-parameter i):
        v_all_i ← EMA of g_i^2
        v_sub_i ≈ Fisher_i

        v_sub_scaled_i = v_sub_i / global_mean(v_sub)

        base_rms_i   = sqrt(alpha_geom * v_all_i)
        protect_i    = beta_geom * (v_sub_scaled_i ** gamma_exp)

        denom_i = base_rms_i + protect_i + eps

        Δθ_i = -lr * m_hat_i / denom_i

    Properties:
      - If fisher_sub is None or beta_geom = 0 -> exactly Adam.
      - If v_sub is small -> denom ≈ base_rms -> Adam-like.
      - If v_sub is large -> extra additive penalty in denom -> stronger protection.
    """

    # 1) Start from base GPT-2
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    model = copy.deepcopy(base_model).to(device)
    model.train()

    # 2) Collect named parameters to align with fisher_sub[name]
    named_params = [
        (name, p) for name, p in model.named_parameters()
        if p.requires_grad
    ]

    # 3) Initialize state (m, v_all, v_sub)
    state = {}
    for name, p in named_params:
        v_sub_init = torch.zeros_like(p.data)
        if fisher_sub is not None and name in fisher_sub:
            v_sub_init = fisher_sub[name].clone().to(p.data.device)
        state[name] = {
            "m": torch.zeros_like(p.data),
            "v_all": torch.zeros_like(p.data),
            "v_sub": v_sub_init,
        }

    # 4) Compute a global mean of v_sub for normalization (dimensionless)
    if fisher_sub is not None:
        total_sum = 0.0
        total_count = 0
        for name, p in named_params:
            v_sub = state[name]["v_sub"]
            if v_sub.numel() > 0:
                total_sum += v_sub.sum().item()
                total_count += v_sub.numel()
        if total_count > 0:
            global_vsub_mean = total_sum / total_count
        else:
            global_vsub_mean = 1.0
    else:
        global_vsub_mean = 1.0

    beta1 = 0.9
    eps = 1e-8
    global_step = 0

    print("=== ProtectedAdam-precomputed2 (additive): Adam base + Fisher protection ===")
    print(
        f"alpha_geom={alpha_geom}, beta_geom={beta_geom}, "
        f"gamma_exp={gamma_exp}, rho_all={rho_all}, "
        f"global_vsub_mean={global_vsub_mean:.3e}"
    )

    for epoch in range(num_epochs):
        for step, batch in enumerate(eng_loader):
            batch = {k: v.to(device) for k, v in batch.items()}

            # ----- 1) Forward/backward on English (new task) -----
            model.zero_grad()
            out = model(**batch)
            loss = out.loss
            loss.backward()

            global_step += 1

            # ----- 2) Protected Adam step (additive Fisher term) -----
            with torch.no_grad():
                for name, p in named_params:
                    if p.grad is None:
                        continue

                    g = p.grad.data
                    s = state[name]

                    # First moment (Adam)
                    s["m"].mul_(beta1).add_(g, alpha=1 - beta1)

                    # Second moment on "all" (new English) data (Adam-style)
                    s["v_all"].mul_(rho_all).addcmul_(g, g, value=1 - rho_all)

                    v_all = s["v_all"]
                    v_sub = s["v_sub"]

                    # Base Adam geometry: sqrt of v_all (scaled)
                    base_rms = (alpha_geom * v_all).sqrt()

                    # Normalized protective curvature from v_sub (dimensionless)
                    if fisher_sub is not None and beta_geom != 0.0 and global_vsub_mean > 0.0:
                        v_sub_scaled = v_sub / (global_vsub_mean + 1e-12)
                        v_sub_scaled = torch.clamp(v_sub_scaled, min=0.0)  # safety
                        protect_term = beta_geom * v_sub_scaled.pow(gamma_exp)
                    else:
                        protect_term = 0.0

                    m_hat = s["m"] / (1 - beta1**global_step)

                    # ADDITIVE protection: denom = base_rms + protective term
                    denom = base_rms + protect_term + eps
                    step_dir = m_hat / denom

                    p.data.add_(step_dir, alpha=-lr)

            if (step + 1) % 100 == 0:
                print(f"[Epoch {epoch} Step {step+1}] loss_new = {loss.item():.4f}")

        # ----- 3) Epoch-end evaluation -----
        print(f"Epoch {epoch} evaluation:")
        eng_ppl = eval_ppl(model, english_test_ds, "English new (ProtAdam-pre2-add)")
        fr_ppl  = eval_ppl(model, basque_test_ds,  "Basque protected (ProtAdam-pre2-add)")

    return model


# 3. Run English finetuning using precomputed Fisher, no Basque batches
protected_model_pre = run_protected_adam_precomputed2(
    num_epochs=3,
    lr=1e-5,
    alpha_geom=1.0,
    beta_geom=0.1,
    gamma_exp=0.5,
    rho_all=0.99,
    fisher_sub=fisher_basque,   # <- pass the dict here
)



=== ProtectedAdam-precomputed2 (additive): Adam base + Fisher protection ===
alpha_geom=1.0, beta_geom=0.1, gamma_exp=0.5, rho_all=0.99, global_vsub_mean=4.393e-06
[Epoch 0 Step 100] loss_new = 5.0631
[Epoch 0 Step 200] loss_new = 4.9248
Epoch 0 evaluation:
English new (ProtAdam-pre2-add) perplexity: 62.104
Basque protected (ProtAdam-pre2-add) perplexity: 1168.050
[Epoch 1 Step 100] loss_new = 4.4134
[Epoch 1 Step 200] loss_new = 4.3895
Epoch 1 evaluation:
English new (ProtAdam-pre2-add) perplexity: 51.642
Basque protected (ProtAdam-pre2-add) perplexity: 1170.619
[Epoch 2 Step 100] loss_new = 4.4425
[Epoch 2 Step 200] loss_new = 4.0078
Epoch 2 evaluation:
English new (ProtAdam-pre2-add) perplexity: 46.780
Basque protected (ProtAdam-pre2-add) perplexity: 1182.053


In [21]:
def run_adam_with_fisher_trust_region(
    num_epochs=2,
    lr=1e-5,
    beta1=0.9,
    beta2=0.999,
    eps=1e-8,
    fisher_sub=None,    # dict[name -> tensor] from estimate_fishe_basque_named(...)
    delta_kl=1e-3,      # KL budget per step (approx)
):
    """
    Adam on English, with a TRPO-style KL trust region on Basque capability:
      1) Compute standard Adam step Δθ.
      2) Estimate Basque KL ≈ 0.5 * Σ_i F_sub[i] * (Δθ_i)^2
      3) If KL > delta_kl: scale Δθ by sqrt(delta_kl / KL).
    """

    # Start from the same base model as elsewhere
    base_model = AutoModelForCausalLM.from_pretrained(model_name)
    base_model.resize_token_embeddings(len(tokenizer))
    model = copy.deepcopy(base_model).to(device)
    model.train()

    # Named params for alignment with fisher_sub
    named_params = [
        (name, p) for name, p in model.named_parameters()
        if p.requires_grad
    ]

    # Adam state
    state = {}
    for name, p in named_params:
        state[name] = {
            "m": torch.zeros_like(p.data),
            "v": torch.zeros_like(p.data),
        }

    global_step = 0

    print("=== Adam with Fisher KL trust region on Basque ===")
    print(f"lr={lr}, delta_kl={delta_kl}")
    for epoch in range(num_epochs):
        for step, batch in enumerate(eng_loader):
            batch = {k: v.to(device) for k, v in batch.items()}

            # 1) Forward/backward on English batch
            model.zero_grad()
            out = model(**batch)
            loss = out.loss
            loss.backward()

            global_step += 1

            # 2) Compute Adam proposal step Δθ for each param (WITHOUT applying yet)
            proposed_steps = {}  # name -> tensor (Δθ)
            for name, p in named_params:
                if p.grad is None:
                    proposed_steps[name] = torch.zeros_like(p.data)
                    continue

                g = p.grad.data
                s = state[name]

                # Adam moments
                s["m"].mul_(beta1).add_(g, alpha=1 - beta1)
                s["v"].mul_(beta2).addcmul_(g, g, value=1 - beta2)

                # Bias-corrected
                m_hat = s["m"] / (1 - beta1 ** global_step)
                v_hat = s["v"] / (1 - beta2 ** global_step)

                # Classic Adam step (note: step is *direction*, no lr yet)
                step_dir = m_hat / (v_hat.sqrt() + eps)

                # Proposed parameter change Δθ = -lr * step_dir
                delta_theta = -lr * step_dir
                proposed_steps[name] = delta_theta

            # 3) Estimate Basque KL for this joint step using precomputed Fisher
            kl_est = 0.0
            if fisher_sub is not None:
                for name, p in named_params:
                    if name not in fisher_sub:
                        continue
                    delta = proposed_steps[name]
                    if delta is None:
                        continue
                    F = fisher_sub[name].to(delta.device)
                    # 0.5 * sum_i F_i * (Δθ_i)^2
                    kl_est += 0.5 * (F * (delta ** 2)).sum().item()

            # 4) Compute scaling factor to enforce KL ≤ delta_kl
            if fisher_sub is None or kl_est <= 0.0:
                scale = 1.0
            elif kl_est <= delta_kl:
                scale = 1.0
            else:
                scale = (delta_kl / kl_est) ** 0.5

            # 5) Apply scaled step
            for name, p in named_params:
                delta = proposed_steps[name]
                if delta is None:
                    continue
                p.data.add_(delta * scale)

            if (step + 1) % 100 == 0:
                print(
                    f"[Epoch {epoch} Step {step+1}] "
                    f"loss_new = {loss.item():.4f}, KL_est = {kl_est:.3e}, scale = {scale:.3f}"
                )

        # 6) Evaluation at epoch end
        print(f"Epoch {epoch} evaluation:")
        eng_ppl = eval_ppl(model, english_test_ds, "English new (Adam+KL)")
        fr_ppl  = eval_ppl(model, basque_test_ds,  "Basque protected (Adam+KL)")

    return model




# Precompute model Fisher on Basque (once)
base_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
base_model.resize_token_embeddings(len(tokenizer))

fisher_basque = estimate_fisher_basque(base_model, num_batches=200)

# Now run English finetuning with TRPO-style KL trust region on Basque
adam_trpo_model = run_adam_with_fisher_trust_region(
    num_epochs=3,
    lr=1e-5,
    fisher_sub=fisher_basque,
    delta_kl=1e-10,   # tune this up/down
)


=== Adam with Fisher KL trust region on Basque ===
lr=1e-05, delta_kl=1e-10
[Epoch 0 Step 100] loss_new = 5.2644, KL_est = 8.133e-09, scale = 0.111
[Epoch 0 Step 200] loss_new = 4.6718, KL_est = 8.090e-09, scale = 0.111
Epoch 0 evaluation:
English new (Adam+KL) perplexity: 52.872
Basque protected (Adam+KL) perplexity: 1183.309
[Epoch 1 Step 100] loss_new = 4.0843, KL_est = 4.778e-09, scale = 0.145
[Epoch 1 Step 200] loss_new = 4.1564, KL_est = 4.154e-09, scale = 0.155
Epoch 1 evaluation:
English new (Adam+KL) perplexity: 46.493
Basque protected (Adam+KL) perplexity: 1336.613
[Epoch 2 Step 100] loss_new = 3.6354, KL_est = 4.001e-09, scale = 0.158
[Epoch 2 Step 200] loss_new = 3.4773, KL_est = 3.032e-09, scale = 0.182
Epoch 2 evaluation:
English new (Adam+KL) perplexity: 45.036
Basque protected (Adam+KL) perplexity: 1466.439


In [25]:
print("=== Final comparison ===")

print("Before training:")
base_model = AutoModelForCausalLM.from_pretrained(model_name)
base_model.resize_token_embeddings(len(tokenizer))
model = copy.deepcopy(base_model).to(device)

eval_ppl(model, english_test_ds, "English new (before training)")
eval_ppl(model, basque_test_ds,  "Basque protected (before training)")

print("Baseline Adam:")
eval_ppl(baseline_model, english_test_ds, "English new (baseline)")
eval_ppl(baseline_model, basque_test_ds,  "Basque protected (baseline)")

print("\nEWC:")
eval_ppl(ewc_model, english_test_ds, "English new (EWC)")
eval_ppl(ewc_model, basque_test_ds,  "Basque protected (EWC)")

print("\nProtectedAdam-γ:")
eval_ppl(protected_model, english_test_ds, "English new (ProtectedAdam-γ)")
eval_ppl(protected_model, basque_test_ds,  "Basque protected (ProtectedAdam-γ)")


print("\nProtectedAdam2-γ:")
eval_ppl(protected_model2, english_test_ds, "English new (ProtectedAdam-γ)")
eval_ppl(protected_model2, basque_test_ds,  "Basque protected (ProtectedAdam-γ)")

print("\nReplay:")
eval_ppl(replay_model, english_test_ds, "English new (ProtectedAdam-γ)")
eval_ppl(replay_model, basque_test_ds,  "Basque protected (ProtectedAdam-γ)")

=== Final comparison ===
Before training:
English new (before training) perplexity: 83.357
Basque protected (before training) perplexity: 1197.813
Baseline Adam:
English new (baseline) perplexity: 41.908
Basque protected (baseline) perplexity: 2986.717

EWC:
English new (EWC) perplexity: 44.610
Basque protected (EWC) perplexity: 2244.481

ProtectedAdam-γ:
English new (ProtectedAdam-γ) perplexity: 43.644
Basque protected (ProtectedAdam-γ) perplexity: 1467.923

ProtectedAdam2-γ:
English new (ProtectedAdam-γ) perplexity: 45.276
Basque protected (ProtectedAdam-γ) perplexity: 1845.656

Replay:
English new (ProtectedAdam-γ) perplexity: 45.777
Basque protected (ProtectedAdam-γ) perplexity: 104.797


104.79691320983954