In [1]:
import torch
from eqn_sol import *
import scipy.sparse
from scipy.sparse.linalg import expm
from torch_geometric.utils import from_scipy_sparse_matrix
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## test PropogateSignal
Note: The result is A.T @ x, not A @ x. (doesn't matter if we use symmetric laplacian)

In [2]:
x = torch.rand(5, 3, 4)
A = scipy.sparse.random(5, 5, density=0.1, format='csc')

y1 = torch.tensor((A.T @ x.flatten(start_dim=1, end_dim=-1))).view(*x.size())

propogator = PropogateSignal()
edge_index, edge_weight = from_scipy_sparse_matrix(A)
y2 = propogator(x, edge_index, edge_weight)

y1 == y2

tensor([[[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]],

        [[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]],

        [[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]],

        [[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]],

        [[True, True, True, True],
         [True, True, True, True],
         [True, True, True, True]]])

In [3]:
B = torch.tensor(A.toarray(), dtype=torch.float)
edge_index, edge_weight = from_tensor(B)
y3 = propogator(x, edge_index, edge_weight)

torch.abs(y1 - y3).max().item()

1.9418180730035317e-08

## test HeatEqnSol

In [4]:
x = torch.rand(10, 20, 4, dtype=torch.float).to(device)
A = scipy.sparse.random(10, 10, density=0.1, format='csc', dtype=np.float32)

B = expm(- 3 * A)
y1 = (torch.tensor(B.toarray(), dtype=torch.float).to(device).T @ x.flatten(start_dim=1, end_dim=-1)).view(*x.size())

sol = HeatEqnSolSparseExpm(A, device).to(device)
y2 = sol(x, 3)

(y1 - y2).abs().max().item()

9.5367431640625e-07

## test WaveEqnSol

In [5]:
x = torch.rand(10, 20, 4, dtype=torch.float).to(device)
A = scipy.sparse.random(10, 10, density=0.1, format='csc', dtype=np.float32)
A = A + A.T
A = A.toarray()
A = A + np.diag(np.ones(A.shape[0]))
d = A.sum(axis=0)
D = np.diag(d)
Dinvhf = np.diag(1/np.sqrt(d))
laplacian = Dinvhf @ (D - A) @ Dinvhf
# Compute all eigenvalues and eigenvectors using eigsh
laplacian = torch.tensor(laplacian, dtype=torch.float)

sol = WaveEqnSol(laplacian, device, eps=1e-5).to(device)


In [6]:
torch.abs(sol.V @ torch.diag(sol.sqrtlam ** 2) @ sol.V.T - laplacian.to(device)).max().item()

1.7472812885444e-07

In [7]:
y2 = sol(x, x, 4.)