In [None]:
import deepxde as dde
import matplotlib
matplotlib.use('nbagg')
import matplotlib.pyplot as plt
import torch
import numpy as np
import math
from mpl_toolkits.mplot3d import Axes3D

In [None]:
print(torch.cuda.is_available())

In [None]:
beta = 0.1
gamma = 0.04
Ds=1e-4+3e-5 # three diffusion constants
Di=1e-4
Dr=1e-4+3e-5
tot=1

In [None]:
def boundary(_, on_initial):
    return on_initial

In [None]:
def func(x):
    return np.exp((-(x[..., 0:1]-0.5)**2/(2*(0.2)**2)))/math.sqrt(2*np.pi)

In [None]:
def func1(x):
    return 1-0.05-func(x[:, 0:1])

In [None]:
def pde_system(x, y):
    S, I,R = y[:, 0:1], y[:, 1:2], y[:,2:3]
    ds_t = dde.grad.jacobian(y, x, i=0,j=1)
    di_t = dde.grad.jacobian(y, x, i=1,j=1)
    dr_t = dde.grad.jacobian(y, x, i=2,j=1)
    ds_xx = dde.grad.hessian(y, x, i=0,j=0,component=0)
    di_xx = dde.grad.hessian(y, x, i=0,j=0,component=1)
    dr_xx = dde.grad.hessian(y, x, i=0,j=0,component=2)
    return [ds_t+beta*I*S/tot-Ds*ds_xx, di_t-beta*I*S/tot+gamma*I-Di*di_xx, dr_t-gamma*I-Dr*dr_xx,S+I+R-tot]

In [None]:
def output_transform(x,y):
    return y*y+1e-6

In [None]:
timdomain = dde.geometry.TimeDomain(0, 100)
geom=dde.geometry.Interval(0,1)
geomtime=dde.geometry.GeometryXTime(geom,timdomain)
ic1 = dde.icbc.IC(geomtime, func1, lambda _, on_initial:on_initial,component=0)
ic2 = dde.icbc.IC(geomtime, func, lambda _, on_initial:on_initial,component=1)
ic3 = dde.icbc.IC(geomtime, lambda x: 0.05, lambda _, on_initial:on_initial,component=2)
data = dde.data.TimePDE(geomtime, pde_system, [ic1,ic2,ic3], 3000, num_initial=100, num_test=100)
layer_size = [2,20,80,256,120,40,3]
activation = "elu"
initializer = "Glorot normal"
net = dde.nn.FNN(layer_size, activation, initializer)
net.apply_output_transform(output_transform)

In [None]:
model = dde.Model(data, net)
model.compile("adam", lr=0.001)
losshistory, train_state = model.train(iterations=50000)
# Most backends except jax can have a second fine tuning of the solution

In [None]:
X = torch.arange(0, 1, 0.01)
#T = torch.zeros(100)
T = torch.arange(0, 100, 0.5)
X, T = torch.meshgrid(X, T)

In [None]:
#mesh=torch.stack((X,T),dim=-1)
inputs=torch.stack((X,T),dim=-1)
print(inputs.shape)

In [None]:
Z=net(inputs.cuda())

In [None]:
print(Z.shape)

In [None]:
from matplotlib import cm
from matplotlib.ticker import LinearLocator

In [None]:
%matplotlib notebook
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
surf = ax.plot_surface(X.detach().cpu().numpy(), T.detach().cpu().numpy(), Z[:,:,1].detach().cpu().numpy(),cmap=cm.viridis,linewidth=10, antialiased=False)
plt.show()

In [None]:
print(matplotlib.get_backend())