# Trying the model

## Update mechanism
The implementation uses semi-implicit Euler integration to update the next state based on the predicted accelerations:
$$\dot{\mathbf{p}}^{t+1}=\dot{\mathbf{p}}^t+\Delta t\cdot \ddot{\mathbf{p}}^t $$
$$\mathbf{p}^{t+1}=\mathbf{p}^t+\Delta t\cdot \dot{\mathbf{p}}^{t+1}$$
where $\Delta t=1$ for simplicity. 



In [None]:
%load_ext autoreload
%cd /workspace

In [None]:
%%writefile open_gns/simulator.py

import torch
from torch_geometric.data import Data
from torch_geometric.transforms import RadiusGraph
from open_gns.models import EncodeProcessDecode

class Simulator():
    def __init__(self, *, positions, properties, velocities=None, device=None, R=0.08):
        # initialize the model
        self.R = R
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        checkpoint = torch.load('checkpoint_5_2.9207797146117526e-06.pt')
        input_size = 20
        model = EncodeProcessDecode(input_size).to(device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        self.model = model
        self.positions = positions.to(device)
        self.properties = properties.to(device)
        self.velocities = velocities if velocities is not None else torch.zeros((len(positions), 5*3))
        self.velocities = self.velocities.to(device)
        self.data = self.make_graph(positions, properties, self.velocities)

    
    def make_graph(self, positions, properties, velocities):
        x = torch.cat([positions, properties, velocities], 1)
        data = Data(x=x, pos=positions)
        find_edges = RadiusGraph(self.R)
        data = find_edges(data)
        return data
    
    def step(self, pos=None):
        # Predict accelerations
        data = self.data
        if pos is not None:
            data.x[:,:3] = pos
        accelerations_ = self.model(data.x, data.edge_index)
        velocities_ = data.x[:,-3:] + accelerations_ 
        positions_ = data.pos + velocities_
        # Reconstruct data for next frame
        self.velocities = torch.cat([self.velocities[:,3:], velocities_], 1)
        self.data = self.make_graph(positions_, self.properties, self.velocities)
        return positions_, velocities_, accelerations_

In [None]:
# load
%autoreload 2
import torch
from open_gns.dataset import GNSDataset
from torch.nn import MSELoss
from open_gns.simulator import Simulator

mse = MSELoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load dataset
dataset = GNSDataset('./notebooks', split='test')
# Perform rollout using the simulator
rollout = dataset[0:143]
data = rollout[0]
data = data.to(device)
v = data.x[:,-15:]
sim = Simulator(positions=data.pos, velocities=v, properties=data.x[:,3:5], device=device)
assert torch.all(sim.data.x == data.x)
assert torch.all(sim.data.edge_index == data.edge_index)
assert torch.all(sim.data.pos == data.pos)
positions = []
accelerations = []
acc_gt = []
pos_gt = []
for i, data in enumerate(rollout[1:]):
    pos_gt.append(data.pos)
    acc_gt.append(data.y)
    data = data.to(device)
    # Predict
    pos, vel, acc = sim.step(data.pos)
    positions.append(pos.detach().cpu())
    accelerations.append(acc.detach().cpu())
    # Compare against dataset
    loss = mse(acc,data.y)
    print(loss.item())
    

In [None]:
for p in sim.model.parameters():
    print(p.data.size())
print(sim.model)

In [None]:
%autoreload 2
from open_gns.utils import animate_rollout, animate_rollout_quiver

animate_rollout_quiver(positions, accelerations)
animate_rollout_quiver(pos_gt, acc_gt)

In [None]:
%autoreload 2
import torch
print(torch)
print(torch.__version__)
print(data.edge_index.size())
for i in range(data.edge_index.size(1)):
    a,b = data.edge_index[:,i]
    print(f'({a})[{data.pos[a]}] -> ({b})[{data.pos[b]}]')
    print(torch.sum(torch.pow(data.pos[a] - data.pos[b], 2))/3.0)