In [1]:
from pinnsform import *
from torchviz import make_dot
from dataclasses import dataclass

@dataclass
class Mesh:
    full : torch.Tensor
    part : list

def meshify(l, requires_grad = False):
    parts = [torch.tensor(part, dtype=torch.float32, requires_grad = requires_grad) for part in np.array_split(np.array(list(product(*l))), 2, 1)]
    full = torch.cat(parts, dim=1)
    return Mesh(full, parts)

def generate_mesh_object(point_counts, domain, requires_grad = True, skew = None):
    if isinstance(point_counts, int):
        point_counts = [point_counts]
        domain = [domain]

    full_list = [np.linspace(0, 1, count) for i, count in enumerate(point_counts)]
    if skew:
        full_list = [points ** skew[i] for i, points in enumerate(full_list)]
    full_list = [domain[i][0] + (domain[i][1] - domain[i][0]) * points for i,points in enumerate(full_list)]

    full_mesh = meshify(full_list, requires_grad)
    
    border_lists = [(copy(full_list), copy(full_list)) for _ in full_list]
    for i, _ in enumerate(full_list):
        border_lists[i][0][i] = [float(domain[i][0])]
        border_lists[i][1][i] = [float(domain[i][1])]
    border_meshes = [(meshify(border[0], requires_grad), meshify(border[1], requires_grad)) for border in border_lists]
    
    return full_mesh, border_meshes

def f(model, mesh, of = None):
    if isinstance(mesh, Mesh):
        points = mesh.full
    if of is not None:
        return torch.split(model(points), 1, 1)[of]
    return model(points)

def df(model, mesh, of = 0, wrt = 0, order = 1):
    df_of = f(model, mesh, of)
    respect_to = mesh.part[wrt]
    for _ in range(order):
        df_of = torch.autograd.grad(
            df_of,
            respect_to,
            grad_outputs=torch.ones_like(respect_to),
            create_graph=True,
            retain_graph=True,
        )[0]
    return df_of

In [9]:
class BASE_MODEL(nn.Module):
    def __init__(self):
        super(BASE_MODEL, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(2, 10),
            nn.Tanh(),
            nn.Linear(10, 1)
        )

    def forward(self, x):
        return self.layers(x)

In [16]:
#with torch.no_grad():
mesh, boundaries = generate_mesh_object((5,5), ((0,np.pi),(0,2)), requires_grad=True)

In [20]:
boundaries[1][0].full

tensor([[0.0000, 0.0000],
        [0.7854, 0.0000],
        [1.5708, 0.0000],
        [2.3562, 0.0000],
        [3.1416, 0.0000]], grad_fn=<CatBackward0>)

In [11]:
model = BASE_MODEL()

In [13]:
pde_residue = df(model, mesh, wrt=1, order=1) - 0.5*df(model, mesh, wrt=0, order=2)
print(pde_residue)
graph = make_dot(pde_residue)
graph.save(f"computation_graph_residue.dot")

tensor([[0.0210],
        [0.0220],
        [0.0232],
        [0.0257],
        [0.0301],
        [0.0524],
        [0.0480],
        [0.0411],
        [0.0342],
        [0.0293],
        [0.0796],
        [0.0704],
        [0.0570],
        [0.0423],
        [0.0297],
        [0.0998],
        [0.0873],
        [0.0696],
        [0.0498],
        [0.0316],
        [0.1115],
        [0.0974],
        [0.0779],
        [0.0558],
        [0.0345]], grad_fn=<SubBackward0>)


'computation_graph_residue.dot'