In [1]:
import os
import gc
import json
import random
import numpy as np
import pandas as pd
import nltk
import torch
import optuna
from datasets import Dataset
from transformers import (
    AutoTokenizer, BartForConditionalGeneration,
    Seq2SeqTrainer, Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq, EarlyStoppingCallback
)
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
from rouge_score import rouge_scorer
from bert_score import score
import re
import shutil
import logging
import sqlite3
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner

KeyboardInterrupt: 

In [None]:
# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

In [None]:
# Seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
optuna.trial.FixedTrial.seed = SEED
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Check CUDA
logger.info(f"CUDA available: {torch.cuda.is_available()}")
logger.info(f"CUDA version: {torch.version.cuda}")
logger.info(f"CUDNN version: {torch.backends.cudnn.version()}")

# NLTK downloads
try:
    nltk.download('wordnet', quiet=True)
    nltk.download('omw-1.4', quiet=True)
    nltk.download('punkt', quiet=True)
    nltk.download('punkt_tab', quiet=True)
except Exception as e:
    logger.error(f"NLTK download failed: {e}")
    raise

# Directories
PROJECT_ROOT = r"D:\A_CSE499"
DATA_DIR = os.path.join(PROJECT_ROOT, "data")
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "outputLarge_B")
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)


In [None]:
# Load tokenizer and set device
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
device = torch.device("cuda:0")
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. Please check your PyTorch installation and NVIDIA drivers.")

# Punctuation cleaning function
def fix_punctuation_spacing(text):
    if not isinstance(text, str):
        return text
    text = text.replace(r'\newline', ' ')
    text = re.sub(r'\s+([,.;:!?])', r'\1', text)
    text = re.sub(r'([,.;:!?])([^\s\W])', r'\1 \2', text)
    text = re.sub(r'\(\s+', '(', text)
    text = re.sub(r'\s+\)', ')', text)
    text = re.sub(r'"\s+', '"', text)
    text = re.sub(r'\s+"', '"', text)
    text = re.sub(r"'\s+", "'", text)
    text = re.sub(r"\s+'", "'", text)
    text = re.sub(r'\s*[-–—]+\s*', ' — ', text)
    text = re.sub(r'\s+', ' ', text)
    return text.strip()


In [None]:
# Preprocess function
def preprocess_function(batch):
    inputs = []
    targets = []
    for c, q in zip(batch['context'], batch['question']):
        c_clean = fix_punctuation_spacing(str(c)) if c else ""
        q_clean = fix_punctuation_spacing(str(q)) if q else ""
        if q_clean.startswith("What is the"):
            q_clean = q_clean.replace("What is the", "What can you tell about")
        if c_clean and q_clean:
            inputs.append(c_clean)
            targets.append(q_clean)

    model_inputs = tokenizer(
        inputs,
        max_length=64,
        truncation=True,
        padding="max_length",
        return_tensors="pt"
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            targets,
            max_length=64,
            truncation=True,
            padding="max_length",
            return_tensors="pt"
        )
    model_inputs["labels"] = labels["input_ids"]
    model_inputs["labels"][model_inputs["labels"] == tokenizer.pad_token_id] = -100
    return model_inputs

# Load and validate datasets
def validate_csv(file_path, required_columns):
    if not os.path.exists(file_path):
        logger.error(f"CSV file not found: {file_path}")
        raise FileNotFoundError(f"CSV file not found: {file_path}")
    df = pd.read_csv(file_path)
    for col in required_columns:
        if col not in df.columns:
            logger.error(f"Column '{col}' missing in {file_path}")
            raise ValueError(f"Column '{col}' missing in {file_path}")
    return df

# Load full datasets
df_train_squad = validate_csv(os.path.join(DATA_DIR, "squad_train_v1.csv"), ['context', 'question'])
df_val_squad = validate_csv(os.path.join(DATA_DIR, "squad_validation_v1.csv"), ['context', 'question'])
df_train_mctest = validate_csv(os.path.join(DATA_DIR, "mctest_train.csv"), ['story', 'question'])
df_val_mctest = validate_csv(os.path.join(DATA_DIR, "mctest_validation.csv"), ['story', 'question'])
df_test_mctest = validate_csv(os.path.join(DATA_DIR, "mctest_test.csv"), ['story', 'question'])
df_train_mctest = df_train_mctest.rename(columns={'story': 'context'})
df_val_mctest = df_val_mctest.rename(columns={'story': 'context'})
df_test_mctest = df_test_mctest.rename(columns={'story': 'context'})
df_train_FairytaleQA = validate_csv(os.path.join(DATA_DIR, "fairytaleqa_train.csv"), ['content', 'question'])
df_val_FairytaleQA = validate_csv(os.path.join(DATA_DIR, "fairytaleqa_validation.csv"), ['content', 'question'])
df_test_FairytaleQA = validate_csv(os.path.join(DATA_DIR, "fairytaleqa_test.csv"), ['content', 'question'])
df_train_FairytaleQA = df_train_FairytaleQA.rename(columns={'content': 'context'})
df_val_FairytaleQA = df_val_FairytaleQA.rename(columns={'content': 'context'})
df_test_FairytaleQA = df_test_FairytaleQA.rename(columns={'content': 'context'})


In [None]:
# Combine and shuffle datasets
df_train = pd.concat([df_train_squad, df_train_mctest, df_train_FairytaleQA], ignore_index=True)
df_val = pd.concat([df_val_squad, df_val_mctest, df_val_FairytaleQA], ignore_index=True)
df_test = pd.concat([df_test_mctest, df_test_FairytaleQA], ignore_index=True)
dataset_train = Dataset.from_pandas(df_train).shuffle(seed=SEED).select(range(min(5000, len(df_train))))
dataset_val = Dataset.from_pandas(df_val).shuffle(seed=SEED).select(range(min(1000, len(df_val))))
dataset_test = Dataset.from_pandas(df_test).shuffle(seed=SEED).select(range(min(500, len(df_test))))


In [None]:
# Tokenize datasets
tokenized_dir = os.path.join(OUTPUT_DIR, "tokenized_datasets")
os.makedirs(tokenized_dir, exist_ok=True)

def save_tokenized_datasets(train_dataset, val_dataset, test_dataset):
    if not os.path.exists(os.path.join(tokenized_dir, "train")):
        train_dataset.save_to_disk(os.path.join(tokenized_dir, "train"))
        logger.info("Saved tokenized training dataset")
    if not os.path.exists(os.path.join(tokenized_dir, "val")):
        val_dataset.save_to_disk(os.path.join(tokenized_dir, "val"))
        logger.info("Saved tokenized validation dataset")
    if not os.path.exists(os.path.join(tokenized_dir, "test")):
        test_dataset.save_to_disk(os.path.join(tokenized_dir, "test"))
        logger.info("Saved tokenized test dataset")

def load_tokenized_datasets():
    global processed_train_dataset, processed_val_dataset, processed_test_dataset
    if os.path.exists(os.path.join(tokenized_dir, "train")):
        train_dataset = Dataset.load_from_disk(os.path.join(tokenized_dir, "train"))
        temp = train_dataset.filter(lambda x: all(k in x for k in ["input_ids", "attention_mask", "labels"]))
        if len(temp) == len(train_dataset):
            logger.info("Loaded tokenized training dataset")
        else:
            logger.warning("Invalid train dataset format, re-tokenizing...")
            train_dataset = dataset_train.map(preprocess_function, batched=True, batch_size=50, remove_columns=['context', 'question'])
            train_dataset.save_to_disk(os.path.join(tokenized_dir, "train"))
            logger.info("Tokenized and saved training dataset")
    else:
        train_dataset = dataset_train.map(preprocess_function, batched=True, batch_size=50, remove_columns=['context', 'question'])
        train_dataset.save_to_disk(os.path.join(tokenized_dir, "train"))
        logger.info("Tokenized and saved training dataset")

    if os.path.exists(os.path.join(tokenized_dir, "val")):
        val_dataset = Dataset.load_from_disk(os.path.join(tokenized_dir, "val"))
        temp = val_dataset.filter(lambda x: all(k in x for k in ["input_ids", "attention_mask", "labels"]))
        if len(temp) == len(val_dataset):
            logger.info("Loaded tokenized validation dataset")
        else:
            logger.warning("Invalid val dataset format, re-tokenizing...")
            val_dataset = dataset_val.map(preprocess_function, batched=True, batch_size=50, remove_columns=['context', 'question'])
            val_dataset.save_to_disk(os.path.join(tokenized_dir, "val"))
            logger.info("Tokenized and saved validation dataset")
    else:
        val_dataset = dataset_val.map(preprocess_function, batched=True, batch_size=50, remove_columns=['context', 'question'])
        val_dataset.save_to_disk(os.path.join(tokenized_dir, "val"))
        logger.info("Tokenized and saved validation dataset")

    if os.path.exists(os.path.join(tokenized_dir, "test")):
        test_dataset = Dataset.load_from_disk(os.path.join(tokenized_dir, "test"))
        temp = test_dataset.filter(lambda x: all(k in x for k in ["input_ids", "attention_mask", "labels"]))
        if len(temp) == len(test_dataset):
            logger.info("Loaded tokenized test dataset")
        else:
            logger.warning("Invalid test dataset format, re-tokenizing...")
            test_dataset = dataset_test.map(preprocess_function, batched=True, batch_size=50, remove_columns=['context', 'question'])
            test_dataset.save_to_disk(os.path.join(tokenized_dir, "test"))
            logger.info("Tokenized and saved test dataset")
    else:
        test_dataset = dataset_test.map(preprocess_function, batched=True, batch_size=50, remove_columns=['context', 'question'])
        test_dataset.save_to_disk(os.path.join(tokenized_dir, "test"))
        logger.info("Tokenized and saved test dataset")

    return train_dataset, val_dataset, test_dataset


In [None]:
# Load and assign tokenized datasets globally
processed_train_dataset, processed_val_dataset, processed_test_dataset = load_tokenized_datasets()

# Log dataset sizes for debugging
logger.info(f"Train dataset size: {len(processed_train_dataset)}")
logger.info(f"Validation dataset size: {len(processed_val_dataset)}")
logger.info(f"Test dataset size: {len(processed_test_dataset)}")

# Clean datasets
for ds_name in ["processed_train_dataset", "processed_val_dataset", "processed_test_dataset"]:
    if ds_name in globals():
        ds = globals()[ds_name]
        cleaned = ds.remove_columns([c for c in ds.column_names if c not in ["input_ids", "attention_mask", "labels"]])
        globals()[ds_name] = cleaned

In [None]:
# Evaluation Metrics
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    predictions = np.clip(predictions, 0, tokenizer.vocab_size - 1)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels[labels == -100] = tokenizer.pad_token_id
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    ref_tokens = [[nltk.word_tokenize(ref)] for ref in decoded_labels]
    pred_tokens = [nltk.word_tokenize(pred) for pred in decoded_preds]
    smoothie = SmoothingFunction().method4
    bleu1 = corpus_bleu(ref_tokens, pred_tokens, weights=(1, 0, 0, 0), smoothing_function=smoothie)
    bleu2 = corpus_bleu(ref_tokens, pred_tokens, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothie)
    bleu3 = corpus_bleu(ref_tokens, pred_tokens, weights=(0.33, 0.33, 0.33, 0), smoothing_function=smoothie)
    bleu4 = corpus_bleu(ref_tokens, pred_tokens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothie)
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    rouge_l = sum(
        scorer.score(ref, pred)['rougeL'].fmeasure
        for ref, pred in zip(decoded_labels, decoded_preds)
    ) / len(decoded_labels)
    meteor = sum(
        meteor_score([nltk.word_tokenize(ref)], nltk.word_tokenize(pred))
        for ref, pred in zip(decoded_labels, decoded_preds)
    ) / len(decoded_labels)
    try:
        P, R, F1 = score(decoded_preds, decoded_labels, lang="en", verbose=False)
        bertscore = F1.mean().item()
    except Exception:
        bertscore = 0.0
    return {
        "bleu-1": bleu1, "bleu-2": bleu2, "bleu-3": bleu3, "bleu-4": bleu4,
        "rouge-l": rouge_l, "meteor": meteor, "bertscore": bertscore
    }


In [None]:
# Custom Early Stopping Callback
class CustomEarlyStoppingCallback(EarlyStoppingCallback):
    def __init__(self, early_stopping_patience, min_delta=0.01):
        super().__init__(early_stopping_patience=early_stopping_patience)
        self.min_delta = min_delta
        self.best_metric = float('inf')
        self.early_stopping_patience_counter = 0

    def on_train_begin(self, args, state, control, **kwargs):
        if args.load_best_model_at_end:
            assert args.metric_for_best_model is not None, (
                "EarlyStoppingCallback requires metric_for_best_model to be defined when load_best_model_at_end=True"
            )
        assert args.eval_strategy != "no", (
            "EarlyStoppingCallback requires eval_strategy to be 'steps' or 'epoch'"
        )
        logger.info("Initialized CustomEarlyStoppingCallback")

    def on_evaluate(self, args, state, control, metrics, **kwargs):
        eval_loss = metrics.get('eval_loss', float('inf'))
        if self.best_metric == float('inf') or eval_loss < self.best_metric - self.min_delta:
            self.best_metric = eval_loss
            self.early_stopping_patience_counter = 0
        else:
            self.early_stopping_patience_counter += 1
        if self.early_stopping_patience_counter >= self.early_stopping_patience:
            logger.info(f"Early stopping triggered after {self.early_stopping_patience} evaluations with eval_loss={eval_loss}")
            control.should_training_stop = True




In [None]:
# Save trial results helper
def save_trial_results(study, output_dir):
    trials_data = []
    for trial in study.trials:
        trial_data = {
            'trial_number': trial.number,
            'eval_loss': trial.value if trial.value is not None else float('inf'),
            'state': str(trial.state),
            **trial.params
        }
        trials_data.append(trial_data)
    best_trial_data = {
        'trial_number': study.best_trial.number,
        'eval_loss': study.best_trial.value,
        'state': 'BEST',
        **study.best_params
    }
    trials_data.append(best_trial_data)
    output_path = os.path.join(output_dir, 'optuna_trials.csv')
    mode = 'a' if os.path.exists(output_path) else 'w'
    trials_df = pd.DataFrame(trials_data)
    trials_df.to_csv(output_path, index=False, mode=mode, header=not os.path.exists(output_path))
    logger.info("Saved trial results to optuna_trials.csv")

In [None]:
# Objective function for Optuna
def objective(trial):
    torch.cuda.empty_cache()
    gc.collect()
    
    # Load model
    model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
    model.to(device)
    model.gradient_checkpointing_enable()
    logger.info(f"Gradient checkpointing enabled: {model.is_gradient_checkpointing}")
    
    # Log VRAM usage
    if torch.cuda.is_available():
        vram_used = torch.cuda.memory_allocated(device) / 1024**3
        vram_total = torch.cuda.get_device_properties(device).total_memory / 1024**3
        logger.info(f"VRAM usage after model load: {vram_used:.2f}GB / {vram_total:.2f}GB")
    
    # Generation config
    generation_config = model.generation_config
    generation_config.no_repeat_ngram_size = 3
    generation_config.min_length = 5
    generation_config.max_length = 64
    generation_config.num_beams = 3
    
    # Suggest hyperparameters
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 5e-5, log=True)
    weight_decay = trial.suggest_float("weight_decay", 0.01, 0.1, log=True)
    warmup_steps = trial.suggest_int("warmup_steps", 100, 500, step=50)
    lr_scheduler_type = trial.suggest_categorical("lr_scheduler_type", ["linear", "cosine"])
    
    logger.info(f"Trial {trial.number} parameters: learning_rate={learning_rate}, "
                f"weight_decay={weight_decay}, warmup_steps={warmup_steps}, "
                f"lr_scheduler_type={lr_scheduler_type}")
    
    # Define output directory for this trial
    trial_output_dir = os.path.join(OUTPUT_DIR, f"trial_{trial.number}")
    os.makedirs(trial_output_dir, exist_ok=True)
    
    # Training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=trial_output_dir,
        num_train_epochs=5,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        gradient_accumulation_steps=16,
        dataloader_num_workers=0,
        dataloader_pin_memory=torch.cuda.is_available(),
        lr_scheduler_type=lr_scheduler_type,
        learning_rate=learning_rate,
        warmup_steps=warmup_steps,
        remove_unused_columns=False,
        report_to=[],
        eval_strategy="steps",
        eval_steps=200,
        save_strategy="no",
        weight_decay=weight_decay,
        fp16=torch.cuda.is_available(),
        logging_strategy="steps",
        logging_steps=50,
        predict_with_generate=True,
        generation_max_length=64,
        generation_num_beams=3,
        load_best_model_at_end=False,
        group_by_length=True,
        skip_memory_metrics=True,
        disable_tqdm=True,
    )
    
    # Initialize trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=processed_train_dataset,
        eval_dataset=processed_val_dataset,
        compute_metrics=None,
        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True),
        callbacks=[CustomEarlyStoppingCallback(early_stopping_patience=3)]
    )
    
    # Train the model with mixed precision
    logger.info(f"Starting new training for trial {trial.number}")
    try:
        with torch.cuda.amp.autocast():
            trainer.train()
    except Exception as e:
        logger.error(f"Training failed for trial {trial.number}: {e}")
        raise
    
    # Evaluate and report
    eval_results = trainer.evaluate()
    torch.cuda.empty_cache()
    gc.collect()
    
    # Log VRAM after training
    if torch.cuda.is_available():
        vram_used = torch.cuda.memory_allocated(device) / 1024**3
        logger.info(f"VRAM usage after training: {vram_used:.2f}GB / {vram_total:.2f}GB")
    
    for log in trainer.state.log_history:
        if 'eval_loss' in log:
            step = log.get("step", 0)
            trial.report(log['eval_loss'], step=step)
            if trial.should_prune():
                raise optuna.exceptions.TrialPruned()
    
    return eval_results["eval_loss"]

In [None]:
# Main Optuna study
study_name = "bart_question_generation"
storage_url = f"sqlite:///{os.path.join(OUTPUT_DIR, 'optuna_study.db')}"
db_path = os.path.join(OUTPUT_DIR, "optuna_study.db")
if os.path.exists(db_path):
    try:
        conn = sqlite3.connect(db_path)
        conn.close()
        os.remove(db_path)
        logger.info("Deleted existing Optuna database to start fresh.")
    except (PermissionError, sqlite3.OperationalError) as e:
        logger.warning(f"Could not delete existing database {db_path}: {e}. Reusing existing database.")

try:
    study = optuna.create_study(
        study_name=study_name,
        storage=storage_url,
        direction="minimize",
        sampler=TPESampler(seed=42),  # Bayesian optimization
        pruner=MedianPruner(n_warmup_steps=2),
        load_if_exists=True
    )
    study.optimize(objective, n_trials=15)
    save_trial_results(study, OUTPUT_DIR)
    best_params = study.best_params
    with open(os.path.join(OUTPUT_DIR, "best_params.json"), "w") as f:
        json.dump(best_params, f, indent=4)
    logger.info(f"Best hyperparameters: {best_params}")
    logger.info(f"Best objective value (eval_loss): {study.best_value}")

    # Clean up non-best trials
    best_trial_dir = os.path.join(OUTPUT_DIR, f"trial_{study.best_trial.number}")
    for trial_dir in os.listdir(OUTPUT_DIR):
        if trial_dir.startswith("trial_") and trial_dir != os.path.basename(best_trial_dir):
            try:
                shutil.rmtree(os.path.join(OUTPUT_DIR, trial_dir))
                logger.info(f"Deleted non-best trial directory: {trial_dir}")
            except Exception as e:
                logger.warning(f"Failed to delete trial directory {trial_dir}: {e}")
except Exception as e:
    logger.error(f"Optuna optimization or file saving failed: {e}")
    raise
finally:
    if 'storage' in locals():
        del storage
    logger.info("Closed Optuna storage connections.")

In [None]:
# Final Training with Best Parameters
logger.info("Starting final training with best hyperparameters")

# Load best hyperparameters
best_params = study.best_params
logger.info(f"Using best hyperparameters: {best_params}")

# Load model
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
model.to(device)
model.gradient_checkpointing_enable()
logger.info(f"Gradient checkpointing enabled: {model.is_gradient_checkpointing}")

# Log VRAM usage before training
if torch.cuda.is_available():
    vram_used = torch.cuda.memory_allocated(device) / 1024**3
    vram_total = torch.cuda.get_device_properties(device).total_memory / 1024**3
    logger.info(f"VRAM usage before final training: {vram_used:.2f}GB / {vram_total:.2f}GB")

# Define training arguments with best parameters
training_args = Seq2SeqTrainingArguments(
    output_dir=os.path.join(OUTPUT_DIR, "final_model"),
    num_train_epochs=5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=16,
    dataloader_num_workers=0,
    dataloader_pin_memory=torch.cuda.is_available(),
    lr_scheduler_type=best_params["lr_scheduler_type"],
    learning_rate=best_params["learning_rate"],
    warmup_steps=best_params["warmup_steps"],
    weight_decay=best_params["weight_decay"],
    fp16=torch.cuda.is_available(),
    logging_strategy="steps",
    logging_steps=50,
    eval_strategy="epoch",  # Changed from evaluation_strategy
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    predict_with_generate=True,
    generation_max_length=64,
    generation_num_beams=3,
    group_by_length=True,
    skip_memory_metrics=True,
    disable_tqdm=True,
)

# Initialize data collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True)

# Initialize trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=processed_train_dataset,
    eval_dataset=processed_val_dataset,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
    callbacks=[CustomEarlyStoppingCallback(early_stopping_patience=3)]
)

# Train the model
try:
    logger.info("Starting final training")
    with torch.amp.autocast('cuda'):  # Updated to use correct autocast API
        trainer.train()
    logger.info("Final training completed successfully")
except Exception as e:
    logger.error(f"Final training failed: {e}")
    raise


# Log VRAM usage after training
if torch.cuda.is_available():
    vram_used = torch.cuda.memory_allocated(device) / 1024**3
    logger.info(f"VRAM usage after final training: {vram_used:.2f}GB / {vram_total:.2f}GB")

# Clean up memory
torch.cuda.empty_cache()
gc.collect()


In [None]:
# Save final model and tokenizer
final_model_dir = os.path.join(OUTPUT_DIR, "final_model")
try:
    model.save_pretrained(final_model_dir)
    tokenizer.save_pretrained(final_model_dir)
    logger.info(f"Saved final model and tokenizer to {final_model_dir}")
except Exception as e:
    logger.error(f"Failed to save model or tokenizer: {e}")
    raise

# Evaluate on test set
logger.info("Evaluating on test set")
try:
    test_results = trainer.evaluate(processed_test_dataset)
    logger.info(f"Test set evaluation results: {test_results}")
    results_path = os.path.join(OUTPUT_DIR, "test_results.json")
    with open(results_path, "w") as f:
        json.dump(test_results, f, indent=4)
    logger.info(f"Saved test set results to {results_path}")
except Exception as e:
    logger.error(f"Test set evaluation failed: {e}")
    raise

# Display example predictions
logger.info("Generating example predictions")
num_examples = min(5, len(processed_test_dataset))  # Ensure enough samples
sample_indices = random.sample(range(len(processed_test_dataset)), num_examples)
for idx in sample_indices:
    sample = processed_test_dataset[idx]
    input_ids = torch.tensor(sample["input_ids"]).unsqueeze(0).to(device)
    attention_mask = torch.tensor(sample["attention_mask"]).unsqueeze(0).to(device)
    
    # Generate prediction
    model.eval()
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=64,
            num_beams=3,
            no_repeat_ngram_size=3,
            min_length=5
        )
    predicted_question = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Filter out -100 from labels before decoding
    valid_labels = [token for token in sample["labels"] if token >= 0]
    ground_truth_question = tokenizer.decode(valid_labels, skip_special_tokens=True)
    
    context = tokenizer.decode(sample["input_ids"], skip_special_tokens=True)
    
    logger.info(f"\nExample {idx}:")
    logger.info(f"Context: {context}")
    logger.info(f"Ground Truth Question: {ground_truth_question}")
    logger.info(f"Predicted Question: {predicted_question}")

# Clean up memory
torch.cuda.empty_cache()
gc.collect()