In [None]:
"""
# Tri-Channel OECT MC: Diffusion Model Sanity Check

This notebook verifies the diffusion model implementation and investigates
the propagation delay calculations that are critical to the paper.
"""

import sys
import subprocess

print("=== Python Environment Diagnostic ===")
print(f"Python executable: {sys.executable}")
print(f"Python version: {sys.version}")
print(f"\nPython path entries:")
for i, path in enumerate(sys.path):
    print(f"  {i}: {path}")

print("\n=== Checking pip list ===")
result = subprocess.run([sys.executable, "-m", "pip", "list"], capture_output=True, text=True)
packages = result.stdout.split('\n')
print("Looking for our key packages:")
for package in packages:
    if any(name in package.lower() for name in ['numpy', 'matplotlib', 'scipy', 'pandas']):
        print(f"  {package}")

# %% [markdown]
# ## Setup and Imports

# %%
import numpy as np
import matplotlib.pyplot as plt
import yaml
from pathlib import Path
from typing import Dict, Any, Tuple, Union

# Add parent directory to path for imports
import sys
sys.path.append('..')

from src.mc_channel.transport import (
    greens_function_3d,
    finite_burst_concentration
)
from src.constants import (
    get_nt_params,
    validate_system_parameters,
    UM_TO_M,
    MS_TO_S
)
# Import the analysis helpers from their new location
from src.analysis_utils import (
    find_peak_concentration,
    calculate_propagation_metrics,
    verify_propagation_delays # If this is used by the test file
)

# Configure matplotlib for publication-quality figures
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.size'] = 12
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['legend.fontsize'] = 12

# %% [markdown]
# ## Load Configuration and Validate Parameters

# %%
# Load configuration
config_path = Path('../config/default.yaml')
if config_path.exists():
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
else:
    # Use default configuration if file doesn't exist
    config = {
        'temperature_K': 310.0,
        'alpha': 0.20,
        'clearance_rate': 0.01,
        'neurotransmitters': {
            'DA': {
                'D_m2_s': 7.6e-10,
                'lambda': 1.7,
                'k_on_M_s': 5e4,
                'k_off_s': 1.5,
                'q_eff_e': 0.6
            },
            'SERO': {
                'D_m2_s': 9.1e-10,
                'lambda': 1.5,
                'k_on_M_s': 3e4,
                'k_off_s': 0.9,
                'q_eff_e': 0.2
            }
        },
        'T_release_ms': 10,
        'burst_shape': 'rect',
        'gamma_shape_k': 2.0,
        'gamma_scale_theta': 5e-3,
        'gate_area_m2': 4e-8
    }

# Convert string values to numeric types if they exist
numeric_fields = [
    'gate_area_m2', 'gamma_scale_theta', 'hooge_alpha', 'N_apt',
    'temperature_K', 'alpha', 'clearance_rate', 'T_release_ms',
    'gamma_shape_k', 'gm_S', 'C_tot_F', 'rho_corr', 'thermal_T',
    'K_d_Hz', 'monte_carlo_trials', 'time_window_s', 'dt_s'
]

for field in numeric_fields:
    if field in config and isinstance(config[field], str):
        config[field] = float(config[field])

# Convert list elements that should be numeric
if 'Nm_range' in config and isinstance(config['Nm_range'], list):
    config['Nm_range'] = [float(x) if isinstance(x, str) else x for x in config['Nm_range']]

if 'distances_um' in config and isinstance(config['distances_um'], list):
    config['distances_um'] = [float(x) if isinstance(x, str) else x for x in config['distances_um']]

# Convert neurotransmitter parameters
for nt_name, nt_params in config.get('neurotransmitters', {}).items():
    nt_numeric_fields = ['D_m2_s', 'lambda', 'k_on_M_s', 'k_off_s', 'q_eff_e']
    for field in nt_numeric_fields:
        if field in nt_params and isinstance(nt_params[field], str):
            nt_params[field] = float(nt_params[field])

# Validate parameters
#validation = validate_system_parameters(config)
#print("System Parameter Validation:")
#print(f"Valid: {validation['valid']}")
#print(f"Damköhler numbers: {validation['damkohler_numbers']}")
#for warning in validation['warnings']:
    #print(f"  {warning}")

# %% [markdown]
# ## 1. Green's Function Behavior
# 
# First, let's visualize how the Green's function behaves for a single molecule
# release, comparing DA and SERO at different times.

# %%
# Distance range
r_vec = np.linspace(0, 500e-6, 200)  # 0 to 500 μm
times = [0.01, 0.1, 1.0, 5.0]  # seconds

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for ax, nt_type in zip(axes, ['DA', 'SERO']):
    nt_params = get_nt_params(config, nt_type)
    D = nt_params['D_m2_s']
    lam = nt_params['lambda']
    
    for t in times:
        G_vec = [greens_function_3d(r, t, D, lam, config['alpha'], 
                                   config['clearance_rate']) 
                 for r in r_vec]
        ax.plot(r_vec * 1e6, G_vec, label=f't = {t}s')
    
    ax.set_xlabel('Distance (μm)')
    ax.set_ylabel('Green\'s function (m⁻³)')
    ax.set_title(f'{nt_type} Green\'s Function Evolution')
    ax.set_yscale('log')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../results/figures/greens_function_comparison.png')
plt.show()

# %% [markdown]
# ## 2. Concentration Profiles for Finite Release
# 
# Now let's examine concentration profiles resulting from finite-duration release,
# which is more realistic than instantaneous release.

# %%
# Parameters for concentration profile
distances_um = [50, 100, 200]
Nm = 1e4  # 10,000 molecules (typical synaptic vesicle)
t_vec = np.linspace(0, 20, 400)  # 20 seconds with fine resolution

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Compare burst shapes (rectangular vs gamma)
for row, burst_shape in enumerate(['rect', 'gamma']):
    config['burst_shape'] = burst_shape
    
    for col, nt_type in enumerate(['DA', 'SERO']):
        ax = axes[row, col]
        
        for d_um in distances_um:
            d_m = d_um * UM_TO_M
            c_profile = finite_burst_concentration(Nm, d_m, t_vec, config, nt_type)
            
            # Find peak for annotation
            c_peak, t_peak = find_peak_concentration(c_profile, t_vec)
            
            # Convert to nM for easier reading
            c_profile_nM = c_profile * 1e9
            c_peak_nM = c_peak * 1e9
            
            ax.plot(t_vec, c_profile_nM, label=f'{d_um} μm (peak: {t_peak:.2f}s)')
            ax.plot(t_peak, c_peak_nM, 'o', markersize=6)
        
        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Concentration (nM)')
        ax.set_title(f'{nt_type} - {burst_shape.capitalize()} Burst')
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_xlim(0, 10)

plt.tight_layout()
plt.savefig('../results/figures/concentration_profiles.png')
plt.show()

# Reset to rectangular for remaining analyses
config['burst_shape'] = 'rect'

# %% [markdown]
# ## 3. Propagation Delay Analysis
# 
# This is crucial for understanding the reported delays of 6.3s (DA) and 4.7s (SERO)
# at 100 μm. Let's investigate what contributes to these delays.

# %%
# Detailed analysis at 100 μm
distance_m = 100e-6
t_vec_fine = np.linspace(0, 15, 1000)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Calculate and plot for both neurotransmitters
results_100um = {}
for nt_type, color in zip(['DA', 'SERO'], ['blue', 'red']):
    c_profile = finite_burst_concentration(Nm, distance_m, t_vec_fine, config, nt_type)
    c_peak, t_peak = find_peak_concentration(c_profile, t_vec_fine)
    
    # Store results
    results_100um[nt_type] = {
        'peak_time': t_peak,
        'peak_conc_nM': c_peak * 1e9
    }
    
    # Plot concentration profile
    ax1.plot(t_vec_fine, c_profile * 1e9, color=color, label=f'{nt_type}')
    ax1.axvline(t_peak, color=color, linestyle='--', alpha=0.5)
    ax1.text(t_peak + 0.2, c_peak * 1e9 * 0.8, f'{t_peak:.2f}s', 
             color=color, fontsize=10)
    
    # Plot normalized profiles to compare shapes
    c_norm = c_profile / c_peak
    ax2.plot(t_vec_fine, c_norm, color=color, label=f'{nt_type}')
    ax2.axhline(0.5, color='gray', linestyle=':', alpha=0.5)

ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Concentration (nM)')
ax1.set_title('Concentration Profiles at 100 μm')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_xlim(0, 10)

ax2.set_xlabel('Time (s)')
ax2.set_ylabel('Normalized Concentration')
ax2.set_title('Normalized Profiles (Shape Comparison)')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_xlim(0, 10)

plt.tight_layout()
plt.savefig('../results/figures/propagation_delay_analysis.png')
plt.show()

print(f"\nPeak times at 100 μm:")
print(f"DA: {results_100um['DA']['peak_time']:.3f} s (reported: 6.3 s)")
print(f"SERO: {results_100um['SERO']['peak_time']:.3f} s (reported: 4.7 s)")

# %% [markdown]
# ## 4. Factors Contributing to Propagation Delay

# %%
# Analyze different contributions to delay
distance_m = 100e-6

# Factor 1: Pure diffusion time
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

for i, nt_type in enumerate(['DA', 'SERO']):
    nt_params = get_nt_params(config, nt_type)
    D = nt_params['D_m2_s']
    lam = nt_params['lambda']
    
    # Calculate characteristic times
    t_diff = distance_m**2 * lam**2 / D
    
    # Effect of clearance
    ax = axes[0, i]
    for k_clear in [0, 0.001, 0.01, 0.1]:
        config_temp = config.copy()
        config_temp['clearance_rate'] = k_clear
        c_profile = finite_burst_concentration(Nm, distance_m, t_vec_fine, 
                                             config_temp, nt_type)
        c_peak, t_peak = find_peak_concentration(c_profile, t_vec_fine)
        ax.plot(t_vec_fine, c_profile * 1e9, 
                label=f"k' = {k_clear} s⁻¹ (peak: {t_peak:.2f}s)")
    
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Concentration (nM)')
    ax.set_title(f'{nt_type}: Effect of Clearance Rate')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 10)
    
    # Effect of volume fraction
    ax = axes[1, i]
    for alpha in [0.1, 0.2, 0.3]:
        config_temp = config.copy()
        config_temp['alpha'] = alpha
        c_profile = finite_burst_concentration(Nm, distance_m, t_vec_fine, 
                                             config_temp, nt_type)
        c_peak, t_peak = find_peak_concentration(c_profile, t_vec_fine)
        ax.plot(t_vec_fine, c_profile * 1e9, 
                label=f"α = {alpha} (peak: {t_peak:.2f}s)")
    
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Concentration (nM)')
    ax.set_title(f'{nt_type}: Effect of Volume Fraction')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_xlim(0, 10)

plt.tight_layout()
plt.savefig('../results/figures/delay_factors_analysis.png')
plt.show()

# %% [markdown]
# ## 5. Comprehensive Propagation Metrics

# %%
# Calculate full metrics for manuscript parameters
print("Propagation Metrics at 100 μm:\n")

results_comparison = verify_propagation_delays(config)

for nt_type in ['DA', 'SERO']:
    print(f"\n{nt_type}:")
    metrics = calculate_propagation_metrics(config, Nm, 100e-6, nt_type)
    
    print(f"  Peak concentration: {metrics['peak_concentration_M']*1e9:.2f} nM")
    print(f"  Time to peak: {metrics['time_to_peak_s']:.3f} s")
    print(f"  FWHM: {metrics['fwhm_s']:.3f} s")
    print(f"  10-90% rise time: {metrics['rise_time_10_90_s']:.3f} s")
    print(f"  Characteristic diffusion time: {metrics['t_diff_characteristic_s']*1000:.1f} ms")
    print(f"  Delay factor: {metrics['delay_factor']:.1f}x")
    print(f"  \nManuscript comparison:")
    print(f"    Reported delay: {results_comparison[nt_type]['reported_delay_s']} s")
    print(f"    Percent difference: {results_comparison[nt_type]['percent_difference']:.1f}%")

# %% [markdown]
# ## 6. Distance Scaling Analysis

# %%
# How does propagation delay scale with distance?
distances_um = np.logspace(1, 2.5, 20)  # 10 to ~316 μm
delays_da = []
delays_SERO = []

for d_um in distances_um:
    d_m = d_um * UM_TO_M
    
    # DA
    metrics = calculate_propagation_metrics(config, Nm, d_m, 'DA')
    delays_da.append(metrics['time_to_peak_s'])
    
    # SERO
    metrics = calculate_propagation_metrics(config, Nm, d_m, 'SERO')
    delays_SERO.append(metrics['time_to_peak_s'])

# Plot scaling
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Linear scale
ax1.plot(distances_um, delays_da, 'b-', label='DA', linewidth=2)
ax1.plot(distances_um, delays_SERO, 'r-', label='SERO', linewidth=2)
ax1.plot(100, results_100um['DA']['peak_time'], 'bo', markersize=8)
ax1.plot(100, results_100um['SERO']['peak_time'], 'ro', markersize=8)
ax1.set_xlabel('Distance (μm)')
ax1.set_ylabel('Time to Peak (s)')
ax1.set_title('Propagation Delay vs Distance')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Log-log scale to check power law
ax2.loglog(distances_um, delays_da, 'b-', label='DA', linewidth=2)
ax2.loglog(distances_um, delays_SERO, 'r-', label='SERO', linewidth=2)

# Add reference lines for different scaling laws
d_ref = distances_um
ax2.loglog(d_ref, 0.001 * (d_ref/10)**2, 'k--', alpha=0.5, label='∝ r²')
ax2.loglog(d_ref, 0.001 * (d_ref/10)**1.5, 'k:', alpha=0.5, label='∝ r^1.5')

ax2.set_xlabel('Distance (μm)')
ax2.set_ylabel('Time to Peak (s)')
ax2.set_title('Log-Log Scaling Analysis')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../results/figures/distance_scaling_analysis.png')
plt.show()

# %% [markdown]
# ## 7. Summary and Conclusions
# 
# Based on this analysis, we can draw several conclusions about the propagation delays:

# %%
print("SUMMARY OF FINDINGS:\n")
print("1. Calculated vs Reported Delays at 100 μm:")
print(f"   - DA: {results_100um['DA']['peak_time']:.3f} s (calculated) vs 6.3 s (reported)")
print(f"   - SERO: {results_100um['SERO']['peak_time']:.3f} s (calculated) vs 4.7 s (reported)")

print("\n2. Key Factors Affecting Delay:")
print("   - Finite release duration (10 ms)")
print("   - Tortuosity (DA: 1.7, SERO: 1.5)")
print("   - Clearance rate (0.01 s⁻¹)")
print("   - Restricted volume fraction (α = 0.2)")

print("\n3. Scaling Behavior:")
print("   - Delays scale approximately as r^n where n ≈ 1.5-2")
print("   - Not pure r² due to clearance and finite release effects")

print("\n4. Implications for Symbol Period:")
if max(results_100um['DA']['peak_time'], results_100um['SERO']['peak_time']) < 1.0:
    print("   - Calculated delays suggest Ts could be shorter than 20s")
    print("   - However, biological constraints (burst intervals) may dominate")
else:
    print("   - Delays support the choice of Ts = 20s")
    print("   - Provides sufficient margin for complete molecular clearance")

# Save key results
import json
results_to_save = {
    'calculated_delays_100um': {
        'DA': results_100um['DA']['peak_time'],
        'SERO': results_100um['SERO']['peak_time']
    },
    'reported_delays_100um': {
        'DA': 6.3,
        'SERO': 4.7
    },
    'peak_concentrations_nM': {
        'DA': results_100um['DA']['peak_conc_nM'],
        'SERO': results_100um['SERO']['peak_conc_nM']
    }
}

with open('../results/data/propagation_analysis_results.json', 'w') as f:
    json.dump(results_to_save, f, indent=2)

print("\n\nResults saved to ../results/data/propagation_analysis_results.json")