# Selective Knowledge Negation Unlearning


## 1. Setup e Import

In [None]:
!pip install rouge-score
import torch
import pandas as pd
import numpy as np
import json
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from rouge_score import rouge_scorer

# Configurazioni
MODEL_PATH = "/kaggle/input/olmo-model/semeval25-unlearning-1B-model"
DATA_PATH = "/kaggle/input/olmo-model/semeval25-unlearning-data"

print(f"GPUs disponibili: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

## 2. Caricamento Dati e Modelli

In [None]:
# Caricamento dataset
retain_train_df = pd.read_parquet(f"{DATA_PATH}/data/retain_train-00000-of-00001.parquet", engine='pyarrow')
retain_validation_df = pd.read_parquet(f"{DATA_PATH}/data/retain_validation-00000-of-00001.parquet", engine='pyarrow')
forget_train_df = pd.read_parquet(f"{DATA_PATH}/data/forget_train-00000-of-00001.parquet", engine='pyarrow')
forget_validation_df = pd.read_parquet(f"{DATA_PATH}/data/forget_validation-00000-of-00001.parquet", engine='pyarrow')

# Salvataggio in formato JSONL
!mkdir -p train validation
retain_train_df.to_json('train/retain.jsonl', orient='records', lines=True)
forget_train_df.to_json('train/forget.jsonl', orient='records', lines=True)
retain_validation_df.to_json('validation/retain.jsonl', orient='records', lines=True)
forget_validation_df.to_json('validation/forget.jsonl', orient='records', lines=True)

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Dataset salvati e tokenizer caricato")

## 3. Dataset

In [None]:
class UnlearningDataset(Dataset):
    def __init__(self, data_source, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        if isinstance(data_source, pd.DataFrame):
            self.data = data_source
            print(f"Caricati {len(self.data)} esempi dal DataFrame")
        elif isinstance(data_source, str):
            data_list = []
            with open(data_source, 'r', encoding='utf-8') as f:
                for line in f:
                    item = json.loads(line.strip())
                    data_list.append(item)
            self.data = pd.DataFrame(data_list)
            print(f"Caricati {len(self.data)} esempi da {data_source}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        prompt_text = self.data.iloc[idx]["input"]
        answer_text = self.data.iloc[idx]["output"]

        # Tokenize prompt alone to get the split boundary (answer starts after this index)
        prompt_tok = self.tokenizer(
            prompt_text,
            padding=False,
            truncation=True,
            max_length=self.max_length,
            return_tensors=None,
        )

        # Tokenize concatenated prompt + answer for model inputs and labels
        combined_tok = self.tokenizer(
            f"{prompt_text} {answer_text}",
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors=None,
        )

        # Length of prompt in tokens (answer starts at this index in labels)
        start_locs = len(prompt_tok["input_ids"]) - 1

        return {
            "input_ids": torch.tensor(combined_tok["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(combined_tok["attention_mask"], dtype=torch.long),
            "start_locs": start_locs,
            "labels": torch.tensor(combined_tok["input_ids"], dtype=torch.long),
            "split": 1 if self.data.iloc[idx]["split"] == "forget" else 0,
        }

In [None]:
# Create dataset and dataloader
batch_size = 4
train_data = pd.concat([retain_train_df, forget_train_df], ignore_index=True)
dataset = UnlearningDataset(train_data, tokenizer)
dataloader = DataLoader(dataset, batch_size, shuffle=True)

print(f"Dataset creato con {len(dataset)} esempi")

## 4. Selective Knowledge Negation Trainer

In [None]:
class SelectiveKnowledgeNegationTrainer:
    """
    Trainer implementing Selective Knowledge Negation Unlearning (SKU).

    Core idea:
    - For retain samples: optimize the standard language modeling loss (cross-entropy) to preserve knowledge.
    - For forget samples: minimize the probability of producing the forbidden answer tokens via token-level
      unlikelihood loss on the answer span while still keeping CE on the prompt context to stabilize training.

    This class retains the previous structure: setup models, train, save, and compute a task vector w.r.t. initial LoRA state.
    """

    def __init__(self, model_path, tokenizer, lora_config, device="cuda:0"):
        self.model_path = model_path
        self.tokenizer = tokenizer
        self.lora_config = lora_config
        self.device = device

        self.model = None
        self.initial_state_dict = {}

    def setup_model(self):
        """Load base model and wrap with LoRA adapters for efficient finetuning."""
        print("🔧 Setting up model (LoRA)...")
        base_model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            local_files_only=True
        )
        self.model = get_peft_model(base_model, self.lora_config).to(self.device)
        self.model.print_trainable_parameters()

        # Snapshot initial LoRA weights to compute a task vector later
        for name, p in self.model.named_parameters():
            if p.requires_grad:
                self.initial_state_dict[name] = p.data.clone()
        print("✅ Model setup completed")

    @staticmethod
    def _shift_labels_for_ce(labels):
        """Utility to align labels for causal LM loss if computing manually."""
        # Not used if we call HF loss directly, but handy for custom computations.
        return labels.clone()

    def _compute_mask_spans(self, input_ids, start_locs):
        """
        Build boolean masks for prompt and answer spans per sample.
        start_locs: integer index where the answer starts (after the prompt tokens)
        We assume labels contain both prompt and answer, with the answer appended to the prompt (as created in the dataset).
        """
        # input_ids: [B, T]
        B, T = input_ids.shape
        device = input_ids.device
        prompt_mask = torch.zeros((B, T), dtype=torch.bool, device=device)
        answer_mask = torch.zeros((B, T), dtype=torch.bool, device=device)
        for i in range(B):
            s = int(start_locs[i].item()) if torch.is_tensor(start_locs[i]) else int(start_locs[i])
            # prompt part includes up to s (exclusive for prediction on next token)
            # answer part starts from s to end (where labels correspond to predicted answer tokens)
            # We predict token t using inputs up to t-1, so for loss masking we mark labels positions.
            # We'll consider tokens after s as answer tokens for unlikelihood.
            # Clamp to valid range
            s = max(0, min(s, T - 1))
            # labels at position >= s are part of the answer sequence
            answer_mask[i, s:] = True
            prompt_mask[i, :s] = True
        return prompt_mask, answer_mask

    def _cross_entropy_loss(self, logits, labels, loss_mask):
        """
        Standard token-level cross-entropy computed only where loss_mask is True.
        logits: [B, T, V], labels: [B, T], loss_mask: [B, T] (bool)
        """
        vocab = logits.size(-1)
        # Flatten
        logits_flat = logits.view(-1, vocab)
        labels_flat = labels.view(-1)
        mask_flat = loss_mask.view(-1)

        if mask_flat.sum() == 0:
            return logits.new_tensor(0.0)

        ce = F.cross_entropy(
            logits_flat[mask_flat],
            labels_flat[mask_flat],
            reduction="mean"
        )
        return ce

    def _unlikelihood_loss(self, logits, labels, loss_mask):
        """
        Token-level unlikelihood loss from "The Unlikelihood Training Objective" (Welleck et al.).
        For target (gold) token y, penalize log p(y) to discourage producing y.
        Implemented as: L_UL = - log(1 - p(y)) aggregated over masked positions.
        To ensure numerical stability: clamp p in (eps, 1 - eps).
        """
        # Compute probabilities
        probs = F.softmax(logits, dim=-1)
        B, T, V = probs.shape
        # Gather p(y)
        y = labels.unsqueeze(-1)  # [B, T, 1]
        p_y = torch.gather(probs, dim=-1, index=y).squeeze(-1)  # [B, T]
        eps = 1e-6
        p_y = p_y.clamp(min=eps, max=1 - eps)
        ul = -torch.log(1.0 - p_y)
        # Mask
        ul = ul[loss_mask]
        if ul.numel() == 0:
            return logits.new_tensor(0.0)
        return ul.mean()

    def train(self, dataloader, num_epochs=4, lr=1e-4, ce_weight_prompt=1.0, ul_weight_answer=1.0, ce_weight_retain=1.0, grad_clip=1.0):
        """
        Train with SKU:
        - For retain (split==0): standard CE on all label positions.
        - For forget (split==1): CE on prompt tokens (to keep context fluency) + Unlikelihood on answer tokens.
        Weights allow tuning the relative contributions.
        """
        assert self.model is not None, "Call setup_model() first"

        self.model.train()
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=0.01)

        for epoch in range(num_epochs):
            epoch_losses = []
            with tqdm(total=len(dataloader), desc=f"SKU Epoch {epoch+1}") as pbar:
                for batch in dataloader:
                    input_ids = batch["input_ids"].to(self.device)
                    attention_mask = batch["attention_mask"].to(self.device)
                    labels = batch["labels"].to(self.device)
                    start_locs = batch["start_locs"].to(self.device)
                    split = batch["split"].to(self.device)  # 0 retain, 1 forget

                    outputs = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels,
                        output_hidden_states=False,
                        return_dict=True,
                    )
                    logits = outputs.logits  # [B, T, V]

                    # Build span masks
                    prompt_mask, answer_mask = self._compute_mask_spans(input_ids, start_locs)

                    # Masks by split
                    retain_mask = (split == 0).unsqueeze(-1).expand_as(prompt_mask)
                    forget_mask = (split == 1).unsqueeze(-1).expand_as(prompt_mask)

                    # Retain: CE on all tokens where labels != -100 (use attention mask to avoid padded positions)
                    valid_tokens = attention_mask.bool()
                    retain_loss = self._cross_entropy_loss(
                        logits, labels, loss_mask=(valid_tokens & retain_mask)
                    ) * ce_weight_retain

                    # Forget: CE only on prompt to keep fluency
                    forget_prompt_loss = self._cross_entropy_loss(
                        logits, labels, loss_mask=(valid_tokens & forget_mask & prompt_mask)
                    ) * ce_weight_prompt

                    # Forget: Unlikelihood on answer tokens to negate specific knowledge
                    forget_ul_loss = self._unlikelihood_loss(
                        logits, labels, loss_mask=(valid_tokens & forget_mask & answer_mask)
                    ) * ul_weight_answer

                    loss = retain_loss + forget_prompt_loss + forget_ul_loss

                    optimizer.zero_grad()
                    loss.backward()
                    if grad_clip is not None:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip)
                    optimizer.step()

                    epoch_losses.append(loss.item())
                    pbar.set_postfix({
                        "Loss": f"{loss.item():.4f}",
                        "RetCE": f"{retain_loss.item():.3f}",
                        "FgtCE": f"{forget_prompt_loss.item():.3f}",
                        "FgtUL": f"{forget_ul_loss.item():.3f}",
                    })
                    pbar.update(1)

            avg_loss = float(np.mean(epoch_losses)) if epoch_losses else 0.0
            print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss:.4f}")

    def save_model(self, save_path):
        """Save model adapters and tokenizer."""
        self.model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)

    def calculate_task_vector(self):
        """Return task vector between initial and trained LoRA parameters."""
        tv = {}
        for name, p in self.model.named_parameters():
            if p.requires_grad and name in self.initial_state_dict:
                tv[name] = p.data - self.initial_state_dict[name]
        return tv


## 5. Setup Trainer and Training

In [None]:
# Configure LoRA (single model)
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# Initialize SKU trainer
sku_trainer = SelectiveKnowledgeNegationTrainer(
    model_path=MODEL_PATH,
    tokenizer=tokenizer,
    lora_config=lora_config,
    device="cuda:0" if torch.cuda.is_available() else "cpu",
)

# Setup model
sku_trainer.setup_model()

In [None]:
# Train SKU model
sku_trainer.train(
    dataloader=dataloader,
    num_epochs=4,
    lr=1e-4,
    ce_weight_prompt=1.0,       # CE weight on prompt for forget samples
    ul_weight_answer=1.0,       # Unlikelihood weight on answer for forget samples
    ce_weight_retain=1.0,       # CE weight for retain samples
    grad_clip=1.0,
)

# Save an intermediate checkpoint
sku_trainer.save_model("sku_model_epoch_last")

In [None]:
# Save final model
sku_trainer.save_model("studentmodel_final")

## 6. Save Results and Task Vector

In [None]:
# Create results directory
os.makedirs('balanced_results', exist_ok=True)

# Save SKU model
sku_trainer.save_model('balanced_results/balanced_model')

# Calculate and save task vector
task_vector = sku_trainer.calculate_task_vector()
torch.save(task_vector, 'balanced_results/task_vector.pt')

print("✅ Results saved in balanced_results/")
print("- balanced_model/: SKU-trained model")
print("- task_vector.pt: Task vector for future applications")

# 7. Evaluation



In [None]:

import types

try:
    import evaluation
    import importlib
    importlib.reload(evaluation)
except ImportError:
    pass

def run_evaluation(
    data_path,
    checkpoint_path,
    output_dir="eval_results",
    mia_data_path=None,
    mmlu_metrics_file_path=None,
    max_new_tokens=256,
    batch_size=25,
    debug=False,
    compute_metrics_only=False,
    seed=42,
    keep_files=False,
):
    try:
        # Costruiamo un oggetto args simile a quello di argparse
        args = types.SimpleNamespace(
            data_path=data_path,
            checkpoint_path=checkpoint_path,
            output_dir=output_dir,
            mia_data_path=mia_data_path,
            mmlu_metrics_file_path=mmlu_metrics_file_path,
            max_new_tokens=max_new_tokens,
            batch_size=batch_size,
            debug=debug,
            compute_metrics_only=compute_metrics_only,
            seed=seed,
            keep_files=keep_files,
        )

        # Verifica che i file esistano
        print(f"🔍 Verificando paths...")
        print(f"  Data path: {data_path}")
        print(f"  Checkpoint path: {checkpoint_path}")
        print(f"  Output dir: {output_dir}")
        
        if not os.path.exists(data_path):
            raise FileNotFoundError(f"Data path not found: {data_path}")
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
        if not os.path.exists(os.path.join(data_path, 'forget.jsonl')):
            raise FileNotFoundError(f"forget.jsonl not found in {data_path}")
        if not os.path.exists(os.path.join(data_path, 'retain.jsonl')):
            raise FileNotFoundError(f"retain.jsonl not found in {data_path}")

        # Normalizza i path (come nello script originale)
        from pathlib import Path
        if args.output_dir is None:
            args.output_dir = os.getcwd()
        else:
            args.output_dir = args.output_dir.rstrip('/')
            Path(args.output_dir).mkdir(parents=True, exist_ok=True)

        # Lancia direttamente le funzioni
        import random, torch, numpy as np
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

        from accelerate import Accelerator
        accelerator = Accelerator()

        if not args.compute_metrics_only:
            from transformers import AutoModelForCausalLM, AutoTokenizer
            from peft import PeftModel, LoraConfig
            
            print(f"📥 Loading model from {args.checkpoint_path}...")
            
            # Carica il modello PEFT (LoRA) se è salvato come tale
            try:
                # Prima prova a caricare come modello PEFT
                base_model_path = MODEL_PATH  # Usa il path del modello base
                base_model = AutoModelForCausalLM.from_pretrained(
                    base_model_path, 
                    local_files_only=True,
                    torch_dtype=torch.bfloat16
                )
                model = PeftModel.from_pretrained(base_model, args.checkpoint_path)
                print("✅ Loaded as PEFT model")
            except:
                # Se fallisce, prova a caricare come modello normale
                model = AutoModelForCausalLM.from_pretrained(
                    args.checkpoint_path,
                    torch_dtype=torch.bfloat16,
                    trust_remote_code=True
                )
                print("✅ Loaded as regular model")
            
            tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

            print("🚀 Starting inference...")
            evaluation.inference(args, model, tokenizer)
            
            if args.mia_data_path is not None:
                print("🔍 Starting MIA attacks...")
                evaluation.mia_attacks(args, model, tokenizer)

        if accelerator.is_main_process:
            print("📊 Computing metrics...")
            evaluation.compute_metrics(args)
            print("✅ Evaluation completed!")

    except Exception as e:
        print(f"❌ Error during evaluation: {e}")
        import traceback
        traceback.print_exc()

# === Step 4: Esegui evaluation ===
print("🎯 Starting evaluation process...")

# Verifica che i file esistano prima di iniziare
if os.path.exists("validation/forget.jsonl") and os.path.exists("validation/retain.jsonl"):
    if os.path.exists("balanced_results/balanced_model/"):
        run_evaluation(
            data_path="validation/",  # cartella relativa con forget.jsonl e retain.jsonl
            checkpoint_path="balanced_results/balanced_model/",  # cartella relativa con i pesi del modello
            output_dir="eval_results",
            debug=True  # Attiva debug per vedere cosa succede
        )
    else:
        print("❌ Model checkpoint not found at balanced_results/balanced_model/")
        print("   Make sure the training completed successfully")
else:
    print("❌ Validation files not found")
    print("   Expected: validation/forget.jsonl and validation/retain.jsonl")
    print("   Make sure the data processing completed successfully")