In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import periodogram
from spectral_connectivity import Multitaper, Connectivity
from spectral_decomposition import spectrum

# ——— Parameters ——————————————————————————————————————————————————
fs        = 1000.0
durations = [0.1, 0.5, 1.0]   # s
TWs = [1, 2, 3]
#TW        = 3.0
#K         = int(2*TW - 1)
n_iter    = 10000
seeds     = np.arange(n_iter)
freqs_of_interest = [12, 30, 50, 80]  # Hz


for duration in durations:

    # ——— True PSD in dB ——————————————————————————————————————
    res_th = spectrum(
        sampling_rate      = fs,
        duration           = duration,
        aperiodic_exponent = 2.0,
        aperiodic_offset   = 0.5,
        knee               = 200.0,
        peaks              = [{'freq':12,'amplitude':1.0,'sigma':2.0}],
        average_firing_rate= 0.0,
        random_state       = 0,
        direct_estimate    = False,
        plot               = False
    )
    f_th = res_th.frequency_domain.frequencies
    P_th = res_th.frequency_domain.combined_spectrum
    mask = f_th >= 0
    f_th, P_th = f_th[mask], P_th[mask]
    P_th_db = 10 * np.log10(P_th)

    # ——— empty‐signal freq grids ——————————————————————————————————
    N = int(fs * duration)
    #f_hann, _ = periodogram(np.zeros(N), fs=fs, window='hann',   nfft=N, scaling='density')
    #f_hamm, _ = periodogram(np.zeros(N), fs=fs, window='hamming', nfft=N, scaling='density')
    
    f_mt1      = Connectivity.from_multitaper(
                   Multitaper(np.zeros(N), sampling_frequency=fs,
                             time_halfbandwidth_product=1, n_tapers=1)
               ).frequencies
    f_mt2      = Connectivity.from_multitaper(
                   Multitaper(np.zeros(N), sampling_frequency=fs,
                             time_halfbandwidth_product=2, n_tapers=3)
               ).frequencies
    f_mt3      = Connectivity.from_multitaper(
                   Multitaper(np.zeros(N), sampling_frequency=fs,
                             time_halfbandwidth_product=3, n_tapers=5)
               ).frequencies

    # ——— allocate storage for bias in dB ————————————————————————
    #bias_hann_db = np.zeros((n_iter, len(f_hann)))
    #bias_hamm_db = np.zeros_like(bias_hann_db)
    bias_mt_db1   = np.zeros((n_iter, len(f_mt1)))
    bias_mt_db2   = np.zeros((n_iter, len(f_mt2)))
    bias_mt_db3   = np.zeros((n_iter, len(f_mt3)))

    # ——— Monte Carlo draws ————————————————————————————————————
    for i, seed in enumerate(seeds):
        res = spectrum(
            sampling_rate      = fs,
            duration           = duration,
            aperiodic_exponent = 2.0,
            aperiodic_offset   = 0.5,
            knee               = 200.0,
            peaks              = [{'freq':12,'amplitude':1.0,'sigma':2.0}],
            average_firing_rate= 0.0,
            random_state       = int(seed),
            direct_estimate    = True,
        )
        sig = res.time_domain.combined_signal

        # single‐taper
        #_, P_h = periodogram(sig,  fs=fs, window='hann',   nfft=len(sig), scaling='density')
        #_, P_m = periodogram(sig,  fs=fs, window='hamming', nfft=len(sig), scaling='density')
        #P_h_db = 10 * np.log10(P_h)
        #P_m_db = 10 * np.log10(P_m)

        # multitaper
        conn1     = Connectivity.from_multitaper(
                       Multitaper(sig, sampling_frequency=fs,
                                 time_halfbandwidth_product=1, n_tapers=1)
                   )
        fcur1, P_t1 = conn1.frequencies, conn1.power().squeeze()
        P_t_db1    = 10 * np.log10(P_t1)
        
        
        
        conn2     = Connectivity.from_multitaper(
                       Multitaper(sig, sampling_frequency=fs,
                                 time_halfbandwidth_product=2, n_tapers=3)
                   )
        fcur2, P_t2 = conn2.frequencies, conn2.power().squeeze()
        P_t_db2    = 10 * np.log10(P_t2)
        
        
        
        
        conn3     = Connectivity.from_multitaper(
                       Multitaper(sig, sampling_frequency=fs,
                                 time_halfbandwidth_product=3, n_tapers=5)
                   )
        fcur3, P_t3 = conn3.frequencies, conn3.power().squeeze()
        P_t_db3    = 10 * np.log10(P_t3)
        
        

        # interpolate true dB‐PSD
        #P_th_hann_db = np.interp(f_hann, f_th, P_th_db)
        #bias_hann_db[i,:] = P_h_db - P_th_hann_db

        #P_th_hamm_db = np.interp(f_hamm, f_th, P_th_db)
        #bias_hamm_db[i,:] = P_m_db - P_th_hamm_db

        P_th_mt_db1  = np.interp(fcur1, f_th, P_th_db)
        bias_mt_db1[i,:]    = P_t_db1 - P_th_mt_db1
        
        P_th_mt_db2  = np.interp(fcur2, f_th, P_th_db)
        bias_mt_db2[i,:]    = P_t_db2 - P_th_mt_db2
        
        P_th_mt_db3  = np.interp(fcur3, f_th, P_th_db)
        bias_mt_db3[i,:]    = P_t_db3 - P_th_mt_db3

    # ——— find indices of interest ————————————————————————————————
    #idx_hann = [np.argmin(np.abs(f_hann - f)) for f in freqs_of_interest]
    #idx_hamm = [np.argmin(np.abs(f_hamm - f)) for f in freqs_of_interest]
    idx_mt1   = [np.argmin(np.abs(f_mt1   - f)) for f in freqs_of_interest]
    idx_mt2   = [np.argmin(np.abs(f_mt2   - f)) for f in freqs_of_interest]
    idx_mt3   = [np.argmin(np.abs(f_mt3   - f)) for f in freqs_of_interest]

    # ——— store as csv ——————————————————————————————————————————
    #df_bias_hann_db = pd.DataFrame(bias_hann_db, columns=[f"Freq_{i}" for i in range(len(f_hann))])
    #df_bias_hamm_db = pd.DataFrame(bias_hamm_db, columns=[f"Freq_{i}" for i in range(len(f_hamm))])
    df_bias_mt_db1 = pd.DataFrame(bias_mt_db1, columns=[f"Freq_{i}" for i in range(len(f_mt1))])
    df_bias_mt_db2 = pd.DataFrame(bias_mt_db2, columns=[f"Freq_{i}" for i in range(len(f_mt2))])
    df_bias_mt_db3 = pd.DataFrame(bias_mt_db3, columns=[f"Freq_{i}" for i in range(len(f_mt3))])


    #df_bias_hann_db.to_csv(f'./data/bias_hann_db {duration}s.csv', index=False)
    #df_bias_hamm_db.to_csv(f'./data/bias_hamm_db {duration}s.csv', index=False)
    df_bias_mt_db1.to_csv(f'./tws/bias_mt_db TW=1 {duration}s.csv', index=False)
    df_bias_mt_db2.to_csv(f'./tws/bias_mt_db TW=2 {duration}s.csv', index=False)
    df_bias_mt_db3.to_csv(f'./tws/bias_mt_db TW=3 {duration}s.csv', index=False)


    # ——— plot violins with medians, means, and annotate mean±SD ——————
    fig, axes = plt.subplots(1, 3, figsize=(15, 4), sharey=True)
    for ax, data_db, fgrid, idxs, title in zip(
        axes,
        [bias_mt_db1, bias_mt_db2, bias_mt_db3],
        [f_mt1,       f_mt2,       f_mt3],
        [idx_mt1,     idx_mt2,     idx_mt3],
        ['TW = 1',       'TW = 2',    'TW = 3']
    ):
        vp = ax.violinplot(
            [data_db[:, j] for j in idxs],
            positions=freqs_of_interest,
            widths=0.5*(freqs_of_interest[1]-freqs_of_interest[0]),
            showmeans=True
        )
        ax.set_xscale('log')
        ax.set_xticks(freqs_of_interest)
        ax.set_xticklabels([str(f) for f in freqs_of_interest])
        ax.axhline(0, linestyle='--', color='k', linewidth=1)
        ax.set_title(f'{title} Bias (dB)\n{duration}s data')
        ax.set_xlabel('Frequency (Hz)')

        # annotate mean±SD above each violin
        means = [data_db[:, j].mean() for j in idxs]
        stds  = [data_db[:, j].std()  for j in idxs]
        ylim = ax.get_ylim()
        y_text = ylim[1] - 0.15*(ylim[1]-ylim[0])
        for x, m, s in zip(freqs_of_interest, means, stds):
            ax.text(x, y_text, f'{m:.2f}±{s:.2f}', ha='center', fontsize=8)

    axes[0].set_ylabel('Bias (dB)')
    plt.suptitle(f'Bias Distributions in dB at Selected Frequencies\n'
                 f'(mean±SD shown, N={n_iter} runs)')
    plt.tight_layout(rect=[0,0,1,0.94])
    plt.show()
