In [5]:
import numpy as np
import scipy.special as sp
import matplotlib.pyplot as plt

def CRLB_for_CRT(s,tr,td,Nsc,Ppr=1,ne=100,plot=True):
    
    """
    Calculates the Cramer-Rao Lower Bound on timing resolution for a scintillation detector.
    
    Parameters:
        s (float) (ns): Effective transit time spread in terms of a standard deviation.
        tr (array-like) (ns): Array of rise time constants for light-emitting mechanisms. The ith element of tr corresponds to the ith elements of td and Ppr. Must have length equal to td and Ppr.
        td (array-like) (ns): Array of decay time constants for light-emitting mechanisms.
        Nsc (int): Number of detected scintillation photons.
        Ppr (array-like, optional): Probabilities for each light-emitting mechanism. Must sum to 1. (Default is 1).
        ne (int, optional): Number of order statistics to compute (default is 100).
        plot (bool, optional): Whether to plot results (default is True).
    
    Returns:

        CRT_ordered (array) (ps): Calculated values for nth timestamp.
        CRT_cumulative (array) (ps): Calculated values for smallest n timestamps.
        Lower_bound (float) (ps): Lower bound for CRT values.
    """
    dt=0.001
    t_start = -dt
    t_end = 10
    t0=0
    t_tr = 4.4332 * s
    t = np.arange(t_start, t_end + dt, dt)
    nt = len(t)


    if Ppr is None:
        Ppr = 1

    t = t if t.ndim > 1 else t.T
    tr = np.array(tr)  
    td = np.array(td) 
    Ppr = np.array(Ppr)  
    tr_td_match =  [a == b for a, b in zip(tr, td)]
    if True in tr_td_match:
        raise NotImplementedError('Case tr == td is not implemented.')
        
        
    if len(tr) != len(td) or len(tr) != len(Ppr):
        raise ValueError("tr, td, and Ppr must have the same length.")
    if Nsc <= 0 or ne <= 0:
        raise ValueError("Nsc and ne must be positive integers.")
    if np.any(td <= 0):
        raise ValueError("td must contain only positive floats.")  
    if np.any(Ppr <= 0):
        raise ValueError("Ppr must contain only positive floats.")  
    if np.any(tr <= 0):
        raise ValueError("tr must contain only positive floats.")  
    if s <= 0:
        raise ValueError("s must be a positive float.") 
        
    if np.sum(Ppr) != 1:
        raise ValueError("Sum of Ppr must be 1.")

        
    

    ad = lambda t, td_i: 0.5 * np.exp(-(t - t0 - t_tr) / td_i + s**2 / (2 * td_i**2)) * (sp.erf((t - t0 - t_tr - s**2 / td_i) / (s * np.sqrt(2))) + sp.erf((t_tr + s**2 / td_i) / (s * np.sqrt(2))))
    ar = lambda t, tr_i: 0.5 * np.exp(-(t - t0 - t_tr) / tr_i + s**2 / (2 * tr_i**2)) * (sp.erf((t - t0 - t_tr - s**2 / tr_i) / (s * np.sqrt(2))) + sp.erf((t_tr + s**2 / tr_i) / (s * np.sqrt(2))))
    g = lambda t: 0.5 * (sp.erf((t - t0 - t_tr) / (s * np.sqrt(2))) + sp.erf(t_tr / (s * np.sqrt(2))))


    
    C = 2 / (1 + sp.erf(t_tr / (s * np.sqrt(2))))


    def p_ptr(t):
        """
        Calculates the probability density function value at a given time t.
        """
        p_ptrs = []
        
        for i in range(len(tr)):
            p_ptrs.append(np.concatenate([
            np.zeros(t[t < t0].shape),  # Zero values for t < t0
            C * (1 / (td[i] - tr[i]) * ad(t[t >= t0],td[i]) - 1 / (td[i] - tr[i]) * ar(t[t >= t0],tr[i]))  # Calculation for t >= t0
            ]))

        p_ptrs = [ai * bi for ai, bi in zip(p_ptrs, Ppr)] # Weighs p_ptrs by Ppr
        p_ptr = np.sum(p_ptrs, axis=0)

        
        return p_ptr
    
    def P_ptr(t):
        """
        Calculates the cumulative distribution function at a given time t.
        """

        P_ptrs=[]


        for i in range(len(tr)):

            P_ptrs.append(np.concatenate([
            np.zeros(t[t < t0].shape),  # Zero values for t < t0
            C * (
                g(t[t >= t0]) +
                1 * tr[i] / (td[i] - tr[i]) * ar(t[t >= t0],tr[i]) -
                1 * td[i] / (td[i] - tr[i]) * ad(t[t >= t0],td[i])
            )  # Calculation for t >= t0
        ]))

        P_ptrs = [ai * bi for ai, bi in zip(P_ptrs, Ppr)] # Weighs p_ptrs by Ppr
        P_ptr = np.sum(P_ptrs, axis=0)

        return P_ptr




    # Hazard Function and Order Statistics
    h = lambda t: p_ptr(t) / (1 - P_ptr(t))
    Fn = lambda n, Nsc, t: sp.betainc(n, Nsc - n + 1, P_ptr(t))
    fn = lambda n, Nsc, t: 1 / sp.beta(n, Nsc - n + 1) * P_ptr(t)**(n - 1) * (1 - P_ptr(t))**(Nsc - n) * p_ptr(t)


    with np.errstate(divide='ignore', invalid='ignore'):
        ptr_t = (p_ptr(t + dt / 2) + p_ptr(t - dt / 2)) / 2
        dptr_dt = (p_ptr(t + dt / 2) - p_ptr(t - dt / 2)) / dt
        h_t = (h(t + dt / 2) + h(t - dt / 2)) / 2
        d_ln_h_dt = np.nan_to_num(dptr_dt / ptr_t) + h_t

        
    I1_nNsc = np.zeros(ne)
    In = np.zeros(ne)
    fn_t = np.zeros((ne, nt))

    for n in range(1, ne + 1):
        Fn_t = (Fn(n, Nsc - 1, t + dt / 2) + Fn(n, Nsc - 1, t - dt / 2)) / 2
        fn_t[n - 1, :] = (fn(n, Nsc, t + dt / 2) + fn(n, Nsc, t - dt / 2)) / 2
        fn_t[n - 1, np.isnan(fn_t[n - 1, :])] = 0

        dfn_dt = (fn(n, Nsc, t + dt / 2) - fn(n, Nsc, t - dt / 2)) / dt
        dfn_dt[np.isnan(dfn_dt)] = 0

        with np.errstate(divide='ignore', invalid='ignore'):
            dump = np.nan_to_num(dfn_dt / fn_t[n - 1, :])

        I1_nNsc[n - 1] = Nsc * np.sum(d_ln_h_dt**2 * (1 - Fn_t) * ptr_t) * dt
        In[n - 1] = np.sum(dump**2 * fn_t[n - 1, :]) * dt

    P = 1
    t_end = max(2 * max(td), t_end)
    while P > 1e-3:
        t_end += 0.1
        P = 1 - P_ptr(t_end)

    t_long = np.arange(-dt, t_end + dt, dt)
    with np.errstate(divide='ignore', invalid='ignore'):
        dump = ((p_ptr(t_long + dt / 2) - p_ptr(t_long - dt / 2)) / dt)**2 / ((p_ptr(t_long + dt / 2) + p_ptr(t_long - dt / 2)) / 2)
    dump[np.isnan(dump)] = 0
    I1_11 = np.sum(dump) * dt

    # Calculate variances
    t_reshaped = t[:, np.newaxis]  # Reshape t for proper broadcasting
    E = np.sum(t_reshaped * fn_t[:ne, :].T, axis=0) * dt 
    var = np.sum((t_reshaped - E) ** 2 * fn_t[:ne, :].T, axis=0) * dt
    
    CRT_ordered = np.sqrt(2 / In) * 1000 * 2.355
    CRT_cumulative = np.sqrt(2 / I1_nNsc) * 1000 * 2.355
    Lower_bound = (np.sqrt(2 / (Nsc * I1_11)) * 1000 * 2.355)

    if plot:
        plt.plot(CRT_ordered, 'o', label='nth timestamp')
        plt.plot(CRT_cumulative, 'x', label='Smallest n timestamps')
        plt.plot([1, ne], [Lower_bound] * 2, 'r', label='Lower Bound')

        plt.xlabel('Photon Number')
        plt.ylabel('CRT (ps)')
        plt.legend()
        plt.show()
        
    
    return CRT_ordered, CRT_cumulative, Lower_bound


