In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from model import PhiNN
from helpers import jump_function


In [None]:
model_fpath = "out/model_training/model3113898_20231018_190742_49"

In [None]:
NCELLS = 100
SIGMA = 1e-3

f_signal = lambda t, p: jump_function(t, p[...,0], p[...,1:3], p[...,3:])

model = PhiNN(
    ndim=2, nsig=2, f_signal=f_signal,
    ncells=NCELLS, 
    sigma=SIGMA,
)

model.load_state_dict(torch.load(model_fpath))
model.eval()

In [None]:
# State space
x = np.linspace(-10, 10, 100)
y = np.linspace(-10, 10, 100)
xs, ys = np.meshgrid(x, y)
z = np.array([xs.flatten(), ys.flatten()]).T
z = torch.tensor(z, dtype=torch.float32, requires_grad=True)

In [None]:
signal_params = np.array([5, 0, 0, 1, 0])
signal_params = torch.tensor(signal_params, dtype=torch.float32)

In [None]:
f = model.f(0, z, signal_params)
f_arr = f.detach().numpy()
u, v = f_arr.T
norms = np.sqrt(u**2 + v**2)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8))
ax.set_xlabel(f"$x$")
ax.set_ylabel(f"$y$")
ax.set_title(f"$f(x,y)$")
ax.quiver(xs, ys, f_arr[:,0], f_arr[:,1])