In [None]:
# %% [markdown]
# ## OECT Noise Figures
# 
# Here we visualise the electrical-noise budget of a single pixel and show how
# the CTRL channel subtraction suppresses low-frequency drift and 1/f noise.

# %%
import numpy as np
import matplotlib.pyplot as plt
import yaml
import pathlib
from scipy import signal
from src.mc_receiver.oect import oect_current, differential_channels, generate_correlated_noise
from src.mc_receiver.binding import bernoulli_binding

# Load config using the same approach as earlier in the notebook
cfg_file = pathlib.Path("../config/default.yaml")
with cfg_file.open("r") as file:
    config_oect = yaml.safe_load(file)

# Convert string values to numeric types if needed
numeric_fields = [
    'gate_area_m2', 'gamma_scale_theta', 'hooge_alpha', 'N_apt',
    'temperature_K', 'alpha', 'clearance_rate', 'T_release_ms',
    'gamma_shape_k', 'gm_S', 'C_tot_F', 'rho_corr', 'thermal_T',
    'K_d_Hz', 'monte_carlo_trials', 'time_window_s', 'dt_s',
    'R_ch_Ohm', 'alpha_H', 'N_c'
]

for field in numeric_fields:
    if field in config_oect and isinstance(config_oect[field], str):
        config_oect[field] = float(config_oect[field])

# Convert neurotransmitter parameters
for nt_name, nt_params in config_oect.get('neurotransmitters', {}).items():
    nt_numeric_fields = ['D_m2_s', 'lambda', 'k_on_M_s', 'k_off_s', 'q_eff_e']
    for field in nt_numeric_fields:
        if field in nt_params and isinstance(nt_params[field], str):
            nt_params[field] = float(nt_params[field])

# Generate a 300-s trace with constant bound sites (1 M) so only noise remains
dt = config_oect['dt_s']
n_samples = int(300 / dt)
bound_sites = np.full(n_samples, int(1e6))

curr_da = oect_current(bound_sites, "DA", config_oect, seed=100)
curr_ctrl = oect_current(bound_sites, "SERO", config_oect, seed=200)  # use SERO params for CTRL baseline

# Welch PSDs
fs = 1/dt
nperseg = 16384

def welch(x): 
    return signal.welch(x, fs=fs, nperseg=nperseg, detrend='constant')

freq, psd_total = welch(curr_da['total'])
_, psd_white = welch(curr_da['thermal'])
_, psd_flicker = welch(curr_da['flicker'])
_, psd_drift = welch(curr_da['drift'])

# Figure 4a – pixel noise budget
plt.figure(figsize=(9,6))
plt.loglog(freq, psd_white, label='Thermal (white)')
plt.loglog(freq, psd_flicker, label='Flicker 1/f')
plt.loglog(freq, psd_drift, label='Drift 1/f²')
plt.loglog(freq, psd_total, 'k--', linewidth=2, label='Total')

# reference slope guides
f_ref1 = np.array([0.02, 0.2])
plt.loglog(f_ref1, 1e-23*(f_ref1/0.02)**-1, color='0.6', lw=1)
plt.text(0.022, 4e-23, '–1', fontsize=9)

f_ref2 = np.array([0.02, 0.2])
plt.loglog(f_ref2, 3e-22*(f_ref2/0.02)**-2, color='0.6', lw=1)
plt.text(0.022, 6e-22, '–2', fontsize=9)

plt.xlim(0.01, 100)
plt.ylim(1e-30, 1e-20)
plt.xlabel('Frequency (Hz)')
plt.ylabel('PSD (A²/Hz)')
plt.title('OECT Pixel Noise Budget (DA channel)')
plt.legend()
plt.grid(alpha=0.3, which='both')
plt.tight_layout()
plt.savefig('../results/figures/oect_noise_breakdown.png', dpi=300)
plt.show()

# ----  shared-drift + common-mode noise model  ----
rho_target = 0.8          # realistic common-mode correlation

# 1. Shared 1/f² drift  → identical for all channels
shared_drift = curr_da['drift']                                 # reuse one realisation

# 2. Independent pixel noise (thermal + flicker)
noise_da   = curr_da['thermal']  + curr_da['flicker']
noise_ctrl  = curr_ctrl['thermal'] + curr_ctrl['flicker']
noise_SERO  = curr_da['thermal']  + curr_da['flicker'] * 0.8   # similar stats

# 3. Extra common-mode variation (temperature, pH) = white Gaussian, high correlation
cm_noise = generate_correlated_noise(n_samples, rho_target, seed=777) * 1.0e-10  # ±100 pA

# 4. Assemble pixel currents
i_da   = curr_da['signal']            + shared_drift + noise_da  + cm_noise[0]
i_ctrl  = curr_ctrl['signal'] * 0.0     + shared_drift + noise_ctrl + cm_noise[1]  # no signal in CTRL
i_SERO  = curr_da['signal'] * 0.0      + shared_drift + noise_SERO + cm_noise[2]

# Differential (DA − CTRL)
diff = differential_channels(i_da, i_SERO, i_ctrl, rho_target)['diff_da']
f2, psd_before = welch(i_da)
_, psd_after = welch(diff)

# Figure 4b – effect of CTRL subtraction
plt.figure(figsize=(9,6))

# highlight the low-frequency "drift" band
plt.fill_betweenx([1e-30, 1e-20], 0.01, 0.1,
                  color='grey', alpha=0.12, zorder=0)
plt.text(0.011, 2e-26, 'drift band', rotation=90, fontsize=9)

plt.loglog(f2, psd_before, 'r-', label='Before subtraction')
plt.loglog(f2, psd_after, 'b-', label='After subtraction')
plt.xlim(0.01, 100)
plt.ylim(1e-25, 1e-20)
plt.xlabel('Frequency (Hz)')
plt.ylabel('PSD (A²/Hz)')
plt.title('Noise PSD: impact of CTRL channel subtraction')
plt.legend()
plt.grid(alpha=0.3, which='both')
plt.tight_layout()
plt.savefig('../results/figures/oect_differential_psd.png', dpi=300)
plt.show()

# Print quantitative reduction at the second PSD bin (~0.03 Hz)
reduction_db = 10 * np.log10(psd_before[1] / psd_after[1])
print(f"Low-freq noise reduction ≈ {reduction_db:.1f} dB at {f2[1]:.3f} Hz")

# %% [markdown]
# ### Supplementary – noise-reduction versus common-mode correlation ρ
# 
# To gauge robustness, we vary the correlation coefficient ρ between pixels
# and measure the drift-band noise-reduction (0.05 Hz).

# %%
rho_vals = np.linspace(0, 0.99, 7)           # 0, 0.2, … 0.99
reduction = []
# PSD bin near 0.05 Hz (drift-dominated)
idx_005Hz = np.argmin(np.abs(freq - 0.05))

for rho in rho_vals:
    cm = generate_correlated_noise(n_samples, rho, seed=int(rho*1000)) * 1.0e-10
    i_da  = curr_da['signal'] + shared_drift + noise_da  + cm[0]
    i_ctrl =                      shared_drift + noise_ctrl + cm[1]
    diff   = differential_channels(i_da, i_da, i_ctrl, rho)['diff_da']
    
    _, psd_before = welch(i_da)
    _, psd_after  = welch(diff)
    reduction.append(10 * np.log10(psd_before[idx_005Hz] / psd_after[idx_005Hz]))

plt.figure(figsize=(5, 3.5))
plt.plot(rho_vals, reduction, 'o-')
plt.xlabel('Correlation $\\rho$')
plt.ylabel('Noise reduction at 0.05 Hz (dB)')
plt.title('Differential drift-noise suppression')
plt.grid(True)

# Format Y-axis ticks as integers
import matplotlib.ticker as mtick
plt.gca().yaxis.set_major_formatter(mtick.FormatStrFormatter('%d'))

plt.tight_layout()
plt.savefig('../results/figures/oect_noise_reduction_vs_rho.png', dpi=300)
plt.show()

print(f"Maximum noise reduction: {max(reduction):.1f} dB at ρ = {rho_vals[np.argmax(reduction)]:.2f}")