# Trying the model

## Update mechanism
The implementation uses semi-implicit Euler integration to update the next state based on the predicted accelerations:
$$\dot{\mathbf{p}}^{t+1}=\dot{\mathbf{p}}^t+\Delta t\cdot \ddot{\mathbf{p}}^t $$
$$\mathbf{p}^{t+1}=\mathbf{p}^t+\Delta t\cdot \dot{\mathbf{p}}^{t+1}$$
where $\Delta t=1$ for simplicity. 



In [2]:
%load_ext autoreload
%cd /workspace

/workspace


In [3]:
%%writefile open_gns/simulator.py

import torch
from torch_geometric.data import Data
from torch_geometric.transforms import RadiusGraph
from open_gns.models import EncodeProcessDecode
from open_gns.normalizer import Normalizer

class Simulator():
    def __init__(self, *, positions, properties, velocities=None, device=None, R=0.08):
        # initialize the model
        self.R = R
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        checkpoint = torch.load('checkpoint_9_2.7949773842113843e-06.pt')
        input_size = 25
        model = EncodeProcessDecode(input_size).to(device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        self.model = model
        self.positions = positions.to(device)
        self.properties = properties.to(device)
        self.velocities = velocities if velocities is not None else torch.zeros((len(positions), 5*3))
        self.velocities = self.velocities.to(device)
        self.data = self.make_graph(positions, properties, self.velocities)
        self.norm = Normalizer(input_size, mask_cols=[3,4], device=device)

    
    def make_graph(self, positions, properties, velocities):
        d = torch.stack([
            positions[:,1],       # bottom
            positions[:,0],       # left
            positions[:,2],        # back
            1.2 - positions[:,0], # right
            0.4 - positions[:,2]   # front
        ], dim=1)
        d = torch.clamp(d, min=0, max=self.R)   
        x = torch.cat([positions, properties, velocities, d], 1)
        data = Data(x=x, pos=positions)
        find_edges = RadiusGraph(self.R)
        data = find_edges(data)
        return data
    
    def step(self, pos=None):
        # Predict accelerations
        data = self.data
        if pos is not None:
            data.x[:,:3] = pos
            data.pos = pos
        accelerations_ = self.model(self.norm(data.x), data.edge_index)
        velocities_ = data.x[:,17:20] + accelerations_ 
        positions_ = data.pos + velocities_
        print('p_t:', data.x[0], data.pos[0])
        print('a_t:', accelerations_[0])
        print('v_t:', data.x[0,17:20])
        print('v_t+1',velocities_[0])
        print('p_t+1', positions_[0])
        # Reconstruct data for next frame
        self.velocities = torch.cat([self.velocities[:,3:], velocities_], 1)
        self.data = self.make_graph(positions_, self.properties, self.velocities)
        return positions_, velocities_, accelerations_

Overwriting open_gns/simulator.py


In [4]:
# load
%autoreload 2
import torch
from open_gns.dataset import GNSDataset
from torch.nn import MSELoss
from open_gns.simulator import Simulator

mse = MSELoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load dataset
dataset = GNSDataset('./notebooks', split='test')
# Perform rollout using the simulator
rollout = dataset[0:143]
data = rollout[0]
data = data.to(device)
v = data.x[:,5:20]
sim = Simulator(positions=data.pos, velocities=v, properties=data.x[:,3:5], device=device)
assert torch.all(sim.data.edge_index == data.edge_index)
assert torch.all(sim.data.pos == data.pos)
positions = []
accelerations = []
velocities = []
acc_gt = []
pos_gt = []
vel_gt = []
for i, data in enumerate(rollout[1:]):
    pos_gt.append(data.pos)
    acc_gt.append(data.y)
    vel_gt.append(data.x[:,17:20])
    data = data.to(device)
    # Predict
    pos, vel, acc = sim.step()
    positions.append(pos.detach().cpu())
    accelerations.append(acc.detach().cpu())
    velocities.append(vel.detach().cpu())
    # Compare against dataset
    loss = mse(acc,data.y)
    print(loss.item())
    

p_t: tensor([6.0163e-01, 1.0000e-02, 1.2251e-01, 0.0000e+00, 1.0000e+00, 0.0000e+00,
        5.0000e-04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e-02, 8.0000e-02, 8.0000e-02, 8.0000e-02,
        8.0000e-02], device='cuda:0') tensor([0.6016, 0.0100, 0.1225], device='cuda:0')
a_t: tensor([ 2.7834e-06,  6.7952e-05, -2.9317e-06], device='cuda:0',
       grad_fn=<SelectBackward>)
v_t: tensor([0., 0., 0.], device='cuda:0')
v_t+1 tensor([ 2.7834e-06,  6.7952e-05, -2.9317e-06], device='cuda:0',
       grad_fn=<SelectBackward>)
p_t+1 tensor([0.6016, 0.0101, 0.1225], device='cuda:0', grad_fn=<SelectBackward>)
5.0308663048781455e-06
p_t: tensor([ 6.0163e-01,  1.0068e-02,  1.2250e-01,  0.0000e+00,  1.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
        

In [None]:
for p in sim.model.parameters():
    print(p.data.size())
print(sim.model)

In [5]:
%autoreload 2
from open_gns.utils import animate_rollout, animate_rollout_quiver

animate_rollout_quiver(positions, velocities)
animate_rollout_quiver(pos_gt, vel_gt)

VBox(children=(Figure(animation=200.0, camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), quaternion…

VBox(children=(Figure(animation=200.0, camera=PerspectiveCamera(fov=46.0, position=(0.0, 0.0, 2.0), quaternion…

In [None]:
%autoreload 2
import torch
print(torch)
print(torch.__version__)
print(data.edge_index.size())
for i in range(data.edge_index.size(1)):
    a,b = data.edge_index[:,i]
    print(f'({a})[{data.pos[a]}] -> ({b})[{data.pos[b]}]')
    print(torch.sum(torch.pow(data.pos[a] - data.pos[b], 2))/3.0)