<a href="https://colab.research.google.com/github/natrask/ENM5320/blob/main/Code/CNN_stencil.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [26]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# Define the model (in our case, its y = A*x + b)
class FiniteDifferenceOperator(nn.Module):
    def __init__(self,Nleft,Nright):
        super(FiniteDifferenceOperator, self).__init__()
        # Number of total nodes in finite difference stencils
        self.Nstencil = Nleft + Nright + 1
        self.Nleft = Nleft
        self.Nright = Nright

        # Initialize with random coefficients
        self.stencil = torch.nn.Parameter(torch.randn(self.Nstencil, dtype=torch.float64))
        # self.stencil = torch.from_numpy((0.5/dx)*np.array([1,0,-1]))

        # Initialize CNN layer
        self.cnn = nn.Conv1d(1, 1, kernel_size=self.Nstencil, bias=False, padding='same', dtype=torch.float64)  # padding='same' for periodic BC
        self.cnn.weight.data = self.stencil.view(1, 1, -1)  # Reshape stencil to match weight shape



    def forward(self, x):
        # Apply the finite difference stencil to the gridfunction x, assuming periodic BC
        N_nodes = x.shape[0]
        f_out = torch.zeros_like(x)
        for i in range(N_nodes):
          # Wrap indices periodically using the modulo operator (%)
          indices = [(i + j - self.Nleft) % (N_nodes-1) for j in range(self.Nstencil)]

          # Grab solution at indices
          xstencil = x[indices].unsqueeze(0).unsqueeze(0)
          cnn_output = self.cnn(xstencil).squeeze().squeeze()

          # Apply stencil
          f_out[i] = torch.sum(cnn_output * xstencil)

        return f_out


# Parameters
L = 2.0*np.pi  # Length of the domain
T = np.pi   # Total time
nx = 50  # Number of spatial points
nt = 50  # Number of time steps

# Discretization
dx = L / nx
dt = T / nt
x = np.linspace(0, L, nx)
u = torch.from_numpy(np.sin(2 * np.pi * x / L))  # Initial condition
uexact = np.sin(2 * np.pi * (x-T) / L)  # Exact solution
Dx = FiniteDifferenceOperator(1,1)  # Finite difference operator w a neighbor on either side

def uexact(x,t):
  return torch.from_numpy(np.sin(2 * np.pi * (x-t) / L))
un = uexact(x,0)
Dx(un)
# plt.plot(Dx(un).detach().numpy())


tensor([-0.0036, -0.0337, -0.1221, -0.2629, -0.4470, -0.6623, -0.8947, -1.1291,
        -1.3501, -1.5432, -1.6958, -1.7980, -1.8430, -1.8279, -1.7537, -1.6253,
        -1.4509, -1.2422, -1.0126, -0.7773, -0.5516, -0.3502, -0.1864, -0.0709,
        -0.0112, -0.0112, -0.0709, -0.1864, -0.3502, -0.5516, -0.7773, -1.0126,
        -1.2422, -1.4509, -1.6253, -1.7537, -1.8279, -1.8430, -1.7980, -1.6958,
        -1.5432, -1.3501, -1.1291, -0.8947, -0.6623, -0.4470, -0.2629, -0.1221,
        -0.0337, -0.0036], dtype=torch.float64, grad_fn=<CopySlices>)

In [27]:
Dx(un)

tensor([-0.0036, -0.0337, -0.1221, -0.2629, -0.4470, -0.6623, -0.8947, -1.1291,
        -1.3501, -1.5432, -1.6958, -1.7980, -1.8430, -1.8279, -1.7537, -1.6253,
        -1.4509, -1.2422, -1.0126, -0.7773, -0.5516, -0.3502, -0.1864, -0.0709,
        -0.0112, -0.0112, -0.0709, -0.1864, -0.3502, -0.5516, -0.7773, -1.0126,
        -1.2422, -1.4509, -1.6253, -1.7537, -1.8279, -1.8430, -1.7980, -1.6958,
        -1.5432, -1.3501, -1.1291, -0.8947, -0.6623, -0.4470, -0.2629, -0.1221,
        -0.0337, -0.0036], dtype=torch.float64, grad_fn=<CopySlices>)