In [21]:
import numpy as np
import torch
import torch.nn as nn
import os
from torch.utils.data import DataLoader,Dataset
import plotly.express as px
import plotly.graph_objects as go

VER='test'
A=100
T=96
x_shape=(100,50)
BATCH_SIZE = 100
LEARNING_RATE = 1e-4
EPOCHS = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DEVICE = 'cpu'
SEED=24# 设置随机数种子



def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     torch.backends.cudnn.deterministic = True
setup_seed(SEED)

def para_init(x):
    para=[]
    for i in range(8):
        para.append(torch.rand_like(x[:,1]).to(DEVICE))
    return para

class custom_dataset(Dataset):#需要继承data.Dataset
    def __init__(self,size,shape):
        self.size=size
        self.shape=shape
    def __getitem__(self, index):
        a=torch.rand(self.shape[0])*self.size[0]
        t=torch.rand(self.shape[1])*self.size[1]
        return torch.concat((torch.meshgrid(a,t)[0].reshape(-1,1),torch.meshgrid(a,t)[1].reshape(-1,1)),dim=1)[index]
    def __len__(self):
        return self.shape[0]*self.shape[1]

class Net(nn.Module):
    def __init__(self,input,output):
        super(Net, self).__init__()
        self.fc=nn.Sequential(
            nn.Linear(input,64),
            nn.Tanh(),
            nn.Linear(64,64),
            nn.Tanh(),
            nn.Linear(64,64),
            nn.Tanh(),
            nn.Linear(64,64),
            nn.Tanh(),
            nn.Linear(64,64),
            nn.Tanh(),
            nn.Linear(64,64),
            nn.Tanh(),
            nn.Linear(64,output),
            nn.Softplus()
        )
 
    def forward(self, x):
        u = self.fc(x)
        return u

def f(u):
    s,i,r,lam=list(u.reshape((4,-1)))
    return torch.stack((-(d+(v+theta)*p)*s-s*lam*torch.sum(i)+delta*r,
        s*lam*torch.sum(i)-(d+gamma)*i,
        (v+theta)*p*s+gamma*i-(d+delta)*r),dim=1)

def ode_loss(u,x):
    dudx=torch.zeros_like(u[:,:-1])
    for i in range(u.shape[1]-1):
        dudx[:,i]=torch.sum(torch.autograd.grad(u[:,i], x,grad_outputs=torch.ones_like(u[:,0]),create_graph=True)[0],axis=1)
    odeloss=loss_fn(dudx,f(u))
    return odeloss

def boundary(x):
    boundary_t=model(torch.stack((torch.zeros_like(x[:,0]),x[:,1]),dim=1))[:,:3]
    boundary_a=model(torch.stack((x[:,1],torch.zeros_like(x[:,0])),dim=1))[:,:3]
    return loss_fn(torch.concat((boundary_t,boundary_a),dim=1),torch.stack((b-m,m,torch.zeros_like(m),s0,i0,r0),dim=1))

In [3]:
x=torch.rand(x_shape).requires_grad_(True).to(DEVICE)
para=para_init(x)
d,v,p,theta,delta,b,m,gamma=para
s0=torch.rand_like(x[:,1]).to(DEVICE)
i0=torch.rand_like(x[:,1]).to(DEVICE)
r0=torch.rand_like(x[:,1]).to(DEVICE)



model=Net(2,4)
model.to(DEVICE)
loss_fn =nn.MSELoss()
trainLoader=DataLoader(custom_dataset((A,T),x_shape),batch_size=BATCH_SIZE,shuffle=True)

optim=torch.optim.Adam(model.parameters(),lr=LEARNING_RATE)
start_epoch=0

# checkpoint = torch.load("./checkpoint/%s/ckpt_best_1000.pth"%VER)  # 加载断点
# model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
# start_epoch = checkpoint['epoch']  # 设置开始的epoch
# optim.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数

# 训练
epoch_train = []  # 存放训练误差
os.makedirs('./checkpoint/%s'%VER,exist_ok=True)
for epoch in range(start_epoch,EPOCHS):
    for x in trainLoader:        # 加载训练集
        x_trainData= x.requires_grad_(True).to(DEVICE)
        beta=model(x_trainData)[:,2].reshape(-1)
        u = model(x_trainData)             # forward
        oloss=ode_loss(u,x_trainData)
        bloss=boundary(x_trainData)
        loss=oloss+bloss  # loss
        optim.zero_grad()                          # gradient to zero
        loss.backward()                            # backward
        optim.step()                               # gradient descent
 
    epoch_train.append(loss.item())
    if (epoch+1) % 10 == 0:
        print("epoch:%d,train loss:%10.5e,boundary loss:%10.3e,ODE loss:%10.3e" % (epoch+1,loss.item(),bloss.item(),oloss.item()))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


epoch:10,train loss:9.95652e-01,boundary loss: 1.473e-01,ODE loss: 8.483e-01
epoch:20,train loss:3.99764e-01,boundary loss: 1.558e-01,ODE loss: 2.439e-01
epoch:30,train loss:2.73534e-01,boundary loss: 1.618e-01,ODE loss: 1.118e-01
epoch:40,train loss:2.32755e-01,boundary loss: 1.656e-01,ODE loss: 6.717e-02
epoch:50,train loss:2.15129e-01,boundary loss: 1.683e-01,ODE loss: 4.681e-02
epoch:60,train loss:2.04301e-01,boundary loss: 1.702e-01,ODE loss: 3.414e-02
epoch:70,train loss:1.98349e-01,boundary loss: 1.713e-01,ODE loss: 2.708e-02
epoch:80,train loss:1.93894e-01,boundary loss: 1.724e-01,ODE loss: 2.153e-02
epoch:90,train loss:1.91936e-01,boundary loss: 1.724e-01,ODE loss: 1.957e-02
epoch:100,train loss:1.88621e-01,boundary loss: 1.708e-01,ODE loss: 1.784e-02


In [23]:
fig = go.Figure(data=[go.Mesh3d(x=x.cpu().detach().numpy()[:,0],y=x.cpu().detach().numpy()[:,1],z=u.cpu().detach().numpy()[:,0])])
fig.show()