In [1]:
import os, time, json, itertools, datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import dimod
import networkx as nx
import minorminer
from math import erf, sqrt

from dwave.system import (
    DWaveCliqueSampler,
    DWaveSampler,
    FixedEmbeddingComposite
)
from dwave.embedding.chain_strength import uniform_torque_compensation

# tqdm (fallback if not installed)
try:
    from tqdm.auto import tqdm
except Exception:
    def tqdm(x, **kwargs):  # no-op fallback
        return x
        
SOLVER   = os.getenv("DWAVE_SOLVER", "Advantage2_system1.6")
ENDPOINT = 'https://cloud.dwavesys.com/sapi'
TOKEN = '---' # Please use DWAVE API TOKEN
SOLVER_FILTER = 'Advantage2_system1.6'
# -----------------------------
# Config: endpoint/token/solver
# -----------------------------
#ENDPOINT = os.getenv("DWAVE_API_ENDPOINT", "https://cloud.dwavesys.com/sapi")
#TOKEN    = os.getenv("DWAVE_API_TOKEN", None)   # export DWAVE_API_TOKEN=...
# Solver can be a name (str) or a filter (dict). Examples:
#   SOLVER = 'Advantage2_system1.6'
#   SOLVER = {'topology__type': 'zephyr'}
#SOLVER   = os.getenv("DWAVE_SOLVER", "Advantage2_system1.6")

# -----------------------------
# Logging
# -----------------------------
def log(msg):
    ts = datetime.datetime.now().strftime("%H:%M:%S")
    print(f"[{ts}] {msg}", flush=True)

# -----------------------------
# Small util: append progress row
# -----------------------------
def _append_progress_row(path, row_dict):
    """Append a row to CSV (header auto-created)."""
    df = pd.DataFrame([row_dict])
    header = not os.path.exists(path)
    df.to_csv(path, mode='a', header=header, index=False)

# -----------------------------
# Build L_list from CliqueSampler (no credits)
# -----------------------------
def get_L_list_from_clique(L_min=8, step=5, solver=SOLVER):
    if TOKEN is None:
        log("WARNING: DWAVE_API_TOKEN not set; CliqueSampler may fail to auth.")
    cs = DWaveCliqueSampler(endpoint=ENDPOINT, token=TOKEN, solver=solver)
    L_max = getattr(cs, "largest_clique_size", None)
    if not isinstance(L_max, int) or L_max <= 0:
        raise RuntimeError("largest_clique_size.")
    if L_min > L_max:
        L_min = L_max
    L_list = list(range(L_min, L_max + 1, step)) or [L_max]
    log(f"[L_max] {L_max}  |  [L_list] {L_list} (step={step})")
    return L_list, L_max

# -----------------------------
# QUBO → BQM (one per L)
# -----------------------------
def make_random_bqm(L, h_range=(-1.0,1.0), J_range=(-1.0,1.0),
                    density=1.0, seed=0):
    rng = np.random.default_rng(seed)
    Q = {}
    for i in range(L):
        Q[(i, i)] = rng.uniform(*h_range)
    for i in range(L):
        for j in range(i+1, L):
            if rng.random() < density:
                Q[(i, j)] = rng.uniform(*J_range)
    return dimod.BinaryQuadraticModel.from_qubo(Q)

# -----------------------------
# QPU & Embeddings
# -----------------------------
def get_qpu(solver=SOLVER):
    if TOKEN is None:
        log("WARNING: DWAVE_API_TOKEN not set; DWaveSampler may fail to auth.")
    return DWaveSampler(endpoint=ENDPOINT, token=TOKEN, solver=solver)

def find_clique_embedding(qpu, L, seed=0, verbose=True):
    """Manual K_L embedding on hardware graph."""
    hw = qpu.to_networkx_graph()
    K = nx.complete_graph(L)
    emb = minorminer.find_embedding(K.edges(), hw.edges(), random_seed=seed)
    if not emb:
        raise RuntimeError(f"Embedding failed for K_{L}")
    if verbose:
        lens = [len(c) for c in emb.values()]
        log(f"[Embedding-KL] L={L} chains={len(emb)} meanCL={np.mean(lens):.2f} maxCL={np.max(lens)}")
    return emb

def find_bqm_embedding(qpu, bqm, seed=0, verbose=True):
    """BQM sparsity-based embedding (shorter chains for sparse graphs)."""
    hw = qpu.to_networkx_graph()
    G = nx.Graph()
    G.add_nodes_from(bqm.variables)
    G.add_edges_from(bqm.quadratic.keys())
    emb = minorminer.find_embedding(G.edges(), hw.edges(), random_seed=seed)
    if not emb or any(v not in emb for v in bqm.variables):
        raise RuntimeError("BQM-based embedding failed")
    if verbose:
        lens = [len(c) for c in emb.values()]
        log(f"[Embedding-BQM] n={len(emb)} meanCL={np.mean(lens):.2f} maxCL={np.max(lens)}")
    return emb

def chain_lengths_from_embedding(emb):
    return np.array([len(c) for c in emb.values()], dtype=int)

# -----------------------------
# Solver param negotiation (anneal_time vs anneal_schedule)
# -----------------------------
def _solver_allowed_params(qpu):
    try:
        return set(qpu.parameters.keys())
    except Exception:
        return set()

def _apply_anneal_params(qpu, base_kwargs, anneal_time_us=None, anneal_schedule=None, verbose=False):
    """
    Return kwargs filtered/mapped to solver-supported params:
    - if anneal_schedule given and supported -> use it
    - else if anneal_time_us:
        * if 'anneal_time' supported -> pass as is
        * elif 'anneal_schedule' supported -> convert to linear schedule
    - drop unsupported keys
    """
    allowed = _solver_allowed_params(qpu)
    kw = dict(base_kwargs)

    if anneal_schedule is not None and 'anneal_schedule' in allowed:
        kw['anneal_schedule'] = anneal_schedule
    elif anneal_time_us is not None:
        if 'anneal_time' in allowed:
            kw['anneal_time'] = float(anneal_time_us)
        elif 'anneal_schedule' in allowed:
            kw['anneal_schedule'] = [(0.0, 0.0), (float(anneal_time_us), 1.0)]
            if verbose:
                log(f"[adapt] anneal_time_us={anneal_time_us} -> linear schedule (no 'anneal_time')")
        else:
            if verbose:
                log("[adapt] anneal_time unsupported; using default schedule.")
    return kw

# -----------------------------
# CBF summarizer
# -----------------------------
def summarize_cbf(sampleset):
    if 'chain_break_fraction' in sampleset.record.dtype.names:
        vec = sampleset.record['chain_break_fraction']
    else:
        info = getattr(sampleset, "info", {}) or {}
        emb_ctx = info.get("embedding_context", {})
        if "chain_break_fraction" in emb_ctx:
            vec = np.asarray(emb_ctx["chain_break_fraction"], dtype=float)
            if len(vec) != len(sampleset):
                vec = np.resize(vec, len(sampleset))
        else:
            vec = np.full(len(sampleset), np.nan, dtype=float)

    finite = np.isfinite(vec)
    if not finite.any():
        return dict(mean=0.0, std=0.0, prob_break=0.0, n=len(vec), vec=vec)
    vecf = vec[finite]
    mean = float(np.mean(vecf))
    std  = float(np.std(vecf, ddof=1)) if len(vecf) > 1 else 0.0
    pbr  = float(np.mean(vecf > 0.0))
    return dict(mean=mean, std=std, prob_break=pbr, n=len(vecf), vec=vecf)

# -----------------------------
# QA runs
# -----------------------------
def run_qa_with_embedding(bqm, emb, num_reads=2000, chain_strength=None,
                          anneal_time_us=None, anneal_schedule=None, verbose=False):
    """
    FixedEmbeddingComposite + QPU; default schedule unless anneal_time_us/schedule requested.
    Auto-maps anneal_time to anneal_schedule if solver lacks 'anneal_time'.
    """
    qpu = get_qpu()
    comp = FixedEmbeddingComposite(qpu, emb)

    comp_kwargs = dict(
        num_reads=num_reads,
        chain_strength=chain_strength,
        chain_break_fraction=True
    )
    solver_kwargs = _apply_anneal_params(qpu, {}, anneal_time_us, anneal_schedule, verbose)
    comp_kwargs.update(solver_kwargs)

    if verbose:
        log(f"[sample kwargs] {comp_kwargs}")

    t0 = time.time()
    ss = comp.sample(bqm, **comp_kwargs)
    dt = time.time() - t0

    stats = summarize_cbf(ss)
    E_vec = ss.record.energy
    return ss, stats, E_vec, dt

def run_qa_via_clique_sampler(bqm, num_reads=2000, chain_strength=None,
                              anneal_time_us=None, anneal_schedule=None, verbose=False):
    """
    DWaveCliqueSampler로 K_n 임베딩 자동 사용.
    체인 길이는 sampleset.info['embedding_context']['embedding']에서 복원.
    """
    qpu = get_qpu()  # for param negotiation
    cs = DWaveCliqueSampler(endpoint=ENDPOINT, token=TOKEN, solver=SOLVER)

    solver_kwargs = _apply_anneal_params(qpu, {}, anneal_time_us, anneal_schedule, verbose)

    t0 = time.time()
    ss = cs.sample(
        bqm,
        num_reads=num_reads,
        chain_strength=chain_strength,
        chain_break_fraction=True,
        **solver_kwargs
    )
    dt = time.time() - t0

    stats = summarize_cbf(ss)
    E_vec = ss.record.energy

    emb = ss.info.get("embedding_context", {}).get("embedding", {})
    lengths = np.array([len(chain) for chain in emb.values()], dtype=int) if emb else np.array([])

    if verbose:
        if len(lengths):
            log(f"[CliqueSampler] meanCL={np.mean(lengths):.2f}, maxCL={np.max(lengths)}")
        log(f"[sample kwargs] {dict(num_reads=num_reads, chain_strength=chain_strength, **solver_kwargs)}")

    return ss, stats, E_vec, dt, lengths

# -----------------------------
# One L: one BQM + selected embedding mode, R replicates
# -----------------------------
def measure_L_oneQUBO(L, *,
                      qubo_seed=0, emb_seed=0, n_reps=10, num_reads=2000,
                      density=1.0, chain_strength_mode="utc",
                      anneal_time_us=None, anneal_schedule=None,
                      sampler_mode="clique_sampler",   # 'clique_sampler' | 'fixed_embedding' | 'bqm_embedding'
                      out_dir=".", verbose=True,
                      # progress options
                      show_progress=False,
                      progress_csv=None,
                      checkpoint_every=1,
                      save_intermediate_hist=False,
                      hist_every=5):
    tL0 = time.time()
    bqm = make_random_bqm(L, density=density, seed=qubo_seed)

    # chain_strength
    if isinstance(chain_strength_mode, (int, float)):
        cs = float(chain_strength_mode)
    elif chain_strength_mode == "utc":
        cs = uniform_torque_compensation(bqm)
    else:
        raise ValueError("chain_strength_mode must be 'utc' 또는 float")

    per_rep, vecs, E_list = [], [], []
    lengths = np.array([], dtype=int)

    if sampler_mode == "clique_sampler":
        # replicate 0: also capture chain lengths
        ss0, stats0, E0, dt0, lengths0 = run_qa_via_clique_sampler(
            bqm, num_reads=num_reads, chain_strength=cs,
            anneal_time_us=anneal_time_us, anneal_schedule=anneal_schedule, verbose=verbose
        )
        lengths = lengths0 if len(lengths0) else np.array([], dtype=int)
        per_rep.append(dict(
            L=L, replicate=0,
            mean_cbf=stats0['mean'], std_cbf=stats0['std'],
            prob_break=stats0['prob_break'], time_sec=dt0, n_reads=len(ss0),
            mean_energy=float(np.mean(E0)),
            std_energy=float(np.std(E0, ddof=1)) if len(E0)>1 else 0.0,
            best_energy=float(np.min(E0)),
            anneal_time_us=float(anneal_time_us) if anneal_time_us is not None else np.nan
        ))
        vecs.append(stats0['vec']); E_list.append(E0)

        # progress CSV append
        if progress_csv and checkpoint_every and ((0+1) % checkpoint_every == 0):
            _append_progress_row(progress_csv, dict(
                L=L, replicate=0,
                mean_cbf=stats0['mean'], std_cbf=stats0['std'],
                prob_break=stats0['prob_break'], time_sec=dt0,
                mean_energy=float(np.mean(E0)), best_energy=float(np.min(E0)),
                chain_strength=float(cs),
                mean_chainlen=float(np.mean(lengths)) if len(lengths) else np.nan,
                max_chainlen=int(np.max(lengths)) if len(lengths) else -1,
                anneal_time_us=float(anneal_time_us) if anneal_time_us is not None else np.nan,
                elapsed_sec=time.time()-tL0
            ))

        it = tqdm(range(1, n_reps), disable=not show_progress, desc=f"reps L={L}")
        for r in it:
            ss, stats, E_vec, dt, _ = run_qa_via_clique_sampler(
                bqm, num_reads=num_reads, chain_strength=cs,
                anneal_time_us=anneal_time_us, anneal_schedule=anneal_schedule, verbose=False
            )
            per_rep.append(dict(
                L=L, replicate=r,
                mean_cbf=stats['mean'], std_cbf=stats['std'],
                prob_break=stats['prob_break'], time_sec=dt, n_reads=len(ss),
                mean_energy=float(np.mean(E_vec)),
                std_energy=float(np.std(E_vec, ddof=1)) if len(E_vec)>1 else 0.0,
                best_energy=float(np.min(E_vec)),
                anneal_time_us=float(anneal_time_us) if anneal_time_us is not None else np.nan
            ))
            vecs.append(stats['vec']); E_list.append(E_vec)

            try:
                it.set_postfix(meanCBF=f"{stats['mean']:.3f}",
                               pBreak=f"{stats['prob_break']:.3f}",
                               bestE=f"{np.min(E_vec):.3f}")
            except Exception:
                pass

            if progress_csv and checkpoint_every and ((r+1) % checkpoint_every == 0):
                _append_progress_row(progress_csv, dict(
                    L=L, replicate=r,
                    mean_cbf=stats['mean'], std_cbf=stats['std'],
                    prob_break=stats['prob_break'], time_sec=dt,
                    mean_energy=float(np.mean(E_vec)), best_energy=float(np.min(E_vec)),
                    chain_strength=float(cs),
                    mean_chainlen=float(np.mean(lengths)) if len(lengths) else np.nan,
                    max_chainlen=int(np.max(lengths)) if len(lengths) else -1,
                    anneal_time_us=float(anneal_time_us) if anneal_time_us is not None else np.nan,
                    elapsed_sec=time.time()-tL0
                ))

            if save_intermediate_hist and ((r+1) % hist_every == 0 or r == n_reps-1):
                vec_cat_sofar = np.concatenate(vecs) if vecs else np.array([])
                png = os.path.join(out_dir, f"cbf_hist_L{L}_sofar_R{r+1}.png")
                plot_histogram(vec_cat_sofar, f"CBF so far (L={L}, R={r+1})", out_png=png)

    elif sampler_mode in ("fixed_embedding", "bqm_embedding"):
        qpu = get_qpu()
        if sampler_mode == "fixed_embedding":
            emb = find_clique_embedding(qpu, L, seed=emb_seed, verbose=verbose)
        else:
            emb = find_bqm_embedding(qpu, bqm, seed=emb_seed, verbose=verbose)
        lengths = chain_lengths_from_embedding(emb)

        it = tqdm(range(n_reps), disable=not show_progress, desc=f"reps L={L}")
        for r in it:
            ss, stats, E_vec, dt = run_qa_with_embedding(
                bqm, emb, num_reads=num_reads, chain_strength=cs,
                anneal_time_us=anneal_time_us, anneal_schedule=anneal_schedule, verbose=False
            )
            per_rep.append(dict(
                L=L, replicate=r,
                mean_cbf=stats['mean'], std_cbf=stats['std'],
                prob_break=stats['prob_break'], time_sec=dt, n_reads=len(ss),
                mean_energy=float(np.mean(E_vec)),
                std_energy=float(np.std(E_vec, ddof=1)) if len(E_vec)>1 else 0.0,
                best_energy=float(np.min(E_vec)),
                anneal_time_us=float(anneal_time_us) if anneal_time_us is not None else np.nan
            ))
            vecs.append(stats['vec']); E_list.append(E_vec)

            try:
                it.set_postfix(meanCBF=f"{stats['mean']:.3f}",
                               pBreak=f"{stats['prob_break']:.3f}",
                               bestE=f"{np.min(E_vec):.3f}")
            except Exception:
                pass

            if progress_csv and checkpoint_every and ((r+1) % checkpoint_every == 0):
                _append_progress_row(progress_csv, dict(
                    L=L, replicate=r,
                    mean_cbf=stats['mean'], std_cbf=stats['std'],
                    prob_break=stats['prob_break'], time_sec=dt,
                    mean_energy=float(np.mean(E_vec)), best_energy=float(np.min(E_vec)),
                    chain_strength=float(cs),
                    mean_chainlen=float(np.mean(lengths)) if len(lengths) else np.nan,
                    max_chainlen=int(np.max(lengths)) if len(lengths) else -1,
                    anneal_time_us=float(anneal_time_us) if anneal_time_us is not None else np.nan,
                    elapsed_sec=time.time()-tL0
                ))

            if save_intermediate_hist and ((r+1) % hist_every == 0 or r == n_reps-1):
                vec_cat_sofar = np.concatenate(vecs) if vecs else np.array([])
                png = os.path.join(out_dir, f"cbf_hist_L{L}_sofar_R{r+1}.png")
                plot_histogram(vec_cat_sofar, f"CBF so far (L={L}, R={r+1})", out_png=png)

    else:
        raise ValueError("sampler_mode ∈ {'clique_sampler','fixed_embedding','bqm_embedding'}")

    # pooled stats
    vec_cat = np.concatenate(vecs) if vecs else np.array([])
    pooled = dict(
        mean=float(np.mean(vec_cat)) if len(vec_cat) else 0.0,
        std=float(np.std(vec_cat, ddof=1)) if len(vec_cat)>1 else 0.0,
        prob_break=float(np.mean(vec_cat>0.0)) if len(vec_cat) else 0.0,
        n=len(vec_cat)
    )
    E_all = np.concatenate(E_list) if E_list else np.array([])
    energy_pooled = dict(
        mean=float(np.mean(E_all)) if len(E_all) else 0.0,
        std=float(np.std(E_all, ddof=1)) if len(E_all)>1 else 0.0,
        best=float(np.min(E_all)) if len(E_all) else 0.0
    )

    if verbose:
        atxt = f" | AT={anneal_time_us}us" if anneal_time_us is not None else ""
        mcl = float(np.mean(lengths)) if len(lengths) else float('nan')
        xcl = int(np.max(lengths)) if len(lengths) else -1
        log(f"[L={L}] pooled_meanCBF={pooled['mean']:.4f}, pooled_pBreak={pooled['prob_break']:.3f}, "
            f"meanE={energy_pooled['mean']:.4f}, bestE={energy_pooled['best']:.4f} "
            f"| meanCL={mcl:.2f}, maxCL={xcl}{atxt} | elapsed={time.time()-tL0:.2f}s")

    return dict(
        L=L, qubo_seed=qubo_seed, emb_seed=emb_seed, chain_strength=float(cs),
        per_rep=per_rep, vecs=vecs, lengths=lengths, pooled=pooled,
        E_list=E_list, E_all=E_all, energy_pooled=energy_pooled,
        anneal_time_us=anneal_time_us
    )

# -----------------------------
# ICE model & fitting
# -----------------------------
def sigma_eff_sq(ell, sigma_h, sigma_c):
    return ell*(sigma_h**2) + max(0, ell-1)*(sigma_c**2)

def p_break_exact(ell, sigma_h, sigma_c, kappa):
    s2 = sigma_eff_sq(int(ell), sigma_h, sigma_c)
    if s2 <= 0: return 0.0
    x = kappa/(sqrt(2.0)*sqrt(s2))
    return 1.0 - erf(x)  # == erfc(x)

def cbf_from_lengths(lengths, sigma_h, sigma_c, kappa):
    if len(lengths)==0: return 0.0
    return float(np.mean([p_break_exact(ell, sigma_h, sigma_c, kappa) for ell in lengths]))

def fit_sigma_params(calib_rows,
                     sh_grid=np.linspace(0.005,0.08,16),
                     sc_grid=np.linspace(0.005,0.08,16),
                     k_grid =np.linspace(0.10,1.00,19),
                     verbose=True):
    best, best_err = None, float('inf')
    for sh, sc, k in itertools.product(sh_grid, sc_grid, k_grid):
        err = 0.0
        for row in calib_rows:
            pred = cbf_from_lengths(row["lengths"], sh, sc, k)
            err += (pred - row["cbf_obs"])**2
        if err < best_err:
            best_err = err
            best = (float(sh), float(sc), float(k))
    if verbose and best is not None:
        log(f"[fit] sigma_h={best[0]:.4f}, sigma_c={best[1]:.4f}, kappa={best[2]:.4f} | SSE={best_err:.6f}")
    table = []
    for row in calib_rows:
        pred = cbf_from_lengths(row["lengths"], best[0], best[1], best[2])
        table.append(dict(L=row["L"], cbf_obs=row["cbf_obs"], cbf_pred=float(pred)))
    return dict(sigma_h=best[0], sigma_c=best[1], kappa=best[2], sse=best_err, table=table)

# -----------------------------
# Plots
# -----------------------------
def plot_histogram(vec, title, out_png=None, bins=25):
    plt.figure(figsize=(6,4))
    plt.hist(vec, bins=bins, alpha=0.7, density=True)
    plt.title(title); plt.xlabel("chain_break_fraction"); plt.ylabel("density")
    plt.tight_layout()
    if out_png:
        plt.savefig(out_png, dpi=160); log(f"[saved] {out_png}")
    plt.close()

# -----------------------------
# Run all L, fit ICE, save outputs
# -----------------------------
def run_experiment_and_fit(L_list,
                           n_reps=10, num_reads=2000,
                           qubo_seed_base=0, emb_seed_base=0,
                           density=1.0, chain_strength_mode="utc",
                           out_dir="qa_ice_fit_hybrid_out",
                           sh_grid=None, sc_grid=None, k_grid=None,
                           make_hist=True, verbose=True,
                           anneal_time_us=None, anneal_schedule=None,
                           sampler_mode="clique_sampler",
                           # progress options
                           show_progress=True,
                           checkpoint_every=1,
                           save_intermediate_hist=False,
                           hist_every=5):
    os.makedirs(out_dir, exist_ok=True)
    perL_rows, raw_npz = [], {}

    # progress CSV (per replicate)
    progress_csv = os.path.join(out_dir, "progress_per_rep.csv")
    # live per-L summary (updated each L)
    perL_live_csv = os.path.join(out_dir, "summary_per_L_live.csv")

    iterable = tqdm(list(enumerate(L_list, 1)), total=len(L_list),
                    disable=not show_progress, desc="L sweep")

    for idx, L in iterable:
        log(f"\n=== [{idx}/{len(L_list)}] L={L} ===")
        try:
            res = measure_L_oneQUBO(
                L, qubo_seed=qubo_seed_base+L, emb_seed=emb_seed_base+L,
                n_reps=n_reps, num_reads=num_reads, density=density,
                chain_strength_mode=chain_strength_mode, out_dir=out_dir, verbose=verbose,
                anneal_time_us=anneal_time_us, anneal_schedule=anneal_schedule,
                sampler_mode=sampler_mode,
                # progress options
                show_progress=show_progress,
                progress_csv=progress_csv,
                checkpoint_every=checkpoint_every,
                save_intermediate_hist=save_intermediate_hist,
                hist_every=hist_every
            )
        except Exception as e:
            log(f"[skip] L={L} failed: {e.__class__.__name__}: {e}")
            continue

        # stash raws (CBF + Energy + chain lengths)
        for r, vec in enumerate(res["vecs"]):
            raw_npz[f"CBF_vec_L{L}_rep{r}"] = vec
        raw_npz[f"CBF_vec_L{L}_ALL"] = np.concatenate(res["vecs"]) if res["vecs"] else np.array([])
        raw_npz[f"chain_lengths_L{L}"] = res["lengths"]
        for r, E in enumerate(res["E_list"]):
            raw_npz[f"E_vec_L{L}_rep{r}"] = E
        raw_npz[f"E_vec_L{L}_ALL"] = res["E_all"]

        # per-L row
        mcl = float(np.mean(res["lengths"])) if len(res["lengths"]) else np.nan
        xcl = int(np.max(res["lengths"])) if len(res["lengths"]) else -1
        perL_rows.append(dict(
            L=L, qubo_seed=res["qubo_seed"], emb_seed=res["emb_seed"],
            mean_cbf_obs=res["pooled"]["mean"], std_cbf_obs=res["pooled"]["std"],
            prob_break_obs=res["pooled"]["prob_break"], n_total=res["pooled"]["n"],
            chain_strength=res["chain_strength"],
            mean_chainlen=mcl, max_chainlen=xcl,
            mean_energy_obs=res["energy_pooled"]["mean"], best_energy_obs=res["energy_pooled"]["best"],
            anneal_time_us=float(res["anneal_time_us"]) if res["anneal_time_us"] is not None else np.nan
        ))

        # live summary update
        df_live = pd.DataFrame(perL_rows).sort_values("L").reset_index(drop=True)
        df_live.to_csv(perL_live_csv, index=False)

        if make_hist:
            png = os.path.join(out_dir, f"cbf_hist_L{L}_R{n_reps}.png")
            plot_histogram(raw_npz[f"CBF_vec_L{L}_ALL"],
                           f"CBF Histogram (L={L}, one BQM, R={n_reps} runs)", out_png=png)

    # per-L table (final)
    df_L = pd.DataFrame(perL_rows).sort_values("L").reset_index(drop=True)
    ts = int(time.time())
    perL_csv = os.path.join(out_dir, f"summary_per_L_{ts}.csv")
    df_L.to_csv(perL_csv, index=False); log(f"[saved] {perL_csv}")

    # fit ICE
    calib = [{"L": int(row.L), "cbf_obs": float(row.mean_cbf_obs),
              "lengths": raw_npz[f"chain_lengths_L{int(row.L)}"]} for _, row in df_L.iterrows()]
    if sh_grid is None: sh_grid = np.linspace(0.005,0.08,16)
    if sc_grid is None: sc_grid = np.linspace(0.005,0.08,16)
    if k_grid  is None: k_grid  = np.linspace(0.10,1.00,19)
    fit = fit_sigma_params(calib, sh_grid, sc_grid, k_grid, verbose=True)

    # save fit table
    df_fit = pd.DataFrame(fit["table"]).sort_values("L").reset_index(drop=True)
    df_fit["cbf_err"] = (df_fit["cbf_pred"] - df_fit["cbf_obs"]).abs()
    fit_csv = os.path.join(out_dir, f"fit_obs_vs_pred_{ts}.csv")
    df_fit.to_csv(fit_csv, index=False); log(f"[saved] {fit_csv}")
    log(f"[fit] sigma_h={fit['sigma_h']:.4f}  sigma_c={fit['sigma_c']:.4f}  kappa={fit['kappa']:.4f}  SSE={fit['sse']:.6f}")

    # overlay plot
    plt.figure(figsize=(6,4))
    plt.plot(df_L["L"], df_L["mean_cbf_obs"], marker="o", label="Observed")
    plt.plot(df_fit["L"], df_fit["cbf_pred"], marker="o", label="Predicted")
    plt.xlabel("L (logical variables / chains)"); plt.ylabel("mean CBF")
    plt.title("Observed vs Predicted mean CBF vs L")
    plt.legend(); plt.tight_layout()
    overlay_png = os.path.join(out_dir, f"cbf_obs_pred_vs_L_{ts}.png")
    plt.savefig(overlay_png, dpi=160); log(f"[saved] {overlay_png}")
    plt.close()

    # raws + params
    raw_npz_path = os.path.join(out_dir, f"raw_vectors_{ts}.npz")
    np.savez(raw_npz_path, **raw_npz); log(f"[saved] {raw_npz_path}")
    with open(os.path.join(out_dir, f"fit_params_{ts}.json"), "w") as f:
        json.dump({"sigma_h": fit["sigma_h"], "sigma_c": fit["sigma_c"], "kappa": fit["kappa"],
                   "sse": fit["sse"]}, f, indent=2)

    return dict(df_L=df_L, df_fit=df_fit, fit=fit,
                perL_csv=perL_csv, fit_csv=fit_csv,
                overlay_png=overlay_png, raw_npz=raw_npz_path, out_dir=out_dir)
#ENDPOINT = os.getenv("DWAVE_API_ENDPOINT", "https://cloud.dwavesys.com/sapi")
#TOKEN    = os.getenv("DWAVE_API_TOKEN", None)   # set in your env: export DWAVE_API_TOKEN=...
#SOLVER_FILTER = {'topology__type': 'zephyr'}
# ============================================================
# Sweep helper: chain-strength sweep (single out_dir)
# ============================================================
def run_and_fit_chain_strength_sweep_flat(
    L_list,
    cs_values,
    *,
    n_reps=10,
    num_reads=2000,
    qubo_seed_base=0,
    emb_seed_base=0,
    density=1.0,
    anneal_time_us=20.0,
    out_root="qa_ice_fit_hybrid_CSsweep_AT20us_flat",
    sh_grid=None, sc_grid=None, k_grid=None,
    make_hist=True, verbose=True,
    sampler_mode="clique_sampler",
    # progress options
    show_progress=True, checkpoint_every=1,
    save_intermediate_hist=False, hist_every=5
):

    os.makedirs(out_root, exist_ok=True)
    combined_rows = []

    for cs in cs_values:
        cs_tag = f"cs{str(round(cs,2)).replace('.','p')}"
        log(f"=== [SWEEP] chain_strength={cs:.2f}, anneal_time={anneal_time_us}us ===")

        res = run_experiment_and_fit(
            L_list=L_list,
            n_reps=n_reps, num_reads=num_reads,
            qubo_seed_base=qubo_seed_base, emb_seed_base=emb_seed_base,
            density=density, chain_strength_mode=cs,
            out_dir=out_root,
            sh_grid=sh_grid, sc_grid=sc_grid, k_grid=k_grid,
            make_hist=make_hist, verbose=verbose,
            anneal_time_us=anneal_time_us, anneal_schedule=None,
            sampler_mode=sampler_mode,
            # progress
            show_progress=show_progress,
            checkpoint_every=checkpoint_every,
            save_intermediate_hist=save_intermediate_hist,
            hist_every=hist_every
        )

        df_one = pd.read_csv(res["perL_csv"])
        out_csv = os.path.join(out_root, f"summary_per_L_{cs_tag}.csv")
        df_one.to_csv(out_csv, index=False)
        log(f"[saved] {out_csv}")

        df_one["chain_strength"] = cs
        df_one["anneal_time_us"] = anneal_time_us
        combined_rows.append(df_one)

    df_all = pd.concat(combined_rows, ignore_index=True) if combined_rows else pd.DataFrame()
    combined_csv = os.path.join(out_root, "sweep_summary_per_L.csv")
    df_all.to_csv(combined_csv, index=False)
    log(f"[saved] {combined_csv}")

    return df_all

# -----------------------------
# Main (example)
# -----------------------------
if __name__ == "__main__":
    # 0) L candidates by clique size (no credits)
    L_list, L_max = get_L_list_from_clique(L_min=5, step=5, solver=SOLVER)

    # 1) Choose sampler/embedding mode:
    #    'clique_sampler' | 'fixed_embedding' | 'bqm_embedding'
    SAMPLER_MODE = "clique_sampler"  # default: use DWaveCliqueSampler

    # 2) Chain-strength sweep example (single out_dir)
    cs_values = np.round(np.arange(0.1, 2.51, 0.1), 2)

    df_all = run_and_fit_chain_strength_sweep_flat(
        L_list=L_list,
        cs_values=cs_values,
        n_reps=10,
        num_reads=2000,
        density=1.0,
        anneal_time_us=20.0,
        out_root=f"Test_{SAMPLER_MODE}",
        make_hist=True,
        verbose=True,
        sampler_mode=SAMPLER_MODE,
        # progress
        show_progress=True,
        checkpoint_every=1,
        save_intermediate_hist=False,
        hist_every=5
    )
    log("[done] flat chain strength sweep finished.")


  from .autonotebook import tqdm as notebook_tqdm


[19:17:46] [L_max] 104  |  [L_list] [5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100] (step=5)
[19:17:46] === [SWEEP] chain_strength=0.10, anneal_time=20.0us ===


L sweep:   0%|                                                                                  | 0/20 [00:00<?, ?it/s]

[19:17:46] 
=== [1/20] L=5 ===
[19:17:49] [adapt] anneal_time_us=20.0 -> linear schedule (no 'anneal_time')
[19:17:52] [sample kwargs] {'num_reads': 2000, 'chain_strength': 0.1, 'anneal_schedule': [(0.0, 0.0), (20.0, 1.0)]}



reps L=5:   0%|                                                                                  | 0/9 [00:00<?, ?it/s][A
reps L=5:   0%|                                       | 0/9 [00:05<?, ?it/s, bestE=-0.962, meanCBF=0.000, pBreak=0.000][A
reps L=5:  11%|███▍                           | 1/9 [00:05<00:40,  5.06s/it, bestE=-0.962, meanCBF=0.000, pBreak=0.000][A
reps L=5:  11%|███▍                           | 1/9 [00:10<00:40,  5.06s/it, bestE=-0.962, meanCBF=0.000, pBreak=0.000][A
reps L=5:  22%|██████▉                        | 2/9 [00:10<00:36,  5.18s/it, bestE=-0.962, meanCBF=0.000, pBreak=0.000][A
reps L=5:  22%|██████▉                        | 2/9 [00:15<00:36,  5.18s/it, bestE=-0.962, meanCBF=0.000, pBreak=0.000][A
reps L=5:  33%|██████████▎                    | 3/9 [00:17<00:34,  5.67s/it, bestE=-0.962, meanCBF=0.000, pBreak=0.000][A
L sweep:   0%|                                                                                  | 0/20 [00:22<?, ?it/s]


KeyboardInterrupt: 