# Unlearning DualTeacher


## 1. Setup e Import

In [1]:
!pip install rouge-score
import torch
import pandas as pd
import numpy as np
import json
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from rouge_score import rouge_scorer

# Configurazioni
MODEL_PATH = "/kaggle/input/olmo-model/semeval25-unlearning-1B-model"
DATA_PATH = "/kaggle/input/olmo-model/semeval25-unlearning-data"

print(f"GPUs disponibili: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

Collecting rouge-score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge-score
  Building wheel for rouge-score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge-score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=6a1623a469f08851a1994092de20c4c54ce9c65bfd3beef88799342baee9d2b4
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge-score
Installing collected packages: rouge-score
Successfully installed rouge-score-0.1.2


2025-08-16 15:54:35.712752: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755359676.074921      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755359676.176076      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


GPUs disponibili: 2
GPU 0: Tesla T4
GPU 1: Tesla T4


## 2. Caricamento Dati e Modelli

In [2]:
# Caricamento dataset
retain_train_df = pd.read_parquet(f"{DATA_PATH}/data/retain_train-00000-of-00001.parquet", engine='pyarrow')
retain_validation_df = pd.read_parquet(f"{DATA_PATH}/data/retain_validation-00000-of-00001.parquet", engine='pyarrow')
forget_train_df = pd.read_parquet(f"{DATA_PATH}/data/forget_train-00000-of-00001.parquet", engine='pyarrow')
forget_validation_df = pd.read_parquet(f"{DATA_PATH}/data/forget_validation-00000-of-00001.parquet", engine='pyarrow')

# Salvataggio in formato JSONL
!mkdir -p train validation
retain_train_df.to_json('train/retain.jsonl', orient='records', lines=True)
forget_train_df.to_json('train/forget.jsonl', orient='records', lines=True)
retain_validation_df.to_json('validation/retain.jsonl', orient='records', lines=True)
forget_validation_df.to_json('validation/forget.jsonl', orient='records', lines=True)

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Dataset salvati e tokenizer caricato")

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

Dataset salvati e tokenizer caricato


## 3. Dataset

In [None]:
class UnlearningDataset(Dataset):
    def __init__(self, data_source, tokenizer, max_length=256):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        if isinstance(data_source, pd.DataFrame):
            self.data = data_source
            print(f"Caricati {len(self.data)} esempi dal DataFrame")
        elif isinstance(data_source, str):
            data_list = []
            with open(data_source, 'r', encoding='utf-8') as f:
                for line in f:
                    item = json.loads(line.strip())
                    data_list.append(item)
            self.data = pd.DataFrame(data_list)
            print(f"Caricati {len(self.data)} esempi da {data_source}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        prompt = self.tokenizer(
            self.data.iloc[idx]["input"],
            padding="max_length",
            truncation=True, 
            max_length=512, 
            return_tensors=None
        )
        labels = self.tokenizer(
            f"{self.data.iloc[idx]['input']} {self.data.iloc[idx]['output']}",
            padding="max_length",
            truncation=True, 
            max_length=512, 
            return_tensors=None
        )
        attention_mask = prompt["attention_mask"]
        start_locs = self.tokenizer(self.data.iloc[idx]["input"])

        return {
            "input_ids": torch.tensor(prompt["input_ids"]),
            "attention_mask": torch.tensor(attention_mask),
            "start_locs": len(start_locs["input_ids"]) - 1,
            "labels": torch.tensor(labels["input_ids"]),
            "split": 1 if self.data.iloc[idx]["split"] == "forget" else 0,
        }

In [None]:
# Create dataset and dataloader
batch_size = 4
train_data = pd.concat([retain_train_df, forget_train_df], ignore_index=True)
dataset = UnlearningDataset(train_data, tokenizer)
dataloader = DataLoader(dataset, batch_size, shuffle=True)

print(f"Dataset creato con {len(dataset)} esempi")

Caricati 2248 esempi dal DataFrame
Dataset creato con 2248 esempi


## 4. DualTeacher Trainer Class

In [None]:
class DualTeacherTrainer:
    def __init__(self, model_path, tokenizer, teacher_lora_config, student_lora_config, device_map=None):
        self.model_path = model_path
        self.tokenizer = tokenizer
        self.teacher_lora_config = teacher_lora_config
        self.student_lora_config = student_lora_config
        self.device_map = device_map or {"student": "cuda:0", "teacher": "cuda:1"}
        
        self.good_teacher = None
        self.student_model = None
        self.initial_state_dict = {}
        
    def setup_models(self):
        """Initialize and setup both teacher and student models"""
        print("🔧 Setting up models...")
        
        # Setup good teacher with LoRA (for training)
        base_model = AutoModelForCausalLM.from_pretrained(self.model_path, local_files_only=True)
        self.good_teacher = get_peft_model(base_model, self.teacher_lora_config)
        self.good_teacher = self.good_teacher.to(self.device_map["teacher"])
        self.good_teacher.print_trainable_parameters()
        
        # Setup student model with LoRA
        base_model_student = AutoModelForCausalLM.from_pretrained(self.model_path, local_files_only=True)
        self.student_model = get_peft_model(base_model_student, self.student_lora_config)
        self.student_model = self.student_model.to(self.device_map["student"])
        self.student_model.print_trainable_parameters()
        
        # Save initial state for task vector calculation
        for name, param in self.student_model.named_parameters():
            if param.requires_grad:
                self.initial_state_dict[name] = param.data.clone()
        
        print("✅ Models setup completed")
    
    def merge_and_unload_teacher(self):
        """Convert good teacher from LoRA to base model after training"""
        print("🔄 Converting good teacher to base model...")
        
        # Merge LoRA weights into base model
        merged_model = self.good_teacher.merge_and_unload()
        
        # Replace the LoRA model with the merged base model
        self.good_teacher = merged_model.to(self.device_map["teacher"])
        
        # Set to eval mode and freeze
        self.good_teacher.eval()
        for param in self.good_teacher.parameters():
            param.requires_grad = False
            
        print("✅ Good teacher converted to base model and frozen")
        
    def create_bad_teacher_logits(self, good_teacher_outputs):
        """Create bad teacher logits for unlearning"""
        device = self.device_map["teacher"]
        return torch.normal(
            mean=0, std=1, size=good_teacher_outputs.logits.shape
        ).to(device) + torch.ones(good_teacher_outputs.logits.shape[-1]).to(device)
    
    def compute_kl_divergence(self, batch):
        """Compute KL divergence loss for dual teacher training"""
        student_device = self.device_map["student"]
        teacher_device = self.device_map["teacher"]
        
        # Student forward pass (with LoRA)
        student_outputs = self.student_model(
            batch["input_ids"].to(student_device),
            attention_mask=batch["attention_mask"].to(student_device),
            labels=batch["labels"].to(student_device),
        )
        
        # Teacher forward pass (base model, no LoRA)
        with torch.no_grad():
            good_teacher_outputs = self.good_teacher(
                batch["input_ids"].to(teacher_device),
                attention_mask=batch["attention_mask"].to(teacher_device),
                labels=batch["labels"].to(teacher_device),
            )
        
        # Create bad teacher and combine
        l = torch.unsqueeze(batch["split"], -1)
        l = torch.unsqueeze(l, -1)
        
        bad_teacher = self.create_bad_teacher_logits(good_teacher_outputs)
        prob_p = torch.nn.functional.softmax(bad_teacher, -1)
        prob_f = torch.nn.functional.softmax(good_teacher_outputs.logits, -1)
        prob_q = torch.nn.functional.softmax(student_outputs.logits, -1).to(teacher_device)
        
        out_teacher = (1 - l.to(teacher_device)) * prob_f + l.to(teacher_device) * prob_p
        
        loss = (out_teacher * (torch.log(out_teacher + 1e-12) - torch.log(prob_q + 1e-12))).sum(-1).mean()
        
        return loss
    
    def train_good_teacher(self, dataloader, num_epochs=2, lr=1e-4):
        """Train the good teacher on retain samples using LoRA"""
        print("🚀 Training good teacher with LoRA...")
        
        self.good_teacher.train()
        optimizer = torch.optim.Adam(self.good_teacher.parameters(), lr=lr)
        
        for epoch in range(num_epochs):
            print(f"📅 Epoca {epoch + 1}/{num_epochs} - Good Teacher Training")
            
            epoch_losses = []
            retain_batches_processed = 0
            
            with tqdm(total=len(dataloader), desc=f"Good Teacher Epoca {epoch+1}") as pbar:
                for batch in dataloader:
                    # Filter retain samples only
                    split = batch['split']
                    retain_mask = (split == 0)
                    
                    if not retain_mask.any():
                        pbar.update(1)
                        continue
                    
                    # Extract retain samples
                    input_ids = batch['input_ids'][retain_mask].to(self.device_map["teacher"])
                    attention_mask = batch['attention_mask'][retain_mask].to(self.device_map["teacher"])
                    labels = batch['labels'][retain_mask].to(self.device_map["teacher"])
                    
                    if input_ids.size(0) == 0:
                        pbar.update(1)
                        continue
                    
                    optimizer.zero_grad()
                    outputs = self.good_teacher(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                    loss = outputs.loss
                    loss.backward()
                    optimizer.step()
                    
                    epoch_losses.append(loss.item())
                    retain_batches_processed += 1
                    pbar.update(1)
                    
                    if retain_batches_processed % 100 == 0:
                        pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
            
            if epoch_losses:
                avg_loss = np.mean(epoch_losses)
                print(f"📊 Good Teacher Epoca {epoch+1} - Loss medio: {avg_loss:.4f}")
        
        print("✅ Good teacher training completed")
        
        # Convert to base model after training
        self.merge_and_unload_teacher()
        
    def train_student(self, dataloader, num_epochs=4, lr=1e-4):
        """Train student model with dual teacher approach (student uses LoRA, teacher is base model)"""
        print("🚀 Training student with LoRA against base model teacher...")
        
        self.student_model.train()
        optimizer = torch.optim.AdamW(self.student_model.parameters(), lr=lr, weight_decay=0.01)
        
        for epoch in range(num_epochs):
            epoch_losses = []
            
            with tqdm(total=len(dataloader), desc=f"Student Epoca {epoch+1}") as pbar:
                for batch in dataloader:
                    optimizer.zero_grad()
                    loss = self.compute_kl_divergence(batch)
                    loss.backward()
                    optimizer.step()
                    
                    epoch_losses.append(loss.item())
                    pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
                    pbar.update(1)
            
            avg_loss = np.mean(epoch_losses)
            print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss:.4f}")
            
            # Save model after each epoch
            self.save_model(f"studentmodel_epoch_{epoch+1}")
        
        print("✅ Student training completed")
        
    def save_model(self, save_path):
        """Save student model and tokenizer"""
        self.student_model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)
        
    def save_good_teacher(self, save_path):
        """Save good teacher base model"""
        self.good_teacher.save_pretrained(save_path)
        
    def calculate_task_vector(self):
        """Calculate task vector from initial to final state"""
        task_vector = {}
        for name, param in self.student_model.named_parameters():
            if param.requires_grad and name in self.initial_state_dict:
                task_vector[name] = param.data - self.initial_state_dict[name]
        return task_vector

## 5. Setup Trainer and Training

In [6]:
# Configure LoRA for teacher and student
teacher_lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    bias="none",
)

student_lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)

# Initialize trainer
trainer = DualTeacherTrainer(
    model_path=MODEL_PATH,
    tokenizer=tokenizer,
    teacher_lora_config=teacher_lora_config,
    student_lora_config=student_lora_config,
    device_map={"student": "cuda:0", "teacher": "cuda:1"}
)

# Setup models
trainer.setup_models()

🔧 Setting up models...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 8,388,608 || all params: 1,288,175,616 || trainable%: 0.6512


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 4,194,304 || all params: 1,283,981,312 || trainable%: 0.3267
✅ Models setup completed


In [None]:
# Train good teacher (will be converted to base model after training)
trainer.train_good_teacher(dataloader, num_epochs=2, lr=1e-4)

# Save good teacher
trainer.save_good_teacher("good_teacher_trained")

🚀 Training good teacher with LoRA...
📅 Epoca 1/2 - Good Teacher Training


Good Teacher Epoca 1:  85%|████████▍ | 475/562 [13:53<03:07,  2.16s/it, Loss=7.1021]

In [None]:
# Train student with dual teacher approach
# Student uses LoRA, good teacher is now base model
trainer.train_student(dataloader, num_epochs=4, lr=1e-4)

# Save final model
trainer.save_model("studentmodel_final")

## 6. Save Results and Task Vector

In [None]:
# Create results directory
os.makedirs('balanced_results', exist_ok=True)

# Save student model
trainer.save_model('balanced_results/balanced_model')

# Calculate and save task vector
task_vector = trainer.calculate_task_vector()
torch.save(task_vector, 'balanced_results/task_vector.pt')

print("✅ Results saved in balanced_results/")
print("- balanced_model/: Student model with unlearning")
print("- task_vector.pt: Task vector for future applications")

# 7. Evaluation



In [None]:

import sys
import types

try:
    import evaluation
    import importlib
    importlib.reload(evaluation)
except ImportError:
    pass

def run_evaluation(
    data_path,
    checkpoint_path,
    output_dir="eval_results",
    mia_data_path=None,
    mmlu_metrics_file_path=None,
    max_new_tokens=256,
    batch_size=25,
    debug=False,
    compute_metrics_only=False,
    seed=42,
    keep_files=False,
):
    try:
        # Costruiamo un oggetto args simile a quello di argparse
        args = types.SimpleNamespace(
            data_path=data_path,
            checkpoint_path=checkpoint_path,
            output_dir=output_dir,
            mia_data_path=mia_data_path,
            mmlu_metrics_file_path=mmlu_metrics_file_path,
            max_new_tokens=max_new_tokens,
            batch_size=batch_size,
            debug=debug,
            compute_metrics_only=compute_metrics_only,
            seed=seed,
            keep_files=keep_files,
        )

        # Verifica che i file esistano
        print(f"🔍 Verificando paths...")
        print(f"  Data path: {data_path}")
        print(f"  Checkpoint path: {checkpoint_path}")
        print(f"  Output dir: {output_dir}")
        
        if not os.path.exists(data_path):
            raise FileNotFoundError(f"Data path not found: {data_path}")
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
        if not os.path.exists(os.path.join(data_path, 'forget.jsonl')):
            raise FileNotFoundError(f"forget.jsonl not found in {data_path}")
        if not os.path.exists(os.path.join(data_path, 'retain.jsonl')):
            raise FileNotFoundError(f"retain.jsonl not found in {data_path}")

        # Normalizza i path (come nello script originale)
        from pathlib import Path
        if args.output_dir is None:
            args.output_dir = os.getcwd()
        else:
            args.output_dir = args.output_dir.rstrip('/')
            Path(args.output_dir).mkdir(parents=True, exist_ok=True)

        # Lancia direttamente le funzioni
        import random, torch, numpy as np
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

        from accelerate import Accelerator
        accelerator = Accelerator()

        if not args.compute_metrics_only:
            from transformers import AutoModelForCausalLM, AutoTokenizer
            from peft import PeftModel, LoraConfig
            
            print(f"📥 Loading model from {args.checkpoint_path}...")
            
            # Carica il modello PEFT (LoRA) se è salvato come tale
            try:
                # Prima prova a caricare come modello PEFT
                base_model_path = MODEL_PATH  # Usa il path del modello base
                base_model = AutoModelForCausalLM.from_pretrained(
                    base_model_path, 
                    local_files_only=True,
                    torch_dtype=torch.bfloat16
                )
                model = PeftModel.from_pretrained(base_model, args.checkpoint_path)
                print("✅ Loaded as PEFT model")
            except:
                # Se fallisce, prova a caricare come modello normale
                model = AutoModelForCausalLM.from_pretrained(
                    args.checkpoint_path,
                    torch_dtype=torch.bfloat16,
                    trust_remote_code=True
                )
                print("✅ Loaded as regular model")
            
            tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

            print("🚀 Starting inference...")
            evaluation.inference(args, model, tokenizer)
            
            if args.mia_data_path is not None:
                print("🔍 Starting MIA attacks...")
                evaluation.mia_attacks(args, model, tokenizer)

        if accelerator.is_main_process:
            print("📊 Computing metrics...")
            evaluation.compute_metrics(args)
            print("✅ Evaluation completed!")

    except Exception as e:
        print(f"❌ Error during evaluation: {e}")
        import traceback
        traceback.print_exc()

# === Step 4: Esegui evaluation ===
print("🎯 Starting evaluation process...")

# Verifica che i file esistano prima di iniziare
if os.path.exists("validation/forget.jsonl") and os.path.exists("validation/retain.jsonl"):
    if os.path.exists("balanced_results/balanced_model/"):
        run_evaluation(
            data_path="validation/",  # cartella relativa con forget.jsonl e retain.jsonl
            checkpoint_path="balanced_results/balanced_model/",  # cartella relativa con i pesi del modello
            output_dir="eval_results",
            debug=True  # Attiva debug per vedere cosa succede
        )
    else:
        print("❌ Model checkpoint not found at balanced_results/balanced_model/")
        print("   Make sure the training completed successfully")
else:
    print("❌ Validation files not found")
    print("   Expected: validation/forget.jsonl and validation/retain.jsonl")
    print("   Make sure the data processing completed successfully")