In [1]:
import torch
from torch import nn

import dgl

import cloudpickle

import matplotlib.pyplot as plt
from ipywidgets import interact

from tqdm import trange
import csv

import numpy as np
import os

In [2]:
plt.rcParams['figure.figsize'] = [5,4]
plt.rcParams['font.size'] = 16
plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['lines.markersize'] = 3
plt.rcParams['lines.linewidth'] = 3
plt.rcParams['svg.fonttype'] = 'none'
plt.rcParams["axes.formatter.use_mathtext"] = True
plt.rcParams['font.family'] = 'sans-serif'

In [3]:
d = '20240127231416'

learning_dir = 'spmNonParam2DFullODE_scale_'+d
simulation_dir = 'spmSimulate_'+d #'spmSimulate_20230713163922' # 

device = 'cuda:0'

In [4]:
torch.no_grad()

<torch.autograd.grad_mode.no_grad at 0x709fe60c6a90>

In [5]:
with open(os.path.join(simulation_dir, 'Spring_SDE_model.pt'), mode='rb') as f:
    simulation_model = cloudpickle.load(f)
print(simulation_model)

print(simulation_model.state_dict())

print('c : ', simulation_model.dynamicGNDEmodule.calc_module.sp.c())
print('r_c : ', simulation_model.dynamicGNDEmodule.calc_module.sp.r_c())

dynamicGSDEwrapper(
  (dynamicGNDEmodule): dynamicGNDEmodule(
    (calc_module): interactionModule(
      (sp): springPotential()
      (distanceCalc): euclidDistance_nonPeriodic()
    )
    (edgeRefresher): edgeRefresh(
      (edgeConditionModule): radiusgraphEdge(
        (scoreCalcModule): distanceSigmoid()
        (distanceCalc): euclidDistance_nonPeriodic()
        (distance2edge): distance2edge_batched()
      )
      (scorePostProcessModule): pAndLogit2KLdiv()
      (scoreIntegrationModule): scoreListModule()
    )
  )
  (ndataInOutModule): multiVariableNdataInOut()
  (derivativeInOutModule): multiVariableNdataInOut()
  (noiseInOutModule): singleVariableNdataInOut()
)
OrderedDict([('dynamicGNDEmodule.calc_module.gamma', tensor(0.1000)), ('dynamicGNDEmodule.calc_module.sigma', tensor(0.0010)), ('dynamicGNDEmodule.calc_module.sp.logc', tensor(-6.9078, dtype=torch.float64)), ('dynamicGNDEmodule.calc_module.sp.logr_c', tensor(0., dtype=torch.float64))])
c :  tensor(0.0010, dtype=tor

In [6]:
with open(os.path.join(learning_dir, 'Spring_nonParametric2Dfull_learned_model.pt'), mode='rb') as f:
    learned_model = cloudpickle.load(f)
print(learned_model)

print(learned_model.state_dict())

print('c : ', learned_model.dynamicGNDEmodule.calc_module.sp.c())
print('r_c : ', learned_model.dynamicGNDEmodule.calc_module.sp.r_c())

dynamicGSDEwrapper(
  (dynamicGNDEmodule): dynamicGNDEmodule(
    (calc_module): interactionModule_nonParametric_2Dfull(
      (sp): springPotential()
      (distanceCalc): euclidDistance_nonPeriodic()
      (fNN): Sequential(
        (Linear0): Linear(in_features=2, out_features=128, bias=True)
        (ELU0): ELU(alpha=1.0)
        (Linear1): Linear(in_features=128, out_features=128, bias=True)
        (ELU1): ELU(alpha=1.0)
        (Linear2): Linear(in_features=128, out_features=128, bias=True)
        (ELU2): ELU(alpha=1.0)
        (Linear3): Linear(in_features=128, out_features=2, bias=True)
        (Scaling): scalingLayer()
      )
      (f2NN): Sequential(
        (Linear0): Linear(in_features=2, out_features=128, bias=True)
        (ELU0): ELU(alpha=1.0)
        (Linear1): Linear(in_features=128, out_features=128, bias=True)
        (ELU1): ELU(alpha=1.0)
        (Linear2): Linear(in_features=128, out_features=128, bias=True)
        (ELU2): ELU(alpha=1.0)
        (Linear3): Li

In [7]:
L = 5

In [8]:
x = torch.load(os.path.join(simulation_dir, 'Spring_SDE_traj.pt'))
print(x.shape)
print(x)

torch.Size([51, 6, 100, 4])
tensor([[[[ 3.2671e+00,  3.2636e+00, -7.3349e-04, -6.0501e-04],
          [ 2.9669e+00,  3.0053e+00,  1.5974e-04, -9.0202e-05],
          [ 4.0723e+00,  1.7493e+00,  8.1706e-04,  2.3798e-04],
          ...,
          [ 1.2402e+00,  1.7038e+00,  6.2573e-04, -1.1100e-04],
          [ 9.2754e-01,  1.2214e+00,  9.7612e-04, -6.8387e-04],
          [ 3.3541e+00,  3.2866e+00,  6.9277e-04,  9.1953e-04]],

         [[ 7.6422e-01,  3.6269e+00,  4.1272e-04, -8.3670e-04],
          [ 3.1199e+00,  4.4714e+00,  5.2266e-04,  6.9256e-04],
          [ 2.3672e+00,  2.3538e+00,  4.2866e-04, -8.3589e-05],
          ...,
          [ 1.7215e+00,  3.7209e+00,  3.4044e-04,  9.0351e-04],
          [ 4.6922e+00,  2.5542e+00,  3.1684e-04, -5.1171e-04],
          [ 7.9484e-01,  4.4883e+00, -7.2740e-04, -1.5847e-04]],

         [[ 4.0412e+00,  2.3882e-01, -4.1157e-05,  1.3304e-05],
          [ 5.7814e-02,  4.1064e+00,  8.0486e-04, -4.1012e-04],
          [ 3.2180e+00,  4.2426e+00, -9.25

In [9]:
print(x.shape)

torch.Size([51, 6, 100, 4])


In [10]:
t = torch.load(os.path.join(simulation_dir, 'Spring_SDE_t_eval.pt'))
print(t)

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27.,
        28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41.,
        42., 43., 44., 45., 46., 47., 48., 49., 50.])


In [11]:
t_min = t.min()
t_max = t.max()

In [12]:
N_batch = x.shape[1]
N_t = x.shape[0]

In [13]:
N_particles = x.shape[2]

In [14]:
@interact(n=(0, N_batch-1), t = (0, N_t-1))
def f(n, t):
    fig, ax = plt.subplots()
    ax.plot(x[t, n, :, 0], x[t, n, :, 1], 'o')
    ax.quiver(x[t, n, :, 0], x[t, n, :, 1], x[t, n, :, 2], x[t, n, :, 3])
    ax.grid(True)
    ax.set_aspect('equal')
    ax.set_xlim([0, L])
    ax.set_ylim([0, L])
    plt.show()

interactive(children=(IntSlider(value=2, description='n', max=5), IntSlider(value=25, description='t', max=50)…

In [14]:
def make_graph(t, ns):
    gs = [dgl.graph((torch.tensor([], dtype=int), torch.tensor([], dtype=int)), num_nodes=N_particles)] * len(ns)
    for n, g in zip(ns, gs):
        #g.add_nodes(N_particles)
        g.ndata['x'] = x[t, n, :, :2]
        g.ndata['v'] = x[t, n, :, 2:]
    gs = dgl.batch(gs)
    return gs

In [15]:
def pred_force_in_graph(g):
    learned_model.deleteGraph()
    learned_model.dynamicGNDEmodule.edgeRefresher.edgeConditionModule(g)
    g = learned_model.dynamicGNDEmodule.edgeRefresher.createEdge(g)
    g.apply_edges(learned_model.dynamicGNDEmodule.calc_module.calc_message)
    g.apply_nodes(lambda nodes: {'f2_pred': learned_model.dynamicGNDEmodule.calc_module.f2NN(nodes.data['v'])})
    return g.edata['m'].detach().reshape(-1, 2), g.ndata['f2_pred'].detach().reshape(-1, 2), torch.stack(g.edges(), dim=1).detach()

In [16]:
def true_force_in_graph(g):
    simulation_model.deleteGraph()
    simulation_model.dynamicGNDEmodule.edgeRefresher.edgeConditionModule(g)
    g = simulation_model.dynamicGNDEmodule.edgeRefresher.createEdge(g)
    g.apply_edges(simulation_model.dynamicGNDEmodule.calc_module.calc_message)
    g.apply_nodes(lambda nodes: {'f2_true': -simulation_model.dynamicGNDEmodule.calc_module.gamma * nodes.data['v']})
    return g.edata['m'].detach().reshape(-1, 2), g.ndata['f2_true'].detach().reshape(-1, 2), torch.stack(g.edges(), dim=1).detach()

In [17]:
true_force = []
pred_force = []
true_force2 = []
pred_force2 = []
true_edges = []
pred_edges = []
N_minibatch = 50

learned_model.to(device)
simulation_model.to(device)


for t in trange(N_t):
  for n in range(0, N_batch, N_minibatch):
    ns = np.arange(n, min(n+N_minibatch, N_batch))
    g = make_graph(t, ns).to(device)
    tf, tf2, tf_edges = true_force_in_graph(g)
    tf_nt = torch.zeros_like(tf_edges)
    tf_nt[..., 0] = t
    tf_nt[..., 1] = n
    tf_edges = torch.cat((tf_edges, tf_nt), dim=-1)
    true_force.append(tf.cpu())
    true_edges.append(tf_edges.cpu())
    true_force2.append(tf2.cpu())
    simulation_model.deleteGraph()
    #print(true_force)
    g = make_graph(t, ns).to(device)
    pf, pf2, pf_edges = pred_force_in_graph(g)
    pf_nt = torch.zeros_like(pf_edges)
    pf_nt[..., 0] = t
    pf_nt[..., 1] = n
    pf_edges = torch.cat((pf_edges, pf_nt), dim=-1)
    pred_force.append(pf.cpu())
    pred_edges.append(pf_edges.cpu())
    pred_force2.append(pf2.cpu())
    #print(pred_force)
    learned_model.deleteGraph()
pred_force = torch.concat(pred_force, dim=0)
true_force = torch.concat(true_force, dim=0)
pred_edges = torch.concat(pred_edges, dim=0)
true_edges = torch.concat(true_edges, dim=0)
pred_force2 = torch.concat(pred_force2, dim=0)
true_force2 = torch.concat(true_force2, dim=0)

100%|███████████████████████████████████████████| 51/51 [00:02<00:00, 19.20it/s]


In [18]:
if true_edges.shape[0] != pred_edges.shape[0]:
    flg_diffedge = True
else: 
    flg_diffedge = (torch.sort(true_edges, dim=0)[0] != torch.sort(pred_edges, dim=0)[0]).any()

print(flg_diffedge)

if flg_diffedge:
    true_pred_edges, i_true_pred, n_true_pred = torch.unique(torch.cat([true_edges, pred_edges], dim=0), sorted=False, return_inverse=True, return_counts=True, dim=0)
    edges_only_true = n_true_pred == 1 and i_true_pred < true_edges.shape[0]
    edges_only_pred = n_true_pred == 1 and i_true_pred >= true_edges.shape[0]
    edges_both = n_true_pred == 2
    true_force_all = torch.zeros([true_pred_edges.shape[0], 2], dtype=true_force.dtype, device=true_force.device)
    pred_force_all = torch.zeros([true_pred_edges.shape[0], 2], dtype=pred_force.dtype, device=pred_force.device)
    true_force_all[i_true_pred[edges_only_true]] = true_force
    pred_force_all[i_true_pred[edges_only_pred]] = pred_force
    true_force_all[i_true_pred[edges_both]] = true_force
    pred_force_all[i_true_pred[edges_both]] = pred_force
else:
    true_force_all = true_force
    pred_force_all = pred_force

tensor(False)


In [19]:
error_force = pred_force_all - true_force_all

In [20]:
error_force

tensor([[ 2.1684e-05,  9.2796e-05],
        [-5.4859e-05, -1.0713e-04],
        [ 8.9630e-05,  1.1568e-04],
        ...,
        [-2.3264e-05,  3.4932e-05],
        [ 4.2051e-05,  3.3540e-05],
        [ 6.9070e-05,  3.6964e-05]])

In [21]:
true_force

tensor([[ 3.3527e-03,  2.2349e-03],
        [-8.3817e-04,  4.0591e-03],
        [ 1.7345e-03,  3.6007e-04],
        ...,
        [-3.9703e-04,  1.0440e-03],
        [-4.8414e-04, -4.3469e-04],
        [ 3.3065e-04,  5.8676e-05]])

In [22]:
pred_force

tensor([[ 3.3743e-03,  2.3277e-03],
        [-8.9303e-04,  3.9519e-03],
        [ 1.8241e-03,  4.7575e-04],
        ...,
        [-4.2030e-04,  1.0790e-03],
        [-4.4209e-04, -4.0115e-04],
        [ 3.9972e-04,  9.5640e-05]])

In [23]:
ratio_force = pred_force / true_force

In [24]:
ratio_force

tensor([[1.0065, 1.0415],
        [1.0655, 0.9736],
        [1.0517, 1.3213],
        ...,
        [1.0586, 1.0335],
        [0.9131, 0.9228],
        [1.2089, 1.6300]])

In [25]:
norm_force = torch.norm(true_force_all, dim=-1, keepdim=True)

In [26]:
norm_force.min()

tensor(2.5034e-09)

In [27]:
norm_error = torch.norm(error_force, dim=-1, keepdim=True)

In [28]:
norm_error.max()

tensor(0.0021)

In [29]:
normalized_MSE = (norm_error**2).mean() / (norm_force**2).mean()
normalized_MSE

tensor(0.0233)

In [30]:
normalized_MAE = abs(error_force).mean() / abs(true_force).mean()
normalized_MAE

tensor(0.1129)

In [31]:
abs(error_force).max()

tensor(0.0021)

In [32]:
norm_force2 = torch.norm(true_force2, dim=-1, keepdim=True)

In [33]:
error_force2 = pred_force2 - true_force2

In [34]:
norm_error2 = torch.norm(error_force2, dim=-1, keepdim=True)

In [35]:
normalized_MSE2 = (norm_error2**2).mean() / (norm_force2**2).mean()
normalized_MSE2

tensor(0.0537)

In [36]:
normalized_MAE2 = abs(error_force2).mean() / abs(true_force2).mean()
normalized_MAE2

tensor(0.3771)

In [37]:
error_data = {'normalizedMSE_spring': normalized_MSE.item(),
              'normalizedMAE_spring': normalized_MAE.item(),
              'normalizedMSE_friction': normalized_MSE2.item(),
              'normalizedMAE_friction': normalized_MAE2.item()}

In [38]:
np.save(os.path.join(learning_dir, 'pred_error_in_dataset.npy'),
        error_data)

In [39]:
with open(os.path.join(learning_dir, 'pred_error_in_dataset.csv'),'w',encoding='utf-8') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames = error_data.keys())
    writer.writeheader()
    writer.writerows([error_data])