In [45]:
# -*- coding: utf-8 -*-
import math
from typing import List, Tuple, Dict, Any
import pandas as pd
from arc import Caesium, Rubidium

State = Tuple[int, int, float]  # (n, l, j)

# -------------------------
# Helpers
# -------------------------
def make_atom(name: str):
    return Caesium() if name.lower() in ("cs", "caesium", "cesium") else Rubidium()

def valid_lj(l: int, j: float) -> bool:
    return abs(j - (l + 0.5)) < 1e-9 or abs(j - (l - 0.5)) < 1e-9

def energy_eV(atom, s):
    try:
        return atom.getEnergy(*s)  # eV
    except Exception:
        return None

def is_downward(atom, upper: State, lower: State) -> bool:
    Eu = energy_eV(atom, upper)
    El = energy_eV(atom, lower)
    return (Eu is not None) and (El is not None) and (El < Eu)

def trans_rate(atom, upper: State, lower: State, temperature_K: float = 0.0) -> float:
    """ARC transition rate A_{upper->lower} [s^-1], includes BBR at T if provided."""
    n1, l1, j1 = upper
    n2, l2, j2 = lower
    if not (valid_lj(l1, j1) and valid_lj(l2, j2)):
        return 0.0
    try:
        A = atom.getTransitionRate(n1, l1, j1, n2, l2, j2, temperature=temperature_K)
        return max(0.0, A)
    except Exception:
        return 0.0

def trans_wavelength_m(atom, upper: State, lower: State) -> float:
    n1, l1, j1 = upper
    n2, l2, j2 = lower
    try:
        lam = atom.getTransitionWavelength(n1, l1, j1, n2, l2, j2)  # [m], sign may encode direction
        return abs(lam)
    except Exception:
        return float("nan")

def trans_freq_hz(atom, upper: State, lower: State) -> float:
    n1, l1, j1 = upper
    n2, l2, j2 = lower
    try:
        f = atom.getTransitionFrequency(n1, l1, j1, n2, l2, j2)  # [Hz]
        return abs(f)
    except Exception:
        return float("nan")

def e1_downward_neighbors(atom, upper: State,
                          n_floor: int = 1,
                          n_ceiling: int = None) -> List[State]:
    """
    Generate E1-allowed *downward* neighbors from 'upper'.
    Ensures n' > l' (i.e., n_min_allowed = max(n_floor, l'+1)).
    """
    n, l, j = upper
    if n_ceiling is None:
        n_ceiling = max(n + 40, 40)

    cands: List[State] = []
    for lp in (l - 1, l + 1):
        if lp < 0:
            continue
        for jp in (lp - 0.5, lp + 0.5):
            if jp <= 0 or not valid_lj(lp, jp):
                continue
            # ★ 핵심: n' 하한은 반드시 l'+1 이상이어야 함
            n_min_allowed = max(n_floor, int(lp) + 1)
            for np in range(n_min_allowed, n_ceiling + 1):
                lower = (np, lp, jp)
                if is_downward(atom, upper, lower):
                    cands.append(lower)

    # de-duplicate
    uniq, seen = [], set()
    for s in cands:
        if s not in seen:
            seen.add(s); uniq.append(s)
    return uniq

# -------------------------
# Core API
# -------------------------
def channel_lifetime(
    atom_name: str,
    upper: State,
    lower: State,
    temperature_K: float = 0.0,
    n_floor: int = 1,
    n_ceiling: int | None = None,
) -> Tuple[Dict[str, Any], pd.DataFrame]:
    """
    Compute:
      - A_ij: transition rate upper->lower [s^-1] at T
      - A_total: total decay rate out of upper (sum over all downward E1 channels) [s^-1]
      - tau_total = 1/A_total  (mean time to first decay event)
      - branching_ratio = A_ij / A_total
      - tau_channel_naive = 1/A_ij  (not physical 'first-photon time', but useful)
    Also returns a DataFrame of all downward channels from 'upper'.
    """
    atom = make_atom(atom_name)

    # Basic sanity
    if not (valid_lj(upper[1], upper[2]) and valid_lj(lower[1], lower[2])):
        raise ValueError("Invalid (l,j) pair: j must be l±1/2 for both upper and lower.")

    if not is_downward(atom, upper, lower):
        raise ValueError("Given lower state is not below upper (E_lower >= E_upper).")

    # Specific channel
    A_ij = trans_rate(atom, upper, lower, temperature_K=temperature_K)

    # Enumerate all downward channels & sum
    neigh = e1_downward_neighbors(atom, upper, n_floor=n_floor, n_ceiling=n_ceiling)
    rows = []
    A_total = 0.0
    for s2 in neigh:
        A = trans_rate(atom, upper, s2, temperature_K=temperature_K)
        if A <= 0.0:
            continue
        A_total += A
        lam = trans_wavelength_m(atom, upper, s2)
        f_hz = trans_freq_hz(atom, upper, s2)
        rows.append({
            "lower": f"{s2[0]}{'SPDFGH'[s2[1]] if s2[1] < 6 else 'L'+str(s2[1])}_{int(2*s2[2])}/2",
            "lambda_nm": lam*1e9 if math.isfinite(lam) else float("nan"),
            "A_s^-1": A,  # Einstein A
        })

    df = pd.DataFrame(rows)
    if not df.empty:
        df["branching_ratio"] = df["A_s^-1"] / A_total if A_total > 0 else 0.0
        df.sort_values("A_s^-1", ascending=False, inplace=True, ignore_index=True)

    tau_total = (1.0 / A_total) if A_total > 0 else float("inf")
    tau_channel_naive = (1.0 / A_ij) if A_ij > 0 else float("inf")
    branching_ratio = (A_ij / A_total) if A_total > 0 else 0.0

    result = {
        "upper": upper,
        "lower": lower,
        "temperature_K": temperature_K,
        "A_ij_s^-1": A_ij,
        "A_total_s^-1": A_total,
        "tau_total_s": tau_total,
        "branching_ratio": branching_ratio,
        "tau_channel_naive_s": tau_channel_naive,
        "lambda_m(upper->lower)": trans_wavelength_m(atom, upper, lower),
        "f_THz(upper->lower)": trans_freq_hz(atom, upper, lower) / 1e12,
        "note": (
            "tau_total is the mean time to first decay event. "
            "tau_channel_naive = 1/A_ij is a single-channel time constant (interpret with BR)."
        ),
    }

    return result, df


# -------------------------
#  usage
# -------------------------
# 
# 
# 
if __name__ == "__main__":
    # 예: Cs 21P3/2 → 20D5/2 @ 300 K
    atom_name = "Cs"
    atom = make_atom(atom_name)
    upper = (15, 2, 1.5)   # 21P3/2
    lower = (14, 2, 2.5)   # 20D5/2
    Eu, El = atom.getEnergy(*upper), atom.getEnergy(*lower)   # eV, 이온화=0, 결합상태는 음수
    print("=== optical selection rule ok? ===")
    print("E_upper?", Eu, "E_lower?", El, "downward?", El < Eu)
    print("\n")
    
    res, table = channel_lifetime(atom_name, upper, lower, temperature_K=325.0, n_floor=1, n_ceiling=60)

    print("=== Channel summary ===")
    for k, v in res.items():
        print(f"{k}: {v}")

    print("\n=== All downward E1 channels (top 10 by A) ===")
    print(table.head(10).to_string(index=False))

=== optical selection rule ok? ===
E_upper? -0.08673570754215366 E_lower? -0.10227953629204023 downward? True


=== Channel summary ===
upper: (15, 2, 1.5)
lower: (14, 2, 2.5)
temperature_K: 325.0
A_ij_s^-1: 0.0
A_total_s^-1: 3448314.805565088
tau_total_s: 2.8999672488896394e-07
branching_ratio: 0.0
tau_channel_naive_s: inf
lambda_m(upper->lower): 7.976425913345518e-05
f_THz(upper->lower): 3.7584810698035978
note: tau_total is the mean time to first decay event. tau_channel_naive = 1/A_ij is a single-channel time constant (interpret with BR).

=== All downward E1 channels (top 10 by A) ===
  lower     lambda_nm       A_s^-1  branching_ratio
 5P_1/2    138.193811 2.230081e+06         0.646716
 5P_3/2    146.262903 3.502522e+05         0.101572
 6P_1/2    512.068719 3.472827e+05         0.100711
 7P_1/2   1118.374421 1.074685e+05         0.031166
 6P_3/2    527.020611 6.741054e+04         0.019549
 8P_1/2   2000.775665 4.815658e+04         0.013965
 4P_3/2    204.175772 4.729337e+04     