In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
import torch
import os
from tqdm.notebook import tqdm
import itertools

# Lorenz system

In [None]:
def lorenz(xyz, sigma=10.0, rho=28.0, beta=8.0 / 3.0):
    x, y, z = xyz
    dx = sigma * (y - x)
    dy = x * (rho - z) - y
    dz = x * y - beta * z
    return np.array([dx, dy, dz])

def multi_node_lorenz(t, state, N, K, C):
    state = state.reshape(N, 3)
    dstate = np.zeros_like(state)
    
    for i in range(N):
        dstate[i] = lorenz(state[i])
        
        coupling = np.zeros(3)
        for j in range(N):
            coupling+= C[i, j] * (state[j] - state[i])

        dstate[i] += K * coupling
    
    return dstate.flatten()

def lorenz_system(N = 10):

    K = 0.1    # coupling strength
    C = np.ones((N, N)) - np.eye(N)

    np.random.seed(0)
    initial_state = np.random.rand(N, 3)

    t_span = [0, 30]
    t_eval = np.linspace(t_span[0], t_span[1], 5000)

    sol = solve_ivp(multi_node_lorenz, t_span, initial_state.flatten(), t_eval=t_eval, args=(N, K, C))

    X_all = sol.y.reshape(N, 3, -1)  # shape (N, 3, T)

    return X_all # (N, 3, T)



# save data

In [None]:
N = 20
trial_num = 10
data_save_dir = "./Lorenz_data"

os.makedirs(data_save_dir, exist_ok=True)

for N in N_list:
    print("N:", N)

    for i in tqdm(range(trial_num), desc="dataset", position=0):

        X_all = lorenz_system(N) # [N, 3, T]

        states_time = torch.from_numpy(X_all).float()
        states_time = states_time.permute(2, 0, 1) # [N, 3, T] -> [T, N, 3]

        data_path = os.path.join(data_save_dir, "Lorenz_data_n%s_trial0%s.pt" % (N, i+1))

        dataset = {
            "states_time": states_time,
        }

        torch.save(dataset, data_path)
