In [None]:
from scipy.optimize import minimize_scalar
import math
import numpy as np
import json
from ase.io import read
from pyace import create_multispecies_basis_config
from pyace.activelearning import compute_B_projections
from quests.entropy import perfect_entropy

H_star = {
    "Graphene":        4.245179458166078,
    "Diamond":         4.318381910272738,
    "Graphite":        5.6085074467370095,
    "Nanotubes":       7.0282707526691715,
    "Fullerenes":      8.67911004440742,
    "Defects":         9.531933892473084,
    "Surfaces":        9.823139796211981,
    "Liquid":          11.61485589283075,
    "Amorphous_Bulk":  12.183809856122803,
}

# Optional weights (defaults to 1.0)
W = {k: 1.0 for k in H_star}

batch_size = 10000

def optimize_bandwidth_global(descriptor_dict, H_star, W, h0=None, span=5.0, tol=1e-3, maxit=60):
    """
    descriptor_dict: {name: np.ndarray (N_atoms, N_feat)}
    H_star:          {name: float} ground-truth entropies
    W:               {name: float} weights
    h0:              initial guess for h (>0). If None, use scale heuristic.
    span:            multiplicative span around h0 for bracketing; i.e. [h0/span, h0*span]
    returns: (h_opt, report_dict)
    """
    # If no h0, build a simple scale from descriptors (median L2 across a tiny subset)
    if not h0 or not np.isfinite(h0) or h0 <= 0:
        samples = []
        for X in descriptor_dict.values():
            if X.size == 0: continue
            n = min(len(X), 64)
            idx = np.random.choice(len(X), size=n, replace=False)
            Xi = X[idx]
            # pairwise norms on a tiny subset for scale
            diffs = Xi[:n//2] - Xi[n//2:n]
            if diffs.size:
                samples.append(np.median(np.linalg.norm(diffs, axis=1)))
        h0 = np.median(samples) if samples else 1.0
        if not np.isfinite(h0) or h0 <= 0:
            h0 = 1.0

    # Work in log-space to keep h>0
    u0 = math.log(h0)
    # Bracket in log-space using symmetric span
    u_lo = math.log(h0 / span)
    u_hi = math.log(h0 * span)

    def L_of_u(u):
        h = math.exp(u)
        loss = 0.0
        for name, X in descriptor_dict.items():
            if name not in H_star: 
                continue
            H = perfect_entropy(X, h=h, batch_size=batch_size)
            diff = H - H_star[name]
            loss += W.get(name, 1.0) * (diff * diff)
        return loss

    res = minimize_scalar(
        L_of_u,
        bracket=(u_lo, u0, u_hi),   # Brent accepts a 3-pt bracket
        method="brent",
        options={"xtol": tol, "maxiter": maxit},
    )
    u_opt = res.x
    h_opt = float(math.exp(u_opt))

    # Build a small report
    report = {
        "success": bool(res.success),
        "message": str(res.message),
        "nit": int(res.nfev),
        "u_opt": float(u_opt),
        "h0": float(h0),
        "u0": float(u0),
        "bracket": [float(u_lo), float(u0), float(u_hi)],
        "final_loss": float(res.fun),
    }
    return h_opt, report

# --------- Inside your sweep loop, replace the bandwidth section ---------
# 1) Precompute descriptors once per dataset (with the *data* basis)
frames_cache = {name: read(f"/home/grethel/dev/quests/examples/gap20/{name}.xyz", index=":")
                for name in H_star}

# ... within your for params in sweep: loop ...
# Build data basis (C) once per tier
data_basis_config = make_ace_config(params=params, elements=["C"])
data_basis = create_multispecies_basis_config(data_basis_config)

# Compute descriptors once per dataset for this basis
descriptor_dict = {}
for name, frames in frames_cache.items():
    X = compute_B_projections(data_basis, frames)[0]
    descriptor_dict[name] = X

# Get initial guess from your FCC heuristic (as you asked)
fcc_basis_config = make_ace_config(params=params, elements=["Au"])  # or ["Au","Ag"] if you prefer
fcc_basis = create_multispecies_basis_config(fcc_basis_config)
h0 = fcc_strain_heuristic(fcc_basis, supercell=3)  # your current function

# 2) Brent minimize the global loss
h_opt, opt_report = optimize_bandwidth_global(
    descriptor_dict=descriptor_dict,
    H_star=H_star,
    W=W,
    h0=h0,
    span=5.0,     # try 3–10 if needed
    tol=1e-3,
    maxit=60
)

# 3) Evaluate entropies at the optimized bandwidth and serialize
entry = {
    "basis_config": data_basis_config,   # the basis actually used for datasets
    "elements_data": ["C"],
    "bandwidth_init": float(h0),
    "bandwidth_opt": float(h_opt),
    "optimizer": opt_report,
    "entropy": {},
    "entropy_error": {},                 # H - H*
}

for name, X in descriptor_dict.items():
    H = perfect_entropy(X, h=h_opt, batch_size=batch_size)
    entry["entropy"][name] = float(H)
    if name in H_star:
        entry["entropy_error"][name] = float(H - H_star[name])

# append to jsonl
with open(results_path, "a") as f:
    f.write(json.dumps(entry) + "\n")
    f.flush()


In [None]:
## !/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import json
import math
import numpy as np

from ase.io import read
from ase.build import bulk, make_supercell

from scipy.optimize import minimize_scalar

from pyace import create_multispecies_basis_config
from pyace.activelearning import compute_B_projections

from quests.entropy import perfect_entropy  # and diversity if you need it

# -------------------------- User Config ---------------------------------

# GAP-20 dataset names you want to include in the global fit
DATASETS = [
    "Graphene",
    "Diamond",
    "Graphite",
    "Nanotubes",
    "Fullerenes",
    "Defects",
    "Surfaces",
    "Liquid",
    "Amorphous_Bulk",
]
# DATASETS = [
#     "Graphene",
#     "Graphite",
#     "Nanotubes"
# ]

# Provide ground-truth entropies for each dataset (REPLACE these with real values)
H_STAR = {
    "Graphene":        4.245179458166078,
    "Diamond":         4.318381910272738,
    "Graphite":        5.6085074467370095,
    "Nanotubes":       7.0282707526691715,
    "Fullerenes":      8.67911004440742,
    "Defects":         9.531933892473084,
    "Surfaces":        9.823139796211981,
    "Liquid":          11.61485589283075,
    "Amorphous_Bulk":  12.183809856122803,
}

# Optional per-dataset weights (defaults to 1.0 if key missing)
W = {k: 1.0 for k in H_STAR}

# Where the GAP-20 XYZ files live
GAP20_DIR = "/home/grethel/dev/quests/examples/gap20"

# Output JSONL
RESULTS_PATH = "/home/grethel/dev/quests/sweep_results/sweep_global_brent.jsonl"
os.makedirs(os.path.dirname(RESULTS_PATH), exist_ok=True)

# Batch size for entropy calculation
BATCH_SIZE = 10_000

# FCC heuristic (initial guess) probe settings
HEURISTIC_ELEMENTS = ("Au",)   # or ("Au","Ag") if you want multi-component probe
HEURISTIC_SUPERCELL = 3        # supercell multiplier for the FCC probe

# Fidelity sweep (tiers of nrad/lmax)
# SWEEP = [
#     {"nrad": [4],           "lmax": [4]},             # F0
#     {"nrad": [6, 3],        "lmax": [6, 3]},          # F1
#     {"nrad": [8, 4, 2],     "lmax": [8, 6, 2]},       # F2
#     {"nrad": [10, 6, 3],    "lmax": [10, 8, 4]},      # F3
#     {"nrad": [12, 8, 4],    "lmax": [12, 10, 6]},     # F4
# ]
SWEEP = [
    {"nrad": [12, 8, 4],    "lmax": [12, 10, 6]},     # F4
]

# ------------------------------------------------------------------------


def make_ace_config(params: dict, elements: list, rcut: float = 5.5, dcut: float = 0.2, ndensity: int = 1):
    """Build a python-ACE basis config for the given elements and (nrad,lmax) tiers."""
    return {
        "elements": elements,
        "embeddings": {
            "ALL": {
                "npot": "FinnisSinclairShiftedScaled",
                "fs_parameters": [3.0, 0.8],
                "ndensity": int(ndensity),  # must be int
            }
        },
        "bonds": {
            "ALL": {
                "radbase": "SBessel",
                "radparameters": [rcut],
                "rcut": rcut,
                "dcut": dcut,
            }
        },
        "functions": {
            "ALL": {
                "nradmax_by_orders": params["nrad"],
                "lmax_by_orders":    params["lmax"],
            }
        },
        "deltaSplineBins": 5e-5,
    }


def fcc_strain_heuristic(basis, species="Au", a=3.58, supercell=1, strain=0.99):
    """Single-species FCC heuristic; returns L2 between one atom before/after strain."""
    fcc1 = bulk(species, "fcc", a=a, cubic=True)
    if supercell and supercell > 1:
        T = (np.eye(3, dtype=int) * int(supercell))
        fcc1 = make_supercell(fcc1, T)

    fcc2 = fcc1.copy()
    fcc2.set_cell(strain * fcc2.cell, scale_atoms=True)

    X1 = compute_B_projections(basis, [fcc1])[0]  # (N_atoms, N_feat)
    X2 = compute_B_projections(basis, [fcc2])[0]
    if X1.shape[0] == 0 or X2.shape[0] == 0:
        raise ValueError("Empty descriptors in FCC heuristic; check basis elements and rcut/dcut.")

    # Compare atom 0 (deterministic in cubic cell)
    return float(np.linalg.norm(X1[0] - X2[0]))


def _default_scale_from_descriptors(descriptor_dict):
    """Fallback to build a positive initial h if heuristic fails or not provided."""
    samples = []
    for X in descriptor_dict.values():
        if X is None or X.size == 0:
            continue
        n = min(len(X), 64)
        if n < 2:
            continue
        idx = np.random.choice(len(X), size=n, replace=False)
        Xi = X[idx]
        n2 = n // 2
        diffs = Xi[:n2] - Xi[n2:2*n2]
        if diffs.size > 0:
            samples.append(np.median(np.linalg.norm(diffs, axis=1)))
    h0 = np.median(samples) if samples else 1.0
    if not np.isfinite(h0) or h0 <= 0:
        h0 = 1.0
    return float(h0)


def optimize_bandwidth_global(descriptor_dict, H_star, W, h0=None,
                              span=5.0, tol=1e-3, maxit=200):
    """
    Global bandwidth fit: minimize L(h)=sum_d w_d [H_d(h)-H*_d]^2.
    Optimize u=log h with 'bounded' (golden-section). No bracketing min needed.
    """
    # fallback scale if no good h0
    def _default_scale():
        samples = []
        for X in descriptor_dict.values():
            if X is None or X.size == 0: 
                continue
            n = min(len(X), 64)
            if n < 2: 
                continue
            idx = np.random.choice(len(X), size=n, replace=False)
            Xi = X[idx]
            n2 = n // 2
            diffs = Xi[:n2] - Xi[n2:2*n2]
            if diffs.size:
                samples.append(np.median(np.linalg.norm(diffs, axis=1)))
        val = np.median(samples) if samples else 1.0
        return float(val if np.isfinite(val) and val > 0 else 1.0)

    if h0 is None or not np.isfinite(h0) or h0 <= 0:
        h0 = _default_scale()

    # bounds in log-space
    w = math.log(span)
    u0 = math.log(h0)
    u_lo, u_hi = u0 - w, u0 + w

    def L_of_u(u):
        h = math.exp(u)
        loss = 0.0
        for name, X in descriptor_dict.items():
            if name not in H_star:
                continue
            H = perfect_entropy(X, h=h, batch_size=BATCH_SIZE)
            d = H - H_star[name]
            loss += W.get(name, 1.0) * (d * d)
        return loss

    res = minimize_scalar(
        L_of_u,
        bounds=(u_lo, u_hi),
        method="bounded",
        options={"xatol": tol, "maxiter": maxit},
    )
    u_opt = float(res.x)
    h_opt = float(math.exp(u_opt))
    report = {
        "success": bool(res.success),
        "message": str(res.message),
        "nfev": int(res.nfev),
        "u_opt": u_opt,
        "h0": float(h0),
        "bounds": [float(u_lo), float(u_hi)],
        "final_loss": float(res.fun),
        "method": "bounded",
    }
    return h_opt, report


def main():
    # Preload frames for all datasets (so we can reuse across tiers)
    frames_cache = {}
    for name in DATASETS:
        path = os.path.join(GAP20_DIR, f"{name}.xyz")
        frames_cache[name] = read(path, index=":")

    # Open output for appending (one JSON per tier)
    with open(RESULTS_PATH, "w") as fout:
        for params in SWEEP:
            print(f"\n>>> Tier params: {params}")

            # Build data basis (carbon) ONCE per tier and cache descriptors per dataset
            data_basis_config = make_ace_config(params=params, elements=["C"])
            data_basis = create_multispecies_basis_config(data_basis_config)

            descriptor_dict = {}
            for name, frames in frames_cache.items():
                X = compute_B_projections(data_basis, frames)[0]
                descriptor_dict[name] = X
                print(f"  {name}: descriptors shape = {X.shape}")

            # Initial guess from FCC heuristic using HEURISTIC_ELEMENTS
            heuristic_basis_config = make_ace_config(params=params, elements=list(HEURISTIC_ELEMENTS))
            heuristic_basis = create_multispecies_basis_config(heuristic_basis_config)
            try:
                h0 = fcc_strain_heuristic(
                    heuristic_basis,
                    species=HEURISTIC_ELEMENTS[0],
                    a=3.58,
                    supercell=HEURISTIC_SUPERCELL,
                    strain=0.99,
                )
            except Exception as e:
                print(f"  Heuristic failed ({e}); falling back to descriptor scale.")
                h0 = None

            print(f"  Initial bandwidth guess h0 = {h0 if h0 is not None else '(auto)'}")

            # Optimize global bandwidth
            h_opt, opt_report = optimize_bandwidth_global(
                descriptor_dict=descriptor_dict,
                H_star=H_STAR,
                W=W,
                h0=h0,
                span=5.0,
                tol=1e-3,
                maxit=60,
            )
            if (not rpt["success"]) or (abs(rpt["u_opt"] - rpt["bounds"][0]) < 1e-6) or (abs(rpt["u_opt"] - rpt["bounds"][1]) < 1e-6):
                print("Span failed, trying different span...")
                h_opt, opt_report = optimize_bandwidth_global(
                    descriptor_dict=descriptor_dict,
                    H_star=H_STAR,
                    W=W,
                    h0=h0,
                    span=100.0,
                    tol=1e-3,
                    maxit=60,
                )
                
            print(f"  Optimized bandwidth h* = {h_opt:.6f} | loss = {opt_report['final_loss']:.6e}")

            # Evaluate entropies at the optimized bandwidth and compute errors
            entropy_at_opt = {}
            entropy_err = {}
            for name, X in descriptor_dict.items():
                H = perfect_entropy(X, h=h_opt, batch_size=BATCH_SIZE)
                entropy_at_opt[name] = float(H)
                if name in H_STAR:
                    entropy_err[name] = float(H - H_STAR[name])

            # Write JSONL entry (per tier)
            entry = {
                "basis_config": data_basis_config,    # the basis used for the datasets
                "elements_data": ["C"],
                "tier_params": params,
                "bandwidth_init": float(opt_report["h0"]),
                "bandwidth_opt": float(h_opt),
                "optimizer": opt_report,
                "entropy": entropy_at_opt,
                "entropy_error": entropy_err,
            }
            fout.write(json.dumps(entry) + "\n")
            fout.flush()

    print(f"\n✅ Done. Results written to: {RESULTS_PATH}")


if __name__ == "__main__":
    main()
