# LLM Unlearning for SEMEval 2025: SKU Experiments

Authors: [Your Name], [Collaborators]
Course: [Course Name], [University]
Date: 2025-09-15

Abstract: This notebook reproduces SKU experiments for LLM unlearning as part of our SEMEval 2025 submission. It provides a clean, reproducible workflow: environment setup, configuration, data loading, training/evaluation, and results export with metadata. All random seeds are fixed, and package versions are logged for determinism.

Outline:
- Setup and Reproducibility
- Configuration and Paths
- Data Preparation
- Training and Unlearning
- Evaluation
- Results and Export
- References & Appendix

In [None]:
# Setup and Reproducibility
import os, sys, platform, random, json, time, datetime
from pathlib import Path

import numpy as np

# Optional imports guarded for environments without these libs
try:
    import torch
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
except Exception:
    torch = None

# Seeds
SEED = int(os.getenv("SEED", 42))
random.seed(SEED)
np.random.seed(SEED)
if torch:
    torch.manual_seed(SEED)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(SEED)

# Paths
ROOT = Path.cwd()
DATA_DIR = ROOT / "train"
VAL_DIR = ROOT / "validation"
OUT_DIR = ROOT / "outputs"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Environment snapshot
ENV_INFO = {
    "timestamp": datetime.datetime.utcnow().isoformat() + "Z",
    "python": sys.version.split()[0],
    "platform": platform.platform(),
    "executable": sys.executable,
    "packages": {}
}

# Capture key packages if installed
for pkg in ["torch", "transformers", "datasets", "accelerate", "numpy", "pandas", "scikit-learn"]:
    try:
        mod = __import__(pkg.replace("-", "_"))
        ver = getattr(mod, "__version__", "unknown")
        ENV_INFO["packages"][pkg] = ver
    except Exception:
        pass

print(json.dumps(ENV_INFO, indent=2))

In [None]:
# Configuration and Paths
from dataclasses import dataclass, asdict
from typing import Optional

try:
    # Prefer project config if available
    import config as project_config
except Exception:
    project_config = None

@dataclass
class Config:
    model_name: str = getattr(project_config, "MODEL_NAME", "gpt2")
    batch_size: int = getattr(project_config, "BATCH_SIZE", 8)
    lr: float = getattr(project_config, "LEARNING_RATE", 5e-5)
    num_epochs: int = getattr(project_config, "NUM_EPOCHS", 1)
    max_length: int = getattr(project_config, "MAX_LENGTH", 256)
    eval_batch_size: int = getattr(project_config, "EVAL_BATCH_SIZE", 8)
    output_dir: str = str(OUT_DIR)
    data_dir: str = str(DATA_DIR)
    val_dir: str = str(VAL_DIR)
    seed: int = SEED

CFG = Config()
print("Config:\n", json.dumps(asdict(CFG), indent=2))

# Ensure expected directories exist
for p in [CFG.output_dir, CFG.data_dir, CFG.val_dir]:
    Path(p).mkdir(parents=True, exist_ok=True)

# Selective Knowledge Negation Unlearning


## How to Run

- Run cells from top to bottom. The setup cells create output folders and log environment details for reproducibility.
- Adjust `Config` parameters in the Configuration cell as needed.
- Ensure `train/` and `validation/` contain `retain.jsonl` and `forget.jsonl` files as per the project.
- Results and logs will be saved under `outputs/` and appended to `evaluation_results.jsonl` at the project root.

If you encounter missing packages, install them per `requirements.txt`.

In [None]:
# Dependencies & GPU Info (avoid pip in academic submission; use requirements.txt)
import importlib, shutil
REQUIRED = ["torch", "transformers", "peft", "huggingface_hub", "pyarrow", "pandas", "tqdm", "rouge_score"]
missing = []
for pkg in REQUIRED:
    try:
        importlib.import_module(pkg)
    except Exception:
        missing.append(pkg)
if missing:
    print("⚠️ Missing packages:", missing)
    print("Install them via: pip install -r requirements.txt (or project managed environment)")

import torch
import pandas as pd
import numpy as np
import json, math, os
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from rouge_score import rouge_scorer
from huggingface_hub import snapshot_download

# Model source (local mirror fallback)
MODEL_PATH = os.getenv("MODEL_PATH", "semeval25-unlearning-1B-model")
if not os.path.exists(MODEL_PATH):
    try:
        snapshot_download(repo_id='llmunlearningsemeval2025organization/olmo-1B-model-semeval25-unlearning', local_dir=MODEL_PATH)
    except Exception as e:
        print(f"⚠️ Could not download model snapshot: {e}")

print(f"GPUs available: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

## 2. Caricamento Dati e Modelli

In [None]:
# Data loading (supports parquet or jsonl)
from pathlib import Path

def load_split(path: str):
    p = Path(path)
    if p.suffix == '.parquet':
        return pd.read_parquet(p)
    elif p.suffix in {'.jsonl', '.json'}:
        rows = []
        with open(p, 'r', encoding='utf-8') as f:
            for line in f:
                rows.append(json.loads(line))
        return pd.DataFrame(rows)
    else:
        raise ValueError(f"Unsupported file type: {p}")

DATA_ROOT = Path(os.getenv('DATA_ROOT', '.'))
retain_train_path = os.getenv('RETAIN_TRAIN', str(DATA_ROOT / 'train' / 'retain.jsonl'))
forget_train_path = os.getenv('FORGET_TRAIN', str(DATA_ROOT / 'train' / 'forget.jsonl'))
retain_val_path = os.getenv('RETAIN_VAL', str(DATA_ROOT / 'validation' / 'retain.jsonl'))
forget_val_path = os.getenv('FORGET_VAL', str(DATA_ROOT / 'validation' / 'forget.jsonl'))

# Fallback: if parquet paths exist (original Kaggle style), use them
parquet_candidates = {
    'retain_train': 'retain_train-00000-of-00001.parquet',
    'forget_train': 'forget_train-00000-of-00001.parquet',
    'retain_validation': 'retain_validation-00000-of-00001.parquet',
    'forget_validation': 'forget_validation-00000-of-00001.parquet'
}
for key, fname in parquet_candidates.items():
    candidate = Path('/kaggle/input/olmo-model/semeval25-unlearning-data/data') / fname
    if candidate.exists():
        if 'retain_train' == key: retain_train_path = str(candidate)
        if 'forget_train' == key: forget_train_path = str(candidate)
        if 'retain_validation' == key: retain_val_path = str(candidate)
        if 'forget_validation' == key: forget_val_path = str(candidate)

retain_train_df = load_split(retain_train_path)
forget_train_df = load_split(forget_train_path)
retain_validation_df = load_split(retain_val_path)
forget_validation_df = load_split(forget_val_path)

# Tokenizer
if 'tokenizer' not in globals():
    tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'right'

print("Datasets loaded & tokenizer ready")
print({
    'retain_train': len(retain_train_df),
    'forget_train': len(forget_train_df),
    'retain_validation': len(retain_validation_df),
    'forget_validation': len(forget_validation_df),
})

In [None]:
# Sanity check lettura
print("Train (retain, forget) sizes:", len(retain_train_df), len(forget_train_df))
print("Columns:", list(retain_train_df.columns))
print(retain_train_df.head(2))

## 3. Dataset

In [None]:
class UnlearningDataset(Dataset):
    """Dataset for Selective Knowledge Unlearning.

    Each record must contain:
      - input: prompt text
      - output: ground-truth answer text (to be preserved for retain; suppressed for forget)
      - split: 'retain' or 'forget'

    Strategy:
      1. Tokenize prompt and answer separately to create a clean boundary.
      2. Truncate prompt first (max_length). Remaining budget allocated to answer tokens.
      3. Keep track of how many answer tokens survive truncation (answer_len_kept).
      4. start_locs indexes the first answer token BEFORE shift (model targets are shifted by one).

    Returns plain Python lists for efficiency; collate_fn converts to tensors.
    """
    def __init__(self, data_source: pd.DataFrame | str, tokenizer, max_length: int = 512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        if isinstance(data_source, pd.DataFrame):
            self.data = data_source
            print(f"Loaded {len(self.data)} examples from DataFrame")
        elif isinstance(data_source, str):
            data_list = []
            with open(data_source, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        data_list.append(json.loads(line))
            self.data = pd.DataFrame(data_list)
            print(f"Loaded {len(self.data)} examples from {data_source}")
        else:
            raise TypeError("data_source must be DataFrame or path to jsonl")
        # Basic validation
        expected_cols = {"input", "output", "split"}
        missing = expected_cols - set(self.data.columns)
        if missing:
            raise ValueError(f"Missing required columns: {missing}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        row = self.data.iloc[idx]
        prompt_text = str(row["input"]) if row["input"] is not None else ""
        answer_text = str(row["output"]) if row["output"] is not None else ""

        prompt_tok = self.tokenizer(
            prompt_text,
            add_special_tokens=False,
            padding=False,
            truncation=True,
            max_length=self.max_length,
            return_tensors=None,
        )
        answer_text_sp = answer_text if answer_text.startswith(" ") else (" " + answer_text)
        answer_tok = self.tokenizer(
            answer_text_sp,
            add_special_tokens=False,
            padding=False,
            truncation=True,
            max_length=self.max_length,
            return_tensors=None,
        )
        prompt_ids = prompt_tok["input_ids"]
        answer_ids = answer_tok["input_ids"]
        prompt_len = len(prompt_ids)
        ans_len = len(answer_ids)
        if prompt_len >= self.max_length:
            input_ids = prompt_ids[: self.max_length]
            answer_len_kept = 0
        else:
            available = self.max_length - prompt_len
            answer_len_kept = min(ans_len, max(0, available))
            input_ids = prompt_ids + answer_ids[:answer_len_kept]
        attention_mask = [1] * len(input_ids)
        labels = list(input_ids)
        ans_start = prompt_len
        start_locs = min(ans_start, len(input_ids) - 1)
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "start_locs": start_locs,
            "answer_len_kept": int(answer_len_kept),
            "labels": labels,
            "split": 1 if row["split"] == "forget" else 0,
        }

In [None]:
# Create dataset and dataloader
# Reduce max_length slightly to lower memory footprint
batch_size = 2
train_data = pd.concat([retain_train_df, forget_train_df], ignore_index=True)

train_data = train_data.dropna(subset=["input", "output"]).reset_index(drop=True)

dataset = UnlearningDataset(train_data, tokenizer, max_length=384)

# Dynamic padding collate_fn
def sku_collate_fn(batch, pad_id, max_length=384):
    bs = len(batch)
    lengths = [min(len(item['input_ids']), max_length) for item in batch]
    max_len = max(lengths) if lengths else 1
    input_ids = torch.full((bs, max_len), pad_id, dtype=torch.long)
    attention_mask = torch.zeros((bs, max_len), dtype=torch.long)
    labels = torch.full((bs, max_len), -100, dtype=torch.long)
    start_locs = []
    answer_lens = []
    splits = []
    for i, item in enumerate(batch):
        ids = item['input_ids'][:max_length]
        lbls = item['labels'][:max_length]
        L = len(ids)
        if L > 0:
            input_ids[i, :L] = torch.tensor(ids, dtype=torch.long)
            attention_mask[i, :L] = 1
            labels[i, :L] = torch.tensor(lbls, dtype=torch.long)
        s = min(int(item['start_locs']), max(L - 1, 0))
        start_locs.append(s)
        # Clamp answer length to what's actually present after truncation
        kept = int(item.get('answer_len_kept', 0))
        kept = max(0, min(kept, max(0, L - s)))
        answer_lens.append(kept)
        splits.append(int(item['split']))
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
        'start_locs': torch.tensor(start_locs, dtype=torch.long),
        'answer_len_kept': torch.tensor(answer_lens, dtype=torch.long),
        'split': torch.tensor(splits, dtype=torch.long),
    }

pad_id = tokenizer.pad_token_id or tokenizer.eos_token_id
use_pin_memory = torch.cuda.is_available()
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda b: sku_collate_fn(b, pad_id, max_length=384),
    pin_memory=use_pin_memory,
    num_workers=2 if use_pin_memory else 0,
    persistent_workers=False,
)

print(f"Dataset creato con {len(dataset)} esempi")

## 4. Selective Knowledge Negation Trainer

In [None]:
class SelectiveKnowledgeNegationTrainer:
    """
    Trainer implementing Selective Knowledge Negation Unlearning (SKU).

    Core idea:
    - For retain samples: optimize the standard language modeling loss (cross-entropy) to preserve knowledge.
    - For forget samples: minimize the probability of producing the forbidden answer tokens via token-level
      unlikelihood loss on the answer span while still keeping CE on the prompt context to stabilize training.

    Added enhancements:
    - L2 anchoring of trainable weights to their initial values (helps preserve general knowledge).
    - Entropy regularization on the forget answer span (makes distribution flat to reduce memorization).
    - Optional refusal-target CE on forget answer span to steer toward a safe response template.
    - Cosine LR scheduler with warmup and unlikelihood ramp-up for stable training.
    """

    def __init__(self, model_path, tokenizer, lora_config, device="cuda", refusal_text: str = " I cannot comply with that request."):
        self.model_path = model_path
        self.tokenizer = tokenizer
        self.lora_config = lora_config
        self.device = device
        self.refusal_text = refusal_text

        self.model = None
        self.base_model = None
        self.initial_state_dict = {}

    def _count_trainable(self, model):
        total = sum(p.numel() for p in model.parameters())
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        return trainable, total

    def setup_model(self):
        print("🔧 Setting up model (LoRA)...")
        base_model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            local_files_only=True
        )

        self.base_model = base_model

        model = get_peft_model(base_model, self.lora_config).to(self.device)
        
        # Disable KV cache during training to save memory
        try:
            model.config.use_cache = False
        except Exception:
            pass
        
        self.model = model

        # Report trainables and snapshot initial trainable params
        try:
            self.model.print_trainable_parameters()
        except Exception:
            trainable, total = self._count_trainable(self.model)
            print(f"Trainable params: {trainable} / {total}")
        # Warn if still zero
        tcount, _ = self._count_trainable(self.model)
        if tcount == 0:
            print("❗No trainable parameters detected. Training will be a no-op and backward will be skipped. Check target_modules.")
        self.initial_state_dict.clear()
        for name, p in self.model.named_parameters():
            if p.requires_grad:
                self.initial_state_dict[name] = p.data.clone()
        print("✅ Model setup completed")

    def _compute_span_masks_targets_precise(self, attention_mask, start_locs, answer_len_kept):
        B, T = attention_mask.shape
        device = attention_mask.device
        prompt_mask_tgt = torch.zeros((B, T), dtype=torch.bool, device=device)
        answer_mask_tgt = torch.zeros((B, T), dtype=torch.bool, device=device)
        for i in range(B):
            s = int(start_locs[i].item()) if torch.is_tensor(start_locs[i]) else int(start_locs[i])
            s = max(0, min(s, T - 1))
            L = int(answer_len_kept[i].item()) if torch.is_tensor(answer_len_kept[i]) else int(answer_len_kept[i])
            # shift-to-target alignment: target indices correspond to positions 1..T-1
            split_t = s - 1
            if split_t >= 0:
                # Prompt is [0 .. split_t-1]
                if split_t > 0:
                    prompt_mask_tgt[i, :split_t] = True
                # Answer is [split_t .. split_t+L-1], clipped to T
                if L > 0:
                    end_pos = min(T - 1, split_t + L - 1)
                    answer_mask_tgt[i, split_t : end_pos + 1] = True
            else:
                # No prompt tokens; all start as answer, but limit to L
                if L > 0:
                    end_pos = min(T - 1, L - 1)
                    answer_mask_tgt[i, : end_pos + 1] = True
        return prompt_mask_tgt, answer_mask_tgt

    def _cross_entropy_loss(self, logits, labels, loss_mask):
        vocab = logits.size(-1)
        # Use reshape instead of view to safely handle non-contiguous tensors
        logits_flat = logits.reshape(-1, vocab)
        labels_flat = labels.reshape(-1)
        mask_flat = loss_mask.reshape(-1)
        if mask_flat.sum() == 0:
            # Return a zero that's attached to the current graph if logits require grad; else plain zero
            return (logits.sum() * 0.0) if logits.requires_grad else torch.tensor(0.0, device=logits.device)
        return F.cross_entropy(
            logits_flat[mask_flat],
            labels_flat[mask_flat],
            reduction="mean",
        )

    def _unlikelihood_loss(self, logits, labels, loss_mask):
        """Unlikelihood loss with safe indexing: ignore invalid targets and prevent OOB gather."""
        V = logits.size(-1)
        device = logits.device
        # Build validity mask for target indices
        pad_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
        eos_id = self.tokenizer.eos_token_id
    
        labels_long = labels.long()
        valid_tokens = (labels_long >= 0) & (labels_long < V) & (labels_long != -100)
        if pad_id is not None:
            valid_tokens = valid_tokens & (labels_long != pad_id)
        if eos_id is not None:
            valid_tokens = valid_tokens & (labels_long != eos_id)
    
        # Replace invalid indices with 0 to keep gather in-bounds; they will be masked out later
        safe_idx = labels_long.clamp(min=0, max=V - 1)
    
        probs = F.softmax(logits, dim=-1)
        p_y_all = torch.gather(probs, dim=-1, index=safe_idx.unsqueeze(-1)).squeeze(-1)
    
        mask = (loss_mask & valid_tokens)
        if mask.sum().item() == 0:
            return (logits.sum() * 0.0) if logits.requires_grad else torch.tensor(0.0, device=device)
    
        p_y = p_y_all[mask].clamp(1e-6, 1 - 1e-6)
        ul = -torch.log(1.0 - p_y)
        return ul.mean()

    def _entropy_on_mask(self, logits, loss_mask):
        # Computes mean entropy H(p) over masked positions
        probs = F.softmax(logits, dim=-1)
        eps = 1e-8
        ent = -(probs * (probs.clamp_min(eps).log())).sum(dim=-1)
        ent = ent[loss_mask]
        if ent.numel() == 0:
            return (logits.sum() * 0.0) if logits.requires_grad else torch.tensor(0.0, device=logits.device)
        return ent.mean()

    def _l2_anchor(self):
        if not self.initial_state_dict:
            return torch.tensor(0.0, device=self.device)
        total = None
        for name, p in self.model.named_parameters():
            if p.requires_grad and name in self.initial_state_dict:
                diff = p - self.initial_state_dict[name].to(p.device, dtype=p.dtype)
                term = (diff * diff).sum()
                total = term if total is None else total + term
        if total is None:
            total = torch.tensor(0.0, device=self.device)
        return total / max(1, len(self.initial_state_dict))

    def _build_refusal_targets(self, answer_mask_tgt, forget_mask, Tm1):
        # answer_mask_tgt, forget_mask: [B, T-1] bool
        B = answer_mask_tgt.size(0)
        device = answer_mask_tgt.device
        # Tokenize refusal template once
        refusal_ids = self.tokenizer(
            self.refusal_text,
            add_special_tokens=False,
            padding=False,
            truncation=True,
            max_length=Tm1,
            return_tensors=None,
        )["input_ids"]
        if len(refusal_ids) == 0:
            # Fallback to EOS
            rid = self.tokenizer.eos_token_id
            refusal_ids = [rid if rid is not None else 0]
        R = len(refusal_ids)
        # Prepare targets of shape [B, T-1] filled with -100
        target_ref = torch.full((B, Tm1), -100, dtype=torch.long, device=device)
        # For each sample, fill the answer span positions with repeated refusal ids
        idxs = torch.arange(Tm1, device=device)
        for i in range(B):
            mask = (answer_mask_tgt[i] & forget_mask[i])
            L = int(mask.sum().item())
            if L <= 0:
                continue
            # Repeat/trim the refusal ids to L
            seq = (refusal_ids * ((L + R - 1) // R))[:L]
            target_positions = idxs[mask]
            target_ref[i, target_positions] = torch.tensor(seq, dtype=torch.long, device=device)
        return target_ref

    def train(
        self,
        dataloader,
        num_epochs=4,
        lr=1e-4,
        ce_weight_prompt=1.0,
        ul_weight_answer=1.0,
        ce_weight_retain=1.0,
        l2_anchor_weight=0,
        l2_after_weight=5e-6,
        curriculum_switch=2,
        entropy_weight_answer=0.0,
        refusal_weight=0.0,
        grad_clip=1.0,
        grad_accum_steps=1,
        use_mixed_precision=True,
        warmup_ratio=0.1,
        ul_ramp_ratio=0.2,
    ):
        assert self.model is not None, "Call setup_model() first"
        self.model.train()
        # Optimizer: prefer 8-bit if available
        optimizer = None
        try:
            import bitsandbytes as bnb
            optimizer = bnb.optim.PagedAdamW8bit(self.model.parameters(), lr=lr, weight_decay=0.01)
            print("🧮 Using 8-bit PagedAdamW optimizer (bitsandbytes)")
        except Exception:
            try:
                optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=0.01, fused=True)
                print("🧮 Using fused AdamW optimizer")
            except Exception:
                optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=0.01)
                print("🧮 Using standard AdamW optimizer")

        # Scheduler with warmup and cosine decay
        steps_per_epoch = max(1, (len(dataloader) + grad_accum_steps - 1) // grad_accum_steps)
        total_steps = num_epochs * steps_per_epoch
        warmup_steps = max(1, int(warmup_ratio * total_steps))
        ul_ramp_steps = max(1, int(ul_ramp_ratio * total_steps))
        try:
            from transformers import get_cosine_schedule_with_warmup
            scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
            print(f"📈 Using cosine scheduler with warmup ({warmup_steps}/{total_steps})")
        except Exception:
            def lr_lambda(step):
                if step < warmup_steps:
                    return float(step) / float(max(1, warmup_steps))
                progress = (step - warmup_steps) / float(max(1, total_steps - warmup_steps))
                # Cosine from 1 to 0
                return 0.5 * (1.0 + math.cos(math.pi * progress))
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
            print(f"📈 Using LambdaLR scheduler with warmup ({warmup_steps}/{total_steps})")

        # Mixed precision setup
        from contextlib import nullcontext
        device_type = 'cuda' if ('cuda' in str(self.device) and torch.cuda.is_available()) else ('mps' if ('mps' in str(self.device) and torch.backends.mps.is_available()) else None)
        param_dtype = next(self.model.parameters()).dtype if any(p.requires_grad for p in self.model.parameters()) else torch.float32
        use_amp = use_mixed_precision and device_type is not None
        amp_dtype = torch.bfloat16 if (device_type == 'cuda' and param_dtype == torch.bfloat16) else torch.float16
        scaler = torch.cuda.amp.GradScaler(enabled=(device_type == 'cuda' and amp_dtype == torch.float16 and use_amp))
        
        global_step = 0
        no_grad_batches = 0
        for epoch in range(num_epochs):
            epoch_losses = []
            cur_l2 = l2_after_weight if epoch >= curriculum_switch else l2_anchor_weight
            print(f"\n[Epoch {epoch+1}/{EPOCHS}] L2 anchor weight: {cur_l2}")
            with tqdm(total=len(dataloader), desc=f"SKU Epoch {epoch+1}") as pbar:
                optimizer.zero_grad(set_to_none=True)
                for step, batch in enumerate(dataloader):
                    input_ids = batch["input_ids"].to(self.device, non_blocking=True)
                    attention_mask = batch["attention_mask"].to(self.device, non_blocking=True)
                    labels_full = batch["labels"].to(self.device, non_blocking=True)
                    start_locs = batch["start_locs"].to(self.device, non_blocking=True)
                    answer_len_kept = batch["answer_len_kept"].to(self.device, non_blocking=True)
                    split = batch["split"].to(self.device, non_blocking=True)

                    with (torch.autocast(device_type=device_type, dtype=amp_dtype) if use_amp else nullcontext()):
                        outputs = self.model(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                            use_cache=False,
                            return_dict=True,
                        )
                        logits_full = outputs.logits  # [B, T, V]
                        logits = logits_full[:, :-1, :]
                        target = labels_full[:, 1:]
                        attn_tgt = attention_mask[:, 1:].bool()
                        # Precise masks using kept answer length
                        prompt_mask_tgt, answer_mask_tgt = self._compute_span_masks_targets_precise(attention_mask, start_locs, answer_len_kept)
                        prompt_mask_tgt = prompt_mask_tgt[:, :-1]
                        answer_mask_tgt = answer_mask_tgt[:, :-1]
                        retain_mask = (split == 0).unsqueeze(-1).expand_as(prompt_mask_tgt)
                        forget_mask = (split == 1).unsqueeze(-1).expand_as(prompt_mask_tgt)
                        valid_tgt = attn_tgt

                        # Core losses
                        retain_loss = self._cross_entropy_loss(
                            logits, target, loss_mask=(valid_tgt & retain_mask)
                        ) * ce_weight_retain
                        forget_prompt_loss = self._cross_entropy_loss(
                            logits, target, loss_mask=(valid_tgt & forget_mask & prompt_mask_tgt)
                        ) * ce_weight_prompt

                        # Ramp-up UL weight
                        ul_scale = min(1.0, float(global_step + 1) / float(max(1, ul_ramp_steps)))
                        forget_ul_loss = self._unlikelihood_loss(
                            logits, target, loss_mask=(valid_tgt & forget_mask & answer_mask_tgt)
                        ) * (ul_weight_answer * ul_scale)

                        # Entropy regularization on forget answer span (maximize entropy)
                        ent_loss = self._entropy_on_mask(
                            logits, loss_mask=(valid_tgt & forget_mask & answer_mask_tgt)
                        )
                        entropy_term = -entropy_weight_answer * ent_loss

                        # Optional refusal-target CE on forget answer span
                        refusal_term = torch.tensor(0.0, device=logits.device)
                        if refusal_weight > 0.0:
                            Tm1 = logits.size(1)
                            target_ref = self._build_refusal_targets(answer_mask_tgt, forget_mask, Tm1)
                            refusal_term = self._cross_entropy_loss(
                                logits, target_ref, loss_mask=(valid_tgt & forget_mask & answer_mask_tgt)
                            ) * refusal_weight

                        # L2 anchor on weights
                        l2_term = self._l2_anchor() * cur_l2

                        loss = retain_loss + forget_prompt_loss + forget_ul_loss + entropy_term + refusal_term + l2_term
                        loss_for_backward = loss / max(1, int(grad_accum_steps))

                    if not loss_for_backward.requires_grad:
                        no_grad_batches += 1
                        pbar.set_postfix({"Loss": f"{float(loss.detach().cpu()):.4f}", "note": "no-grad-batch"})
                        pbar.update(1)
                        continue

                    if scaler.is_enabled():
                        scaler.scale(loss_for_backward).backward()
                    else:
                        loss_for_backward.backward()

                    do_step = ((step + 1) % grad_accum_steps == 0) or (step + 1 == len(dataloader))
                    if do_step:
                        if scaler.is_enabled():
                            scaler.unscale_(optimizer)
                        if grad_clip is not None:
                            torch.nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip)
                        if scaler.is_enabled():
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            optimizer.step()
                        scheduler.step()
                        optimizer.zero_grad(set_to_none=True)
                        global_step += 1

                    epoch_losses.append(float(loss.detach().cpu()))
                    pbar.set_postfix({
                        "Loss": f"{float(loss.detach().cpu()):.4f}",
                        "RetCE": f"{float(retain_loss.detach().cpu()):.3f}",
                        "FgtCE": f"{float(forget_prompt_loss.detach().cpu()):.3f}",
                        "FgtUL": f"{float(forget_ul_loss.detach().cpu()):.3f}",
                        "Ent": f"{float(entropy_term.detach().cpu()):.3f}",
                        "Ref": f"{float(refusal_term.detach().cpu()):.3f}",
                        "L2": f"{float(l2_term.detach().cpu()):.3f}",
                    })
                    pbar.update(1)

                    # Proactive cleanup to prevent fragmentation
                    del outputs, logits_full, logits, target, attn_tgt, prompt_mask_tgt, answer_mask_tgt, retain_mask, forget_mask, valid_tgt
                    if device_type == 'cuda' and ((step + 1) % 50 == 0):
                        torch.cuda.empty_cache()
            if device_type == 'cuda':
                torch.cuda.empty_cache()
            avg_loss = float(np.mean(epoch_losses)) if epoch_losses else 0.0
            if no_grad_batches:
                print(f"ℹ️ Epoch {epoch+1}: skipped {no_grad_batches} batches with no grad signal (check trainable params)")
                no_grad_batches = 0
            print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss:.4f}")
            ckpt_dir = f"enhanced_ckpts/epoch_{epoch+1}"
            os.makedirs(ckpt_dir, exist_ok=True)
            print(f"Saving checkpoint in: {ckpt_dir}")
            self.save_model(ckpt_dir)

            adapted_local = PeftModel.from_pretrained(self.base_model, ckpt_dir)
            sensitive_prompt = forget_validation_df.iloc[0]['input']
            sensitive_answer = forget_validation_df.iloc[0]['output']
            lp_base, Lspan = compute_span_logprob(base_model_local, tokenizer, sensitive_prompt, sensitive_answer)
            lp_adapt, _ = compute_span_logprob(adapted_local, tokenizer, sensitive_prompt, sensitive_answer)
            print(f"[Epoch {epoch+1}] Span log-prob base {lp_base:.2f} -> adapted {lp_adapt:.2f} Δ {lp_adapt-lp_base:.2f}")

    def save_model(self, output_dir: str):
        """Save LoRA adapters (preferred) or fallback to saving base model weights."""
        os.makedirs(output_dir, exist_ok=True)
        try:
            # If using PEFT, this saves the adapter weights
            self.model.save_pretrained(output_dir)
            print(f"💾 Saved PEFT adapters to {output_dir}")
        except Exception as e:
            print(f"⚠️ Could not save PEFT adapters directly: {e}")
            try:
                # Fallback: try saving base model
                if hasattr(self.model, "base_model"):
                    self.model.base_model.save_pretrained(output_dir)
                    print(f"💾 Saved base model to {output_dir}")
            except Exception as e2:
                print(f"❌ Failed to save model: {e2}")

    def calculate_task_vector(self):
        """Compute delta between current trainable params and their initial snapshot."""
        delta = {}
        for name, p in self.model.named_parameters():
            if p.requires_grad and name in self.initial_state_dict:
                delta[name] = (p.data - self.initial_state_dict[name]).detach().cpu()
        return delta

## 5. Setup Trainer and Training

In [None]:
print("Detecting linear submodules for targeted adaptation...")
base_tmp = AutoModelForCausalLM.from_pretrained(MODEL_PATH, local_files_only=True)
linear_names = []
for name, module in base_tmp.named_modules():
    if isinstance(module, torch.nn.Linear):
        # Use the leaf module name
        leaf = name.split('.')[-1]
        linear_names.append(leaf)
unique_linear = sorted(set(linear_names))
print(f"Found {len(unique_linear)} unique linear leaf names (showing first 25): {unique_linear[:25]}")

# Heuristic filter: keep typical projection/feed-forward names if they exist
preferred = [n for n in unique_linear if any(k in n for k in ["q", "k", "v", "o", "proj", "gate", "up", "down", "w1", "w2", "fc", "linear"])]
# Fallback to all unique linear names if filter becomes too small
if len(preferred) < 4:
    preferred = unique_linear
print(f"Using {len(preferred)} target module names for LoRA: {preferred}")

auto_lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=32,                 # larger rank for stronger capacity
    lora_alpha=64,
    lora_dropout=0.05,
    target_modules=preferred,
)

# Rebuild trainer with new config
sku_trainer = SelectiveKnowledgeNegationTrainer(
    model_path=MODEL_PATH,
    tokenizer=tokenizer,
    lora_config=auto_lora_config,
    device=("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")),
)
sku_trainer.setup_model()

# Force LoRA params to fp32 for better small-gradient resolution
for n, p in sku_trainer.model.named_parameters():
    if p.requires_grad:
        p.data = p.data.float()

print("Trainable parameter count:", sum(p.numel() for p in sku_trainer.model.parameters() if p.requires_grad))

# Optional: improve stability/perf
if hasattr(sku_trainer.model, "gradient_checkpointing_enable"):
    try:
        sku_trainer.model.gradient_checkpointing_enable()
    except Exception:
        pass
try:
    sku_trainer.model = sku_trainer.model.to(dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32)
except Exception:
    pass

In [None]:
FORGET_UPSAMPLE = 1  # 1 to not upsample

if 'dataset' in globals():
    import pandas as pd
    base_df = dataset.data
    forget_df = base_df[base_df['split'] == 'forget']
    retain_df = base_df[base_df['split'] == 'retain']
    if FORGET_UPSAMPLE > 1 and len(forget_df) > 0:
        reps_int = int(FORGET_UPSAMPLE)
        frac_part = FORGET_UPSAMPLE - reps_int
        replicated = [forget_df]*reps_int
        if frac_part > 1e-6:
            replicated.append(forget_df.sample(frac=frac_part, replace=True, random_state=42))
        extra_forget = pd.concat(replicated, ignore_index=True)
        aug_df = pd.concat([retain_df, extra_forget], ignore_index=True)
        print(f"[Enhanced Training] Upsampled forget examples: {len(extra_forget)}")
    else:
        aug_df = base_df
    dataset = UnlearningDataset(aug_df, tokenizer, max_length=384)
    dataloader = DataLoader(
        dataset,
        batch_size=2,
        shuffle=True,
        collate_fn=lambda b: sku_collate_fn(b, pad_id, max_length=384),
        pin_memory=torch.cuda.is_available(),
        num_workers=2 if torch.cuda.is_available() else 0,
        persistent_workers=False,
    )
    print(f"[Enhanced Training] New dataset size {len(dataset)} | forget {sum(dataset.data['split']=='forget')} | retain {sum(dataset.data['split']=='retain')}")
else:
    print("[Enhanced Training] Dataset object not found; skipping upsample step.")

In [None]:
from dataclasses import dataclass

@dataclass
class TrainArgs:
    epochs: int = 5
    lr: float = 1e-3
    ul_weight: float = 10.0
    refusal_weight: float = 0.5
    ce_prompt_weight: float = 0.3
    ce_retain_weight: float = 1.0
    entropy_weight: float = 0.05
    grad_accum: int = 4
    warmup: float = 0.08
    ul_ramp: float = 0.15
    grad_clip: float = 1.0
    use_mixed_precision: bool = True
    l2_anchor: float = 0.0
    curriculum_switch: int = 2
    l2_after: float = 5e-6

TRAIN_ARGS = TrainArgs()
print(TRAIN_ARGS)

# Train SKU model
sku_trainer.train(
    dataloader=dataloader,
    num_epochs=TRAIN_ARGS.epochs,
    lr=TRAIN_ARGS.lr,
    ce_weight_prompt=TRAIN_ARGS.ce_prompt_weight,
    ul_weight_answer=TRAIN_ARGS.ul_weight,
    ce_weight_retain=TRAIN_ARGS.ce_retain_weight,
    l2_anchor_weight=TRAIN_ARGS.l2_anchor,
    l2_after_weight=TRAIN_ARGS.l2_after,
    curriculum_switch=TRAIN_ARGS.curriculum_switch,
    entropy_weight_answer=TRAIN_ARGS.entropy_weight,
    refusal_weight=TRAIN_ARGS.refusal_weight,
    grad_clip=TRAIN_ARGS.grad_clip,
    grad_accum_steps=TRAIN_ARGS.grad_accum,
    use_mixed_precision=TRAIN_ARGS.use_mixed_precision,
    warmup_ratio=TRAIN_ARGS.warmup,
    ul_ramp_ratio=TRAIN_ARGS.ul_ramp,
)

# Quick delta check: task-vector L2 norm
with torch.no_grad():
    delta = sku_trainer.calculate_task_vector()
    total_norm = 0.0
    for t in delta.values():
        total_norm += float(t.float().pow(2).sum().sqrt())
    print(f"Δ (task-vector) total L2 norm: {total_norm:.4f}")

# Create results directory
os.makedirs('balanced_results', exist_ok=True)

# Save SKU model
sku_trainer.save_model('balanced_results/balanced_model')

# Calculate and save task vector
task_vector = sku_trainer.calculate_task_vector()
torch.save(task_vector, 'balanced_results/task_vector.pt')

print("✅ Results saved in balanced_results/")
print("- balanced_model/: SKU-trained model")
print("- task_vector.pt: Task vector for future applications")

# Optional quick A/B: generate one forget sample before vs after adapters to verify effect
try:
    from transformers import AutoTokenizer, AutoModelForCausalLM
    from peft import PeftModel
    prompt = forget_validation_df.iloc[0]['input']
    base_tok = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
    if base_tok.pad_token is None:
        base_tok.pad_token = base_tok.eos_token
    base_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, local_files_only=True)
    inputs = base_tok(prompt, return_tensors='pt').to(base_model.device)
    with torch.no_grad():
        out_base = base_model.generate(**inputs, max_new_tokens=64)
    txt_base = base_tok.decode(out_base[0], skip_special_tokens=True)

    adapted = PeftModel.from_pretrained(base_model, 'balanced_results/balanced_model')
    with torch.no_grad():
        out_adapt = adapted.generate(**inputs, max_new_tokens=64)
    txt_adapt = base_tok.decode(out_adapt[0], skip_special_tokens=True)

    print("--- A/B Quick Check ---")
    print("Prompt:", prompt)
    print("Base  :", txt_base)
    print("Adapt :", txt_adapt)
except Exception as e:
    print(f"A/B quick check skipped: {e}")

# 6. Evaluation



In [None]:
# Results export helper
import getpass

def append_result(record: dict, file_path: str = "evaluation_results.jsonl"):
    rec = {
        **record,
        "meta": {
            "user": getpass.getuser(),
            "time": datetime.datetime.utcnow().isoformat() + "Z",
            "seed": SEED,
            "config": asdict(CFG),
            "env": ENV_INFO,
            "notebook": "SKU_copia_4.ipynb"
        }
    }
    with open(file_path, "a", encoding="utf-8") as f:
        f.write(json.dumps(rec) + "\n")
    print(f"Appended results to {file_path}")

In [None]:
import types
from evaluation import inference, mia_attacks, compute_metrics

DEFAULT_CHECKPOINT = 'balanced_results/balanced_model'
LOCAL_VALIDATION_DIR = 'validation'


def run_evaluation(
    data_path: str = LOCAL_VALIDATION_DIR,
    checkpoint_path: str = DEFAULT_CHECKPOINT,
    output_dir: str = "eval_results",
    mia_data_path=None,
    mmlu_metrics_file_path=None,
    max_new_tokens: int = 256,
    batch_size: int = 16,
    debug: bool = False,
    compute_metrics_only: bool = False,
    seed: int = 42,
    keep_files: bool = False,
):
    try:
        args = types.SimpleNamespace(
            data_path=data_path,
            checkpoint_path=checkpoint_path,
            output_dir=output_dir,
            mia_data_path=mia_data_path,
            mmlu_metrics_file_path=mmlu_metrics_file_path,
            max_new_tokens=max_new_tokens,
            batch_size=batch_size,
            debug=debug,
            compute_metrics_only=compute_metrics_only,
            seed=seed,
            keep_files=keep_files,
        )
        print(f"🔍 Paths:\n  data={data_path}\n  ckpt={checkpoint_path}\n  out={output_dir}")
        if not os.path.exists(data_path):
            raise FileNotFoundError(f"Data path not found: {data_path}")
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
        for split_file in ["forget.jsonl", "retain.jsonl"]:
            if not os.path.exists(os.path.join(data_path, split_file)):
                raise FileNotFoundError(f"Missing {split_file} in {data_path}")

        from pathlib import Path as _P
        _P(output_dir).mkdir(parents=True, exist_ok=True)

        import random, torch, numpy as np
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

        from accelerate import Accelerator
        accelerator = Accelerator()

        if not args.compute_metrics_only:
            from transformers import AutoModelForCausalLM, AutoTokenizer
            from peft import PeftModel
            base_tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
            if base_tokenizer.pad_token is None:
                base_tokenizer.pad_token = base_tokenizer.eos_token
            base_model = AutoModelForCausalLM.from_pretrained(
                MODEL_PATH,
                local_files_only=True,
                torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            )
            try:
                model = PeftModel.from_pretrained(base_model, args.checkpoint_path)
                print("✅ Loaded base + adapters")
            except Exception as e:
                print(f"⚠️ Adapter load failed ({e}); trying plain model")
                model = AutoModelForCausalLM.from_pretrained(
                    args.checkpoint_path,
                    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
                    trust_remote_code=True,
                )
            model.eval()
            print("🚀 Inference...")
            inference(args, model, base_tokenizer)
            if args.mia_data_path:
                print("🔍 MIA attacks...")
                mia_attacks(args, model, base_tokenizer)
        if accelerator.is_main_process:
            print("📊 Metrics...")
            compute_metrics(args)
            print("✅ Done")
    except Exception as e:
        print(f"❌ Evaluation error: {e}")
        import traceback; traceback.print_exc()

print("🎯 Launching local evaluation (if validation data present)...")
if os.path.exists(LOCAL_VALIDATION_DIR) and all(os.path.exists(os.path.join(LOCAL_VALIDATION_DIR, f)) for f in ["forget.jsonl", "retain.jsonl"]):
    if os.path.exists(DEFAULT_CHECKPOINT):
        run_evaluation()
    else:
        print(f"❌ Missing checkpoint at {DEFAULT_CHECKPOINT}")
else:
    print(f"❌ Validation directory {LOCAL_VALIDATION_DIR} with required jsonl files not found")

In [None]:
# === 8. Diagnostics & Debugging for SKU Effectiveness ===
import torch, math
from collections import Counter
from transformers import AutoModelForCausalLM, AutoTokenizer

print("[Diagnostics] Starting SKU debugging block...")

# 1. Dataset split distribution & basic length stats
if 'dataset' in globals():
    split_counts = Counter(dataset.data['split']) if hasattr(dataset, 'data') else {}
    print("Split counts:", split_counts)
    # Approx prompt / answer token stats (first 200 samples)
    prompt_lens = []
    answer_lens = []
    for i in range(min(200, len(dataset))):
        item = dataset[i]
        prompt_lens.append(int(item['start_locs']))
        answer_lens.append(int(item['answer_len_kept']))
    if prompt_lens:
        print(f"Avg prompt tokens: {sum(prompt_lens)/len(prompt_lens):.1f} | Avg kept answer tokens: {sum(answer_lens)/len(answer_lens):.1f}")
        zero_ans = sum(1 for x in answer_lens if x == 0)
        print(f"Samples with 0 answer tokens kept: {zero_ans}/{len(answer_lens)} ({100*zero_ans/len(answer_lens):.1f}%)")
else:
    print("Dataset not found in globals().")

# 2. Inspect one batch to see UL active positions
if 'dataloader' in globals():
    first_batch = next(iter(dataloader))
    # Move minimal tensors
    input_ids = first_batch['input_ids']
    attention_mask = first_batch['attention_mask']
    start_locs = first_batch['start_locs']
    answer_len_kept = first_batch['answer_len_kept']
    split = first_batch['split']
    if 'sku_trainer' in globals():
        with torch.no_grad():
            prompt_mask_tgt, answer_mask_tgt = sku_trainer._compute_span_masks_targets_precise(attention_mask, start_locs, answer_len_kept)
            # Align with target length (T-1)
            prompt_mask_tgt = prompt_mask_tgt[:, :-1]
            answer_mask_tgt = answer_mask_tgt[:, :-1]
            attn_tgt = attention_mask[:, 1:].bool()
            retain_mask = (split == 0).unsqueeze(-1).expand_as(prompt_mask_tgt)
            forget_mask = (split == 1).unsqueeze(-1).expand_as(prompt_mask_tgt)
            ul_active = (attn_tgt & forget_mask & answer_mask_tgt).sum().item()
            forget_answer_tokens = (forget_mask & answer_mask_tgt).sum().item()
            print(f"UL active positions in sample batch: {ul_active}")
            print(f"Total forget answer target positions in sample batch: {forget_answer_tokens}")
            if forget_answer_tokens == 0:
                print("WARNING: No forget answer tokens available; unlikelihood loss will be zero. Consider increasing max_length or shortening prompts.")
    else:
        print("sku_trainer not defined.")
else:
    print("Dataloader not found.")

# 3. Function to compute log-prob of a sensitive answer span before & after adapters

def compute_span_logprob(model, tokenizer, prompt, span_text, device=None):
    device = device or (next(model.parameters()).device if any(p.requires_grad for p in model.parameters()) else 'cpu')
    model.eval()
    with torch.no_grad():
        tok = tokenizer(prompt + span_text, return_tensors='pt')
        for k in tok: tok[k] = tok[k].to(device)
        outputs = model(**tok, use_cache=False, return_dict=True)
        logits = outputs.logits # [1, T, V]
        input_ids = tok['input_ids']
        # We want log P(span | prompt). Identify boundary.
        prompt_ids = tokenizer(prompt, add_special_tokens=False)['input_ids']
        plen = len(prompt_ids)
        # Shift for causal LM
        target_ids = input_ids[:, 1:]  # next-token targets
        logits_shifted = logits[:, :-1, :]
        # Positions corresponding to span tokens
        span_positions = list(range(plen, input_ids.size(1)-1))  # exclude last because of shift alignment
        if not span_positions:
            return float('nan'), 0
        log_probs = torch.log_softmax(logits_shifted[0, span_positions, :], dim=-1)
        tgt_tokens = target_ids[0, span_positions]
        gathered = log_probs[range(len(span_positions)), tgt_tokens]
        return gathered.sum().item(), len(span_positions)

# 4. Compare probability of original answer phrase (greedy sensitive segment)
try:
    sensitive_prompt = forget_validation_df.iloc[0]['input']
    # Use ground-truth output field as sensitive answer to suppress
    sensitive_answer = forget_validation_df.iloc[0]['output']
    base_tok = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
    if base_tok.pad_token is None: base_tok.pad_token = base_tok.eos_token
    base_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, local_files_only=True, torch_dtype=torch.float32)
    lp_base, Lspan = compute_span_logprob(base_model, base_tok, sensitive_prompt, sensitive_answer)
    print(f"Base span log-prob (sum over {Lspan} tokens): {lp_base:.2f}")
    adapted = None
    from peft import PeftModel
    if os.path.exists('enhanced_ckpts/epoch_1'):
        try:
            adapted = PeftModel.from_pretrained(base_model, 'enhanced_ckpts/epoch_1')
        except Exception as e:
            print("Could not load adapters from sku_model_epoch_last:", e)
    elif os.path.exists('enhanced_ckpts/epoch_1'):
        try:
            adapted = PeftModel.from_pretrained(base_model, 'enhanced_ckpts/epoch_1')
        except Exception as e:
            print("Could not load adapters from enhanced_ckpts/epoch_1:", e)
    if adapted is not None:
        lp_adapt, _ = compute_span_logprob(adapted, base_tok, sensitive_prompt, sensitive_answer)
        print(f"Adapted span log-prob: {lp_adapt:.2f}")
        if math.isfinite(lp_base) and math.isfinite(lp_adapt):
            delta = lp_adapt - lp_base
            print(f"Δ log-prob (adapted - base): {delta:.2f} (negative desired for forgetting)")
            if delta > -0.5:
                print("Span probability not sufficiently reduced. Consider stronger ul_weight_answer, higher lr, or upsampling forget examples.")
    else:
        print("No adapted model directory found for probability comparison.")
except Exception as e:
    print("Span log-prob comparison skipped:", e)

# 5. Recommendations print based on quick heuristics
print("\n[Heuristic Recommendations]")
if 'split_counts' in locals() and split_counts:
    total = sum(split_counts.values())
    fgt = split_counts.get('forget', 0)
    if fgt / max(1,total) < 0.2:
        print("- Forget examples <20%: upsample forget or increase ul_weight_answer/refusal_weight.")
if 'answer_lens' in locals() and answer_lens:
    if sum(1 for x in answer_lens if x==0) / len(answer_lens) > 0.3:
        print("- Many samples lose answer tokens (truncation). Increase max_length or shorten prompts.")
print("- If Δ log-prob ~0, raise lr (e.g., 1e-4), set ul_weight_answer 6-8, set refusal_weight 0.3, temporarily disable l2_anchor.")
print("- Use sampling (top_p=0.9, temperature=0.8) for qualitative A/B instead of greedy only.")
print("[Diagnostics] Complete.")

## References & Appendix

- Paper: Task Vectors in Language Models. arXiv:2402.10058
- Codebase: This notebook builds on project modules in this repository (see `config.py`, `training_utils.py`, `evaluation_utils.py`).

Academic Integrity: This notebook is prepared for coursework submission. All external sources are cited; experiments are reproducible with fixed seeds and logged environment details.

In [None]:
# Reproducibility manifest
manifest = {
    'config': asdict(CFG),
    'train_args': TRAIN_ARGS.__dict__,
    'env': ENV_INFO,
    'data_counts': {
        'retain_train': len(retain_train_df),
        'forget_train': len(forget_train_df),
        'retain_validation': len(retain_validation_df),
        'forget_validation': len(forget_validation_df)
    }
}
with open('outputs/manifest.json', 'w', encoding='utf-8') as f:
    json.dump(manifest, f, indent=2)
print('Wrote outputs/manifest.json')