# Solving an ODE

$$
\begin{aligned}
y'(x) &= -2y(x)+2x^2+2x\\
y(0) &= 1 \\
0 \le x & \le 0.5
\end{aligned}

The exact solution:
$$
y(x) = e^{-2x}+x^2
$$

In [1]:
# install pytorch
import torch
from torch import nn, optim
import matplotlib.pyplot as plt
import numpy as np


In [2]:
# Define input x
x = torch.linspace(0, 0.5, 100, requires_grad=True)

print(x.shape)
x = x.reshape(-1, 1)
print(x.shape)

torch.Size([100])
torch.Size([100, 1])


In [3]:
# Define an Neural Network as differential equation y(x).
# Input x, output y
mlp = nn.Sequential(
    nn.Linear(1, 10),
    nn.Tanh(),
    nn.Linear(10, 20),
    nn.Tanh(),
    nn.Linear(20, 1),
)

In [4]:
# optimizer to tune NN
optimizer = optim.Adam(list(mlp.parameters()), lr=0.001)

In [5]:
# Define an order of the NN, i.e., dy/dx
def dy_dx(y, x):
    return torch.autograd.grad(
        y, x, grad_outputs=torch.ones_like(y), create_graph=True
    )[0]

In [6]:
losses = []

for i in range(500):
    y = mlp.forward(x)
    y_p = dy_dx(y, x)


    # y'=-2y+2x^2+2x
    residual = y_p - (-2*y+2*x**2+2*x)

    initial = y[0] - 1

    loss = (residual**2).mean() + initial**2

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(loss.detach().numpy()[0])

    if i % 10 == 0:
        print("Epoch %3d: Current loss: %.2e" % (i, losses[-1]))

Epoch   0: Current loss: 2.88e+00
Epoch  10: Current loss: 1.83e+00
Epoch  20: Current loss: 1.13e+00
Epoch  30: Current loss: 7.45e-01
Epoch  40: Current loss: 5.92e-01
Epoch  50: Current loss: 5.49e-01
Epoch  60: Current loss: 5.32e-01
Epoch  70: Current loss: 5.18e-01
Epoch  80: Current loss: 5.07e-01
Epoch  90: Current loss: 5.01e-01
Epoch 100: Current loss: 4.97e-01
Epoch 110: Current loss: 4.93e-01
Epoch 120: Current loss: 4.89e-01
Epoch 130: Current loss: 4.86e-01
Epoch 140: Current loss: 4.83e-01
Epoch 150: Current loss: 4.80e-01
Epoch 160: Current loss: 4.77e-01
Epoch 170: Current loss: 4.74e-01
Epoch 180: Current loss: 4.71e-01
Epoch 190: Current loss: 4.68e-01
Epoch 200: Current loss: 4.65e-01
Epoch 210: Current loss: 4.62e-01
Epoch 220: Current loss: 4.60e-01
Epoch 230: Current loss: 4.57e-01
Epoch 240: Current loss: 4.54e-01
Epoch 250: Current loss: 4.51e-01
Epoch 260: Current loss: 4.48e-01
Epoch 270: Current loss: 4.45e-01
Epoch 280: Current loss: 4.42e-01
Epoch 290: Cur

In [7]:
x_test = torch.linspace(0, 0.5, 31).reshape(-1, 1)
# The exact symbolic solution
exact = torch.exp(-2*x_test)+x_test**2
# The PINN prediction
predict = mlp.forward(x_test).detach().numpy()
error = exact - predict

MAE = torch.abs(error).mean()

print("Mean Absolute Error: %.2e" % MAE)

Mean Absolute Error: 8.23e-02


In [None]:

fig, axs = plt.subplots(1, 3, figsize=(15, 5))

axs[0].plot(x_test, exact, "g-", label="Exact")
axs[0].plot(x_test, predict, "r.", label="Predict")

axs[1].plot(x_test, error, "b", label="Residual")

axs[2].plot(np.log10(losses), "c", label="Loss")

for ax in axs:
    ax.set_xlabel("$x$")
    ax.set_ylabel("$y(x)$")
    ax.legend()