In [None]:
# -*- coding: utf-8 -*-
"""
2_FineTuning_Results.ipynb

This notebook visualizes the training and validation loss curves for each fine-tuning method.
It aims to show the learning progress and stability of different approaches during training.
"""

# Import necessary libraries
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import glob
import json

from src.config import MODELS_DIR, FIGURES_DIR, REPORT_TITLE

# Define paths to training logs
LOG_DIRS = {
    "Full Fine-tuning": os.path.join(MODELS_DIR, "finetuned_full", "runs"), # Adjust if logs are in a different structure
    "LoRA": os.path.join(MODELS_DIR, "finetuned_lora", "runs"),
    "QLoRA": os.path.join(MODELS_DIR, "finetuned_qlora", "runs"),
    "Adapter (IA3)": os.path.join(MODELS_DIR, "finetuned_adapter", "runs"),
    "Prompt-tuning": os.path.join(MODELS_DIR, "finetuned_prompt_tuning", "runs"),
}

# Function to extract logs (simplified, might need refinement based on actual log format)
def extract_logs(log_dir: str) -> pd.DataFrame:
    # Assuming logs are in `trainer_state.json` or similar format in the latest run directory
    # This is a simplified approach and might need adjustment.
    # A more robust solution would involve parsing event files or a custom logger.
    log_files = glob.glob(os.path.join(log_dir, "**/trainer_state.json"), recursive=True)
    if not log_files:
        print(f"No trainer_state.json found in {log_dir}. Trying other log formats...")
        # Fallback to a simpler log parsing if trainer_state.json is not available
        # For example, if logs are printed to console and redirected to a file
        # This part is highly dependent on how training logs are actually saved.
        return pd.DataFrame() # Return empty if no known log format found
    
    # Get the latest log file if multiple exist
    latest_log_file = max(log_files, key=os.path.getctime)
    
    try:
        with open(latest_log_file, 'r', encoding='utf-8') as f:
            trainer_state = json.load(f)
        
        # Extract relevant metrics (loss, eval_loss, epoch, step)
        log_history = trainer_state.get('log_history', [])
        df = pd.DataFrame(log_history)
        return df
    except Exception as e:
        print(f"Error reading log file {latest_log_file}: {e}")
        return pd.DataFrame()

all_training_logs = {}
for model_type, path in LOG_DIRS.items():
    if os.path.exists(path):
        logs = extract_logs(path)
        if not logs.empty:
            all_training_logs[model_type] = logs
            print(f"Extracted logs for {model_type}: {len(logs)} entries.")
        else:
            print(f"No valid logs extracted for {model_type}.")
    else:
        print(f"Log directory not found for {model_type}: {path}")


# --- Visualization: Training & Validation Loss ---
if all_training_logs:
    print("\nGenerating Training and Validation Loss Plots...")
    plt.figure(figsize=(15, 8))
    
    for model_type, logs_df in all_training_logs.items():
        if 'loss' in logs_df.columns and 'eval_loss' in logs_df.columns and 'epoch' in logs_df.columns:
            # Plot training loss
            plt.plot(logs_df['epoch'], logs_df['loss'], label=f'{model_type} Training Loss', alpha=0.7)
            # Plot validation loss (if available for all epochs)
            # Note: eval_loss might not be logged for every step, only at evaluation_strategy epochs
            eval_logs = logs_df.dropna(subset=['eval_loss'])
            if not eval_logs.empty:
                plt.plot(eval_logs['epoch'], eval_logs['eval_loss'], label=f'{model_type} Validation Loss', linestyle='--')
            
    plt.title(f'Training and Validation Loss Across Fine-tuning Methods ({REPORT_TITLE})', fontsize=16)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.0)
    plt.tight_layout()
    output_path = os.path.join(FIGURES_DIR, "training_validation_loss_curves.png")
    plt.savefig(output_path)
    print(f"Training and Validation Loss Plots saved to {output_path}")
    plt.show()
else:
    print("No training logs found to generate plots.")

print("Fine-tuning results visualization setup complete. Run this notebook after training scripts.")
