In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt


def exact_solution(x):
    A = (2 * np.sin(1) - 1) / np.cos(1)
    return A * np.sin(x) + 2 * np.cos(x) + x**2 - 2


class FCN(nn.Module):
    def __init__(self, N_INPUT, N_OUTPUT, N_HIDDEN, N_LAYERS):
        super().__init__()
        activation = nn.Tanh

        self.fcs = nn.Sequential(
            nn.Linear(N_INPUT, N_HIDDEN),
            activation()
        )

        self.fch = nn.Sequential(
            *[nn.Sequential(
                nn.Linear(N_HIDDEN, N_HIDDEN),
                activation()
            ) for _ in range(N_LAYERS - 1)]
        )

        self.fce = nn.Linear(N_HIDDEN, N_OUTPUT)

    def forward(self, x):
        x = self.fcs(x)
        x = self.fch(x)
        x = self.fce(x)
        return x




pinn = FCN(1, 1, 32, 3)

x_boundary_0 = torch.tensor(0.).view(-1, 1).requires_grad_(True)  # y(0) = 0
x_boundary_1 = torch.tensor(1.).view(-1, 1).requires_grad_(True)  # dy/dx(1) = 1

x_physics = torch.linspace(0, 1, 30).view(-1, 1).requires_grad_(True)  # (30, 1)


optimiser = torch.optim.Adam(pinn.parameters(), lr=1e-3)


for i in range(15001):
    optimiser.zero_grad()


    lambda_bc = 1e-1
    lambda_physics = 1e-4


    u_0 = pinn(x_boundary_0)  # y(0)
    loss_bc_0 = (torch.squeeze(u_0) - 0)**2


    u_1 = pinn(x_boundary_1)  # y(1)
    dudx_1 = torch.autograd.grad(u_1, x_boundary_1, torch.ones_like(u_1), create_graph=True)[0]
    loss_bc_1 = (torch.squeeze(dudx_1) - 1)**2


    u = pinn(x_physics)  # (30, 1)
    dudx = torch.autograd.grad(u, x_physics, torch.ones_like(u), create_graph=True)[0]  # du/dx
    d2udx2 = torch.autograd.grad(dudx, x_physics, torch.ones_like(dudx), create_graph=True)[0]  # d²u/dx²
    physics_residual = d2udx2 + u - x_physics**2
    loss_physics = torch.mean(physics_residual**2)


    loss = loss_bc_0 + lambda_bc * loss_bc_1 + lambda_physics * loss_physics
    loss.backward()
    optimiser.step()


    if i in [0, 1000, 2500, 5000, 10000, 15000]:  # Plot at steps 0, 1000, 2500, 5000, 10000, and 15000
        x_test = torch.linspace(0, 1, 300).view(-1, 1)
        u_test = pinn(x_test).detach()

        # Exact solution for comparison
        u_exact = exact_solution(x_test.numpy())

        plt.figure(figsize=(6, 2.5))
        plt.scatter(x_physics.detach()[:, 0],
                    torch.zeros_like(x_physics)[:, 0], s=20, lw=0, color="tab:green", alpha=0.6)
        plt.scatter(x_boundary_0.detach()[:, 0],
                    torch.zeros_like(x_boundary_0)[:, 0], s=20, lw=0, color="tab:red", alpha=0.6)
        plt.scatter(x_boundary_1.detach()[:, 0],
                    torch.zeros_like(x_boundary_1)[:, 0], s=20, lw=0, color="tab:blue", alpha=0.6)
        plt.plot(x_test[:, 0], u_test[:, 0], label="PINN solution", color="tab:green")
        plt.plot(x_test[:, 0], u_exact, label="Exact solution", color="tab:gray", linestyle='--', alpha=0.8)
        plt.title(f"Training step {i}")
        plt.legend()
        plt.show()
