In [2]:
import numpy as np
from ase.io import read
from pyace import create_multispecies_basis_config
from pyace.activelearning import compute_B_projections
from quests.entropy import perfect_entropy

# ------------------------------------------
# 1) Fix make_ace_config (accept elements=[])
# ------------------------------------------
def make_ace_config(params: dict, elements, rcut: float = 5.5, dcut: float = 0.2):
    return {
        "elements": elements,
        "embeddings": {"ALL": {
            "npot": "FinnisSinclairShiftedScaled",
            "fs_parameters": [3.0, 0.8],
            "ndensity": 1,
        }},
        "bonds": {"ALL": {
            "radbase": "SBessel",
            "radparameters": [rcut],
            "rcut": rcut,
            "dcut": dcut,
        }},
        "functions": {"ALL": {
            "nradmax_by_orders": params["nrad"],
            "lmax_by_orders":    params["lmax"],
        }},
        "deltaSplineBins": 5e-5,
    }

def pilot_bandwidth(X, rng=np.random.default_rng(0), max_pts=2000):
    """
    Pilot h0 from the median pairwise *raw* Euclidean distance.
    Works even when feature scales differ (distance will be dominated by large-scale dims).
    """
    X = np.asarray(X, dtype=float)
    n = len(X)
    idx = np.arange(n)
    if n > max_pts:
        idx = rng.choice(n, size=max_pts, replace=False)
    Y = X[idx]

    m = len(Y)
    # sample random pairs (no full O(n^2) matrix)
    k = min(20000, m * (m - 1) // 2)
    i1 = rng.integers(0, m, size=k)
    i2 = rng.integers(0, m, size=k)
    mask = i1 != i2
    i1, i2 = i1[mask], i2[mask]

    dists = np.linalg.norm(Y[i1] - Y[i2], axis=1)
    med = np.median(dists)
    d = X.shape[1]
    # Same mapping as before, but now in RAW space
    h0 = med / np.sqrt(2.0 * max(d, 1))
    if not np.isfinite(h0) or h0 <= 0:
        n_eff = len(X)
        h0 = n_eff ** (-1.0 / (max(d, 1) + 4))
    print(f"pilot bandwidth {h0}")
    return float(h0)

def evaluate_entropy_loss(X, H_star, h, batch_size=10000):
    H = perfect_entropy(X, h=h, batch_size=batch_size)
    print(f"entropy {H} bandwidth {h}")
    return (H - H_star) ** 2, H

def coarse_log_grid_bracket(X, H_star, h0, width_factor=100.0, num=25, batch_size=10000):
    print("Starting scan...")
    lo = np.log10(h0 / width_factor)
    hi = np.log10(h0 * width_factor)
    grid = np.linspace(lo, hi, num)
    vals = []
    for t in grid:
        h = 10.0 ** t
        f, Hval = evaluate_entropy_loss(X, H_star, h, batch_size=batch_size)
        vals.append((t, f, Hval))
    best_i = int(np.argmin([v[1] for v in vals]))
    a_i = max(0, best_i - 1)
    c_i = min(len(vals) - 1, best_i + 1)
    if a_i == best_i:
        a_i = max(0, best_i - 2)
    if c_i == best_i:
        c_i = min(len(vals) - 1, best_i + 2)
    a, fa, _ = vals[a_i]
    b, fb, _ = vals[best_i]
    c, fc, _ = vals[c_i]
    if not (fb <= fa and fb <= fc):
        a, fa, _ = vals[max(0, best_i - 1)]
        c, fc, _ = vals[min(len(vals) - 1, best_i + 1)]
    return (a, fa), (b, fb), (c, fc), vals

def golden_section_search_log10(X, H_star, a, b, c, max_iter=60, tol=1e-3, batch_size=10000):
    print("Starting search...")
    gr = (np.sqrt(5.0) - 1.0) / 2.0
    left, right = a, c
    x1 = right - gr * (right - left)
    x2 = left + gr * (right - left)

    def f_of_t(t):
        h = 10.0 ** t
        return evaluate_entropy_loss(X, H_star, h, batch_size=batch_size)

    f1, H1 = f_of_t(x1)
    f2, H2 = f_of_t(x2)
    for _ in range(max_iter):
        if abs(right - left) < tol:
            break
        if f1 > f2:
            left = x1
            x1, f1, H1 = x2, f2, H2
            x2 = left + gr * (right - left)
            f2, H2 = f_of_t(x2)
        else:
            right = x2
            x2, f2, H2 = x1, f1, H1
            x1 = right - gr * (right - left)
            f1, H1 = f_of_t(x1)

    if f1 < f2:
        return x1, 10.0 ** x1, f1, H1
    else:
        return x2, 10.0 ** x2, f2, H2

def optimize_bandwidth_entropy(X, H_star, batch_size=10000, grid_width=100.0, grid_pts=25):
    """
    RAW-space optimization: NO standardization.
    """
    X = np.asarray(X, dtype=float)
    h0 = pilot_bandwidth(X)
    (a, fa), (b, fb), (c, fc), scan = coarse_log_grid_bracket(
        X, H_star, h0, width_factor=grid_width, num=grid_pts, batch_size=batch_size
    )
    t_best, h_best, f_best, H_best = golden_section_search_log10(
        X, H_star, a, b, c, max_iter=60, tol=1e-3, batch_size=batch_size
    )
    report = {
        "h0": h0,
        "log10_bounds": (a, c),
        "grid_points": grid_pts,
        "best_log10h": t_best,
        "best_h": h_best,
        "best_entropy": H_best,
        "target_entropy": H_star,
        "abs_error": abs(H_best - H_star),
        "squared_error": f_best,
    }
    return h_best, report

# ------------------------------------------
# 3) Use it with your Graphite run
# ------------------------------------------
data_name = "Graphite"
entropy_label = 5.6085074467370095

# sweep = [{"nrad": [8, 4, 2], "lmax": [8, 6, 2]}]
sweep = {"nrad": [12, 8, 4],    "lmax": [12, 10, 6]},
# sweep = [{"nrad": [4], "lmax": [4]}]
elements = ["C"]  # carbon for GAP-20 graphite

path = f"/home/grethel/dev/quests/examples/gap20/{data_name}.xyz"
frames_list = read(path, index=":")

for params in sweep:
    data_basis_config = make_ace_config(params=params, elements=elements)
    data_basis = create_multispecies_basis_config(data_basis_config)

    # Compute ACE descriptors (array of shape [N, D])
    descriptor_ace = compute_B_projections(data_basis, frames_list)[0]
    descriptor_ace = np.asarray(descriptor_ace, dtype=float)

    # Tune bandwidth to match entropy_label
    h_opt, opt_report = optimize_bandwidth_entropy(
        descriptor_ace, H_star=entropy_label, batch_size=10000, grid_width=100.0, grid_pts=25
    )

    print("\n=== Bandwidth optimization (Graphite) ===")
    print(f"pilot h0         : {opt_report['h0']:.6g}")
    print(f"search log10 span: [{opt_report['log10_bounds'][0]:.3f}, {opt_report['log10_bounds'][1]:.3f}]")
    print(f"best h           : {opt_report['best_h']:.6g}")
    print(f"H(best h)        : {opt_report['best_entropy']:.9f}")
    print(f"target H*        : {opt_report['target_entropy']:.9f}")
    print(f"|H - H*|         : {opt_report['abs_error']:.6g}")


pilot bandwidth 2.830657501330006
Starting scan...
entropy 8.468691334896384 bandwidth 0.02830657501330006
entropy 8.468691319395695 bandwidth 0.04154837007341101
entropy 8.468691312200878 bandwidth 0.06098466716464337
entropy 8.468691303333802 bandwidth 0.08951324980043922
entropy 8.468689204703207 bandwidth 0.13138748249955604
entropy 8.468605413170035 bandwidth 0.19285045058755582
entropy 8.467339428017734 bandwidth 0.2830657501330006
entropy 8.458077474960147 bandwidth 0.4154837007341099
entropy 8.431036872210534 bandwidth 0.6098466716464337
entropy 8.394698849183962 bandwidth 0.8951324980043922
entropy 8.35868415565763 bandwidth 1.31387482499556
entropy 8.31650697826479 bandwidth 1.9285045058755583
entropy 8.254876054246767 bandwidth 2.830657501330006
entropy 8.164969883421916 bandwidth 4.154837007341099
entropy 7.987101221361748 bandwidth 6.098466716464333
entropy 7.475972579342057 bandwidth 8.951324980043923
entropy 6.3642512915556475 bandwidth 13.1387482499556
entropy 5.0222280