In [1]:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt


In [None]:
# Define the model
class FiniteDifferenceOperator(nn.Module):
    def __init__(self,Nleft,Nright,num_layers, hidden_dim):
        super(FiniteDifferenceOperator, self).__init__()
        # Number of total nodes in finite difference stencils
        self.Nstencil = Nleft + Nright + 1
        self.Nleft = Nleft
        self.Nright = Nright

        # Define a learnable finite difference stencil
        self.stencil = torch.nn.Parameter(torch.randn(self.Nstencil))

        # Define a MLP to model the nonlinearity
        self.mlp = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, Nstencil)
        )
    def forward(self, x, h):
        # Apply the finite difference stencil to the gridfunction x, assuming periodic BC
        N_nodes = x.shape[0]
        Dh_x = torch.zeros_like(x)

        # Enforce sum of stencil entries are zero
        self.stencil[-1] = -torch.sum(self.stencil[:-1])

        # Goal - build up D^* grad(N[D x])

        # Step 1 - apply D stencil to x 
        for i in range(N_nodes):
          # Wrap indices periodically using the modulo operator (%)
          indices = [(i + j - self.Nleft) % (N_nodes-1) for j in range(self.Nstencil)]

          # Grab solution at indices
          xstencil = x[indices]

          # Apply learned stencil to xstencil
          Dh_x[i] = torch.sum(self.stencil * xstencil)

        # Step 2 - calculate gradient of mlp applied to Dh_x
        f_out = self.mlp(Dh_x.unsqueeze(1))
        
     
        # Return stencil applied to current state consisting of nonlinearity and stabilizing diffusion
        return f_out