# Unlearning DualTeacher


## 1. Setup and Imports

In [None]:
# Install required packages for Kaggle environment (safe if already installed)
%pip install -q rouge-score

import os
import json
import shutil
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from rouge_score import rouge_scorer
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType, PeftModel

# Paths and configuration for Kaggle environment
MODEL_PATH = "/kaggle/input/olmo-model/semeval25-unlearning-1B-model"
DATA_PATH = "/kaggle/input/olmo-model/semeval25-unlearning-data"
MIA_VAL_PATH = "/kaggle/input/mia-dataset-val"
MIA_TRAIN_PATH = "/kaggle/input/mia-dataset"
GOOD_TEACHER_PATH = "/kaggle/input/good-teacher"

STUDENT_TRAINED = "/kaggle/input/student-trained"
STUDENT_PATH = "/kaggle/working/studentmodel_final"
Path(STUDENT_PATH).mkdir(parents=True, exist_ok=True)

# Copy any pre-trained student artifacts from input to working dir (idempotent)
dir_path = Path(STUDENT_TRAINED)
for file in dir_path.iterdir():
    shutil.copyfile(f"{STUDENT_TRAINED}/{file.name}", f"{STUDENT_PATH}/{file.name}")

# Report visible GPUs in Kaggle runtime
print(f"Available GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

## 2. Data and Model Loading

In [None]:
# Load parquet datasets from Kaggle inputs
retain_train_df = pd.read_parquet(f"{DATA_PATH}/data/retain_train-00000-of-00001.parquet", engine='pyarrow')
retain_validation_df = pd.read_parquet(f"{DATA_PATH}/data/retain_validation-00000-of-00001.parquet", engine='pyarrow')
forget_train_df = pd.read_parquet(f"{DATA_PATH}/data/forget_train-00000-of-00001.parquet", engine='pyarrow')
forget_validation_df = pd.read_parquet(f"{DATA_PATH}/data/forget_validation-00000-of-00001.parquet", engine='pyarrow')

# Save as JSONL for evaluation scripts (portable without shell commands)
Path('train').mkdir(parents=True, exist_ok=True)
Path('validation').mkdir(parents=True, exist_ok=True)
retain_train_df.to_json('train/retain.jsonl', orient='records', lines=True)
forget_train_df.to_json('train/forget.jsonl', orient='records', lines=True)
retain_validation_df.to_json('validation/retain.jsonl', orient='records', lines=True)
forget_validation_df.to_json('validation/forget.jsonl', orient='records', lines=True)

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Datasets saved and tokenizer loaded")

## 3. Dataset

In [None]:
class UnlearningDataset(Dataset):
    """Dataset for unlearning with combined input and output text.

    - Expects items with keys: 'input', 'output', optional 'split' ('retain' or 'forget').
    - Returns tokenized tensors and the start index of the output (start_locs).
    """
    def __init__(self, data_source, tokenizer, max_length=256):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        if isinstance(data_source, pd.DataFrame):
            self.data = data_source
            print(f"Loaded {len(self.data)} samples from DataFrame")
        elif isinstance(data_source, str):
            data_list = []
            with open(data_source, 'r', encoding='utf-8') as f:
                for line in f:
                    item = json.loads(line.strip())
                    data_list.append(item)
            self.data = pd.DataFrame(data_list)
            print(f"Loaded {len(self.data)} samples from {data_source}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        input_text = item["input"]
        output_text = item["output"]

        # Single tokenization of concatenated input and output
        combined = f"{input_text} {output_text}"
        tokenized = self.tokenizer(
            combined,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Tokenize only the input to compute the start index of the output
        input_ids = self.tokenizer(
            input_text,
            return_tensors="pt"
        )["input_ids"].squeeze(0)

        return {
            "input_ids": tokenized["input_ids"].squeeze(0),
            "attention_mask": tokenized["attention_mask"].squeeze(0),
            "start_locs": input_ids.size(0),  # index where output begins
            "labels": tokenized["input_ids"].squeeze(0),
            "split": 1 if item.get("split", "retain") == "forget" else 0
        }


In [None]:
# Create datasets and dataloaders
batch_size = 4
# Build separate datasets for retain/forget splits
retain_train_dataset = UnlearningDataset(retain_train_df, tokenizer)
forget_train_dataset = UnlearningDataset(forget_train_df, tokenizer)

retain_train_dataloader = DataLoader(retain_train_dataset, batch_size, shuffle=True)
forget_train_dataloader = DataLoader(forget_train_dataset, batch_size, shuffle=True)

# Validation loaders (no shuffling)
retain_val_dataset = UnlearningDataset(retain_validation_df, tokenizer)
forget_val_dataset = UnlearningDataset(forget_validation_df, tokenizer)
retain_val_dataloader = DataLoader(retain_val_dataset, batch_size, shuffle=False)
forget_val_dataloader = DataLoader(forget_val_dataset, batch_size, shuffle=False)

## 4. DualTeacher Trainer Class

In [None]:
class DualTeacherTrainer:
    def __init__(self, model_path, tokenizer, teacher_lora_config, student_lora_config, device_map=None):
        self.model_path = model_path
        self.tokenizer = tokenizer
        self.teacher_lora_config = teacher_lora_config
        self.student_lora_config = student_lora_config
        self.device_map = device_map or {"student": "cuda:0", "teacher": "cuda:1"}
        
        self.good_teacher = None
        self.student_model = None
        self.initial_state_dict = {}
        
        # Validation tracking
        self.best_val_loss = float('inf')
        self.best_epoch = 0
        
        # Sequential training parameters
        self.forget_weight_multiplier = 0.5  # Reduce forget influence to prevent catastrophic forgetting
        
        
    def setup_models(self, skip_teacher_setup=False):
        """Initialize and setup both teacher and student models"""
        # Ensure two GPUs are available (Kaggle dual GPU runtime)
        if torch.cuda.device_count() < 2:
            raise RuntimeError("This notebook expects 2 GPUs. Please enable a 2-GPU accelerator in Kaggle settings.")
        print("Setting up models...")
        base_model = AutoModelForCausalLM.from_pretrained(self.model_path, local_files_only=True)
        
        if skip_teacher_setup is False:
            # Setup good teacher with LoRA (for training)
            self.good_teacher = get_peft_model(base_model, self.teacher_lora_config)
            self.good_teacher = self.good_teacher.to(self.device_map["teacher"])
            self.good_teacher.print_trainable_parameters()
            
        # Setup student model with LoRA
        self.student_model = get_peft_model(base_model, self.student_lora_config)
        self.student_model = self.student_model.to(self.device_map["student"])
        self.student_model.print_trainable_parameters()
        
        # Save initial state for task vector calculation
        for name, param in self.student_model.named_parameters():
            if param.requires_grad:
                self.initial_state_dict[name] = param.data.clone()
        
        print("Models setup completed")
        
    
    def create_bad_teacher_logits(self, good_teacher_logits):
        """Create bad-teacher logits as noisy values around a uniform distribution."""
        # Uniform logits (pre-softmax)
        uniform_logits = torch.zeros_like(good_teacher_logits)
        # Add Gaussian noise
        noisy_logits = uniform_logits + 0.1 * torch.randn_like(good_teacher_logits)
        return noisy_logits


        
    
    def compute_retain_only_loss(self, batch):
        """Compute loss only for retain samples"""
        # Devices
        student_device = self.device_map["student"]
        teacher_device = self.device_map["teacher"]
    
        # Inputs
        input_ids_student = batch["input_ids"].to(student_device)
        attention_mask_student = batch["attention_mask"].to(student_device)
    
        # Student forward
        student_logits = self.student_model(input_ids_student, attention_mask=attention_mask_student).logits
        student_log_probs = torch.nn.functional.log_softmax(student_logits, dim=-1).to(teacher_device)
    
        # Teacher forward (no grad)
        input_ids_teacher = batch["input_ids"].to(teacher_device)
        attention_mask_teacher = batch["attention_mask"].to(teacher_device)
    
        with torch.no_grad():
            good_teacher_logits = self.good_teacher(input_ids_teacher, attention_mask=attention_mask_teacher).logits
            good_teacher_probs = torch.nn.functional.softmax(good_teacher_logits, dim=-1)
    
        # KL divergence for retain samples
        retain_kl = torch.nn.functional.kl_div(
            student_log_probs,
            good_teacher_probs,
            reduction="batchmean",  # averaged over batch and tokens
            log_target=False
        )

        # Student entropy regularization (encourages confidence)
        entropy_loss = -(student_log_probs.exp() * student_log_probs).sum(-1).mean()
        
        return 3.0 * retain_kl - 0.1 * entropy_loss    # Negative entropy to encourage confidence
    
    def compute_forget_only_loss(self, batch):
        """Compute loss only for forget samples"""
        # Devices
        student_device = self.device_map["student"]
        teacher_device = self.device_map["teacher"]
    
        # Inputs
        input_ids_student = batch["input_ids"].to(student_device)
        attention_mask_student = batch["attention_mask"].to(student_device)
    
        # Student forward
        student_logits = self.student_model(input_ids_student, attention_mask=attention_mask_student).logits
        student_log_probs = torch.nn.functional.log_softmax(student_logits, dim=-1).to(teacher_device)
    
        # Teacher forward (no grad)
        input_ids_teacher = batch["input_ids"].to(teacher_device)
        attention_mask_teacher = batch["attention_mask"].to(teacher_device)
    
        with torch.no_grad():
            good_teacher_logits = self.good_teacher(input_ids_teacher, attention_mask=attention_mask_teacher).logits
            bad_teacher_logits = self.create_bad_teacher_logits(good_teacher_logits)
            bad_teacher_probs = torch.nn.functional.softmax(bad_teacher_logits, dim=-1)
    
        # KL divergence for forget samples (towards bad teacher)
        forget_kl = torch.nn.functional.kl_div(
            student_log_probs,
            bad_teacher_probs,
            reduction="batchmean",
            log_target=False
        )

        # Student entropy regularization (encourages uncertainty)
        entropy_loss = -(student_log_probs.exp() * student_log_probs).sum(-1).mean()
        
        return 2.0 * forget_kl + 0.1 * entropy_loss  # Positive entropy to encourage uncertainty

    def compute_kl_divergence(self, batch):
        """Original mixed training method"""
        # Devices
        student_device = self.device_map["student"]
        teacher_device = self.device_map["teacher"]
    
        # Inputs
        input_ids_student = batch["input_ids"].to(student_device)
        attention_mask_student = batch["attention_mask"].to(student_device)
        labels_student = batch["labels"].to(student_device)
        split = batch["split"].float().to(student_device)
    
        # Student forward
        student_logits = self.student_model(input_ids_student, attention_mask=attention_mask_student).logits
        student_log_probs = torch.nn.functional.log_softmax(student_logits, dim=-1).to(teacher_device)
    
        # Teacher forward (no grad)
        input_ids_teacher = batch["input_ids"].to(teacher_device)
        attention_mask_teacher = batch["attention_mask"].to(teacher_device)
    
        with torch.no_grad():
            good_teacher_logits = self.good_teacher(input_ids_teacher, attention_mask=attention_mask_teacher).logits
            bad_teacher_logits = self.create_bad_teacher_logits(good_teacher_logits)
            good_teacher_probs = torch.nn.functional.softmax(good_teacher_logits, dim=-1)
            bad_teacher_probs = torch.nn.functional.softmax(bad_teacher_logits, dim=-1)
    
        # Masks retain/forget
        retain_mask = (split <= 0.5).to(teacher_device)
        forget_mask = (split > 0.5).to(teacher_device)
    
        total_loss = 0.0
        if retain_mask.any():
            retain_kl = torch.nn.functional.kl_div(
                student_log_probs[retain_mask],
                good_teacher_probs[retain_mask.bool()],
                reduction="none",
                log_target=False
            ).sum(dim=-1)  # sum over vocabulary
            retain_kl = retain_kl.mean()
            total_loss += 3.0 * retain_kl  # retain_weight
    
        if forget_mask.any():
            forget_kl = torch.nn.functional.kl_div(
                student_log_probs[forget_mask],
                bad_teacher_probs[forget_mask.bool()],
                reduction="none",
                log_target=False
            ).sum(dim=-1)
            forget_kl = forget_kl.mean()
            total_loss +=  1.5 * forget_kl  # forget_weight
    
        # Student entropy regularization
        entropy_loss = -(student_log_probs.exp() * student_log_probs).sum(-1).mean()
    
        return total_loss + 0.2 * entropy_loss

    def validate_retain_only(self, retain_val_dataloader):
        """Validate only on retain samples"""
        self.student_model.eval()
        val_losses = []
        
        with torch.no_grad():
            for batch in retain_val_dataloader:
                loss = self.compute_retain_only_loss(batch)
                val_losses.append(loss.item())
        
        self.student_model.train()
        return np.mean(val_losses) if val_losses else float('inf')

    def validate_forget_only(self, forget_val_dataloader):
        """Validate only on forget samples"""
        self.student_model.eval()
        val_losses = []
        
        with torch.no_grad():
            for batch in forget_val_dataloader:
                loss = self.compute_forget_only_loss(batch)
                val_losses.append(loss.item())
        
        self.student_model.train()
        return np.mean(val_losses) if val_losses else float('inf')
      
        
    def train_student_sequential(self, retain_dataloader, forget_dataloader, 
                               retain_val_dataloader=None, forget_val_dataloader=None, 
                               num_epochs=6, lr=1e-5, val_freq=1, patience=3):
        """Train student model with sequential retain/forget training"""
        print("Training student with SEQUENTIAL dual-teacher approach...")
        
        self.student_model.train()
        optimizer = torch.optim.AdamW(self.student_model.parameters(), lr=lr, weight_decay=0.01)
        
        patience_counter = 0
        
        for epoch in range(num_epochs):
            retain_losses = []
            forget_losses = []
            
            # Calculate total batches for progress bar
            total_batches = len(retain_dataloader) + len(forget_dataloader)
            
            with tqdm(total=total_batches, desc=f"Student Epoch {epoch+1} (Sequential)") as pbar:
                
                # Phase 1: Train on retain batches
                pbar.set_description(f"Epoch {epoch+1} - Phase 1: Retain")
                for batch in retain_dataloader:
                    optimizer.zero_grad()
                    loss = self.compute_retain_only_loss(batch)
                    loss.backward()
                    optimizer.step()
                    
                    retain_losses.append(loss.item())
                    pbar.update(1)
                    pbar.set_postfix({"Retain Loss": f"{loss.item():.4f}"})
                
                # Phase 2: Pure forget training (with reduced weight)
                pbar.set_description(f"Epoch {epoch+1} - Phase 2: Forget")
                for batch in forget_dataloader:
                    optimizer.zero_grad()
                    loss = self.compute_forget_only_loss(batch) * self.forget_weight_multiplier
                    loss.backward()
                    optimizer.step()
                    
                    forget_losses.append(loss.item() / self.forget_weight_multiplier)  # Store unweighted for logging
                    pbar.update(1)
                    pbar.set_postfix({"Forget Loss": f"{loss.item():.4f}"})

                
               
            # Logging
            avg_retain_loss = np.mean(retain_losses) if retain_losses else 0.0
            avg_forget_loss = np.mean(forget_losses) if forget_losses else 0.0
            
            print(f"Student Epoch {epoch+1}")
            print(f"   └─ Retain: {avg_retain_loss:.4f}, Forget: {avg_forget_loss:.4f} (×{self.forget_weight_multiplier})")
            
            # Validation if validation dataloaders are provided
            if retain_val_dataloader is not None and forget_val_dataloader is not None and (epoch + 1) % val_freq == 0:
                print("Running validation...")
                retain_val_loss = self.validate_retain_only(retain_val_dataloader)
                forget_val_loss = self.validate_forget_only(forget_val_dataloader)
                
                # Combine with weights for early stopping
                # For forget loss, we use the inverse because we want it to be high (more confusion = better)
                # Normalize to get a consistent combined metric
                retain_score = retain_val_loss                   # lower = better
                forget_score = 1.0 / (1.0 + forget_val_loss)     # lower = better if forget loss is high
            
                combined_val_loss = 0.7 * retain_score + 0.3 * forget_score
                print(f"Validation - Retain: {retain_val_loss:.4f}, Forget: {forget_val_loss:.4f}, Combined: {combined_val_loss:.4f}")
                
                # Check if this is the best model so far
                if combined_val_loss < self.best_val_loss:
                    self.best_val_loss = combined_val_loss
                    self.best_epoch = epoch + 1
                    patience_counter = 0
                    print(f"New best validation loss: {combined_val_loss:.4f}")
                    
                    # Save best model
                    self.save_model(f"studentmodel_best_val_sequential")
                else:
                    patience_counter += 1
                    print(f"No improvement for {patience_counter} validation checks")
                
                # Early stopping
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1} (no improvement for {patience} validations)")
                    break
            
            # Save model after each epoch
            self.save_model(f"studentmodel_epoch_{epoch+1}_sequential")
        
        print("Student sequential training completed")
        print(f"Best validation loss: {self.best_val_loss:.4f} at epoch {self.best_epoch}")

    def train_good_teacher(self, dataloader, val_dataloader=None, num_epochs=2, lr=1e-4, save_path="good_teacher_adapter", val_freq=1):
        """Train the good teacher on retain samples using LoRA with optional validation"""
        print("Training good teacher with LoRA...")

        self.good_teacher.to(self.device_map["teacher"])
        self.good_teacher.train()
        optimizer = torch.optim.AdamW(self.good_teacher.parameters(), lr=lr)
        
        for epoch in range(num_epochs):
            print(f"Epoch {epoch + 1}/{num_epochs} - Good Teacher Training")
            
            epoch_losses = []
            retain_batches_processed = 0
            
            with tqdm(total=len(dataloader), desc=f"Good Teacher Epoch {epoch+1}") as pbar:
                for batch in dataloader:
                    # Filter retain samples only
                    split = batch['split']
                    retain_mask = (split == 0)
                    
                    if not retain_mask.any():
                        pbar.update(1)
                        continue
                    
                    # Extract retain samples
                    input_ids = batch['input_ids'][retain_mask].to(self.device_map["teacher"])
                    attention_mask = batch['attention_mask'][retain_mask].to(self.device_map["teacher"])
                    labels = batch['labels'][retain_mask].to(self.device_map["teacher"])
                    
                    if input_ids.size(0) == 0:
                        pbar.update(1)
                        continue
                    
                    optimizer.zero_grad()
                    outputs = self.good_teacher(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                    loss = outputs.loss
                    
                    loss.backward()
                    optimizer.step()
                    
                    epoch_losses.append(loss.item())
                    retain_batches_processed += 1
                    pbar.update(1)
                    
                    if retain_batches_processed % 100 == 0:
                        pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
            
            if epoch_losses:
                avg_loss = np.mean(epoch_losses)
                print(f"Good Teacher Epoch {epoch+1} - Average Loss: {avg_loss:.4f}")
                
                # Validation if validation dataloader is provided
                if val_dataloader is not None and (epoch + 1) % val_freq == 0:
                    print("Running validation...")
                    val_loss = self.validate_good_teacher(val_dataloader)
                    print(f"Validation Loss: {val_loss:.4f}")
        
        print("Good teacher training completed")
        
        # Save only LoRA adapter
        self.save_good_teacher(save_path)
        print(f"Good teacher adapter saved to {save_path}")
        
    def save_good_teacher(self, save_path):
        """Save only the LoRA adapter of the good teacher"""
        self.good_teacher.save_pretrained(save_path)
        
    def save_model(self, save_path):
        """Save student model and tokenizer"""
        self.student_model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)

    def load_teacher(self, GOOD_TEACHER_PATH):
        """Load a pre-trained good teacher model"""
        self.good_teacher = AutoModelForCausalLM.from_pretrained(GOOD_TEACHER_PATH)
        self.good_teacher.eval()  # freeze it
        for param in self.good_teacher.parameters():
            param.requires_grad = False
        self.good_teacher.to(self.device_map["teacher"])
        print("Good teacher loaded and frozen")

## 5. Setup Trainer and Training

In [None]:
train_good_teacher = False

In [None]:
# Configure LoRA for teacher and student
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    bias="none",
)


# Initialize trainer
trainer = DualTeacherTrainer(
    model_path=MODEL_PATH,
    tokenizer=tokenizer,
    teacher_lora_config=lora_config,
    student_lora_config=lora_config,
    device_map={"student": "cuda:0", "teacher": "cuda:1"}
)
# # Setup models

if train_good_teacher == True:
    trainer.setup_models()
    # Train good teacher (will be converted to base model after training)
    trainer.train_good_teacher(train_dataloader, val_dataloader, num_epochs=5) 
    

In [None]:
# Setup models (student is initialized regardless)
trainer.setup_models(skip_teacher_setup=True)

# Load a pre-trained good teacher base model and freeze it
trainer.load_teacher(GOOD_TEACHER_PATH)

# Now train the student directly
trainer.train_student_sequential(
    retain_train_dataloader, 
    forget_train_dataloader,
    retain_val_dataloader,
    forget_val_dataloader,
    num_epochs=6
)

## 6. Evaluation

In [None]:
with open("evaluation.py", "w") as f:
    f.write(r"""
# Official evaluation script provided by the task organizers

import os
import sys
import json
import glob
import math
import torch
import random
import shutil
import argparse
import datasets
import numpy as np
import pandas as pd

from tqdm import tqdm
from pathlib import Path
from accelerate import Accelerator
from collections import defaultdict
from statistics import mean, harmonic_mean
from rouge_score import rouge_scorer
from torch.utils.data import DataLoader
from sklearn.metrics import roc_curve, auc
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

def get_args_and_verify():
    parser = argparse.ArgumentParser(description="Script to run inference and evaluation")
    parser.add_argument('--data_path', help="Path to unlearning dataset containing jsonl files")
    parser.add_argument('--checkpoint_path', help="Path to model checkpoint")
    parser.add_argument('--output_dir', required=False, default=None, help="Path to store inference files and evaluation results")
    parser.add_argument('--mia_data_path', required=False, default=None, help="Path to member and nonmember jsonl files for MIA attack")
    parser.add_argument('--mmlu_metrics_file_path', required=False, default=None, help="Path to metrics.json file generated by MMLU")
    parser.add_argument('--max_new_tokens', required=False, type=int, default=256, help='Maximum number of tokens to generate')
    parser.add_argument('--batch_size', required=False, type=int, default=25, help='Batch size for inference')
    parser.add_argument('--debug', required=False, default=False, action='store_true', help='Print detailed messages')
    parser.add_argument('--compute_metrics_only', required=False, default=False, action='store_true', help='Skip inference and compute metrics from inference files')
    parser.add_argument('--seed', required=False, default=42, help='Random seed for experiments')
    parser.add_argument('--keep_files', required=False, default=False, action='store_true', help='Retain intermediate files')
    args = parser.parse_args()

    if args.compute_metrics_only:
        args.keep_files = True

    if args.output_dir is None:
        args.output_dir = os.getcwd()
    else:
        args.output_dir = args.output_dir.rstrip('/')
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    # Verify data files exist
    assert(os.path.exists(args.data_path))
    assert(os.path.exists(os.path.join(args.data_path, 'forget.jsonl')))
    assert(os.path.exists(os.path.join(args.data_path, 'retain.jsonl')))

    # If specified, verify data files exist
    if args.mia_data_path is not None:
        assert(os.path.exists(args.mia_data_path))
        assert(os.path.exists(os.path.join(args.mia_data_path, 'member.jsonl')))
        assert(os.path.exists(os.path.join(args.mia_data_path, 'nonmember.jsonl')))
    else: # Add warning if MIA path not provided
        print("WARNING: MIA data path not provided. Final evaluation includes MIA; please rerun with this option for accurate performance. Proceeding for now.")
        
    if args.mmlu_metrics_file_path is not None:
        assert(os.path.exists(args.mmlu_metrics_file_path))
    else: # Add warning if MMLU file not provided
        print("WARNING: MMLU metrics file not provided, your final evaluation metric includes MMLU aggregate performance so please run this test to get the accurate performance. Proceeding for now")

    return args

def inference(args, model, tokenizer):
    forget_file = args.data_path + 'forget.jsonl'
    retain_file = args.data_path + 'retain.jsonl'

    accelerator = Accelerator()
    model.to(accelerator.device)

    for split, train_file in [('retain', retain_file), ('forget', forget_file)]:
        data_files = {}
        dataset_args = {}
        if train_file is not None:
            data_files["train"] = train_file
        raw_datasets = datasets.load_dataset(
            "json",
            data_files=data_files,
            **dataset_args,
        )
        train_dataset = raw_datasets["train"]

        output_dic = defaultdict(lambda :{'id': [], 'task': [], 'input': [], 'expected_output': [], 'model_output': [], 'nll': []})

        with accelerator.split_between_processes(train_dataset, apply_padding=True) as data:
            for idx in tqdm(range(len(data['input']))):
                question, answer = data["input"][idx], data["output"][idx]
                output_dic[accelerator.process_index]['id'].append(data["id"][idx])
                output_dic[accelerator.process_index]['task'].append(data["task"][idx])
                output_dic[accelerator.process_index]['input'].append(data["input"][idx])
                output_dic[accelerator.process_index]['expected_output'].append(data["output"][idx])
                input_ids = tokenizer(
                    question,
                    return_tensors='pt'
                ).input_ids.to(model.device)

                combined_input_ids = tokenizer(
                    question+answer,
                    return_tensors='pt'
                ).input_ids.to(model.device)
                combined_target_ids = combined_input_ids.clone()
                combined_target_ids[:,:len(input_ids[0])] = -100

                with torch.no_grad():
                    # Create attention mask to avoid warnings
                    attention_mask = torch.ones_like(input_ids)
                    out = model.generate(
                        input_ids, 
                        attention_mask=attention_mask,
                        max_new_tokens=args.max_new_tokens, 
                        do_sample=False, 
                        use_cache=True, 
                        pad_token_id=tokenizer.eos_token_id
                    )
                    output_ids = out[:, len(input_ids[0]):]
                    output = tokenizer.batch_decode(
                        output_ids,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=True)[0]
                    output_dic[accelerator.process_index]['model_output'].append(output)

                    # For Perplexity
                    out = model(combined_input_ids, labels=combined_target_ids)
                    if args.debug:
                        print(tokenizer.batch_decode(
                            torch.argmax(
                                torch.nn.functional.softmax(
                                    out.logits.clone().detach(),
                                    dim=2),
                                dim=2)[:, len(input_ids[0]):],
                            skip_special_tokens=True,
                            clean_up_tokenization_spaces=True)[0])
                    neg_log_likelihood = out.loss.item()
                    output_dic[accelerator.process_index]['nll'].append(neg_log_likelihood)

        accelerator.wait_for_everyone()
        
        if args.debug:
            print([len(value) for value in output_dic[accelerator.process_index].values()])
        output_df = pd.DataFrame.from_dict(output_dic[accelerator.process_index])
        
        output_file_name = f"{args.output_dir}/{split}_{accelerator.process_index}.csv"
        if args.debug:
            print('Saving to: ', output_file_name)
        output_df.to_csv(output_file_name, index=False)

def mia_attacks(args, model, tokenizer):
    member_file = args.mia_data_path + 'member.jsonl'
    nonmember_file = args.mia_data_path + 'nonmember.jsonl'

    accelerator = Accelerator()
    model.to(accelerator.device)

    for dataset, train_file in [('member', member_file), ('nonmember', nonmember_file)]:
        data_files = {}
        dataset_args = {}
        if train_file is not None:
            data_files["train"] = train_file
        raw_datasets = datasets.load_dataset(
            "json",
            data_files=data_files,
            **dataset_args,
        )
        train_dataset = raw_datasets["train"]

        output_dic = defaultdict(lambda :{'id': [], 'nll': []})

        with accelerator.split_between_processes(train_dataset, apply_padding=True) as data:
            for idx in tqdm(range(len(data['document']))):
                document = data["document"][idx]
                output_dic[accelerator.process_index]['id'].append(data["id"][idx])
                input_ids = tokenizer(
                    document,
                    return_tensors='pt'
                ).input_ids.to(model.device)

                target_ids = input_ids.clone()

                with torch.no_grad():
                    out = model(input_ids, labels=target_ids)
                    neg_log_likelihood = out.loss.item()
                    output_dic[accelerator.process_index]['nll'].append(neg_log_likelihood)

        accelerator.wait_for_everyone()
        
        output_df = pd.DataFrame.from_dict(output_dic[accelerator.process_index])
        
        results_dir = os.path.join(args.output_dir, 'mia_results')
        Path(results_dir).mkdir(parents=True, exist_ok=True)
        output_file_name = f"{results_dir}/{dataset}_{accelerator.process_index}.csv"
        if args.debug:
            print('Saving to: ', output_file_name)
        output_df.to_csv(output_file_name, index=False)

def compute_auc(member_loss, nonmember_loss):
    assert not np.any(np.isnan(member_loss))
    assert not np.any(np.isnan(nonmember_loss))
    combined_loss = member_loss + nonmember_loss 
    combined_loss = -1 * np.array(combined_loss)
    combined_labels = len(member_loss) * [1] + len(nonmember_loss) * [0]
    fp, tp, _ = roc_curve(combined_labels, combined_loss)

    auc_score = float(auc(fp, tp))

    return auc_score

def compute_metrics(args):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)

    results = {}
    aggregate_scores_list = []
    for split in ['forget', 'retain']:
        files = glob.glob(args.output_dir + '/{}_*.csv'.format(split))
        if len(files) == 0:
            print("[ERROR] Missing inference files, rerun script with inference first")
            return  # sys.exit(1) throws a long traceback so just return for now
        df_list = [pd.read_csv(f) for f in files]
        if not args.keep_files:
            _ = [os.remove(f) for f in files]
        df = pd.concat(df_list, ignore_index=True)

        df['regurgitation-score-rouge-1'] = None
        df['regurgitation-score'] = None
        df['knowledge-score'] = None
        ground_truths = df['expected_output'].tolist()
        gen_outputs = df['model_output'].tolist()

        for i, (gen, gt) in enumerate(zip(gen_outputs, ground_truths)):
            if df.loc[i, 'id'][:-1].endswith('sc'):
                rouge_scores = scorer.score(str(gt), str(gen))
                df.loc[i, 'regurgitation-score-rouge-1'] = rouge_scores['rouge1'].recall
                df.loc[i, 'regurgitation-score'] = rouge_scores['rougeL'].recall
            elif df.loc[i, 'id'][:-1].endswith('qa'):
                 df.loc[i, 'knowledge-score'] = int(str(gt).strip().lower() == str(gen).strip().lower())

        results[split+'-set'] = {'overall-regurgitation-score': np.mean(df['regurgitation-score']), 'overall-knowledge-score': np.mean(df['knowledge-score'])}
        split_aggregate_scores_dict = df.groupby('task')[['regurgitation-score', 'knowledge-score']].mean().to_dict(orient='index')
        results[split+'-set'].update(split_aggregate_scores_dict)
        split_aggregate_score_values = [float(val) for inner in split_aggregate_scores_dict.values() for val in inner.values()]
        if split == 'forget':
            split_aggregate_score_values = [(1 - val) for val in split_aggregate_score_values]

        aggregate_scores_list.extend(split_aggregate_score_values)

    if args.mia_data_path is not None:
        mia_results_dir = os.path.join(args.output_dir, 'mia_results')
        mia_results = {}
        for dataset in ['member', 'nonmember']:
            files = glob.glob(mia_results_dir + '/{}_*.csv'.format(dataset))
            if len(files) == 0:
                print("[ERROR] Missing mia files, rerun script with inference first")
                return  # sys.exit(1) throws a long traceback so just return for no
            df_list = [pd.read_csv(f) for f in files]
            df = pd.concat(df_list, ignore_index=True)
            mia_results[dataset] = df['nll'].tolist()
        
        if not args.keep_files:
            shutil.rmtree(mia_results_dir)

        auc = compute_auc(mia_results['member'], mia_results['nonmember'])
        # Best MIA rates we can get are ~0.5. 
        # Scores close to 1 suggest under-unlearning
        # Scores close to 0 suggest over-unlearning
        results['mia_loss_acc'] = auc
#        aggregate_scores_list.append(1 - auc) 

    if args.mmlu_metrics_file_path is not None:
        with open(args.mmlu_metrics_file_path) as inptr:
            mmlu_scores = json.loads(inptr.read())
        results['mmlu_average'] = mmlu_scores['average_acc']
#        aggregate_scores_list.append(mmlu_scores['average_acc'])
    
    results['aggregated-terms'] = aggregate_scores_list

    task_aggregate = harmonic_mean(aggregate_scores_list)
    results['aggregate-score'] = -1

    results['harmonic-mean-task-aggregate'] = task_aggregate

    # Need MMLU and MIA scores to compute the aggregate
    if 'mmlu_average' in results and 'mia_loss_acc' in results:
        if results['mmlu_average'] < 0.371:
            # MMLU score should not drop below 75% of pre-unlearning preformance
            print(f"[WARNING] The MMLU average for the provided checkpoint is below threshold. If this happens your model may not be considered in final challenge ranking.")

        mia_final_score = 1 - abs(results['mia_loss_acc'] - 0.5)*2
        results['mia_final_score'] = mia_final_score
        results['aggregate-score'] = mean([task_aggregate, results['mmlu_average'], mia_final_score])

    metrics_file = os.path.join(args.output_dir, 'evaluation_results.jsonl')
    with open(metrics_file, 'w') as outptr:
        outptr.write(json.dumps(results))

def main():
    args = get_args_and_verify()

    # Set random seed
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    if args.debug:
        print('Evaluating Checkpoint at {}'.format(args.checkpoint_path))

    checkpoint_path = args.checkpoint_path

    # Set up accelerator
    accelerator = Accelerator()
    if not args.compute_metrics_only:
        model = AutoModelForCausalLM.from_pretrained(checkpoint_path, torch_dtype=torch.bfloat16, trust_remote_code = True) # .to('cuda')

        tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
        tokenizer.pad_token = tokenizer.eos_token
    
        inference(args, model, tokenizer)

        if args.mia_data_path is not None:
            mia_attacks(args, model, tokenizer)

    if accelerator.is_main_process:
        compute_metrics(args)

if __name__ == '__main__':
    main()""")
import sys
import types

try:
    import evaluation
    import importlib
    importlib.reload(evaluation)
except ImportError:
    pass

def run_evaluation(
    data_path,
    checkpoint_path,
    output_dir="eval_results",
    mia_data_path=None, #MIA_TRAIN_PATH
    mia_data_val_path=MIA_VAL_PATH,
    mmlu_metrics_file_path=None,
    max_new_tokens=256,
    batch_size=25,
    debug=False,
    compute_metrics_only=False,
    seed=42,
    keep_files=True,
):
    try:
        # Build an argparse-like args object
        args = types.SimpleNamespace(
            data_path=data_path,
            checkpoint_path=checkpoint_path,
            output_dir=output_dir,
            mia_data_path=mia_data_path,
            mia_data_val_path=mia_data_val_path,
            mmlu_metrics_file_path=mmlu_metrics_file_path,
            max_new_tokens=max_new_tokens,
            batch_size=batch_size,
            debug=debug,
            compute_metrics_only=compute_metrics_only,
            seed=seed,
            keep_files=keep_files,
        )

        # Verify file paths exist
        print("Verifying paths...")
        print(f"  Data path: {data_path}")
        print(f"  Checkpoint path: {checkpoint_path}")
        print(f"  Output dir: {output_dir}")
        
        if not os.path.exists(data_path):
            raise FileNotFoundError(f"Data path not found: {data_path}")
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint path not found: {checkpoint_path}")
        if not os.path.exists(os.path.join(data_path, 'forget.jsonl')):
            raise FileNotFoundError(f"forget.jsonl not found in {data_path}")
        if not os.path.exists(os.path.join(data_path, 'retain.jsonl')):
            raise FileNotFoundError(f"retain.jsonl not found in {data_path}")

        # Normalize paths (as in the original script)
        from pathlib import Path
        if args.output_dir is None:
            args.output_dir = os.getcwd()
        else:
            args.output_dir = args.output_dir.rstrip('/')
            Path(args.output_dir).mkdir(parents=True, exist_ok=True)

        # Run functions directly
        import random, torch, numpy as np
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

        from accelerate import Accelerator
        accelerator = Accelerator()

        if not args.compute_metrics_only:
            from transformers import AutoModelForCausalLM, AutoTokenizer
            from peft import PeftModel, LoraConfig
            
            print(f"Loading model from {args.checkpoint_path}...")
            
            # Load PEFT (LoRA) model if saved as adapter
            try:
                # Try to load as PEFT model first
                base_model_path = MODEL_PATH  # Use the base model path
                base_model = AutoModelForCausalLM.from_pretrained(
                    base_model_path, 
                    local_files_only=True,
                    torch_dtype=torch.bfloat16
                )
                model = PeftModel.from_pretrained(base_model, args.checkpoint_path)
                print("Loaded as PEFT model")
            except:
                # If that fails, load as a regular model
                model = AutoModelForCausalLM.from_pretrained(
                    args.checkpoint_path,
                    torch_dtype=torch.bfloat16,
                    trust_remote_code=True
                )
                print("Loaded as regular model")
            
            tokenizer = AutoTokenizer.from_pretrained(args.checkpoint_path)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

            print("Starting inference...")
            evaluation.inference(args, model, tokenizer)
            
            # if args.mia_data_path is not None:
                # print("🔍 Starting MIA attacks...")
                # evaluation.mia_attacks(args, model, tokenizer)

        if accelerator.is_main_process:
            print("Computing metrics...")
            evaluation.compute_metrics(args)
            print("Evaluation completed!")

    except Exception as e:
        print("Error during evaluation:", str(e))
        import traceback
        traceback.print_exc()

# === Step 4: Run evaluation ===
print("Starting evaluation process...")

# Check files exist before starting
if os.path.exists("validation/forget.jsonl") and os.path.exists("validation/retain.jsonl"):
    if os.path.exists("studentmodel_best_val_sequential/"):
        run_evaluation(
            data_path="validation/",  # relative folder with forget.jsonl and retain.jsonl
            checkpoint_path="studentmodel_best_val_sequential/",  # cartella relativa con i pesi del modello
            output_dir="eval_results",
            debug=False  # Attiva debug per vedere cosa succede
        )
    else:
        print("Model checkpoint not found at balanced_results/balanced_model/")
        print("   Make sure the training completed successfully")
else:
    print("Validation files not found")
    print("   Expected: validation/forget.jsonl and validation/retain.jsonl")
    print("   Make sure the data processing completed successfully")