In [3]:
import numpy as np
import sympy as sp
import matplotlib.pyplot as plt

In [2]:
from math import factorial

def coefficients(m, alpha):
    alpha = np.stack(alpha)
    p = len(alpha)
    A = np.zeros((p, p))
    A[0].fill(1)
    for k in range(1, p):
        A[k] = alpha**k
    return factorial(m)*np.linalg.inv(A)[:, m]

In [8]:
from dataclasses import dataclass

@dataclass
class Diff:
    m: int
    n: int
    p: int
    h: float

    def __post_init__(self):
        assert self.n >= 2*self.p
        P = np.arange(self.p)[np.newaxis].repeat(self.p, axis=0)
        l = int(self.p/2)
        d = int(l*2 != self.p)
        C = np.stack([coefficients(self.m, alpha) for alpha in P - P.T])
        M = np.zeros((self.n - 2*l, self.n))
        for i in range(self.n - 2*l):
            M[i, i:i+self.p] = C[l]
        self.M = np.c_[ 
            '0',
            np.pad(C[:l], [(0, 0), (0, n - self.p)]),
            M,
            np.pad(C[l+d:], [(0, 0), (n - self.p, 0)]),
        ]/(self.h**self.m)
    
    def __call__(self, y):
        return self.M.dot(y)

In [124]:
class Lattice:
    def __init__(self, *setup):
        self.ranges = [np.arange(xl, xr, dx) for xl, xr, dx in setup]
    
    @property
    def shape(self):
        return tuple(map(len, self.ranges))
    
    @property
    def grid(self):
        return np.stack(np.meshgrid(*self.ranges), axis=-1)
    
    def loc(self, x, axis=0):
        return np.abs(self.ranges[axis] - x).argmin()
    
    def at(self, *locs):
        return tuple((
            self.loc(x, axis=i) if not x in (Ellipsis, None) else x 
            for i, x in enumerate(locs)
        ))
    
    def window(self, *lims):
        return tuple((
            slice(self.loc(lim[0], axis=i), self.loc(lim[1], axis=i)) if not lim in (Ellipsis, None) else lim 
            for i, lim in enumerate(lims)
        ))

In [115]:
lat = Lattice(*[(-1, 1, 0.1)]*3)
lat.grid.shape

(20, 20, 20, 3)

In [47]:
y = np.random.random(lat.shape)

In [76]:
y[*lat.at(..., 0.1)].shape

(20, 20)

In [123]:
y[*lat.window((-0.1, 0.15), ..., (-1, 0))].shape

(3, 20, 10)