In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_active_states(n, l, m, occ, cutoff=1e-3, mode="scatter", max_l_panels=6):
    """
    Plot active hydrogenic states above a given occupation threshold.
    
    Parameters
    ----------
    n, l, m : arrays of quantum numbers (same length).
    occ     : array of occupation numbers.
    cutoff  : threshold for marking a state as 'active'.
    mode    : "scatter" for (n,l) view (m>0),
              "grid" for (n,m) panels grouped by l.
    max_l_panels : maximum number of l panels (grid mode).
    """
    mask = occ > cutoff
    n, l, m, occ = n[mask], l[mask], m[mask], occ[mask]

    if mode == "scatter":
        # Only keep m > 0
        mask_m = m > 0
        plt.figure(figsize=(7,5))
        plt.scatter(n[mask_m], l[mask_m], marker="o", c="blue", alpha=0.7)
        plt.xlabel("Principal quantum number n")
        plt.ylabel("Orbital angular momentum l")
        plt.title(f"Active states (m > 0, occ > {cutoff})")
        plt.grid(True, alpha=0.3)
        plt.show()

    elif mode == "grid":
        unique_ls = sorted(set(l.astype(int)))
        if len(unique_ls) > max_l_panels:
            unique_ls = unique_ls[:max_l_panels]

        fig, axes = plt.subplots(1, len(unique_ls), figsize=(3*len(unique_ls), 5), sharey=True)

        if len(unique_ls) == 1:
            axes = [axes]

        for idx, l_val in enumerate(unique_ls):
            ax = axes[idx]
            mask_l = (l == l_val)
            ax.scatter(n[mask_l], m[mask_l], marker="o", c="red")
            ax.set_title(f"l = {l_val}")
            ax.set_xlabel("n")
            if idx == 0:
                ax.set_ylabel("m")
            ax.grid(True, alpha=0.3)

        plt.suptitle(f"Active states (occ > {cutoff})")
        plt.show()

    else:
        raise ValueError("mode must be 'scatter' or 'grid'.")


# Example usage with mock data
if __name__ == "__main__":
    np.random.seed(123)
    n = np.random.randint(1, 15, 200)
    l = np.array([np.random.randint(0, ni) for ni in n])
    m = np.array([np.random.randint(-li, li+1) for li in l])
    occ = np.random.lognormal(mean=-2, sigma=1, size=len(n))

    # Switch mode: "scatter" or "grid"
    plot_active_states(n, l, m, occ, cutoff=0.05, mode="scatter")
    plot_active_states(n, l, m, occ, cutoff=0.05, mode="grid")