In [None]:
# %load_ext autoreload
# %autoreload 2

import json
import os
from collections import Counter
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from llm_synthesis.utils.style_utils import get_cmap, set_style


def load_result_json(file_path: str) -> list[dict[str, Any]]:
    """Load and parse result.json file."""
    try:
        with open(file_path, encoding="utf-8") as f:
            data = json.load(f)
        return data
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        return []


def extract_synthesis_data(synthesis: dict[str, Any]) -> dict[str, Any]:
    """Extract evaluation scores and metadata."""
    data = {
        "scores": {},
        "synthesis_method": None,
        "target_compound_type": None,
        "material": None,
    }
    if "evaluation" in synthesis and "scores" in synthesis["evaluation"]:
        for key, value in synthesis["evaluation"]["scores"].items():
            if key.endswith("_score"):
                data["scores"][key] = float(value)
    if "synthesis" in synthesis:
        data["synthesis_method"] = synthesis["synthesis"].get(
            "synthesis_method"
        )
        data["target_compound_type"] = synthesis["synthesis"].get(
            "target_compound_type"
        )
    data["material"] = synthesis.get("material")
    return data


def process_subdirectory(subdir_path: str) -> dict[str, Any]:
    """Process a single subdirectory."""
    result_file = os.path.join(subdir_path, "result.json")
    syntheses = load_result_json(result_file)
    if not syntheses:
        return {}

    all_scores, methods, types, materials = [], [], [], []
    for s in syntheses:
        data = extract_synthesis_data(s)
        if data["scores"]:
            all_scores.append(data["scores"])
            methods.append(data["synthesis_method"])
            types.append(data["target_compound_type"])
            materials.append(data["material"])

    if not all_scores:
        return {}

    all_categories = {k for scores in all_scores for k in scores.keys()}
    avg_scores = {
        cat: np.nanmean([s.get(cat, np.nan) for s in all_scores])
        for cat in all_categories
    }

    return {
        "subdir_name": os.path.basename(subdir_path),
        "scores": avg_scores,
        "synthesis_count": len(all_scores),
        "synthesis_methods": methods,
        "target_compound_types": types,
        "materials": materials,
    }


def analyze_metadata(results: list[dict[str, Any]]) -> dict[str, Any]:
    """Analyze metadata across all results."""
    all_methods = [
        m for r in results if r for m in r.get("synthesis_methods", []) if m
    ]
    all_types = [
        t for r in results if r for t in r.get("target_compound_types", []) if t
    ]
    return {
        "synthesis_methods": dict(Counter(all_methods)),
        "target_compound_types": dict(Counter(all_types)),
    }


def calculate_scores_by_category(
    my_dir: str, results: list[dict[str, Any]], category_field: str
) -> dict[str, dict[str, float]]:
    """Calculate average scores grouped by a specific category."""
    category_data = {}
    for result in filter(None, results):
        subdir_path = os.path.join(my_dir, result["subdir_name"])
        syntheses = load_result_json(os.path.join(subdir_path, "result.json"))
        for s in syntheses:
            data = extract_synthesis_data(s)
            category = data.get(category_field)
            if category and data["scores"]:
                category_data.setdefault(category, []).append(data["scores"])

    category_averages = {}
    for category, score_lists in category_data.items():
        all_score_types = {k for scores in score_lists for k in scores.keys()}
        avg_scores = {
            st: np.nanmean([s.get(st, np.nan) for s in score_lists])
            for st in all_score_types
        }
        category_averages[category] = {
            k: v for k, v in avg_scores.items() if not np.isnan(v)
        }
    return category_averages


# =============================================================================
# Plotting Functions (Updated)
# =============================================================================


def plot_scores_by_category(
    scores_by_category: dict[str, dict[str, float]],
    category_name: str,
    counts: dict[str, int],
    cmap,
):
    """Generate and save bar plots for each score type."""
    df = pd.DataFrame.from_dict(scores_by_category, orient="index")
    if not counts:
        return

    norm = plt.Normalize(
        vmin=min(counts.values()) or 0, vmax=max(counts.values()) or 1
    )

    for score_column in df.columns:
        fig, ax = plt.subplots()
        sorted_scores = df[score_column].sort_values(ascending=False)
        bar_colors = [
            cmap(norm(counts.get(cat, 0))) for cat in sorted_scores.index
        ]

        sorted_scores.plot(kind="bar", ax=ax, color=bar_colors)

        ax.set_title(
            f"Average {score_column.replace('_', ' ').title()} by {category_name.replace('_', ' ').title()}"
        )
        ax.set_xlabel(category_name.replace("_", " ").title())
        ax.set_ylabel("Average Score")
        ax.set_ylim(0, 5)
        plt.xticks(rotation=45, ha="right")

        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = fig.colorbar(sm, ax=ax)
        cbar.set_label("Number of Entries")

        plt.tight_layout()
        # plt.savefig(f"{category_name}_{score_column}_plot_colored.png")
        # plt.close(fig)


def plot_average_of_other_scores(
    scores_by_category: dict[str, dict[str, float]],
    category_name: str,
    counts: dict[str, int],
    cmap,
):
    """Plot the average of all scores, excluding 'overall_score'."""
    df = pd.DataFrame.from_dict(scores_by_category, orient="index")
    score_columns = [col for col in df.columns if col != "overall_score"]
    if not score_columns or not counts:
        return

    df["average_of_other_scores"] = df[score_columns].mean(axis=1)
    norm = plt.Normalize(
        vmin=min(counts.values()) or 0, vmax=max(counts.values()) or 1
    )

    fig, ax = plt.subplots()
    sorted_scores = df["average_of_other_scores"].sort_values(ascending=False)
    bar_colors = [cmap(norm(counts.get(cat, 0))) for cat in sorted_scores.index]

    sorted_scores.plot(kind="bar", ax=ax, color=bar_colors)

    ax.set_title(
        f"Average of Other Scores by {category_name.replace('_', ' ').title()}"
    )
    ax.set_xlabel(category_name.replace("_", " ").title())
    ax.set_ylabel("Average Score")
    ax.set_ylim(0, 5)
    plt.xticks(rotation=45, ha="right")

    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax)
    cbar.set_label("Number of Entries")

    plt.tight_layout()
    # plt.savefig(f"{category_name}_average_of_other_scores_plot_colored.png")
    # plt.close(fig)


# =============================================================================
# Main Execution
# =============================================================================


def main(
    my_dir: str,
    output_file: str = "synthesis_evaluation_results.csv",
    cmap=None,
):
    """Main function to process all subdirectories and calculate averages."""
    if not os.path.exists(my_dir):
        print(f"Error: Directory {my_dir} does not exist!")
        return

    subdirs = [
        d for d in os.listdir(my_dir) if os.path.isdir(os.path.join(my_dir, d))
    ]
    if not subdirs:
        print(f"No subdirectories found in {my_dir}")
        return

    results = [process_subdirectory(os.path.join(my_dir, s)) for s in subdirs]
    metadata = analyze_metadata(results)

    # --- Plotting by Synthesis Method ---
    method_scores = calculate_scores_by_category(
        my_dir, results, "synthesis_method"
    )
    if method_scores:
        plot_scores_by_category(
            method_scores,
            "synthesis_method",
            metadata["synthesis_methods"],
            cmap,
        )
        plot_average_of_other_scores(
            method_scores,
            "synthesis_method",
            metadata["synthesis_methods"],
            cmap,
        )

    # --- Plotting by Target Compound Type ---
    type_scores = calculate_scores_by_category(
        my_dir, results, "target_compound_type"
    )
    if type_scores:
        plot_scores_by_category(
            type_scores,
            "target_compound_type",
            metadata["target_compound_types"],
            cmap,
        )
        plot_average_of_other_scores(
            type_scores,
            "target_compound_type",
            metadata["target_compound_types"],
            cmap,
        )

    print("Plot generation complete.")


MY_DIR = "/Users/magdalenalederbauer/Code/lematerial-llm-synthesis/results/single_run/2025-07-31/14-14-06/results"
# Set the plot style
set_style("manuscript")

cmap = get_cmap()

# Run the main analysis and plotting function
main(my_dir=MY_DIR, cmap=cmap)

In [None]:
def main(
    my_dir: str,
    output_file: str = "synthesis_evaluation_results.csv",
    cmap=None,
):
    """Main function to process all subdirectories and calculate averages."""
    if not os.path.exists(my_dir):
        print(f"Error: Directory {my_dir} does not exist!")
        return

    subdirs = [
        d for d in os.listdir(my_dir) if os.path.isdir(os.path.join(my_dir, d))
    ]
    if not subdirs:
        print(f"No subdirectories found in {my_dir}")
        return

    results = [process_subdirectory(os.path.join(my_dir, s)) for s in subdirs]
    metadata = analyze_metadata(results)

    # --- Plotting by Synthesis Method ---
    method_scores = calculate_scores_by_category(
        my_dir, results, "synthesis_method"
    )
    if method_scores:
        plot_scores_by_category(
            method_scores,
            "synthesis_method",
            metadata["synthesis_methods"],
            cmap,
        )
        plot_average_of_other_scores(
            method_scores,
            "synthesis_method",
            metadata["synthesis_methods"],
            cmap,
        )

    # --- Plotting by Target Compound Type ---
    type_scores = calculate_scores_by_category(
        my_dir, results, "target_compound_type"
    )
    if type_scores:
        plot_scores_by_category(
            type_scores,
            "target_compound_type",
            metadata["target_compound_types"],
            cmap,
        )
        plot_average_of_other_scores(
            type_scores,
            "target_compound_type",
            metadata["target_compound_types"],
            cmap,
        )

    print("Plot generation complete.")


MY_DIR = "/Users/magdalenalederbauer/Code/lematerial-llm-synthesis/results/single_run/2025-07-31/14-14-06/results"
# Set the plot style
set_style("manuscript")

cmap = get_cmap()

# Run the main analysis and plotting function
main(my_dir=MY_DIR, cmap=cmap)

In [None]:
# implement error bars
# implement radar chart
# add new eval samples
# implement synthesis match score