# Finetune MAE Trajectories for Polymer Tasks

This notebook summarizes the finetuning behaviour of the dynamic pretrain suite.
For each PI property, we track the MAE reported after finetuning on models
trained with progressively larger pretrain task sequences. The goal is to highlight
how performance evolves as the target task is introduced into the pretrain stages.

In [1]:
import json
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
ARTIFACT_ROOT = Path("../artifacts/radonpy_polyinfo_pretrain_finetune_runs_251206")
RUN_PATTERN = "run*"
# FINETUNE_TASKS = [
#     "density",
#     "Rg",
#     "r2",
#     "self-diffusion",
#     "Cp",
#     "Cv",
#     "linear_expansion",
#     "refractive_index",
#     "tg",
# ]

FINETUNE_TASKS = [
    'Tg',
    'Thermal decomposition temp-Temperature',
    'Tm',
    'Density',
    'Tensile stress at break',
    'Elongation at break',
    'Gas permeability coefficient O2',
    'Water absorption',
    'Dielectric const AC freq-all',
    'Tensile modulus',
    'Solubility parameter',
    'Gas permeability coefficient CO2',
    'Heat of fusion',
    'Gas permeability coefficient N2',
    'Refractive index',
    'Linear expansion coefficient',
    'Dielectric const AC freq-low',
    'Volume expansion coefficient',
    'Tensile storage modulus freq-all',
    'Crystallization temp',
    'Tensile stress at yield',
    'Contact angle',
    'Surface tension',
    'Tensile storage modulus loss tangent freq-all',
    'Cp',
    'Tensile storage modulus freq-middle',
    'Tensile loss modulus freq-all',
    'Dielectric const AC freq-high',
    'Dielectric const AC freq-middle',
    'Vicat softening temp',
    'Brittleness temp',
    'Water vapor transmission',
    'Hansen parameter delta-d',
    'Hansen parameter delta-p',
    'Heat of crystallization',
    'Hansen parameter delta-h',
    'Shear storage modulus freq-all',
    'Elongation at yield',
    'Tensile storage modulus freq-high',
    'Tensile storage modulus loss tangent freq-high',
    'Shear storage modulus loss tangent freq-all',
    'Tensile storage modulus loss tangent freq-middle',
    'HDT',
    'Shear storage modulus freq-low',
    'Tensile storage modulus freq-low',
    'Tensile loss modulus freq-high',
    'Thermal conductivity',
    'Shear storage modulus loss tangent freq-low',
    'Stress-optical coefficient',
    'Dielectric const DC',
    'Flexural modulus',
    'Flexural storage modulus freq-all',
    'Tensile loss modulus freq-middle',
    'Shore hardness',
    'Flexural storage modulus freq-low',
    'Growth rate of crystal',
    'Tensile storage modulus loss tangent freq-low',
    'Flexural stress at break',
    'Tensile loss modulus freq-low',
    'Charpy impact',
    'Compressive modulus',
]

if not ARTIFACT_ROOT.exists():
    raise FileNotFoundError(f"Expected artifact directory at {ARTIFACT_ROOT}")

In [3]:
def load_finetune_records(artifact_root: Path) -> pd.DataFrame:
    records: list[dict[str, object]] = []
    for run_dir in sorted(artifact_root.glob(RUN_PATTERN)):
        if not run_dir.is_dir():
            continue
        run_label = run_dir.name
        for stage_dir in sorted(run_dir.glob("pretrain_stage*")):
            pretrain_metrics_path = stage_dir / "prediction" / "metrics.json"
            if not pretrain_metrics_path.exists():
                continue
            with pretrain_metrics_path.open("r", encoding="utf-8") as handle:
                pretrain_payload = json.load(handle)
            stage_sequence = pretrain_payload.get("task_sequence") or []
            try:
                stage_index = int(stage_dir.name.split("pretrain_stage")[1].split("_")[0])
            except (IndexError, ValueError):
                continue
            finetune_root = stage_dir / "finetune"
            if not finetune_root.exists():
                continue
            for finetune_stage in sorted(finetune_root.iterdir()):
                metrics_path = finetune_stage / "prediction" / "metrics.json"
                if not metrics_path.exists():
                    continue
                with metrics_path.open("r", encoding="utf-8") as handle:
                    payload = json.load(handle)
                metrics_block = payload.get("metrics", {})
                if not metrics_block:
                    continue
                # Each finetune metrics.json contains a single task entry
                task_name, metric_values = next(iter(metrics_block.items()))
                mae = metric_values.get("mae")
                if mae is None:
                    continue
                records.append(
                    {
                        "run": run_label,
                        "stage": stage_index,
                        "property": task_name,
                        "mae": float(mae),
                        "task_sequence": stage_sequence,
                        "property_in_pretrain": task_name in stage_sequence,
                    }
                )
    frame = pd.DataFrame.from_records(records)
    if frame.empty:
        raise RuntimeError("No finetune metrics found in the artifact tree.")
    frame.sort_values(["property", "run", "stage"], inplace=True)
    frame.reset_index(drop=True, inplace=True)
    return frame


finetune_df = load_finetune_records(ARTIFACT_ROOT)
finetune_df.head()

Unnamed: 0,run,stage,property,mae,task_sequence,property_in_pretrain
0,run01,1,Brittleness temp,32.253574,[TC_dihed],False
1,run01,2,Brittleness temp,30.5644,"[TC_dihed, sp_ele_long]",False
2,run01,3,Brittleness temp,31.926353,"[TC_dihed, sp_ele_long, dielectric_const_dc]",False
3,run01,4,Brittleness temp,34.134763,"[TC_dihed, sp_ele_long, dielectric_const_dc, T...",False
4,run01,5,Brittleness temp,34.771431,"[TC_dihed, sp_ele_long, dielectric_const_dc, T...",False


In [4]:
finetune_df

Unnamed: 0,run,stage,property,mae,task_sequence,property_in_pretrain
0,run01,1,Brittleness temp,32.253574,[TC_dihed],False
1,run01,2,Brittleness temp,30.564400,"[TC_dihed, sp_ele_long]",False
2,run01,3,Brittleness temp,31.926353,"[TC_dihed, sp_ele_long, dielectric_const_dc]",False
3,run01,4,Brittleness temp,34.134763,"[TC_dihed, sp_ele_long, dielectric_const_dc, T...",False
4,run01,5,Brittleness temp,34.771431,"[TC_dihed, sp_ele_long, dielectric_const_dc, T...",False
...,...,...,...,...,...,...
33890,run13,36,Water vapor transmission,,"[sp_ele_long, qm_dipole_monomer1, self-diffusi...",False
33891,run13,37,Water vapor transmission,,"[sp_ele_long, qm_dipole_monomer1, self-diffusi...",False
33892,run13,38,Water vapor transmission,,"[sp_ele_long, qm_dipole_monomer1, self-diffusi...",False
33893,run13,39,Water vapor transmission,,"[sp_ele_long, qm_dipole_monomer1, self-diffusi...",False


In [5]:
sns.set_theme(style="whitegrid", context="talk")
properties = [prop for prop in FINETUNE_TASKS if prop in finetune_df["property"].unique()]
runs = sorted(finetune_df["run"].unique())
palette = dict(zip(runs, sns.color_palette("tab10", n_colors=len(runs))))

In [6]:
max_stage = finetune_df["stage"].max()
TREND_DEGREE = 1  # 1 for linear, 2 for quadratic
SLOPE_EPS = 1e-8  # guard for zero division when normalizing slopes

# Create output directory
output_dir = Path("plots/finetune_mae_trajectories")
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Saving plots to {output_dir.resolve()}")

# First pass: compute normalized linear trend slope and percentage change for sort order
property_trend_stats = []
for property_name in properties:
    subset = finetune_df[finetune_df["property"] == property_name]
    if subset.empty:
        continue
    stats_df = subset.groupby("stage")["mae"].agg(["mean", "std"]).reset_index()
    if len(stats_df) < 2:
        continue
    # Linear slope for sorting (degree-1 fit regardless of plot degree)
    linear_coeffs = np.polyfit(stats_df["stage"], stats_df["mean"], deg=1)
    slope = linear_coeffs[0]
    initial_mae = stats_df["mean"].iloc[0]
    final_mae = stats_df["mean"].iloc[-1]
    pct_change = ((final_mae - initial_mae) / (initial_mae + SLOPE_EPS)) * 100
    norm_denominator = max(abs(initial_mae), SLOPE_EPS)
    norm_slope = slope / norm_denominator
    property_trend_stats.append((property_name, norm_slope, slope, pct_change))

# Sort by normalized slope ascending (more negative first)
property_trend_stats.sort(key=lambda x: x[1])

# Second pass: generate plots with sorted filenames
for rank, (property_name, norm_slope, slope, pct_change) in enumerate(property_trend_stats, start=1):
    subset = finetune_df[finetune_df["property"] == property_name]
    if subset.empty:
        continue

    fig, ax1 = plt.subplots(figsize=(7.5, 6), dpi=150)

    # Calculate Mean and Std for each stage across runs
    stats_df = subset.groupby("stage")["mae"].agg(["mean", "std"]).reset_index()
    stats_df["std"] = stats_df["std"].fillna(0)  # Handle single run case

    # Plot Mean + Std
    line_handle, = ax1.plot(
        stats_df["stage"],
        stats_df["mean"],
        marker="o",
        markersize=6,
        color=palette[runs[0]] if runs else "tab:blue",
        linewidth=2.5,
        label="Mean MAE"
    )
    ax1.fill_between(
        stats_df["stage"],
        stats_df["mean"] - stats_df["std"],
        stats_df["mean"] + stats_df["std"],
        color=line_handle.get_color(),
        alpha=0.2,
        label="Std Dev"
    )

    # Plot Trend Line (1st or 2nd degree)
    if len(stats_df) >= 2:
        coeffs = np.polyfit(stats_df["stage"], stats_df["mean"], deg=TREND_DEGREE)
        trend_poly = np.poly1d(coeffs)
        stage_grid = np.linspace(stats_df["stage"].min(), stats_df["stage"].max(), 200)
        ax1.plot(
            stage_grid,
            trend_poly(stage_grid),
            color="black",
            linewidth=2.0,
            linestyle="--",
            label=f"{'Linear' if TREND_DEGREE==1 else 'Quadratic'} Trend"
        )

    pretty_name = property_name.replace("_", " ").title()
    title_main = pretty_name
    title_stats = f"(MAE $\\rightarrow$ {pct_change:+.1f}% | norm slope {norm_slope:+.3f})"
    ax1.set_title(f"{title_main}\n{title_stats}", fontsize=16, pad=15)

    ax1.legend(loc='upper right', fontsize=10)

    ax1.set_xlim(0.5, max_stage + 0.5)
    ax1.set_xlabel("# of Pretrain Properties", fontsize=14)
    ax1.set_ylabel("MAE", fontsize=14)
    ax1.tick_params(axis="both", labelsize=12)
    ax1.grid(True, which="major", linewidth=0.6, alpha=0.4)

    for spine in ax1.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1.2)
        spine.set_color("black")

    plt.tight_layout()

    # Save figure with rank and normalized slope in filename
    safe_property_name = property_name.replace("/", "_").replace(" ", "_")
    filename = f"{rank:02d}_nslope{norm_slope:+07.4f}_pct{pct_change:+06.2f}_{safe_property_name}.png"
    save_path = output_dir / filename
    plt.savefig(save_path, bbox_inches="tight")
    plt.close(fig)

print(f"Saved {len(property_trend_stats)} plots, sorted by normalized slope asc")

Saving plots to /Users/liuchang/projects/foundation_model/notebooks/plots/finetune_mae_trajectories
Saved 61 plots, sorted by normalized slope asc


In [8]:
# Summary: count properties by normalized slope thresholds
NEG_THRESHOLD = -0.001
POS_THRESHOLD = 0.001

neg_props = [name for name, norm_slope, _, _ in property_trend_stats if norm_slope < NEG_THRESHOLD]
pos_props = [name for name, norm_slope, _, _ in property_trend_stats if norm_slope > POS_THRESHOLD]

print(f"Properties with norm slope < {NEG_THRESHOLD} (Count: {len(neg_props)}):")
for idx, name in enumerate(neg_props, 1):
    print(f"{idx}. {name}")

print(f"\nProperties with norm slope > {POS_THRESHOLD} (Count: {len(pos_props)}):")
for idx, name in enumerate(pos_props, 1):
    print(f"{idx}. {name}")

Properties with norm slope < -0.001 (Count: 31):
1. Growth rate of crystal
2. Crystallization temp
3. Tensile loss modulus freq-middle
4. Elongation at yield
5. Hansen parameter delta-p
6. Shear storage modulus loss tangent freq-all
7. Tensile loss modulus freq-high
8. Heat of fusion
9. Hansen parameter delta-h
10. Tg
11. Contact angle
12. Tensile storage modulus loss tangent freq-middle
13. Charpy impact
14. Shear storage modulus loss tangent freq-low
15. Thermal decomposition temp-Temperature
16. Heat of crystallization
17. Dielectric const AC freq-middle
18. Compressive modulus
19. Density
20. Tm
21. HDT
22. Flexural stress at break
23. Dielectric const AC freq-high
24. Tensile storage modulus freq-high
25. Shear storage modulus freq-all
26. Volume expansion coefficient
27. Brittleness temp
28. Water absorption
29. Stress-optical coefficient
30. Shear storage modulus freq-low
31. Tensile storage modulus loss tangent freq-high

Properties with norm slope > 0.001 (Count: 6):
1. Tensil

Each panel isolates a single property. Semi-transparent lines trace MAE progression per run, bold segments emphasise epochs after the property joins the pretrain curriculum, and a dashed linear trendline summarises the overall trajectory.