# Unlearning with TOFU (Inverted Hinge Loss + Fisher Information + FILA)


## 1. Setup and Imports

In [None]:
# Install required packages
%pip install -q rouge-score torchmetrics transformers huggingface_hub

import os
import json
import shutil
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from rouge_score import rouge_scorer
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
from huggingface_hub import snapshot_download

# Paths and configuration for local environment
MODEL_PATH = "semeval25-unlearning-1B-model"
DATA_PATH = "semeval25-unlearning-data"
STUDENT_PATH = "studentmodel_final"
Path(STUDENT_PATH).mkdir(parents=True, exist_ok=True)

# HuggingFace token (replace with your actual token)
# Get your token from: https://huggingface.co/settings/tokens
HF_TOKEN = "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"

# Download model and data if not already present
if not os.path.exists(MODEL_PATH):
    print("Downloading model...")
    snapshot_download(
        repo_id='llmunlearningsemeval2025organization/olmo-1B-model-semeval25-unlearning', 
        token=HF_TOKEN, 
        local_dir=MODEL_PATH
    )

if not os.path.exists(DATA_PATH):
    print("Downloading dataset...")
    snapshot_download(
        repo_id='llmunlearningsemeval2025organization/semeval25-unlearning-dataset-public', 
        token=HF_TOKEN, 
        local_dir=DATA_PATH, 
        repo_type="dataset"
    )

# Report visible GPUs
print(f"Available GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

## 2. Data and Model Loading

In [None]:
# Load parquet datasets from local directory
retain_train_df = pd.read_parquet(f"{DATA_PATH}/data/retain_train-00000-of-00001.parquet", engine='pyarrow')
retain_validation_df = pd.read_parquet(f"{DATA_PATH}/data/retain_validation-00000-of-00001.parquet", engine='pyarrow')
forget_train_df = pd.read_parquet(f"{DATA_PATH}/data/forget_train-00000-of-00001.parquet", engine='pyarrow')
forget_validation_df = pd.read_parquet(f"{DATA_PATH}/data/forget_validation-00000-of-00001.parquet", engine='pyarrow')

# Save as JSONL for evaluation scripts (portable without shell commands)
Path('train').mkdir(parents=True, exist_ok=True)
Path('validation').mkdir(parents=True, exist_ok=True)
retain_train_df.to_json('train/retain.jsonl', orient='records', lines=True)
forget_train_df.to_json('train/forget.jsonl', orient='records', lines=True)
retain_validation_df.to_json('validation/retain.jsonl', orient='records', lines=True)
forget_validation_df.to_json('validation/forget.jsonl', orient='records', lines=True)

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Datasets saved and tokenizer loaded")

## 3. Dataset

In [None]:
class UnlearningDataset(Dataset):
    """Dataset for unlearning with combined input and output text.

    - Expects items with keys: 'input', 'output', optional 'split' ('retain' or 'forget').
    - Returns tokenized tensors and the start index of the output (start_locs).
    """
    def __init__(self, data_source, tokenizer, max_length=256):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        if isinstance(data_source, pd.DataFrame):
            self.data = data_source
            print(f"Loaded {len(self.data)} samples from DataFrame")
        elif isinstance(data_source, str):
            data_list = []
            with open(data_source, 'r', encoding='utf-8') as f:
                for line in f:
                    item = json.loads(line.strip())
                    data_list.append(item)
            self.data = pd.DataFrame(data_list)
            print(f"Loaded {len(self.data)} samples from {data_source}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        input_text = item["input"]
        output_text = item["output"]

        # Single tokenization of concatenated input and output
        combined = f"{input_text} {output_text}"
        tokenized = self.tokenizer(
            combined,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Tokenize only the input to compute the start index of the output
        input_ids = self.tokenizer(
            input_text,
            return_tensors="pt"
        )["input_ids"].squeeze(0)

        return {
            "input_ids": tokenized["input_ids"].squeeze(0),
            "attention_mask": tokenized["attention_mask"].squeeze(0),
            "start_locs": input_ids.size(0),  # index where output begins
            "labels": tokenized["input_ids"].squeeze(0),
            "split": 1 if item.get("split", "retain") == "forget" else 0
        }


In [None]:
# Create datasets and dataloaders
batch_size = 4
# Build separate datasets for retain/forget splits
retain_train_dataset = UnlearningDataset(retain_train_df, tokenizer)
forget_train_dataset = UnlearningDataset(forget_train_df, tokenizer)

retain_train_dataloader = DataLoader(retain_train_dataset, batch_size, shuffle=True)
forget_train_dataloader = DataLoader(forget_train_dataset, batch_size, shuffle=True)

# Validation loaders (no shuffling)
retain_val_dataset = UnlearningDataset(retain_validation_df, tokenizer)
forget_val_dataset = UnlearningDataset(forget_validation_df, tokenizer)
retain_val_dataloader = DataLoader(retain_val_dataset, batch_size, shuffle=False)
forget_val_dataloader = DataLoader(forget_val_dataset, batch_size, shuffle=False)

## 4. DualTeacher Trainer Class

In [None]:
import torch.nn.functional as F
from functools import reduce
from torchmetrics.classification import MulticlassHingeLoss
from torchmetrics.utilities.data import to_onehot
from torchmetrics.metric import Metric
from torchmetrics.functional.classification.confusion_matrix import _multiclass_confusion_matrix_format
from torchmetrics.functional.classification.hinge import (
    _multiclass_hinge_loss_arg_validation, 
    _multiclass_hinge_loss_tensor_validation,
    _hinge_loss_compute
)

def _custom_multiclass_hinge_loss_update(
    preds,
    target,
    alpha,
    squared,
    multiclass_mode = "crammer-singer"
):
    if not torch.all((preds >= 0) * (preds <= 1)):
        preds = preds.softmax(1)

    target = to_onehot(target, max(2, preds.shape[1])).bool()
    if multiclass_mode == "crammer-singer":
        margin = preds[target]
        margin -= torch.max(preds[~target].view(preds.shape[0], -1), dim=1)[0]
    else:
        target = target.bool()
        margin = torch.zeros_like(preds)
        margin[target] = preds[target]
        margin[~target] = -preds[~target]

    measures = alpha + margin
    measures = torch.clamp(measures, 0)

    if squared:
        measures = measures.pow(2)

    total = torch.tensor(target.shape[0], device=target.device)
    return measures.sum(dim=0), total

def multiclass_hinge_loss(
    preds,
    target,
    num_classes,
    alpha = 1.0,
    squared = False,
    multiclass_mode = "crammer-singer",
    ignore_index = None,
    validate_args = True,
):
    if validate_args:
        _multiclass_hinge_loss_arg_validation(num_classes, squared, multiclass_mode, ignore_index)
        _multiclass_hinge_loss_tensor_validation(preds, target, num_classes, ignore_index)
    preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index, convert_to_labels=False)
    measures, total = _custom_multiclass_hinge_loss_update(
        preds, 
        target, 
        alpha,
        squared, 
        multiclass_mode,
    )
    return _hinge_loss_compute(measures, total)

class TOFUTrainer:
    def __init__(self, model_path, tokenizer, lora_config, device="cuda:0", importance_file=None):
        self.model_path = model_path
        self.tokenizer = tokenizer
        self.lora_config = lora_config
        self.device = device
        self.importance_file = importance_file
        
        self.model = None
        self.initial_state_dict = {}
        
        # Validation tracking
        self.best_val_loss = float('inf')
        self.best_epoch = 0
        
        # Loss weights
        self.forget_weight = 1.0
        self.retain_weight = 1.0
        
    def get_module_by_name(self, module, access_string):
        """Helper function to get module by name"""
        names = access_string.split(sep='.')
        return reduce(getattr, names, module)
        
    def setup_model(self):
        """Initialize and setup model with LoRA and optional FILA"""
        print("Setting up model...")
        
        # Load base model
        base_model = AutoModelForCausalLM.from_pretrained(self.model_path, local_files_only=True)
        
        # Setup model with LoRA
        self.model = get_peft_model(base_model, self.lora_config)
        self.model = self.model.to(self.device)
        self.model.print_trainable_parameters()
        
        # Apply FILA if importance file is provided
        if self.importance_file:
            self.apply_fila()
        
        # Save initial state for task vector calculation
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.initial_state_dict[name] = param.data.clone()
        
        print("Model setup completed")
    
    def apply_fila(self):
        """Apply Fisher Information Weighted LoRA Adaptation (FILA)"""
        print(f'Loading importance file from {self.importance_file}')
        imp_file = torch.load(self.importance_file, map_location='cpu')
        
        f_cnt = imp_file['f_cnt']
        r_cnt = imp_file['r_cnt']
        importance_f = imp_file['importance_f']
        importance_r = imp_file['importance_r']
        
        # Calculate importance ratio: forget/retain
        importances = {n: torch.div(importance_f[n]/f_cnt, 1e-5+(importance_r[n]/r_cnt)) for n in importance_f.keys()}
        
        # Get LoRA target modules
        lora_targets = self.lora_config.target_modules
        
        for old_name, importance in importances.items():
            if not any([target_name in old_name for target_name in lora_targets]):
                continue
                
            name = old_name.replace("module.", '')
            lora_A = 'base_model.model.'+name.replace(".weight", '')+'.lora_A'
            lora_B = 'base_model.model.'+name.replace(".weight", '')+'.lora_B'
            base_layer = 'base_model.model.'+name.replace(".weight", '')+'.base_layer'
            scaling = 'base_model.model.'+name.replace(".weight", '')+'.scaling'

            try:
                lora_A_module = self.get_module_by_name(self.model, lora_A)
                lora_B_module = self.get_module_by_name(self.model, lora_B)
                base_layer_module = self.get_module_by_name(self.model, base_layer)
                scaling_module = self.get_module_by_name(self.model, scaling)

                orig_shape = base_layer_module.weight.shape
                W = base_layer_module.weight.data.reshape(orig_shape)
                dtype = W.dtype
                W = W.to(torch.float32)

                # Solve row-wise weighted low-rank approximation
                row_importance = importance.sum(dim=1).sqrt().to(W.device) # row-wise sum
                U, S, V = torch.svd_lowrank(row_importance[:,None] * W, q=self.lora_config.r)

                S = S / scaling_module['default']

                new_lora_A = (V * torch.sqrt(S)).t()
                new_lora_B = (1/(row_importance+1e-5))[:,None] * (U * torch.sqrt(S))
                new_residual = base_layer_module.weight.data.reshape(orig_shape) - scaling_module['default'] * new_lora_B @ new_lora_A

                lora_A_module['default'].weight.data = new_lora_A.contiguous().to(dtype)
                lora_B_module['default'].weight.data = new_lora_B.contiguous().to(dtype)
                base_layer_module.weight.data = new_residual.contiguous().to(dtype)
                
                print(f"Applied FILA to {name}")
            except Exception as e:
                print(f"Could not apply FILA to {name}: {e}")
                continue
    
    def compute_inverted_hinge_loss(self, batch):
        """Compute Inverted Hinge Loss (IHL) for forget samples"""
        input_ids = batch["input_ids"].to(self.device)
        attention_mask = batch["attention_mask"].to(self.device)
        labels = batch["labels"].to(self.device)
        
        # Forward pass
        outputs = self.model(input_ids, labels=labels, attention_mask=attention_mask)
        
        # Compute inverted hinge loss
        scores = outputs.logits
        shift_logits = scores[..., :-1, :].contiguous().squeeze().view(-1, scores.size(-1)) # [BN, V]
        shift_labels = labels[..., 1:].contiguous().squeeze().view(-1) # [BN,]
        
        forget_loss = multiclass_hinge_loss(
            shift_logits[shift_labels != -100,:], # ignore pad tokens
            shift_labels[shift_labels != -100],
            shift_logits.size(-1),
        )
        
        return forget_loss
    
    def compute_retain_loss(self, batch):
        """Compute standard cross-entropy loss for retain samples"""
        input_ids = batch["input_ids"].to(self.device)
        attention_mask = batch["attention_mask"].to(self.device)
        labels = batch["labels"].to(self.device)
        
        # Forward pass
        outputs = self.model(input_ids, labels=labels, attention_mask=attention_mask)
        return outputs.loss
    
    def compute_mixed_loss(self, forget_batch, retain_batch):
        """Compute combined loss using IHL for forget and CE for retain"""
        forget_loss = self.compute_inverted_hinge_loss(forget_batch)
        retain_loss = self.compute_retain_loss(retain_batch)
        
        total_loss = self.forget_weight * forget_loss + self.retain_weight * retain_loss
        
        return total_loss, forget_loss, retain_loss
    
    def train_epoch(self, forget_dataloader, retain_dataloader, optimizer):
        """Train for one epoch"""
        self.model.train()
        epoch_losses = []
        epoch_forget_losses = []
        epoch_retain_losses = []
        
        # Make sure both dataloaders have the same length by cycling the shorter one
        min_batches = min(len(forget_dataloader), len(retain_dataloader))
        
        forget_iter = iter(forget_dataloader)
        retain_iter = iter(retain_dataloader)
        
        for _ in tqdm(range(min_batches), desc="Training"):
            optimizer.zero_grad()
            
            try:
                forget_batch = next(forget_iter)
                retain_batch = next(retain_iter)
            except StopIteration:
                break
            
            total_loss, forget_loss, retain_loss = self.compute_mixed_loss(forget_batch, retain_batch)
            
            total_loss.backward()
            optimizer.step()
            
            epoch_losses.append(total_loss.item())
            epoch_forget_losses.append(forget_loss.item())
            epoch_retain_losses.append(retain_loss.item())
        
        return {
            'total_loss': np.mean(epoch_losses),
            'forget_loss': np.mean(epoch_forget_losses),
            'retain_loss': np.mean(epoch_retain_losses)
        }
    
    def validate(self, forget_val_dataloader, retain_val_dataloader):
        """Validate model performance"""
        self.model.eval()
        val_losses = []
        val_forget_losses = []
        val_retain_losses = []
        
        with torch.no_grad():
            # Validate on a subset to save time
            for i, (forget_batch, retain_batch) in enumerate(zip(forget_val_dataloader, retain_val_dataloader)):
                if i >= 10:  # Limit validation batches
                    break
                
                total_loss, forget_loss, retain_loss = self.compute_mixed_loss(forget_batch, retain_batch)
                
                val_losses.append(total_loss.item())
                val_forget_losses.append(forget_loss.item())
                val_retain_losses.append(retain_loss.item())
        
        self.model.train()
        return {
            'val_total_loss': np.mean(val_losses) if val_losses else float('inf'),
            'val_forget_loss': np.mean(val_forget_losses) if val_forget_losses else float('inf'),
            'val_retain_loss': np.mean(val_retain_losses) if val_retain_losses else float('inf')
        }
    
    def train(self, forget_dataloader, retain_dataloader, 
              forget_val_dataloader=None, retain_val_dataloader=None,
              num_epochs=3, lr=1e-5, patience=3):
        """Main training loop"""
        print("Training with TOFU approach (IHL + Fisher Information + FILA)...")
        
        self.model.train()
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr, weight_decay=0.01)
        
        patience_counter = 0
        
        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch+1}/{num_epochs}")
            
            # Train
            train_metrics = self.train_epoch(forget_dataloader, retain_dataloader, optimizer)
            
            print(f"Train - Total: {train_metrics['total_loss']:.4f}, "
                  f"Forget (IHL): {train_metrics['forget_loss']:.4f}, "
                  f"Retain: {train_metrics['retain_loss']:.4f}")
            
            # Validate
            if forget_val_dataloader and retain_val_dataloader:
                val_metrics = self.validate(forget_val_dataloader, retain_val_dataloader)
                
                print(f"Val - Total: {val_metrics['val_total_loss']:.4f}, "
                      f"Forget: {val_metrics['val_forget_loss']:.4f}, "
                      f"Retain: {val_metrics['val_retain_loss']:.4f}")
                
                # Early stopping
                if val_metrics['val_total_loss'] < self.best_val_loss:
                    self.best_val_loss = val_metrics['val_total_loss']
                    self.best_epoch = epoch
                    patience_counter = 0
                    # Save best model state
                    self.best_model_state = {name: param.clone() for name, param in self.model.named_parameters()}
                else:
                    patience_counter += 1
                    
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
        
        print(f"Training completed. Best epoch: {self.best_epoch+1}")
        
        # Load best model if we have it
        if hasattr(self, 'best_model_state'):
            for name, param in self.model.named_parameters():
                param.data.copy_(self.best_model_state[name])
            print("Loaded best model weights")
    
    def save_model(self, save_path):
        """Save the trained model"""
        Path(save_path).mkdir(parents=True, exist_ok=True)
        self.model.save_pretrained(save_path)
        print(f"Model saved to {save_path}")
    
    def compute_task_vector(self):
        """Compute task vector (difference from initial model)"""
        task_vector = {}
        for name, param in self.model.named_parameters():
            if name in self.initial_state_dict:
                task_vector[name] = param.data - self.initial_state_dict[name]
        return task_vector

## 5. Setup Trainer and Training

In [None]:
# LoRA configuration for TOFU training
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

# Initialize TOFU trainer with optional Fisher Information file for FILA
# Set importance_file=None to disable FILA, or provide path to enable it
tofu_trainer = TOFUTrainer(
    model_path=MODEL_PATH,
    tokenizer=tokenizer,
    lora_config=lora_config,
    device="cuda:0",
    importance_file=None  # Set to Fisher information file path to enable FILA
)

# Setup model
tofu_trainer.setup_model()

print("TOFU Trainer initialized successfully!")
print("Features enabled:")
print("- Inverted Hinge Loss (IHL) for forget samples")
print("- Cross-entropy loss for retain samples") 
print("- LoRA fine-tuning")
if tofu_trainer.importance_file:
    print("- FILA (Fisher Information weighted LoRA)")
else:
    print("- FILA disabled (no importance file provided)")

In [None]:
# Train with TOFU approach
print("Starting TOFU training...")

# Train the model using Inverted Hinge Loss + standard cross-entropy
tofu_trainer.train(
    forget_dataloader=forget_train_dataloader,
    retain_dataloader=retain_train_dataloader,
    forget_val_dataloader=forget_val_dataloader,
    retain_val_dataloader=retain_val_dataloader,
    num_epochs=3,
    lr=1e-5,
    patience=2
)

# Save the trained model
save_path = STUDENT_PATH + "_tofu"
tofu_trainer.save_model(save_path)

print(f"TOFU training completed! Model saved to {save_path}")

# Compute task vector (difference from initial weights)
task_vector = tofu_trainer.compute_task_vector()
print(f"Task vector computed with {len(task_vector)} parameter differences")

In [None]:
def measure_fisher_information(model, forget_dataloader, retain_dataloader, save_path=None):
    """
    Measure Fisher Information for FILA implementation
    Based on TOFU's measure_importance.py
    """
    print("Measuring Fisher Information...")
    
    # Find all linear layer names for importance measurement
    def find_all_linear_names(model):
        cls = torch.nn.Linear
        lora_module_names = set()
        for name, module in model.named_modules():
            if isinstance(module, cls):
                names = name.split('.')
                lora_module_names.add(names[0] if len(names) == 1 else names[-1])
        if 'lm_head' in lora_module_names:
            lora_module_names.remove('lm_head')
        return list(lora_module_names)
    
    # Force all parameters to require gradients
    model.train()
    for param in model.parameters():
        param.requires_grad = True
    
    # Find target modules
    target_modules = find_all_linear_names(model)
    print(f"Target modules for importance: {target_modules}")
    
    # Initialize importance tracking
    importance_f = {}
    importance_r = {}
    for name, param in model.named_parameters():
        for t in target_modules:
            if t in name and 'weight' in name:
                importance_f[name] = 0
                importance_r[name] = 0
    
    f_cnt = 0
    r_cnt = 0
    
    # Measure importance on forget samples
    print("Measuring importance on forget samples...")
    for step, batch in enumerate(tqdm(forget_dataloader, desc="Forget importance")):
        if step >= 10:  # Limit to prevent long computation
            break
            
        input_ids = batch["input_ids"].to(model.device)
        attention_mask = batch["attention_mask"].to(model.device)
        labels = batch["labels"].to(model.device)
        
        output = model(input_ids, labels=labels, attention_mask=attention_mask)
        output.loss.backward()
        
        cnt = torch.sum(labels != -100)
        for n, param in model.named_parameters():
            if n in importance_f and param.grad is not None:
                importance_f[n] += (param.grad.pow(2) * cnt).detach().cpu()
            if param.grad is not None:
                param.grad = None
        f_cnt += cnt
    
    # Measure importance on retain samples
    print("Measuring importance on retain samples...")
    for step, batch in enumerate(tqdm(retain_dataloader, desc="Retain importance")):
        if step >= 10:  # Limit to prevent long computation
            break
            
        input_ids = batch["input_ids"].to(model.device)
        attention_mask = batch["attention_mask"].to(model.device)
        labels = batch["labels"].to(model.device)
        
        output = model(input_ids, labels=labels, attention_mask=attention_mask)
        output.loss.backward()
        
        cnt = torch.sum(labels != -100)
        for n, param in model.named_parameters():
            if n in importance_r and param.grad is not None:
                importance_r[n] += (param.grad.pow(2) * cnt).detach().cpu()
            if param.grad is not None:
                param.grad = None
        r_cnt += cnt
    
    # Package results
    importances = {
        'f_cnt': f_cnt,
        'r_cnt': r_cnt,
        'importance_f': importance_f,
        'importance_r': importance_r
    }
    
    # Save if path provided
    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
        torch.save(importances, save_path)
        print(f"Fisher information saved to {save_path}")
    
    print("Fisher information measurement completed!")
    return importances

# Example usage to create Fisher information for FILA:
# fisher_info = measure_fisher_information(
#     tofu_trainer.model, 
#     forget_train_dataloader, 
#     retain_train_dataloader,
#     save_path="fisher_importance.pt"
# )

## 6. Evaluation

In [None]:
# Evaluation with TOFU trained model
print("Starting evaluation with TOFU trained model...")

# Check files exist before starting
if os.path.exists("validation/forget.jsonl") and os.path.exists("validation/retain.jsonl"):
    tofu_model_path = STUDENT_PATH + "_tofu"
    if os.path.exists(tofu_model_path):
        run_evaluation(
            data_path="validation/",  
            checkpoint_path=tofu_model_path,  
            output_dir="eval_results_tofu",
            debug=False
        )
        print("✅ TOFU model evaluation completed!")
    else:
        print(f"❌ TOFU model checkpoint not found at {tofu_model_path}")
        print("   Make sure the TOFU training completed successfully")
else:
    print("❌ Validation files not found")
    print("   Expected: validation/forget.jsonl and validation/retain.jsonl")
    print("   Make sure the data processing completed successfully")

# Optional: Compare with baseline
print("\n" + "="*50)
print("COMPARISON SUMMARY")
print("="*50)
print("✅ TOFU Implementation Features:")
print("  - Inverted Hinge Loss (IHL) for forget samples")
print("  - Fisher Information measurement capability")
print("  - FILA (Fisher Information weighted LoRA) support")
print("  - Standard LoRA fine-tuning")
print("  - Early stopping with validation")
print("\n🔄 Replaced Dual Teacher approach with TOFU methods")
print("📊 Evaluation results saved in eval_results_tofu/")

# Show some example Fisher information usage
print("\n" + "="*50)
print("FISHER INFORMATION & FILA USAGE")
print("="*50)
print("To enable FILA in future runs:")
print("1. First measure Fisher information:")
print("   fisher_info = measure_fisher_information(model, forget_dl, retain_dl, 'fisher.pt')")
print("2. Then initialize trainer with importance file:")
print("   trainer = TOFUTrainer(..., importance_file='fisher.pt')")
print("3. FILA will automatically apply weighted LoRA initialization")