# Diffusion Equation

\begin{equation}
\frac{\partial g(x,t)}{\partial t} = \frac{\partial^2 g(x,t)}{\partial x^2}
\end{equation}

\begin{split}
\begin{align*}
g(0,t) &= 0 ,\qquad t \geq 0 \\
g(1,t) &= 0,\qquad t \geq 0 \\
g(x,0) &= u(x),\qquad x\in [0,1],\qquad  u(x) = \sin(\pi x)\\
\end{align*}
\end{split}

\begin{equation}
g(x,t) = \exp(-\pi^2 t)\sin(\pi x)
\end{equation}

In [13]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import scienceplots
plt.style.use(['science', 'notebook', 'grid'])

In [14]:
torch.manual_seed(15)


def exact_solution(x, t):
    g = torch.exp(-(torch.pi ** 2) * t) * torch.sin(torch.pi * x)
    return g

In [15]:
class FCN(nn.Module):
    "Defines a connected network"

    def __init__(self, N_INPUT, N_OUTPUT, N_HIDDEN, N_LAYERS):
        super().__init__()
        activation = nn.Tanh
        self.fcs = nn.Sequential(*[
                        nn.Linear(N_INPUT, N_HIDDEN),
                        activation()])
        self.fch = nn.Sequential(*[
                        nn.Sequential(*[
                            nn.Linear(N_HIDDEN, N_HIDDEN),
                            activation()]) for _ in range(N_LAYERS - 1)])
        self.fce = nn.Linear(N_HIDDEN, N_OUTPUT)

    def forward(self, x):
        x = self.fcs(x)
        x = self.fch(x)
        x = self.fce(x)
        return x

In [25]:
model = FCN(1, 1, 32, 4)
x_boundary = torch.linspace(0, 1, 2).view(-1, 1).requires_grad_(True)
x_physics = torch.linspace(0, 1, 30).view(-1, 1).requires_grad_(True)
t_boundary = torch.tensor(0.).view(-1, 1).requires_grad_(True)
t_physics = torch.linspace(0, 1, 30).view(-1, 1).requires_grad_(True)
x_test = torch.linspace(0, 1, 300).view(-1, 1)
t_test = torch.linspace(0, 1, 300).view(-1, 1)
u_exact = exact_solution(x_test, t_test)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
num_iter = 15001

In [None]:
for i in range(num_iter):
    optimizer.zero_grad()

    u = model(t_boundary)
    loss1 = (torch.squeeze(u) - 1) ** 2

    dudt = torch.autograd.grad(u, t_boundary, torch.ones_like(u), create_graph=True)[0]
    loss2 = (torch.squeeze(dudt) - 0) ** 2

    u = model(t_physics)
    dudt = torch.autograd.grad(u, t_physics, torch.ones_like(u), create_graph=True)[0]
    d2udt2 = torch.autograd.grad(dudt, t_physics, torch.ones_like(dudt), create_graph=True)[0]
    loss3 = torch.mean((d2udt2 + mu * dudt + k * u) ** 2)

    lambda1, lambda2 = 1e-1, 1e-4
    loss = loss1 + lambda1 * loss2 + lambda2 * loss3
    loss.backward()
    optimizer.step()

    # if i % 2500 == 0:
    #     u = model(t_test).detach()
    #     diff = str(np.round(np.max(np.abs(u - u_exact).numpy()), 5))
    #     plt.figure(figsize=(8, 4))
    #     plt.scatter(t_physics.detach()[:, 0],
    #                 torch.zeros_like(t_physics)[:, 0], s=20, lw=0, color="tab:green", alpha=0.6)
    #     plt.scatter(t_boundary.detach()[:, 0],
    #                 torch.zeros_like(t_boundary)[:, 0], s=20, lw=0, color="tab:blue", alpha=0.6)
    #     plt.plot(t_test[:, 0], u_exact[:, 0], label="Exact solution", color="tab:red", alpha=0.6)
    #     plt.plot(t_test[:, 0], u[:, 0], label="PINN solution", color="tab:green")
    #     plt.title(f"Training step {i}")
    #     plt.text(0.175, 0.9, f'max absolute difference: {diff}', size=10, bbox=dict(facecolor='white', edgecolor='black'))
    #     plt.legend()
    #     plt.show()