In [1]:
import math
import torch_scatter
import torch
import torch_geometric as pyg
import torch.nn.functional as F
import os
import numpy as np
import json
from neuralop.models import FNO
from torch import nn
from models.neuraloperator.neuralop.layers.mlp import MLP as NeuralOpMLP
from models.neuraloperator.neuralop.layers.embeddings import PositionalEmbedding
from models.neuraloperator.neuralop.layers.integral_transform import IntegralTransform
from models.neuraloperator.neuralop.layers.neighbor_search import NeighborSearch
import random
from random import randint
from dataloader import preprocess
from models.giorom2d import PhysicsEngine as LearnedSimulator
#from GAT import preprocess, LearnedSimulator

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



In [2]:
params = {
    "epoch": 1000,
    "batch_size": 4,
    "lr": 1e-4,
    "noise": 3e-4,
    "save_interval": 1000,
    "eval_interval": 10000,
    "rollout_interval": 20000,
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
simulator = LearnedSimulator(device)
ckpt = torch.load('/home/csuser/Documents/Neural Operator/saved_models/giorom2d_WaterDropSmall_sampled.pt')
weights = torch.load('/home/csuser/Documents/Neural Operator/saved_models/giorom2d_WaterDropSmall_sampled.pt')['model']

model_dict = simulator.state_dict()
ckpt_dict = {}

#print(simulator.keys())
model_dict = dict(model_dict)

for k, v in weights.items():
    k2 = k[0:]
    #print(k2)
    if k2 in model_dict:
        #print(k2)
        if model_dict[k2].size() == v.size():
            ckpt_dict[k2] = v
        else:
            print("Size mismatch while loading! %s != %s Skipping %s..."%(str(model_dict[k2].size()), str(v.size()), k2))
            mismatch = True
    else:
        print("Model Dict not in Saved Dict! %s != %s Skipping %s..."%(2, str(v.size()), k2))
        mismatch = True
if len(simulator.state_dict().keys()) > len(ckpt_dict.keys()):
    mismatch = True
model_dict.update(ckpt_dict)
simulator.load_state_dict(model_dict)


#simulator.load_state_dict(weights['model'])
simulator = simulator.cuda()


In [3]:
def rollout(model, data, metadata, noise_std):
    device = next(model.parameters()).device
    model.eval()
    window_size = model.window_size + 1
    total_time = data["position"].size(0)
    #print("Total Time = ", total_time)
    
    traj = data["position"][:window_size]
    #print("TRAJ SHAPE = ", traj.shape)
    traj = traj.permute(1, 0, 2)
    particle_type = data["particle_type"]


    for time in range(total_time - window_size):
        with torch.no_grad():
            #print("PARTICLE TYPE = ", particle_type.shape)
            #print("TRAJECTORY = ", traj.shape)
            graph = preprocess(particle_type, traj[:, -window_size:], None, metadata, 0.0, radius=0.050, graph_type='radius')
            graph = graph.to(device)
            acceleration = model(graph).cpu()
            acceleration = acceleration * torch.sqrt(torch.tensor(metadata["acc_std"]) ** 2 + noise_std ** 2) + torch.tensor(metadata["acc_mean"])

            recent_position = traj[:, -1]
            recent_velocity = recent_position - traj[:, -2]
            new_velocity = recent_velocity + acceleration
            new_position = recent_position + new_velocity
            traj = torch.cat((traj, new_position.unsqueeze(1)), dim=1)
    return traj

In [4]:
def avg_velocity(rollout):
    velocity_seq = rollout[:,1:] - rollout[:, :-1]
    print(velocity_seq.shape)
    return velocity_seq.numpy().max()
    

In [5]:
def avg_spacing(rollout):
    spacing_vectors = []
    for i in range(1000):
        distances = torch.cdist(rollout[i], rollout[i]).flatten()
        spacing_vectors.append(distances[distances.nonzero()].min())
    return np.asarray(spacing_vectors).mean()

In [6]:
from collections import OrderedDict
from dgl.geometry import farthest_point_sampler
from scipy.spatial import Delaunay
class RolloutDataset(pyg.data.Dataset):
    def __init__(self, data_path, split, window_length=7, sampling=False, sampling_strategy='random', graph_type='radius', mesh_size=170, radius=None):
        super().__init__()
        
        # load data from the disk
        self.data_path = data_path
        with open(os.path.join('/home/csuser/Documents/new_dataset/WaterDropSmall/', "metadata.json")) as f:
            self.metadata = json.load(f)
        #with open(os.path.join(data_path, f"{split}_offset.json")) as f:
        #    self.offset = json.load(f)
        #self.offset = {int(k): v for k, v in self.offset.items()}
        self.window_length = window_length
        self.sampling = sampling
        self.sampling_strategy = sampling_strategy
        self.graph_type = graph_type
        self.mesh_size=  mesh_size
        self.radius = radius
        #self.particle_type = np.memmap(os.path.join(data_path, f"{split}_particle_type.dat"), dtype=np.int64, mode="r")
        #self.position = np.memmap(os.path.join(data_path, f"{split}_position.dat"), dtype=np.float32, mode="r")
        dataset = torch.load(data_path)
        self.particle_type = dataset['particle_type']
        self.position = dataset['position']
        self.n_particles_per_example = dataset['n_particles_per_example']
        self.outputs = dataset['output']

        if(self.sampling == True or self.sampling==False):
            #mesh_size =  np.random.randint(int(0.30*360), int(0.45*360))
            mesh_size = 120
            while(mesh_size %10 !=0):
                mesh_size += 1
            
            #points = list(range(0, 360, 4))
            points = sorted(random.sample(range(0, 360), mesh_size))
        self.points = points
        #for traj in self.offset.values():
        #    self.dim = traj["position"]["shape"][2]
        #    break
        self.dim = self.position[0].shape[2]
    def len(self):
        return len(self.position)
    
    def get(self, idx):
        #traj = self.offset[idx]
        #size = traj["position"]["shape"][1]
        #time_step = traj["position"]["shape"][0]
        #particle_type = self.particle_type[traj["particle_type"]["offset"]: traj["particle_type"]["offset"] + size].copy()
        #particle_type = torch.from_numpy(particle_type)
        #position = self.position[traj["position"]["offset"]: traj["position"]["offset"] + time_step * size * self.dim].copy()
        #position.resize(traj["position"]["shape"])
        #position = torch.from_numpy(position)

        particle_type = torch.from_numpy(self.particle_type[idx])
        position_seq = torch.from_numpy(self.position[idx])
        position_seq = torch.permute(position_seq, dims=(1,0,2))
        
        target_position = torch.from_numpy(self.outputs[idx])
        if(self.sampling):
            if(self.sampling_strategy == 'random'):
                self.points = sorted(random.sample(range(0, particle_type.shape[0]), self.mesh_size))
                particle_type = particle_type[self.points]
                position_seq = position_seq.permute(1,0,2)
                position_seq = position_seq[self.points]
                position_seq = position_seq.permute(1,0,2)
                target_position = target_position[self.points]
            elif(self.sampling_strategy == 'fps'):
                init_pos = position_seq.permute(1, 0, 2)[0].unsqueeze(0)
                point_idx = farthest_point_sampler(init_pos, self.mesh_size)[0]
                particle_type = particle_type[point_idx]
                position_seq = position_seq[point_idx]
                target_position = target_position[point_idx]
        data = {"particle_type": particle_type, "position": position_seq}
        return data

rollout_dataset = RolloutDataset('/home/csuser/Documents/new_dataset/WaterDropSmall/rollout.pt', "train", sampling=True, sampling_strategy='random', graph_type='radius',radius=0.040)
#print(len(rollout_dataset))
simulator.eval()
sim_id = 0
rollout_data = rollout_dataset[sim_id]
temp = rollout_data['position'][0]

rollout_out = rollout(simulator, rollout_data, rollout_dataset.metadata, params["noise"])
#print(avg_velocity(rollout_out))
rollout_out = rollout_out.permute(1, 0, 2)
loss = (rollout_out - rollout_data["position"]) ** 2
loss = loss.sum(dim=-1).mean()
print("Rollout Loss: ", loss)
print(rollout_out[1:,:].shape)
print(rollout_out.shape)

torch.save(rollout_out, f'GIOROM_Sand_sim_{sim_id}.pt')
torch.save(rollout_data, f'GIOROM_Sand_sim_{sim_id}_gt.pt')
#print(avg_spacing(rollout_out))


Rollout Loss:  tensor(0.0506)
torch.Size([999, 170, 2])
torch.Size([1000, 170, 2])


In [7]:
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

In [8]:
TYPE_TO_COLOR = {
    3: "black",
    0: "green",
    7: "magenta",
    6: "gold",
    5: "#14b1f5",
}


def visualize_prepare(ax, particle_type, position, metadata):
    bounds = metadata["bounds"]
    ax.set_xlim(bounds[0][0], bounds[0][1])
    ax.set_ylim(bounds[1][0], bounds[1][1])
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect(1.0)
    points = {type_: ax.plot([], [], "o", ms=2, color=color)[0] for type_, color in TYPE_TO_COLOR.items()}
    return ax, position, points


def visualize_pair(particle_type, position_pred, position_gt, metadata):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    plot_info = [
        visualize_prepare(axes[0], particle_type, position_gt, metadata),
        visualize_prepare(axes[1], particle_type, position_pred, metadata),
    ]
    axes[0].set_title("Ground truth")
    axes[1].set_title("Prediction")

    plt.close()
    def update(step_i):
        outputs = []
        for _, position, points in plot_info:
            for type_, line in points.items():
                mask = particle_type == type_
                line.set_data(position[step_i, mask, 0], position[step_i, mask, 1])
            outputs.append(line)
        return outputs

    return animation.FuncAnimation(fig, update, frames=np.arange(0, position_gt.size(0)), interval=10, blit=True)

In [9]:
anim = visualize_pair(rollout_data["particle_type"], rollout_out, rollout_data["position"], rollout_dataset.metadata)
writer = animation.writers['ffmpeg'](fps=30)
anim.save('disc_4.mp4',writer=writer,dpi=200)
HTML(anim.to_html5_video())