In [1]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']= 'TRUE'
import numpy as np
from scipy.integrate import solve_ivp
import math
Nv = np.load('Nv.npy')
print(Nv)
Sv = np.load('Sv.npy')
print(Sv)
from scipy import interpolate
param = (interpolate.InterpolatedUnivariateSpline(Nv, Sv),)    

def solve(N_span,ini_state):
    def tensor(t, state, spline):
        z, x, y = state
        sigma = float(spline(t))
        dz = 1.5 * sigma - 1
        dx = -3 * x + 1.5 * sigma * x - math.exp(z) * y
        dy = -y + 1.5 * sigma * y + math.exp(z) * x
        return [dz, dx, dy]
    def subhorizon(N, state, spline):
        return state[0]-5
    subhorizon.terminal = True; subhorizon.direction = 1
    def jacobian(N, state, spline):
        z, x, y = state
        sigma = float(spline(N))
        jac = np.array(
            (
                (0, 0, 0),
                (-math.exp(z)*y, -3+1.5*sigma, -math.exp(z)),
                (math.exp(z)*x, math.exp(z), -1+1.5*sigma),
            ),
            dtype = np.float64,
        )
        return jac
    
    result = solve_ivp(tensor, N_span, ini_state, 
                        method='LSODA', 
                        #t_eval=N,
                        dense_output=True,
                        events=[subhorizon,],
                        args=param,
                        rtol=1e-6, atol=[1e-10, 1e-20, 1e-20],
                        jac=jacobian,
                        )
    return result.sol

import json
import torch
from torch.utils.data import Dataset
import numpy as np
from sklearn.preprocessing import StandardScaler

class JSONDataset(Dataset):
    def __init__(self, batch_size=1):
        self.batch_size = batch_size
        with open("solve.ivp2.json", 'r', encoding='utf-8') as f:
            self.documents = json.load(f)
        self.scaler = StandardScaler()
     
    def __len__(self):
        return len(self.documents)

    def __getitem__(self, idx):
        start_idx = idx 
        end_idx = min((idx + 1), len(self.documents))

        batch_documents = self.documents[start_idx:end_idx]

        # 用于存储输入和输出数据
        train_inputs = []
        train_outputs = []

        for document in batch_documents:
            N_span_start, N_span_end = document['N_span']
            sol = solve(document['N_span'], document['ini_state'])
            N_effective = Nv[(Nv >= sol.t_min) & (Nv <= sol.t_max)]

            for t in N_effective:
                # 生成训练输入
                train_inputs.append([t] + [N_span_start, N_span_end] + list(document['ini_state']))

                # 计算输出
                output = sol(t)
                train_outputs.append(output)

        # 将输入输出转换为numpy数组，再转为torch张量
        train_inputs = np.array(train_inputs)
        train_outputs = np.array(train_outputs)

        train_outputs = self.scaler.fit_transform(train_outputs)


        train_inputs = torch.tensor(train_inputs, dtype=torch.float32)
        train_outputs = torch.tensor(train_outputs, dtype=torch.float32)



        return train_inputs, train_outputs


import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from IPython.display import clear_output
from sklearn.preprocessing import StandardScaler

# -------------------------
# 第一步：创建数据集和 DataLoader
# -------------------------
# 创建 JSON 数据集
  # 你的 JSON 文件路径
dataset = JSONDataset()

# 创建 DataLoader
batch_size = 1
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

[0.000e+00 1.000e-02 2.000e-02 ... 6.738e+01 6.739e+01 6.740e+01]
[1.         1.         1.         ... 0.32904425 0.32245421 0.31593402]


In [2]:
import torch.nn as nn
from tqdm import tqdm
epochs = 1
for epoch in range(epochs):
    batch_bar = tqdm(dataloader, desc=f"Epoch {epoch}/{epochs}", leave=False)
    for batch_inputs, batch_outputs in batch_bar:
        print(1)
        print(batch_inputs.shape)
        print(batch_outputs.shape)

Epoch 0/1:   0%|          | 2/10000 [00:00<29:44,  5.60it/s]

1
torch.Size([1, 2382, 6])
torch.Size([1, 2382, 3])
1
torch.Size([1, 1740, 6])
torch.Size([1, 1740, 3])
1
torch.Size([1, 1443, 6])
torch.Size([1, 1443, 3])


Epoch 0/1:   0%|          | 6/10000 [00:00<22:50,  7.29it/s]

1
torch.Size([1, 2382, 6])
torch.Size([1, 2382, 3])
1
torch.Size([1, 1092, 6])
torch.Size([1, 1092, 3])
1
torch.Size([1, 774, 6])
torch.Size([1, 774, 3])


Epoch 0/1:   0%|          | 8/10000 [00:01<20:17,  8.21it/s]

1
torch.Size([1, 1358, 6])
torch.Size([1, 1358, 3])
1
torch.Size([1, 1616, 6])
torch.Size([1, 1616, 3])
1


Epoch 0/1:   0%|          | 9/10000 [00:01<20:26,  8.15it/s]

torch.Size([1, 1072, 6])
torch.Size([1, 1072, 3])


Epoch 0/1:   0%|          | 10/10000 [00:01<25:26,  6.54it/s]

1
torch.Size([1, 1464, 6])
torch.Size([1, 1464, 3])
1
torch.Size([1, 1768, 6])
torch.Size([1, 1768, 3])


Epoch 0/1:   0%|          | 12/10000 [00:01<26:18,  6.33it/s]

1
torch.Size([1, 2383, 6])
torch.Size([1, 2383, 3])


Epoch 0/1:   0%|          | 15/10000 [00:02<23:57,  6.94it/s]

1
torch.Size([1, 2382, 6])
torch.Size([1, 2382, 3])
1
torch.Size([1, 1375, 6])
torch.Size([1, 1375, 3])
1
torch.Size([1, 1443, 6])
torch.Size([1, 1443, 3])


Epoch 0/1:   0%|          | 18/10000 [00:02<19:21,  8.59it/s]

1
torch.Size([1, 1947, 6])
torch.Size([1, 1947, 3])
1
torch.Size([1, 1792, 6])
torch.Size([1, 1792, 3])
1
torch.Size([1, 1719, 6])
torch.Size([1, 1719, 3])


Epoch 0/1:   0%|          | 19/10000 [00:02<24:50,  6.70it/s]

1
torch.Size([1, 2383, 6])
torch.Size([1, 2383, 3])
1
torch.Size([1, 1650, 6])
torch.Size([1, 1650, 3])


Epoch 0/1:   0%|          | 23/10000 [00:03<22:44,  7.31it/s]

1
torch.Size([1, 2383, 6])
torch.Size([1, 2383, 3])
1
torch.Size([1, 1792, 6])
torch.Size([1, 1792, 3])
1
torch.Size([1, 1633, 6])
torch.Size([1, 1633, 3])


Epoch 0/1:   0%|          | 24/10000 [00:03<26:34,  6.25it/s]

1
torch.Size([1, 2383, 6])
torch.Size([1, 2383, 3])


Epoch 0/1:   0%|          | 26/10000 [00:03<28:27,  5.84it/s]

1
torch.Size([1, 2383, 6])
torch.Size([1, 2383, 3])
1
torch.Size([1, 1050, 6])
torch.Size([1, 1050, 3])
1
torch.Size([1, 1467, 6])
torch.Size([1, 1467, 3])


                                                             
KeyboardInterrupt



In [12]:

# -------------------------
# 第二步：定义模型
# -------------------------
class ODE_Network(nn.Module):
    def __init__(self):
        super(ODE_Network, self).__init__()
        self.fc1 = nn.Linear(6, 128)
        self.fc21 = nn.Linear(128, 256)
        self.fc22 = nn.Linear(256, 512)
        self.fc23 = nn.Linear(512, 128)
        self.fc24 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 3)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc21(x))
        x = torch.relu(self.fc22(x))
        x = torch.relu(self.fc23(x))
        x = torch.relu(self.fc24(x))
        x = self.fc3(x)
        return x


model = ODE_Network()

# -------------------------
# 第三步：定义损失函数和优化器
# -------------------------
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

# -------------------------
# 第四步：训练循环
# -------------------------
epochs = 100
losses = []
import torch.nn as nn
from tqdm import tqdm
for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0
    batch_bar = tqdm(dataloader, desc=f"Epoch {epoch}/{epochs}", leave=False)
    for batch_inputs, batch_outputs in batch_bar:
        outputs = model(batch_inputs)
        loss = criterion(outputs, batch_outputs)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item() * batch_inputs.size(0)
        batch_bar.set_postfix(loss=loss.item())
    
    epoch_loss /= len(dataset)
    losses.append(epoch_loss)

torch.save(model.state_dict(), 'ode_model_multi_test.pth')
