# Script to plot relevant plots for Fig. 2 (NeurIPS AI4Mat)


In [None]:
import json
import os

from llm_synthesis.utils.style_utils import get_cmap, get_palette, set_style

cmap = get_cmap()
palette = get_palette()
set_style()

In [None]:
from datasets import load_dataset

ds = load_dataset("LeMaterial/LeMat-Synth")
ds

In [None]:
df = ds["sample_for_evaluation"].to_pandas()

df.head()

In [None]:
df["paper_published_date"].unique()

In [None]:
# get the first 4 numbers of the paper_published_date
df["paper_published_date"] = df["paper_published_date"].str[:4]

In [None]:
# df["source"] is arxiv when paper_url contains arxiv.org, chemrxiv when paper_url contains chemrxiv.org, else it is "omg24"

df["source"] = df["paper_url"].apply(
    lambda x: "arxiv"
    if "arxiv.org" in x
    else "chemrxiv"
    if "chemrxiv.org" in x
    else "omg24"
)

df.groupby("source").size()

In [None]:
df[["material_category", "synthesis_method", "paper_published_date", "source"]]

In [None]:
output_dir = os.getcwd()

# Define the full path for the output file
file_path = os.path.join(output_dir, "dataset_statistics_with_source.csv")


df[
    ["material_category", "synthesis_method", "paper_published_date", "source"]
].to_csv(file_path, index=False)

In [None]:
# Auto-generate colors for material categories
unique_materials = sorted(df["material_category"].unique())
unique_synthesis = sorted(df["synthesis_method"].unique())
material_colors = dict(zip(unique_materials, palette))
unique_synthesis

In [None]:
# rename all entries where synthesis_method is iCVD to CVD
df.loc[df["synthesis_method"] == "iCVD", "synthesis_method"] = "CVD"
df.loc[df["synthesis_method"] == "MOCVD", "synthesis_method"] = "CVD"
df.loc[
    df["synthesis_method"] == "pulsed laser deposition", "synthesis_method"
] = "PLD"
df.loc[
    df["synthesis_method"] == "molecular beam epitaxy", "synthesis_method"
] = "MBE"
df.loc[
    df["synthesis_method"] == "filtered vacuum arc deposition (FVAD)",
    "synthesis_method",
] = "FVAD"
df.loc[
    df["synthesis_method"] == "filtered vacuum arc deposition",
    "synthesis_method",
] = "FVAD"
df.loc[
    df["synthesis_method"] == "atomic layer deposition",
    "synthesis_method",
] = "ALD"
unique_synthesis = sorted(df["synthesis_method"].unique())
unique_synthesis

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

set_style()
# Find top 5 most common material categories for bar plots
material_counts = df["material_category"].value_counts()
top_5_materials = material_counts.head(7).index.tolist()

print("📊 MATERIAL CATEGORY ANALYSIS:")
print(f"   Total categories: {len(material_counts)}")
print(f"   Category counts: {dict(material_counts)}")
print(f"   Top 5 for bar plots: {top_5_materials}")


# Create a mapping function for bar plots (group less common as "other")
def group_materials_for_bars(category):
    if category in top_5_materials:
        return category
    else:
        return "other"


# Apply grouping for bar plots
df["material_category_grouped"] = df["material_category"].apply(
    group_materials_for_bars
)

# Auto-generate colors for all original categories (for heatmaps)
unique_materials = sorted(df["material_category"].unique())
unique_synthesis = sorted(df["synthesis_method"].unique())
material_colors = dict(
    zip(unique_materials, sns.color_palette("Set2", len(unique_materials)))
)

# Auto-generate colors for grouped categories (for bar plots)
unique_materials_grouped = sorted(df["material_category_grouped"].unique())
material_colors_grouped = dict(
    zip(
        unique_materials_grouped,
        sns.color_palette("Set2", len(unique_materials_grouped)),
    )
)


sources = ["arxiv", "chemrxiv", "omg24"]
source_titles = ["ArXiv", "ChemRxiv", "OMG24"]

for source, title in zip(sources, source_titles):
    print(f"\n{'=' * 80}")
    print(f"VISUALIZING: {title.upper()}")
    print("=" * 80)

    # Filter data for this source
    source_data = df[df["source"] == source]

    if len(source_data) == 0:
        print(f"No data for {source}")
        continue

    years = sorted(source_data["paper_published_date"].unique())

    # =======================================================================
    # 1. BAR PLOT: Count vs Year, colored by Material Category (Top 5 only)
    # =======================================================================
    print(
        f"\n📊 BAR PLOT: {title} Papers Over Time by Material Category (Top 5)"
    )

    fig, ax = plt.subplots(figsize=(8, 6))

    # Prepare data for stacked bars using GROUPED categories
    bar_data = (
        source_data.groupby(
            ["paper_published_date", "material_category_grouped"]
        )
        .size()
        .unstack(fill_value=0)
    )

    # Create stacked bar plot
    bar_data.plot(
        kind="bar",
        stacked=True,
        ax=ax,
        color=[material_colors_grouped[col] for col in bar_data.columns],
        edgecolor="black",
        linewidth=0.8,
    )

    # Customize bar plot
    ax.set_title(
        f"{title}: Research Papers by Material Category Over Time\n(Top 5 Most Common Categories)",
        fontsize=14,
        fontweight="bold",
        pad=20,
    )
    ax.set_xlabel("Year", fontsize=12)
    ax.set_ylabel("Number of Papers", fontsize=12)
    ax.tick_params(axis="x", rotation=0)
    ax.grid(axis="y", alpha=0.3)
    ax.legend(
        title="Material Category", bbox_to_anchor=(1.05, 1), loc="upper left"
    )

    # Add value labels on bars
    for container in ax.containers:
        ax.bar_label(
            container,
            label_type="center",
            fontweight="bold",
            labels=[str(int(v)) if v > 0 else "" for v in container.datavalues],
        )

    plt.tight_layout()
    plt.show()

    # =======================================================================
    # 2. HEATMAPS: Material Category vs Synthesis Method for Each Year
    # =======================================================================
    print(f"\n🔥 HEATMAPS: {title} Material vs Synthesis Method by Year")

    # Get ALL unique categories and methods across the entire dataset for consistent dimensions
    all_materials = sorted(df["material_category"].unique())
    all_synthesis = sorted(df["synthesis_method"].unique())

    for year in years:
        year_data = source_data[source_data["paper_published_date"] == year]

        if len(year_data) == 0:
            continue

        # Create pivot table for heatmap
        heatmap_data = (
            year_data.groupby(["material_category", "synthesis_method"])
            .size()
            .unstack(fill_value=0)
        )

        # Ensure consistent dimensions by reindexing to include ALL categories and methods
        heatmap_data = heatmap_data.reindex(
            index=all_materials, columns=all_synthesis, fill_value=0
        )

        # Always create heatmap (even if mostly zeros) for consistent comparison
        fig, ax = plt.subplots(figsize=(8, 6))

        # Create heatmap with consistent color scale
        sns.heatmap(
            heatmap_data,
            # annot=True,  # Show numbers
            fmt="d",  # Integer format
            cmap="Blues",  # Nice color scheme
            ax=ax,
            cbar_kws={"label": "Number of Papers"},
            linewidths=0.5,  # Add grid lines
            square=False,  # Don't force square cells
            vmin=0,  # Consistent color scale starting at 0
            vmax=max(3, heatmap_data.values.max()),
        )  # Consistent max scale

        ax.set_title(
            f"{title} {year}: Material Category vs Synthesis Method\n({len(year_data)} papers total)",
            fontsize=14,
            fontweight="bold",
        )
        ax.set_xlabel("Synthesis Method", fontsize=12)
        ax.set_ylabel("Material Category", fontsize=12)

        # Rotate x-axis labels for better readability
        plt.setp(ax.get_xticklabels(), rotation=90, ha="right")
        plt.setp(ax.get_yticklabels(), rotation=0)

        plt.tight_layout()
        plt.show()

        # Print summary for this year
        total_papers = len(year_data)
        print(f"  📋 {year}: {total_papers} papers total")

        # Show breakdown
        material_breakdown = year_data["material_category"].value_counts()
        synthesis_breakdown = year_data["synthesis_method"].value_counts()

        print(f"     Materials: {dict(material_breakdown)}")
        print(f"     Synthesis: {dict(synthesis_breakdown)}")

    # =======================================================================
    # 3. SUMMARY STATS
    # =======================================================================
    print(f"\n📈 SUMMARY: {title}")
    print("-" * 40)
    print(f"Total papers: {len(source_data)}")
    print(f"Years covered: {min(years)} - {max(years)}")
    print(
        f"Material categories (original): {len(source_data['material_category'].unique())}"
    )
    print(
        f"Material categories (bar plot): {len(source_data['material_category_grouped'].unique())}"
    )
    print(f"Synthesis methods: {len(source_data['synthesis_method'].unique())}")

    print("\n📊 Material breakdown (bar plot grouping):")
    grouped_breakdown = source_data["material_category_grouped"].value_counts()
    for category, count in grouped_breakdown.items():
        print(f"   {category}: {count} papers")

print(f"\n{'=' * 80}")
print("🎉 ALL VISUALIZATIONS COMPLETE!")
print("=" * 80)

# Quick overall summary
print("\n📊 OVERALL DATA SUMMARY:")
print("=" * 50)
for source in sources:
    source_count = len(df[df["source"] == source])
    print(f"{source.upper()}: {source_count} papers")

In [None]:
from datasets import load_dataset

ds_paper = load_dataset(
    "LeMaterial/LeMat-Synth-Papers", subset="full", split="arxiv"
)

df_paper = ds_paper.to_pandas()

df_paper.head()

In [None]:
df.columns

In [None]:
annotation_folder = (
    "/Users/magdalenalederbauer/Code/lematerial-llm-synthesis/annotations"
)

# for every subdir in annotation_folder
for subdir in os.listdir(annotation_folder):
    # id = name of subdir
    id = subdir
    id = id.replace("cond-mat.", "cond-mat/")
    synthesis_procedures_of_paper = df_paper[df_paper["id"] == id]
    url_of_paper = (
        synthesis_procedures_of_paper["pdf_url"]
        .values[0]
        .replace("https://", "")
    )
    matched_lemat_synth_entry = df[df["paper_url"].str.contains(url_of_paper)]
    if len(matched_lemat_synth_entry) == 0:
        continue
    if (
        matched_lemat_synth_entry["synthesized_material"].values[0]
        == "No materials synthesized"
    ):
        continue

    # result_llm = subdir/result.json
    result_llm = os.path.join(annotation_folder, subdir, "result.json")
    result_human = os.path.join(annotation_folder, subdir, "result_human.json")

    # load llm_ontology as json
    llm_ontology = json.loads(open(result_llm).read())
    try:
        human_ontology = json.loads(open(result_human).read())
    except FileNotFoundError:
        # print(f"No human ontology for {id}")
        # human ontology is a list of empty dicts in same format as llm_ontology
        human_ontology = [{} for _ in llm_ontology]

    for idx, (item_llm, item_human) in enumerate(
        zip(llm_ontology, human_ontology)
    ):
        mat_name = item_llm["material"]
        synthesis = item_llm["synthesis"]
        evaluation_llm = item_llm["evaluation"]
        evaluation_human = item_human["evaluation"] if item_human else None
        # fill the first row of matched_lemat_synth_entry with the values
        try:
            matched_lemat_synth_entry.iloc[idx] = {
                "synthesized_material": mat_name,
                "synthesis": synthesis,
                "synthesis_extraction_performance_llm": evaluation_llm,
                "synthesis_extraction_performance_human": evaluation_human,
            }
        except Exception:
            print(f"Error filling row {idx} of {id}")
            print(matched_lemat_synth_entry)
            print(mat_name)
            print(synthesis)
            print(evaluation_llm)
            print(evaluation_human)
            break

In [None]:
llm_ontology[0].keys()

In [None]:
matched_lemat_synth_entry