<a href="https://colab.research.google.com/github/Royal4224/ENM_5320/blob/main/Code/PyTorchFDM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# Define the model (in our case, its y = A*x + b)
class FiniteDifferenceOperator(nn.Module):
    def __init__(self,Nleft,Nright):
        super(FiniteDifferenceOperator, self).__init__()
        # Number of total nodes in finite difference stencils
        self.Nstencil = Nleft + Nright + 1
        self.Nleft = Nleft
        self.Nright = Nright

        # Initialize with random coefficients
        self.stencil = torch.nn.Parameter(torch.randn(self.Nstencil))
        # self.stencil = torch.from_numpy((0.5/dx)*np.array([1,0,-1]))
        # a unit test to see if learnable stencil performs like hard-coded stencil
        # from finite difference example


    def forward(self, x):
        # Apply the finite difference stencil to the gridfunction x, assuming periodic BC
        N_nodes = x.shape[0]
        f_out = torch.zeros_like(x)
        for i in range(N_nodes):
          for jj in range(self.Nstencil):
            j = jj - self.Nleft
            j_withbc = (i+j)%(N_nodes-1)
            f_out[i] += self.stencil[jj]*x[j_withbc]

        return f_out


# Parameters
L = 2.0*np.pi  # Length of the domain
T = np.pi   # Total time
nx = 50  # Number of spatial points
nt = 50  # Number of time steps

# Discretization
dx = L / nx
dt = T / nt
x = np.linspace(0, L, nx)
u = torch.from_numpy(np.sin(2 * np.pi * x / L))  # Initial condition
uexact = np.sin(2 * np.pi * (x-T) / L)  # Exact solution
Dx = FiniteDifferenceOperator(1,1)  # Finite difference operator w a neighbor on either side

def uexact(x,t):
  return torch.from_numpy(np.sin(2 * np.pi * (x-t) / L))

# Define the optimizer so that it optmizes over stencil parameters
optimizer = optim.Adam(Dx.parameters(), lr=0.01)

num_epochs = 1000
for epoch in range(num_epochs):
    loss = 0
    for n in range(nt):
        t = n*dt
        un = uexact(x,t)
        unp1 = uexact(x,t+dt)
        dudtn = (unp1-un)/dt
        dudt_learned = Dx(un)/dx
        loss = loss + torch.mean((dudtn-dudt_learned)**2)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.4f}', Dx.stencil.detach().numpy())

# Plot the results
# for n in range(nt):
#     u_new = np.zeros_like(u)
#     for i in range(nx):
#         u_new[i] = u[i] + ( 0.5 * dt / dx) * (u[(i+1) % (nx-1)] - u[(i-1) % (nx-1)])
#     for i in range(nx):
#         u[i] = u_new[i]
#     if n % 10 == 0:
#       tn = n*dt
#       uexact = np.sin(-2 * np.pi * (x-tn) / L)
#       plt.plot(x, u,'--',label='Numerical')
#       plt.plot(x, uexact, label='Exact')


Epoch [1/1000], Loss: 934.7374 [-1.6706849 -0.5000951  1.66178  ]
Epoch [11/1000], Loss: 576.9005 [-1.5739075 -0.4039905  1.7568829]
Epoch [21/1000], Loss: 479.4702 [-1.4953773  -0.33122113  1.8204918 ]
Epoch [31/1000], Loss: 477.7966 [-1.4498906  -0.30265248  1.8234059 ]
Epoch [41/1000], Loss: 459.9402 [-1.4304941  -0.30993524  1.7802165 ]
Epoch [51/1000], Loss: 441.6493 [-1.4137815 -0.3239203  1.7272108]
Epoch [61/1000], Loss: 427.4574 [-1.384915   -0.32605287  1.6845527 ]
Epoch [71/1000], Loss: 411.6237 [-1.3462449 -0.3183057  1.6501122]
Epoch [81/1000], Loss: 395.7258 [-1.3061723  -0.31053922  1.6136045 ]
Epoch [91/1000], Loss: 379.6728 [-1.268007   -0.30666578  1.5713279 ]
Epoch [101/1000], Loss: 363.4683 [-1.2298437  -0.30408102  1.5264331 ]
Epoch [111/1000], Loss: 347.2924 [-1.1898882  -0.30025446  1.4818002 ]
Epoch [121/1000], Loss: 331.2041 [-1.1485986  -0.29552168  1.4372696 ]
Epoch [131/1000], Loss: 315.2977 [-1.1069998  -0.29100356  1.3918873 ]
Epoch [141/1000], Loss: 299.6