In [None]:
import jittor as jt
import jittor.nn as nn
jt.flags.use_cuda = 1
from jittornode import odeint

import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp

# ===== SciPy 部分 =====
def ode_rhs(t, y_flat):
    y = y_flat.reshape(1, 2)
    dydt = (y ** 3) @ np.array([[-0.1, 2.0], [-2.0, -0.1]])
    return dydt.flatten()

t0, t1 = 0.0, 1.0
y0 = np.array([2.0, 0.0])
t_eval = np.linspace(t0, t1, 1000)

sol = solve_ivp(ode_rhs, t_span=(t0, t1), y0=y0, t_eval=t_eval, method="RK45", rtol=1e-8, atol=1e-10)

# ===== Jittor 部分 =====
true_y0 = jt.array([[2.0, 0.0]])
t_jt = jt.linspace(t0, t1, 1000)
true_A = jt.array([[-0.1, 2.0], [-2.0, -0.1]])

# 定义 ODE 函数
class ODEFunc(nn.Module):
    def execute(self, t, y):
        return jt.matmul(y**3, true_A)
    

jt_solution = odeint(ODEFunc(), true_y0, t_jt, method='rk4')
jt_solution_np = jt_solution.squeeze(1).numpy()

# ===== 误差分析 =====
abs_error = np.abs(jt_solution_np - sol.y.T)
l2_error = np.sqrt(np.sum(abs_error**2, axis=1))

print("最大绝对误差:", abs_error.max())
print("平均L2误差:", np.mean(l2_error))

# ===== 可视化对比 =====
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(t_eval, sol.y[0], label="SciPy y[0]", linestyle="--")
plt.plot(t_eval, jt_solution_np[:, 0], label="Jittor y[0]", alpha=0.8)
plt.legend()
plt.title("y[0] Comparison")

plt.subplot(1, 2, 2)
plt.plot(t_eval, sol.y[1], label="SciPy y[1]", linestyle="--")
plt.plot(t_eval, jt_solution_np[:, 1], label="Jittor y[1]", alpha=0.8)
plt.legend()
plt.title("y[1] Comparison")

plt.tight_layout()
plt.grid(True)
plt.show()