# Selective Knowledge Negation Unlearning


## 1. Setup e Import

In [None]:
!pip install rouge-score transformers peft huggingface_hub pyarrow
import torch
import pandas as pd
import numpy as np
import json
import os
import math
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from rouge_score import rouge_scorer

# Configurazioni
MODEL_PATH = "semeval25-unlearning-1B-model"

print(f"GPUs disponibili: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

## 2. Caricamento Dati e Modelli

In [None]:
from huggingface_hub import snapshot_download

# Helpers to load datasets from either JSONL or Parquet

def load_table(path: str) -> pd.DataFrame:
    if not os.path.exists(path):
        raise FileNotFoundError(f"File non trovato: {path}")
    ext = os.path.splitext(path)[1].lower()
    if ext in [".jsonl", ".json"]:
        return pd.read_json(path, lines=True)
    if ext in [".parquet"]:
        # requires pyarrow or fastparquet
        return pd.read_parquet(path)
    raise ValueError(f"Formato file non supportato: {ext}")

# Canonical schema
schema_cols = ["input", "output", "split"]

# Try to coerce to expected schema
def ensure_schema(df: pd.DataFrame, path: str) -> pd.DataFrame:
    # map common alt column names
    input_candidates = ["input", "prompt", "question", "instruction", "query", "source", "text"]
    output_candidates = ["output", "answer", "response", "target", "completion"]

    # choose first available
    in_col = next((c for c in input_candidates if c in df.columns), None)
    out_col = next((c for c in output_candidates if c in df.columns), None)

    df2 = pd.DataFrame()
    df2["input"] = df[in_col] if in_col else np.nan
    df2["output"] = df[out_col] if out_col else np.nan

    # split: keep if exists, else infer from filename, else retain
    if "split" in df.columns:
        df2["split"] = df["split"]
    else:
        base = os.path.basename(path).lower()
        if "forget" in base:
            df2["split"] = "forget"
        else:
            df2["split"] = "retain"

    # ensure types are strings where relevant
    df2["input"] = df2["input"].astype(str)
    df2["output"] = df2["output"].astype(str)
    return df2[schema_cols]

# Caricamento dati
retain_train_path = 'train/retain.jsonl'  # se esiste
forget_train_path = 'train/forget_train-00000-of-00001.parquet'  # nuovo formato
retain_val_path = 'validation/retain.jsonl'
forget_val_path = 'validation/forget.jsonl'

# Carica con fallback: se un file non esiste, usa un DF vuoto coerente

def safe_load(path):
    try:
        df = load_table(path)
        df = ensure_schema(df, path)
        return df
    except FileNotFoundError:
        return pd.DataFrame(columns=schema_cols)

retain_train_df = safe_load(retain_train_path)
forget_train_df = safe_load(forget_train_path)
retain_validation_df = safe_load(retain_val_path)
forget_validation_df = safe_load(forget_val_path)

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Prefer right padding for causal LM efficiency
tokenizer.padding_side = 'right'

snapshot_download(repo_id='llmunlearningsemeval2025organization/olmo-1B-model-semeval25-unlearning', local_dir='semeval25-unlearning-1B-model')

print("Dataset salvati e tokenizer caricato")
print({
    'retain_train': len(retain_train_df),
    'forget_train': len(forget_train_df),
    'retain_validation': len(retain_validation_df),
    'forget_validation': len(forget_validation_df),
})

## 3. Dataset

In [None]:
# Sanity check lettura
print("Train (retain, forget) sizes:", len(retain_train_df), len(forget_train_df))
print("Columns:", list(train_data.columns))
print(train_data.head(2))

In [None]:
class UnlearningDataset(Dataset):
    def __init__(self, data_source, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        if isinstance(data_source, pd.DataFrame):
            self.data = data_source
            print(f"Caricati {len(self.data)} esempi dal DataFrame")
        elif isinstance(data_source, str):
            # Supporto a file .jsonl e .parquet
            ext = os.path.splitext(data_source)[1].lower()
            if ext in [".jsonl", ".json"]:
                data_list = []
                with open(data_source, 'r', encoding='utf-8') as f:
                    for line in f:
                        item = json.loads(line.strip())
                        data_list.append(item)
                self.data = pd.DataFrame(data_list)
            elif ext in [".parquet"]:
                self.data = pd.read_parquet(data_source)
            else:
                raise ValueError(f"Formato file non supportato: {ext}")
            # Normalizza colonne necessarie
            for col in ["input", "output", "split"]:
                if col not in self.data.columns:
                    self.data[col] = np.nan
            print(f"Caricati {len(self.data)} esempi da {data_source}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        prompt_text = str(self.data.iloc[idx]["input"])  # coerce to string
        answer_raw = self.data.iloc[idx]["output"]
        answer_text = str(answer_raw)

        # Tokenize prompt and answer separately to get a robust boundary
        prompt_tok = self.tokenizer(
            prompt_text,
            add_special_tokens=False,
            padding=False,
            truncation=True,
            max_length=self.max_length,
            return_tensors=None,
        )
        # Ensure a leading space for the answer for stable tokenization after the prompt
        answer_text_sp = answer_text if answer_text.startswith(" ") else (" " + answer_text)
        answer_tok = self.tokenizer(
            answer_text_sp,
            add_special_tokens=False,
            padding=False,
            truncation=True,
            max_length=self.max_length,
            return_tensors=None,
        )

        prompt_ids = prompt_tok["input_ids"]
        answer_ids = answer_tok["input_ids"]

        # Compute how many answer tokens survive after truncation
        prompt_len = len(prompt_ids)
        ans_len = len(answer_ids)
        if prompt_len >= self.max_length:
            # No room for answer tokens
            input_ids = prompt_ids[: self.max_length]
            answer_len_kept = 0
        else:
            available = self.max_length - prompt_len
            answer_len_kept = min(ans_len, max(0, available))
            input_ids = prompt_ids + answer_ids[:answer_len_kept]
        
        attention_mask = [1] * len(input_ids)
        labels = list(input_ids)

        # Start index of the first answer token (before shift)
        ans_start = prompt_len
        start_locs = min(ans_start, len(input_ids) - 1)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "start_locs": start_locs,
            "answer_len_kept": int(answer_len_kept),
            "labels": labels,
            "split": 1 if self.data.iloc[idx]["split"] == "forget" else 0,
        }

In [None]:
# Create dataset and dataloader
# Reduce max_length slightly to lower memory footprint
batch_size = 2

# Concatena train retain + forget (Parquet)
train_data = pd.concat([retain_train_df, forget_train_df], ignore_index=True)

# Filtra righe prive di input/output
train_data = train_data.dropna(subset=["input", "output"]).reset_index(drop=True)

dataset = UnlearningDataset(train_data, tokenizer, max_length=384)

# Dynamic padding collate_fn
def sku_collate_fn(batch, pad_id, max_length=384):
    bs = len(batch)
    lengths = [min(len(item['input_ids']), max_length) for item in batch]
    max_len = max(lengths) if lengths else 1
    input_ids = torch.full((bs, max_len), pad_id, dtype=torch.long)
    attention_mask = torch.zeros((bs, max_len), dtype=torch.long)
    labels = torch.full((bs, max_len), -100, dtype=torch.long)
    start_locs = []
    answer_lens = []
    splits = []
    for i, item in enumerate(batch):
        ids = item['input_ids'][:max_length]
        am = item['attention_mask'][:max_length]
        input_ids[i, :len(ids)] = torch.tensor(ids)
        attention_mask[i, :len(am)] = torch.tensor(am)
        labels[i, :len(ids)] = torch.tensor(item['labels'][:max_length])
        start_locs.append(min(item['start_locs'], max_length - 1))
        answer_lens.append(min(item['answer_len_kept'], max_length))
        splits.append(item['split'])
    batch_out = {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
        'start_locs': torch.tensor(start_locs, dtype=torch.long),
        'answer_len_kept': torch.tensor(answer_lens, dtype=torch.long),
        'split': torch.tensor(splits, dtype=torch.long)
    }
    return batch_out

## 4. Selective Knowledge Negation Trainer

In [None]:
class SelectiveKnowledgeNegationTrainer:
    """
    Trainer implementing Selective Knowledge Negation Unlearning (SKU).

    Core idea:
    - For retain samples: optimize the standard language modeling loss (cross-entropy) to preserve knowledge.
    - For forget samples: minimize the probability of producing the forbidden answer tokens via token-level
      unlikelihood loss on the answer span while still keeping CE on the prompt context to stabilize training.

    Added enhancements:
    - L2 anchoring of trainable weights to their initial values (helps preserve general knowledge).
    - Entropy regularization on the forget answer span (makes distribution flat to reduce memorization).
    - Optional refusal-target CE on forget answer span to steer toward a safe response template.
    - Cosine LR scheduler with warmup and unlikelihood ramp-up for stable training.
    """

    def __init__(self, model_path, tokenizer, lora_config, device="cuda:0", refusal_text: str = " I cannot comply with that request."):
        self.model_path = model_path
        self.tokenizer = tokenizer
        self.lora_config = lora_config
        self.device = device
        self.refusal_text = refusal_text

        self.model = None
        self.initial_state_dict = {}

    def _count_trainable(self, model):
        total = sum(p.numel() for p in model.parameters())
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        return trainable, total

    def setup_model(self):
        print("🔧 Setting up model (LoRA)...")
        base_model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            local_files_only=True
        )

        model = get_peft_model(base_model, self.lora_config).to(self.device)
        
        # Disable KV cache during training to save memory
        try:
            model.config.use_cache = False
        except Exception:
            pass
        
        self.model = model

        # Report trainables and snapshot initial trainable params
        try:
            self.model.print_trainable_parameters()
        except Exception:
            trainable, total = self._count_trainable(self.model)
            print(f"Trainable params: {trainable} / {total}")
        # Warn if still zero
        tcount, _ = self._count_trainable(self.model)
        if tcount == 0:
            print("❗No trainable parameters detected. Training will be a no-op and backward will be skipped. Check target_modules.")
        self.initial_state_dict.clear()
        for name, p in self.model.named_parameters():
            if p.requires_grad:
                self.initial_state_dict[name] = p.data.clone()
        print("✅ Model setup completed")

    def _compute_span_masks_targets_precise(self, attention_mask, start_locs, answer_len_kept):
        B, T = attention_mask.shape
        device = attention_mask.device
        prompt_mask_tgt = torch.zeros((B, T), dtype=torch.bool, device=device)
        answer_mask_tgt = torch.zeros((B, T), dtype=torch.bool, device=device)
        for i in range(B):
            s = int(start_locs[i].item()) if torch.is_tensor(start_locs[i]) else int(start_locs[i])
            s = max(0, min(s, T - 1))
            L = int(answer_len_kept[i].item()) if torch.is_tensor(answer_len_kept[i]) else int(answer_len_kept[i])
            # shift-to-target alignment: target indices correspond to positions 1..T-1
            split_t = s - 1
            if split_t >= 0:
                # Prompt is [0 .. split_t-1]
                if split_t > 0:
                    prompt_mask_tgt[i, :split_t] = True
                # Answer is [split_t .. split_t+L-1], clipped to T
                if L > 0:
                    end_pos = min(T - 1, split_t + L - 1)
                    answer_mask_tgt[i, split_t : end_pos + 1] = True
            else:
                # No prompt tokens; all start as answer, but limit to L
                if L > 0:
                    end_pos = min(T - 1, L - 1)
                    answer_mask_tgt[i, : end_pos + 1] = True
        return prompt_mask_tgt, answer_mask_tgt

    def _cross_entropy_loss(self, logits, labels, loss_mask):
        vocab = logits.size(-1)
        # Use reshape instead of view to safely handle non-contiguous tensors
        logits_flat = logits.reshape(-1, vocab)
        labels_flat = labels.reshape(-1)
        mask_flat = loss_mask.reshape(-1)
        if mask_flat.sum() == 0:
            # Return a zero that's attached to the current graph if logits require grad; else plain zero
            return (logits.sum() * 0.0) if logits.requires_grad else torch.tensor(0.0, device=logits.device)
        return F.cross_entropy(
            logits_flat[mask_flat],
            labels_flat[mask_flat],
            reduction="mean",
        )

    def _unlikelihood_loss(self, logits, labels, loss_mask):
        probs = F.softmax(logits, dim=-1)
        y = labels.unsqueeze(-1)
        p_y = torch.gather(probs, dim=-1, index=y).squeeze(-1)
        eps = 1e-6
        p_y = p_y.clamp_min(eps).clamp_max(1 - eps)
        ul = -torch.log(1.0 - p_y)
        # Exclude pad/eos targets from UL (noisy gradients)
        pad_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
        eos_id = self.tokenizer.eos_token_id
        valid_tokens = (labels != pad_id) & (labels != -100)
        if eos_id is not None:
            valid_tokens = valid_tokens & (labels != eos_id)
        ul = ul[loss_mask & valid_tokens]
        if ul.numel() == 0:
            # Return a zero that's attached to the current graph if logits require grad; else plain zero
            return (logits.sum() * 0.0) if logits.requires_grad else torch.tensor(0.0, device=logits.device)
        return ul.mean()

    def _entropy_on_mask(self, logits, loss_mask):
        # Computes mean entropy H(p) over masked positions
        probs = F.softmax(logits, dim=-1)
        eps = 1e-8
        ent = -(probs * (probs.clamp_min(eps).log())).sum(dim=-1)
        ent = ent[loss_mask]
        if ent.numel() == 0:
            return (logits.sum() * 0.0) if logits.requires_grad else torch.tensor(0.0, device=logits.device)
        return ent.mean()

    def _l2_anchor(self):
        if not self.initial_state_dict:
            return torch.tensor(0.0, device=self.device)
        total = None
        for name, p in self.model.named_parameters():
            if p.requires_grad and name in self.initial_state_dict:
                diff = p - self.initial_state_dict[name].to(p.device, dtype=p.dtype)
                term = (diff * diff).sum()
                total = term if total is None else total + term
        if total is None:
            total = torch.tensor(0.0, device=self.device)
        return total / max(1, len(self.initial_state_dict))

    def _build_refusal_targets(self, answer_mask_tgt, forget_mask, Tm1):
        # answer_mask_tgt, forget_mask: [B, T-1] bool
        B = answer_mask_tgt.size(0)
        device = answer_mask_tgt.device
        # Tokenize refusal template once
        refusal_ids = self.tokenizer(
            self.refusal_text,
            add_special_tokens=False,
            padding=False,
            truncation=True,
            max_length=Tm1,
            return_tensors=None,
        )["input_ids"]
        if len(refusal_ids) == 0:
            # Fallback to EOS
            rid = self.tokenizer.eos_token_id
            refusal_ids = [rid if rid is not None else 0]
        R = len(refusal_ids)
        # Prepare targets of shape [B, T-1] filled with -100
        target_ref = torch.full((B, Tm1), -100, dtype=torch.long, device=device)
        # For each sample, fill the answer span positions with repeated refusal ids
        idxs = torch.arange(Tm1, device=device)
        for i in range(B):
            mask = (answer_mask_tgt[i] & forget_mask[i])
            L = int(mask.sum().item())
            if L <= 0:
                continue
            # Repeat/trim the refusal ids to L
            seq = (refusal_ids * ((L + R - 1) // R))[:L]
            target_positions = idxs[mask]
            target_ref[i, target_positions] = torch.tensor(seq, dtype=torch.long, device=device)
        return target_ref

    def train(
        self,
        dataloader,
        num_epochs=4,
        lr=1e-4,
        ce_weight_prompt=1.0,
        ul_weight_answer=1.0,
        ce_weight_retain=1.0,
        l2_anchor_weight=1e-5,
        entropy_weight_answer=0.0,
        refusal_weight=0.0,
        grad_clip=1.0,
        grad_accum_steps=1,
        use_mixed_precision=True,
        warmup_ratio=0.1,
        ul_ramp_ratio=0.2,
    ):
        assert self.model is not None, "Call setup_model() first"
        self.model.train()
        # Optimizer: prefer 8-bit if available
        optimizer = None
        try:
            import bitsandbytes as bnb
            optimizer = bnb.optim.PagedAdamW8bit(self.model.parameters(), lr=lr, weight_decay=0.01)
            print("🧮 Using 8-bit PagedAdamW optimizer (bitsandbytes)")
        except Exception:
            try:
                optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=0.01, fused=True)
                print("🧮 Using fused AdamW optimizer")
            except Exception:
                optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=0.01)
                print("🧮 Using standard AdamW optimizer")

        # Scheduler with warmup and cosine decay
        steps_per_epoch = max(1, (len(dataloader) + grad_accum_steps - 1) // grad_accum_steps)
        total_steps = num_epochs * steps_per_epoch
        warmup_steps = max(1, int(warmup_ratio * total_steps))
        ul_ramp_steps = max(1, int(ul_ramp_ratio * total_steps))
        try:
            from transformers import get_cosine_schedule_with_warmup
            scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
            print(f"📈 Using cosine scheduler with warmup ({warmup_steps}/{total_steps})")
        except Exception:
            def lr_lambda(step):
                if step < warmup_steps:
                    return float(step) / float(max(1, warmup_steps))
                progress = (step - warmup_steps) / float(max(1, total_steps - warmup_steps))
                # Cosine from 1 to 0
                return 0.5 * (1.0 + math.cos(math.pi * progress))
            scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
            print(f"📈 Using LambdaLR scheduler with warmup ({warmup_steps}/{total_steps})")

        # Mixed precision setup
        from contextlib import nullcontext
        device_type = 'cuda' if ('cuda' in str(self.device) and torch.cuda.is_available()) else ('mps' if ('mps' in str(self.device) and torch.backends.mps.is_available()) else None)
        param_dtype = next(self.model.parameters()).dtype if any(p.requires_grad for p in self.model.parameters()) else torch.float32
        use_amp = use_mixed_precision and device_type is not None
        amp_dtype = torch.bfloat16 if (device_type == 'cuda' and param_dtype == torch.bfloat16) else torch.float16
        scaler = torch.cuda.amp.GradScaler(enabled=(device_type == 'cuda' and amp_dtype == torch.float16 and use_amp))
        
        global_step = 0
        no_grad_batches = 0
        for epoch in range(num_epochs):
            epoch_losses = []
            with tqdm(total=len(dataloader), desc=f"SKU Epoch {epoch+1}") as pbar:
                optimizer.zero_grad(set_to_none=True)
                for step, batch in enumerate(dataloader):
                    input_ids = batch["input_ids"].to(self.device, non_blocking=True)
                    attention_mask = batch["attention_mask"].to(self.device, non_blocking=True)
                    labels_full = batch["labels"].to(self.device, non_blocking=True)
                    start_locs = batch["start_locs"].to(self.device, non_blocking=True)
                    answer_len_kept = batch["answer_len_kept"].to(self.device, non_blocking=True)
                    split = batch["split"].to(self.device, non_blocking=True)

                    with (torch.autocast(device_type=device_type, dtype=amp_dtype) if use_amp else nullcontext()):
                        outputs = self.model(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                            use_cache=False,
                            return_dict=True,
                        )
                        logits_full = outputs.logits  # [B, T, V]
                        logits = logits_full[:, :-1, :]
                        target = labels_full[:, 1:]
                        attn_tgt = attention_mask[:, 1:].bool()
                        # Precise masks using kept answer length
                        prompt_mask_tgt, answer_mask_tgt = self._compute_span_masks_targets_precise(attention_mask, start_locs, answer_len_kept)
                        prompt_mask_tgt = prompt_mask_tgt[:, :-1]
                        answer_mask_tgt = answer_mask_tgt[:, :-1]
                        retain_mask = (split == 0).unsqueeze(-1).expand_as(prompt_mask_tgt)
                        forget_mask = (split == 1).unsqueeze(-1).expand_as(prompt_mask_tgt)
                        valid_tgt = attn_tgt

                        # Core losses
                        retain_loss = self._cross_entropy_loss(
                            logits, target, loss_mask=(valid_tgt & retain_mask)
                        ) * ce_weight_retain
                        forget_prompt_loss = self._cross_entropy_loss(
                            logits, target, loss_mask=(valid_tgt & forget_mask & prompt_mask_tgt)
                        ) * ce_weight_prompt

                        # Ramp-up UL weight
                        ul_scale = min(1.0, float(global_step + 1) / float(max(1, ul_ramp_steps)))
                        forget_ul_loss = self._unlikelihood_loss(
                            logits, target, loss_mask=(valid_tgt & forget_mask & answer_mask_tgt)
                        ) * (ul_weight_answer * ul_scale)

                        # Entropy regularization on forget answer span (maximize entropy)
                        ent_loss = self._entropy_on_mask(
                            logits, loss_mask=(valid_tgt & forget_mask & answer_mask_tgt)
                        )
                        entropy_term = -entropy_weight_answer * ent_loss

                        # Optional refusal-target CE on forget answer span
                        refusal_term = torch.tensor(0.0, device=logits.device)
                        if refusal_weight > 0.0:
                            Tm1 = logits.size(1)
                            target_ref = self._build_refusal_targets(answer_mask_tgt, forget_mask, Tm1)
                            refusal_term = self._cross_entropy_loss(
                                logits, target_ref, loss_mask=(valid_tgt & forget_mask & answer_mask_tgt)
                            ) * refusal_weight

                        # L2 anchor on weights
                        l2_term = self._l2_anchor() * l2_anchor_weight

                        loss = retain_loss + forget_prompt_loss + forget_ul_loss + entropy_term + refusal_term + l2_term
                        loss_for_backward = loss / max(1, int(grad_accum_steps))

                    if not loss_for_backward.requires_grad:
                        no_grad_batches += 1
                        pbar.set_postfix({"Loss": f"{float(loss.detach().cpu()):.4f}", "note": "no-grad-batch"})
                        pbar.update(1)
                        continue

                    if scaler.is_enabled():
                        scaler.scale(loss_for_backward).backward()
                    else:
                        loss_for_backward.backward()

                    do_step = ((step + 1) % grad_accum_steps == 0) or (step + 1 == len(dataloader))
                    if do_step:
                        if scaler.is_enabled():
                            scaler.unscale_(optimizer)
                        if grad_clip is not None:
                            torch.nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip)
                        if scaler.is_enabled():
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            optimizer.step()
                        scheduler.step()
                        optimizer.zero_grad(set_to_none=True)
                        global_step += 1

                    epoch_losses.append(float(loss.detach().cpu()))
                    pbar.set_postfix({
                        "Loss": f"{float(loss.detach().cpu()):.4f}",
                        "RetCE": f"{float(retain_loss.detach().cpu()):.3f}",
                        "FgtCE": f"{float(forget_prompt_loss.detach().cpu()):.3f}",
                        "FgtUL": f"{float(forget_ul_loss.detach().cpu()):.3f}",
                        "Ent": f"{float(entropy_term.detach().cpu()):.3f}",
                        "Ref": f"{float(refusal_term.detach().cpu()):.3f}",
                        "L2": f"{float(l2_term.detach().cpu()):.3f}",
                    })
                    pbar.update(1)

                    # Proactive cleanup to prevent fragmentation
                    del outputs, logits_full, logits, target, attn_tgt, prompt_mask_tgt, answer_mask_tgt, retain_mask, forget_mask, valid_tgt
                    if device_type == 'cuda' and ((step + 1) % 50 == 0):
                        torch.cuda.empty_cache()
            if device_type == 'cuda':
                torch.cuda.empty_cache()
            avg_loss = float(np.mean(epoch_losses)) if epoch_losses else 0.0
            if no_grad_batches:
                print(f"ℹ️ Epoch {epoch+1}: skipped {no_grad_batches} batches with no grad signal (check trainable params)")
                no_grad_batches = 0
            print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss:.4f}")

    def save_model(self, output_dir: str):
        """Save LoRA adapters (preferred) or fallback to saving base model weights."""
        os.makedirs(output_dir, exist_ok=True)
        try:
            # If using PEFT, this saves the adapter weights
            self.model.save_pretrained(output_dir)
            print(f"💾 Saved PEFT adapters to {output_dir}")
        except Exception as e:
            print(f"⚠️ Could not save PEFT adapters directly: {e}")
            try:
                # Fallback: try saving base model
                if hasattr(self.model, "base_model"):
                    self.model.base_model.save_pretrained(output_dir)
                    print(f"💾 Saved base model to {output_dir}")
            except Exception as e2:
                print(f"❌ Failed to save model: {e2}")

    def calculate_task_vector(self):
        """Compute delta between current trainable params and their initial snapshot."""
        delta = {}
        for name, p in self.model.named_parameters():
            if p.requires_grad and name in self.initial_state_dict:
                delta[name] = (p.data - self.initial_state_dict[name]).detach().cpu()
        return delta

## 5. Setup Trainer and Training

In [None]:
# Configure LoRA (single model)
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    # Use a safe default to attach adapters to all Linear layers
    target_modules="all-linear",
)

# Initialize SKU trainer
device_sel = "cuda:0" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
sku_trainer = SelectiveKnowledgeNegationTrainer(
    model_path=MODEL_PATH,
    tokenizer=tokenizer,
    lora_config=lora_config,
    device=device_sel,
)

# Setup model
sku_trainer.setup_model()

# Guard: ensure there are trainable parameters
trainable_params = sum(p.numel() for p in sku_trainer.model.parameters() if p.requires_grad)
print(f"Trainable parameter count: {trainable_params}")
assert trainable_params > 0, "No trainable parameters found; LoRA did not attach."

# Optional: improve stability/perf
if hasattr(sku_trainer.model, "gradient_checkpointing_enable"):
    try:
        sku_trainer.model.gradient_checkpointing_enable()
    except Exception:
        pass
try:
    sku_trainer.model = sku_trainer.model.to(dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32)
except Exception:
    pass

In [None]:
# Train SKU model
sku_trainer.train(
    dataloader=dataloader,
    num_epochs=6,
    lr=1e-4,
    ce_weight_prompt=0.5,        # reduce to avoid reinforcing harmful answers
    ul_weight_answer=7.0,        # stronger UL
    ce_weight_retain=1.0,        # preserve retain knowledge
    l2_anchor_weight=0,       # gentle anchor
    entropy_weight_answer=0.05,  # push to higher entropy on forbidden span
    refusal_weight=0.3,          # steer toward refusal template
    grad_clip=1.0,
    grad_accum_steps=4,          # keep effective batch size
    use_mixed_precision=True,    # enable autocast
    warmup_ratio=0.08,
    ul_ramp_ratio=0.15,
)

# Quick delta check: task-vector L2 norm
with torch.no_grad():
    delta = sku_trainer.calculate_task_vector()
    total_norm = 0.0
    for t in delta.values():
        total_norm += float(t.float().pow(2).sum().sqrt())
    print(f"Δ (task-vector) total L2 norm: {total_norm:.4f}")

# Save an intermediate checkpoint
sku_trainer.save_model("sku_model_epoch_last")

# Optional quick A/B: generate one forget sample before vs after adapters to verify effect
try:
    from transformers import AutoTokenizer, AutoModelForCausalLM
    from peft import PeftModel
    prompt = forget_validation_df.iloc[0]['input']
    base_tok = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
    if base_tok.pad_token is None:
        base_tok.pad_token = base_tok.eos_token
    base_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, local_files_only=True)
    inputs = base_tok(prompt, return_tensors='pt').to(base_model.device)
    with torch.no_grad():
        out_base = base_model.generate(**inputs, max_new_tokens=64)
    txt_base = base_tok.decode(out_base[0], skip_special_tokens=True)

    adapted = PeftModel.from_pretrained(base_model, "sku_model_epoch_last")
    with torch.no_grad():
        out_adapt = adapted.generate(**inputs, max_new_tokens=64)
    txt_adapt = base_tok.decode(out_adapt[0], skip_special_tokens=True)

    print("--- A/B Quick Check ---")
    print("Prompt:", prompt)
    print("Base  :", txt_base)
    print("Adapt :", txt_adapt)
except Exception as e:
    print(f"A/B quick check skipped: {e}")

## 6. Save Results and Task Vector

In [None]:
# Create results directory
os.makedirs('balanced_results', exist_ok=True)

# Save SKU model
sku_trainer.save_model('balanced_results/balanced_model')

# Calculate and save task vector
task_vector = sku_trainer.calculate_task_vector()
torch.save(task_vector, 'balanced_results/task_vector.pt')

print("✅ Results saved in balanced_results/")
print("- balanced_model/: SKU-trained model")
print("- task_vector.pt: Task vector for future applications")

# 7. Evaluation



In [None]:
import types

try:
    import evaluation
    import importlib
    importlib.reload(evaluation)
except ImportError:
    pass

def run_evaluation(
    data_path,
    checkpoint_path,
    output_dir="eval_results",
    mia_data_path=None,
    mmlu_metrics_file_path=None,
    max_new_tokens=256,
    batch_size=25,
    debug=False,
    compute_metrics_only=False,
    seed=42,
    keep_files=False,
):
    try:
        # Costruiamo un oggetto args simile a quello di argparse
        args = types.SimpleNamespace(
            data_path=data_path,
            checkpoint_path=checkpoint_path,
            output_dir=output_dir,
            mia_data_path=mia_data_path,
            mmlu_metrics_file_path=mmlu_metrics_file_path,
            max_new_tokens=max_new_tokens,
            batch_size=batch_size,
            debug=debug,
            compute_metrics_only=compute_metrics_only,
            seed=seed,
            keep_files=keep_files,
        )

        # Verifica che i file esistano
        print(f"🔍 Verificando paths...")
        print(f"  Data path: {data_path}")
        print(f"  Checkpoint path: {checkpoint_path}")
        print(f"  Output dir: {output_dir}")
        
        if not os.path.exists(data_path):
            raise FileNotFoundError(f"Data path not found: {data_path}")
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
        if not os.path.exists(os.path.join(data_path, 'forget.jsonl')):
            raise FileNotFoundError(f"forget.jsonl not found in {data_path}")
        if not os.path.exists(os.path.join(data_path, 'retain.jsonl')):
            raise FileNotFoundError(f"retain.jsonl not found in {data_path}")

        # Normalizza i path (come nello script originale)
        from pathlib import Path
        if args.output_dir is None:
            args.output_dir = os.getcwd()
        else:
            args.output_dir = args.output_dir.rstrip('/')
            Path(args.output_dir).mkdir(parents=True, exist_ok=True)

        # Lancia direttamente le funzioni
        import random, torch, numpy as np
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

        from accelerate import Accelerator
        accelerator = Accelerator()

        if not args.compute_metrics_only:
            from transformers import AutoModelForCausalLM, AutoTokenizer
            from peft import PeftModel
            
            print(f"📥 Loading base model from {MODEL_PATH} and adapters from {args.checkpoint_path}...")
            
            # Always build tokenizer from the base model to keep special tokens consistent
            base_tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
            if base_tokenizer.pad_token is None:
                base_tokenizer.pad_token = base_tokenizer.eos_token

            # Load base model and then PEFT adapters
            base_model = AutoModelForCausalLM.from_pretrained(
                MODEL_PATH, 
                local_files_only=True,
                torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
            )
            try:
                model = PeftModel.from_pretrained(base_model, args.checkpoint_path)
                print("✅ Loaded base + PEFT adapters")
            except Exception as e:
                print(f"⚠️ PEFT load failed ({e}); trying to load as a regular model")
                model = AutoModelForCausalLM.from_pretrained(
                    args.checkpoint_path,
                    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
                    trust_remote_code=True
                )

            # Sanity: check adapters are present and active
            has_peft = hasattr(model, 'peft_config') and len(getattr(model, 'peft_config', {})) > 0
            print(f"PEFT active: {has_peft}")
            if has_peft:
                print(f"Adapters: {list(model.peft_config.keys())}")
            model.eval()

            print("🚀 Starting inference...")
            evaluation.inference(args, model, base_tokenizer)
            
            if args.mia_data_path is not None:
                print("🔍 Starting MIA attacks...")
                evaluation.mia_attacks(args, model, base_tokenizer)

        if accelerator.is_main_process:
            print("📊 Computing metrics...")
            evaluation.compute_metrics(args)
            print("✅ Evaluation completed!")

    except Exception as e:
        print(f"❌ Error during evaluation: {e}")
        import traceback
        traceback.print_exc()

# === Step 4: Esegui evaluation ===
print("🎯 Starting evaluation process...")

# Verifica che i file esistano prima di iniziare
if os.path.exists("validation/forget.jsonl") and os.path.exists("validation/retain.jsonl"):
    if os.path.exists("balanced_results/balanced_model/"):
        run_evaluation(
            data_path="validation/",  # cartella relativa con forget.jsonl e retain.jsonl
            checkpoint_path="balanced_results/balanced_model/",  # cartella relativa con i pesi del modello
            output_dir="eval_results",
            debug=True  # Attiva debug per vedere cosa succede
        )
    else:
        print("❌ Model checkpoint not found at balanced_results/balanced_model/")
        print("   Make sure the training completed successfully")
else:
    print("❌ Validation files not found")
    print("   Expected: validation/forget.jsonl and validation/retain.jsonl")
    print("   Make sure the data processing completed successfully")

In [None]:
# Patch: safe unlikelihood loss to avoid CUDA index OOB in gather
import torch
import torch.nn.functional as F

# This assumes the class SelectiveKnowledgeNegationTrainer is already defined above

def _ul_safe(self, logits, labels, loss_mask):
    """Unlikelihood loss with safe indexing: ignore invalid targets and prevent OOB gather."""
    V = logits.size(-1)
    device = logits.device
    # Build validity mask for target indices
    pad_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
    eos_id = self.tokenizer.eos_token_id

    labels_long = labels.long()
    valid_tokens = (labels_long >= 0) & (labels_long < V) & (labels_long != -100)
    if pad_id is not None:
        valid_tokens = valid_tokens & (labels_long != pad_id)
    if eos_id is not None:
        valid_tokens = valid_tokens & (labels_long != eos_id)

    # Replace invalid indices with 0 to keep gather in-bounds; they will be masked out later
    safe_idx = labels_long.clamp(min=0, max=V - 1)

    probs = F.softmax(logits, dim=-1)
    p_y_all = torch.gather(probs, dim=-1, index=safe_idx.unsqueeze(-1)).squeeze(-1)

    mask = (loss_mask & valid_tokens)
    if mask.sum().item() == 0:
        return (logits.sum() * 0.0) if logits.requires_grad else torch.tensor(0.0, device=device)

    p_y = p_y_all[mask].clamp(1e-6, 1 - 1e-6)
    ul = -torch.log(1.0 - p_y)
    return ul.mean()

# Monkey-patch the method on the class if present
try:
    SelectiveKnowledgeNegationTrainer._unlikelihood_loss = _ul_safe
    print("✅ Patched SelectiveKnowledgeNegationTrainer._unlikelihood_loss with safe implementation")
except NameError:
    print("⚠️ Could not find SelectiveKnowledgeNegationTrainer to patch. Define the class first and re-run this cell.")

In [None]:
# Safe LoRA re-init to ensure trainable adapters attach
from peft import LoraConfig, TaskType

lora_config_safe = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules="all-linear",
)

# Recreate trainer to use the safe config
sku_trainer = SelectiveKnowledgeNegationTrainer(
    model_path=MODEL_PATH,
    tokenizer=tokenizer,
    lora_config=lora_config_safe,
    device=("cuda:0" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")),
)

sku_trainer.setup_model()

# Optional training stability tweaks
if hasattr(sku_trainer.model, "gradient_checkpointing_enable"):
    try:
        sku_trainer.model.gradient_checkpointing_enable()
    except Exception:
        pass
try:
    sku_trainer.model = sku_trainer.model.to(dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32)
except Exception:
    pass

print("✅ Trainer re-initialized with 'all-linear' LoRA target. Proceed to run the training cell again.")

In [None]:
# === 8. Diagnostics & Debugging for SKU Effectiveness ===
import torch, math
from collections import Counter
from transformers import AutoModelForCausalLM, AutoTokenizer

print("[Diagnostics] Starting SKU debugging block...")

# 1. Dataset split distribution & basic length stats
if 'dataset' in globals():
    split_counts = Counter(dataset.data['split']) if hasattr(dataset, 'data') else {}
    print("Split counts:", split_counts)
    # Approx prompt / answer token stats (first 200 samples)
    prompt_lens = []
    answer_lens = []
    for i in range(min(200, len(dataset))):
        item = dataset[i]
        prompt_lens.append(int(item['start_locs']))
        answer_lens.append(int(item['answer_len_kept']))
    if prompt_lens:
        print(f"Avg prompt tokens: {sum(prompt_lens)/len(prompt_lens):.1f} | Avg kept answer tokens: {sum(answer_lens)/len(answer_lens):.1f}")
        zero_ans = sum(1 for x in answer_lens if x == 0)
        print(f"Samples with 0 answer tokens kept: {zero_ans}/{len(answer_lens)} ({100*zero_ans/len(answer_lens):.1f}%)")
else:
    print("Dataset not found in globals().")

# 2. Inspect one batch to see UL active positions
if 'dataloader' in globals():
    first_batch = next(iter(dataloader))
    # Move minimal tensors
    input_ids = first_batch['input_ids']
    attention_mask = first_batch['attention_mask']
    start_locs = first_batch['start_locs']
    answer_len_kept = first_batch['answer_len_kept']
    split = first_batch['split']
    if 'sku_trainer' in globals():
        with torch.no_grad():
            prompt_mask_tgt, answer_mask_tgt = sku_trainer._compute_span_masks_targets_precise(attention_mask, start_locs, answer_len_kept)
            # Align with target length (T-1)
            prompt_mask_tgt = prompt_mask_tgt[:, :-1]
            answer_mask_tgt = answer_mask_tgt[:, :-1]
            attn_tgt = attention_mask[:, 1:].bool()
            retain_mask = (split == 0).unsqueeze(-1).expand_as(prompt_mask_tgt)
            forget_mask = (split == 1).unsqueeze(-1).expand_as(prompt_mask_tgt)
            ul_active = (attn_tgt & forget_mask & answer_mask_tgt).sum().item()
            forget_answer_tokens = (forget_mask & answer_mask_tgt).sum().item()
            print(f"UL active positions in sample batch: {ul_active}")
            print(f"Total forget answer target positions in sample batch: {forget_answer_tokens}")
            if forget_answer_tokens == 0:
                print("WARNING: No forget answer tokens available; unlikelihood loss will be zero. Consider increasing max_length or shortening prompts.")
    else:
        print("sku_trainer not defined.")
else:
    print("Dataloader not found.")

# 3. Function to compute log-prob of a sensitive answer span before & after adapters

def compute_span_logprob(model, tokenizer, prompt, span_text, device=None):
    device = device or (next(model.parameters()).device if any(p.requires_grad for p in model.parameters()) else 'cpu')
    model.eval()
    with torch.no_grad():
        tok = tokenizer(prompt + span_text, return_tensors='pt')
        for k in tok: tok[k] = tok[k].to(device)
        outputs = model(**tok, use_cache=False, return_dict=True)
        logits = outputs.logits # [1, T, V]
        input_ids = tok['input_ids']
        # We want log P(span | prompt). Identify boundary.
        prompt_ids = tokenizer(prompt, add_special_tokens=False)['input_ids']
        plen = len(prompt_ids)
        # Shift for causal LM
        target_ids = input_ids[:, 1:]  # next-token targets
        logits_shifted = logits[:, :-1, :]
        # Positions corresponding to span tokens
        span_positions = list(range(plen, input_ids.size(1)-1))  # exclude last because of shift alignment
        if not span_positions:
            return float('nan'), 0
        log_probs = torch.log_softmax(logits_shifted[0, span_positions, :], dim=-1)
        tgt_tokens = target_ids[0, span_positions]
        gathered = log_probs[range(len(span_positions)), tgt_tokens]
        return gathered.sum().item(), len(span_positions)

# 4. Compare probability of original answer phrase (greedy sensitive segment)
try:
    sensitive_prompt = forget_validation_df.iloc[0]['input']
    # Use ground-truth output field as sensitive answer to suppress
    sensitive_answer = forget_validation_df.iloc[0]['output']
    base_tok = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
    if base_tok.pad_token is None: base_tok.pad_token = base_tok.eos_token
    base_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, local_files_only=True, torch_dtype=torch.float32)
    lp_base, Lspan = compute_span_logprob(base_model, base_tok, sensitive_prompt, sensitive_answer)
    print(f"Base span log-prob (sum over {Lspan} tokens): {lp_base:.2f}")
    adapted = None
    from peft import PeftModel
    if os.path.exists('sku_model_epoch_last'):
        try:
            adapted = PeftModel.from_pretrained(base_model, 'sku_model_epoch_last')
        except Exception as e:
            print("Could not load adapters from sku_model_epoch_last:", e)
    elif os.path.exists('balanced_results/balanced_model'):
        try:
            adapted = PeftModel.from_pretrained(base_model, 'balanced_results/balanced_model')
        except Exception as e:
            print("Could not load adapters from balanced_results/balanced_model:", e)
    if adapted is not None:
        lp_adapt, _ = compute_span_logprob(adapted, base_tok, sensitive_prompt, sensitive_answer)
        print(f"Adapted span log-prob: {lp_adapt:.2f}")
        if math.isfinite(lp_base) and math.isfinite(lp_adapt):
            delta = lp_adapt - lp_base
            print(f"Δ log-prob (adapted - base): {delta:.2f} (negative desired for forgetting)")
            if delta > -0.5:
                print("Span probability not sufficiently reduced. Consider stronger ul_weight_answer, higher lr, or upsampling forget examples.")
    else:
        print("No adapted model directory found for probability comparison.")
except Exception as e:
    print("Span log-prob comparison skipped:", e)

# 5. Recommendations print based on quick heuristics
print("\n[Heuristic Recommendations]")
if 'split_counts' in locals() and split_counts:
    total = sum(split_counts.values())
    fgt = split_counts.get('forget', 0)
    if fgt / max(1,total) < 0.2:
        print("- Forget examples <20%: upsample forget or increase ul_weight_answer/refusal_weight.")
if 'answer_lens' in locals() and answer_lens:
    if sum(1 for x in answer_lens if x==0) / len(answer_lens) > 0.3:
        print("- Many samples lose answer tokens (truncation). Increase max_length or shorten prompts.")
print("- If Δ log-prob ~0, raise lr (e.g., 1e-4), set ul_weight_answer 6-8, set refusal_weight 0.3, temporarily disable l2_anchor.")
print("- Use sampling (top_p=0.9, temperature=0.8) for qualitative A/B instead of greedy only.")
print("[Diagnostics] Complete.")

In [None]:
# === 8.1 Enhanced LoRA Re-init with Auto-Detected Linear Modules ===
import torch, types
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType

print("[Enhanced LoRA] Detecting linear submodules for targeted adaptation...")
base_tmp = AutoModelForCausalLM.from_pretrained(MODEL_PATH, local_files_only=True)
linear_names = []
for name, module in base_tmp.named_modules():
    if isinstance(module, torch.nn.Linear):
        # Use the leaf module name
        leaf = name.split('.')[-1]
        linear_names.append(leaf)
unique_linear = sorted(set(linear_names))
print(f"Found {len(unique_linear)} unique linear leaf names (showing first 25): {unique_linear[:25]}")

# Heuristic filter: keep typical projection/feed-forward names if they exist
preferred = [n for n in unique_linear if any(k in n for k in ["q", "k", "v", "o", "proj", "gate", "up", "down", "w1", "w2", "fc", "linear"])]
# Fallback to all unique linear names if filter becomes too small
if len(preferred) < 4:
    preferred = unique_linear
print(f"Using {len(preferred)} target module names for LoRA: {preferred}")

auto_lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=32,                 # larger rank for stronger capacity
    lora_alpha=64,
    lora_dropout=0.05,
    target_modules=preferred,
)

# Rebuild trainer with new config
sku_trainer = SelectiveKnowledgeNegationTrainer(
    model_path=MODEL_PATH,
    tokenizer=tokenizer,
    lora_config=auto_lora_config,
    device=("cuda:0" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")),
)
sku_trainer.setup_model()

# Force LoRA params to fp32 for better small-gradient resolution
for n, p in sku_trainer.model.named_parameters():
    if p.requires_grad:
        p.data = p.data.float()

print("[Enhanced LoRA] Trainable parameter count:", sum(p.numel() for p in sku_trainer.model.parameters() if p.requires_grad))
print("Proceed to run the enhanced training cell (8.2).")

In [None]:
# === 8.2 Enhanced Training Pass (Stronger Forgetting) ===
from math import ceil
import types

print("[Enhanced Training] Ready.")

# Hyperparameters tuned for stronger unlearning pressure
EPOCHS = 5
LR = 1e-4
UL_WEIGHT = 7.0
REFUSAL_WEIGHT = 0.3
CE_PROMPT_WEIGHT = 0.5
CE_RETAIN_WEIGHT = 1.0
L2_ANCHOR = 0.0
ENTROPY_WEIGHT = 0.05
GRAD_ACCUM = 4
WARMUP = 0.08
UL_RAMP = 0.15
CURRICULUM_SWITCH = 2
L2_AFTER = 5e-6
FORGET_UPSAMPLE = 1.5

if 'dataset' in globals():
    import pandas as pd
    base_df = dataset.data
    forget_df = base_df[base_df['split'] == 'forget']
    retain_df = base_df[base_df['split'] == 'retain']
    if FORGET_UPSAMPLE > 1 and len(forget_df) > 0:
        reps_int = int(FORGET_UPSAMPLE)
        frac_part = FORGET_UPSAMPLE - reps_int
        replicated = [forget_df]*reps_int
        if frac_part > 1e-6:
            replicated.append(forget_df.sample(frac=frac_part, replace=True, random_state=42))
        extra_forget = pd.concat(replicated, ignore_index=True)
        aug_df = pd.concat([retain_df, extra_forget], ignore_index=True)
        print(f"[Enhanced Training] Upsampled forget examples: {len(extra_forget)}")
    else:
        aug_df = base_df
    dataset = UnlearningDataset(aug_df, tokenizer, max_length=384)
    dataloader = DataLoader(
        dataset,
        batch_size=2,
        shuffle=True,
        collate_fn=lambda b: sku_collate_fn(b, pad_id, max_length=384),
        pin_memory=torch.cuda.is_available(),
        num_workers=2 if torch.cuda.is_available() else 0,
        persistent_workers=False,
    )
    print(f"[Enhanced Training] New dataset size {len(dataset)} | forget {sum(dataset.data['split']=='forget')} | retain {sum(dataset.data['split']=='retain')}")
else:
    print("[Enhanced Training] Dataset object not found; skipping upsample step.")

orig_train_method = sku_trainer.train

def train_with_curriculum(self):
    for epoch in range(EPOCHS):
        cur_l2 = L2_AFTER if epoch >= CURRICULUM_SWITCH else L2_ANCHOR
        print(f"\n[Epoch {epoch+1}/{EPOCHS}] L2 anchor weight: {cur_l2}")
        orig_train_method(
            dataloader=dataloader,
            num_epochs=1,
            lr=LR,
            ce_weight_prompt=CE_PROMPT_WEIGHT,
            ul_weight_answer=UL_WEIGHT,
            ce_weight_retain=CE_RETAIN_WEIGHT,
            l2_anchor_weight=cur_l2,
            entropy_weight_answer=ENTROPY_WEIGHT,
            refusal_weight=REFUSAL_WEIGHT,
            grad_clip=1.0,
            grad_accum_steps=GRAD_ACCUM,
            use_mixed_precision=True,
            warmup_ratio=WARMUP if epoch == 0 else 0.0,
            ul_ramp_ratio=UL_RAMP,
        )
        # Save epoch checkpoint
        ckpt_dir = f"enhanced_ckpts/epoch_{epoch+1}"
        os.makedirs(ckpt_dir, exist_ok=True)
        self.save_model(ckpt_dir)
        # Span probability diagnostic
        try:
            from peft import PeftModel
            base_tok = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
            if base_tok.pad_token is None: base_tok.pad_token = base_tok.eos_token
            base_model_local = AutoModelForCausalLM.from_pretrained(MODEL_PATH, local_files_only=True, torch_dtype=torch.float32)
            adapted_local = PeftModel.from_pretrained(base_model_local, ckpt_dir)
            sensitive_prompt = forget_validation_df.iloc[0]['input']
            sensitive_answer = forget_validation_df.iloc[0]['output']
            lp_base, Lspan = compute_span_logprob(base_model_local, base_tok, sensitive_prompt, sensitive_answer)
            lp_adapt, _ = compute_span_logprob(adapted_local, base_tok, sensitive_prompt, sensitive_answer)
            print(f"[Epoch {epoch+1}] Span log-prob base {lp_base:.2f} -> adapted {lp_adapt:.2f} Δ {lp_adapt-lp_base:.2f}")
        except Exception as e:
            print(f"[Epoch {epoch+1}] Span diagnostic skipped: {e}")

sku_trainer.train_curriculum = types.MethodType(train_with_curriculum, sku_trainer)
print("Call sku_trainer.train_curriculum() to launch enhanced training.")

In [None]:
# === 8.3 Deep Debug: Adapter Change & Gradient Inspection ===
import torch, math

if 'sku_trainer' not in globals():
    print("sku_trainer not found. Run the setup cells first.")
else:
    model = sku_trainer.model
    changed = []
    total_l2 = 0.0
    total_params = 0
    # Identify LoRA matrices (common naming: lora_A, lora_B)
    for name, p in model.named_parameters():
        if p.requires_grad and ('lora_A' in name or 'lora_B' in name):
            init = sku_trainer.initial_state_dict.get(name)
            if init is None:
                continue
            diff = (p.detach() - init.to(p.device, dtype=p.dtype))
            l2 = diff.pow(2).sum().sqrt().item()
            mean_abs = diff.abs().mean().item()
            changed.append((name, p.numel(), l2, mean_abs))
            total_l2 += diff.pow(2).sum().item()
            total_params += p.numel()
    if not changed:
        print("No LoRA parameter deltas detected (initial snapshot may refer to different model instance).")
    else:
        changed.sort(key=lambda x: x[2], reverse=True)
        print(f"Top 5 adapter deltas (L2 norm):")
        for name, n, l2, mean_abs in changed[:5]:
            print(f"  {name}: L2={l2:.4e} | mean|Δ|={mean_abs:.4e} | elems={n}")
        global_l2 = math.sqrt(total_l2) if total_l2>0 else 0.0
        print(f"Global adapter Δ L2: {global_l2:.4e} over {total_params} params")

    # If global_l2 ~ 0, training produced effectively no update -> verify gradient flow on one batch
    if total_l2 < 1e-6:
        print("Adapter updates near zero; sampling one batch to check gradient flow...")
        b = next(iter(dataloader))
        for k in b: b[k] = b[k].to(sku_trainer.device)
        model.train()
        for name, p in model.named_parameters():
            if p.grad is not None:
                p.grad = None
        out = model(input_ids=b['input_ids'], attention_mask=b['attention_mask'], use_cache=False, return_dict=True)
        logits = out.logits[:, :-1]
        target = b['labels'][:, 1:]
        attn_tgt = b['attention_mask'][:, 1:].bool()
        # Simple CE to probe gradients
        vocab = logits.size(-1)
        loss_probe = torch.nn.functional.cross_entropy(logits.reshape(-1, vocab)[attn_tgt.reshape(-1)], target.reshape(-1)[attn_tgt.reshape(-1)])
        loss_probe.backward()
        grad_sum = 0.0
        grad_ct = 0
        for name, p in model.named_parameters():
            if p.requires_grad and ('lora_A' in name or 'lora_B' in name):
                if p.grad is not None:
                    gnorm = p.grad.detach().pow(2).sum().sqrt().item()
                    if gnorm > 0:
                        grad_sum += gnorm
                        grad_ct += 1
        print(f"Adapter grad probes: {grad_ct} matrices with cumulative grad L2 {grad_sum:.4e}")
        if grad_ct == 0 or grad_sum < 1e-6:
            print("No gradient reaching LoRA layers. Possible causes: target_modules mismatch, model wrapped after snapshot, or parameters frozen.")
        else:
            print("Gradients flow, but saved deltas were zero -> likely saving/loading different adapter instance.")

    # Sanity: list adapter directory files for last epoch if present
    import os
    last_ckpt = None
    if os.path.isdir('enhanced_ckpts'):
        epochs = [d for d in os.listdir('enhanced_ckpts') if d.startswith('epoch_')]
        if epochs:
            epochs.sort()
            last_ckpt = f"enhanced_ckpts/{epochs[-1]}"
    if last_ckpt:
        print(f"Listing adapter files in {last_ckpt}:")
        try:
            for f in os.listdir(last_ckpt):
                print("  -", f, os.path.getsize(os.path.join(last_ckpt,f)), "bytes")
        except Exception as e:
            print("Could not list files:", e)
    else:
        print("No enhanced_ckpts directory found.")

print("[Deep Debug Complete]")

In [None]:
# === 8.4 Multi-Sample Forget Span Probability Evaluation ===
import torch, math, statistics
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

print("[8.4] Evaluating average forget-answer log-prob before vs after adapters (first 50 samples)...")

N_SAMPLES = 50
if 'forget_validation_df' not in globals():
    print("forget_validation_df missing.")
else:
    base_tok = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
    if base_tok.pad_token is None: base_tok.pad_token = base_tok.eos_token
    base_model_eval = AutoModelForCausalLM.from_pretrained(MODEL_PATH, local_files_only=True, torch_dtype=torch.float32)
    adapted_model_eval = None
    # Prefer last enhanced checkpoint
    pref_ckpt = 'enhanced_ckpts/epoch_5'
    if not os.path.isdir(pref_ckpt):
        # fallback to any epoch dir
        import glob
        cands = sorted(glob.glob('enhanced_ckpts/epoch_*'))
        if cands:
            pref_ckpt = cands[-1]
    if os.path.isdir(pref_ckpt):
        try:
            adapted_model_eval = PeftModel.from_pretrained(base_model_eval, pref_ckpt)
            adapted_model_eval.eval()
            print(f"Loaded adapters from {pref_ckpt}")
        except Exception as e:
            print("Adapter load failed:", e)
    else:
        print("No adapter checkpoint directory found for evaluation.")

    def span_lp(model, prompt, answer):
        with torch.no_grad():
            toks_prompt = base_tok(prompt, add_special_tokens=False)
            plen = len(toks_prompt['input_ids'])
            toks_full = base_tok(prompt + answer, return_tensors='pt')
            out = model(**toks_full, use_cache=False, return_dict=True)
            logits = out.logits[:, :-1]
            targets = toks_full['input_ids'][:, 1:]
            # answer positions start at plen (input index), so target indices start at plen-1
            start = max(plen-1, 0)
            end = targets.size(1)
            if start >= end:
                return float('nan'), 0
            lps = torch.log_softmax(logits[0, start:end, :], dim=-1)
            tgt = targets[0, start:end]
            gather = lps[range(end-start), tgt]
            return gather.sum().item(), end-start
    base_scores = []
    adapt_scores = []
    for i in range(min(N_SAMPLES, len(forget_validation_df))):
        row = forget_validation_df.iloc[i]
        lp_b, Lb = span_lp(base_model_eval, row['input'], row['output'])
        if math.isfinite(lp_b):
            base_scores.append(lp_b)
        if adapted_model_eval is not None:
            lp_a, La = span_lp(adapted_model_eval, row['input'], row['output'])
            if math.isfinite(lp_a):
                adapt_scores.append(lp_a)
    if base_scores and adapt_scores and len(base_scores)==len(adapt_scores):
        deltas = [a-b for a,b in zip(adapt_scores, base_scores)]
        print(f"Samples compared: {len(deltas)}")
        print(f"Mean base span log-prob: {statistics.mean(base_scores):.2f}")
        print(f"Mean adapted span log-prob: {statistics.mean(adapt_scores):.2f}")
        print(f"Mean Δ (adapt-base): {statistics.mean(deltas):.2f}")
        print(f"Median Δ: {statistics.median(deltas):.2f}")
        worse = sum(1 for d in deltas if d>0)
        better = sum(1 for d in deltas if d<0)
        print(f"Count Δ<0 (desired): {better} | Δ>0: {worse}")
        if statistics.mean(deltas) > -0.5:
            print("=> Forgetting signal weak. Consider raising UL_WEIGHT, reducing CE_PROMPT_WEIGHT further, or adding KL to refusal.")
    else:
        print("Insufficient comparable scores.")

print("[8.4] Done.")

# === 8.5 Instrumented Training Wrapper (Logs UL loss & Answer Mask Coverage) ===
import time, types

if 'sku_trainer' in globals():
    original_train = sku_trainer.train
    def train_instrumented(self, *args, **kwargs):
        print("[Instrumented] Starting instrumented pass...")
        log_every = 50
        batch_stats = []
        from contextlib import nullcontext
        # Wrap original dataloader to capture per-batch masks
        dl = kwargs.get('dataloader')
        if dl is None:
            raise ValueError('Provide dataloader explicitly to train_instrumented')
        # Monkey patch internal methods to intercept UL
        orig_ul = self._unlikelihood_loss
        def tracked_ul(logits, labels, loss_mask):
            ul_val = orig_ul(logits, labels, loss_mask)
            batch_stats.append(('ul', float(ul_val.detach().cpu()), int(loss_mask.sum().item())))
            return ul_val
        self._unlikelihood_loss = tracked_ul
        start_time = time.time()
        try:
            original_train(self, *args, **kwargs)
        finally:
            self._unlikelihood_loss = orig_ul
        # Summaries
        ul_vals = [v for tag,v,_ in batch_stats if tag=='ul']
        ul_tokens = [c for tag,_,c in batch_stats if tag=='ul']
        if ul_vals:
            print(f"[Instrumented] Avg UL loss: {sum(ul_vals)/len(ul_vals):.4f} | Avg UL active tokens: {sum(ul_tokens)/len(ul_tokens):.1f}")
        print(f"[Instrumented] Duration: {time.time()-start_time:.1f}s")
    sku_trainer.train_instrumented = types.MethodType(train_instrumented, sku_trainer)
    print("[8.5] sku_trainer.train_instrumented available. Use it with the same args as sku_trainer.train().")
else:
    print("sku_trainer not found; skip instrumented wrapper.")

print("[8.5] Ready.")

In [None]:
# === 8.4b Corrected Isolated Span Probability Evaluation ===
"""
The previous 8.4 evaluation loaded adapters into (or alongside) the same base model instance prior to computing
"base" scores, making base and adapted probabilities identical (Δ = 0). This cell:
 1. Loads a FRESH base model (base_model_plain) for baseline scores.
 2. Loads a SEPARATE fresh base model and then applies adapters (base_model_with_adapters -> adapted_model).
 3. Computes log P(answer | prompt) over answer tokens only, for N samples.
 4. Reports mean / median Δ (adapt - base). Negative values indicate successful forgetting.

Memory note: Loads the 1B model twice; if OOM, set LOAD_TWICE=False to reuse one instance (but must score base BEFORE attaching adapters) .
"""
import math, statistics, torch, gc
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

N_SAMPLES = 50
LOAD_TWICE = True   # set False if memory constrained
CKPT_DIR = 'enhanced_ckpts/epoch_5'

def load_base():
    return AutoModelForCausalLM.from_pretrained(
        MODEL_PATH, local_files_only=True, torch_dtype=torch.float32
    )

print('[8.4b] Starting isolated evaluation ...')
if 'forget_validation_df' not in globals():
    print('Missing forget_validation_df.')
else:
    base_tok = AutoTokenizer.from_pretrained('allenai/OLMo-1B-0724-hf')
    if base_tok.pad_token is None:
        base_tok.pad_token = base_tok.eos_token

    if not LOAD_TWICE:
        # One-pass strategy: compute base scores, then attach adapters and recompute
        base_model_plain = load_base().eval()
        # Collect base scores
        base_scores = []
        answers_lens = []
        with torch.no_grad():
            for i in range(min(N_SAMPLES, len(forget_validation_df))):
                row = forget_validation_df.iloc[i]
                prompt = row['input']
                answer = row['output']
                toks_full = base_tok(prompt + answer, return_tensors='pt')
                plen = len(base_tok(prompt, add_special_tokens=False)['input_ids'])
                for k in toks_full: toks_full[k] = toks_full[k]
                out = base_model_plain(**toks_full, use_cache=False, return_dict=True)
                logits = out.logits[:, :-1]
                targets = toks_full['input_ids'][:, 1:]
                start = plen  # answer starts at prompt length for target alignment
                end = targets.size(1)
                if start < end:
                    lps = torch.log_softmax(logits[0, start-1:end-1, :], dim=-1)  # shift back one due to causal alignment
                    span_ids = targets[0, start-1:end-1]
                    gather = lps[range(span_ids.size(0)), span_ids]
                    base_scores.append(gather.sum().item())
                    answers_lens.append(span_ids.size(0))
        # Attach adapters
        adapted_model = PeftModel.from_pretrained(base_model_plain, CKPT_DIR).eval()
        adapt_scores = []
        with torch.no_grad():
            for i in range(min(N_SAMPLES, len(forget_validation_df))):
                row = forget_validation_df.iloc[i]
                prompt = row['input']
                answer = row['output']
                toks_full = base_tok(prompt + answer, return_tensors='pt')
                plen = len(base_tok(prompt, add_special_tokens=False)['input_ids'])
                out = adapted_model(**toks_full, use_cache=False, return_dict=True)
                logits = out.logits[:, :-1]
                targets = toks_full['input_ids'][:, 1:]
                start = plen
                end = targets.size(1)
                if start < end:
                    lps = torch.log_softmax(logits[0, start-1:end-1, :], dim=-1)
                    span_ids = targets[0, start-1:end-1]
                    gather = lps[range(span_ids.size(0)), span_ids]
                    adapt_scores.append(gather.sum().item())
    else:
        # Two independent loads
        base_model_plain = load_base().eval()
        base_scores = []
        answers_lens = []
        with torch.no_grad():
            for i in range(min(N_SAMPLES, len(forget_validation_df))):
                row = forget_validation_df.iloc[i]
                prompt = row['input']
                answer = row['output']
                toks_full = base_tok(prompt + answer, return_tensors='pt')
                plen = len(base_tok(prompt, add_special_tokens=False)['input_ids'])
                out = base_model_plain(**toks_full, use_cache=False, return_dict=True)
                logits = out.logits[:, :-1]
                targets = toks_full['input_ids'][:, 1:]
                start = plen
                end = targets.size(1)
                if start < end:
                    lps = torch.log_softmax(logits[0, start-1:end-1, :], dim=-1)
                    span_ids = targets[0, start-1:end-1]
                    gather = lps[range(span_ids.size(0)), span_ids]
                    base_scores.append(gather.sum().item())
                    answers_lens.append(span_ids.size(0))
        # Free first instance if memory tight
        del base_model_plain
        gc.collect()
        try:
            torch.cuda.empty_cache()
        except Exception:
            pass
        # Load second base and apply adapters
        base_model_with_adapters = load_base()
        adapted_model = PeftModel.from_pretrained(base_model_with_adapters, CKPT_DIR).eval()
        adapt_scores = []
        with torch.no_grad():
            for i in range(min(N_SAMPLES, len(forget_validation_df))):
                row = forget_validation_df.iloc[i]
                prompt = row['input']
                answer = row['output']
                toks_full = base_tok(prompt + answer, return_tensors='pt')
                plen = len(base_tok(prompt, add_special_tokens=False)['input_ids'])
                out = adapted_model(**toks_full, use_cache=False, return_dict=True)
                logits = out.logits[:, :-1]
                targets = toks_full['input_ids'][:, 1:]
                start = plen
                end = targets.size(1)
                if start < end:
                    lps = torch.log_softmax(logits[0, start-1:end-1, :], dim=-1)
                    span_ids = targets[0, start-1:end-1]
                    gather = lps[range(span_ids.size(0)), span_ids]
                    adapt_scores.append(gather.sum().item())

    if base_scores and adapt_scores and len(base_scores)==len(adapt_scores):
        deltas = [a-b for a,b in zip(adapt_scores, base_scores)]
        print(f"Samples compared: {len(deltas)}")
        print(f"Mean base span log-prob: {statistics.mean(base_scores):.2f}")
        print(f"Mean adapted span log-prob: {statistics.mean(adapt_scores):.2f}")
        print(f"Mean Δ (adapt-base): {statistics.mean(deltas):.2f}")
        print(f"Median Δ: {statistics.median(deltas):.2f}")
        neg = sum(1 for d in deltas if d < 0)
        pos = sum(1 for d in deltas if d > 0)
        print(f"Count Δ<0 (desired): {neg} | Δ>0: {pos}")
        if statistics.mean(deltas) > -0.5:
            print('=> Forgetting still weak OR sensitive spans not strongly represented; consider: higher UL_WEIGHT (8-10), add KL to refusal, extra epochs focused only on forget batch sampling.')
    else:
        print('Insufficient comparable scores or mismatch.')

print('[8.4b] Done.')