In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
import torch
import os
from tqdm.notebook import tqdm
import itertools

# HR system

In [None]:
def hindmarsh_rose(xyz, a=1.0, b=3.0, c=1.0, d=5.0, s=4.0, r=0.006, x_R=-1.6, I=3.25):
    x, y, z = xyz
    dx = y - a * x**3 + b * x**2 - z + I
    dy = c - d * x**2 - y
    dz = r * (s * (x - x_R) - z)
    return np.array([dx, dy, dz])

def multi_node_hr(t, state, N, K, C):
    state = state.reshape(N, 3)   # [N, 3]
    dstate = np.zeros_like(state) # [N, 3]
    
    for i in range(N):
        dstate[i] = hindmarsh_rose(state[i])
        coupling = np.zeros(3)
        for j in range(N):
            coupling += C[i, j] * (state[j] - state[i])
        
        dstate[i] += K * coupling
    
    return dstate.flatten()

def hr_system(N = 10):
    K = 0.1  
    C = np.ones((N, N)) - np.eye(N)

    np.random.seed(0)
    initial_state = np.random.rand(N, 3) * 0.5

    t_span = [0, 500]
    t_eval = np.linspace(t_span[0], t_span[1], 5000)

    sol = solve_ivp(multi_node_hr, t_span, initial_state.flatten(), t_eval=t_eval, args=(N, K, C))

    X_all = sol.y.reshape(N, 3, -1)  # shape (N, 3, T)

    return X_all # (N, 3, T)


# save data

In [None]:
N = 20
trial_num = 10
data_save_dir = "./HR_data"

os.makedirs(data_save_dir, exist_ok=True)

for i in tqdm(range(trial_num), desc="dataset", position=0):

    X_all = hr_system(N) # [N, 3, T]

    states_time = torch.from_numpy(X_all).float()
    states_time = states_time.permute(2, 0, 1) # [N, 3, T] -> [T, N, 3]

    data_path = os.path.join(data_save_dir, "HR_data_n%s_trial0%s.pt" % (N, i+1))

    dataset = {
        "states_time": states_time,
    }

    torch.save(dataset, data_path)
