In [7]:
import numpy as np
from scipy import signal
from scipy.integrate import solve_ivp
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
from enum import Enum
import matplotlib.pyplot as plt

# Basic parameter classes and enums
class CompType(Enum):
    AIS = "axon_initial_segment"
    SOMA = "soma"
    DEND = "dendrite"
    NODE = "ranvier_node"

@dataclass
class ChannelParams:
    """Channel parameters including calcium dependence"""
    g_max: float
    E_rev: float
    v_half_m: float
    k_m: float
    tau_m: float
    v_half_h: Optional[float] = None
    k_h: Optional[float] = None
    tau_h: Optional[float] = None
    ca_dependent: bool = False
    ca_half: Optional[float] = None
    k_ca: Optional[float] = None
    resurgent: bool = False
    res_tau: Optional[float] = None
    res_voltage: Optional[float] = None

@dataclass
class SynapticInput:
    """Synaptic input parameters"""
    g_max: float
    E_rev: float
    tau_rise: float
    tau_decay: float
    n_synapses: int = 1
    firing_rate: float = 0.0

class CalciumDynamics:
    """Calcium handling with compartment-specific buffering"""
    def __init__(self, comp_type: CompType):
        self.comp_type = comp_type
        self.init_buffer_properties()
        self.init_coupling_domains()
        self.Ca_fast = 0.06
        self.Ca_slow = 0.06
        
    def init_buffer_properties(self):
        """Initialize calcium parameters with stronger buffering"""
        self.buffer_density = {
            CompType.DEND: 2e-8,
            CompType.SOMA: 5e-8,
            CompType.AIS: 5e-8,
            CompType.NODE: 5e-7
        }[self.comp_type]
        
        self.params = {
            'ca_base': 0.05,  # Slightly lower baseline
            'ca_tau_fast': 0.5,  # Faster calcium dynamics
            'ca_tau_slow': 20.0,  # Faster slow component
            'kappa_fast': 200.0 * self.buffer_density,  # Stronger buffering
            'kappa_slow': 400.0 * self.buffer_density,
            'pump_rate': 1.0,  # Increased pump rate
            'conversion_factor': 0.05  # Reduced calcium entry
        }
        
    def init_coupling_domains(self):
        """Initialize calcium channel-to-KCa coupling domains"""
        self.coupling_domains = {
            'Cav2_1': {  # P/Q-type selective coupling
                'KCa1_1': 1.0,  # Strong coupling to BK
                'KCa2_2': 0.8,  # Strong coupling to SK
                'KCa3_1': 0.6   # Moderate coupling to IK
            },
            'Cav3_1': {  # T-type coupling
                'KCa1_1': 0.2,
                'KCa2_2': 0.3,
                'KCa3_1': 0.2
            }
        }

    def update_calcium(self, I_ca: Dict[str, float], dt: float) -> Tuple[float, float]:
        """Update calcium concentrations"""
        # Total calcium current
        I_ca_total = sum(I_ca.values())
        
        # Calcium flux
        ca_influx = -I_ca_total * self.params['conversion_factor']
        
        # Fast buffer updates
        dCa_fast = (ca_influx - (self.Ca_fast - self.params['ca_base']) / 
                   self.params['ca_tau_fast']) / self.params['kappa_fast']
        
        # Slow buffer updates
        dCa_slow = (ca_influx - (self.Ca_slow - self.params['ca_base']) / 
                   self.params['ca_tau_slow']) / self.params['kappa_slow']
        
        # Update concentrations
        self.Ca_fast += dCa_fast * dt
        self.Ca_slow += dCa_slow * dt
        
        # Ensure non-negative concentrations
        self.Ca_fast = max(self.Ca_fast, self.params['ca_base'])
        self.Ca_slow = max(self.Ca_slow, self.params['ca_base'])
        
        return self.Ca_fast, self.Ca_slow

    def get_domain_calcium(self, channel: str) -> float:
        """Get calcium concentration for specific channel domains"""
        if channel in self.coupling_domains:
            base_ca = self.Ca_fast + self.Ca_slow
            couplings = self.coupling_domains[channel]
            return sum(c * base_ca for c in couplings.values()) / len(couplings)
        return self.Ca_fast + self.Ca_slow

class IonicChannels:
    """Channel management and current computation"""
    def __init__(self, comp_type: CompType):
        self.comp_type = comp_type
        self.channels = {}
        self.gate_states = {}
        self.init_channels()
        
    def init_channels(self):
        """Initialize all channel types"""
        self.add_sodium_channels()
        self.add_potassium_channels()
        self.add_calcium_channels()
        
    def add_sodium_channels(self):
        """Add Nav1.6 channels with resurgent component"""
        base_nav_density = 1.0 if self.comp_type == CompType.AIS else 0.3
        
        # Regular Nav1.6
        self.channels['Nav1_6'] = ChannelParams(
            g_max=base_nav_density,
            E_rev=60.0,
            v_half_m=-40.0,
            k_m=6.0,
            tau_m=0.02,
            v_half_h=-65.0,
            k_h=-6.0,
            tau_h=0.5
        )
        
        # Resurgent component
        self.channels['Nav1_6R'] = ChannelParams(  # Changed name to avoid underscore
            g_max=0.2 * base_nav_density,
            E_rev=60.0,
            v_half_m=-45.0,
            k_m=5.0,
            tau_m=1.0,
            v_half_h=-60.0,
            k_h=-8.0,
            tau_h=5.0,
            resurgent=True,
            res_tau=2.0,
            res_voltage=-40.0
        )

    def add_potassium_channels(self):
        """Add K+ channels with adjusted conductances"""
        # Increased conductances for proper repolarization
        k_channels = {
            'Kv1_1': (0.2, -35.0, 9.0, 0.5),   # Faster
            'Kv1_5': (0.2, -30.0, 10.0, 1.0),
            'Kv3_3': (0.3, -20.0, 10.0, 0.3),  # Increased and faster
            'Kv3_4': (0.3, -15.0, 8.0, 0.2),   # Increased and faster
            'Kv4_3': (0.2, -50.0, 14.0, 2.0)
        }
        
        for channel, (g, v_m, k_m, tau) in k_channels.items():
            self.add_channel(channel, ChannelParams(
                g_max=g,
                E_rev=-77.0,
                v_half_m=v_m,
                k_m=k_m,
                tau_m=tau
            ))
    def add_calcium_channels(self):
        """Add calcium channels with proper densities"""
        # P/Q-type (Cav2.1)
        self.add_channel('Cav2_1', ChannelParams(
            g_max=0.002,
            E_rev=120.0,
            v_half_m=-25.0,
            k_m=8.0,
            tau_m=0.5,
            v_half_h=-40.0,
            k_h=-6.0,
            tau_h=5.0
        ))
        
        # T-type channels
        t_type_configs = {
            'Cav3_1': (-50.0, 7.0, -65.0, -5.0),
            'Cav3_2': (-45.0, 6.5, -70.0, -4.5),
            'Cav3_3': (-55.0, 7.5, -75.0, -5.5)
        }
        
        for channel, (v_m, k_m, v_h, k_h) in t_type_configs.items():
            self.add_channel(channel, ChannelParams(
                g_max=0.0005,
                E_rev=120.0,
                v_half_m=v_m,
                k_m=k_m,
                tau_m=1.0,
                v_half_h=v_h,
                k_h=k_h,
                tau_h=10.0
            ))
            
        # Calcium-activated K+ channels
        ca_k_channels = {
            'KCa1_1': (0.01, 1.0, 0.2),  # BK
            'KCa2_2': (0.004, 0.3, 0.2),  # SK
            'KCa3_1': (0.004, 0.5, 0.2)   # IK
        }
        
        for channel, (g, ca_half, k_ca) in ca_k_channels.items():
            self.add_channel(channel, ChannelParams(
                g_max=g,
                E_rev=-77.0,
                v_half_m=0.0,
                k_m=1.0,
                tau_m=2.0,
                ca_dependent=True,
                ca_half=ca_half,
                k_ca=k_ca
            ))
    def add_channel(self, name: str, params: ChannelParams):
        """Add a channel and initialize its states"""
        self.channels[name] = params
        self.gate_states[f"{name}_m"] = 0.0
        if params.v_half_h is not None:
            self.gate_states[f"{name}_h"] = 1.0
        if params.resurgent:
            self.gate_states[f"{name}_res"] = 0.0
        if params.ca_dependent:
            self.gate_states[f"{name}_ca"] = 0.0

class WaveAnalysis:
    """AM/FM wave analysis"""
    def __init__(self):
        self.sampling_rate = 10000  # Hz
        self.carrier_band = (30, 80)  # Hz for simple spikes
        self.modulation_band = (0.5, 8)  # Hz for envelope
        
    def analyze_modulation(self, signal_data: np.ndarray, dt: float) -> Dict:
        """Analyze AM/FM components of a signal"""
        # Compute carrier frequency components
        freqs, times, Sxx = signal.spectrogram(
            signal_data,
            fs=1/dt,
            nperseg=1024,
            noverlap=512
        )
        
        # Extract envelope
        analytic_signal = signal.hilbert(signal_data)
        amplitude_envelope = np.abs(analytic_signal)
        
        # Compute modulation frequency
        env_freqs, env_times, env_Sxx = signal.spectrogram(
            amplitude_envelope,
            fs=1/dt,
            nperseg=2048,
            noverlap=1024
        )
        
        # Compute modulation index
        mod_index = np.ptp(amplitude_envelope) / (2 * np.mean(amplitude_envelope))
        
        return {
            'carrier_freqs': freqs,
            'carrier_times': times,
            'carrier_power': Sxx,
            'envelope': amplitude_envelope,
            'mod_freqs': env_freqs,
            'mod_times': env_times,
            'mod_power': env_Sxx,
            'mod_index': mod_index
        }

class Compartment:
    """Base compartment class"""
    def __init__(self, comp_type: CompType, volume: float):
        self.type = comp_type
        self.volume = volume
        self.Cm = 1.0  # µF/cm²
        
        # Initialize mechanisms
        self.ionic_channels = IonicChannels(comp_type)
        self.calcium_dynamics = CalciumDynamics(comp_type)
        
        # State variables
        self.V = -65.0
        self.time = 0.0
        self.last_spike_time = -1000.0
        
        # Add leak conductance
        self.g_leak = 0.002  # S/cm²
        self.E_leak = -65  # mV

def compute_currents(self) -> Dict[str, float]:
    """Compute all ionic currents"""
    currents = {}
    
    # Channel currents
    for channel, params in self.ionic_channels.channels.items():
        try:
            # Check available states for debugging
            available_states = self.ionic_channels.gate_states.keys()
            if f"{channel}_m" not in available_states:
                print(f"Channel {channel} states not found. Available states: {available_states}")
                continue
                
            # Basic activation
            I = params.g_max * self.ionic_channels.gate_states[f"{channel}_m"]
            
            # Inactivation if present
            if params.v_half_h is not None and f"{channel}_h" in available_states:
                I *= self.ionic_channels.gate_states[f"{channel}_h"]
                
            # Calcium dependence
            if params.ca_dependent:
                ca_level = self.calcium_dynamics.get_domain_calcium(channel)
                I *= 1.0 / (1.0 + (params.ca_half/ca_level)**params.k_ca)
                
            # Resurgent gate if present
            if params.resurgent and f"{channel}_res" in available_states:
                I *= self.ionic_channels.gate_states[f"{channel}_res"]
                
            # Driving force
            I *= (self.V - params.E_rev)
            currents[channel] = I
            
        except KeyError as e:
            print(f"Error processing channel {channel}")
            print(f"Available states: {list(self.ionic_channels.gate_states.keys())}")
            raise
    
        # Leak current
        currents['leak'] = self.g_leak * (self.V - self.E_leak)
        
        return currents
    def update_calcium_pumps(self, dt: float):
        """Enhanced calcium extrusion"""
        # PMCA pump
        pmca_max = 0.5  # μM/ms
        pmca_km = 0.1   # μM
        
        total_ca = self.calcium_dynamics.Ca_fast + self.calcium_dynamics.Ca_slow
        base_ca = self.calcium_dynamics.params['ca_base']
        
        pmca_current = pmca_max * (total_ca - base_ca)**2 / (pmca_km**2 + (total_ca - base_ca)**2)
        
        # NCX exchanger
        ncx_max = 1.0   # μM/ms
        ncx_km = 1.0    # μM
        ncx_current = ncx_max * (total_ca - base_ca) / (ncx_km + (total_ca - base_ca))
        
        # Total extrusion
        total_extrusion = (pmca_current + ncx_current) * dt
        
        # Apply to both components proportionally
        fast_fraction = self.calcium_dynamics.Ca_fast / total_ca
        self.calcium_dynamics.Ca_fast -= total_extrusion * fast_fraction
        self.calcium_dynamics.Ca_slow -= total_extrusion * (1 - fast_fraction)
    def update(self, dt: float, I_ext: float = 0.0) -> bool:
        """Update compartment state"""
        self.time += dt
        
        try:
            # Compute currents
            currents = self.compute_currents()
            I_total = sum(currents.values()) - I_ext
            
            # Update voltage
            dV = -I_total / self.Cm
            self.V += dV * dt
            
            # Update calcium
            I_ca = {ch: curr for ch, curr in currents.items() if ch.startswith('Ca')}
            self.calcium_dynamics.update_calcium(I_ca, dt)
            
            # Update channel gates
            self.update_gates(dt)
            
            # Check for spike
            spike = False
            if self.V > 20 and (self.time - self.last_spike_time) > 2.0:
                self.last_spike_time = self.time
                spike = True
                
            return spike
            
        except Exception as e:
            print(f"Error in compartment update: {e}")
            raise

    def update_gates(self, dt: float):
        """Update all channel gates"""
        for channel, params in self.ionic_channels.channels.items():
            # Voltage-dependent activation
            m_inf = 1.0 / (1.0 + np.exp(-(self.V - params.v_half_m) / params.k_m))
            self.ionic_channels.gate_states[f"{channel}_m"] += (
                (m_inf - self.ionic_channels.gate_states[f"{channel}_m"]) / params.tau_m * dt
            )
            
            # Inactivation if present
            if params.v_half_h is not None:
                h_inf = 1.0 / (1.0 + np.exp(-(self.V - params.v_half_h) / params.k_h))
                self.ionic_channels.gate_states[f"{channel}_h"] += (
                    (h_inf - self.ionic_channels.gate_states[f"{channel}_h"]) / params.tau_h * dt
                )
            
            # Resurgent gate if present
            if params.resurgent:
                res_inf = 1.0 / (1.0 + np.exp(-(self.V - params.res_voltage) / 5.0))
                self.ionic_channels.gate_states[f"{channel}_res"] += (
                    (res_inf - self.ionic_channels.gate_states[f"{channel}_res"]) / params.res_tau * dt
                )

class DendriticCompartment(Compartment):
    """Dendritic compartment with wave processing"""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.wave_analyzer = WaveAnalysis()
        self.init_wave_properties()
        
    def init_wave_properties(self):
        """Initialize wave propagation properties"""
        self.wave_props = {
            'propagation_speed': 0.5,  # m/s
            'decay_length': 200e-6,    # 200 µm
            'spatial_freq': 100,       # waves/m
            'temporal_freq': 40        # Hz
        }
        
    def process_waves(self, inputs: np.ndarray, positions: np.ndarray) -> np.ndarray:
        """Process dendritic inputs with wave properties"""
        t = np.arange(len(inputs)) * self.dt
        x = positions
        
        k = 2 * np.pi * self.wave_props['spatial_freq']
        omega = 2 * np.pi * self.wave_props['temporal_freq']
        
        # Standing waves
        standing_waves = np.outer(np.sin(omega * t), np.sin(k * x))
        
        # Traveling waves
        v = self.wave_props['propagation_speed']
        traveling_waves = np.sin(k * (x[:, None] - v * t))
        
        # Apply decay
        decay = np.exp(-x / self.wave_props['decay_length'])
        total_waves = (standing_waves + traveling_waves) * decay[:, None]
        
        return inputs * np.mean(total_waves, axis=0)

class PurkinjeCell:
    """Complete Purkinje cell model"""
    def __init__(self):
        # Create compartments
        self.compartments = {
            'ais': Compartment(CompType.AIS, volume=1e-12),
            'soma': Compartment(CompType.SOMA, volume=5e-12),
            'dend': DendriticCompartment(CompType.DEND, volume=20e-12),
            'node1': Compartment(CompType.NODE, volume=0.5e-12)
        }
        
        # Coupling resistances (MΩ)
        self.coupling_resistances = {
            ('ais', 'soma'): 2.0,    # Reduced resistance
            ('soma', 'dend'): 10.0,  # Adjusted for better propagation
            ('ais', 'node1'): 15.0   # Adjusted for axonal conduction
        }
        
        self.dt = 0.1  # ms
        self.time = 0.0

    def compute_coupling_current(self, comp1: str, comp2: str) -> float:
        """Compute current flow between compartments"""
        key = tuple(sorted([comp1, comp2]))
        if key in self.coupling_resistances:
            V1 = self.compartments[comp1].V
            V2 = self.compartments[comp2].V
            R = self.coupling_resistances[key]
            return (V2 - V1) / R
        return 0.0

    def update(self, dt: float, I_ext: float = 0.0):
        """Update all compartments"""
        self.time += dt
        
        # Compute coupling currents first
        coupling_currents = {}
        for comp1 in self.compartments:
            coupling_currents[comp1] = sum(
                self.compute_coupling_current(comp1, comp2)
                for comp2 in self.compartments
                if comp2 != comp1
            )
        
        # Update each compartment
        spikes = {}
        for comp_name, comp in self.compartments.items():
            # Add coupling current to external current
            I_total = I_ext + coupling_currents[comp_name]
            
            # Update compartment
            spike = comp.update(dt, I_total)
            spikes[comp_name] = spike
        
        return spikes

    def run_simulation(self, duration: float, I_ext: float = 0.0) -> Dict:
        """Run simulation for specified duration"""
        n_steps = int(duration / self.dt)
        
        # Initialize results dictionary
        results = {
            't': np.arange(n_steps) * self.dt,
            'V': {name: np.zeros(n_steps) for name in self.compartments},
            'Ca': {name: np.zeros(n_steps) for name in self.compartments},
            'spikes': {name: [] for name in self.compartments}
        }
        
        # Run simulation
        for i in range(n_steps):
            spikes = self.update(self.dt, I_ext)
            
            # Store results
            for name, comp in self.compartments.items():
                results['V'][name][i] = comp.V
                results['Ca'][name][i] = (comp.calcium_dynamics.Ca_fast + 
                                        comp.calcium_dynamics.Ca_slow)
                if spikes[name]:
                    results['spikes'][name].append(i * self.dt)
        
        return results

def plot_results(results: Dict):
    """Plot simulation results"""
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
    
    # Plot membrane potentials
    for name in results['V']:
        ax1.plot(results['t'], results['V'][name], label=name)
        if name in results['spikes']:
            spike_times = results['spikes'][name]
            if spike_times:
                ax1.plot(spike_times, [20] * len(spike_times), 'o', 
                        color='red', markersize=4)
    
    ax1.set_ylabel('Membrane Potential (mV)')
    ax1.legend()
    ax1.grid(True)
    
    # Plot calcium concentrations
    for name in results['Ca']:
        ax2.plot(results['t'], results['Ca'][name], label=name)
    
    ax2.set_xlabel('Time (ms)')
    ax2.set_ylabel('Ca²⁺ Concentration (µM)')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

# Example usage
if __name__ == "__main__":
    cell = PurkinjeCell()
    results = cell.run_simulation(duration=100.0, I_ext=0.5)
    plot_results(results)

AttributeError: 'Compartment' object has no attribute 'update'