In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
from scipy.optimize import least_squares

# === 1. Load Data ===
df = pd.read_excel("./ir_experiment.xlsx")
df.columns = df.columns.str.strip()
abs_cols = [c for c in df.columns if c not in ['Wavenumber','FA']]
x = df['Wavenumber'].values

# === 2. Model Functions ===
def gaussian(x, c, w, a):
    return a * np.exp(-(x - c)**2 / (2*w**2))

def multi_with_baseline(x, *p):
    n = (len(p) - 2) // 3
    y = sum(gaussian(x, p[3*i], p[3*i+1], p[3*i+2]) for i in range(n))
    return y + p[-2] + p[-1]*x

def residuals_var(vp, x, y, vidx, full):
    params = full.copy()
    for i, idx in enumerate(vidx):
        params[idx] = vp[i]
    return multi_with_baseline(x, *params) - y

# === 3. Shared Controls ===
control_layout = widgets.Layout(width='360px')
style = {"description_width": "initial"}

curve_dd = widgets.Dropdown(options=abs_cols, description="Curve:", layout=control_layout, style=style)

min_w = widgets.FloatSlider(value=400, min=400, max=4000, step=0.0001,
                            description="X min", layout=control_layout, readout_format='.4f')
max_w = widgets.FloatSlider(value=4000, min=400, max=4000, step=0.0001,
                            description="X max", layout=control_layout, readout_format='.4f')
min_a = widgets.FloatSlider(value=0, min=0, max=15000, step=0.0001,
                            description="Y min", layout=control_layout, readout_format='.4f')
max_a = widgets.FloatSlider(value=15000, min=0, max=15000, step=0.0001,
                            description="Y max", layout=control_layout, readout_format='.4f')

btn_set_x = widgets.Button(description="Set X 1500–1800", layout=widgets.Layout(width='140px'))
btn_set_x.on_click(lambda _: (setattr(min_w, 'value', 1500), setattr(max_w, 'value', 1800)))
btn_set_y = widgets.Button(description="Set Y 3000–11000", layout=widgets.Layout(width='140px'))
btn_set_y.on_click(lambda _: (setattr(min_a, 'value', 3000), setattr(max_a, 'value', 11000)))

# === 4. Manual Fit Controls ===
Nmax = 5
peak_count = widgets.ToggleButtons(options=list(range(1, Nmax+1)),
                                   description="Peaks:", layout=widgets.Layout(width='auto'), style=style)

sliders, locks = [], []
ratios, ratio_locks = [None]*Nmax, [None]*Nmax

for i in range(Nmax):
    for name, mn, mx, val in [("C",400,4000,1600), ("W",1,1000,100), ("A",0,15000,1000)]:
        s = widgets.FloatSlider(value=val, min=mn, max=mx, step=0.0001,
                                description=f"{name}{i+1}", layout=control_layout, readout_format='.4f')
        t = widgets.ToggleButton(value=False, icon='lock', description='',
                                 tooltip=f"Lock {name}{i+1}", layout=widgets.Layout(width='40px'))
        sliders.append(s); locks.append(t)
    if i > 0:
        r = widgets.FloatSlider(value=sliders[3*i+2].value/sliders[2].value,
                                min=0.1, max=10, step=0.0001,
                                description=f"R{i+1}", layout=control_layout, readout_format='.4f')
        rl = widgets.ToggleButton(value=False, icon='lock', description='',
                                  tooltip=f"Lock R{i+1}", layout=widgets.Layout(width='40px'))
        ratios[i], ratio_locks[i] = r, rl

baseline_b0 = widgets.FloatSlider(value=0, min=-10000, max=10000, step=0.0001,
                                  description="b0", layout=control_layout, readout_format='.4f')
lock_b0     = widgets.ToggleButton(value=False, icon='lock', description='',
                                   tooltip="Lock b0", layout=widgets.Layout(width='40px'))
baseline_b1 = widgets.FloatSlider(value=0, min=-100, max=100, step=0.0001,
                                  description="b1", layout=control_layout, readout_format='.4f')
lock_b1     = widgets.ToggleButton(value=False, icon='lock', description='',
                                   tooltip="Lock b1", layout=widgets.Layout(width='40px'))

manual_box = widgets.VBox(layout=widgets.Layout(margin='10px 0 0 0'))
def update_manual(_=None):
    N = peak_count.value
    items = []
    for i in range(N):
        for j in range(3):
            idx = 3*i + j
            items.append(widgets.HBox([sliders[idx], locks[idx]]))
        if i > 0:
            items.append(widgets.HBox([ratios[i], ratio_locks[i]]))
    items.append(widgets.HBox([baseline_b0, lock_b0]))
    items.append(widgets.HBox([baseline_b1, lock_b1]))
    manual_box.children = tuple(items)

peak_count.observe(update_manual, names='value')
update_manual()

# === 5. Safe Ratio-Linked Callback ===
ratio_updating = False
def on_amp_change(change):
    global ratio_updating
    if change.name != 'value' or ratio_updating:
        return
    idx = sliders.index(change.owner)
    if idx % 3 != 2:
        return
    peak = idx // 3; N = peak_count.value
    if peak > 0 and not ratio_locks[peak].value:
        return
    ratio_updating = True
    if peak == 0:
        A1 = change.new
    else:
        A1 = change.new / ratios[peak].value
        sliders[2].value = A1
    for j in range(1, N):
        if ratio_locks[j].value and j != peak:
            sliders[3*j+2].value = A1 * ratios[j].value
    ratio_updating = False

for i in range(Nmax):
    sliders[3*i+2].observe(on_amp_change, names='value')

# === 6. Plotting ===
plot_out = widgets.Output(layout=widgets.Layout(width='800px', height='600px'))
def redraw(_=None):
    with plot_out:
        clear_output(wait=True)
        y = df[curve_dd.value].values
        mask = (x >= min_w.value) & (x <= max_w.value)
        xv, yv = x[mask], y[mask]
        N = peak_count.value
        p = []
        for i in range(N):
            c = sliders[3*i].value
            w = sliders[3*i+1].value
            a = sliders[3*i+2].value
            if i > 0 and ratio_locks[i].value:
                a = p[2] * ratios[i].value
            p += [c, w, a]
        p += [baseline_b0.value, baseline_b1.value]
        plt.figure(figsize=(12, 8))
        plt.plot(xv, yv, label='Exp', lw=2)
        for i in range(N):
            plt.plot(xv, gaussian(xv, *p[3*i:3*i+3]), '--', label=f'P{i+1}')
        plt.plot(xv, p[-2] + p[-1]*xv, ':', label='Baseline')
        plt.plot(xv, multi_with_baseline(xv, *p), '-', label='Total')
        plt.xlim(min_w.value, max_w.value)
        plt.ylim(min_a.value, max_a.value)
        plt.legend(); plt.show()

watchers = [curve_dd, min_w, max_w, min_a, max_a, peak_count,
            baseline_b0, baseline_b1] + sliders + [r for r in ratios if r]
for w in watchers:
    w.observe(redraw, names='value')
redraw()

# === 7. Auto Fit + Full Summary ===
auto_count = widgets.ToggleButtons(options=list(range(1, Nmax+1)),
                                   description="Auto Peaks:", layout=widgets.Layout(width='auto'), style=style)
bounds_w = {k: widgets.FloatText(value=v, description=k, layout=control_layout)
            for k, v in [('cmin',400), ('cmax',4000), ('wmin',1), ('wmax',1000), ('amin',0), ('amax',15000)]}
export   = widgets.Text(value="fit_results.csv", description="CSV Path:", layout=control_layout)
auto_btn  = widgets.Button(description="Run Auto Fit", button_style='success')
out_sum   = widgets.Output()
df_sum    = pd.DataFrame()

def run_auto(_=None):
    N = auto_count.value
    y = df[curve_dd.value].values
    xf, yf = x[(x>=min_w.value)&(x<=max_w.value)], y[(x>=min_w.value)&(x<=max_w.value)]
    if len(xf) < N+2:
        with out_sum: clear_output(); print("❌ X-range too narrow"); return
    centers = np.linspace(xf.min(), xf.max(), N+2)[1:-1]
    p0, lb, ub = [], [], []
    for i, cg in enumerate(centers):
        idx0 = 3*i
        c0 = sliders[idx0].value if locks[idx0].value else cg
        w0 = sliders[idx0+1].value if locks[idx0+1].value else (bounds_w['wmin'].value + bounds_w['wmax'].value)/2
        a0 = np.max(yf)/N
        if locks[idx0+2].value:
            a0 = sliders[idx0+2].value
        if i > 0 and ratio_locks[i].value:
            a0 = sliders[2].value * ratios[i].value
        p0 += [c0, w0, a0]
        lb += [bounds_w['cmin'].value, bounds_w['wmin'].value, bounds_w['amin'].value]
        ub += [bounds_w['cmax'].value, bounds_w['wmax'].value, bounds_w['amax'].value]
    p0 += [baseline_b0.value, baseline_b1.value]; lb += [-10000, -100]; ub += [10000, 100]
    vidx = [i for i in range(3*N) if not locks[i].value]
    if not lock_b0.value: vidx.append(3*N)
    if not lock_b1.value: vidx.append(3*N+1)
    p0_var = [p0[i] for i in vidx]
    lb_var  = [lb[i] for i in vidx]; ub_var = [ub[i] for i in vidx]
    try:
        res = least_squares(residuals_var, p0_var, bounds=(lb_var, ub_var),
                            args=(xf, yf, vidx, p0), max_nfev=10000)
        full = p0.copy()
        for k, idx in enumerate(vidx):
            full[idx] = res.x[k]
        peak_count.value = N
        for i in range(N):
            for j in range(3):
                idx = 3*i + j
                if not locks[idx].value:
                    sliders[idx].value = full[idx]
        for i in range(1, N):
            if not ratio_locks[i].value:
                ratios[i].value = full[3*i+2] / full[2]
        if not lock_b0.value: baseline_b0.value = full[-2]
        if not lock_b1.value: baseline_b1.value = full[-1]
        # full summary
        d = {'N': N, 'RSS': np.sum(res.fun**2)}
        for i in range(N):
            d[f'peak{i+1}_center'] = full[3*i]
            d[f'peak{i+1}_width']  = full[3*i+1]
            d[f'peak{i+1}_amp']    = full[3*i+2]
        d['b0'], d['b1'] = full[-2], full[-1]
        global df_sum; df_sum = pd.DataFrame([d])
        with out_sum: clear_output(); display(df_sum)
    except Exception as e:
        with out_sum: clear_output(); print("❌ Auto fit error:", e)

auto_btn.on_click(run_auto)
save_btn = widgets.Button(description="Save CSV")
save_btn.on_click(lambda _: (df_sum.to_csv(export.value, index=False), print(f"📄 Saved to {export.value}")))

# === 8. Layout & Display ===
controls = widgets.VBox([
    curve_dd,
    widgets.HBox([min_w, btn_set_x]),
    widgets.HBox([max_w, btn_set_y]),
    widgets.HBox([min_a]),
    widgets.HBox([max_a]),
    peak_count,
    manual_box,
    auto_count, *bounds_w.values(), export, auto_btn, save_btn, out_sum
], layout=widgets.Layout(width='380px'))

ui = widgets.HBox([controls, plot_out], layout=widgets.Layout(margin='10px'))
display(ui)


HBox(children=(VBox(children=(Dropdown(description='Curve:', layout=Layout(width='360px'), options=('0 wt', '1…