In [1]:
import os
import json
import random
import numpy as np
import pandas as pd
import torch
import time
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
import transformers
from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback,
    TrainerCallback,
    TrainerState,
    TrainerControl,
    get_cosine_schedule_with_warmup
)
from transformers.trainer_utils import set_seed
import evaluate
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
import bert_score
from collections import Counter
import nltk
import warnings
from typing import List, Tuple
from tqdm import tqdm
from datetime import datetime
import logging
from tabulate import tabulate
from transformers.utils.logging import set_verbosity_error
import gc
import optuna  # --- NEW ---

warnings.filterwarnings("ignore")

# Setup logging for VS Code
def setup_logging():
    """Setup logging configuration"""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s | %(levelname)s | %(message)s',
        datefmt='%H:%M:%S',
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler('training.log', encoding='utf-8')
        ]
    )
    return logging.getLogger(__name__)

logger = setup_logging()
#set_verbosity_error()

# Ensure NLTK data is available
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

class QuestionGenerationDataset(Dataset):
    def __init__(self, contexts, questions, tokenizer, max_length=512):
        self.contexts = contexts
        self.questions = questions
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.contexts)

    def __getitem__(self, idx):
        context = self.contexts[idx]
        question = self.questions[idx]

        input_text = f"Generate a question from context: {context}"

        input_encoding = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding=False,
            truncation=True,
            return_tensors="pt"
        )

        target_encoding = self.tokenizer(
            question,
            max_length=128,
            padding=False,
            truncation=True,
            return_tensors="pt"
        )

        return {
            "input_ids": input_encoding.input_ids.flatten(),
            "attention_mask": input_encoding.attention_mask.flatten(),
            "labels": target_encoding.input_ids.flatten()
        }

class DataProcessor:
    def __init__(self, data_dir, tokenizer):
        self.data_dir = data_dir
        self.tokenizer = tokenizer

    def load_fairytale_qa(self, file_path: str) -> Tuple[List[str], List[str]]:
        contexts, questions = [], []
        if os.path.exists(file_path):
            logger.info(f"Loading FairytaleQA from: {file_path}")
            try:
                df = pd.read_csv(file_path, encoding='utf-8')
                for _, row in df.iterrows():
                    if pd.notna(row.get('context', '')) and pd.notna(row.get('question', '')):
                        context, question = str(row['context']).strip(), str(row['question']).strip()
                        if context and question and len(self.tokenizer.encode(context)) < 512:
                            contexts.append(context)
                            questions.append(question)
                logger.info(f"Loaded {len(contexts)} samples from FairytaleQA")
            except Exception as e:
                logger.error(f"Error reading {file_path}: {e}")
        else:
            logger.warning(f"File not found: {file_path}")
        return contexts, questions

    def load_mctest(self, file_path: str) -> Tuple[List[str], List[str]]:
        contexts, questions = [], []
        if os.path.exists(file_path):
            logger.info(f"Loading MCTest from: {file_path}")
            try:
                df = pd.read_csv(file_path, encoding='utf-8')
                for _, row in df.iterrows():
                    if pd.notna(row.get('context', '')) and pd.notna(row.get('question', '')):
                        context, question = str(row['context']).strip(), str(row['question']).strip()
                        if context and question and len(self.tokenizer.encode(context)) < 512:
                            contexts.append(context)
                            questions.append(question)
                logger.info(f"Loaded {len(contexts)} samples from MCTest")
            except Exception as e:
                logger.error(f"Error reading {file_path}: {e}")
        else:
            logger.warning(f"File not found: {file_path}")
        return contexts, questions

    def load_squad(self, file_path: str) -> Tuple[List[str], List[str]]:
        contexts, questions = [], []
        if os.path.exists(file_path):
            logger.info(f"Loading SQuAD from: {file_path}")
            try:
                df = pd.read_csv(file_path, encoding='utf-8')
                for _, row in df.iterrows():
                    if pd.notna(row.get('context', '')) and pd.notna(row.get('question', '')):
                        context, question = str(row['context']).strip(), str(row['question']).strip()
                        if context and question and len(self.tokenizer.encode(context)) < 512:
                            contexts.append(context)
                            questions.append(question)
                logger.info(f"Loaded {len(contexts)} samples from SQuAD")
            except Exception as e:
                logger.error(f"Error reading {file_path}: {e}")
        else:
            logger.warning(f"File not found: {file_path}")
        return contexts, questions

    def load_all_datasets(self) -> Tuple[List[str], List[str], List[str], List[str]]:
        all_train_contexts, all_train_questions = [], []
        all_val_contexts, all_val_questions = [], []

        datasets = {
            "FairytaleQA": (self.load_fairytale_qa, ("FairytaleQA_train.csv", "FairytaleQA_validation.csv")),
            "MCTest": (self.load_mctest, ("mctest_train.csv", "mctest_validation.csv")),
            "SQuAD": (self.load_squad, ("squad_train_v1.csv", "squad_validation_v1.csv")),
        }

        for name, (loader_fn, (train_file, val_file)) in datasets.items():
            train_c, train_q = loader_fn(os.path.join(self.data_dir, train_file))
            val_c, val_q = loader_fn(os.path.join(self.data_dir, val_file))
            all_train_contexts.extend(train_c)
            all_train_questions.extend(train_q)
            all_val_contexts.extend(val_c)
            all_val_questions.extend(val_q)

        logger.info(f"Loaded {len(all_train_contexts)} training samples and {len(all_val_contexts)} validation samples")

        train_combined = list(zip(all_train_contexts, all_train_questions))
        val_combined = list(zip(all_val_contexts, all_val_questions))
        random.shuffle(train_combined)
        random.shuffle(val_combined)

        if train_combined:
            all_train_contexts, all_train_questions = zip(*train_combined)
        if val_combined:
            all_val_contexts, all_val_questions = zip(*val_combined)

        return (list(all_train_contexts), list(all_train_questions),
                list(all_val_contexts), list(all_val_questions))

class AdvancedEvaluationMetrics:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.rouge_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
        self.smoothing = SmoothingFunction().method1

    def compute_bleu(self, references, predictions):
        bleu_scores = {f"bleu_{i}": [] for i in range(1, 5)}
        for ref, pred in zip(references, predictions):
            ref_tokens, pred_tokens = ref.split(), pred.split()
            for i in range(1, 5):
                weights = [1/i] * i
                score = sentence_bleu([ref_tokens], pred_tokens, weights=weights, smoothing_function=self.smoothing)
                bleu_scores[f"bleu_{i}"].append(score)
        return {k: np.mean(v) for k, v in bleu_scores.items()}

    def compute_rouge_l(self, references, predictions):
        scores = [self.rouge_scorer.score(ref, pred)['rougeL'].fmeasure for ref, pred in zip(references, predictions)]
        return {"rouge_l": np.mean(scores)}

    def compute_meteor(self, references, predictions):
        try:
            meteor = evaluate.load("meteor")
            return {"meteor": meteor.compute(predictions=predictions, references=references)["meteor"]}
        except Exception:
            return {"meteor": 0.0}

    def compute_bert_score(self, references, predictions):
        try:
            P, R, F1 = bert_score.score(predictions, references, lang="en", verbose=False)
            return {"bert_score": F1.mean().item()}
        except Exception:
            return {"bert_score": 0.0}

    # --- MODIFICATION 1: SELF-BLEU FIXED ---
    def compute_self_bleu(self, predictions):
        if len(predictions) < 2: return {"self_bleu": 0.0}
        scores = []
        for i, pred in enumerate(predictions):
            others = predictions[:i] + predictions[i+1:]
            pred_tokens = pred.split()
            # The fix: Using all other sentences, not just the first 10
            other_tokens = [other.split() for other in others]
            if other_tokens:
                score = sentence_bleu(other_tokens, pred_tokens, smoothing_function=self.smoothing)
                scores.append(score)
        return {"self_bleu": np.mean(scores) if scores else 0.0}

    def compute_distinct_n(self, predictions, n):
        all_ngrams = [tuple(tokens[i:i+n]) for pred in predictions for tokens in [pred.split()] for i in range(len(tokens)-n+1)]
        if not all_ngrams: return 0.0
        return len(set(all_ngrams)) / len(all_ngrams)

    def compute_all_metrics(self, references, predictions):
        if not references or not predictions:
            return {m: 0.0 for m in ["bleu_1", "bleu_2", "bleu_3", "bleu_4", "rouge_l", "meteor", "bert_score", "self_bleu", "distinct_1", "distinct_2"]}
        
        metrics = {}
        logger.info("Computing quality and diversity metrics...")
        metrics.update(self.compute_bleu(references, predictions))
        metrics.update(self.compute_rouge_l(references, predictions))
        metrics.update(self.compute_meteor(references, predictions))
        metrics.update(self.compute_bert_score(references, predictions))
        metrics.update(self.compute_self_bleu(predictions))
        metrics["distinct_1"] = self.compute_distinct_n(predictions, 1)
        metrics["distinct_2"] = self.compute_distinct_n(predictions, 2)
        return metrics

class DiverseDecoder:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def diverse_generate(self, input_ids, attention_mask, num_return_sequences=1):
        outputs = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=128,
            num_beams=8,
            num_return_sequences=num_return_sequences,
            num_beam_groups=4,
            diversity_penalty=1.5,
            no_repeat_ngram_size=2,
            early_stopping=True,
            pad_token_id=self.tokenizer.pad_token_id,
            eos_token_id=self.tokenizer.eos_token_id
        )
        return outputs

class MetricsLogger:
    def __init__(self):
        self.evaluation_history = []

    def log_epoch_progress(self, epoch, loss, learning_rate):
        logger.info(f"Epoch {epoch:>.2f} | Loss: {loss:.4f} | LR: {learning_rate:.2e}")

    def log_evaluation(self, metrics, epoch=None, step=None):
        eval_record = {'epoch': epoch, 'step': step, 'timestamp': datetime.now(), **metrics}
        self.evaluation_history.append(eval_record)
        self.display_metrics_table(metrics, epoch)

    def display_metrics_table(self, metrics, epoch=None, step=None):
        print("\n" + "="*80)
        print(f"EVALUATION RESULTS - Epoch {epoch}")
        print("="*80)
        
        quality = {k: v for k, v in metrics.items() if any(x in k for x in ['bleu', 'rouge', 'meteor', 'bert_score'])}
        diversity = {k: v for k, v in metrics.items() if any(x in k for x in ['self_bleu', 'distinct'])}
        other = {k: v for k, v in metrics.items() if k not in quality and k not in diversity}
        
        for title, metric_dict in [("QUALITY", quality), ("DIVERSITY", diversity), ("OTHER", other)]:
            if metric_dict:
                print(f"\n{title} METRICS:")
                table_data = [[k.replace('eval_', '').replace('_', '-').upper(), f"{v:.4f}"] for k, v in metric_dict.items()]
                print(tabulate(table_data, headers=['Metric', 'Score'], tablefmt='grid'))
        print("="*80 + "\n")

    def display_training_summary(self):
        if not self.evaluation_history: return
        print("\n" + "="*80)
        print("TRAINING SUMMARY")
        print("="*80)

        eval_logs = [log for log in self.evaluation_history if 'eval_bleu_4' in log]
        if not eval_logs: return
        
        best_bleu4 = max(eval_logs, key=lambda x: x.get('eval_bleu_4', 0))
        best_rouge = max(eval_logs, key=lambda x: x.get('eval_rouge_l', 0))
        best_meteor = max(eval_logs, key=lambda x: x.get('eval_meteor', 0))
        
        summary_data = [
            ['Best BLEU-4', f"{best_bleu4.get('eval_bleu_4', 0):.4f}", f"Epoch {best_bleu4.get('epoch', 'N/A')}"],
            ['Best ROUGE-L', f"{best_rouge.get('eval_rouge_l', 0):.4f}", f"Epoch {best_rouge.get('epoch', 'N/A')}"],
            ['Best METEOR', f"{best_meteor.get('eval_meteor', 0):.4f}", f"Epoch {best_meteor.get('epoch', 'N/A')}"],
        ]
        print(tabulate(summary_data, headers=['Metric', 'Best Score', 'Achieved At'], tablefmt='grid'))
        print("="*80 + "\n")

class CustomLoggingCallback(TrainerCallback):
    def __init__(self, metrics_logger=None):
        self.metrics_logger = metrics_logger if metrics_logger is not None else MetricsLogger()

    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
        if state.is_world_process_zero and logs:
            eval_metrics = {k: v for k, v in logs.items() if k.startswith("eval_") and isinstance(v, (int, float))}
            if eval_metrics:
                self.metrics_logger.log_evaluation(
                    metrics=eval_metrics,
                    epoch=int(logs.get('epoch', state.epoch or 0)),
                    step=state.global_step
                )
            elif any(key in logs for key in ['train_loss', 'loss']) and 'learning_rate' in logs:
                loss = logs.get('train_loss', logs.get('loss', 0))
                self.metrics_logger.log_epoch_progress(
                    epoch=logs.get('epoch', state.epoch or 0),
                    loss=loss,
                    learning_rate=logs['learning_rate']
                )

    def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if state.is_world_process_zero:
            self.metrics_logger.display_training_summary()

# --- MODIFICATION 2, PART 1: ADD NEW CALLBACK CLASS ---
class AverageTrainLossLogger(TrainerCallback):
    def __init__(self):
        self.epoch_train_losses = []

    def on_epoch_begin(self, args, state, control, **kwargs):
        """Reset the list of losses at the start of each epoch."""
        self.epoch_train_losses = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        """On each log step, if it's a training log, append the loss."""
        if 'loss' in logs and 'learning_rate' in logs:
            self.epoch_train_losses.append(logs['loss'])

    def on_epoch_end(self, args, state, control, **kwargs):
        """At the end of the epoch, calculate and log the average."""
        if self.epoch_train_losses:
            avg_epoch_loss = np.mean(self.epoch_train_losses)
            logger.info(f"===== Average Training Loss for Epoch {int(state.epoch)}: {avg_epoch_loss:.4f} =====")

# --- NEW: CALLBACK FOR OPTUNA PRUNING ---
class OptunaPruningCallback(TrainerCallback):
    """
    A TrainerCallback to report evaluation metrics to Optuna for pruning.
    """
    def __init__(self, trial: optuna.Trial):
        self.trial = trial

    def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics: dict, **kwargs):
        # We report the metric we're optimizing for (bleu_4)
        if "eval_bleu_4" in metrics:
            metric_value = metrics["eval_bleu_4"]
            self.trial.report(metric_value, state.global_step)
            
            # Check if the trial should be pruned
            if self.trial.should_prune():
                logger.warning(f"Trial {self.trial.number} pruned at step {state.global_step} with BLEU-4: {metric_value}.")
                raise optuna.TrialPruned()
        elif "eval_loss" in metrics:
            # Fallback to loss if bleu_4 isn't available for some reason
            self.trial.report(metrics["eval_loss"], state.global_step)
            if self.trial.should_prune():
                logger.warning(f"Trial {self.trial.number} pruned at step {state.global_step} with Loss: {metrics['eval_loss']}.")
                raise optuna.TrialPruned()

def log_gpu_memory(prefix=""):
    """Log current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        logger.info(f"{prefix} GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
    else:
        logger.info(f"{prefix} GPU not available")

def log_model_state(model, prefix=""):
    """Log current model state"""
    logger.info(f"{prefix} Model - Training mode: {model.training}, Device: {next(model.parameters()).device}")

# --- MODIFIED: CustomTrainer now accepts an `is_tuning_trial` flag ---
class CustomTrainer(Trainer):
    def __init__(self, evaluator, metrics_logger=None, is_tuning_trial=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.evaluator = evaluator
        self.decoder = DiverseDecoder(self.model, self.processing_class)
        self.metrics_logger = metrics_logger or MetricsLogger()
        self.is_tuning_trial = is_tuning_trial  # <-- Store the flag

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        # --- MODIFIED: Check if we are in a tuning trial ---
        if self.is_tuning_trial:
            logger.info(f"=== TUNING TRIAL EVALUATION (BLEU-4 Only) ===")
        else:
            logger.info("=== BEFORE EVALUATION ===")
            log_gpu_memory("BEFORE EVAL")
            log_model_state(self.model, "BEFORE EVAL")
        
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        self.model.eval()
        
        predictions, references = [], []
        total_loss = 0.0
        num_batches = 0
        
        epoch = self.state.epoch if self.state.epoch is not None else 0
        if not self.is_tuning_trial:
             logger.info("Starting evaluation on full validation set...")
        
        eval_desc = f"Evaluating Epoch {int(epoch)}"
        if self.is_tuning_trial:
            eval_desc = f"Tuning Trial {self.state.trial_number} Eval" if hasattr(self.state, 'trial_number') else "Tuning Trial Eval"
            
        eval_progress = tqdm(eval_dataloader, desc=eval_desc, unit="batch", disable=False) # Disable tqdm in tuning

        try:
            for batch_idx, batch in enumerate(eval_progress):
                batch_device = {k: v.to(self.args.device) for k, v in batch.items()}
                
                with torch.no_grad():
                    outputs = self.model(**batch_device)
                    total_loss += outputs.loss.item()
                    num_batches += 1
                    
                    inputs = {k: v for k, v in batch_device.items() if k != "labels"}
                    generated_ids = self.decoder.diverse_generate(inputs["input_ids"], inputs["attention_mask"])
                    batch_predictions = self.processing_class.batch_decode(generated_ids, skip_special_tokens=True)
                    predictions.extend(batch_predictions)
                    
                    labels = batch["labels"].cpu().numpy()
                    labels = np.where(labels != -100, labels, self.processing_class.pad_token_id)
                    batch_references = self.processing_class.batch_decode(labels, skip_special_tokens=True)
                    references.extend(batch_references)
                    
                    if (batch_idx + 1) % 20 == 0:
                        torch.cuda.empty_cache()
                        
        except Exception as e:
            logger.error(f"Error during evaluation: {e}")
            return {f"{metric_key_prefix}_loss": float('inf')}

        # --- MODIFIED: Conditional metric computation ---
        metrics = {}
        if self.is_tuning_trial:
            # For tuning, ONLY compute loss and BLEU-4
            if not self.is_tuning_trial: logger.info(f"Tuning trial: Computing only BLEU-4 for {len(predictions)} predictions...")
            bleu_scores = self.evaluator.compute_bleu(references, predictions)
            metrics = {
                f"{metric_key_prefix}_bleu_4": bleu_scores.get('bleu_4', 0.0)
            }
        else:
            # Full evaluation for the final run
            logger.info(f"Computing metrics for {len(predictions)} predictions...")
            metrics = self.evaluator.compute_all_metrics(references, predictions)
            metrics = {f"{metric_key_prefix}_{k}": v for k, v in metrics.items()}
        # --- END OF MODIFICATION ---
        
        avg_loss = total_loss / max(num_batches, 1)
        metrics[f"{metric_key_prefix}_loss"] = avg_loss
        
        self.log(metrics)
        
        del predictions, references, batch_device
        torch.cuda.empty_cache()
        gc.collect()
        
        self.model.train()
        
        if not self.is_tuning_trial:
            logger.info("=== AFTER EVALUATION ===")
            log_gpu_memory("AFTER EVAL")
            log_model_state(self.model, "AFTER EVAL")
            logger.info(f"Evaluation completed successfully. Average loss: {avg_loss:.4f}")
            
        return metrics

def generate_sample_predictions(model, tokenizer, eval_contexts, eval_questions, num_samples=20):
    logger.info("Generating sample predictions...")
    indices = random.sample(range(len(eval_contexts)), min(num_samples, len(eval_contexts)))
    model.eval()
    decoder = DiverseDecoder(model, tokenizer)
    
    print("\n" + "="*100)
    print("SAMPLE PREDICTIONS - 20 Context-Question Pairs")
    print("="*100)
    
    for i, idx in enumerate(indices, 1):
        context, actual_question = eval_contexts[idx], eval_questions[idx]
        input_text = f"Generate a question from context: {context}"
        input_encoding = tokenizer(input_text, max_length=512, truncation=True, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            generated_ids = decoder.diverse_generate(input_encoding["input_ids"], input_encoding["attention_mask"])
            predicted_question = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        
        print(f"\n--- SAMPLE {i:2d} ---")
        print(f"CONTEXT: {context[:200]}{'...' if len(context) > 200 else ''}")
        print(f"ACTUAL:   {actual_question}")
        print(f"PREDICTED: {predicted_question}")
        print("-" * 80)
    print("="*100 + "\n")

def setup_training_args(output_dir, num_train_epochs=5, train_dataset_size=0):
    effective_batch_size = 8 * 2
    steps_per_epoch = max(1, train_dataset_size // effective_batch_size)
    
    logger.info(f"Training Configuration: Dataset size: {train_dataset_size:,}, Effective batch size: {effective_batch_size}, Steps per epoch: {steps_per_epoch}")
    
    return TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=2,
        torch_compile=True,
        learning_rate=5e-5,  # This will be overridden by Optuna
        weight_decay=0.01, # This will be overridden by Optuna
        warmup_steps=min(500, steps_per_epoch), # This will be overridden by Optuna
        logging_steps=max(10, steps_per_epoch // 10),
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="eval_bleu_4",
        greater_is_better=True,
        bf16=torch.cuda.is_bf16_supported(),
        dataloader_pin_memory=True,
        dataloader_num_workers=0,
        remove_unused_columns=False,
        gradient_checkpointing=True,
        lr_scheduler_type="cosine", # This will be overridden by Optuna
        save_total_limit=3, # Will use 1 for tuning
        report_to="none",
        seed=42,
        data_seed=42,
        group_by_length=True,
    )

# --- NEW: Globals to hold data for Optuna trials ---
# This avoids reloading data for every trial
g_tokenizer = None
g_model_name = "google/flan-t5-base"
g_train_dataset = None
g_eval_dataset = None
g_data_collator = None
g_evaluator = None
g_eval_contexts = None
g_eval_questions = None

# --- NEW: Optuna Objective Function ---
def objective(trial: optuna.Trial):
    global g_tokenizer, g_model_name, g_train_dataset, g_eval_dataset, g_data_collator, g_evaluator
    
    logger.info(f"--- Starting Optuna Trial {trial.number} ---")
    
    # --- 1. Define Search Space ---
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-4, log=True)
    weight_decay = trial.suggest_float("weight_decay", 0.0, 0.1)
    lr_scheduler_type = trial.suggest_categorical("lr_scheduler_type", ["linear", "cosine", "constant"])
    
    # Calculate steps_per_epoch for warmup suggestion
    effective_batch_size = 8 * 2
    steps_per_epoch = max(1, len(g_train_dataset) // effective_batch_size)
    warmup_steps = trial.suggest_int("warmup_steps", 100, steps_per_epoch) 
    
    # --- USE SMALL EVAL SUBSET FOR TUNING (e.g., 10% or 5000 samples max) ---
    eval_subset_size = min(5000, int(len(g_eval_dataset) * 0.3))  # Use 10% of eval data, max 5000
    eval_indices = random.sample(range(len(g_eval_dataset)), eval_subset_size)
    eval_subset = torch.utils.data.Subset(g_eval_dataset, eval_indices)
    
    logger.info(f"Trial {trial.number}: Using {len(g_train_dataset)} train samples and {eval_subset_size} eval samples (subset for speed)")
    
    # --- 2. Configure Training ---
    model = T5ForConditionalGeneration.from_pretrained(g_model_name)
    if g_tokenizer.pad_token is not None and g_tokenizer.pad_token_id > g_tokenizer.vocab_size:
         model.resize_token_embeddings(len(g_tokenizer))
    
    output_dir = f"./optuna-trials/trial_{trial.number}"
    
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=8,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=2,
        torch_compile=True,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        warmup_steps=warmup_steps,
        lr_scheduler_type=lr_scheduler_type,
        logging_steps=max(10, steps_per_epoch // 10),
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="eval_bleu_4",
        greater_is_better=True,
        bf16=torch.cuda.is_bf16_supported(),
        dataloader_pin_memory=True,
        dataloader_num_workers=0,
        remove_unused_columns=False,
        gradient_checkpointing=True,
        save_total_limit=1,
        report_to="none",
        seed=42,
        data_seed=42,
        group_by_length=True,
        logging_dir=f"{output_dir}/logs",
    )
    
    pruning_callback = OptunaPruningCallback(trial)
    trial_data_collator = DataCollatorForSeq2Seq(tokenizer=g_tokenizer, model=model, padding=True, max_length=512, label_pad_token_id=-100)

    trainer = CustomTrainer(
        evaluator=g_evaluator,
        metrics_logger=None,
        is_tuning_trial=True,
        model=model,
        args=training_args,
        train_dataset=g_train_dataset,  # Full training data
        eval_dataset=eval_subset,        # <--- SUBSET for fast evaluation
        tokenizer=g_tokenizer,
        data_collator=trial_data_collator,
        callbacks=[
            pruning_callback,
            EarlyStoppingCallback(early_stopping_patience=4)
        ]
    )
    
    trainer.state.trial_number = trial.number
    
    # --- 3. Train ---
    try:
        trainer.train()
    except optuna.TrialPruned:
        logger.info(f"Trial {trial.number} was pruned.")
        del model, trainer, trial_data_collator
        torch.cuda.empty_cache()
        gc.collect()
        return 0.0
    except Exception as e:
        logger.error(f"Error in trial {trial.number}: {e}")
        del model, trainer, trial_data_collator
        torch.cuda.empty_cache()
        gc.collect()
        return 0.0

    # --- 4. Report & Return ---
    best_metric = trainer.state.best_metric
    logger.info(f"--- Finished Optuna Trial {trial.number} | Best BLEU-4: {best_metric} ---")
    
    del model, trainer, trial_data_collator
    torch.cuda.empty_cache()
    gc.collect()
    
    return best_metric if best_metric is not None else 0.0


# --- MODIFIED: Main function now orchestrates setup, tuning, and final training ---
def main():
    global g_tokenizer, g_model_name, g_train_dataset, g_eval_dataset, g_data_collator, g_evaluator, g_eval_contexts, g_eval_questions
    
    set_seed(42)
    logger.info("Starting T5 Question Generation Training")
    
    data_dir = "E:/A_CSE499/data" # <-- Set your data directory
    output_dir = "./t5-flan-question-generation-tuned" # For the *final* model

    # --- 1. SETUP (Done ONCE) ---
    logger.info(f"Loading model and tokenizer: {g_model_name}...")
    g_tokenizer = T5Tokenizer.from_pretrained(g_model_name)
    
    # We load a dummy model here just to resize embeddings if needed
    # and to create the *global* data collator.
    # The *actual* model for training will be loaded in the objective/final run.
    dummy_model = T5ForConditionalGeneration.from_pretrained(g_model_name)
    if g_tokenizer.pad_token is None:
        g_tokenizer.add_special_tokens({"pad_token": "<pad>"})
        dummy_model.resize_token_embeddings(len(g_tokenizer))
    
    logger.info("Loading and processing datasets...")
    data_processor = DataProcessor(data_dir, g_tokenizer)
    train_contexts, train_questions, g_eval_contexts, g_eval_questions = data_processor.load_all_datasets()
    
    if not train_contexts:
        logger.error("No training data loaded! Please check data directory.")
        return
        
    g_train_dataset = QuestionGenerationDataset(train_contexts, train_questions, g_tokenizer)
    g_eval_dataset = QuestionGenerationDataset(g_eval_contexts, g_eval_questions, g_tokenizer)
    
    # This collator is used by the objective function, but needs a model
    g_data_collator = DataCollatorForSeq2Seq(tokenizer=g_tokenizer, model=dummy_model, padding=True, max_length=512, label_pad_token_id=-100)
    del dummy_model # We don't need this anymore
    
    g_evaluator = AdvancedEvaluationMetrics(g_tokenizer)
    
    # --- 2. OPTUNA TUNING ---
    logger.info("=== STARTING HYPERPARAMETER TUNING ===")
    
    # Use TPE Sampler (Bayesian)
    sampler = optuna.samplers.TPESampler(seed=42)
    # Prune trials that are performing worse than the median
    pruner = optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=3) # Prune after 1 epoch (n_warmup_steps=1)
    
    # Create study with checkpointing
    study = optuna.create_study(
        study_name="t5-qg-tuning",
        direction="maximize", # We want to maximize BLEU-4
        sampler=sampler,
        pruner=pruner,
        storage="sqlite:///t5_qg_tuning.db", # This file enables resuming
        load_if_exists=True # Resume from checkpoint
    )
    
    n_trials = 10
    n_completed_trials = len([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE])
    
    if n_completed_trials >= n_trials:
        logger.info(f"Study already has {n_completed_trials} completed trials. Skipping optimization.")
    else:
        n_remaining_trials = n_trials - n_completed_trials
        logger.info(f"Resuming study. {n_completed_trials} trials complete, running {n_remaining_trials} more.")
        try:
            study.optimize(objective, n_trials=n_remaining_trials) # e.g., 12h timeout
        except KeyboardInterrupt:
            logger.warning("Optuna optimization interrupted.")
    
    logger.info("=== TUNING COMPLETE ===")
    best_trial = study.best_trial
    logger.info(f"Best trial number: {best_trial.number}")
    logger.info(f"Best BLEU-4: {best_trial.value:.4f}")
    logger.info(f"Best hyperparameters: {json.dumps(best_trial.params, indent=2)}")

    # --- 3. FINAL TRAINING ---
    logger.info("=== STARTING FINAL TRAINING WITH BEST HYPERPARAMETERS ===")
    
    best_params = best_trial.params
    
    # Load the final model to be trained
    model = T5ForConditionalGeneration.from_pretrained(g_model_name)
    if g_tokenizer.pad_token is not None and g_tokenizer.pad_token_id > g_tokenizer.vocab_size:
         model.resize_token_embeddings(len(g_tokenizer))
         
    shared_metrics_logger = MetricsLogger()
    custom_logging_callback = CustomLoggingCallback(shared_metrics_logger)
    avg_loss_callback = AverageTrainLossLogger()
    
    # Use the original setup_training_args
    training_args = setup_training_args(
        output_dir,
        num_train_epochs=8, # From original config
        train_dataset_size=len(g_train_dataset)
    )
    
    # Override args with best params from the study
    training_args.learning_rate = best_params["learning_rate"]
    training_args.weight_decay = best_params["weight_decay"]
    training_args.warmup_steps = best_params["warmup_steps"]
    training_args.lr_scheduler_type = best_params["lr_scheduler_type"]
    
    # Re-create the data collator with the *final* model instance
    final_data_collator = DataCollatorForSeq2Seq(tokenizer=g_tokenizer, model=model, padding=True, max_length=512, label_pad_token_id=-100)

    trainer = CustomTrainer(
        evaluator=g_evaluator,
        metrics_logger=shared_metrics_logger,
        is_tuning_trial=False, # <--- IMPORTANT: Run full evaluation
        model=model,
        args=training_args,
        train_dataset=g_train_dataset,
        eval_dataset=g_eval_dataset,
        tokenizer=g_tokenizer,
        data_collator=final_data_collator,
        callbacks=[
            custom_logging_callback,
            avg_loss_callback,
            EarlyStoppingCallback(early_stopping_patience=4)
        ]
    )
    
    logger.info("=== FINAL TRAINING START ===")
    log_gpu_memory("FINAL TRAINING START")
    log_model_state(model, "FINAL TRAINING START")
    
    logger.info("Starting final training...")
    try:
        trainer.train()
    except KeyboardInterrupt:
        logger.warning("Final training interrupted. Saving current model...")
        trainer.save_model(os.path.join(output_dir, "interrupted"))
        return
    
    logger.info("Saving final model...")
    trainer.save_model(os.path.join(output_dir, "final"))
    g_tokenizer.save_pretrained(os.path.join(output_dir, "final"))
    
    logger.info("Performing final comprehensive evaluation...")
    final_metrics = trainer.evaluate()
    
    generate_sample_predictions(model, g_tokenizer, g_eval_contexts, g_eval_questions)
    
    print("\n" + "*"*40)
    print("TRAINING COMPLETED SUCCESSFULLY!")
    print("*"*40)
    
    final_results = [[k.replace('eval_', '').upper(), f"{v:.4f}"] for k, v in sorted(final_metrics.items()) if isinstance(v, (int, float))]
    if final_results:
        print("\nFINAL EVALUATION RESULTS (using best params):")
        print(tabulate(final_results, headers=['Metric', 'Final Score'], tablefmt='fancy_grid'))
    
    logger.info("Training pipeline completed successfully!")

if __name__ == "__main__":
    main()

16:28:02 | INFO | Starting T5 Question Generation Training
16:28:02 | INFO | Loading model and tokenizer: google/flan-t5-base...
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
16:28:05 | INFO | Loading and processing datasets...
16:28:05 | INFO | Loading FairytaleQA from: E:/A_CSE499/data\FairytaleQA_train.csv
Token indices sequence length is longer than the specified maximum sequence length for this model (1224 > 512). Running this sequence through the model will result in indexing errors
16:28:11 | INFO | Loaded 7573 samples from FairytaleQA
16:28:11 | INFO | Loading Fairytale

Epoch,Training Loss,Validation Loss


16:36:44 | INFO | Epoch 0.10 | Loss: 1.9151 | LR: 1.17e-05
16:41:34 | INFO | Epoch 0.20 | Loss: 1.8208 | LR: 2.35e-05
16:46:20 | INFO | Epoch 0.30 | Loss: 1.8171 | LR: 3.53e-05
16:51:06 | INFO | Epoch 0.40 | Loss: 1.7906 | LR: 3.96e-05
16:55:51 | INFO | Epoch 0.50 | Loss: 1.7709 | LR: 3.96e-05
17:00:38 | INFO | Epoch 0.60 | Loss: 1.7694 | LR: 3.95e-05
17:05:24 | INFO | Epoch 0.70 | Loss: 1.7214 | LR: 3.94e-05
17:10:10 | INFO | Epoch 0.80 | Loss: 1.7391 | LR: 3.93e-05
17:14:56 | INFO | Epoch 0.90 | Loss: 1.7407 | LR: 3.91e-05
17:19:42 | INFO | Epoch 1.00 | Loss: 1.7228 | LR: 3.89e-05
17:20:20 | INFO | ===== Average Training Loss for Epoch 1: 1.7808 =====
17:20:20 | INFO | === BEFORE EVALUATION ===
17:20:20 | INFO | BEFORE EVAL GPU Memory - Allocated: 2.80GB, Reserved: 6.38GB
17:20:20 | INFO | BEFORE EVAL Model - Training mode: True, Device: cuda:0
17:20:28 | INFO | Starting evaluation on full validation set...
Evaluating Epoch 1:   0%|          | 0/5461 [00:00<?, ?batch/s]Passing a tupl


EVALUATION RESULTS - Epoch 1

QUALITY METRICS:
+------------+---------+
| Metric     |   Score |
| BLEU-1     |  0.1962 |
+------------+---------+
| BLEU-2     |  0.1132 |
+------------+---------+
| BLEU-3     |  0.0783 |
+------------+---------+
| BLEU-4     |  0.0596 |
+------------+---------+
| ROUGE-L    |  0.2578 |
+------------+---------+
| METEOR     |  0.2579 |
+------------+---------+
| BERT-SCORE |  0.8896 |
+------------+---------+
| SELF-BLEU  |  0.9927 |
+------------+---------+

DIVERSITY METRICS:
+------------+---------+
| Metric     |   Score |
| SELF-BLEU  |  0.9927 |
+------------+---------+
| DISTINCT-1 |  0.066  |
+------------+---------+
| DISTINCT-2 |  0.1489 |
+------------+---------+

OTHER METRICS:
+----------+---------+
| Metric   |   Score |
| LOSS     |  1.5683 |
+----------+---------+



19:54:25 | INFO | === AFTER EVALUATION ===
19:54:25 | INFO | AFTER EVAL GPU Memory - Allocated: 2.82GB, Reserved: 3.05GB
19:54:25 | INFO | AFTER EVAL Model - Training mode: True, Device: cuda:0
19:54:25 | INFO | Evaluation completed successfully. Average loss: 1.5683
19:59:29 | INFO | Epoch 1.10 | Loss: 1.6547 | LR: 3.87e-05
20:04:16 | INFO | Epoch 1.20 | Loss: 1.6417 | LR: 3.84e-05
20:09:01 | INFO | Epoch 1.30 | Loss: 1.6573 | LR: 3.81e-05
20:13:47 | INFO | Epoch 1.40 | Loss: 1.6408 | LR: 3.78e-05
20:18:32 | INFO | Epoch 1.50 | Loss: 1.6559 | LR: 3.74e-05
20:23:17 | INFO | Epoch 1.60 | Loss: 1.6440 | LR: 3.70e-05
20:28:04 | INFO | Epoch 1.70 | Loss: 1.6348 | LR: 3.66e-05
20:32:50 | INFO | Epoch 1.80 | Loss: 1.6348 | LR: 3.62e-05
20:37:36 | INFO | Epoch 1.90 | Loss: 1.6281 | LR: 3.57e-05
20:42:21 | INFO | Epoch 2.00 | Loss: 1.6379 | LR: 3.52e-05
20:42:28 | INFO | ===== Average Training Loss for Epoch 2: 1.6430 =====
20:42:28 | INFO | === BEFORE EVALUATION ===
20:42:28 | INFO | BEFORE E


EVALUATION RESULTS - Epoch 2

QUALITY METRICS:
+------------+---------+
| Metric     |   Score |
| BLEU-1     |  0.1928 |
+------------+---------+
| BLEU-2     |  0.1112 |
+------------+---------+
| BLEU-3     |  0.0769 |
+------------+---------+
| BLEU-4     |  0.0585 |
+------------+---------+
| ROUGE-L    |  0.2567 |
+------------+---------+
| METEOR     |  0.2569 |
+------------+---------+
| BERT-SCORE |  0.8892 |
+------------+---------+
| SELF-BLEU  |  0.9929 |
+------------+---------+

DIVERSITY METRICS:
+------------+---------+
| Metric     |   Score |
| SELF-BLEU  |  0.9929 |
+------------+---------+
| DISTINCT-1 |  0.0672 |
+------------+---------+
| DISTINCT-2 |  0.1536 |
+------------+---------+

OTHER METRICS:
+----------+---------+
| Metric   |   Score |
| LOSS     |  1.5383 |
+----------+---------+



23:17:15 | INFO | === AFTER EVALUATION ===
23:17:15 | INFO | AFTER EVAL GPU Memory - Allocated: 2.82GB, Reserved: 3.04GB
23:17:15 | INFO | AFTER EVAL Model - Training mode: True, Device: cuda:0
23:17:15 | INFO | Evaluation completed successfully. Average loss: 1.5383
23:22:15 | INFO | Epoch 2.10 | Loss: 1.5758 | LR: 3.47e-05
23:27:01 | INFO | Epoch 2.20 | Loss: 1.5641 | LR: 3.41e-05
23:31:48 | INFO | Epoch 2.30 | Loss: 1.5517 | LR: 3.36e-05
23:36:35 | INFO | Epoch 2.40 | Loss: 1.5555 | LR: 3.30e-05
23:41:21 | INFO | Epoch 2.50 | Loss: 1.5503 | LR: 3.23e-05
23:46:07 | INFO | Epoch 2.60 | Loss: 1.5529 | LR: 3.17e-05
23:50:54 | INFO | Epoch 2.70 | Loss: 1.5498 | LR: 3.10e-05
23:55:41 | INFO | Epoch 2.80 | Loss: 1.5511 | LR: 3.04e-05
00:00:27 | INFO | Epoch 2.90 | Loss: 1.5492 | LR: 2.97e-05
00:05:14 | INFO | Epoch 3.00 | Loss: 1.5538 | LR: 2.90e-05
00:05:24 | INFO | ===== Average Training Loss for Epoch 3: 1.5554 =====
00:05:24 | INFO | === BEFORE EVALUATION ===
00:05:24 | INFO | BEFORE E


EVALUATION RESULTS - Epoch 3

QUALITY METRICS:
+------------+---------+
| Metric     |   Score |
| BLEU-1     |  0.1941 |
+------------+---------+
| BLEU-2     |  0.1125 |
+------------+---------+
| BLEU-3     |  0.0782 |
+------------+---------+
| BLEU-4     |  0.0594 |
+------------+---------+
| ROUGE-L    |  0.2575 |
+------------+---------+
| METEOR     |  0.2581 |
+------------+---------+
| BERT-SCORE |  0.889  |
+------------+---------+
| SELF-BLEU  |  0.9933 |
+------------+---------+

DIVERSITY METRICS:
+------------+---------+
| Metric     |   Score |
| SELF-BLEU  |  0.9933 |
+------------+---------+
| DISTINCT-1 |  0.0664 |
+------------+---------+
| DISTINCT-2 |  0.1531 |
+------------+---------+

OTHER METRICS:
+----------+---------+
| Metric   |   Score |
| LOSS     |    1.53 |
+----------+---------+



02:40:40 | INFO | === AFTER EVALUATION ===
02:40:40 | INFO | AFTER EVAL GPU Memory - Allocated: 2.82GB, Reserved: 3.03GB
02:40:40 | INFO | AFTER EVAL Model - Training mode: True, Device: cuda:0
02:40:40 | INFO | Evaluation completed successfully. Average loss: 1.5300
02:45:37 | INFO | Epoch 3.10 | Loss: 1.4986 | LR: 2.82e-05
02:50:22 | INFO | Epoch 3.20 | Loss: 1.4919 | LR: 2.75e-05
02:55:08 | INFO | Epoch 3.30 | Loss: 1.4794 | LR: 2.67e-05
02:59:54 | INFO | Epoch 3.40 | Loss: 1.4839 | LR: 2.60e-05
03:04:41 | INFO | Epoch 3.50 | Loss: 1.4989 | LR: 2.52e-05
03:09:27 | INFO | Epoch 3.60 | Loss: 1.4929 | LR: 2.44e-05
03:14:13 | INFO | Epoch 3.70 | Loss: 1.4992 | LR: 2.36e-05
03:18:59 | INFO | Epoch 3.80 | Loss: 1.4908 | LR: 2.28e-05
03:23:46 | INFO | Epoch 3.89 | Loss: 1.4754 | LR: 2.20e-05
03:28:33 | INFO | Epoch 3.99 | Loss: 1.4719 | LR: 2.12e-05
03:28:47 | INFO | ===== Average Training Loss for Epoch 4: 1.4883 =====
03:28:47 | INFO | === BEFORE EVALUATION ===
03:28:47 | INFO | BEFORE E


EVALUATION RESULTS - Epoch 4

QUALITY METRICS:
+------------+---------+
| Metric     |   Score |
| BLEU-1     |  0.1951 |
+------------+---------+
| BLEU-2     |  0.1133 |
+------------+---------+
| BLEU-3     |  0.0791 |
+------------+---------+
| BLEU-4     |  0.0604 |
+------------+---------+
| ROUGE-L    |  0.2581 |
+------------+---------+
| METEOR     |  0.2576 |
+------------+---------+
| BERT-SCORE |  0.8893 |
+------------+---------+
| SELF-BLEU  |  0.9926 |
+------------+---------+

DIVERSITY METRICS:
+------------+---------+
| Metric     |   Score |
| SELF-BLEU  |  0.9926 |
+------------+---------+
| DISTINCT-1 |  0.0665 |
+------------+---------+
| DISTINCT-2 |  0.1537 |
+------------+---------+

OTHER METRICS:
+----------+---------+
| Metric   |   Score |
| LOSS     |  1.5295 |
+----------+---------+



06:03:44 | INFO | === AFTER EVALUATION ===
06:03:44 | INFO | AFTER EVAL GPU Memory - Allocated: 2.82GB, Reserved: 3.04GB
06:03:44 | INFO | AFTER EVAL Model - Training mode: True, Device: cuda:0
06:03:44 | INFO | Evaluation completed successfully. Average loss: 1.5295
06:08:36 | INFO | Epoch 4.09 | Loss: 1.4464 | LR: 2.04e-05
06:13:22 | INFO | Epoch 4.19 | Loss: 1.4328 | LR: 1.96e-05
06:18:07 | INFO | Epoch 4.29 | Loss: 1.4370 | LR: 1.88e-05
06:22:52 | INFO | Epoch 4.39 | Loss: 1.4418 | LR: 1.80e-05
06:27:39 | INFO | Epoch 4.49 | Loss: 1.4365 | LR: 1.72e-05
06:32:26 | INFO | Epoch 4.59 | Loss: 1.4422 | LR: 1.64e-05
06:37:12 | INFO | Epoch 4.69 | Loss: 1.4372 | LR: 1.56e-05
06:41:58 | INFO | Epoch 4.79 | Loss: 1.4343 | LR: 1.48e-05
06:46:44 | INFO | Epoch 4.89 | Loss: 1.4321 | LR: 1.40e-05
06:51:31 | INFO | Epoch 4.99 | Loss: 1.4404 | LR: 1.32e-05
06:51:49 | INFO | ===== Average Training Loss for Epoch 5: 1.4381 =====
06:51:49 | INFO | === BEFORE EVALUATION ===
06:51:49 | INFO | BEFORE E


EVALUATION RESULTS - Epoch 5

QUALITY METRICS:
+------------+---------+
| Metric     |   Score |
| BLEU-1     |  0.197  |
+------------+---------+
| BLEU-2     |  0.1148 |
+------------+---------+
| BLEU-3     |  0.0801 |
+------------+---------+
| BLEU-4     |  0.061  |
+------------+---------+
| ROUGE-L    |  0.2605 |
+------------+---------+
| METEOR     |  0.2604 |
+------------+---------+
| BERT-SCORE |  0.8898 |
+------------+---------+
| SELF-BLEU  |  0.9918 |
+------------+---------+

DIVERSITY METRICS:
+------------+---------+
| Metric     |   Score |
| SELF-BLEU  |  0.9918 |
+------------+---------+
| DISTINCT-1 |  0.0669 |
+------------+---------+
| DISTINCT-2 |  0.1536 |
+------------+---------+

OTHER METRICS:
+----------+---------+
| Metric   |   Score |
| LOSS     |  1.5261 |
+----------+---------+



09:25:15 | INFO | === AFTER EVALUATION ===
09:25:15 | INFO | AFTER EVAL GPU Memory - Allocated: 2.81GB, Reserved: 3.04GB
09:25:15 | INFO | AFTER EVAL Model - Training mode: True, Device: cuda:0
09:25:15 | INFO | Evaluation completed successfully. Average loss: 1.5261
09:30:05 | INFO | Epoch 5.09 | Loss: 1.4006 | LR: 1.25e-05
09:34:50 | INFO | Epoch 5.19 | Loss: 1.4038 | LR: 1.17e-05
09:39:36 | INFO | Epoch 5.29 | Loss: 1.3998 | LR: 1.10e-05
09:44:23 | INFO | Epoch 5.39 | Loss: 1.3964 | LR: 1.03e-05
09:49:09 | INFO | Epoch 5.49 | Loss: 1.4050 | LR: 9.57e-06
09:53:56 | INFO | Epoch 5.59 | Loss: 1.4030 | LR: 8.89e-06
09:58:41 | INFO | Epoch 5.69 | Loss: 1.4150 | LR: 8.22e-06
10:03:28 | INFO | Epoch 5.79 | Loss: 1.4080 | LR: 7.57e-06
10:08:14 | INFO | Epoch 5.89 | Loss: 1.3948 | LR: 6.95e-06
10:13:00 | INFO | Epoch 5.99 | Loss: 1.4047 | LR: 6.34e-06
10:13:23 | INFO | ===== Average Training Loss for Epoch 6: 1.4031 =====
10:13:23 | INFO | === BEFORE EVALUATION ===
10:13:23 | INFO | BEFORE E


EVALUATION RESULTS - Epoch 6

QUALITY METRICS:
+------------+---------+
| Metric     |   Score |
| BLEU-1     |  0.1996 |
+------------+---------+
| BLEU-2     |  0.1165 |
+------------+---------+
| BLEU-3     |  0.0813 |
+------------+---------+
| BLEU-4     |  0.0621 |
+------------+---------+
| ROUGE-L    |  0.2627 |
+------------+---------+
| METEOR     |  0.2614 |
+------------+---------+
| BERT-SCORE |  0.8903 |
+------------+---------+
| SELF-BLEU  |  0.9921 |
+------------+---------+

DIVERSITY METRICS:
+------------+---------+
| Metric     |   Score |
| SELF-BLEU  |  0.9921 |
+------------+---------+
| DISTINCT-1 |  0.0662 |
+------------+---------+
| DISTINCT-2 |  0.1515 |
+------------+---------+

OTHER METRICS:
+----------+---------+
| Metric   |   Score |
| LOSS     |  1.5264 |
+----------+---------+



12:46:36 | INFO | === AFTER EVALUATION ===
12:46:36 | INFO | AFTER EVAL GPU Memory - Allocated: 2.81GB, Reserved: 3.02GB
12:46:36 | INFO | AFTER EVAL Model - Training mode: True, Device: cuda:0
12:46:36 | INFO | Evaluation completed successfully. Average loss: 1.5264
12:51:22 | INFO | Epoch 6.09 | Loss: 1.3749 | LR: 5.76e-06
12:56:08 | INFO | Epoch 6.19 | Loss: 1.3699 | LR: 5.20e-06
13:00:53 | INFO | Epoch 6.29 | Loss: 1.3810 | LR: 4.66e-06
13:05:40 | INFO | Epoch 6.39 | Loss: 1.3793 | LR: 4.15e-06
13:10:26 | INFO | Epoch 6.49 | Loss: 1.3982 | LR: 3.67e-06
13:15:13 | INFO | Epoch 6.59 | Loss: 1.3821 | LR: 3.21e-06
13:20:01 | INFO | Epoch 6.69 | Loss: 1.3757 | LR: 2.78e-06
13:24:47 | INFO | Epoch 6.79 | Loss: 1.3836 | LR: 2.38e-06
13:29:33 | INFO | Epoch 6.89 | Loss: 1.3799 | LR: 2.01e-06
13:34:19 | INFO | Epoch 6.99 | Loss: 1.3925 | LR: 1.67e-06
13:34:45 | INFO | ===== Average Training Loss for Epoch 7: 1.3817 =====
13:34:45 | INFO | === BEFORE EVALUATION ===
13:34:45 | INFO | BEFORE E


EVALUATION RESULTS - Epoch 7

QUALITY METRICS:
+------------+---------+
| Metric     |   Score |
| BLEU-1     |  0.1998 |
+------------+---------+
| BLEU-2     |  0.117  |
+------------+---------+
| BLEU-3     |  0.0819 |
+------------+---------+
| BLEU-4     |  0.0624 |
+------------+---------+
| ROUGE-L    |  0.263  |
+------------+---------+
| METEOR     |  0.2626 |
+------------+---------+
| BERT-SCORE |  0.8902 |
+------------+---------+
| SELF-BLEU  |  0.9922 |
+------------+---------+

DIVERSITY METRICS:
+------------+---------+
| Metric     |   Score |
| SELF-BLEU  |  0.9922 |
+------------+---------+
| DISTINCT-1 |  0.0669 |
+------------+---------+
| DISTINCT-2 |  0.1534 |
+------------+---------+

OTHER METRICS:
+----------+---------+
| Metric   |   Score |
| LOSS     |  1.5286 |
+----------+---------+



16:08:33 | INFO | === AFTER EVALUATION ===
16:08:33 | INFO | AFTER EVAL GPU Memory - Allocated: 2.81GB, Reserved: 3.03GB
16:08:33 | INFO | AFTER EVAL Model - Training mode: True, Device: cuda:0
16:08:33 | INFO | Evaluation completed successfully. Average loss: 1.5286
16:13:14 | INFO | Epoch 7.09 | Loss: 1.3737 | LR: 1.36e-06
16:18:00 | INFO | Epoch 7.19 | Loss: 1.3689 | LR: 1.08e-06
16:22:46 | INFO | Epoch 7.29 | Loss: 1.3756 | LR: 8.33e-07
16:27:35 | INFO | Epoch 7.39 | Loss: 1.3672 | LR: 6.16e-07
16:32:21 | INFO | Epoch 7.49 | Loss: 1.3725 | LR: 4.31e-07
16:37:07 | INFO | Epoch 7.59 | Loss: 1.3803 | LR: 2.79e-07
16:41:54 | INFO | Epoch 7.69 | Loss: 1.3695 | LR: 1.60e-07
16:46:41 | INFO | Epoch 7.79 | Loss: 1.3717 | LR: 7.36e-08
16:51:26 | INFO | Epoch 7.89 | Loss: 1.3694 | LR: 2.03e-08
16:56:12 | INFO | Epoch 7.99 | Loss: 1.3727 | LR: 1.89e-10
16:56:42 | INFO | ===== Average Training Loss for Epoch 8: 1.3721 =====
16:56:42 | INFO | === BEFORE EVALUATION ===
16:56:42 | INFO | BEFORE E


EVALUATION RESULTS - Epoch 8

QUALITY METRICS:
+------------+---------+
| Metric     |   Score |
| BLEU-1     |  0.1991 |
+------------+---------+
| BLEU-2     |  0.1168 |
+------------+---------+
| BLEU-3     |  0.0817 |
+------------+---------+
| BLEU-4     |  0.0622 |
+------------+---------+
| ROUGE-L    |  0.2625 |
+------------+---------+
| METEOR     |  0.262  |
+------------+---------+
| BERT-SCORE |  0.89   |
+------------+---------+
| SELF-BLEU  |  0.9929 |
+------------+---------+

DIVERSITY METRICS:
+------------+---------+
| Metric     |   Score |
| SELF-BLEU  |  0.9929 |
+------------+---------+
| DISTINCT-1 |  0.0667 |
+------------+---------+
| DISTINCT-2 |  0.1528 |
+------------+---------+

OTHER METRICS:
+----------+---------+
| Metric   |   Score |
| LOSS     |  1.5317 |
+----------+---------+



19:30:42 | INFO | === AFTER EVALUATION ===
19:30:42 | INFO | AFTER EVAL GPU Memory - Allocated: 2.82GB, Reserved: 3.03GB
19:30:42 | INFO | AFTER EVAL Model - Training mode: True, Device: cuda:0
19:30:42 | INFO | Evaluation completed successfully. Average loss: 1.5317
There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].



TRAINING SUMMARY
+--------------+--------------+---------------+
| Metric       |   Best Score | Achieved At   |
| Best BLEU-4  |       0.0624 | Epoch 7       |
+--------------+--------------+---------------+
| Best ROUGE-L |       0.263  | Epoch 7       |
+--------------+--------------+---------------+
| Best METEOR  |       0.2626 | Epoch 7       |
+--------------+--------------+---------------+



19:31:05 | INFO | Saving final model...
19:31:10 | INFO | Performing final comprehensive evaluation...
19:31:10 | INFO | === BEFORE EVALUATION ===
19:31:10 | INFO | BEFORE EVAL GPU Memory - Allocated: 2.80GB, Reserved: 3.03GB
19:31:10 | INFO | BEFORE EVAL Model - Training mode: True, Device: cuda:0
19:31:18 | INFO | Starting evaluation on full validation set...
Evaluating Epoch 8: 100%|██████████| 5461/5461 [1:35:23<00:00,  1.05s/batch]
21:06:41 | INFO | Computing metrics for 10922 predictions...
21:06:41 | INFO | Computing quality and diversity metrics...
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\Adeptus_Mechanicus\AppData\Roaming\nltk_data.
[nltk_data]     ..
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\Adeptus_Mechanicus\AppData\Roaming\nltk_data.
[nltk_data]     ..
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     C:


EVALUATION RESULTS - Epoch 8

QUALITY METRICS:
+------------+---------+
| Metric     |   Score |
| BLEU-1     |  0.2005 |
+------------+---------+
| BLEU-2     |  0.1177 |
+------------+---------+
| BLEU-3     |  0.0824 |
+------------+---------+
| BLEU-4     |  0.0629 |
+------------+---------+
| ROUGE-L    |  0.2636 |
+------------+---------+
| METEOR     |  0.2633 |
+------------+---------+
| BERT-SCORE |  0.8903 |
+------------+---------+
| SELF-BLEU  |  0.9925 |
+------------+---------+

DIVERSITY METRICS:
+------------+---------+
| Metric     |   Score |
| SELF-BLEU  |  0.9925 |
+------------+---------+
| DISTINCT-1 |  0.0668 |
+------------+---------+
| DISTINCT-2 |  0.153  |
+------------+---------+

OTHER METRICS:
+----------+---------+
| Metric   |   Score |
| LOSS     |  1.5307 |
+----------+---------+



22:04:38 | INFO | === AFTER EVALUATION ===
22:04:38 | INFO | AFTER EVAL GPU Memory - Allocated: 2.81GB, Reserved: 3.03GB
22:04:38 | INFO | AFTER EVAL Model - Training mode: True, Device: cuda:0
22:04:38 | INFO | Evaluation completed successfully. Average loss: 1.5307
22:04:38 | INFO | Generating sample predictions...



SAMPLE PREDICTIONS - 20 Context-Question Pairs

--- SAMPLE  1 ---
CONTEXT: Gasquet (1908) claimed that the Latin name atra mors (Black Death) for the 14th-century epidemic first appeared in modern times in 1631 in a book on Danish history by J.I. Pontanus: "Vulgo & ab effect...
ACTUAL:   What is the Latin name for Black Death?
PREDICTED: In what year did Gasquet claim the Latin name atra mors first appear in modern times?
--------------------------------------------------------------------------------

--- SAMPLE  2 ---
CONTEXT: West is one of the best-selling artists of all time, having sold more than 32 million albums and 100 million digital downloads worldwide. He has won a total of 21 Grammy Awards, making him one of the ...
ACTUAL:   How many Grammy Awards has Kanye West won?
PREDICTED: How many albums has Kanye sold worldwide?
--------------------------------------------------------------------------------

--- SAMPLE  3 ---
CONTEXT: Galicia was spared the worst of the fighting 

22:04:55 | INFO | Training pipeline completed successfully!



--- SAMPLE 20 ---
CONTEXT: One use of the term "computer security" refers to technology that is used to implement secure operating systems. In the 1980s the United States Department of Defense (DoD) used the "Orange Book" stand...
ACTUAL:   What is an example of a system that meets EAL6?
PREDICTED: What does EAL6 stand for?
--------------------------------------------------------------------------------


****************************************
TRAINING COMPLETED SUCCESSFULLY!
****************************************

FINAL EVALUATION RESULTS (using best params):
╒════════════╤═══════════════╕
│ Metric     │   Final Score │
╞════════════╪═══════════════╡
│ EPOCH      │        8      │
├────────────┼───────────────┤
│ BERT_SCORE │        0.8903 │
├────────────┼───────────────┤
│ BLEU_1     │        0.2005 │
├────────────┼───────────────┤
│ BLEU_2     │        0.1177 │
├────────────┼───────────────┤
│ BLEU_3     │        0.0824 │
├────────────┼───────────────┤
│ BLEU_4     │        0.0