<a href="https://colab.research.google.com/github/apester/PINN/blob/main/Simle_ODE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# ============================================================
# 1. Hyperparameters
# ============================================================
n_hidden = 20        # neurons per hidden layer
n_collocation = 100  # collocation points per batch
n_epochs = 8000      # training iterations
learning_rate = 1e-3
device = torch.device("cpu")  # change to "cuda" if you have GPU

# Set a random seed for reproducibility
torch.manual_seed(0)
np.random.seed(0)

# ============================================================
# 2. Define the neural network with 2 hidden layers (sigmoid)
# ============================================================
class OdeNet(nn.Module):
    def __init__(self, n_hidden):
        super(OdeNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(1, n_hidden),
            nn.Sigmoid(),         # first hidden layer activation
            nn.Linear(n_hidden, n_hidden),
            nn.Sigmoid(),         # second hidden layer activation
            nn.Linear(n_hidden, 1)  # output layer (no activation)
        )

    def forward(self, x):
        return self.net(x)


model = OdeNet(n_hidden).to(device)

# ============================================================
# 3. Trial solution y_trial(x) = 1 + x * N(x; theta)
#    (BC y(0) = 1 is built in)
# ============================================================
def y_trial(x):
    """
    x: tensor of shape (N, 1)
    y_trial(x) = 1 + x * model(x)
    """
    return 1.0 + x * model(x)


# ============================================================
# 4. Compute derivative dy/dx using autograd
# ============================================================
def dy_dx(x):
    """
    Returns dy/dx of the trial solution at points x.
    x must have requires_grad=True.
    """
    y = y_trial(x)
    # grad_outputs is tensor of ones with same shape as y
    dy_dx_vals = torch.autograd.grad(
        outputs=y,
        inputs=x,
        grad_outputs=torch.ones_like(y),
        create_graph=True  # keep graph for higher-order derivatives if needed
    )[0]
    return dy_dx_vals


# ============================================================
# 5. Define the loss function (mean squared residual)
#    Residual R(x) = dy/dx - e^x
# ============================================================
def residual_loss(x):
    """
    x: collocation points in [0,1], shape (N,1)
    """
    x.requires_grad_(True)
    d_y = dy_dx(x)               # dy/dx at these points
    rhs = torch.exp(x)           # e^x
    residual = d_y - rhs         # should be near 0
    loss = torch.mean(residual ** 2)
    return loss


# ============================================================
# 6. Training loop
# ============================================================
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(1, n_epochs + 1):
    # Sample collocation points uniformly in [0, 1]
    x_collocation = torch.rand((n_collocation, 1), device=device)

    optimizer.zero_grad()
    loss = residual_loss(x_collocation)
    loss.backward()
    optimizer.step()

    if epoch % 1000 == 0:
        print(f"Epoch {epoch:5d}, Loss = {loss.item():.6e}")

# ============================================================
# 7. Evaluate the trained model on a grid and compare to exact
# ============================================================
x_test = torch.linspace(0, 1, 200).view(-1, 1).to(device)
with torch.no_grad():
    y_pred = y_trial(x_test)
    y_exact = torch.exp(x_test)

x_np = x_test.cpu().numpy().flatten()
y_pred_np = y_pred.cpu().numpy().flatten()
y_exact_np = y_exact.cpu().numpy().flatten()

# Compute absolute error
error_np = np.abs(y_pred_np - y_exact_np)

# ============================================================
# 8. Plot the solution and the error
# ============================================================
plt.figure(figsize=(10, 4))

# Plot learned vs exact solution
plt.subplot(1, 2, 1)
plt.plot(x_np, y_exact_np, label="Exact $e^x$")
plt.plot(x_np, y_pred_np, "--", label="NN approximation")
plt.xlabel("x")
plt.ylabel("y(x)")
plt.title("Solution of dy/dx = e^x, y(0)=1")
plt.legend()
plt.grid(True)

# Plot absolute error
plt.subplot(1, 2, 2)
plt.plot(x_np, error_np)
plt.xlabel("x")
plt.ylabel("|y_NN(x) - e^x|")
plt.title("Absolute Error")
plt.grid(True)

plt.tight_layout()
plt.show()
