### One node

In [1]:
import sys

sys.path.insert(1, "/home/INT/lima.v/projects/phase_amplitude_encoding")

In [2]:
import os
import jax
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from frites.core import copnorm_nd
from hoi.core import get_mi
from tqdm import tqdm

from src.models import simulate

  from .autonotebook import tqdm as notebook_tqdm


#### Simulation parameters

In [3]:
ntrials = 50
dt = 1e-4
fsamp = 1 / dt
time = np.arange(-3, 3, 1 / fsamp)
beta = 1e-4
Npoints = len(time)
decim = 20
f = 40

In [4]:
simulate?

[31mSignature:[39m
simulate(
    A: numpy.ndarray,
    g: float,
    f: float,
    a: float,
    fs: float,
    eta: float,
    T: float,
    Iext: numpy.ndarray = [38;5;28;01mNone[39;00m,
    seed: int = [32m0[39m,
    device: str = [33m'cpu'[39m,
    decim: int = [32m1[39m,
    stim_mode: str = [33m'amp'[39m,
)
[31mDocstring:[39m
Simulates a network of coupled oscillators with external stimulation.

Parameters:
----------
A : np.ndarray
    Adjacency matrix representing network connectivity.
g : float
    Coupling strength parameter.
f : float
    Natural frequency of oscillators.
a : float
    Nonlinear parameter influencing oscillator dynamics.
fs : float
    Sampling frequency.
eta : float
    Noise intensity.
T : float
    Total simulation time in discrete steps.
Iext : np.ndarray, optional
    External input to the oscillators (default is None, meaning no input).
seed : int, optional
    Random seed for noise generation (default is 0).
device : str, optional
    Co

In [5]:
                                           #  A    g     f    a   fs  eta  T    I seed dev   dec   stim
simulate_vmap = jax.vmap(simulate, in_axes=(None, None, None, 0, None, 0, None, 0, 0, None, None, None))

In [6]:
a = np.linspace(-5, 5, 50)
beta = np.linspace(1e-5, 3, 50)

pars = np.array(np.meshgrid(a, beta)).T.reshape(-1, 2)


ntrials = pars.shape[0]

seeds = np.random.randint(0, 50000, ntrials)

In [7]:
t_input = np.arange(0, 0.3, dt)  # Extend to 500 ms

def alpha_function(t, t_peak=0.1, tau=0.05):  # Increase tau
    """Alpha function envelope with t_peak as an additional parameter."""
    envelope = (t / t_peak) * np.exp(-t / tau)
    envelope[t < 0] = 0  # Ensure no negative time values
    return envelope / np.max(envelope)  # Normalize

envelope = alpha_function(t_input - 0.03, t_peak=5., tau=0.03)  # Increase tau

freq_start = 80  # Hz
freq_end = 1  # Hz
freq_decay = freq_start * np.exp(-30 * t_input) + freq_end  # Slower decay

phase = np.cumsum(2 * np.pi * freq_decay * dt)  # Integrate frequency to get phase
input = np.sin(phase) * envelope  # Modulate with envelope

In [8]:
Iext = np.zeros((1, Npoints))
Iext[0, (time >= 0) & (time <= 0.3)] = input
Amplitudes = np.linspace(0, .1, ntrials)
CS = Amplitudes[..., None, None] * Iext
seeds = np.random.randint(0, 10000, ntrials)

In [9]:
ntrials

2500

In [10]:
CS.shape

(2500, 1, 60000)

In [None]:
R1, R2 = [], []


for i in tqdm(range(100)):
    seeds = np.random.randint(0, 10000, ntrials)
    
    #  A         g  f    a                    fs      eta         T      I  seed   dev   dec   stim
    x = simulate_vmap(np.zeros(1), 0, f, pars[:, 0][:, None], fsamp, pars[:, 1], Npoints, CS, seeds, "cpu", decim, "both")

    R1 += [ (x * np.conj(x)).real ]
    R2 += [ np.unwrap(np.angle(x)) ]

  3%|█▏                                       | 3/100 [03:38<1:55:45, 71.61s/it]

In [23]:
from frites.core import copnorm_nd, gccmi_nd_ccc, gcmi_1d_cc, gcmi_nd_cc

for i in tqdm(range(ntrials)):

    x = out[i]
    
    R1 = (x * np.conj(x)).real
    R2 = np.unwrap(np.angle(x))
    
    # Stims across trials
    stim = np.expand_dims(Amplitudes, axis=(0, 1))
    #stim = np.tile(stim, (len(x))).squeeze().T
    
    
    I_S_1 = gcmi_nd_cc(stim, R1, traxis=0)
    I_S_2 = gcmi_nd_cc(stim, R2, traxis=0)
    R_12 = np.minimum(I_S_1, I_S_2)
    
    plt.plot(x.times.values, I_S_1 - R_12)
    plt.plot(x.times.values, I_S_2 - R_12)

  0%|                                                  | 0/2500 [00:00<?, ?it/s]


AssertionError: 

In [28]:
Amplitudes.shape

(2500,)

In [25]:
gcmi_nd_cc(stim, R1)

AssertionError: 