In [1]:
import math
import torch_scatter
import torch
import torch_geometric as pyg
import torch.nn.functional as F
import os
import numpy as np
import json
from neuralop.models import FNO
from torch import nn
from models.neuraloperator.neuralop.layers.mlp import MLP as NeuralOpMLP
from models.neuraloperator.neuralop.layers.embeddings import PositionalEmbedding
from models.neuraloperator.neuralop.layers.integral_transform import IntegralTransform
from models.neuraloperator.neuralop.layers.neighbor_search import NeighborSearch
import random
from random import randint
from dataloader import preprocess
#from models.giorom2d import PhysicsEngine
from models.giorom3d_large import PhysicsEngine
import yaml
#from Baselines.GAT import PhysicsEngine

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
params = {
    "epoch": 1000,
    "batch_size": 4,
    "lr": 1e-4,
    "noise": 0.0003,
    "save_interval": 1000,
    "eval_interval": 10000,
    "rollout_interval": 20000,
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_config_path = 'configs/giorom3d_large.yaml'
with open(model_config_path, 'r') as f:
    model_config = yaml.full_load(f)
simulator = PhysicsEngine(device, **model_config)
ckpt = torch.load('/home/csuser/Documents/Neural Operator/saved_models/giorom3d_large_nclaw_Sand.pt')
weights = torch.load('/home/csuser/Documents/Neural Operator/saved_models/giorom3d_large_nclaw_Sand.pt')['model']

model_dict = simulator.state_dict()
ckpt_dict = {}

#print(simulator.keys())
model_dict = dict(model_dict)

for k, v in weights.items():
    k2 = k[0:]
    #print(k2)
    if k2 in model_dict:
        #print(k2)
        if model_dict[k2].size() == v.size():
            ckpt_dict[k2] = v
        else:
            print("Size mismatch while loading! %s != %s Skipping %s..."%(str(model_dict[k2].size()), str(v.size()), k2))
            mismatch = True
    else:
        print("Model Dict not in Saved Dict! %s != %s Skipping %s..."%(2, str(v.size()), k2))
        mismatch = True
if len(simulator.state_dict().keys()) > len(ckpt_dict.keys()):
    mismatch = True
model_dict.update(ckpt_dict)
simulator.load_state_dict(model_dict)


#simulator.load_state_dict(weights['model'])
simulator = simulator.cuda()
total_params = sum(p.numel() for p in simulator.parameters())
print(f"Number of parameters: {total_params}")

Number of parameters: 2006574


In [3]:
def rollout(model, data, metadata, noise_std):
    device = next(model.parameters()).device
    model.eval()
    window_size = model.window_size + 1
    total_time = data["position"].size(0)
    #total_time = 400
    #print("Total Time = ", total_time)
    
    traj = data["position"][:window_size]
    #print("TRAJ SHAPE = ", traj.shape)
    traj = traj.permute(1, 0, 2)
    particle_type = data["particle_type"]


    for time in range(total_time - window_size):
        #print(time)
        with torch.no_grad():
            #print("PARTICLE TYPE = ", particle_type.shape)
            #print("TRAJECTORY = ", traj.shape)
            graph = preprocess(particle_type, traj[:, -window_size:], None, metadata, 0.0)
            graph = graph.to(device)
            acceleration = model(graph).cpu()
            acceleration = acceleration * torch.sqrt(torch.tensor(metadata["acc_std"]) ** 2 + noise_std ** 2) + torch.tensor(metadata["acc_mean"])

            recent_position = traj[:, -1]
            recent_velocity = recent_position - traj[:, -2]
            new_velocity = recent_velocity + acceleration
            new_position = recent_position + new_velocity
            traj = torch.cat((traj, new_position.unsqueeze(1)), dim=1)
    return traj

In [4]:
from collections import OrderedDict
from dgl.geometry import farthest_point_sampler
from scipy.spatial import Delaunay
class RolloutDataset(pyg.data.Dataset):
    def __init__(self, data_path, split, window_length=7, random_sample=False):
        super().__init__()
        
        # load data from the disk
        self.data_path = data_path
        with open(os.path.join('/home/csuser/Documents/new_dataset/nclaw_Sand/', "metadata.json")) as f:
            self.metadata = json.load(f)
        #with open(os.path.join(data_path, f"{split}_offset.json")) as f:
        #    self.offset = json.load(f)
        #self.offset = {int(k): v for k, v in self.offset.items()}
        #self.metadata['default_connectivity_radius'] = 0.060
        self.window_length = window_length
        self.random_sample = random_sample
        #self.particle_type = np.memmap(os.path.join(data_path, f"{split}_particle_type.dat"), dtype=np.int64, mode="r")
        #self.position = np.memmap(os.path.join(data_path, f"{split}_position.dat"), dtype=np.float32, mode="r")
        dataset = torch.load(data_path)
        self.particle_type = dataset['particle_type']
        self.position = dataset['position']
        self.n_particles_per_example = dataset['n_particles_per_example']
        self.outputs = dataset['output']
        self.mesh_size = 200
        if(self.random_sample == True or self.random_sample==False):
            mesh_size =  np.random.randint(int(0.30*360), int(0.45*360))
            #mesh_size = 120
            while(mesh_size %10 !=0):
                mesh_size += 1
            
            #points = list(range(0, 360, 4))
            points = sorted(random.sample(range(0, 360), mesh_size))
        self.points = points
        #for traj in self.offset.values():
        #    self.dim = traj["position"]["shape"][2]
        #    break
        self.dim = self.position[0].shape[2]
    def len(self):
        return len(self.position)
    
    def get(self, idx):
        #traj = self.offset[idx]
        #size = traj["position"]["shape"][1]
        #time_step = traj["position"]["shape"][0]
        #particle_type = self.particle_type[traj["particle_type"]["offset"]: traj["particle_type"]["offset"] + size].copy()
        #particle_type = torch.from_numpy(particle_type)
        #position = self.position[traj["position"]["offset"]: traj["position"]["offset"] + time_step * size * self.dim].copy()
        #position.resize(traj["position"]["shape"])
        #position = torch.from_numpy(position)

        particle_type = torch.from_numpy(self.particle_type[idx])
        position_seq = torch.from_numpy(self.position[idx])
        position_seq = torch.permute(position_seq, dims=(1,0,2))
        
        target_position = torch.from_numpy(self.outputs[idx])
        if(self.random_sample):
            if(self.sampling_strategy == 'random'):
                self.points = sorted(random.sample(range(0, particle_type.shape[0]), self.mesh_size))
                particle_type = particle_type[self.points]
                position_seq = position_seq.permute(1,0,2)
                position_seq = position_seq[self.points]
                position_seq = position_seq.permute(1,0,2)
                target_position = target_position[self.points]
            elif(self.sampling_strategy == 'fps'):
                init_pos = position_seq.permute(1, 0, 2)[0].unsqueeze(0)
                point_idx = farthest_point_sampler(init_pos, self.mesh_size)[0]
                particle_type = particle_type[point_idx]
                position_seq = position_seq[point_idx]
                target_position = target_position[point_idx]
            #target_position = target_position[self.points]
        data = {"particle_type": particle_type, "position": position_seq}
        return data

rollout_dataset = RolloutDataset('/home/csuser/Documents/new_dataset/nclaw_Sand/rollout.pt', "train", random_sample=False)
rollout_dataset_gt = RolloutDataset('/home/csuser/Documents/new_dataset/nclaw_Sand/rollout.pt', "train", random_sample=False)
#print(len(rollout_dataset))
simulator.eval()
sim_id = 1
rollout_data = rollout_dataset[sim_id]
if(rollout_data['position'].shape[1] != rollout_data['particle_type'].shape[0]):
    temp = rollout_data['position']
    temp = temp.permute(1, 0, 2)
    temp = temp[:rollout_data['particle_type'].shape[0]]
    temp = temp.permute(1, 0, 2)
    rollout_data['position'] = temp
print(rollout_data['position'].shape)
print(rollout_data['particle_type'].shape)
#rollout_data_gt = rollout_dataset_gt[1]
rollout_data_gt = rollout_dataset_gt[sim_id]

temp = rollout_data['position'][0]



rollout_out = rollout(simulator, rollout_data, rollout_dataset.metadata, params["noise"])
rollout_out = rollout_out.permute(1, 0, 2)
loss = (rollout_out - rollout_data["position"]) ** 2
loss = loss.sum(dim=-1).mean()
print("Rollout Loss: ", loss)
torch.save(rollout_out, f'GIOROM_Sand_{sim_id}.pt')
torch.save(rollout_data_gt, f'GIOROM_Sand_{sim_id}_gt.pt')

torch.Size([320, 1711, 3])
torch.Size([1711])
Rollout Loss:  tensor(0.0023)


In [5]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

In [6]:
TYPE_TO_COLOR = {
    3: "black",
    0: "green",
    7: "magenta",
    6: "gold",
    5: "blue",
}


def visualize_prepare(ax, particle_type, position, metadata):
    bounds = metadata["bounds"]
    ax.set_xlim(bounds[0][0], bounds[0][1])
    ax.set_ylim(bounds[1][0], bounds[1][1])
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect(1.0)
    points = {type_: ax.plot([], [], "o", ms=2, color=color)[0] for type_, color in TYPE_TO_COLOR.items()}
    return ax, position, points


def visualize_pair(particle_type, position_pred, position_gt, metadata):
    print(position_pred.shape)
    print(position_gt.shape)
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    old_particle_type = torch.ones(size=(position_gt.shape[1],)) * particle_type[0]
    plot_info = [
        visualize_prepare(axes[0], old_particle_type, position_gt, metadata),
        visualize_prepare(axes[1], particle_type, position_pred, metadata),
    ]
    axes[0].set_title("Ground truth")
    axes[1].set_title("Prediction")

    plt.close()
    def update(step_i):
        outputs = []
        for _, position, points in plot_info:
            for type_, line in points.items():
                mask = particle_type == type_
                if(position.shape[1] == position_gt.shape[1]):
                    mask = old_particle_type == type_
                    #print(position.shape, mask.shape)
                #print(position.shape, mask.shape)
                line.set_data(position[step_i, mask, 0], position[step_i, mask, 1])
            outputs.append(line)
        return outputs

    return animation.FuncAnimation(fig, update, frames=np.arange(0, position_gt.size(0)), interval=20, blit=True)

In [7]:

#inp = torch.load('out.pt')
#anim = visualize_pair(inp['pt'], inp['rout'], inp['pos'], inp['met'])
#anim = visualize_pair(rollout_data["particle_type"], rollout_out, rollout_data["position"], rollout_dataset.metadata)
anim = visualize_pair(rollout_data["particle_type"], rollout_out, rollout_data_gt['position'], rollout_dataset.metadata)
HTML(anim.to_html5_video())

torch.Size([320, 1711, 3])
torch.Size([320, 1711, 3])
