In [None]:
import logging
from pathlib import Path
from dataclasses import dataclass
from typing import Iterable, Optional, Callable, Sequence, Tuple, Dict

import numpy as np
import matplotlib.pyplot as plt

from labcore.data.datadict import DataDict
from labcore.data.datadict_storage import datadict_from_hdf5
from labcore.measurement.storage import run_and_save_sweep

from qcui_measurement.protocols.base import ProtocolOperation
from qcui_measurement.protocols.parameters import (
    Repetition, Steps,
    StartReadoutFrequency, EndReadoutFrequency,
    ReadoutGain, ReadoutLength, Delay,
)
from qcui_measurement.protocols.operations.res_spec import ResonatorSpectroscopy
from qcui_measurement.qick.single_transmon_v2 import FreqSweepProgram

logger = logging.getLogger(__name__)


# --------------------------- helpers & containers -----------------------------
# This part defines some functions that unwrap the signal and fit for a mixture of two hangers instead of the single standard complex hanger (notch) model
@dataclass
class FluxList:
    name: str
    value: Iterable[float]

def _unwrap_and_remove_linear_phase(freq: np.ndarray, sig: np.ndarray) -> Tuple[np.ndarray, float]:
    """Unwrap phase and remove best-fit linear slope (cable delay)."""
    # unwrap and linear fit: angle ≈ φ0 + 2π τ f, so it becomes continuous
    ph = np.unwrap(np.angle(sig))
    slope = np.polyfit(freq, ph, 1)[0]
    sig_unw = sig * np.exp(-1j * slope * freq)
    return sig_unw, slope/(2*np.pi)  # return τ estimate in seconds if freq in Hz

def _hanger_single(f: np.ndarray, f0, Ql, Qc, theta, a, phi0, tau) -> np.ndarray:
    x = (f - f0) / f0
    notch = 1.0 - (Ql/Qc) * np.exp(1j*theta) / (1.0 + 2j*Ql*x)
    return a * np.exp(1j*(phi0 + 2*np.pi*f*tau)) * notch

def _hanger_double_mix(f: np.ndarray, params: Dict[str, float]) -> np.ndarray:
    """Mixture of two nearby resonances with shared line params, weighted by p_g."""
    f0g, f0e = params["f0g"], params["f0e"]
    Ql, Qc   = params["Ql"],  params["Qc"]
    theta    = params["theta"]
    a        = params["a"]
    phi0     = params["phi0"]
    tau      = params["tau"]
    pg       = params["pg"]  # in [0,1]
    s_g = _hanger_single(f, f0g, Ql, Qc, theta, a, phi0, tau)
    s_e = _hanger_single(f, f0e, Ql, Qc, theta, a, phi0, tau)
    return pg*s_g + (1.0 - pg)*s_e

def _fit_double_hanger(freq: np.ndarray, sig_cplx: np.ndarray):
    f = np.asarray(freq, float)
    yC = np.asarray(sig_cplx, complex)

    # -------- de-embed linear phase (done in your pipeline, but safe here) --------
    ph = np.unwrap(np.angle(yC))
    b1, b0 = np.polyfit(f, ph, 1)         # ph ≈ b0 + b1 f
    y = yC * np.exp(-1j * (b0 + b1*f))    # unwrapped/de-sloped
    # NOTE: we fit in this frame; we won't refit tau here.

    # -------- initial guesses via local minima on smoothed |S| --------
    mag = np.abs(y)
    win = max(5, len(f)//80)
    mag_s = np.convolve(mag, np.ones(win)/win, mode="same")

    # local minima: indices i with mag_s[i] < neighbors
    locmins = np.where((mag_s[1:-1] < mag_s[:-2]) & (mag_s[1:-1] < mag_s[2:]))[0] + 1
    if len(locmins) == 0:
        locmins = np.array([int(np.argmin(mag_s))])
    # sort by depth
    locmins = locmins[np.argsort(mag_s[locmins])]

    # pick the deepest as ground; the closest *distinct* other min as excited
    i0 = locmins[0]
    # distinct = at least a few bins away
    min_sep_bins = max(5, len(f)//200)
    candidates = [j for j in locmins[1:] if abs(j - i0) >= min_sep_bins]
    j0 = candidates[0] if candidates else (i0 + min_sep_bins if i0 + min_sep_bins < len(f) else i0 - min_sep_bins)

    f0g0, f0e0 = np.sort([f[i0], f[j0]])

    # -------- line-term seeds (gentle) --------
    a0 = float(np.median(mag))
    theta0 = 0.0
    phi00 = 0.0
    tau0 = 0.0
    Ql0, Qc0 = 1.5e4, 3.0e4
    pg0 = 0.7

    # -------- bounds with a narrow f0e window around its seed --------
    # estimate natural linewidth ~ f0 / Ql0; allow a wide multiple
    linewidth = f0g0 / Ql0
    df_win = max(10.0*linewidth, 1.5e6)   # e.g. +/- 1.5 MHz minimum window
    lb = dict(f0g=f.min(),  f0e=f0e0 - 4*df_win, Ql=500,  Qc=500,
              theta=-np.pi, a=0.0, phi0=-np.pi, tau=-1e-6, pg=0.0)
    ub = dict(f0g=f.max(),  f0e=f0e0 + 4*df_win, Ql=5e5, Qc=5e6,
              theta= np.pi, a=10*np.max(mag), phi0= np.pi, tau= 1e-6, pg=1.0)

    params0 = dict(f0g=f0g0, f0e=f0e0, Ql=Ql0, Qc=Qc0, theta=theta0, a=a0, phi0=phi00, tau=tau0, pg=pg0)

    # -------- model & packing --------
    def _hanger_single(ff, f0, Ql, Qc, theta, a, phi0, tau):
        x = (ff - f0)/f0
        notch = 1.0 - (Ql/Qc)*np.exp(1j*theta)/(1.0 + 2j*Ql*x)
        return a*np.exp(1j*(phi0 + 2*np.pi*ff*tau))*notch

    def _hanger_double_mix(ff, p):
        sg = _hanger_single(ff, p["f0g"], p["Ql"], p["Qc"], p["theta"], p["a"], p["phi0"], p["tau"])
        se = _hanger_single(ff, p["f0e"], p["Ql"], p["Qc"], p["theta"], p["a"], p["phi0"], p["tau"])
        return p["pg"]*sg + (1.0 - p["pg"])*se

    keys = ["f0g","f0e","Ql","Qc","theta","a","phi0","tau","pg"]
    def pack(p):   return np.array([p[k] for k in keys], float)
    def unpack(v): return {k: float(x) for k,x in zip(keys, v)}

    v0  = pack(params0)
    vlb = pack(lb)
    vub = pack(ub)

    # -------- residual with weights (emphasize dips) --------
    # weight ~ 1/(mag_s^2) but bounded to avoid extremes
    w = 1.0/np.maximum(mag_s**2, 1e-4)
    w = w/np.max(w)
    def resid(v):
        p = unpack(v)
        s = _hanger_double_mix(f, p)
        r = np.stack([np.real(s) - np.real(y), np.imag(s) - np.imag(y)], axis=1)
        r = (r.T * w).T     # apply weights pointwise to both Re/Im
        return r.ravel()

    # -------- optimize (SciPy if present, else projected gradient) --------
    try:
        from scipy.optimize import least_squares
        res = least_squares(resid, v0, bounds=(vlb, vub), xtol=1e-12, ftol=1e-12, gtol=1e-12, max_nfev=4000)
        vf = res.x
    except Exception:
        vf = v0.copy()
        lr = 2e-6
        eps = np.array([5e3, 5e3, 50.0, 50.0, 1e-3, 1e-3, 1e-3, 1e-9, 1e-2])
        for _ in range(800):
            r0 = resid(vf); base = np.dot(r0, r0)
            g = np.zeros_like(vf)
            for k in range(len(vf)):
                vtmp = vf.copy(); vtmp[k] = np.clip(vtmp[k] + eps[k], vlb[k], vub[k])
                r2 = resid(vtmp); g[k] = (np.dot(r2,r2) - base)/eps[k]
            vf -= lr*g
            vf = np.minimum(np.maximum(vf, vlb), vub)

    p_fit = unpack(vf)
    # enforce ordering convention
    if p_fit["f0g"] > p_fit["f0e"]:
        p_fit["f0g"], p_fit["f0e"] = p_fit["f0e"], p_fit["f0g"]
        p_fit["pg"] = 1.0 - p_fit["pg"]

    # derived Qi and a simple SNR proxy
    Qi = (p_fit["Ql"]*p_fit["Qc"])/max(p_fit["Qc"] - p_fit["Ql"], 1e-6)
    rr = resid(vf).reshape(-1,2)
    noise = np.std(np.hypot(rr[:,0], rr[:,1]))
    amp = np.max(mag) - np.min(mag)
    snr = float(abs(amp)/(4*max(noise, 1e-12)))
    return p_fit, {"Qi": Qi}, snr

from scipy.signal import find_peaks  # at top once

# Fast magnitude-only dip counter (no fitting)
def _count_dips_fast(freq_Hz: np.ndarray, sig_row: np.ndarray) -> int:
    sig_unw, _ = _unwrap_and_remove_linear_phase(freq_Hz, sig_row)
    mag = np.abs(sig_unw)
    win = max(5, len(mag)//200);  win += (win % 2 == 0)
    mag_s = np.convolve(mag, np.ones(win)/win, mode="same")
    peaks, _ = find_peaks(-mag_s,
                          prominence=0.15*np.ptp(mag_s),
                          distance=max(5, len(mag)//150))
    return int(len(peaks))

# Decide once whether to use single or double for the whole dataset
def _decide_model(f_Hz: np.ndarray, sig2d: np.ndarray,
                  sample_every: int = 10,
                  frac_single_threshold: float = 0.75) -> str:
    idx = np.arange(0, len(sig2d), sample_every)
    counts = np.array([_count_dips_fast(f_Hz, sig2d[i]) for i in idx])
    return "single" if np.mean(counts <= 1) >= frac_single_threshold else "double"

# Find linear slope. Fit y ≈ a*(x-x0) + b on [x0-halfwidth, x0+halfwidth]; return slope a
def _local_linear_slope(x, y, x0, halfwidth):
    lo, hi = x0 - halfwidth, x0 + halfwidth
    mask = (x >= lo) & (x <= hi)
    xx, yy = x[mask], y[mask]
    if len(xx) < 3:
        return np.nan
    u = xx - x0
    a, b = np.polyfit(u, yy, 1)   # slope, intercept
    return float(a)

# Symmetry score at center c: score = RMS_{j} [ y(c+Δ_j) - y(c-Δ_j) ] / scale
# where Δ_j are n_pairs offsets in (0, halfwidth], and y is linearly interpolated
def _symmetry_score(x, y, c, halfwidth, n_pairs, eps=1e-12):
    # build offsets that are available on both sides within data range
    Δmax = min(c - x[0], x[-1] - c, halfwidth)
    if Δmax <= 0:
        return np.inf
    Δ = np.linspace(Δmax/n_pairs, Δmax, n_pairs)

    yp = np.interp(c + Δ, x, y)
    ym = np.interp(c - Δ, x, y)

    r = yp - ym
    rms = np.sqrt(np.mean(r*r))

    # robust scale for normalization (MAD over data)
    med = np.median(y)
    mad = np.median(np.abs(y - med)) + eps
    return float(rms / mad)


# ------------------------------ main protocol --------------------------------

class ResonatorSpectroscopyVsFlux(ProtocolOperation):
    """
    Resonator spectroscopy vs flux with two-branch (g/e) hanger fitting.

    Supports either:
      - global start/end readout frequency (params), or
      - per-flux start/end lists via `start_freq_list`, `end_freq_list`.
    """

    def __init__(
        self,
        params,
        *,
        flux_list: FluxList,
        repetitions: Optional[Repetition] = None,
        steps: Steps,
        start_freq: StartReadoutFrequency,
        end_freq: EndReadoutFrequency,
        readout_gain: ReadoutGain,
        length: ReadoutLength,
        delay: Delay,
        set_flux: Optional[Callable[[float], None]] = None,
        # optional per-flux frequency windows (MHz). If provided, must match flux_list length.
        start_freq_list: Optional[Sequence[float]] = None,
        end_freq_list:   Optional[Sequence[float]] = None,
        name: Optional[str] = None,
    ):
        super().__init__(name=name)

        # I/O (aligned with other ops)
        self._register_inputs(
            repetitions or Repetition(params),
            steps,
            start_freq, end_freq,
            readout_gain, length, delay,
        )
        self._register_outputs(readout_if=StartReadoutFrequency(params))

        # controls
        self.flux_values = np.asarray(list(flux_list.value), float)
        self.set_flux = set_flux
        self.start_freq_list = np.asarray(start_freq_list, float) if start_freq_list is not None else None
        self.end_freq_list   = np.asarray(end_freq_list,   float) if end_freq_list   is not None else None

        # data cache
        self.model_choice: Optional[str] = None
        self.independents = {"flux": [], "frequencies": []}
        self.dependents = {"signal": []}  # shape (Nphi, Nf) complex
        self.data_loc: Optional[Path] = None

        # fit results
        self.fr_g:  list[float] = []
        self.fr_e:  list[float] = []
        self.Ql:    list[float] = []
        self.Qc:    list[float] = []
        self.Qi:    list[float] = []
        self.pg:    list[float] = []
        self.snr:   list[float] = []

        self.figure_paths: list[Path] = []
        self.condition = "Success if dataset collected and fits converge."

    # ------------------------------- measure ---------------------------------

    def _measure_quick(self) -> Path:
        logger.info("Starting resonator spectroscopy vs flux measurement")

        rows = []
        last_loc: Optional[Path] = None

        for i, phi in enumerate(self.flux_values):
            logger.debug(f"[{i+1}/{len(self.flux_values)}] Set flux → {phi}")
            if self.set_flux is not None:
                self.set_flux(phi)

            sweep = FreqSweepProgram()

            # (Optional) configure program’s sweep window here if your driver exposes it.
            # If your FreqSweepProgram reads parameters from Start/EndReadoutFrequency only,
            # you can instead run multiple small sweeps by setting params before calling.

            file_label = f"{self.name or 'res_spec_vs_flux'}_phi={phi:.6f}"
            loc, _ = run_and_save_sweep(sweep, "data", file_label)
            last_loc = Path(loc)

            dd = datadict_from_hdf5(last_loc / "data.ddh5")
            freqs = np.asarray(dd["freq"]["values"])    # MHz
            sig   = np.asarray(dd["signal"]["values"])  # complex, shape (Nf,)
            rows.append((phi, freqs, sig))

        if not rows:
            raise RuntimeError("No data collected.")

        # pack into a single DataDict on a common frequency grid (assume same config)
        flux_vals = np.array([r[0] for r in rows])
        freqs_ref = rows[0][1]
        sig_matrix = np.vstack([r[2] for r in rows])  # (Nphi, Nf)

        out = DataDict(
            freq  = dict(values=freqs_ref),
            flux  = dict(values=flux_vals),
            signal= dict(values=sig_matrix, axes=["flux", "freq"]),
        )

        final_dir = (last_loc.parent if last_loc else Path(".")).resolve()
        final_path = final_dir / "data.ddh5"
        out.save(final_path)
        logger.info(f"Saved 2D dataset → {final_path}")

        # cache
        self.independents["frequencies"] = freqs_ref
        self.independents["flux"]        = flux_vals
        self.dependents["signal"]        = sig_matrix
        self.data_loc = final_dir
        return final_dir

    def _load_data_quick(self):
        path = Path(self.data_loc) / "data.ddh5"
        dd = datadict_from_hdf5(path)
        self.independents["frequencies"] = np.asarray(dd["freq"]["values"])
        self.independents["flux"]        = np.asarray(dd["flux"]["values"])
        self.dependents["signal"]        = np.asarray(dd["signal"]["values"])

    # -------------------------------- analyze --------------------------------

    def analyze(self):
        """Fit the spectroscopy map and produce fr_g / fr_e (fr_e=nan for single)."""
        if self.data_loc is None or not self.independents.get("frequencies", []):
            self._load_data_quick()

        freqs_MHz: np.ndarray = np.asarray(self.independents["frequencies"])  # MHz
        flux_vals: np.ndarray = np.asarray(self.independents["flux"])
        sig2d: np.ndarray     = np.asarray(self.dependents["signal"])         # (Nphi, Nf)

        f_Hz = freqs_MHz * 1e6

        # Decide model once for the whole dataset
        self.model_choice = _decide_model(f_Hz, sig2d)
        logger.info(f"[res_spec_vs_flux] model_choice = {self.model_choice}")

        # reset outputs
        self.fr_g.clear(); self.fr_e.clear()
        self.Qc.clear();   self.Ql.clear(); self.Qi.clear()
        self.pg.clear();   self.snr.clear()

        if self.model_choice == "single":
            # single-notch per row → keep API stable by writing fr_e = nan
            for phi, sig_row in zip(flux_vals, sig2d):
                # Same as res_spec vs gain
                ret = ResonatorSpectroscopy.add_mag_and_unwind_and_fit(
                    freqs_MHz, sig_row, label=f"Φ={phi:.6g}"
                )
        
                params = ret.fit_result.params  # lmfit.Parameters object
        
                # Extract results (use .value safely)
                f0 = params["f_0"].value if "f_0" in params else np.nan
                Ql = params["Q_l"].value if "Q_l" in params else np.nan
                Qc = params["Q_c"].value if "Q_c" in params else np.nan
                Qi = params["Q_i"].value if "Q_i" in params else (
                    (Ql * Qc) / (Qc - Ql) if (Qc > Ql) else np.nan
                )
        
                # Append results (keep structure same as double fit)
                self.fr_g.append(f0)         # in MHz
                self.fr_e.append(np.nan)     # single-notch: no excited branch
                self.Ql.append(Ql)
                self.Qc.append(Qc)
                self.Qi.append(Qi)
                self.pg.append(np.nan)
                self.snr.append(float(ret.snr))
        else:
            # double-hanger for every row (your current path)
            for phi, sig_row in zip(flux_vals, sig2d):
                sig_unw, _ = _unwrap_and_remove_linear_phase(f_Hz, sig_row)
                p_fit, derived, snr = _fit_double_hanger(f_Hz, sig_unw)

                self.fr_g.append(p_fit["f0g"]/ 1e6)  # change to MHz
                self.fr_e.append(p_fit["f0e"]/ 1e6)  # change to MHz
                self.Ql.append(p_fit["Ql"])
                self.Qc.append(p_fit["Qc"])
                self.Qi.append(derived["Qi"])
                self.pg.append(p_fit["pg"])
                self.snr.append(snr)

    # -------------------------------- evaluate -------------------------------

    def evaluate(self) -> bool:
        """Simple pass/fail placeholder: all SNRs above a minimal threshold."""
        if not self.snr:
            self.report_output = ["No fits yet. Did you call analyze()?"]
            return False
        snr_thr = 2.0
        ok = bool(np.all(np.asarray(self.snr) >= snr_thr))
        msg = (
            f"## Resonator Spectroscopy vs Flux\n"
            f"Flux points: {len(self.snr)}\n"
            f"SNR min/median/max: {np.min(self.snr):.2f} / {np.median(self.snr):.2f} / {np.max(self.snr):.2f}\n"
            f"Pass threshold: {snr_thr}\n"
        )
        self.report_output = [msg]
        return ok


    def find_zero_flux_point(
        self,
        halfwidth_sym: float | None = None,  # window half-width around candidate (same units as flux)
        n_pairs: int = 25,                    # mirrored pairs for symmetry score
        halfwidth_slope: float | None = None, # window half-width for local slope fit
        slope_tol: float | None = None,       # flatness threshold on |slope|
        score_rel_drop: float = 0.25,         # how deep a local min must be vs neighbors
        max_candidates: int = 8
    ):
        """
        Find centers where the curve (f0 in single or f_r,g in double) is most even and locally flat.
        Uses self.fr_g only (fr_g == f0 in single model).
        Returns dict with: picked (list of {center, score, slope}), candidates, status, notes.
        """
        # x: flux/current; y: frequency track to use
        x = np.asarray(self.independents["flux"], float)
        y = np.asarray(self.fr_g, float)   # <-- single: f0; double: f_r,g
    
        # keep finite points
        m = np.isfinite(x) & np.isfinite(y)
        x, y = x[m], y[m]
        N = len(x)
        if N < 7:
            return dict(picked=[], candidates=[], status="bad_input", notes=["too few finite points"])
    
        # defaults for windows
        span = x[-1] - x[0]
        if halfwidth_sym   is None: halfwidth_sym   = 0.10 * span
        if halfwidth_slope is None: halfwidth_slope = 0.05 * span
    
        # data-driven slope tolerance if not provided
        dy = np.gradient(y, x)
        med_abs_slope = np.nanmedian(np.abs(dy))
        if slope_tol is None:
            slope_tol = 0.25 * (med_abs_slope if np.isfinite(med_abs_slope) else 1.0)
    
        # symmetry score at each x[k]
        scores = np.empty(N, float)
        for k in range(N):
            scores[k] = _symmetry_score(x, y, x[k], halfwidth_sym, n_pairs)
    
        # pick local minima of the score (5-point neighborhood)
        winsize = 5; halfw = winsize // 2
        candidates = []
        for k in range(halfw, N - halfw):
            local_avg = float(np.mean(scores[k - halfw:k + halfw + 1]))
            if scores[k] <= (1.0 - score_rel_drop) * local_avg:
                c = float(x[k])
                slope = _local_linear_slope(x, y, c, halfwidth_slope)
                candidates.append(dict(center=c, score=float(scores[k]), slope=float(slope)))
    
        # keep best few by score, then enforce flatness
        candidates = sorted(candidates, key=lambda d: d["score"])[:max_candidates]
        picked = [c for c in candidates if np.isfinite(c["slope"]) and abs(c["slope"]) <= slope_tol]
    
        notes = [
            f"model={self.model_choice or 'unknown'} (using fr_g: f0 in single, f_r,g in double)",
            f"halfwidth_sym={halfwidth_sym:.3g}",
            f"halfwidth_slope={halfwidth_slope:.3g}",
            f"slope_tol={slope_tol:.3g}",
            f"score_rel_drop={score_rel_drop}",
            f"n_pairs={n_pairs}",
            f"span={span:.3g}, N={N}",
        ]
        status = "ok" if picked else "no_flat_even_center_found"
        return dict(picked=picked, candidates=candidates, status=status, notes=notes)



In [None]:
# Test cell for the protocol above
params = ...  # given by initial guess modules and should contain steps, delay, gain
fluxes = FluxList(name="PhiPhi0", value=np.linspace(0, 2*np.pi, 101))

def set_flux(phi):
    # send phi to your DAC/flux-bias line
    ...

op = ResonatorSpectroscopyVsFlux(
    params,
    flux_list=fluxes,
    repetitions=Repetition(params),
    steps=Steps(params),
    start_freq=StartReadoutFrequency(params),
    end_freq=EndReadoutFrequency(params),
    readout_gain=ReadoutGain(params),
    length=ReadoutLength(params),
    delay=Delay(params),
    set_flux=set_flux,
    # optional tighter per-flux scan windows:
    # start_freq_list=[...], end_freq_list=[...],
    name="res_spec_vs_flux"
)

data_dir = op._measure_quick()                # collect sweeps → data.ddh5 (2D)
op.analyze()                                  # fit double hanger per flux row
ok = op.evaluate()                            # QA gate; see op.report_output[0] (may want to change SNR threshold)
zero_flux_point = op.find_zero_flux_point     # find zero flux points
print(op.fr_g[:5], op.fr_e[:5], ok)
print(zero_flux_point["status"])
for c in zero_flux_point["picked"]:
    print(f"center ≈ {c['center']:.6g}, score={c['score']:.3e}, slope={c['slope']:.3e}")


# Test only fits function for double hanger model
import logging
from pathlib import Path
from dataclasses import dataclass
from typing import Iterable, Optional, Callable, Sequence, Tuple, Dict

import numpy as np
import matplotlib.pyplot as plt
@dataclass
class FluxList:
    name: str
    value: Iterable[float]

def _unwrap_and_remove_linear_phase(freq: np.ndarray, sig: np.ndarray) -> Tuple[np.ndarray, float]:
    """Unwrap phase and remove best-fit linear slope (cable delay)."""
    # unwrap and linear fit: angle ≈ φ0 + 2π τ f, so it becomes continuous
    ph = np.unwrap(np.angle(sig))
    slope = np.polyfit(freq, ph, 1)[0]
    sig_unw = sig * np.exp(-1j * slope * freq)
    return sig_unw, slope/(2*np.pi)  # return τ estimate in seconds if freq in Hz

def _hanger_single(f: np.ndarray, f0, Ql, Qc, theta, a, phi0, tau) -> np.ndarray:
    x = (f - f0) / f0
    notch = 1.0 - (Ql/Qc) * np.exp(1j*theta) / (1.0 + 2j*Ql*x)
    return a * np.exp(1j*(phi0 + 2*np.pi*f*tau)) * notch

def _hanger_double_mix(f: np.ndarray, params: Dict[str, float]) -> np.ndarray:
    """Mixture of two nearby resonances with shared line params, weighted by p_g."""
    f0g, f0e = params["f0g"], params["f0e"]
    Ql, Qc   = params["Ql"],  params["Qc"]
    theta    = params["theta"]
    a        = params["a"]
    phi0     = params["phi0"]
    tau      = params["tau"]
    pg       = params["pg"]  # in [0,1]
    s_g = _hanger_single(f, f0g, Ql, Qc, theta, a, phi0, tau)
    s_e = _hanger_single(f, f0e, Ql, Qc, theta, a, phi0, tau)
    return pg*s_g + (1.0 - pg)*s_e

def _fit_double_hanger(freq: np.ndarray, sig_cplx: np.ndarray):
    f = np.asarray(freq, float)
    yC = np.asarray(sig_cplx, complex)

    # -------- de-embed linear phase (done in your pipeline, but safe here) --------
    ph = np.unwrap(np.angle(yC))
    b1, b0 = np.polyfit(f, ph, 1)         # ph ≈ b0 + b1 f
    y = yC * np.exp(-1j * (b0 + b1*f))    # unwrapped/de-sloped
    # NOTE: we fit in this frame; we won't refit tau here.

    # -------- initial guesses via local minima on smoothed |S| --------
    mag = np.abs(y)
    win = max(5, len(f)//80)
    mag_s = np.convolve(mag, np.ones(win)/win, mode="same")

    # local minima: indices i with mag_s[i] < neighbors
    locmins = np.where((mag_s[1:-1] < mag_s[:-2]) & (mag_s[1:-1] < mag_s[2:]))[0] + 1
    if len(locmins) == 0:
        locmins = np.array([int(np.argmin(mag_s))])
    # sort by depth
    locmins = locmins[np.argsort(mag_s[locmins])]

    # pick the deepest as ground; the closest *distinct* other min as excited
    i0 = locmins[0]
    # distinct = at least a few bins away
    min_sep_bins = max(5, len(f)//200)
    candidates = [j for j in locmins[1:] if abs(j - i0) >= min_sep_bins]
    j0 = candidates[0] if candidates else (i0 + min_sep_bins if i0 + min_sep_bins < len(f) else i0 - min_sep_bins)

    f0g0, f0e0 = np.sort([f[i0], f[j0]])

    # -------- line-term seeds (gentle) --------
    a0 = float(np.median(mag))
    theta0 = 0.0
    phi00 = 0.0
    tau0 = 0.0
    Ql0, Qc0 = 1.5e4, 3.0e4
    pg0 = 0.7

    # -------- bounds with a narrow f0e window around its seed --------
    # estimate natural linewidth ~ f0 / Ql0; allow a wide multiple
    linewidth = f0g0 / Ql0
    df_win = max(10.0*linewidth, 1.5e6)   # e.g. +/- 1.5 MHz minimum window
    lb = dict(f0g=f.min(),  f0e=f0e0 - 4*df_win, Ql=500,  Qc=500,
              theta=-np.pi, a=0.0, phi0=-np.pi, tau=-1e-6, pg=0.0)
    ub = dict(f0g=f.max(),  f0e=f0e0 + 4*df_win, Ql=5e5, Qc=5e6,
              theta= np.pi, a=10*np.max(mag), phi0= np.pi, tau= 1e-6, pg=1.0)

    params0 = dict(f0g=f0g0, f0e=f0e0, Ql=Ql0, Qc=Qc0, theta=theta0, a=a0, phi0=phi00, tau=tau0, pg=pg0)

    # -------- model & packing --------
    def _hanger_single(ff, f0, Ql, Qc, theta, a, phi0, tau):
        x = (ff - f0)/f0
        notch = 1.0 - (Ql/Qc)*np.exp(1j*theta)/(1.0 + 2j*Ql*x)
        return a*np.exp(1j*(phi0 + 2*np.pi*ff*tau))*notch

    def _hanger_double_mix(ff, p):
        sg = _hanger_single(ff, p["f0g"], p["Ql"], p["Qc"], p["theta"], p["a"], p["phi0"], p["tau"])
        se = _hanger_single(ff, p["f0e"], p["Ql"], p["Qc"], p["theta"], p["a"], p["phi0"], p["tau"])
        return p["pg"]*sg + (1.0 - p["pg"])*se

    keys = ["f0g","f0e","Ql","Qc","theta","a","phi0","tau","pg"]
    def pack(p):   return np.array([p[k] for k in keys], float)
    def unpack(v): return {k: float(x) for k,x in zip(keys, v)}

    v0  = pack(params0)
    vlb = pack(lb)
    vub = pack(ub)

    # -------- residual with weights (emphasize dips) --------
    # weight ~ 1/(mag_s^2) but bounded to avoid extremes
    w = 1.0/np.maximum(mag_s**2, 1e-4)
    w = w/np.max(w)
    def resid(v):
        p = unpack(v)
        s = _hanger_double_mix(f, p)
        r = np.stack([np.real(s) - np.real(y), np.imag(s) - np.imag(y)], axis=1)
        r = (r.T * w).T     # apply weights pointwise to both Re/Im
        return r.ravel()

    # -------- optimize (SciPy if present, else projected gradient) --------
    try:
        from scipy.optimize import least_squares
        res = least_squares(resid, v0, bounds=(vlb, vub), xtol=1e-12, ftol=1e-12, gtol=1e-12, max_nfev=4000)
        vf = res.x
    except Exception:
        vf = v0.copy()
        lr = 2e-6
        eps = np.array([5e3, 5e3, 50.0, 50.0, 1e-3, 1e-3, 1e-3, 1e-9, 1e-2])
        for _ in range(800):
            r0 = resid(vf); base = np.dot(r0, r0)
            g = np.zeros_like(vf)
            for k in range(len(vf)):
                vtmp = vf.copy(); vtmp[k] = np.clip(vtmp[k] + eps[k], vlb[k], vub[k])
                r2 = resid(vtmp); g[k] = (np.dot(r2,r2) - base)/eps[k]
            vf -= lr*g
            vf = np.minimum(np.maximum(vf, vlb), vub)

    p_fit = unpack(vf)
    # enforce ordering convention
    if p_fit["f0g"] > p_fit["f0e"]:
        p_fit["f0g"], p_fit["f0e"] = p_fit["f0e"], p_fit["f0g"]
        p_fit["pg"] = 1.0 - p_fit["pg"]

    # derived Qi and a simple SNR proxy
    Qi = (p_fit["Ql"]*p_fit["Qc"])/max(p_fit["Qc"] - p_fit["Ql"], 1e-6)
    rr = resid(vf).reshape(-1,2)
    noise = np.std(np.hypot(rr[:,0], rr[:,1]))
    amp = np.max(mag) - np.min(mag)
    snr = float(abs(amp)/(4*max(noise, 1e-12)))
    return p_fit, {"Qi": Qi}, snr

import numpy as np
import matplotlib.pyplot as plt

d = np.load("s12_fixed_flux_double_circle.npz")
f = d["freqs_Hz"]
Sg, Se, Smix = d["Sg"], d["Se"], d["Smix"]

# Complex-plane circles
plt.figure(figsize=(6,6))
plt.plot(Sg.real, Sg.imag, label="Sg(f)")
plt.plot(Se.real, Se.imag, label="Se(f)")
plt.plot(Smix.real, Smix.imag, label="Smix(f)", linewidth=2)
plt.gca().set_aspect("equal", "box")
plt.xlabel("Re{S12}"); plt.ylabel("Im{S12}")
plt.legend(); plt.tight_layout(); plt.show()

# Magnitude/phase vs frequency for the mixture
plt.figure(); plt.plot(f/1e9, np.abs(Smix)); plt.xlabel("GHz"); plt.ylabel("|S12|"); plt.show()
plt.figure(); plt.plot(f/1e9, np.unwrap(np.angle(Smix))); plt.xlabel("GHz"); plt.ylabel("arg(S12)"); plt.show()


import numpy as np
import matplotlib.pyplot as plt

# ------------------------------------------------------------
# 1) Load the fixed-flux dataset (Smix is the thing to fit)
# ------------------------------------------------------------
data = np.load("s12_fixed_flux_double_circle.npz")   # adjust path if needed
f_Hz  = data["freqs_Hz"]              # (Nf,)
Smix  = data["Smix"]                  # (Nf,), complex

# ------------------------------------------------------------
# 2) Same pre-fit helper you used elsewhere
# ------------------------------------------------------------
def _unwrap_and_remove_linear_phase(freq, sig):
    ph = np.unwrap(np.angle(sig))
    b1, b0 = np.polyfit(freq, ph, 1)          # ph ~ b0 + b1*f
    sig_unw = sig * np.exp(-1j * (b0 + b1*freq))
    tau_est = b1 / (2*np.pi)                  # seconds
    return sig_unw, tau_est

# ------------------------------------------------------------
# 3) Import your fitter (make sure it's on PYTHONPATH)
#    from your_module import _fit_double_hanger
#    (If it's already defined in your notebook, skip this import.)
# ------------------------------------------------------------
# from resonator_fit.double_hanger import _fit_double_hanger

# ------------------------------------------------------------
# 4) Fit
# ------------------------------------------------------------
sig_unw, tau_est = _unwrap_and_remove_linear_phase(f_Hz, Smix)
p_fit, derived, snr = _fit_double_hanger(f_Hz, sig_unw)

fr_g = p_fit["f0g"]     # Hz
fr_e = p_fit["f0e"]     # Hz

print(f"fr_g  = {fr_g:.3f} Hz  ({fr_g/1e9:.6f} GHz)")
print(f"fr_e  = {fr_e:.3f} Hz  ({fr_e/1e9:.6f} GHz)")
print(f"Ql    = {p_fit['Ql']:.3f},  Qc = {p_fit['Qc']:.3f},  Qi (derived) = {derived['Qi']:.3f}")
print(f"pg    = {p_fit['pg']:.3f},  theta = {p_fit['theta']:.3f}, a = {p_fit['a']:.6f}")
print(f"tau (unwrap est) ≈ {tau_est*1e9:.3f} ns,  SNR proxy = {snr:.3f}")


# model generated from fitted params in the *unwrapped* frame
S_model_unw = _hanger_double_mix(f_Hz, p_fit)

# align global complex gain/phase to the original data for a fair overlay:
scale = np.mean(Smix) / np.mean(S_model_unw)
S_model = S_model_unw * scale

plt.figure(figsize=(7.5,4.5))
plt.plot(f_Hz/1e9, np.abs(Smix),  label="data |S12|", alpha=0.7)
plt.plot(f_Hz/1e9, np.abs(S_model), label="fit  |S12|", lw=2)
plt.xlabel("Frequency (GHz)"); plt.ylabel("|S12|"); plt.legend(); plt.tight_layout(); plt.show()