In [None]:
import json
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path

def plot_training_metrics(log_file_path, metrics_to_plot=None, filter_by_name=None):
    """
    Plot training metrics from trainer_log.jsonl file
    
    Args:
        log_file_path: Path to the trainer_log.jsonl file
        metrics_to_plot: List of metric names to plot (e.g., ['loss', 'accuracy', 'eval_loss'])
        filter_by_name: Filter entries by 'name' field (e.g., 'loss', 'train_loss', 'eval_dpo_eval_loss')
    """
    
    # Read the JSONL file
    data = []
    with open(log_file_path, 'r') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    
    df = pd.DataFrame(data)
    
    # Filter by name if specified
    if filter_by_name:
        df = df[df['name'] == filter_by_name]
    
    # Default metrics to plot if not specified
    if metrics_to_plot is None:
        # Find available numeric columns excluding some metadata
        exclude_cols = {'current_steps', 'total_steps', 'epoch', 'percentage', 'lr'}
        numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns
        metrics_to_plot = [col for col in numeric_cols if col not in exclude_cols]
    
    # Create subplots
    num_metrics = len(metrics_to_plot)
    if num_metrics == 0:
        print("No metrics to plot")
        return
    
    fig, axes = plt.subplots(num_metrics, 1, figsize=(12, 4 * num_metrics))
    if num_metrics == 1:
        axes = [axes]
    
    # Plot each metric
    for i, metric in enumerate(metrics_to_plot):
        if metric in df.columns:
            # Use current_steps as x-axis if available, otherwise use index
            x_data = df['current_steps'] if 'current_steps' in df.columns else df.index
            y_data = df[metric].dropna()
            x_data = x_data[:len(y_data)]
            
            axes[i].plot(x_data, y_data, marker='o', linewidth=2, markersize=4)
            axes[i].set_xlabel('Steps' if 'current_steps' in df.columns else 'Entry')
            axes[i].set_ylabel(metric)
            axes[i].set_title(f'{metric} vs Steps')
            axes[i].grid(True, alpha=0.3)
        else:
            axes[i].text(0.5, 0.5, f'Metric "{metric}" not found', 
                        transform=axes[i].transAxes, ha='center', va='center')
            axes[i].set_title(f'{metric} (not found)')
    
    plt.tight_layout()
    plt.show()

# Example usage:
# Uncomment and modify the path to your trainer_log.jsonl file

# Plot all available metrics
# plot_training_metrics('saves/qwen2_5-0_5b/full/dpo_train_ai_flipped/trainer_log.jsonl')

# Plot specific metrics
# plot_training_metrics('saves/qwen2_5-0_5b/full/dpo_train_ai_flipped/trainer_log.jsonl', 
#                       metrics_to_plot=['loss', 'accuracy'])

# Plot metrics for specific name
# plot_training_metrics('saves/qwen2_5-0_5b/full/dpo_train_ai_flipped_robust/trainer_log.jsonl',
#                       filter_by_name='eval_dpo_eval_loss',
#                       metrics_to_plot=['eval_chosen_rewards', 'eval_rejected_rewards', 'eval_rewards_accuracies'])
