# Unlearning DualTeacher

## 1. Setup and Imports

In [None]:
!pip install rouge-score
import torch
import pandas as pd
import numpy as np
import json
import random
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import shutil

# Configuration paths
MODEL_PATH = "/kaggle/input/olmo-model/semeval25-unlearning-1B-model"
DATA_PATH = "/kaggle/input/olmo-model/semeval25-unlearning-data"
MIA_VAL_PATH  = "/kaggle/input/mia-dataset-val"
MIA_TRAIN_PATH  = "/kaggle/input/mia-dataset"
GOOD_TEACHER_PATH = "/kaggle/input/good-teacher"

STUDENT_TRAINED = "/kaggle/input/student-trained"
STUDENT_PATH  = "/kaggle/working/studentmodel_final"
Path(STUDENT_PATH).mkdir(parents=True, exist_ok=True)

# Copy pre-trained student files if available
dir_path = Path(STUDENT_TRAINED)
for file in dir_path.iterdir():
    shutil.copyfile(f"{STUDENT_TRAINED}/{file.name}", f"{STUDENT_PATH}/{file.name}")

print(f"Available GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

## 2. Data and Model Loading

In [None]:
# Load datasets from parquet files
retain_train_df = pd.read_parquet(f"{DATA_PATH}/data/retain_train-00000-of-00001.parquet", engine='pyarrow')
retain_validation_df = pd.read_parquet(f"{DATA_PATH}/data/retain_validation-00000-of-00001.parquet", engine='pyarrow')
forget_train_df = pd.read_parquet(f"{DATA_PATH}/data/forget_train-00000-of-00001.parquet", engine='pyarrow')
forget_validation_df = pd.read_parquet(f"{DATA_PATH}/data/forget_validation-00000-of-00001.parquet", engine='pyarrow')

# Save datasets as JSONL format for evaluation scripts
!mkdir -p train validation
retain_train_df.to_json('train/retain.jsonl', orient='records', lines=True)
forget_train_df.to_json('train/forget.jsonl', orient='records', lines=True)
retain_validation_df.to_json('validation/retain.jsonl', orient='records', lines=True)
forget_validation_df.to_json('validation/forget.jsonl', orient='records', lines=True)

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Datasets saved and tokenizer loaded")

## 3. Dataset

In [None]:
class UnlearningDataset(Dataset):
    """
    Custom PyTorch Dataset for machine unlearning tasks.
    
    This dataset handles input-output pairs where the model needs to learn different
    behaviors for 'retain' vs 'forget' samples. It concatenates input and output text
    for training and tracks where the output begins for proper loss computation.
    """
    
    def __init__(self, data_source, tokenizer, max_length=256):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        if isinstance(data_source, pd.DataFrame):
            self.data = data_source
            print(f"Loaded {len(self.data)} samples from DataFrame")
        elif isinstance(data_source, str):
            data_list = []
            with open(data_source, 'r', encoding='utf-8') as f:
                for line in f:
                    item = json.loads(line.strip())
                    data_list.append(item)
            self.data = pd.DataFrame(data_list)
            print(f"Loaded {len(self.data)} samples from {data_source}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        input_text = item["input"]
        output_text = item["output"]

        # Combine input and output for training
        combined = f"{input_text} {output_text}"
        tokenized = self.tokenizer(
            combined,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Calculate where output starts (for loss computation)
        input_ids = self.tokenizer(
            input_text,
            return_tensors="pt"
        )["input_ids"].squeeze(0)

        return {
            "input_ids": tokenized["input_ids"].squeeze(0),
            "attention_mask": tokenized["attention_mask"].squeeze(0),
            "start_locs": input_ids.size(0),  # Position where input ends
            "labels": tokenized["input_ids"].squeeze(0),
            "split": 1 if item.get("split", "retain") == "forget" else 0
        }

In [None]:
# Create datasets and dataloaders
batch_size = 4

# Create separate datasets for retain and forget splits
retain_train_dataset = UnlearningDataset(retain_train_df, tokenizer)
forget_train_dataset = UnlearningDataset(forget_train_df, tokenizer) 

retain_train_dataloader = DataLoader(retain_train_dataset, batch_size, shuffle=True)
forget_train_dataloader = DataLoader(forget_train_dataset, batch_size, shuffle=True)

# Validation datasets (no shuffling for consistent evaluation)
retain_val_dataset = UnlearningDataset(retain_validation_df, tokenizer)
forget_val_dataset = UnlearningDataset(forget_validation_df, tokenizer)
retain_val_dataloader = DataLoader(retain_val_dataset, batch_size, shuffle=False)
forget_val_dataloader = DataLoader(forget_val_dataset, batch_size, shuffle=False)

## 4. DualTeacher Trainer Class

In [None]:
class DualTeacherTrainer:
    """
    Implementation of the Dual Teacher approach for machine unlearning using mixed training.
    
    This class implements a dual-teacher distillation strategy where:
    1. A "good teacher" is trained on retain data to preserve important knowledge
    2. A "bad teacher" generates uniform/noisy outputs to forget unwanted knowledge  
    3. A student model learns from both teachers in mixed batches
    """
    
    def __init__(self, model_path, tokenizer, teacher_lora_config, student_lora_config, device_map=None):
        self.model_path = model_path
        self.tokenizer = tokenizer
        self.teacher_lora_config = teacher_lora_config
        self.student_lora_config = student_lora_config
        self.device_map = device_map or {"student": "cuda:0", "teacher": "cuda:1"}
        
        self.good_teacher = None
        self.student_model = None
        
        # Training tracking
        self.best_val_loss = float('inf')
        self.best_epoch = 0
        
    def setup_models(self, skip_teacher_setup=False):
        """Initialize and setup both teacher and student models"""
        print("Setting up models...")
        base_model = AutoModelForCausalLM.from_pretrained(self.model_path, local_files_only=True)
        
        if skip_teacher_setup is False:
            # Setup good teacher with LoRA
            self.good_teacher = get_peft_model(base_model, self.teacher_lora_config)
            self.good_teacher = self.good_teacher.to(self.device_map["teacher"])
            self.good_teacher.print_trainable_parameters()
            
        # Setup student model with LoRA
        self.student_model = get_peft_model(base_model, self.student_lora_config)
        self.student_model = self.student_model.to(self.device_map["student"])
        self.student_model.print_trainable_parameters()
        
        print("Models setup completed")
        
    def create_bad_teacher_logits(self, good_teacher_logits):
        """
        Create bad teacher logits using uniform distribution + gaussian noise.
        This matches the simple approach from DualTeacher_seq.
        """
        uniform_logits = torch.zeros_like(good_teacher_logits)
        noisy_logits = uniform_logits + 0.1 * torch.randn_like(good_teacher_logits)
        return noisy_logits

    def compute_retain_only_loss(self, batch):
        """Compute loss for retain samples using standard KL divergence"""
        student_device = self.device_map["student"]
        teacher_device = self.device_map["teacher"]
    
        input_ids_student = batch["input_ids"].to(student_device)
        attention_mask_student = batch["attention_mask"].to(student_device)
    
        student_logits = self.student_model(input_ids_student, attention_mask=attention_mask_student).logits
        student_log_probs = torch.nn.functional.log_softmax(student_logits, dim=-1).to(teacher_device)
    
        # Teacher forward pass
        input_ids_teacher = batch["input_ids"].to(teacher_device)
        attention_mask_teacher = batch["attention_mask"].to(teacher_device)
    
        with torch.no_grad():
            good_teacher_logits = self.good_teacher(input_ids_teacher, attention_mask=attention_mask_teacher).logits
            good_teacher_probs = torch.nn.functional.softmax(good_teacher_logits, dim=-1)
    
        # Standard KL divergence loss
        retain_kl = torch.nn.functional.kl_div(
            student_log_probs,
            good_teacher_probs,
            reduction="batchmean",
            log_target=False
        )

        # Entropy regularization for confidence
        entropy_loss = -(student_log_probs.exp() * student_log_probs).sum(-1).mean()
        
        return 3.0 * retain_kl - 0.1 * entropy_loss

    def compute_forget_only_loss(self, batch):
        """Compute loss for forget samples using bad teacher (uniform/noisy)"""
        student_device = self.device_map["student"]
        teacher_device = self.device_map["teacher"]
    
        input_ids_student = batch["input_ids"].to(student_device)
        attention_mask_student = batch["attention_mask"].to(student_device)
    
        student_logits = self.student_model(input_ids_student, attention_mask=attention_mask_student).logits
        student_log_probs = torch.nn.functional.log_softmax(student_logits, dim=-1).to(teacher_device)
    
        # Generate bad teacher guidance
        input_ids_teacher = batch["input_ids"].to(teacher_device)
        attention_mask_teacher = batch["attention_mask"].to(teacher_device)
    
        with torch.no_grad():
            good_teacher_logits = self.good_teacher(input_ids_teacher, attention_mask=attention_mask_teacher).logits
            bad_teacher_logits = self.create_bad_teacher_logits(good_teacher_logits)
            bad_teacher_probs = torch.nn.functional.softmax(bad_teacher_logits, dim=-1)
    
        # Standard KL divergence with bad teacher
        forget_kl = torch.nn.functional.kl_div(
            student_log_probs,
            bad_teacher_probs,
            reduction="batchmean",
            log_target=False
        )

        # Entropy regularization for uncertainty
        entropy_loss = -(student_log_probs.exp() * student_log_probs).sum(-1).mean()
        
        return 2.0 * forget_kl + 0.1 * entropy_loss

    def compute_kl_divergence(self, batch):
        """Mixed training method using standard KL divergence"""
        student_device = self.device_map["student"]
        teacher_device = self.device_map["teacher"]
    
        input_ids_student = batch["input_ids"].to(student_device)
        attention_mask_student = batch["attention_mask"].to(student_device)
        split = batch["split"].float().to(student_device)
    
        # Student forward pass
        student_logits = self.student_model(input_ids_student, attention_mask=attention_mask_student).logits
        student_log_probs = torch.nn.functional.log_softmax(student_logits, dim=-1).to(teacher_device)
    
        # Teacher forward pass
        input_ids_teacher = batch["input_ids"].to(teacher_device)
        attention_mask_teacher = batch["attention_mask"].to(teacher_device)
    
        with torch.no_grad():
            good_teacher_logits = self.good_teacher(input_ids_teacher, attention_mask=attention_mask_teacher).logits
            bad_teacher_logits = self.create_bad_teacher_logits(good_teacher_logits)
            good_teacher_probs = torch.nn.functional.softmax(good_teacher_logits, dim=-1)
            bad_teacher_probs = torch.nn.functional.softmax(bad_teacher_logits, dim=-1)
    
        # Create masks for retain/forget samples
        retain_mask = (split <= 0.5).to(teacher_device)
        forget_mask = (split > 0.5).to(teacher_device)
    
        total_loss = 0.0
        
        # Process retain samples with good teacher
        if retain_mask.any():
            retain_kl = torch.nn.functional.kl_div(
                student_log_probs[retain_mask],
                good_teacher_probs[retain_mask.bool()],
                reduction="batchmean",
                log_target=False
            )
            total_loss += 3.0 * retain_kl
    
        # Process forget samples with bad teacher
        if forget_mask.any():
            forget_kl = torch.nn.functional.kl_div(
                student_log_probs[forget_mask],
                bad_teacher_probs[forget_mask.bool()],
                reduction="batchmean",
                log_target=False
            )
            total_loss += 2.0 * forget_kl
    
        # Add entropy regularization
        entropy_loss = -(student_log_probs.exp() * student_log_probs).sum(-1).mean()
    
        return total_loss + 0.2 * entropy_loss

    def validate_retain_only(self, retain_val_dataloader):
        """Validate only on retain samples"""
        self.student_model.eval()
        val_losses = []
        
        with torch.no_grad():
            for batch in retain_val_dataloader:
                loss = self.compute_retain_only_loss(batch)
                val_losses.append(loss.item())
        
        self.student_model.train()
        return np.mean(val_losses) if val_losses else float('inf')

    def validate_forget_only(self, forget_val_dataloader):
        """Validate only on forget samples"""
        self.student_model.eval()
        val_losses = []
        
        with torch.no_grad():
            for batch in forget_val_dataloader:
                loss = self.compute_forget_only_loss(batch)
                val_losses.append(loss.item())
        
        self.student_model.train()
        return np.mean(val_losses) if val_losses else float('inf')
      
    def _interleave_and_shuffle_batches(self, retain_dataloader, forget_dataloader, seed=None):
        """
        Collect batches from both dataloaders and interleave + shuffle for mixed training
        """
        retain_batches = [b for b in retain_dataloader]
        forget_batches = [b for b in forget_dataloader]
        combined = []
        
        # Simple interleaving
        i = 0
        while i < max(len(retain_batches), len(forget_batches)):
            if i < len(retain_batches):
                combined.append(retain_batches[i])
            if i < len(forget_batches):
                combined.append(forget_batches[i])
            i += 1
            
        # Shuffle to mix training order
        if seed is not None:
            random.Random(seed).shuffle(combined)
        else:
            random.shuffle(combined)
        return combined

    def train_student_mixed(self, retain_dataloader, forget_dataloader, 
                            retain_val_dataloader=None, forget_val_dataloader=None, 
                            num_epochs=6, lr=1e-5, val_freq=1, patience=3):
        """
        Train student using mixed approach: combine retain + forget batches in single epoch
        """
        print("Training student with MIXED dual-teacher approach...")
        self.student_model.train()
        optimizer = torch.optim.AdamW(self.student_model.parameters(), lr=lr, weight_decay=0.01)
        
        patience_counter = 0
        
        for epoch in range(num_epochs):
            epoch_losses = []
            combined_batches = self._interleave_and_shuffle_batches(retain_dataloader, forget_dataloader, seed=epoch)
            
            with tqdm(total=len(combined_batches), desc=f"Student Epoch {epoch+1} (Mixed)") as pbar:
                for batch in combined_batches:
                    optimizer.zero_grad()
                    loss = self.compute_kl_divergence(batch)
                    loss.backward()
                    optimizer.step()
                    
                    epoch_losses.append(loss.item())
                    pbar.update(1)
                    pbar.set_postfix({"loss": f"{loss.item():.4f}"})
            
            avg_epoch_loss = np.mean(epoch_losses) if epoch_losses else 0.0
            print(f"Student Epoch {epoch+1} - Average Loss: {avg_epoch_loss:.4f}")
            
            # Validation
            if (retain_val_dataloader is not None and forget_val_dataloader is not None) and ((epoch + 1) % val_freq == 0):
                print("Running validation...")
                retain_val_loss = self.validate_retain_only(retain_val_dataloader)
                forget_val_loss = self.validate_forget_only(forget_val_dataloader)
                
                # Combined validation metric
                retain_score = retain_val_loss
                forget_score = 1.0 / (1.0 + forget_val_loss)
                combined_val_loss = 0.7 * retain_score + 0.3 * forget_score
                print(f"Validation - Retain: {retain_val_loss:.4f}, Forget: {forget_val_loss:.4f}, Combined: {combined_val_loss:.4f}")
                
                if combined_val_loss < self.best_val_loss:
                    self.best_val_loss = combined_val_loss
                    self.best_epoch = epoch + 1
                    patience_counter = 0
                    print(f"New best validation loss: {combined_val_loss:.4f}")
                    self.save_model(f"studentmodel_best_val_mixed")
                else:
                    patience_counter += 1
                    print(f"No improvement for {patience_counter} validation checks")
                
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
            
            self.save_model(f"studentmodel_epoch_{epoch+1}_mixed")
        
        print("Student mixed training completed")
        print(f"Best validation loss: {self.best_val_loss:.4f} at epoch {self.best_epoch}")
        
    def train_good_teacher(self, dataloader, val_dataloader=None, num_epochs=2, lr=1e-4, save_path="good_teacher_adapter", val_freq=1):
        """Train the good teacher on retain samples using LoRA"""
        print("Training good teacher with LoRA...")

        self.good_teacher.to(self.device_map["teacher"])
        self.good_teacher.train()
        optimizer = torch.optim.AdamW(self.good_teacher.parameters(), lr=lr)
        
        for epoch in range(num_epochs):
            print(f"Epoch {epoch + 1}/{num_epochs} - Good Teacher Training")
            
            epoch_losses = []
            retain_batches_processed = 0
            
            with tqdm(total=len(dataloader), desc=f"Good Teacher Epoch {epoch+1}") as pbar:
                for batch in dataloader:
                    # Filter retain samples only
                    split = batch['split']
                    retain_mask = (split == 0)
                    
                    if not retain_mask.any():
                        pbar.update(1)
                        continue
                    
                    # Extract retain samples
                    input_ids = batch['input_ids'][retain_mask].to(self.device_map["teacher"])
                    attention_mask = batch['attention_mask'][retain_mask].to(self.device_map["teacher"])
                    labels = batch['labels'][retain_mask].to(self.device_map["teacher"])
                    
                    if input_ids.size(0) == 0:
                        pbar.update(1)
                        continue
                    
                    optimizer.zero_grad()
                    outputs = self.good_teacher(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                    loss = outputs.loss
                    
                    loss.backward()
                    optimizer.step()
                    
                    epoch_losses.append(loss.item())
                    retain_batches_processed += 1
                    pbar.update(1)
                    
                    if retain_batches_processed % 100 == 0:
                        pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
            
            if epoch_losses:
                avg_loss = np.mean(epoch_losses)
                print(f"Good Teacher Epoch {epoch+1} - Average Loss: {avg_loss:.4f}")
        
        print("Good teacher training completed")
        self.save_good_teacher(save_path)
        print(f"Good teacher adapter saved to {save_path}")
        
    def save_model(self, save_path):
        """Save student model and tokenizer"""
        self.student_model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)
        
    def save_good_teacher(self, save_path):
        """Save only the LoRA adapter of the good teacher"""
        self.good_teacher.save_pretrained(save_path)

    def load_teacher(self, teacher_path):
        """Load a pre-trained good teacher model"""
        self.good_teacher = AutoModelForCausalLM.from_pretrained(teacher_path)
        self.good_teacher.eval()  # Freeze the teacher
        for param in self.good_teacher.parameters():
            param.requires_grad = False
        self.good_teacher.to(self.device_map["teacher"])
        print("Good teacher loaded and frozen")

## 5. Setup Trainer and Training

In [None]:
train_good_teacher = False

In [None]:
# Configure LoRA for teacher and student models
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    bias="none",
)

# Initialize trainer with simplified configuration
trainer = DualTeacherTrainer(
    model_path=MODEL_PATH,
    tokenizer=tokenizer,
    teacher_lora_config=lora_config,
    student_lora_config=lora_config,
    device_map={"student": "cuda:0", "teacher": "cuda:1"}
)

# Setup models for teacher training (if needed)
if train_good_teacher == True:
    trainer.setup_models()
    # Train good teacher to preserve knowledge on retain data
    trainer.train_good_teacher(train_dataloader, val_dataloader, num_epochs=5)

In [None]:
# Setup student model (skip teacher setup since we load pre-trained teacher)
trainer.setup_models(skip_teacher_setup=True)

# Load pre-trained good teacher model
trainer.load_teacher(GOOD_TEACHER_PATH)

# Train student using mixed dual-teacher approach
# This combines retain and forget batches in mixed training epochs
trainer.train_student_mixed(
    retain_train_dataloader, 
    forget_train_dataloader,
    retain_val_dataloader=retain_val_dataloader,
    forget_val_dataloader=forget_val_dataloader,
    num_epochs=6, 
    lr=1e-5, 
    val_freq=1, 
    patience=3
)