<a href="https://colab.research.google.com/github/Deepu-Sharma/Review/blob/main/JJ_single_and_array.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
from scipy.integrate import solve_ivp
from scipy.signal import find_peaks
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

# --- Aesthetic Settings for Plots ---
sns.set(style="whitegrid", font_scale=1.2)
plt.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10

# --- Physical Constants ---
HBAR = 1.0545718e-34              # Reduced Planck's constant (J·s)
ELECTRON_CHARGE = 1.60217662e-19   # Elementary charge (C)
FLUX_QUANTUM_HALF = HBAR / (2 * ELECTRON_CHARGE)  # Reduced flux quantum (V·s)
FLUX_QUANTUM = HBAR * np.pi / ELECTRON_CHARGE      # Magnetic flux quantum (Wb)
JOSEPHSON_CONSTANT = 2 * ELECTRON_CHARGE / HBAR     # Josephson constant (Hz/V)

# --- Josephson Junction Parameters ---
CRITICAL_CURRENT = 1e-6    # Critical current (A)
CAPACITANCE = 1e-12        # Capacitance (F)
SHUNT_RESISTANCE = 50      # Shunt resistance (Ω)

# --- Simulation Parameters ---
SIMULATION_TIME = 20e-9    # Extended simulation duration: 20 ns
dt = 5e-14                 # Time step (s)
TIME_SPAN = (0, SIMULATION_TIME)
TIME_EVALUATION_POINTS = np.linspace(0, SIMULATION_TIME, int(SIMULATION_TIME / dt))

# --- Define a Ramped Bias Current ---
# Bias current ramps linearly, reaching about 5×I_c at end.
RAMP_RATE = (5 * CRITICAL_CURRENT) / SIMULATION_TIME  # in A/s

def bias_current(t):
    """
    Ramped bias current as a function of time.

    Input:
      t : time (s)
    Returns:
      Bias current (A)
    """
    return RAMP_RATE * t

# --- Stochastic Noise Function ---
# Increase noise amplitude; here we set noise_amp = 1e-6 A (equal to I_c)
noise_amp = 1e-6
# Pre-generate noise on a dense time grid:
noise_time_grid = np.linspace(0, SIMULATION_TIME, 10000)
noise_values = np.random.normal(loc=0.0, scale=noise_amp, size=noise_time_grid.shape)
def noise_func(t):
    """Return the noise value at time t via linear interpolation."""
    return np.interp(t, noise_time_grid, noise_values)

# --- Enhanced RCSJ Model with Nonlinear Adaptation and Noise ---
def rcsj_adapt_nonlin(t, state, critical_current, capacitance, shunt_resistance,
                      flux_quantum_half, tau_adapt, k_adapt):
    """
    Enhanced RCSJ model that includes stochastic noise and nonlinear adaptation.

    State Variables:
      state[0] = φ (phase)
      state[1] = V (voltage)
      state[2] = I_adapt (adaptation current)

    Equations:
      dφ/dt = V / FLUX_QUANTUM_HALF

      dV/dt = ( I_bias(t) - I_adapt - critical_current*sin(φ) - V/shunt_resistance )/capacitance
              + noise_func(t)

      dI_adapt/dt = ( k_adapt * tanh(V) - I_adapt )/tau_adapt

    The tanh nonlinearity provides a saturating synaptic-like effect.
    """
    phi, voltage, I_adapt = state
    Ib = bias_current(t)

    dphi_dt = voltage / flux_quantum_half
    dvoltage_dt = (Ib - I_adapt - critical_current * np.sin(phi) - voltage / shunt_resistance) / capacitance \
                  + noise_func(t)
    dI_adapt_dt = (k_adapt * np.tanh(voltage) - I_adapt) / tau_adapt
    return [dphi_dt, dvoltage_dt, dI_adapt_dt]

# --- Adaptation Parameters ---
tau_adapt = 1e-9  # Adaptation time constant (s)
k_adapt = 0.05    # Adaptation strength (A/V)

# --- Solve the Differential Equations ---
initial_state = [0, 0, 0]  # Initial phase, voltage, and adaptation current
solution = solve_ivp(rcsj_adapt_nonlin, TIME_SPAN, initial_state, method='RK45',
                     t_eval=TIME_EVALUATION_POINTS,
                     args=(CRITICAL_CURRENT, CAPACITANCE, SHUNT_RESISTANCE,
                           FLUX_QUANTUM_HALF, tau_adapt, k_adapt),
                     atol=1e-12, rtol=1e-9)

# Retrieve simulation results
time = solution.t                         # Time array (s)
phi = solution.y[0]                       # Phase (rad)
voltage = solution.y[1]                   # Voltage (V)
I_adapt = solution.y[2]                   # Adaptation current (A)
bias_values = bias_current(time)          # Deterministic bias (A)

# --- Compute Voltage Derivative ---
voltage_derivative = np.gradient(voltage, time)

# --- Spike Detection via Voltage Derivative ---
# Focus on times when bias exceeds critical current.
active_mask = bias_values >= CRITICAL_CURRENT
active_time = time[active_mask]
active_deriv = voltage_derivative[active_mask]
active_bias = bias_values[active_mask]

if active_deriv.size > 0:
    deriv_threshold = 0.75 * np.max(active_deriv)
    peak_indices, _ = find_peaks(active_deriv, height=deriv_threshold, distance=20)
else:
    peak_indices = np.array([])

spike_times = active_time[peak_indices]
bias_at_spikes = active_bias[peak_indices]

# --- Exclude Early Transients (e.g., first 5 ns) ---
t_transient = 5e-9
steady_mask = spike_times > t_transient
steady_spike_times = spike_times[steady_mask]
steady_bias = bias_at_spikes[steady_mask]

if len(steady_spike_times) > 1:
    spike_intervals = np.diff(steady_spike_times)
    instantaneous_frequency = 1 / spike_intervals  # in Hz
    bias_for_frequency = steady_bias[1:]
    mid_spike_times = (steady_spike_times[:-1] + steady_spike_times[1:]) / 2.0
else:
    spike_intervals = np.array([])
    instantaneous_frequency = np.array([])
    bias_for_frequency = np.array([])
    mid_spike_times = np.array([])

if steady_spike_times.size > 0:
    estimated_threshold_current = bias_current(steady_spike_times[0])
else:
    estimated_threshold_current = np.nan

# --- Additional Derived Measures ---
if spike_intervals.size > 0:
    mean_isi = np.mean(spike_intervals)
    std_isi = np.std(spike_intervals)
    mean_freq = np.mean(instantaneous_frequency)
else:
    mean_isi = std_isi = mean_freq = np.nan

# Effective drive: bias minus adaptation current.
I_effective = bias_values - I_adapt

# --- Print Simulation Statistics ---
print("Enhanced JJ (Neuron-inspired) with Nonlinear Adaptation and Stochastic Noise:")
print(f"Simulation Time: {SIMULATION_TIME*1e9:.1f} ns")
print(f"Max dV/dt: {np.max(np.abs(voltage_derivative)):.3e} V/s")
print(f"Estimated Threshold Current (steady-state, first spike): {estimated_threshold_current:.3e} A")
print(f"Number of Steady-State Detected Spikes: {len(steady_spike_times)}")
if instantaneous_frequency.size:
    print(f"Average Instantaneous Frequency: {mean_freq:.3e} Hz")
    print(f"ISI: Mean = {mean_isi*1e9:.3f} ns, Std = {std_isi*1e9:.3f} ns")
else:
    print("No steady-state spikes detected for frequency analysis.")
print(f"Effective Drive at final time: {I_effective[-1]*1e6:.3f} μA")

# --- Plotting ---
# Figure 1: Main Time Series
fig1, axes1 = plt.subplots(8, 1, figsize=(10, 32), dpi=150, facecolor='white')

# 1. Input Bias Current
axes1[0].plot(time*1e9, bias_values*1e6, color='royalblue', linewidth=2, label='Bias Current')
axes1[0].set_xlabel('Time (ns)')
axes1[0].set_ylabel('Current (μA)')
axes1[0].set_title('Ramped Input Bias Current')
axes1[0].grid(True, linestyle='--', alpha=0.7)
axes1[0].legend(loc='upper left')

# 2. Adaptation Current
axes1[1].plot(time*1e9, I_adapt*1e6, color='teal', linewidth=2, label='Adaptation Current')
axes1[1].set_xlabel('Time (ns)')
axes1[1].set_ylabel('I_{adapt} (μA)')
axes1[1].set_title('Adaptation Current Evolution')
axes1[1].grid(True, linestyle='--', alpha=0.7)
axes1[1].legend(loc='upper left')

# 3. Effective Drive (Bias - Adaptation)
axes1[2].plot(time*1e9, I_effective*1e6, color='darkmagenta', linewidth=2, label='Effective Drive')
axes1[2].set_xlabel('Time (ns)')
axes1[2].set_ylabel('Effective Current (μA)')
axes1[2].set_title('Effective Drive (Bias - Adaptation)')
axes1[2].grid(True, linestyle='--', alpha=0.7)
axes1[2].legend(loc='upper left')

# 4. Overlay: Bias, Adaptation, and Effective Drive
axes1[3].plot(time*1e9, bias_values*1e6, color='royalblue', linewidth=2, label='Bias')
axes1[3].plot(time*1e9, I_adapt*1e6, color='teal', linewidth=2, label='Adaptation')
axes1[3].plot(time*1e9, I_effective*1e6, color='darkmagenta', linewidth=2, label='Effective Drive')
axes1[3].set_xlabel('Time (ns)')
axes1[3].set_ylabel('Current (μA)')
axes1[3].set_title('Overlay: Bias, Adaptation, and Effective Drive')
axes1[3].grid(True, linestyle='--', alpha=0.7)
axes1[3].legend(loc='upper left')

# 5. Unwrapped Phase Dynamics
unwrapped_phase = np.unwrap(phi)
axes1[4].plot(time*1e9, unwrapped_phase, color='forestgreen', linewidth=2, label='Unwrapped Phase')
axes1[4].set_xlabel('Time (ns)')
axes1[4].set_ylabel('Phase (rad)')
axes1[4].set_title('Unwrapped Phase Dynamics')
axes1[4].grid(True, linestyle='--', alpha=0.7)
axes1[4].legend(loc='upper left')

# 6. Voltage Trace
axes1[5].plot(time*1e9, voltage*1e6, color='crimson', linewidth=1.5, label='Voltage')
axes1[5].set_xlabel('Time (ns)')
axes1[5].set_ylabel('Voltage (μV)')
axes1[5].set_title('Voltage Trace')
axes1[5].grid(True, linestyle='--', alpha=0.7)
axes1[5].legend(loc='upper left')

# 7. Voltage Derivative (Spike Detection)
axes1[6].plot(time*1e9, voltage_derivative, color='darkorange', linewidth=1.5, label='dV/dt')
if len(steady_spike_times) > 0:
    axes1[6].plot(steady_spike_times*1e9, np.interp(steady_spike_times, time, voltage_derivative),
                  'o', color='gold', markersize=5, label='Detected Spikes')
axes1[6].set_xlabel('Time (ns)')
axes1[6].set_ylabel('dV/dt (V/s)')
axes1[6].set_title('Voltage Derivative (Spike Detection)')
axes1[6].grid(True, linestyle='--', alpha=0.7)
axes1[6].legend(loc='upper left')

# 8. Flux Quanta (from Unwrapped Phase)
flux_quanta_over_time = np.abs(unwrapped_phase) / (2 * np.pi)
axes1[7].plot(time*1e9, flux_quanta_over_time, color='purple', linewidth=2, label='Flux Quanta')
axes1[7].set_xlabel('Time (ns)')
axes1[7].set_ylabel('Flux Quanta')
axes1[7].set_title('Flux Quanta Emission')
axes1[7].grid(True, linestyle='--', alpha=0.7)
axes1[7].legend(loc='upper left')

plt.tight_layout()
plt.show()

# --- Figure 2: Frequency and ISI Analysis ---
fig2, axes2 = plt.subplots(2, 2, figsize=(12, 10), dpi=150, facecolor='white')

# f–I Curve (Steady-State)
axes2[0,0].scatter(bias_for_frequency*1e6, instantaneous_frequency*1e-9,
                    color='magenta', s=50, zorder=3)
axes2[0,0].set_xlabel('Bias at Spike (μA)')
axes2[0,0].set_ylabel('Frequency (GHz)')
axes2[0,0].set_title('f–I Curve (Steady-State Spikes)')
axes2[0,0].grid(True, linestyle='--', alpha=0.7)

# Instantaneous Frequency vs. Time
if mid_spike_times.size:
    axes2[0,1].plot(mid_spike_times*1e9, instantaneous_frequency*1e-9, 'o-', color='darkred')
    axes2[0,1].set_xlabel('Time (ns)')
    axes2[0,1].set_ylabel('Frequency (GHz)')
    axes2[0,1].set_title('Instantaneous Frequency vs. Time')
    axes2[0,1].grid(True, linestyle='--', alpha=0.7)
else:
    axes2[0,1].text(0.5, 0.5, 'No Frequency Data', horizontalalignment='center', verticalalignment='center')

# ISI Histogram (Steady-State)
if spike_intervals.size:
    axes2[1,0].hist(spike_intervals*1e9, bins=50, color='green', alpha=0.7)
    axes2[1,0].set_xlabel('Inter-Spike Interval (ns)')
    axes2[1,0].set_ylabel('Count')
    axes2[1,0].set_title('ISI Histogram (Steady-State)')
    axes2[1,0].grid(True, linestyle='--', alpha=0.7)
else:
    axes2[1,0].text(0.5, 0.5, 'No ISI Data', horizontalalignment='center', verticalalignment='center')

# Effective Drive vs. Instantaneous Frequency at Spike Times
if instantaneous_frequency.size:
    axes2[1,1].scatter(bias_for_frequency*1e6, instantaneous_frequency*1e-9,
                        color='blue', s=50, zorder=3)
    axes2[1,1].set_xlabel('Bias at Spike (μA)')
    axes2[1,1].set_ylabel('Frequency (GHz)')
    axes2[1,1].set_title('Frequency vs. Bias (Steady-State)')
    axes2[1,1].grid(True, linestyle='--', alpha=0.7)
else:
    axes2[1,1].text(0.5, 0.5, 'No Data', horizontalalignment='center', verticalalignment='center')

plt.tight_layout()
plt.show()


In [None]:
# --- Figure 3: Applied Current (Bias) vs Output Voltage ---
plt.figure(figsize=(10, 6), dpi=300, facecolor='white')
plt.plot(bias_values * 1e6, voltage * 1e6, color='darkblue', linewidth=1.5)
plt.xlabel('Applied Current (μA)')
plt.ylabel('Output Voltage (μV)')
plt.title('Bias Current vs Output Voltage')
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()


**4x4 Josephson Junction Array**

In [None]:
sns.set(style="whitegrid", font_scale=1.2)
plt.rcParams.update({
    'font.family': 'DejaVu Sans',
    'axes.titlesize': 14,
    'axes.labelsize': 12,
    'legend.fontsize': 10,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10
})

# ================= Optional Accelerators =================
# Try to import numba for JIT acceleration.
try:
    from numba import njit
    numba_available = True
except ImportError:
    numba_available = False

# Try to import joblib for parallel processing.
try:
    from joblib import Parallel, delayed
    parallel_available = True
except ImportError:
    parallel_available = False

# ================= Two-Pointer Helper Functions for STDP =================
if numba_available:
    @njit
    def compute_stdp_two_pointer(post_arr, pre_arr, window, A_plus, tau_plus, A_minus, tau_minus):
        result = 0.0
        n_post = post_arr.shape[0]
        n_pre = pre_arr.shape[0]
        i = 0
        j = 0
        while i < n_post:
            while j < n_pre and pre_arr[j] < post_arr[i] - window:
                j += 1
            k = j
            while k < n_pre and pre_arr[k] <= post_arr[i] + window:
                dt = post_arr[i] - pre_arr[k]
                if dt > 0:
                    result += A_plus * np.exp(-dt / tau_plus)
                elif dt < 0:
                    result -= A_minus * np.exp(dt / tau_minus)
                k += 1
            i += 1
        return result

    @njit
    def compute_hebbian_two_pointer(post_arr, pre_arr, window):
        count = 0
        n_post = post_arr.shape[0]
        n_pre = pre_arr.shape[0]
        i = 0
        j = 0
        while i < n_post:
            while j < n_pre and pre_arr[j] < post_arr[i] - window:
                j += 1
            k = j
            while k < n_pre and pre_arr[k] <= post_arr[i] + window:
                count += 1
                k += 1
            i += 1
        return count
else:
    def compute_stdp_two_pointer(post_arr, pre_arr, window, A_plus, tau_plus, A_minus, tau_minus):
        result = 0.0
        n_post = post_arr.shape[0]
        n_pre = pre_arr.shape[0]
        i = 0
        j = 0
        while i < n_post:
            while j < n_pre and pre_arr[j] < post_arr[i] - window:
                j += 1
            k = j
            while k < n_pre and pre_arr[k] <= post_arr[i] + window:
                dt = post_arr[i] - pre_arr[k]
                if dt > 0:
                    result += A_plus * np.exp(-dt / tau_plus)
                elif dt < 0:
                    result -= A_minus * np.exp(dt / tau_minus)
                k += 1
            i += 1
        return result

    def compute_hebbian_two_pointer(post_arr, pre_arr, window):
        count = 0
        n_post = post_arr.shape[0]
        n_pre = pre_arr.shape[0]
        i = 0
        j = 0
        while i < n_post:
            while j < n_pre and pre_arr[j] < post_arr[i] - window:
                j += 1
            k = j
            while k < n_pre and pre_arr[k] <= post_arr[i] + window:
                count += 1
                k += 1
            i += 1
        return count

# ================= Numba-Accelerated ODE Dynamics =================
if numba_available:
    @njit
    def rcsj_dynamics_numba(t, state, N, critical_current, capacitance,
                            shunt_resistance, simulation_time, noise_base_amp,
                            noise_grid, noise_values, tau_adapt, k_adapt,
                            ramp_rate, FLUX_QUANTUM_HALF,
                            w_above, w_below, w_left, w_right):
        # Reshape state vector into matrices.
        phi = state[0:N*N].reshape(N, N)
        V = state[N*N:2*N*N].reshape(N, N)
        I_adapt = state[2*N*N:].reshape(N, N)
        # Clip voltage.
        for i in range(N):
            for j in range(N):
                if V[i, j] < -2e-3:
                    V[i, j] = -2e-3
                elif V[i, j] > 2e-3:
                    V[i, j] = 2e-3

        # Compute synaptic current.
        I_syn = np.zeros((N, N))
        if N > 1:
            for i in range(1, N):
                for j in range(N):
                    I_syn[i, j] += w_above[i, j] * V[i-1, j]
            for i in range(N-1):
                for j in range(N):
                    I_syn[i, j] += w_below[i, j] * V[i+1, j]
            for i in range(N):
                for j in range(1, N):
                    I_syn[i, j] += w_left[i, j] * V[i, j-1]
            for i in range(N):
                for j in range(N-1):
                    I_syn[i, j] += w_right[i, j] * V[i, j+1]
        # Clip I_syn.
        for i in range(N):
            for j in range(N):
                if I_syn[i, j] < -2e-5:
                    I_syn[i, j] = -2e-5
                elif I_syn[i, j] > 2e-5:
                    I_syn[i, j] = 2e-5

        # Bias current.
        bias = ramp_rate * t
        if bias > critical_current * 2:
            bias = critical_current * 2

        # Effective current.
        I_eff = np.zeros((N, N))
        for i in range(N):
            for j in range(N):
                I_eff[i, j] = bias + I_syn[i, j] - I_adapt[i, j]
                if I_eff[i, j] < -2e-5:
                    I_eff[i, j] = -2e-5
                elif I_eff[i, j] > 2e-5:
                    I_eff[i, j] = 2e-5

        # Noise computation.
        amp = noise_base_amp * (1 + 0.3 * np.sin(2 * np.pi * t / simulation_time))
        noise = np.interp(t, noise_grid, noise_values) * amp
        if noise < -noise_base_amp:
            noise = -noise_base_amp
        elif noise > noise_base_amp:
            noise = noise_base_amp

        # Derivatives.
        dphi_dt = V / FLUX_QUANTUM_HALF
        dV_dt = (I_eff - critical_current * np.sin(phi) - V / shunt_resistance) / capacitance + noise
        dI_adapt_dt = (k_adapt * np.tanh(V / 1e-6) - I_adapt) / tau_adapt

        result = np.empty(state.size)
        idx = 0
        for i in range(N):
            for j in range(N):
                result[idx] = dphi_dt[i, j]
                idx += 1
        for i in range(N):
            for j in range(N):
                result[idx] = dV_dt[i, j]
                idx += 1
        for i in range(N):
            for j in range(N):
                result[idx] = dI_adapt_dt[i, j]
                idx += 1
        return result

    # ---------------- Explicit Euler Integrator (Numba compiled) ----------------
    @njit
    def euler_integrator(state0, t0, dt, nsteps, N, critical_current, capacitance,
                         shunt_resistance, simulation_time, noise_base_amp,
                         noise_grid, noise_values, tau_adapt, k_adapt,
                         ramp_rate, FLUX_QUANTUM_HALF,
                         w_above, w_below, w_left, w_right):
        state = state0.copy()
        sol = np.empty((nsteps + 1, state0.size))
        sol[0, :] = state0
        t = t0
        for i in range(1, nsteps + 1):
            dstate = rcsj_dynamics_numba(t, state, N, critical_current, capacitance,
                                         shunt_resistance, simulation_time, noise_base_amp,
                                         noise_grid, noise_values, tau_adapt, k_adapt,
                                         ramp_rate, FLUX_QUANTUM_HALF,
                                         w_above, w_below, w_left, w_right)
            state = state + dt * dstate
            sol[i, :] = state
            t += dt
        return sol

# ================= The Simulation Class =================
class JJArraySimulator:
    """
    Simulates a 2D array of Josephson Junctions.

    This version includes:
      • A choice of integration methods: either LSODA or a fully compiled explicit Euler integrator.
      • Two-pointer algorithms for STDP weight updates.
      • Parallel implementations for weight update and chaos analysis.
      • Tuned parameters that, for example, (a) reduce the noise and (b) slow the adaptation dynamics.
        These changes lessen fluctuations in inter-spike intervals and, in our tests, produce a lower Lyapunov value.
    """
    HBAR = 1.0545718e-34
    ELECTRON_CHARGE = 1.60217662e-19
    FLUX_QUANTUM_HALF = HBAR / (2 * ELECTRON_CHARGE)

    def __init__(self, config, verbose=True, use_numba=False, use_parallel=False):
        self.N = config['N']
        self.verbose = verbose
        self.use_numba = use_numba and numba_available
        self.use_parallel = use_parallel and parallel_available
        self.integration_method = config.get("integration_method", "lsoda")  # "euler" or "lsoda"
        self.state_size = 3 * self.N * self.N
        self.critical_current = config['critical_current']
        self.capacitance = config['capacitance']
        self.shunt_resistance = config['shunt_resistance']
        self.simulation_time = config['simulation_time']
        self.dt = config['dt']
        # For LSODA, create a time grid.
        self.time_points = np.linspace(0, self.simulation_time, int(self.simulation_time/self.dt)+1)

        # Precompute noise arrays.
        self.noise_base_amp = config['noise_amp']
        self.noise_grid = np.linspace(0, self.simulation_time, 10000)
        self.noise_values = np.random.normal(0.0, self.noise_base_amp, self.noise_grid.shape)

        self.tau_adapt = config['tau_adapt']
        self.k_adapt = config['k_adapt']
        self.learning_params = config['learning_params']

        self.weights = self.initialize_weights(config['initial_weight'])
        self.ramp_rate = config['ramp_rate']

    def initialize_weights(self, initial_weight):
        return tuple(np.full((self.N, self.N), initial_weight) for _ in range(4))

    def rcsj_dynamics(self, t, state):
        # Pure Python fallback (if Numba not used).
        phi = state[0:self.N*self.N].reshape(self.N, self.N)
        V = state[self.N*self.N: 2*self.N*self.N].reshape(self.N, self.N)
        I_adapt = state[2*self.N*self.N:].reshape(self.N, self.N)
        V = np.clip(V, -2e-3, 2e-3)
        w_above, w_below, w_left, w_right = self.weights
        I_syn = np.zeros((self.N, self.N))
        if self.N > 1:
            I_syn[1:,:] += w_above[1:,:] * V[:-1,:]
            I_syn[:-1,:] += w_below[:-1,:] * V[1:,:]
            I_syn[:,1:] += w_left[:,1:] * V[:,:-1]
            I_syn[:,:-1] += w_right[:,:-1] * V[:,1:]
        I_syn = np.clip(I_syn, -2e-5, 2e-5)
        bias = np.clip(self.ramp_rate*t, 0, self.critical_current*2)
        I_eff = np.clip(bias + I_syn - I_adapt, -2e-5, 2e-5)
        amp = self.noise_base_amp*(1+0.3*np.sin(2*np.pi*t/self.simulation_time))
        noise = np.clip(np.interp(t, self.noise_grid, self.noise_values)*amp,
                        -self.noise_base_amp, self.noise_base_amp)
        dphi_dt = V/self.FLUX_QUANTUM_HALF
        dV_dt = (I_eff - self.critical_current*np.sin(phi) - V/self.shunt_resistance)/self.capacitance + noise
        dI_adapt_dt = (self.k_adapt*np.tanh(V/1e-6)-I_adapt)/self.tau_adapt
        return np.concatenate([dphi_dt.flatten(), dV_dt.flatten(), dI_adapt_dt.flatten()])

    def run_simulation(self):
        """Run the simulation using the chosen integration method."""
        if self.integration_method == "euler":
            self.run_simulation_euler()
        else:
            self.run_simulation_lsoda()

    def run_simulation_lsoda(self):
        """Integration using LSODA (with relaxed tolerances)."""
        initial_state = np.zeros(self.state_size)
        initial_state[0:self.N*self.N] = np.random.uniform(-0.2, 0.2, self.N*self.N)
        if self.use_numba:
            def dynamics_jit(t, state):
                return rcsj_dynamics_numba(
                    t, state, self.N, self.critical_current, self.capacitance,
                    self.shunt_resistance, self.simulation_time, self.noise_base_amp,
                    self.noise_grid, self.noise_values, self.tau_adapt, self.k_adapt,
                    self.ramp_rate, self.FLUX_QUANTUM_HALF,
                    self.weights[0], self.weights[1], self.weights[2], self.weights[3]
                )
            dynamics = dynamics_jit
        else:
            dynamics = self.rcsj_dynamics
        solution = solve_ivp(
            dynamics, (0, self.simulation_time), initial_state, method="LSODA",
            t_eval=self.time_points, atol=1e-8, rtol=1e-6, max_step=self.dt*10
        )
        if self.verbose:
            print(f"LSODA status: {solution.status}, Message: {solution.message}")
            print(f"Function evaluations: {solution.nfev}")
        self.time = solution.t
        sol = solution.y
        self.phi = sol[0:self.N*self.N].reshape(self.N, self.N, -1)
        self.V = sol[self.N*self.N:2*self.N*self.N].reshape(self.N, self.N, -1)
        self.I_adapt = sol[2*self.N*self.N:].reshape(self.N, self.N, -1)

    def run_simulation_euler(self):
        """Integration using an explicit Euler method fully compiled with Numba."""
        nsteps = int(self.simulation_time / self.dt)
        initial_state = np.zeros(self.state_size)
        initial_state[0:self.N*self.N] = np.random.uniform(-0.2, 0.2, self.N*self.N)
        if self.use_numba:
            sol = euler_integrator(
                initial_state, 0.0, self.dt, nsteps, self.N, self.critical_current, self.capacitance,
                self.shunt_resistance, self.simulation_time, self.noise_base_amp, self.noise_grid,
                self.noise_values, self.tau_adapt, self.k_adapt, self.ramp_rate, self.FLUX_QUANTUM_HALF,
                self.weights[0], self.weights[1], self.weights[2], self.weights[3]
            )
        else:
            sol = np.empty((nsteps+1, self.state_size))
            sol[0] = initial_state
            state = initial_state.copy()
            t = 0.0
            for i in range(1, nsteps+1):
                dstate = self.rcsj_dynamics(t, state)
                state = state + self.dt * dstate
                sol[i] = state
                t += self.dt
        self.time = np.linspace(0, self.simulation_time, nsteps+1)
        self.phi = sol[:, 0:self.N*self.N].T.reshape(self.N, self.N, -1)
        self.V = sol[:, self.N*self.N:2*self.N*self.N].T.reshape(self.N, self.N, -1)
        self.I_adapt = sol[:, 2*self.N*self.N:].T.reshape(self.N, self.N, -1)

    def detect_spikes(self, threshold_factor=0.05, min_distance=8):
        """Detect spikes in each neuron's voltage trace using adaptive thresholds."""
        spike_times = [[[] for _ in range(self.N)] for _ in range(self.N)]
        total_spikes = 0
        for i in range(self.N):
            for j in range(self.N):
                V_ij = self.V[i, j, :]
                std_v = np.std(V_ij)
                threshold = max(threshold_factor * std_v, 1e-7)
                peaks, _ = find_peaks(V_ij, height=threshold, distance=min_distance)
                spike_times[i][j] = self.time[peaks]
                total_spikes += len(peaks)
                if self.verbose:
                    print(f"Neuron ({i},{j}): {len(peaks)} spikes")
        if self.verbose:
            print(f"Total spikes: {total_spikes}")
        return spike_times

    # --------------- Weight Update (STDP) ---------------
    def update_weights(self, spike_times, rule="stdp"):
        if self.use_parallel:
            return self._update_weights_parallel(spike_times, rule)
        else:
            return self._update_weights_sequential(spike_times, rule)

    def _update_weights_sequential(self, spike_times, rule="stdp"):
        w_above, w_below, w_left, w_right = [w.copy() for w in self.weights]
        params = self.learning_params[rule]
        window = params.get("window", max(params.get("tau_plus",0), params.get("tau_minus",0))*5)
        post_limit = 100
        for i in range(self.N):
            for j in range(self.N):
                post_spikes = spike_times[i][j]
                if len(post_spikes)==0:
                    continue
                post_arr = np.array(post_spikes[:min(post_limit, len(post_spikes))])
                for ni, nj, w_matrix, valid in [
                    (i-1, j, w_above, i>0),
                    (i+1, j, w_below, i<self.N-1),
                    (i, j-1, w_left, j>0),
                    (i, j+1, w_right, j<self.N-1)
                ]:
                    if not valid: continue
                    pre_spikes = spike_times[ni][nj]
                    if len(pre_spikes)==0: continue
                    pre_arr = np.array(pre_spikes)
                    if rule=="stdp":
                        Delta_w = compute_stdp_two_pointer(post_arr, pre_arr, window,
                                                           params["A_plus"], params["tau_plus"],
                                                           params["A_minus"], params["tau_minus"])
                    elif rule=="hebbian":
                        Delta_w = compute_hebbian_two_pointer(post_arr, pre_arr, window)
                    elif rule=="anti_hebbian":
                        Delta_w = - compute_hebbian_two_pointer(post_arr, pre_arr, window)
                    w_matrix[i,j] = np.clip(w_matrix[i,j] + params["eta"] * Delta_w, 0, 1.0)
        return w_above, w_below, w_left, w_right

    def _update_weights_parallel(self, spike_times, rule="stdp"):
        params = self.learning_params[rule]
        window = params.get("window", max(params.get("tau_plus",0), params.get("tau_minus",0))*5)
        post_limit = 100
        N = self.N
        def update_single(i, j):
            post_spikes = spike_times[i][j]
            Delta = {"above": 0.0, "below": 0.0, "left": 0.0, "right": 0.0}
            if len(post_spikes)==0:
                return {(i, j): Delta}
            post_arr = np.array(post_spikes[:min(post_limit, len(post_spikes))])
            if i>0:
                pre_arr = np.array(spike_times[i-1][j])
                if pre_arr.size>0:
                    if rule=="stdp":
                        Delta["above"] = compute_stdp_two_pointer(post_arr, pre_arr, window,
                                                                  params["A_plus"], params["tau_plus"],
                                                                  params["A_minus"], params["tau_minus"])
                    elif rule=="hebbian":
                        Delta["above"] = compute_hebbian_two_pointer(post_arr, pre_arr, window)
                    elif rule=="anti_hebbian":
                        Delta["above"] = - compute_hebbian_two_pointer(post_arr, pre_arr, window)
            if i < N-1:
                pre_arr = np.array(spike_times[i+1][j])
                if pre_arr.size>0:
                    if rule=="stdp":
                        Delta["below"] = compute_stdp_two_pointer(post_arr, pre_arr, window,
                                                                  params["A_plus"], params["tau_plus"],
                                                                  params["A_minus"], params["tau_minus"])
                    elif rule=="hebbian":
                        Delta["below"] = compute_hebbian_two_pointer(post_arr, pre_arr, window)
                    elif rule=="anti_hebbian":
                        Delta["below"] = - compute_hebbian_two_pointer(post_arr, pre_arr, window)
            if j>0:
                pre_arr = np.array(spike_times[i][j-1])
                if pre_arr.size>0:
                    if rule=="stdp":
                        Delta["left"] = compute_stdp_two_pointer(post_arr, pre_arr, window,
                                                                 params["A_plus"], params["tau_plus"],
                                                                 params["A_minus"], params["tau_minus"])
                    elif rule=="hebbian":
                        Delta["left"] = compute_hebbian_two_pointer(post_arr, pre_arr, window)
                    elif rule=="anti_hebbian":
                        Delta["left"] = - compute_hebbian_two_pointer(post_arr, pre_arr, window)
            if j < N-1:
                pre_arr = np.array(spike_times[i][j+1])
                if pre_arr.size>0:
                    if rule=="stdp":
                        Delta["right"] = compute_stdp_two_pointer(post_arr, pre_arr, window,
                                                                  params["A_plus"], params["tau_plus"],
                                                                  params["A_minus"], params["tau_minus"])
                    elif rule=="hebbian":
                        Delta["right"] = compute_hebbian_two_pointer(post_arr, pre_arr, window)
                    elif rule=="anti_hebbian":
                        Delta["right"] = - compute_hebbian_two_pointer(post_arr, pre_arr, window)
            return {(i,j): Delta}
        tasks = [(i,j) for i in range(N) for j in range(N)]
        parallel_results = Parallel(n_jobs=-1)(delayed(update_single)(i,j) for (i,j) in tasks)
        updates = {}
        for res in parallel_results:
            updates.update(res)
        w_above, w_below, w_left, w_right = [w.copy() for w in self.weights]
        eta = params["eta"]
        for i in range(N):
            for j in range(N):
                cell = updates.get((i,j), {"above":0.0, "below":0.0, "left":0.0, "right":0.0})
                if i>0:
                    w_above[i,j] = np.clip(w_above[i,j] + eta * cell["above"], 0, 1.0)
                if i < N-1:
                    w_below[i,j] = np.clip(w_below[i,j] + eta * cell["below"], 0, 1.0)
                if j>0:
                    w_left[i,j] = np.clip(w_left[i,j] + eta * cell["left"], 0, 1.0)
                if j < N-1:
                    w_right[i,j] = np.clip(w_right[i,j] + eta * cell["right"], 0, 1.0)
        return w_above, w_below, w_left, w_right

    # --------------- Chaos Analysis (ISI Metrics) ---------------
    def analyze_chaos(self, spike_times):
        if self.use_parallel:
            return self._analyze_chaos_parallel(spike_times)
        else:
            return self._analyze_chaos_sequential(spike_times)

    def _analyze_chaos_sequential(self, spike_times):
        isi_means = np.full((self.N, self.N), np.nan)
        isi_stds = np.full((self.N, self.N), np.nan)
        isi_cvs = np.full((self.N, self.N), np.nan)
        lyapunov_approx = np.full((self.N, self.N), np.nan)
        for i in range(self.N):
            for j in range(self.N):
                spikes = spike_times[i][j]
                if len(spikes)>1:
                    sp_arr = np.array(spikes)
                    isi = np.diff(sp_arr)
                    mean_isi = np.mean(isi)
                    std_isi = np.std(isi)
                    cv = std_isi/mean_isi if mean_isi>0 else 0
                    isi_means[i,j] = mean_isi*1e12
                    isi_stds[i,j] = std_isi*1e12
                    isi_cvs[i,j] = cv
                    if len(isi)>2:
                        delta_isi = np.abs(np.diff(isi))
                        norm_delta = np.clip(delta_isi/(mean_isi+1e-15), 0, 10)
                        lyap = np.mean(np.log1p(norm_delta))/(mean_isi*1e9+1e-15)
                        lyapunov_approx[i,j] = lyap
                if self.verbose:
                    print(f"Neuron ({i},{j}): CV = {isi_cvs[i,j]:.3f}, Lyapunov = {lyapunov_approx[i,j]:.3f}")
        mean_cv = np.nanmean(isi_cvs) if np.any(~np.isnan(isi_cvs)) else 0
        mean_lyap = np.nanmean(lyapunov_approx) if np.any(~np.isnan(lyapunov_approx)) else 0
        if self.verbose:
            print(f"Average ISI CV: {mean_cv:.3f}, Average Lyapunov: {mean_lyap:.3f}")
        return isi_means, isi_stds, isi_cvs, mean_cv, mean_lyap

    def _analyze_chaos_parallel(self, spike_times):
        N = self.N
        def process_cell(i, j):
            spikes = spike_times[i][j]
            if len(spikes)>1:
                sp_arr = np.array(spikes)
                isi = np.diff(sp_arr)
                mean_isi = np.mean(isi)
                std_isi = np.std(isi)
                cv = std_isi/mean_isi if mean_isi>0 else 0.0
                lyap = np.nan
                if len(isi)>2:
                    delta_isi = np.abs(np.diff(isi))
                    norm_delta = np.clip(delta_isi/(mean_isi+1e-15), 0, 10)
                    lyap = np.mean(np.log1p(norm_delta))/(mean_isi*1e9+1e-15)
                return (i, j, mean_isi*1e12, std_isi*1e12, cv, lyap)
            else:
                return (i, j, np.nan, np.nan, np.nan, np.nan)
        tasks = [(i,j) for i in range(N) for j in range(N)]
        results = Parallel(n_jobs=-1)(delayed(process_cell)(i,j) for (i,j) in tasks)
        isi_means = np.full((N,N), np.nan)
        isi_stds = np.full((N,N), np.nan)
        isi_cvs = np.full((N,N), np.nan)
        lyapunov_approx = np.full((N,N), np.nan)
        for (i,j,mean_val, std_val, cv, lyap) in results:
            isi_means[i,j] = mean_val
            isi_stds[i,j] = std_val
            isi_cvs[i,j] = cv
            lyapunov_approx[i,j] = lyap
        mean_cv = np.nanmean(isi_cvs) if np.any(~np.isnan(isi_cvs)) else 0
        mean_lyap = np.nanmean(lyapunov_approx) if np.any(~np.isnan(lyapunov_approx)) else 0
        if self.verbose:
            for i in range(N):
                for j in range(N):
                    print(f"Neuron ({i},{j}): CV: {isi_cvs[i,j]:.3f}, Lyapunov: {lyapunov_approx[i,j]:.3f}")
            print(f"Average ISI CV: {mean_cv:.3f}, Average Lyapunov: {mean_lyap:.3f}")
        return isi_means, isi_stds, isi_cvs, mean_cv, mean_lyap

    # --------------- Spike Frequency Encoding ---------------
    def compute_frequency_encoding(self, spike_times):
        frequencies = np.zeros((self.N, self.N))
        for i in range(self.N):
            for j in range(self.N):
                spikes = spike_times[i][j]
                if len(spikes)>1:
                    duration = spikes[-1]-spikes[0]
                    frequencies[i,j] = (len(spikes)-1)/duration if duration>0 else 0
        return frequencies

    # ------------------- Plotting -------------------
    def plot_results(self, spike_times, weights_stdp, frequencies, isi_means, isi_stds, isi_cvs):
        # Figure 1: voltage trace, raster, weight & frequency maps, histogram.
        fig, axes = plt.subplots(5,1, figsize=(10,20), dpi=1200, facecolor='white')
        i,j = 0,0
        axes[0].plot(self.time*1e9, self.V[i,j,:]*1e6,
                     color='crimson', label=f'Voltage ({i},{j})')
        axes[0].set_xlabel('Time (ns)')
        axes[0].set_ylabel('Voltage (μV)')
        axes[0].set_title(f'Voltage Trace (Neuron {i},{j})')
        axes[0].grid(True)
        axes[0].legend()
        all_raster = []
        for i_idx in range(self.N):
            for j_idx in range(self.N):
                all_raster.append(self.time[np.searchsorted(self.time, spike_times[i_idx][j_idx])])
        axes[1].eventplot(all_raster, colors='black')
        axes[1].set_xlabel('Time (ns)')
        axes[1].set_ylabel('Neuron Index')
        axes[1].set_title('Spike Raster Plot')
        axes[1].grid(True)
        im = axes[2].imshow(weights_stdp[0], cmap='hot', interpolation='nearest')
        axes[2].set_title('STDP Weights (From Above)')
        axes[2].set_xlabel('j')
        axes[2].set_ylabel('i')
        fig.colorbar(im, ax=axes[2], label='Weight')
        im_freq = axes[3].imshow(frequencies*1e-9, cmap='viridis', interpolation='nearest')
        axes[3].set_title('Spike Frequencies (GHz)')
        axes[3].set_xlabel('j')
        axes[3].set_ylabel('i')
        fig.colorbar(im_freq, ax=axes[3], label='Frequency (GHz)')
        isi_variability = [isi_cvs[i,j] for i in range(self.N)
                           for j in range(self.N) if not np.isnan(isi_cvs[i,j])]
        axes[4].hist(isi_variability, bins=20, color='purple', alpha=0.7)
        axes[4].set_xlabel('ISI Coefficient of Variation')
        axes[4].set_ylabel('Count')
        axes[4].set_title('ISI Variability (Chaos Indicator)')
        axes[4].grid(True)
        plt.tight_layout()
        plt.savefig('jj_array_results.png')
        plt.close()

        # Figure 2: Heatmaps for ISI metrics.
        fig, axes = plt.subplots(1,3, figsize=(18,6), dpi=1200)
        sns.heatmap(isi_means, annot=True, fmt=".2f", cmap="Blues", ax=axes[0])
        axes[0].set_title("Mean ISI (ps)")
        axes[0].set_xlabel("Neuron Column")
        axes[0].set_ylabel("Neuron Row")
        sns.heatmap(isi_stds, annot=True, fmt=".2f", cmap="Oranges", ax=axes[1])
        axes[1].set_title("ISI Std Dev (ps)")
        axes[1].set_xlabel("Neuron Column")
        axes[1].set_ylabel("Neuron Row")
        sns.heatmap(isi_cvs, annot=True, fmt=".2f", cmap="coolwarm", ax=axes[2])
        axes[2].set_title("ISI CV")
        axes[2].set_xlabel("Neuron Column")
        axes[2].set_ylabel("Neuron Row")
        plt.tight_layout()
        plt.savefig('isi_heatmaps.png')
        plt.close()

# ================= Main Script =================
config = {
    'N': 4,
    'critical_current': 5e-6,    # Larger voltage.
    'capacitance': 1e-11,
    'shunt_resistance': 30,      # Increase slight damping.
    'initial_weight': 0.7,       # Strong coupling.
    'simulation_time': 30e-9,
    'dt': 1e-13,
    'noise_amp': 1e-7,           # Reduced noise amplitude for lower variability.
    'tau_adapt': 5e-9,           # Increased adaptation time constant (slower adaptation reduces chaos).
    'k_adapt': 0.05,
    'ramp_rate': 5e-6 / 30e-9,   # Bias ramp.
    'learning_params': {
        'stdp': {'A_plus': 0.01, 'A_minus': 0.01,
                 'tau_plus': 5e-10, 'tau_minus': 5e-10, 'eta': 0.01},
        'hebbian': {'window': 5e-10, 'eta': 0.01},
        'anti_hebbian': {'window': 5e-10, 'eta': 0.01}
    },
    "integration_method": "euler"  # Use Euler for fast fixed-step integration.
}

# Toggle verbose, Numba, and parallel options.
simulator = JJArraySimulator(config, verbose=False, use_numba=True, use_parallel=True)
simulator.run_simulation()
spike_times = simulator.detect_spikes()
weights_stdp = simulator.update_weights(spike_times, rule="stdp")
isi_means, isi_stds, isi_cvs, chaos_cv, chaos_lyap = simulator.analyze_chaos(spike_times)
frequencies = simulator.compute_frequency_encoding(spike_times)
simulator.plot_results(spike_times, weights_stdp, frequencies, isi_means, isi_stds, isi_cvs)
print(f"Simulation completed. Chaos Metrics - CV: {chaos_cv:.3f}, Lyapunov: {chaos_lyap:.3f}")


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 6), dpi=1200)

# Unified font settings
label_fontsize = 24
tick_fontsize = 20
title_fontsize = 24
annot_fontsize = 20

def set_all_fonts(ax):
    """Ensure font consistency across labels, titles, and ticks."""
    ax.set_title(ax.get_title(), fontsize=title_fontsize)
    ax.set_xlabel(ax.get_xlabel(), fontsize=label_fontsize)
    ax.set_ylabel(ax.get_ylabel(), fontsize=label_fontsize)
    ax.tick_params(axis="both", labelsize=tick_fontsize)

# Plot 1: Mean ISI
heatmap1 = sns.heatmap(isi_means, annot=True, fmt=".2f", cmap="Blues", ax=axes[0], annot_kws={"fontsize": annot_fontsize}, cbar_kws={"aspect": 20})
axes[0].set_title("Mean ISI (ps)")
axes[0].set_xlabel("Neuron Column")
axes[0].set_ylabel("Neuron Row")
set_all_fonts(axes[0])

# Adjust colorbar font size
heatmap1_colorbar = heatmap1.collections[0].colorbar
heatmap1_colorbar.ax.tick_params(labelsize=tick_fontsize)

# Plot 2: ISI Std Dev
heatmap2 = sns.heatmap(isi_stds, annot=True, fmt=".2f", cmap="Oranges", ax=axes[1], annot_kws={"fontsize": annot_fontsize}, cbar_kws={"aspect": 20})
axes[1].set_title("ISI Std Dev (ps)")
axes[1].set_xlabel("Neuron Column")
axes[1].set_ylabel("Neuron Row")
set_all_fonts(axes[1])

# Adjust colorbar font size
heatmap2_colorbar = heatmap2.collections[0].colorbar
heatmap2_colorbar.ax.tick_params(labelsize=tick_fontsize)

# Plot 3: ISI CV
heatmap3 = sns.heatmap(isi_cvs, annot=True, fmt=".2f", cmap="coolwarm", ax=axes[2], annot_kws={"fontsize": annot_fontsize}, cbar_kws={"aspect": 20})
axes[2].set_title("ISI CV")
axes[2].set_xlabel("Neuron Column")
axes[2].set_ylabel("Neuron Row")
set_all_fonts(axes[2])

# Adjust colorbar font size
heatmap3_colorbar = heatmap3.collections[0].colorbar
heatmap3_colorbar.ax.tick_params(labelsize=tick_fontsize)

plt.tight_layout()
plt.savefig('isi_heatmaps_large_fonts.png', transparent=True)
plt.close()
