In [None]:
# -*- coding: utf-8 -*-
"""
4_Sample_Predictions.ipynb

This notebook demonstrates sample text summarization predictions from each fine-tuned TinyLLaMA model.
It aims to provide a qualitative comparison of the summary quality across different fine-tuning methods.
"""

# Import necessary libraries
import pandas as pd
import os
import torch

from src.config import RAW_DATA_DIR
from src.prepare_data import load_raw_data
from src.predict import load_model_for_prediction, generate_summary

# --- Load Test Data ---
print("Loading raw test dataset...")
test_df = load_raw_data("test")
print("Raw test dataset loaded successfully.")

# --- Select Sample Articles ---
# You can choose specific indices or a random set
sample_indices = [0, 1, 2] # Example indices
# If you want random samples:
# import random
# random.seed(42)
# sample_indices = random.sample(range(len(test_df)), 3)

print(f"\nSelected sample indices for prediction: {sample_indices}")

# --- Define Models to Evaluate ---
# Ensure these model types match the options in `src/predict.py`
MODEL_TYPES = ["full", "lora", "qlora", "adapter", "prompt_tuning"]

# Store loaded models to avoid reloading in the loop
loaded_models = {}
loaded_tokenizers = {}

print("\nLoading all fine-tuned models for prediction...")
for model_type in MODEL_TYPES:
    try:
        model, tokenizer = load_model_for_prediction(model_type)
        loaded_models[model_type] = model
        loaded_tokenizers[model_type] = tokenizer # Tokenizer should be the same, but store for consistency
        print(f"Loaded {model_type} model successfully.")
    except Exception as e:
        print(f"Error loading {model_type} model: {e}")
        loaded_models[model_type] = None # Mark as failed


# --- Generate and Display Predictions ---
print("\n--- Generating and Displaying Sample Predictions ---")
for idx in sample_indices:
    article = test_df.loc[idx, 'article']
    reference_summary = test_df.loc[idx, 'summary']

    print(f"\n===== Sample {idx} =====")
    print(f"Original Article (truncated): {article[:500]}...") # Truncate for display
    print(f"Reference Summary: {reference_summary}")

    for model_type in MODEL_TYPES:
        if loaded_models[model_type] is not None:
            print(f"\n--- {model_type.upper()} Generated Summary ---")
            # Ensure the correct tokenizer is used with the model
            generated_summary = generate_summary(loaded_models[model_type], loaded_tokenizers[model_type], article)
            print(f"Generated Summary: {generated_summary}")
        else:
            print(f"\n--- {model_type.upper()} Model (not loaded, skipping prediction) ---")
    print("\n" + "="*40)

print("Sample predictions generation complete. Review the outputs to compare summary quality.")
