# Unlearning DualTeacher


## 1. Setup and Imports

In [None]:

import json
from pathlib import Path

import numpy as np
import pandas as pd
import torch

from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType

# Paths and configuration for Kaggle environment
MODEL_PATH = "/kaggle/input/olmo-model/semeval25-unlearning-1B-model"
DATA_PATH = "/kaggle/input/olmo-model/semeval25-unlearning-data"
MIA_VAL_PATH = "/kaggle/input/mia-dataset-val"
MIA_TRAIN_PATH = "/kaggle/input/mia-dataset"
GOOD_TEACHER_PATH = "/kaggle/input/good-teacher"


# Report visible GPUs in Kaggle runtime
print(f"Available GPUs: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for rouge-score (setup.py) ... [?25l[?25hdone
Note: you may need to restart the kernel to use updated packages.


2025-09-14 09:23:17.577867: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757841797.769665      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757841797.822378      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Available GPUs: 2
GPU 0: Tesla T4
GPU 1: Tesla T4


## 2. Data and Model Loading

In [2]:
# Load parquet datasets from Kaggle inputs
retain_train_df = pd.read_parquet(f"{DATA_PATH}/data/retain_train-00000-of-00001.parquet", engine='pyarrow')
retain_validation_df = pd.read_parquet(f"{DATA_PATH}/data/retain_validation-00000-of-00001.parquet", engine='pyarrow')
forget_train_df = pd.read_parquet(f"{DATA_PATH}/data/forget_train-00000-of-00001.parquet", engine='pyarrow')
forget_validation_df = pd.read_parquet(f"{DATA_PATH}/data/forget_validation-00000-of-00001.parquet", engine='pyarrow')

# Save as JSONL for evaluation scripts (portable without shell commands)
Path('train').mkdir(parents=True, exist_ok=True)
Path('validation').mkdir(parents=True, exist_ok=True)
retain_train_df.to_json('train/retain.jsonl', orient='records', lines=True)
forget_train_df.to_json('train/forget.jsonl', orient='records', lines=True)
retain_validation_df.to_json('validation/retain.jsonl', orient='records', lines=True)
forget_validation_df.to_json('validation/forget.jsonl', orient='records', lines=True)

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Datasets saved and tokenizer loaded")

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

Datasets saved and tokenizer loaded


## 3. Dataset

In [3]:
class UnlearningDataset(Dataset):
    """
    Custom PyTorch Dataset for machine unlearning tasks.
    
    This dataset handles input-output pairs where the model needs to learn different
    behaviors for 'retain' vs 'forget' samples. It concatenates input and output text
    for training and tracks where the output begins for proper loss computation.
    
    Args:
        data_source: Either a pandas DataFrame or path to JSONL file
        tokenizer: HuggingFace tokenizer for text encoding
        max_length: Maximum sequence length for tokenization
        
    Returns:
        Dictionary containing:
        - input_ids: Tokenized input + output sequence
        - attention_mask: Attention mask for the sequence
        - start_locs: Index where output begins in the sequence
        - labels: Copy of input_ids for language modeling loss
        - split: Binary indicator (0=retain, 1=forget)
    """
    def __init__(self, data_source, tokenizer, max_length=256):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Handle both DataFrame and file path inputs
        if isinstance(data_source, pd.DataFrame):
            self.data = data_source
            print(f"Loaded {len(self.data)} samples from DataFrame")
        elif isinstance(data_source, str):
            data_list = []
            with open(data_source, 'r', encoding='utf-8') as f:
                for line in f:
                    item = json.loads(line.strip())
                    data_list.append(item)
            self.data = pd.DataFrame(data_list)
            print(f"Loaded {len(self.data)} samples from {data_source}")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        input_text = item["input"]
        output_text = item["output"]

        # Combine input and output for full sequence training
        combined = f"{input_text} {output_text}"
        tokenized = self.tokenizer(
            combined,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        # Calculate where the output starts in the tokenized sequence
        # This is needed for computing loss only on the output portion
        input_ids = self.tokenizer(
            input_text,
            return_tensors="pt"
        )["input_ids"].squeeze(0)

        return {
            "input_ids": tokenized["input_ids"].squeeze(0),
            "attention_mask": tokenized["attention_mask"].squeeze(0),
            "start_locs": input_ids.size(0),  # Output start index
            "labels": tokenized["input_ids"].squeeze(0),
            "split": 1 if item.get("split", "retain") == "forget" else 0  # Binary split indicator
        }

In [4]:
# Create datasets and dataloaders for training and validation
batch_size = 4

# Build separate datasets for retain/forget splits
# This separation allows for different training strategies per split
retain_train_dataset = UnlearningDataset(retain_train_df, tokenizer)
forget_train_dataset = UnlearningDataset(forget_train_df, tokenizer)

# Training dataloaders with shuffling for better generalization
retain_train_dataloader = DataLoader(retain_train_dataset, batch_size, shuffle=True)
forget_train_dataloader = DataLoader(forget_train_dataset, batch_size, shuffle=True)

# Validation datasets and loaders (no shuffling for consistent evaluation)
retain_val_dataset = UnlearningDataset(retain_validation_df, tokenizer)
forget_val_dataset = UnlearningDataset(forget_validation_df, tokenizer)
retain_val_dataloader = DataLoader(retain_val_dataset, batch_size, shuffle=False)
forget_val_dataloader = DataLoader(forget_val_dataset, batch_size, shuffle=False)

Loaded 1136 samples from DataFrame
Loaded 1112 samples from DataFrame
Loaded 278 samples from DataFrame
Loaded 254 samples from DataFrame


## 4. DualTeacher Trainer Class

In [None]:
class DualTeacherTrainer:
    """
    Implementation of the Dual Teacher approach for machine unlearning.
    
    This class implements a dual-teacher distillation strategy where:
    1. A "good teacher" is trained on retain data to preserve important knowledge
    2. A "bad teacher" generates noisy/uniform outputs to forget unwanted knowledge  
    3. A student model learns from both teachers using different loss functions
    
    The approach uses sequential training where retain and forget samples are
    processed in separate phases to better balance the competing objectives.
    """
    
    def __init__(self, model_path, tokenizer, teacher_lora_config, student_lora_config, device_map=None):
        """
        Initialize the DualTeacher trainer.
        
        Args:
            model_path (str): Path to base model for initialization
            tokenizer: HuggingFace tokenizer
            teacher_lora_config: LoRA configuration for teacher model
            student_lora_config: LoRA configuration for student model  
            device_map (dict): GPU assignment for models (default: student=cuda:0, teacher=cuda:1)
        """
        self.model_path = model_path
        self.tokenizer = tokenizer
        self.teacher_lora_config = teacher_lora_config
        self.student_lora_config = student_lora_config
        self.device_map = device_map or {"student": "cuda:0", "teacher": "cuda:1"}
        
        # Model instances
        self.good_teacher = None
        self.student_model = None
        
        # Training tracking variables
        self.best_val_loss = float('inf')
        self.best_epoch = 0
        
        # Sequential training parameters
        self.forget_weight_multiplier = 0.5  # Reduce forget influence to prevent catastrophic forgetting
        
        
    def setup_models(self, skip_teacher_setup=False):
        """
        Initialize and setup both teacher and student models with LoRA adapters.
        
        Args:
            skip_teacher_setup (bool): If True, only setup student model
        """
        # Ensure two GPUs are available for optimal performance
        if torch.cuda.device_count() < 2:
            raise RuntimeError("This implementation expects 2 GPUs for optimal performance.")
        
        print("Setting up models...")
        base_model = AutoModelForCausalLM.from_pretrained(self.model_path, local_files_only=True)
        
        # Setup good teacher with LoRA (for knowledge preservation)
        if skip_teacher_setup is False:
            self.good_teacher = get_peft_model(base_model, self.teacher_lora_config)
            self.good_teacher = self.good_teacher.to(self.device_map["teacher"])
            self.good_teacher.print_trainable_parameters()
            
        # Setup student model with LoRA (main model to be unlearned)
        self.student_model = get_peft_model(base_model, self.student_lora_config)
        self.student_model = self.student_model.to(self.device_map["student"])
        self.student_model.print_trainable_parameters()
        
        print("Models setup completed")
        
    
    def create_bad_teacher_logits(self, good_teacher_logits):
        """
        Create bad-teacher logits as noisy values around uniform distribution.
        
        This simulates a teacher that provides poor/confusing guidance for 
        forget samples, helping the student unlearn unwanted knowledge.
        
        Args:
            good_teacher_logits (torch.Tensor): Logits from good teacher
            
        Returns:
            torch.Tensor: Noisy uniform logits for bad teacher
        """
        # Start with uniform logits (all vocabulary equally likely)
        uniform_logits = torch.zeros_like(good_teacher_logits)
        # Add Gaussian noise to create variability
        noisy_logits = uniform_logits + 0.1 * torch.randn_like(good_teacher_logits)
        return noisy_logits
        
    
    def compute_retain_only_loss(self, batch):
        """
        Compute loss for retain samples only.
        
        Uses KL divergence to align student outputs with good teacher,
        plus entropy regularization to encourage confident predictions.
        
        Args:
            batch (dict): Batch of retain samples
            
        Returns:
            torch.Tensor: Combined loss for retain samples
        """
        # Device assignments
        student_device = self.device_map["student"]
        teacher_device = self.device_map["teacher"]
    
        # Student forward pass
        input_ids_student = batch["input_ids"].to(student_device)
        attention_mask_student = batch["attention_mask"].to(student_device)
        student_logits = self.student_model(input_ids_student, attention_mask=attention_mask_student).logits
        student_log_probs = torch.nn.functional.log_softmax(student_logits, dim=-1).to(teacher_device)
    
        # Teacher forward pass (no gradients needed)
        input_ids_teacher = batch["input_ids"].to(teacher_device)
        attention_mask_teacher = batch["attention_mask"].to(teacher_device)
    
        with torch.no_grad():
            good_teacher_logits = self.good_teacher(input_ids_teacher, attention_mask=attention_mask_teacher).logits
            good_teacher_probs = torch.nn.functional.softmax(good_teacher_logits, dim=-1)
    
        # KL divergence loss: align student with good teacher
        retain_kl = torch.nn.functional.kl_div(
            student_log_probs,
            good_teacher_probs,
            reduction="batchmean",  # Average over batch and sequence
            log_target=False
        )

        # Entropy regularization: encourage confident predictions
        entropy_loss = -(student_log_probs.exp() * student_log_probs).sum(-1).mean()
        
        return 3.0 * retain_kl - 0.1 * entropy_loss  # Negative entropy encourages confidence
    
    def compute_forget_only_loss(self, batch):
        """
        Compute loss for forget samples only.
        
        Uses KL divergence to align student with bad teacher (uniform/noisy),
        plus entropy regularization to encourage uncertain/random predictions.
        
        Args:
            batch (dict): Batch of forget samples
            
        Returns:
            torch.Tensor: Combined loss for forget samples  
        """
        # Device assignments
        student_device = self.device_map["student"]
        teacher_device = self.device_map["teacher"]
    
        # Student forward pass
        input_ids_student = batch["input_ids"].to(student_device)
        attention_mask_student = batch["attention_mask"].to(student_device)
        student_logits = self.student_model(input_ids_student, attention_mask=attention_mask_student).logits
        student_log_probs = torch.nn.functional.log_softmax(student_logits, dim=-1).to(teacher_device)
    
        # Generate bad teacher guidance (no gradients needed)
        input_ids_teacher = batch["input_ids"].to(teacher_device)
        attention_mask_teacher = batch["attention_mask"].to(teacher_device)
    
        with torch.no_grad():
            good_teacher_logits = self.good_teacher(input_ids_teacher, attention_mask=attention_mask_teacher).logits
            bad_teacher_logits = self.create_bad_teacher_logits(good_teacher_logits)
            bad_teacher_probs = torch.nn.functional.softmax(bad_teacher_logits, dim=-1)
    
        # KL divergence loss: align student with bad teacher
        forget_kl = torch.nn.functional.kl_div(
            student_log_probs,
            bad_teacher_probs,
            reduction="batchmean",
            log_target=False
        )

        # Entropy regularization: encourage uncertain predictions
        entropy_loss = -(student_log_probs.exp() * student_log_probs).sum(-1).mean()
        
        return 2.0 * forget_kl + 0.1 * entropy_loss  # Positive entropy encourages uncertainty

    def validate_retain_only(self, retain_val_dataloader):
        """Validate student model on retain samples only."""
        self.student_model.eval()
        val_losses = []
        
        with torch.no_grad():
            for batch in retain_val_dataloader:
                loss = self.compute_retain_only_loss(batch)
                val_losses.append(loss.item())
        
        self.student_model.train()
        return np.mean(val_losses) if val_losses else float('inf')

    def validate_forget_only(self, forget_val_dataloader):
        """Validate student model on forget samples only."""
        self.student_model.eval()
        val_losses = []
        
        with torch.no_grad():
            for batch in forget_val_dataloader:
                loss = self.compute_forget_only_loss(batch)
                val_losses.append(loss.item())
        
        self.student_model.train()
        return np.mean(val_losses) if val_losses else float('inf')
      
        
    def train_student_sequential(self, retain_dataloader, forget_dataloader, 
                               retain_val_dataloader=None, forget_val_dataloader=None, 
                               num_epochs=6, lr=1e-5, val_freq=1, patience=3):
        """
        Train student model using sequential dual-teacher approach.
        
        This method alternates between two training phases per epoch:
        1. Retain phase: Learn from good teacher to preserve knowledge
        2. Forget phase: Learn from bad teacher to unlearn unwanted knowledge
        
        Args:
            retain_dataloader: DataLoader for retain samples
            forget_dataloader: DataLoader for forget samples  
            retain_val_dataloader: Validation data for retain samples
            forget_val_dataloader: Validation data for forget samples
            num_epochs (int): Number of training epochs
            lr (float): Learning rate
            val_freq (int): Validation frequency (every N epochs)
            patience (int): Early stopping patience
        """
        print("Training student with SEQUENTIAL dual-teacher approach...")
        
        self.student_model.train()
        optimizer = torch.optim.AdamW(self.student_model.parameters(), lr=lr, weight_decay=0.01)
        
        patience_counter = 0
        
        for epoch in range(num_epochs):
            retain_losses = []
            forget_losses = []
            
            # Calculate total batches for progress tracking
            total_batches = len(retain_dataloader) + len(forget_dataloader)
            
            with tqdm(total=total_batches, desc=f"Student Epoch {epoch+1} (Sequential)") as pbar:
                
                # Phase 1: Training on retain samples (preserve knowledge)
                pbar.set_description(f"Epoch {epoch+1} - Phase 1: Retain")
                for batch in retain_dataloader:
                    optimizer.zero_grad()
                    loss = self.compute_retain_only_loss(batch)
                    loss.backward()
                    optimizer.step()
                    
                    retain_losses.append(loss.item())
                    pbar.update(1)
                    pbar.set_postfix({"Retain Loss": f"{loss.item():.4f}"})
                
                # Phase 2: Training on forget samples (unlearn knowledge)  
                pbar.set_description(f"Epoch {epoch+1} - Phase 2: Forget")
                for batch in forget_dataloader:
                    optimizer.zero_grad()
                    # Apply reduced weight to prevent catastrophic forgetting
                    loss = self.compute_forget_only_loss(batch) * self.forget_weight_multiplier
                    loss.backward()
                    optimizer.step()
                    
                    # Store unweighted loss for accurate logging
                    forget_losses.append(loss.item() / self.forget_weight_multiplier)
                    pbar.update(1)
                    pbar.set_postfix({"Forget Loss": f"{loss.item():.4f}"})
               
            # Epoch summary logging
            avg_retain_loss = np.mean(retain_losses) if retain_losses else 0.0
            avg_forget_loss = np.mean(forget_losses) if forget_losses else 0.0
            
            print(f"Student Epoch {epoch+1}")
            print(f"   └─ Retain: {avg_retain_loss:.4f}, Forget: {avg_forget_loss:.4f} (×{self.forget_weight_multiplier})")
            
            # Validation and early stopping
            if retain_val_dataloader is not None and forget_val_dataloader is not None and (epoch + 1) % val_freq == 0:
                print("Running validation...")
                retain_val_loss = self.validate_retain_only(retain_val_dataloader)
                forget_val_loss = self.validate_forget_only(forget_val_dataloader)
                
                # Combined validation metric for early stopping
                # Lower retain loss = better, Higher forget loss = better (more confusion)
                retain_score = retain_val_loss                   # Lower is better
                forget_score = 1.0 / (1.0 + forget_val_loss)     # Lower is better (inverted)
                combined_val_loss = 0.7 * retain_score + 0.3 * forget_score
                
                print(f"Validation - Retain: {retain_val_loss:.4f}, Forget: {forget_val_loss:.4f}, Combined: {combined_val_loss:.4f}")
                
                # Check for improvement
                if combined_val_loss < self.best_val_loss:
                    self.best_val_loss = combined_val_loss
                    self.best_epoch = epoch + 1
                    patience_counter = 0
                    print(f"New best validation score: {combined_val_loss:.4f}")
                    
                    # Save best performing model
                    self.save_model(f"studentmodel_best_val_sequential")
                else:
                    patience_counter += 1
                    print(f"No improvement for {patience_counter} validation checks")
                
                # Early stopping check
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1} (no improvement for {patience} validations)")
                    break
            
            # Save checkpoint after each epoch
            self.save_model(f"studentmodel_epoch_{epoch+1}_sequential")
        
        print("Student sequential training completed")
        print(f"Best validation score: {self.best_val_loss:.4f} at epoch {self.best_epoch}")

    def train_good_teacher(self, dataloader, val_dataloader=None, num_epochs=2, lr=1e-4, save_path="good_teacher_adapter", val_freq=1):
        """
        Train the good teacher model on retain samples using LoRA fine-tuning.
        
        The good teacher serves as a knowledge preserving model that the student
        learns from for retain samples to maintain important capabilities.
        
        Args:
            dataloader: Training data (will filter for retain samples only)
            val_dataloader: Validation data (optional)
            num_epochs (int): Number of training epochs
            lr (float): Learning rate
            save_path (str): Path to save LoRA adapter
            val_freq (int): Validation frequency
        """
        print("Training good teacher with LoRA...")

        self.good_teacher.to(self.device_map["teacher"])
        self.good_teacher.train()
        optimizer = torch.optim.AdamW(self.good_teacher.parameters(), lr=lr)
        
        for epoch in range(num_epochs):
            print(f"Epoch {epoch + 1}/{num_epochs} - Good Teacher Training")
            
            epoch_losses = []
            retain_batches_processed = 0
            
            with tqdm(total=len(dataloader), desc=f"Good Teacher Epoch {epoch+1}") as pbar:
                for batch in dataloader:
                    # Filter to retain samples only (split == 0)
                    split = batch['split']
                    retain_mask = (split == 0)
                    
                    if not retain_mask.any():
                        pbar.update(1)
                        continue
                    
                    # Extract retain samples from batch
                    input_ids = batch['input_ids'][retain_mask].to(self.device_map["teacher"])
                    attention_mask = batch['attention_mask'][retain_mask].to(self.device_map["teacher"])
                    labels = batch['labels'][retain_mask].to(self.device_map["teacher"])
                    
                    if input_ids.size(0) == 0:
                        pbar.update(1)
                        continue
                    
                    # Standard language modeling training
                    optimizer.zero_grad()
                    outputs = self.good_teacher(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                    loss = outputs.loss
                    
                    loss.backward()
                    optimizer.step()
                    
                    epoch_losses.append(loss.item())
                    retain_batches_processed += 1
                    pbar.update(1)
                    
                    if retain_batches_processed % 100 == 0:
                        pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
            
            if epoch_losses:
                avg_loss = np.mean(epoch_losses)
                print(f"Good Teacher Epoch {epoch+1} - Average Loss: {avg_loss:.4f}")
        
        print("Good teacher training completed")
        
        # Save LoRA adapter only (more efficient than full model)
        self.save_good_teacher(save_path)
        print(f"Good teacher adapter saved to {save_path}")
        
    def save_good_teacher(self, save_path):
        """Save the LoRA adapter of the good teacher model."""
        self.good_teacher.save_pretrained(save_path)
        
    def save_model(self, save_path):
        """Save student model and tokenizer to specified path."""
        self.student_model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)

    def load_teacher(self, teacher_path):
        """
        Load a pre-trained good teacher model from path.
        
        Args:
            teacher_path (str): Path to pre-trained teacher model
        """
        self.good_teacher = AutoModelForCausalLM.from_pretrained(teacher_path)
        self.good_teacher.eval()  # Set to evaluation mode
        # Freeze teacher parameters (no training needed)
        for param in self.good_teacher.parameters():
            param.requires_grad = False
        self.good_teacher.to(self.device_map["teacher"])
        print("Good teacher loaded and frozen")

## 5. Setup Trainer and Training

In [6]:
# LoRA (Low-Rank Adaptation) configuration for both teacher and student models
# LoRA allows efficient fine-tuning by training only small adapter layers
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,       # Causal language modeling task
    inference_mode=False,               # Training mode (not inference)
    r=16,                              # Rank of adaptation (controls adapter size)
    lora_alpha=32,                     # Scaling factor (typically 2x rank)
    lora_dropout=0.1,                  # Dropout for regularization
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],  # Transformer attention modules to adapt
    bias="none",                       # Don't adapt bias parameters
)

# Initialize the DualTeacher trainer with GPU configuration
# Uses dual-GPU setup for optimal performance (teacher on GPU1, student on GPU0)
trainer = DualTeacherTrainer(
    model_path=MODEL_PATH,
    tokenizer=tokenizer,
    teacher_lora_config=lora_config,
    student_lora_config=lora_config,
    device_map={"student": "cuda:0", "teacher": "cuda:1"}
)

# Training configuration flags
train_good_teacher = False  # Set to True if you want to train teacher from scratch

# Teacher training section (optional - only if training teacher from scratch)
if train_good_teacher == True:
    trainer.setup_models()
    # Train good teacher on retain data to preserve knowledge
    trainer.train_good_teacher(train_dataloader, val_dataloader, num_epochs=5) 

In [None]:
# Setup student model and load pre-trained teacher
# Skip teacher setup since we're using a pre-trained teacher model
trainer.setup_models(skip_teacher_setup=True)

# Load pre-trained good teacher model and freeze it for knowledge distillation
# This teacher has already been trained on retain data to preserve knowledge
trainer.load_teacher(GOOD_TEACHER_PATH)

# Train the student model using sequential dual-teacher approach
# This method alternates between retain and forget training phases
# - Retain phase: Learn from good teacher to preserve knowledge  
# - Forget phase: Learn from bad teacher to unlearn unwanted knowledge
trainer.train_student_sequential(
    retain_train_dataloader,     # DataLoader for retain samples
    forget_train_dataloader,     # DataLoader for forget samples  
    retain_val_dataloader,       # Validation data for retain samples
    forget_val_dataloader,       # Validation data for forget samples
    num_epochs=6                 # Total training epochs
)