In [30]:
import os
import sys
import random
import time
import argparse
import numpy as np
import matplotlib.pyplot as plt

import multiprocessing as mp

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

from models import *
from data import *
#from data import PhysicsFleXDataset, collate_fn

from utils import count_parameters

In [31]:
parser = argparse.ArgumentParser()
parser.add_argument('--pstep', type=int, default=2)
parser.add_argument('--n_rollout', type=int, default=0)
parser.add_argument('--time_step', type=int, default=0)
parser.add_argument('--time_step_clip', type=int, default=0)
parser.add_argument('--dt', type=float, default=1./60.)
parser.add_argument('--nf_relation', type=int, default=300)
parser.add_argument('--nf_particle', type=int, default=200)
parser.add_argument('--nf_effect', type=int, default=200)
parser.add_argument('--env', default='')
parser.add_argument('--train_valid_ratio', type=float, default=0.9)
parser.add_argument('--outf', default='files')
parser.add_argument('--dataf', default='data')
parser.add_argument('--num_workers', type=int, default=10)
parser.add_argument('--gen_data', type=int, default=0)
parser.add_argument('--gen_stat', type=int, default=0)
parser.add_argument('--log_per_iter', type=int, default=1000)
parser.add_argument('--ckp_per_iter', type=int, default=10000)
parser.add_argument('--eval', type=int, default=0)
parser.add_argument('--verbose_data', type=int, default=1)
parser.add_argument('--verbose_model', type=int, default=1)

parser.add_argument('--n_instance', type=int, default=0)
parser.add_argument('--n_stages', type=int, default=0)
parser.add_argument('--n_his', type=int, default=0)

parser.add_argument('--n_epoch', type=int, default=1000)
parser.add_argument('--beta1', type=float, default=0.9)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--forward_times', type=int, default=2)

parser.add_argument('--resume_epoch', type=int, default=0)
parser.add_argument('--resume_iter', type=int, default=0)

# shape state:
# [x, y, z, x_last, y_last, z_last, quat(4), quat_last(4)]
parser.add_argument('--shape_state_dim', type=int, default=14)

# object attributes:
parser.add_argument('--attr_dim', type=int, default=0)

# object state:
parser.add_argument('--state_dim', type=int, default=0)
parser.add_argument('--position_dim', type=int, default=0)

# relation attr:
parser.add_argument('--relation_dim', type=int, default=0)

_StoreAction(option_strings=['--relation_dim'], dest='relation_dim', nargs=None, const=None, default=0, type=<class 'int'>, choices=None, help=None, metavar=None)

In [32]:
args = parser.parse_args("--env SingleHair --gen_data 0".split())

In [33]:
phases_dict = dict()

args.n_rollout = 50
args.num_workers = 5
args.gen_stat = 1

args.dataf = 'data'

# object states:
# [x, y, z, xdot, ydot, zdot]
args.state_dim = 6
args.position_dim = 3

# object attr:
# [rigid]
args.attr_dim = 1

# relation attr:
# [none]
args.relation_dim = 2

args.time_step = 600
args.time_step_clip = 100
args.n_instance = 1
args.n_stages = 1

args.neighbor_radius = 0.08

phases_dict["instance_idx"] = [0, 31]
phases_dict["root_num"] = [[]]
phases_dict["instance"] = ['solid']
phases_dict["material"] = ['solid']

data_names = ['positions', 'velocities','hair_idx']
verbose = 0
phase = 'train'
data_dir = os.path.join(args.dataf, phase)
stat_path = os.path.join(args.dataf, 'stat.h5')
n_rollout = 10



info = {
    'env': args.env,
    'root_num': phases_dict['root_num'],
    'thread_idx': 0,
    'data_dir': data_dir,
    'data_names': data_names,
    'n_rollout': n_rollout // args.num_workers,
    'n_instance': args.n_instance,
    'time_step': args.time_step,
    'time_step_clip': args.time_step_clip,
    'dt': args.dt,
    'shape_state_dim': args.shape_state_dim}

info['env_idx'] = 11

args.outf = 'dump_SingleHair/' + args.outf

args.outf = args.outf + '_' + args.env
args.dataf = 'data/' + args.dataf + '_' + args.env
print (args.dataf)
os.system('mkdir -p ' + args.outf)
os.system('mkdir -p ' + args.dataf)

data/data_SingleHair


0

In [34]:

def gen_PyFleX(info):

    env, root_num = info['env'], info['root_num']
    thread_idx, data_dir, data_names = info['thread_idx'], info['data_dir'], info['data_names']
    n_rollout, n_instance = info['n_rollout'], info['n_instance']
    time_step, time_step_clip = info['time_step'], info['time_step_clip']
    shape_state_dim, dt = info['shape_state_dim'], info['dt']

    env_idx = info['env_idx'] # =11

    np.random.seed(round(time.time() * 1000 + thread_idx) % 2**32)
    
    stats = [init_stat(3), init_stat(3)]

    import pyflex
    pyflex.init()

    for i in range(n_rollout):

        if i % 10 == 0:
            print("%d / %d" % (i, n_rollout))

        rollout_idx = thread_idx * n_rollout + i
        rollout_dir = os.path.join(data_dir, str(rollout_idx))
        os.system('mkdir -p ' + rollout_dir)
        
        # scene_params: [len(box) at dim x,len(box) at dim y,len(box) at dim z, num_hair per circle, num_circle]
        cap_size = [0.1,1.5]
        N_hairs = 1


        scene_params = np.array(cap_size)

        pyflex.set_scene(env_idx, scene_params, thread_idx)
        n_particles = pyflex.get_n_particles()
        n_shapes = 1
        N_particles_per_hair = int(n_particles/N_hairs)
        idx_begins = np.arange(N_hairs)*N_particles_per_hair
        idx_hairs = [[i,i+N_particles_per_hair-1] for i in idx_begins]

        positions = np.zeros((time_step, n_particles+ n_shapes, 3), dtype=np.float32)
        velocities = np.zeros((time_step, n_particles+ n_shapes, 3), dtype=np.float32)
   #     shape_position = np.zeros((time_step, n_shapes, 3), dtype=np.float32)
    #    shape_velocities = np.zeros((time_step, n_shapes, 3), dtype=np.float32)

        for j in range(time_step_clip):
         #   p_clip = pyflex.get_positions().reshape(-1, 4)[:, :3]
         #   shape_p_clip = pyflex.get_shape_states()[:3].reshape(-1,3)
            p_clip = np.concatenate([pyflex.get_positions().reshape(-1, 4)[:, :3],pyflex.get_shape_states()[:3].reshape(-1,3)],axis = 0)
            pyflex.step()

        for j in range(time_step):
            positions[j, :n_particles] = pyflex.get_positions().reshape(-1, 4)[:, :3]
            for k in range(n_shapes):
                 positions[j, n_particles + k] = pyflex.get_shape_states()[:3]
        #    shape_position[j] = pyflex.get_shape_states()[:3].reshape(-1,3)
        #    shape_prevposition = pyflex.get_shape_states()[3:6].reshape(-1,3)
            if j == 0:
                velocities[j] = (positions[j] - p_clip) / dt
           #     shape_velocities[j] = (shape_position[j] - shape_p_clip)/dt
            else:
                velocities[j] = (positions[j] - positions[j - 1]) / dt
          #      shape_velocities[j] = (shape_position[j] - shape_position[j-1])/dt

            pyflex.step()
            data = [positions[j], velocities[j], idx_hairs]
            store_data(data_names, data, os.path.join(rollout_dir, str(j) + '.h5'))
        
        # change dtype for more accurate stat calculation
        # only normalize positions and velocities
        datas = [positions.astype(np.float64), velocities.astype(np.float64)]

        for j in range(len(stats)): 
            # here j = 2, refers to positions and velocities
            stat = init_stat(stats[j].shape[0]) 
            # stat= np.zeros((3,3))
            stat[:, 0] = np.mean(datas[j], axis=(0, 1))[:]
            stat[:, 1] = np.std(datas[j], axis=(0, 1))[:]
            stat[:, 2] = datas[j].shape[0] * datas[j].shape[1] # time_step*n_particles
            stats[j] = combine_stat(stats[j], stat)

    pyflex.clean()

    return stats

In [35]:
class PhysicsFleXDataset(Dataset):

    def __init__(self, args, phase, phases_dict, verbose):
        self.args = args
        self.phase = phase
        self.phases_dict = phases_dict
        self.verbose = verbose
        self.data_dir = os.path.join(self.args.dataf, phase)
        self.stat_path = os.path.join(self.args.dataf, 'stat.h5')

        os.system('mkdir -p ' + self.data_dir)

        #    self.data_names = ['positions', 'velocities', 'shape_quats', 'clusters', 'scene_params']
        self.data_names = ['positions', 'velocities','hair_idx']

        ratio = self.args.train_valid_ratio
        if phase == 'train':
            self.n_rollout = int(self.args.n_rollout * ratio)
        elif phase == 'valid':
            self.n_rollout = self.args.n_rollout - int(self.args.n_rollout * ratio)
        else:
            raise AssertionError("Unknown phase")

    def __len__(self):
        return self.n_rollout * (self.args.time_step - 1)

    def load_data(self, name):
        self.stat = load_data(self.data_names[:2], self.stat_path)
        for i in range(len(self.stat)):
            self.stat[i] = self.stat[i][-self.args.position_dim:, :]
            # print(self.data_names[i], self.stat[i].shape)

    def gen_data(self, name):
        # if the data hasn't been generated, generate the data
        print("Generating data ... n_rollout=%d, time_step=%d" % (self.n_rollout, self.args.time_step))

        infos = []
        for i in range(self.args.num_workers):
            info = {
                'env': self.args.env,
                'root_num': self.phases_dict['root_num'],
                'thread_idx': i,
                'data_dir': self.data_dir,
                'data_names': self.data_names,
                'n_rollout': self.n_rollout // self.args.num_workers,
                'n_instance': self.args.n_instance,
                'time_step': self.args.time_step,
                'time_step_clip': self.args.time_step_clip,
                'dt': self.args.dt,
                'shape_state_dim': self.args.shape_state_dim}

            info['env_idx'] = 11
            infos.append(info)

        cores = self.args.num_workers
        pool = mp.Pool(processes=cores)
        data = pool.map(gen_PyFleX, infos)

        print("Training data generated, warpping up stats ...")

        if self.phase == 'train' and self.args.gen_stat:
            # positions [x, y, z], velocities[xdot, ydot, zdot]
            self.stat = [init_stat(3), init_stat(3)]
            for i in range(len(data)):
                for j in range(len(self.stat)):
                    self.stat[j] = combine_stat(self.stat[j], data[i][j])
            store_data(self.data_names[:2], self.stat, self.stat_path)
        else:
            print("Loading stat from %s ..." % self.stat_path)
            self.stat = load_data(self.data_names[:2], self.stat_path)

    def __getitem__(self, idx):
        idx_rollout = idx // (self.args.time_step - 1)
        idx_timestep = idx % (self.args.time_step - 1)

        # ignore the first frame for env RiceGrip
        if self.args.env == 'RiceGrip' and idx_timestep == 0:
            idx_timestep = np.random.randint(1, self.args.time_step - 1)

        data_path = os.path.join(self.data_dir, str(idx_rollout), str(idx_timestep) + '.h5')
        data_nxt_path = os.path.join(self.data_dir, str(idx_rollout), str(idx_timestep + 1) + '.h5')

        data = load_data(self.data_names, data_path)

        '''
        vel_his = []
        for i in range(self.args.n_his):
            path = os.path.join(self.data_dir, str(idx_rollout), str(max(1, idx_timestep - i - 1)) + '.h5')
            data_his = load_data(self.data_names, path)
            vel_his.append(data_his[1])

        data[1] = np.concatenate([data[1]] + vel_his, 1)
        '''
        
        attr, state, relations, n_particles, n_shapes = \
                prepare_input(data, self.stat, self.args, self.phases_dict, self.verbose)

        ### label
        data_nxt = normalize(load_data(self.data_names, data_nxt_path), self.stat)

        label = torch.FloatTensor(data_nxt[1][:n_particles])

        return attr, state, relations, n_particles, n_shapes, label

In [36]:
def prepare_input(data, stat, args, phases_dict, verbose=0, var=False):
    '''
    for a single hair
    '''
    positions, velocities, hairs_idx = data
    n_shapes = 1
    hairs_idx_begin = [idx[0] for idx in hairs_idx]
    n_particles = positions.shape[0] - n_shapes
    R = 0.12
    
    ### object attributes
    #   dim 10: [rigid, fluid, root_0, root_1, gripper_0, gripper_1, mass_inv,
    #            clusterStiffness, clusterPlasticThreshold, cluasterPlasticCreep]
    #   here we only consider the hairs but not the gripper, attr_dim = 1, attr = 0 for hair, attr = 1 for shapes
    attr = np.zeros((n_particles+n_shapes, args.attr_dim))
    
    ### construct relations
    Rr_idxs = []        # relation receiver idx list
    Rs_idxs = []        # relation sender idx list
    Ras = []            # relation attributes list
    values = []         # relation value list (should be 1)
    node_r_idxs = []    # list of corresponding receiver node idx
    node_s_idxs = []    # list of corresponding sender node idx
    psteps = []         # propagation steps
    
    ##### add env specific graph components
    ### specify for shapes
    rels = []
    vals = []
    
    for i in range(n_shapes):
        attr[n_particles+i, 0] = 1
        dis = np.linalg.norm(positions[:n_particles,:2]-positions[n_particles+i,:2],axis = 1)
        if verbose:
            print ("dis:", dis)
        nodes_rel = np.nonzero(dis <= R)[0]
        # for relation between hair nodes and a gripper, we note it as 1
        gripper = np.ones(nodes_rel.shape[0], dtype=np.int) * (n_particles+i)
        rels += [np.stack([nodes_rel, gripper, np.ones(nodes_rel.shape[0])], axis=1)]
        if verbose:
            print ("dis val:", dis[nodes_rel])
      #  vals += [np.ones(nodes_rel.shape[0], dtype=np.int)]
        vals += [dis[nodes_rel]]
        
    
    ##### add relations between leaf particles
    ## here we only consider the relations in a hair: the relation between a node and the nodes nearby
    ## simple case for one hair, TEMPORARY 2 rels for one link
    nodes_p = np.arange(n_particles-1)
    val = np.linalg.norm(positions[1:n_particles]-positions[:n_particles-1],axis = 1)
    R1 = np.stack([nodes_p,nodes_p+1, np.zeros(n_particles-1)],axis = 1)
    R2 = np.stack([nodes_p+1,nodes_p, np.zeros(n_particles-1)],axis = 1)
    rels += [np.concatenate([R1,R2],axis = 0)]
    vals += [val,val]
    
    rels = np.concatenate(rels, 0)
    vals = np.concatenate(vals, 0)
    
  #  print (vals.shape)
 #   print (rels.shape)
    
    
    if rels.shape[0] > 0:
        if verbose:
            print("Relations neighbor", rels.shape)
        Rr_idxs.append(torch.LongTensor([rels[:, 0], np.arange(rels.shape[0])]))
        Rs_idxs.append(torch.LongTensor([rels[:, 1], np.arange(rels.shape[0])]))
        # Ra: relation attributes
    #    Ra = np.zeros((rels.shape[0], args.relation_dim))  
        Ra = rels[:,2].reshape([-1,1])
        Ras.append(torch.FloatTensor(Ra))
        # values could be changed
     #   values.append(torch.FloatTensor([1] * rels.shape[0]))
      #  values.append(rels[:,2])
        #### for hairs: values equals to the length of this segment
        values.append(torch.FloatTensor(vals))
        node_r_idxs.append(np.arange(n_particles))
        node_s_idxs.append(np.arange(n_particles + n_shapes))
        psteps.append(args.pstep)
        
        
    if verbose:
        print("Attr shape (after hierarchy building):", attr.shape)
        print("Object attr:", np.sum(attr, axis=0))
        print("Particle attr:", np.sum(attr[:n_particles], axis=0))
        print("Shape attr:", np.sum(attr[n_particles:n_particles+n_shapes], axis=0))
        print("Roots attr:", np.sum(attr[n_particles+n_shapes:], axis=0))
        
        
    ### normalize data
    data = [positions, velocities]
    data = normalize(data, stat, var)
    positions, velocities = data[0], data[1]

    if verbose:
        print("Particle positions stats")
        print(positions.shape)
        print(np.min(positions[:n_particles], 0))
        print(np.max(positions[:n_particles], 0))
        print(np.mean(positions[:n_particles], 0))
        print(np.std(positions[:n_particles], 0))

        show_vel_dim = 6 if args.env == 'RiceGrip' else 3
        print("Velocities stats")
        print(velocities.shape)
        print(np.mean(velocities[:n_particles, :show_vel_dim], 0))
        print(np.std(velocities[:n_particles, :show_vel_dim], 0))
        
    state = torch.FloatTensor(np.concatenate([positions, velocities], axis=1))
    attr = torch.FloatTensor(attr)
    relations = [Rr_idxs, Rs_idxs, values, Ras, node_r_idxs, node_s_idxs, psteps]

    return attr, state, relations, n_particles, n_shapes#, instance_idx



In [37]:
datasets = {phase: PhysicsFleXDataset(
    args, phase, phases_dict, verbose=False) for phase in ['train', 'valid']}
datasets['train'].load_data(args.env)
datasets['valid'].load_data(args.env)


In [38]:
dataloaders = {x: torch.utils.data.DataLoader(
    datasets[x], batch_size=args.batch_size,
    shuffle=True if x == 'train' else False,
    num_workers=args.num_workers,
    collate_fn=collate_fn)
    for x in ['train', 'valid']}

In [39]:
args.relation_dim = 2

In [100]:
attr, state, relations, n_particles, n_shapes, label = datasets['train'].__getitem__(200)

In [64]:
relations
# relations = [Rr_idxs, Rs_idxs, values, Ras, node_r_idxs, node_s_idxs, psteps]

[[tensor([[18, 19,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,
           16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,  1,  2,  3,  4,
            5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
           23, 24, 25, 26, 27, 28, 29, 30],
          [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
           18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
           36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
           54, 55, 56, 57, 58, 59, 60, 61]])],
 [tensor([[31, 31,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
           17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,  0,  1,  2,  3,
            4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
           22, 23, 24, 25, 26, 27, 28, 29],
          [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
           18, 19, 20, 21, 22, 23, 

In [66]:
relations[1]

[tensor([[31, 31,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
          17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,  0,  1,  2,  3,
           4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
          22, 23, 24, 25, 26, 27, 28, 29],
         [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
          18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
          36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
          54, 55, 56, 57, 58, 59, 60, 61]])]

In [74]:
def pre_traite(data, verbose):
    attr, state, rels, n_particles, n_shapes, label = data
    Ra, node_r_idx, node_s_idx, pstep = rels[3], rels[4], rels[5], rels[6]

    Rr, Rs = [], []
    Values = []

    for j in range(len(rels[0])):
        Rr_idx, Rs_idx, values = rels[0][j], rels[1][j], rels[2][j]
        V = torch.ones(values.shape)
        Values.append(values)
        Rr.append(torch.sparse.FloatTensor(
            Rr_idx, V, torch.Size([node_r_idx[j].shape[0], Ra[j].size(0)])))
        Rs.append(torch.sparse.FloatTensor(
            Rs_idx, V, torch.Size([node_s_idx[j].shape[0], Ra[j].size(0)])))

    s = 0
    Rrp = Rr[s].t()
    Rsp = Rs[s].t()

    # receiver_attr, sender_attr
    attr_r = attr[node_r_idx[s]]
    attr_s = attr[node_s_idx[s]]
    attr_r_rel = Rrp.mm(attr_r)
    attr_s_rel = Rsp.mm(attr_s)

    # receiver_state, sender_state
    state_r = state[node_r_idx[s]]
    state_s = state[node_s_idx[s]]
    state_r_rel = Rrp.mm(state_r)
    state_s_rel = Rsp.mm(state_s)
    state_diff = state_r_rel - state_s_rel
    if verbose:
        print (attr_r_rel.shape)
        print (attr_s_rel.shape)
        print (state_r_rel.shape)
        print (state_s_rel.shape)
        print (Values[s].reshape([-1,1]).shape)
        print (Ra[s].shape)

    return Rr,Rs, torch.cat([attr_r_rel, attr_s_rel, state_r_rel, state_s_rel,Values[s].reshape([-1,1]), Ra[s]],1)

In [75]:
args.relation_dim = 2
state_dim = args.state_dim
attr_dim = args.attr_dim
relation_dim = args.relation_dim

data = datasets['train'].__getitem__(200)
attr, state, rels, n_particles, n_shapes, label = data
Rr,Rs,relations = pre_traite(data, verbose = 1)
relation_function = RelationEncoder(2*attr_dim + 2*state_dim+relation_dim, 100, 30)
effect = relation_function(relations)

torch.Size([62, 1])
torch.Size([62, 1])
torch.Size([62, 6])
torch.Size([62, 6])
torch.Size([62, 1])
torch.Size([62, 1])


In [76]:
rels[1]

[tensor([[31, 31,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
          17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,  0,  1,  2,  3,
           4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
          22, 23, 24, 25, 26, 27, 28, 29],
         [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
          18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
          36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
          54, 55, 56, 57, 58, 59, 60, 61]])]

In [35]:
effect.shape

torch.Size([62, 30])

In [77]:
Rr[s]

tensor(indices=tensor([[18, 19,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
                        12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
                        26, 27, 28, 29,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10,
                        11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
                        25, 26, 27, 28, 29, 30],
                       [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13,
                        14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
                        28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
                        42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
                        56, 57, 58, 59, 60, 61]]),
       values=tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 

In [78]:
Rs[s]

tensor(indices=tensor([[31, 31,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12,
                        13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
                        27, 28, 29, 30,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9,
                        10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
                        24, 25, 26, 27, 28, 29],
                       [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13,
                        14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
                        28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
                        42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
                        56, 57, 58, 59, 60, 61]]),
       values=tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 

In [37]:
s = 0
effect_r = Rr[s].mm(effect)
effect_r.shape

torch.Size([31, 30])

In [101]:
from models import *

class ParticleEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ParticleEncoder, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.ReLU()
        )

    def forward(self, x):
        '''
        Args:
            x: [n_particles, input_size]
        Returns:
            [n_particles, output_size]
        '''
        # print(x.size())
        return self.model(x)

class RelationEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RelationEncoder, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.ReLU()
        )

    def forward(self, x):
        '''
        Args:
            x: [n_relations, input_size]
        Returns:
            [n_relations, output_size]
        '''
        return self.model(x)
    
    
class IntNet(nn.Module):
    def __init__(self, args, stat, phases_dict, residual=False, use_gpu=False):
        super(IntNet, self).__init__()
        self.args = args
        state_dim = args.state_dim
        attr_dim = args.attr_dim
        relation_dim = args.relation_dim
        nf_particle = 100#args.nf_particle
        nf_relation = 100 #args.nf_relation
        nf_effect = 15 #args.nf_effect
        self.nf_effect = nf_effect
        self.stat = stat
        self.use_gpu = use_gpu
        self.residual = residual
        
        #object function
        #input: state & attribute of the receiver, sum of its effect
        self.object_function = ParticleEncoder(attr_dim + state_dim+nf_effect, nf_particle, args.position_dim)
        
        #relation function
        #input: state & attribute of the sender & receiver, state & attribute of the relation
        #output: the effect of the relation
        self.relation_function = RelationEncoder(2*attr_dim + 2*state_dim+relation_dim, nf_relation, nf_effect)
            
    def forward(self, attr, state, Rr, Rs, Ra,Values, n_particles, node_r_idx, node_s_idx, pstep,
                instance_idx, phases_dict, verbose=0):
        s = 0
        Rrp = Rr[s].t()
        Rsp = Rs[s].t()

        # receiver_attr, sender_attr
        attr_r = attr[node_r_idx[s]]
        attr_s = attr[node_s_idx[s]]
        attr_r_rel = Rrp.mm(attr_r)
        attr_s_rel = Rsp.mm(attr_s)

        # receiver_state, sender_state
        state_r = state[node_r_idx[s]]
        state_s = state[node_s_idx[s]]
        state_r_rel = Rrp.mm(state_r)
        state_s_rel = Rsp.mm(state_s)
        state_diff = state_r_rel - state_s_rel
        if verbose:
            print (attr_r_rel.shape)
            print (attr_s_rel.shape)
            print (state_r_rel.shape)
            print (state_s_rel.shape)
            print (Values[s].reshape([-1,1]).shape)
            print (Ra[s].shape)
        
        
        effect = self.relation_function(torch.cat([attr_r_rel, attr_s_rel, state_r_rel, state_s_rel,Values[s].reshape([-1,1]), Ra[s]],1))
        
        effect_r = Rr[s].mm(effect) #(31*nf_effect)
        if verbose:
            print (effect_r.shape)
        pred = self.object_function(torch.cat([attr_r, state_r,effect_r], 1))
        return pred
        
            

In [102]:
use_gpu = True
args.lr = 0.001
model = IntNet(args, datasets['train'].stat, phases_dict, residual=True, use_gpu=use_gpu)
print("Number of parameters: %d" % count_parameters(model))
# criterion
criterionMSE = nn.MSELoss()

# optimizer
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=3, verbose=True)

if use_gpu:
    model = model.cuda()
    criterionMSE = criterionMSE.cuda()

st_epoch = args.resume_epoch if args.resume_epoch > 0 else 0
best_valid_loss = np.inf

model.train(phase=='train')

Number of parameters: 46218


IntNet(
  (object_function): ParticleEncoder(
    (model): Sequential(
      (0): Linear(in_features=22, out_features=100, bias=True)
      (1): ReLU()
      (2): Linear(in_features=100, out_features=100, bias=True)
      (3): ReLU()
      (4): Linear(in_features=100, out_features=100, bias=True)
      (5): ReLU()
      (6): Linear(in_features=100, out_features=3, bias=True)
      (7): ReLU()
    )
  )
  (relation_function): RelationEncoder(
    (model): Sequential(
      (0): Linear(in_features=16, out_features=100, bias=True)
      (1): ReLU()
      (2): Linear(in_features=100, out_features=100, bias=True)
      (3): ReLU()
      (4): Linear(in_features=100, out_features=100, bias=True)
      (5): ReLU()
      (6): Linear(in_features=100, out_features=15, bias=True)
      (7): ReLU()
    )
  )
)

In [103]:
phase = 'train'
instance_idx = [0, 31]
args.n_epoch = 500
psteps = 0

for epoch in range(args.n_epoch):
    model.train(phase=='train')

    losses = 0.
    for i, data in enumerate(dataloaders[phase]):
#         print ('i:',i)

        attr, state, rels, n_particles, n_shapes, label = data
        Ra, node_r_idx, node_s_idx, pstep = rels[3], rels[4], rels[5], rels[6]

        Rr, Rs = [], []
        Values = []
        for j in range(len(rels[0])):
            Rr_idx, Rs_idx, values = rels[0][j], rels[1][j], rels[2][j]
            V = torch.ones(values.shape)
            Values.append(values)
            Rr.append(torch.sparse.FloatTensor(
                Rr_idx, V, torch.Size([node_r_idx[j].shape[0], Ra[j].size(0)])))
            Rs.append(torch.sparse.FloatTensor(
                Rs_idx, V, torch.Size([node_s_idx[j].shape[0], Ra[j].size(0)])))

        data = [attr, state, Rr, Rs, Ra,Values, label]
        

            # st_time = time.time()
        with torch.set_grad_enabled(phase=='train'):
            if use_gpu:
                instance_idx = torch.Tensor(instance_idx)#.cuda()
                for d in range(len(data)):
                    if type(data[d]) == list:
                        for t in range(len(data[d])):
                            data[d][t] = Variable(data[d][t].cuda())
                    else:
                        data[d] = Variable(data[d].cuda())
            else:
                for d in range(len(data)):
                    if type(data[d]) == list:
                        for t in range(len(data[d])):
                            data[d][t] = Variable(data[d][t])
                    else:
                        data[d] = Variable(data[d])

            attr, state, Rr, Rs, Ra, Values, label = data
            
            pstep = 3

            predicted = model(
                attr, state, Rr, Rs, Ra,Values, n_particles,
                node_r_idx, node_s_idx, pstep,
                instance_idx, phases_dict, 0)
            # print('Time forward', time.time() - st_time)

       #     print(predicted.shape)
       #     print(label.shape)

        loss = criterionMSE(predicted, label)
        losses += np.sqrt(loss.item())

        if phase == 'train':
            if i % 5 == 0:
                # update parameters every args.forward_times
          #      print ('update!')
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
          #      print ('done')


        if i % 10000 == 0:
            n_relations = 0
            for j in range(len(Ra)):
                n_relations += Ra[j].size(0)
            print('%s [%d/%d][%d/%d] n_relations: %d, Loss: %.6f, Agg: %.6f' %
                  (phase, epoch, args.n_epoch, i, len(dataloaders[phase]),
                   n_relations, np.sqrt(loss.item()), losses / (i + 1)))

      #  if phase == 'train' and i > 0 and i % args.ckp_per_iter == 0:
    if epoch % 10 == 0 :
        torch.save(model.state_dict(), '%s/IntNet_epoch_%d.pth' % (args.outf, epoch))

    losses /= len(dataloaders[phase])
    print('%s [%d/%d] Loss: %.4f' %
          (phase, epoch, args.n_epoch, losses))

train [0/500][0/26955] n_relations: 62, Loss: 0.982606, Agg: 0.982606
train [0/500][10000/26955] n_relations: 60, Loss: 0.150784, Agg: 0.755603
train [0/500][20000/26955] n_relations: 61, Loss: 1.120278, Agg: 0.751585
train [0/500] Loss: 0.7494
train [1/500][0/26955] n_relations: 62, Loss: 0.986954, Agg: 0.986954
train [1/500][10000/26955] n_relations: 62, Loss: 0.460955, Agg: 0.740314
train [1/500][20000/26955] n_relations: 60, Loss: 0.798448, Agg: 0.743218
train [1/500] Loss: 0.7448
train [2/500][0/26955] n_relations: 62, Loss: 0.818719, Agg: 0.818719
train [2/500][10000/26955] n_relations: 61, Loss: 0.887499, Agg: 0.741781
train [2/500][20000/26955] n_relations: 60, Loss: 0.518647, Agg: 0.739868
train [2/500] Loss: 0.7426
train [3/500][0/26955] n_relations: 60, Loss: 0.979256, Agg: 0.979256
train [3/500][10000/26955] n_relations: 60, Loss: 0.895848, Agg: 0.742951
train [3/500][20000/26955] n_relations: 60, Loss: 0.507019, Agg: 0.740993
train [3/500] Loss: 0.7420
train [4/500][0/2695

KeyboardInterrupt: 

In [40]:
class DPINet2(nn.Module):
    def __init__(self, args, stat, phases_dict, residual=False, use_gpu=False):
        super(DPINet2, self).__init__()

        self.args = args

        state_dim = args.state_dim
        attr_dim = args.attr_dim
        relation_dim = args.relation_dim
        nf_particle = 100#args.nf_particle
        nf_relation = 150 #args.nf_relation
        nf_effect = 150 #args.nf_effect

        self.nf_effect = nf_effect

        self.stat = stat
        self.use_gpu = use_gpu
        self.residual = residual 

        # (1) particle attr (2) state
        self.particle_encoder = ParticleEncoder(attr_dim + state_dim, nf_particle, nf_effect)

        # (1) sender attr (2) receiver attr (3) state receiver (4) state_diff (5) relation attr
        self.relation_encoder = RelationEncoder( 2 * attr_dim + 2 * state_dim + relation_dim, nf_relation, nf_relation)

        # (1) relation encode (2) sender effect (3) receiver effect
        self.relation_propagator = Propagator(nf_relation + 2 * nf_effect, nf_effect)

        # (1) particle encode (2) particle effect
        self.particle_propagator = Propagator(2 * nf_effect, nf_effect)

        # (1) set particle effect
        self.particle_predictor = ParticlePredictor(nf_effect, nf_effect, args.position_dim)
        
        
    def forward(self, attr, state, Rr, Rs, Ra, Values, n_particles, node_r_idx, node_s_idx, pstep,
                instance_idx, phases_dict, verbose=0):
        
        # calculate particle encoding
        if self.use_gpu:
            particle_effect = Variable(torch.zeros((attr.size(0), self.nf_effect)).cuda())
        else:
            particle_effect = Variable(torch.zeros((attr.size(0), self.nf_effect)))
            
        s = 0
        Rrp = Rr[s].t()
        Rsp = Rs[s].t()

        # receiver_attr, sender_attr
        attr_r = attr[node_r_idx[s]]
        attr_s = attr[node_s_idx[s]]
        attr_r_rel = Rrp.mm(attr_r)
        attr_s_rel = Rsp.mm(attr_s)

        # receiver_state, sender_state
        state_r = state[node_r_idx[s]]
        state_s = state[node_s_idx[s]]
        state_r_rel = Rrp.mm(state_r)
        state_s_rel = Rsp.mm(state_s)
        state_diff = state_r_rel - state_s_rel
        
        particle_encode = self.particle_encoder(torch.cat([attr_r, state_r], 1))
        relation_encode = self.relation_encoder(
                torch.cat([attr_r_rel, attr_s_rel, state_r_rel, state_s_rel, Values[s].reshape([-1,1]), Ra[s]], 1))
        
        for i in range(pstep):
            effect_p_r = particle_effect[node_r_idx[s]]
            effect_p_s = particle_effect[node_s_idx[s]]

            receiver_effect = Rrp.mm(effect_p_r)
            sender_effect = Rsp.mm(effect_p_s)
            

            # calculate relation effect
            effect_rel = self.relation_propagator(
                torch.cat([relation_encode, receiver_effect, sender_effect], 1))

            # calculate particle effect by aggregating relation effect
            effect_p_r_agg = Rr[s].mm(effect_rel)

            # calculate particle effect
            effect_p = self.particle_propagator(
                torch.cat([particle_encode, effect_p_r_agg], 1),
                res=effect_p_r)
            particle_effect[node_r_idx[s]] = effect_p
            
        pred = self.particle_predictor(particle_effect[:31])
        return pred

In [21]:
from models import *

class DPINet(nn.Module):
    def __init__(self, args, stat, phases_dict, residual=False, use_gpu=False):
        super(DPINet, self).__init__()

        self.args = args

        state_dim = args.state_dim
        attr_dim = args.attr_dim
        relation_dim = args.relation_dim
        nf_particle = 100#args.nf_particle
        nf_relation = 150 #args.nf_relation
        nf_effect = 150 #args.nf_effect

        self.nf_effect = nf_effect

        self.stat = stat
        self.use_gpu = use_gpu
        self.residual = residual 

        # (1) particle attr (2) state
        self.particle_encoder = ParticleEncoder(attr_dim + state_dim, nf_particle, nf_effect)

        # (1) sender attr (2) receiver attr (3) state receiver (4) state_diff (5) relation attr
        self.relation_encoder = RelationEncoder( 2 * attr_dim + 2 * state_dim + relation_dim, nf_relation, nf_relation)

        # (1) relation encode (2) sender effect (3) receiver effect
        self.relation_propagator = Propagator(nf_relation + 2 * nf_effect, nf_effect)

        # (1) particle encode (2) particle effect
        self.particle_propagator = Propagator(2 * nf_effect, nf_effect)

        # (1) set particle effect
        self.particle_predictor = ParticlePredictor(nf_effect, nf_effect, args.position_dim)
        
        
    def forward(self, attr, state, Rr, Rs, Ra, n_particles, node_r_idx, node_s_idx, pstep,
                instance_idx, phases_dict, verbose=0):
        
        # calculate particle encoding
        if self.use_gpu:
            particle_effect = Variable(torch.zeros((attr.size(0), self.nf_effect)).cuda())
        else:
            particle_effect = Variable(torch.zeros((attr.size(0), self.nf_effect)))
            
        s = 0
        Rrp = Rr[s].t()
        Rsp = Rs[s].t()

        # receiver_attr, sender_attr
        attr_r = attr[node_r_idx[s]]
        attr_s = attr[node_s_idx[s]]
        attr_r_rel = Rrp.mm(attr_r)
        attr_s_rel = Rsp.mm(attr_s)

        # receiver_state, sender_state
        state_r = state[node_r_idx[s]]
        state_s = state[node_s_idx[s]]
        state_r_rel = Rrp.mm(state_r)
        state_s_rel = Rsp.mm(state_s)
        state_diff = state_r_rel - state_s_rel
        
        particle_encode = self.particle_encoder(torch.cat([attr_r, state_r], 1))
        relation_encode = self.relation_encoder(
                torch.cat([attr_r_rel, attr_s_rel, state_r_rel, state_s_rel, Ra[s]], 1))
        
        for i in range(pstep):
            effect_p_r = particle_effect[node_r_idx[s]]
            effect_p_s = particle_effect[node_s_idx[s]]

            receiver_effect = Rrp.mm(effect_p_r)
            sender_effect = Rsp.mm(effect_p_s)
            

            # calculate relation effect
            effect_rel = self.relation_propagator(
                torch.cat([relation_encode, receiver_effect, sender_effect], 1))

            # calculate particle effect by aggregating relation effect
            effect_p_r_agg = Rr[s].mm(effect_rel)

            # calculate particle effect
            effect_p = self.particle_propagator(
                torch.cat([particle_encode, effect_p_r_agg], 1),
                res=effect_p_r)
            particle_effect[node_r_idx[s]] = effect_p
            
        pred = self.particle_predictor(particle_effect[:31])
        return pred

In [41]:
args.pstep = 3
pstep = 3
use_gpu = True
model = DPINet2(args, datasets['train'].stat, phases_dict, residual=True, use_gpu=use_gpu)
print("Number of parameters: %d" % count_parameters(model))
# criterion
criterionMSE = nn.MSELoss()

# optimizer
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.8, patience=3, verbose=True)

if use_gpu:
    model = model.cuda()
    criterionMSE = criterionMSE.cuda()

st_epoch = args.resume_epoch if args.resume_epoch > 0 else 0
best_valid_loss = np.inf

model.train(phase=='train')

Number of parameters: 222353


DPINet2(
  (particle_encoder): ParticleEncoder(
    (model): Sequential(
      (0): Linear(in_features=7, out_features=100, bias=True)
      (1): ReLU()
      (2): Linear(in_features=100, out_features=150, bias=True)
      (3): ReLU()
    )
  )
  (relation_encoder): RelationEncoder(
    (model): Sequential(
      (0): Linear(in_features=16, out_features=150, bias=True)
      (1): ReLU()
      (2): Linear(in_features=150, out_features=150, bias=True)
      (3): ReLU()
      (4): Linear(in_features=150, out_features=150, bias=True)
      (5): ReLU()
    )
  )
  (relation_propagator): Propagator(
    (linear): Linear(in_features=450, out_features=150, bias=True)
    (relu): ReLU()
  )
  (particle_propagator): Propagator(
    (linear): Linear(in_features=300, out_features=150, bias=True)
    (relu): ReLU()
  )
  (particle_predictor): ParticlePredictor(
    (linear_0): Linear(in_features=150, out_features=150, bias=True)
    (linear_1): Linear(in_features=150, out_features=150, bias=True)
 

In [10]:
phase = 'train'
for i, data in enumerate(dataloaders[phase]):
    attr, state, rels, n_particles, n_shapes, label = data
    print ("attr", attr)
    print ("state", state)
    print ("label", label)
    break

Relations neighborRelations neighbor  (60, 3)Relations neighbor(60, 3)

Attr shape (after hierarchy building):  (32, 1)(60, 3)

Object attr:Attr shape (after hierarchy building):  [1.](32, 1)

Attr shape (after hierarchy building):Object attr:  (32, 1)
[1.]Object attr:
 Particle attr: [1.][0.]
Particle attr:
Particle attr:  Shape attr:[0.] [0.]

[1.]Shape attr:Shape attr:
  [1.]Roots attr:
[1.]Roots attr: 
 [0.]Roots attr:[0.]

 Particle positions statsParticle positions stats[0.]


(32, 3)(32, 3)
Particle positions stats[ 0.39366985 -1.797818   -0.21294536]

(32, 3)
[ 0.39366985 -1.728426   -0.21294536]
[-3.53065289 -1.35067389 -0.21294536]
[ 0.39366985  1.72843141 -0.21294536]

[0.39366985 1.72843141 2.5299532 ]
[-1.51864847  0.1531911   1.14925076]
[1.20698741 0.91447686 0.78803271]
[ 0.39366985  1.72843141 -0.21294536][ 0.39366985 -0.00682503 -0.21294536]
Velocities stats[1.66533454e-16 1.03047548e+00 1.66533454e-16]
Velocities stats

(32, 3)
(32, 3)

[ 0.39366985 -0.05010454 -0.21

In [12]:
Ra, node_r_idx, node_s_idx, pstep = rels[3], rels[4], rels[5], rels[6]

In [13]:
for j in range(len(rels[0])):
    Rr_idx, Rs_idx, values = rels[0][j], rels[1][j], rels[2][j]

In [14]:
values

tensor([0.1097, 0.1093, 0.1104, 0.1099, 0.1098, 0.1097, 0.1095, 0.1094, 0.1093,
        0.1091, 0.1090, 0.1089, 0.1088, 0.1087, 0.1086, 0.1085, 0.1084, 0.1083,
        0.1082, 0.1082, 0.1081, 0.1080, 0.1080, 0.1079, 0.1078, 0.1078, 0.1077,
        0.1076, 0.1076, 0.1075, 0.1097, 0.1093, 0.1104, 0.1099, 0.1098, 0.1097,
        0.1095, 0.1094, 0.1093, 0.1091, 0.1090, 0.1089, 0.1088, 0.1087, 0.1086,
        0.1085, 0.1084, 0.1083, 0.1082, 0.1082, 0.1081, 0.1080, 0.1080, 0.1079,
        0.1078, 0.1078, 0.1077, 0.1076, 0.1076, 0.1075])

In [None]:
phase = 'train'
instance_idx = [0, 31]
args.n_epoch = 500
psteps = 3

for epoch in range(args.n_epoch):
    model.train(phase=='train')

    losses = 0.
    for i, data in enumerate(dataloaders[phase]):
#         print ('i:',i)

        attr, state, rels, n_particles, n_shapes, label = data
        Ra, node_r_idx, node_s_idx, pstep = rels[3], rels[4], rels[5], rels[6]
        
        '''
        Rr, Rs = [], []
        for j in range(len(rels[0])):
            Rr_idx, Rs_idx, values = rels[0][j], rels[1][j], rels[2][j]
            Rr.append(torch.sparse.FloatTensor(
                Rr_idx, values, torch.Size([node_r_idx[j].shape[0], Ra[j].size(0)])))
            Rs.append(torch.sparse.FloatTensor(
                Rs_idx, values, torch.Size([node_s_idx[j].shape[0], Ra[j].size(0)])))

        data = [attr, state, Rr, Rs, Ra, label]
        '''
        Rr, Rs = [], []
        Values = []

        for j in range(len(rels[0])):
            Rr_idx, Rs_idx, values = rels[0][j], rels[1][j], rels[2][j]
            V = torch.ones(values.shape)
            Values.append(values)
            Rr.append(torch.sparse.FloatTensor(
                Rr_idx, V, torch.Size([node_r_idx[j].shape[0], Ra[j].size(0)])))
            Rs.append(torch.sparse.FloatTensor(
                Rs_idx, V, torch.Size([node_s_idx[j].shape[0], Ra[j].size(0)])))
        data = [attr, state, Rr, Rs, Ra, Values, label]


            # st_time = time.time()
        with torch.set_grad_enabled(phase=='train'):
            if use_gpu:
                for d in range(len(data)):
                    if type(data[d]) == list:
                        for t in range(len(data[d])):
                            data[d][t] = Variable(data[d][t].cuda())
                    else:
                        data[d] = Variable(data[d].cuda())
            else:
                for d in range(len(data)):
                    if type(data[d]) == list:
                        for t in range(len(data[d])):
                            data[d][t] = Variable(data[d][t])
                    else:
                        data[d] = Variable(data[d])

            attr, state, Rr, Rs, Ra, Values, label = data
            
            pstep = 3

            predicted = model(
                attr, state, Rr, Rs, Ra, Values, n_particles,
                node_r_idx, node_s_idx, pstep,
                instance_idx, phases_dict, 0)
            # print('Time forward', time.time() - st_time)

       #     print(predicted.shape)
       #     print(label.shape)

        loss = criterionMSE(predicted, label)
        losses += np.sqrt(loss.item())

        if phase == 'train':
            if i % 5 == 0:
                # update parameters every args.forward_times
          #      print ('update!')
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
          #      print ('done')


        if i % 5000 == 0:
            n_relations = 0
            for j in range(len(Ra)):
                n_relations += Ra[j].size(0)
            print('%s [%d/%d][%d/%d] n_relations: %d, Loss: %.6f, Agg: %.6f' %
                  (phase, epoch, args.n_epoch, i, len(dataloaders[phase]),
                   n_relations, np.sqrt(loss.item()), losses / (i + 1)))

      #  if phase == 'train' and i > 0 and i % args.ckp_per_iter == 0:
    if epoch % 10 == 0 :
        torch.save(model.state_dict(), '%s/DPINet2_epoch_%d.pth' % (args.outf, epoch))

    losses /= len(dataloaders[phase])
    print('%s [%d/%d] Loss: %.4f' %
          (phase, epoch, args.n_epoch, losses))

train [0/500][0/26955] n_relations: 62, Loss: 1.243727, Agg: 1.243727
train [0/500][5000/26955] n_relations: 60, Loss: 0.290006, Agg: 0.416934
train [0/500][10000/26955] n_relations: 60, Loss: 0.084720, Agg: 0.330442
train [0/500][15000/26955] n_relations: 60, Loss: 0.284340, Agg: 0.292242
train [0/500][20000/26955] n_relations: 60, Loss: 0.113742, Agg: 0.270082
train [0/500][25000/26955] n_relations: 60, Loss: 0.201741, Agg: 0.254929
train [0/500] Loss: 0.2500
train [1/500][0/26955] n_relations: 60, Loss: 0.311449, Agg: 0.311449
train [1/500][5000/26955] n_relations: 60, Loss: 0.075497, Agg: 0.189254
train [1/500][10000/26955] n_relations: 62, Loss: 0.203056, Agg: 0.187008
train [1/500][15000/26955] n_relations: 60, Loss: 0.242343, Agg: 0.183270
train [1/500][20000/26955] n_relations: 60, Loss: 0.322139, Agg: 0.181136
train [1/500][25000/26955] n_relations: 62, Loss: 0.129658, Agg: 0.179342
train [1/500] Loss: 0.1789
train [2/500][0/26955] n_relations: 62, Loss: 0.169144, Agg: 0.16914

train [17/500][20000/26955] n_relations: 60, Loss: 0.253782, Agg: 0.095544
train [17/500][25000/26955] n_relations: 62, Loss: 0.064974, Agg: 0.095276
train [17/500] Loss: 0.0957
train [18/500][0/26955] n_relations: 60, Loss: 0.398861, Agg: 0.398861
train [18/500][5000/26955] n_relations: 60, Loss: 0.203829, Agg: 0.094305
train [18/500][10000/26955] n_relations: 60, Loss: 0.084420, Agg: 0.094757
train [18/500][15000/26955] n_relations: 60, Loss: 0.042291, Agg: 0.093893
train [18/500][20000/26955] n_relations: 62, Loss: 0.063040, Agg: 0.094722
train [18/500][25000/26955] n_relations: 60, Loss: 0.043475, Agg: 0.093892
train [18/500] Loss: 0.0936
train [19/500][0/26955] n_relations: 62, Loss: 0.077937, Agg: 0.077937
train [19/500][5000/26955] n_relations: 60, Loss: 0.045727, Agg: 0.093145
train [19/500][10000/26955] n_relations: 60, Loss: 0.287035, Agg: 0.097234
train [19/500][15000/26955] n_relations: 62, Loss: 0.129779, Agg: 0.095882
train [19/500][20000/26955] n_relations: 62, Loss: 0.0

train [35/500][0/26955] n_relations: 62, Loss: 0.044710, Agg: 0.044710
train [35/500][5000/26955] n_relations: 62, Loss: 0.044979, Agg: 0.080887
train [35/500][10000/26955] n_relations: 60, Loss: 0.027383, Agg: 0.077277
train [35/500][15000/26955] n_relations: 60, Loss: 0.213023, Agg: 0.077033
train [35/500][20000/26955] n_relations: 60, Loss: 0.114183, Agg: 0.076299
train [35/500][25000/26955] n_relations: 60, Loss: 0.038596, Agg: 0.075611
train [35/500] Loss: 0.0757
train [36/500][0/26955] n_relations: 60, Loss: 0.045826, Agg: 0.045826
train [36/500][5000/26955] n_relations: 60, Loss: 0.032341, Agg: 0.079709
train [36/500][10000/26955] n_relations: 60, Loss: 0.033588, Agg: 0.078184
train [36/500][15000/26955] n_relations: 60, Loss: 0.139470, Agg: 0.078106
train [36/500][20000/26955] n_relations: 60, Loss: 0.070638, Agg: 0.077045
train [36/500][25000/26955] n_relations: 62, Loss: 0.054146, Agg: 0.076592
train [36/500] Loss: 0.0762
train [37/500][0/26955] n_relations: 60, Loss: 0.04787

train [52/500][15000/26955] n_relations: 62, Loss: 0.039414, Agg: 0.064418
train [52/500][20000/26955] n_relations: 60, Loss: 0.020232, Agg: 0.065055
train [52/500][25000/26955] n_relations: 62, Loss: 0.032849, Agg: 0.064704
train [52/500] Loss: 0.0648
train [53/500][0/26955] n_relations: 60, Loss: 0.043366, Agg: 0.043366
train [53/500][5000/26955] n_relations: 60, Loss: 0.070238, Agg: 0.064291
train [53/500][10000/26955] n_relations: 62, Loss: 0.038079, Agg: 0.066060
train [53/500][15000/26955] n_relations: 60, Loss: 0.066210, Agg: 0.065903
train [53/500][20000/26955] n_relations: 60, Loss: 0.168970, Agg: 0.065179
train [53/500][25000/26955] n_relations: 60, Loss: 0.172400, Agg: 0.065043
train [53/500] Loss: 0.0650
train [54/500][0/26955] n_relations: 62, Loss: 0.049428, Agg: 0.049428
train [54/500][5000/26955] n_relations: 61, Loss: 0.018727, Agg: 0.064794
train [54/500][10000/26955] n_relations: 62, Loss: 0.033700, Agg: 0.064794
train [54/500][15000/26955] n_relations: 60, Loss: 0.0

train [69/500] Loss: 0.0614
train [70/500][0/26955] n_relations: 60, Loss: 0.221725, Agg: 0.221725
train [70/500][5000/26955] n_relations: 62, Loss: 0.037041, Agg: 0.058634
train [70/500][10000/26955] n_relations: 62, Loss: 0.025179, Agg: 0.059944
train [70/500][15000/26955] n_relations: 61, Loss: 0.052782, Agg: 0.060839
train [70/500][20000/26955] n_relations: 62, Loss: 0.036139, Agg: 0.060328
train [70/500][25000/26955] n_relations: 60, Loss: 0.031018, Agg: 0.060457
train [70/500] Loss: 0.0608
train [71/500][0/26955] n_relations: 62, Loss: 0.023650, Agg: 0.023650
train [71/500][5000/26955] n_relations: 62, Loss: 0.028776, Agg: 0.059446
train [71/500][10000/26955] n_relations: 61, Loss: 0.029073, Agg: 0.059830
train [71/500][15000/26955] n_relations: 60, Loss: 0.106853, Agg: 0.060497
train [71/500][20000/26955] n_relations: 60, Loss: 0.019886, Agg: 0.060409
train [71/500][25000/26955] n_relations: 60, Loss: 0.242241, Agg: 0.060455
train [71/500] Loss: 0.0604
train [72/500][0/26955] n_

train [87/500][10000/26955] n_relations: 60, Loss: 0.017583, Agg: 0.059264
train [87/500][15000/26955] n_relations: 62, Loss: 0.089049, Agg: 0.058390
train [87/500][20000/26955] n_relations: 60, Loss: 0.158759, Agg: 0.058337
train [87/500][25000/26955] n_relations: 60, Loss: 0.166332, Agg: 0.058173
train [87/500] Loss: 0.0583
train [88/500][0/26955] n_relations: 60, Loss: 0.022304, Agg: 0.022304
train [88/500][5000/26955] n_relations: 60, Loss: 0.189249, Agg: 0.055623
train [88/500][10000/26955] n_relations: 60, Loss: 0.021881, Agg: 0.056717
train [88/500][15000/26955] n_relations: 62, Loss: 0.029683, Agg: 0.057984
train [88/500][20000/26955] n_relations: 60, Loss: 0.024740, Agg: 0.058427
train [88/500][25000/26955] n_relations: 60, Loss: 0.029886, Agg: 0.058682
train [88/500] Loss: 0.0584
train [89/500][0/26955] n_relations: 62, Loss: 0.028717, Agg: 0.028717
train [89/500][5000/26955] n_relations: 62, Loss: 0.101012, Agg: 0.057766
train [89/500][10000/26955] n_relations: 62, Loss: 0.0

train [104/500][20000/26955] n_relations: 60, Loss: 0.020059, Agg: 0.055981
train [104/500][25000/26955] n_relations: 60, Loss: 0.053927, Agg: 0.056098
train [104/500] Loss: 0.0563
train [105/500][0/26955] n_relations: 60, Loss: 0.165756, Agg: 0.165756
train [105/500][5000/26955] n_relations: 60, Loss: 0.291647, Agg: 0.055590
train [105/500][10000/26955] n_relations: 60, Loss: 0.021151, Agg: 0.054615
train [105/500][15000/26955] n_relations: 62, Loss: 0.026808, Agg: 0.055216
train [105/500][20000/26955] n_relations: 61, Loss: 0.026570, Agg: 0.055293
train [105/500][25000/26955] n_relations: 61, Loss: 0.030680, Agg: 0.055538
train [105/500] Loss: 0.0555
train [106/500][0/26955] n_relations: 60, Loss: 0.019114, Agg: 0.019114
train [106/500][5000/26955] n_relations: 60, Loss: 0.228044, Agg: 0.056228
train [106/500][10000/26955] n_relations: 60, Loss: 0.026788, Agg: 0.057151
train [106/500][15000/26955] n_relations: 62, Loss: 0.033556, Agg: 0.055653
train [106/500][20000/26955] n_relations

train [121/500][25000/26955] n_relations: 60, Loss: 0.136736, Agg: 0.053848
train [121/500] Loss: 0.0538
train [122/500][0/26955] n_relations: 62, Loss: 0.019280, Agg: 0.019280
train [122/500][5000/26955] n_relations: 60, Loss: 0.033211, Agg: 0.053824
train [122/500][10000/26955] n_relations: 60, Loss: 0.042274, Agg: 0.053800
train [122/500][15000/26955] n_relations: 60, Loss: 0.015130, Agg: 0.053019
train [122/500][20000/26955] n_relations: 60, Loss: 0.011021, Agg: 0.052380
train [122/500][25000/26955] n_relations: 61, Loss: 0.014242, Agg: 0.052733
train [122/500] Loss: 0.0529
train [123/500][0/26955] n_relations: 61, Loss: 0.029007, Agg: 0.029007
train [123/500][5000/26955] n_relations: 61, Loss: 0.016425, Agg: 0.054461
train [123/500][10000/26955] n_relations: 60, Loss: 0.019086, Agg: 0.053739
train [123/500][15000/26955] n_relations: 60, Loss: 0.021916, Agg: 0.053764
train [123/500][20000/26955] n_relations: 60, Loss: 0.016747, Agg: 0.053298
train [123/500][25000/26955] n_relations

train [138/500] Loss: 0.0527
train [139/500][0/26955] n_relations: 60, Loss: 0.059305, Agg: 0.059305
train [139/500][5000/26955] n_relations: 61, Loss: 0.031060, Agg: 0.053244
train [139/500][10000/26955] n_relations: 62, Loss: 0.052701, Agg: 0.052671
train [139/500][15000/26955] n_relations: 62, Loss: 0.021914, Agg: 0.052314
train [139/500][20000/26955] n_relations: 62, Loss: 0.022933, Agg: 0.052466
train [139/500][25000/26955] n_relations: 60, Loss: 0.013731, Agg: 0.052461
train [139/500] Loss: 0.0523
train [140/500][0/26955] n_relations: 60, Loss: 0.025632, Agg: 0.025632
train [140/500][5000/26955] n_relations: 60, Loss: 0.071388, Agg: 0.054461
train [140/500][10000/26955] n_relations: 62, Loss: 0.017808, Agg: 0.054388
train [140/500][15000/26955] n_relations: 60, Loss: 0.176714, Agg: 0.053593
train [140/500][20000/26955] n_relations: 62, Loss: 0.026602, Agg: 0.053348
train [140/500][25000/26955] n_relations: 61, Loss: 0.036878, Agg: 0.052842
train [140/500] Loss: 0.0529
train [141/

train [156/500][5000/26955] n_relations: 61, Loss: 0.033789, Agg: 0.054444
train [156/500][10000/26955] n_relations: 61, Loss: 0.065897, Agg: 0.051025
train [156/500][15000/26955] n_relations: 62, Loss: 0.122191, Agg: 0.051002
train [156/500][20000/26955] n_relations: 62, Loss: 0.011418, Agg: 0.051057
train [156/500][25000/26955] n_relations: 60, Loss: 0.027665, Agg: 0.050978
train [156/500] Loss: 0.0511
train [157/500][0/26955] n_relations: 60, Loss: 0.015665, Agg: 0.015665
train [157/500][5000/26955] n_relations: 60, Loss: 0.044558, Agg: 0.050721
train [157/500][10000/26955] n_relations: 62, Loss: 0.025860, Agg: 0.050927
train [157/500][15000/26955] n_relations: 60, Loss: 0.013451, Agg: 0.051265
train [157/500][20000/26955] n_relations: 61, Loss: 0.041829, Agg: 0.051029
train [157/500][25000/26955] n_relations: 60, Loss: 0.025657, Agg: 0.051012
train [157/500] Loss: 0.0509
train [158/500][0/26955] n_relations: 60, Loss: 0.012194, Agg: 0.012194
train [158/500][5000/26955] n_relations:

train [173/500][10000/26955] n_relations: 62, Loss: 0.010856, Agg: 0.049457
train [173/500][15000/26955] n_relations: 60, Loss: 0.072271, Agg: 0.049512
train [173/500][20000/26955] n_relations: 60, Loss: 0.017280, Agg: 0.049840
train [173/500][25000/26955] n_relations: 60, Loss: 0.025823, Agg: 0.049592
train [173/500] Loss: 0.0498
train [174/500][0/26955] n_relations: 62, Loss: 0.019328, Agg: 0.019328
train [174/500][5000/26955] n_relations: 60, Loss: 0.014603, Agg: 0.048463
train [174/500][10000/26955] n_relations: 60, Loss: 0.041879, Agg: 0.050110
train [174/500][15000/26955] n_relations: 60, Loss: 0.028683, Agg: 0.050338
train [174/500][20000/26955] n_relations: 60, Loss: 0.018394, Agg: 0.050053
train [174/500][25000/26955] n_relations: 60, Loss: 0.031273, Agg: 0.050112
train [174/500] Loss: 0.0502
train [175/500][0/26955] n_relations: 60, Loss: 0.026372, Agg: 0.026372
train [175/500][5000/26955] n_relations: 62, Loss: 0.016370, Agg: 0.049594
train [175/500][10000/26955] n_relations

train [190/500][15000/26955] n_relations: 60, Loss: 0.018101, Agg: 0.049870
train [190/500][20000/26955] n_relations: 60, Loss: 0.198587, Agg: 0.049688
train [190/500][25000/26955] n_relations: 60, Loss: 0.016194, Agg: 0.049833
train [190/500] Loss: 0.0495
train [191/500][0/26955] n_relations: 60, Loss: 0.029935, Agg: 0.029935
train [191/500][5000/26955] n_relations: 60, Loss: 0.302958, Agg: 0.045666
train [191/500][10000/26955] n_relations: 62, Loss: 0.017773, Agg: 0.046996
train [191/500][15000/26955] n_relations: 61, Loss: 0.020640, Agg: 0.047349
train [191/500][20000/26955] n_relations: 60, Loss: 0.063414, Agg: 0.048311
train [191/500][25000/26955] n_relations: 60, Loss: 0.024963, Agg: 0.048749
train [191/500] Loss: 0.0489
train [192/500][0/26955] n_relations: 61, Loss: 0.042458, Agg: 0.042458
train [192/500][5000/26955] n_relations: 61, Loss: 0.015408, Agg: 0.048870
train [192/500][10000/26955] n_relations: 60, Loss: 0.011963, Agg: 0.049527
train [192/500][15000/26955] n_relations

train [207/500][20000/26955] n_relations: 60, Loss: 0.149560, Agg: 0.047715
train [207/500][25000/26955] n_relations: 60, Loss: 0.013261, Agg: 0.048313
train [207/500] Loss: 0.0482
train [208/500][0/26955] n_relations: 62, Loss: 0.017130, Agg: 0.017130
train [208/500][5000/26955] n_relations: 60, Loss: 0.103731, Agg: 0.047820
train [208/500][10000/26955] n_relations: 60, Loss: 0.038398, Agg: 0.047865
train [208/500][15000/26955] n_relations: 61, Loss: 0.013141, Agg: 0.048553
train [208/500][20000/26955] n_relations: 60, Loss: 0.022888, Agg: 0.048039
train [208/500][25000/26955] n_relations: 61, Loss: 0.009812, Agg: 0.048200
train [208/500] Loss: 0.0481
train [209/500][0/26955] n_relations: 60, Loss: 0.020995, Agg: 0.020995
train [209/500][5000/26955] n_relations: 60, Loss: 0.026082, Agg: 0.047248
train [209/500][10000/26955] n_relations: 60, Loss: 0.285039, Agg: 0.047562
train [209/500][15000/26955] n_relations: 62, Loss: 0.022292, Agg: 0.048268
train [209/500][20000/26955] n_relations

train [224/500][25000/26955] n_relations: 60, Loss: 0.014354, Agg: 0.045711
train [224/500] Loss: 0.0456
train [225/500][0/26955] n_relations: 60, Loss: 0.014746, Agg: 0.014746
train [225/500][5000/26955] n_relations: 62, Loss: 0.011614, Agg: 0.045805
train [225/500][10000/26955] n_relations: 60, Loss: 0.022315, Agg: 0.047090
train [225/500][15000/26955] n_relations: 60, Loss: 0.013232, Agg: 0.046114
train [225/500][20000/26955] n_relations: 60, Loss: 0.246915, Agg: 0.045683
train [225/500][25000/26955] n_relations: 62, Loss: 0.014047, Agg: 0.045295
train [225/500] Loss: 0.0453
train [226/500][0/26955] n_relations: 60, Loss: 0.015777, Agg: 0.015777
train [226/500][5000/26955] n_relations: 61, Loss: 0.017777, Agg: 0.045979
train [226/500][10000/26955] n_relations: 62, Loss: 0.044132, Agg: 0.045249
train [226/500][15000/26955] n_relations: 60, Loss: 0.011730, Agg: 0.044852
train [226/500][20000/26955] n_relations: 60, Loss: 0.122309, Agg: 0.044917
train [226/500][25000/26955] n_relations

train [241/500] Loss: 0.0407
train [242/500][0/26955] n_relations: 62, Loss: 0.030691, Agg: 0.030691
train [242/500][5000/26955] n_relations: 60, Loss: 0.014163, Agg: 0.041095
train [242/500][10000/26955] n_relations: 60, Loss: 0.071802, Agg: 0.040874
train [242/500][15000/26955] n_relations: 60, Loss: 0.011595, Agg: 0.039946
train [242/500][20000/26955] n_relations: 62, Loss: 0.014045, Agg: 0.040295
train [242/500][25000/26955] n_relations: 60, Loss: 0.023333, Agg: 0.040443
train [242/500] Loss: 0.0402
train [243/500][0/26955] n_relations: 62, Loss: 0.028820, Agg: 0.028820
train [243/500][5000/26955] n_relations: 61, Loss: 0.017726, Agg: 0.040166
train [243/500][10000/26955] n_relations: 60, Loss: 0.029863, Agg: 0.039157
train [243/500][15000/26955] n_relations: 60, Loss: 0.028241, Agg: 0.040455
train [243/500][20000/26955] n_relations: 62, Loss: 0.021627, Agg: 0.041077
train [243/500][25000/26955] n_relations: 60, Loss: 0.187487, Agg: 0.041244
train [243/500] Loss: 0.0412
train [244/

train [259/500][5000/26955] n_relations: 60, Loss: 0.016429, Agg: 0.038058
train [259/500][10000/26955] n_relations: 60, Loss: 0.027463, Agg: 0.034754
train [259/500][15000/26955] n_relations: 60, Loss: 0.074908, Agg: 0.033298
train [259/500][20000/26955] n_relations: 60, Loss: 0.121217, Agg: 0.033631
train [259/500][25000/26955] n_relations: 60, Loss: 0.179425, Agg: 0.033481
train [259/500] Loss: 0.0331
train [260/500][0/26955] n_relations: 60, Loss: 0.008222, Agg: 0.008222
train [260/500][5000/26955] n_relations: 62, Loss: 0.024431, Agg: 0.030455
train [260/500][10000/26955] n_relations: 60, Loss: 0.014171, Agg: 0.031365
train [260/500][15000/26955] n_relations: 60, Loss: 0.134920, Agg: 0.031723
train [260/500][20000/26955] n_relations: 62, Loss: 0.023597, Agg: 0.031097
train [260/500][25000/26955] n_relations: 60, Loss: 0.013950, Agg: 0.031103
train [260/500] Loss: 0.0312
train [261/500][0/26955] n_relations: 62, Loss: 0.018453, Agg: 0.018453
train [261/500][5000/26955] n_relations:

train [276/500][10000/26955] n_relations: 62, Loss: 0.014292, Agg: 0.026213
train [276/500][15000/26955] n_relations: 60, Loss: 0.093754, Agg: 0.027452
train [276/500][20000/26955] n_relations: 62, Loss: 0.036281, Agg: 0.027184
train [276/500][25000/26955] n_relations: 62, Loss: 0.012465, Agg: 0.026244
train [276/500] Loss: 0.0268
train [277/500][0/26955] n_relations: 61, Loss: 0.023160, Agg: 0.023160
train [277/500][5000/26955] n_relations: 60, Loss: 0.029631, Agg: 0.029271
train [277/500][10000/26955] n_relations: 61, Loss: 0.016560, Agg: 0.026618
train [277/500][15000/26955] n_relations: 60, Loss: 0.059414, Agg: 0.028059
train [277/500][20000/26955] n_relations: 60, Loss: 0.031333, Agg: 0.027375
train [277/500][25000/26955] n_relations: 62, Loss: 0.019563, Agg: 0.028245
train [277/500] Loss: 0.0281
train [278/500][0/26955] n_relations: 62, Loss: 0.009077, Agg: 0.009077
train [278/500][5000/26955] n_relations: 60, Loss: 0.015813, Agg: 0.023002
train [278/500][10000/26955] n_relations

## Evaluation

In [39]:
stat_path = stat_path = os.path.join(args.dataf, 'stat.h5')
stat = load_data(data_names[:2], stat_path)

In [40]:
use_gpu = torch.cuda.is_available()
model = IntNet(args, stat, phases_dict, residual=True, use_gpu=use_gpu)

In [41]:
model_file = os.path.join(args.outf, 'IntNet_epoch_%d.pth' % (300))

In [42]:
print("Loading network from %s" % model_file)
model.load_state_dict(torch.load(model_file))
model.eval()

criterionMSE = nn.MSELoss()

if use_gpu:
    model.cuda()

Loading network from dump_SingleHair/files_SingleHair/IntNet_epoch_300.pth


In [43]:
idx = 0



for step in range(300 - 1):
    data_path = os.path.join(args.dataf, 'valid', str(idx), str(step) + '.h5')
    data_nxt_path = os.path.join(args.dataf, 'valid', str(idx), str(step + 1) + '.h5')

    data = load_data(data_names, data_path)
    data_nxt = load_data(data_names, data_nxt_path)
    velocities_nxt = data_nxt[1]

    if step == 0:
        positions, velocities, hairs_idx = data
        n_shapes = 1
        scene_params = 11
        count_nodes = positions.shape[0]
        n_particles = count_nodes - n_shapes
        p_gt = np.zeros((args.time_step - 1, n_particles + n_shapes, args.position_dim))
        s_gt = np.zeros((args.time_step - 1, n_shapes, args.shape_state_dim))
        v_nxt_gt = np.zeros((args.time_step - 1, n_particles + n_shapes, args.position_dim))

        p_pred = np.zeros((args.time_step - 1, n_particles + n_shapes, args.position_dim))

    p_gt[step] = positions[:, -args.position_dim:]
    v_nxt_gt[step] = velocities_nxt[:, -args.position_dim:]
    
    s_gt[step, :, :3] = positions[n_particles:, :3]
    s_gt[step, :, 3:6] = p_gt[max(0, step-1), n_particles:, :3]
    s_gt[step, :, 6:] = np.array( [[0.,0.70710677,0,0.70710677,0,0.70710677, 0, 0.70710677]])
    positions = positions + velocities_nxt * args.dt


In [44]:
step = 0
mass = np.zeros((n_particles, 1))
p = np.concatenate([p_gt[step, :n_particles], mass], 1)
p

array([[0.        , 4.        , 0.        , 0.        ],
       [0.        , 3.88348913, 0.        , 0.        ],
       [0.        , 3.76832294, 0.        , 0.        ],
       [0.        , 3.64844704, 0.        , 0.        ],
       [0.        , 3.53010917, 0.        , 0.        ],
       [0.        , 3.4117341 , 0.        , 0.        ],
       [0.        , 3.29377675, 0.        , 0.        ],
       [0.        , 3.17614532, 0.        , 0.        ],
       [0.        , 3.05889988, 0.        , 0.        ],
       [0.        , 2.94205379, 0.        , 0.        ],
       [0.        , 2.82562876, 0.        , 0.        ],
       [0.        , 2.70964026, 0.        , 0.        ],
       [0.        , 2.59410095, 0.        , 0.        ],
       [0.        , 2.47902036, 0.        , 0.        ],
       [0.        , 2.36440516, 0.        , 0.        ],
       [0.        , 2.25025892, 0.        , 0.        ],
       [0.        , 2.13658237, 0.        , 0.        ],
       [0.        , 2.02337432,

In [9]:
import pyflex
pyflex.init()

cap_size = [0.1,1.5]
N_hairs = 1
env_idx = 12
scene_params = np.array(cap_size)
pyflex.set_scene(env_idx, scene_params, 0)

for step in range(args.time_step - 1):
    pyflex.set_shape_states(s_gt[step])

    mass = np.zeros((n_particles, 1))
    p = np.concatenate([p_gt[step, :n_particles], mass], 1)

    pyflex.set_positions(p)
    pyflex.render()
    
pyflex.clean()

In [50]:
idx = 0
data_path = os.path.join(args.dataf, 'valid', str(idx), '0.h5')
data = load_data(data_names, data_path)
instance_idx = [0, 31]
p_pred = np.zeros((args.time_step - 1, n_particles, args.position_dim))

In [51]:
p_pred.shape

(599, 31, 3)

In [47]:
data[0].shape

(32, 3)

In [52]:
for step in range(args.time_step - 1):
    p_pred[step] = data[0][:n_particles]
    attr, state, rels, n_particles, n_shapes = prepare_input(data, stat, args, phases_dict, 0)
    Ra, node_r_idx, node_s_idx, pstep = rels[3], rels[4], rels[5], rels[6]
    
    '''
    Rr, Rs = [], []
    for j in range(len(rels[0])):
        Rr_idx, Rs_idx, values = rels[0][j], rels[1][j], rels[2][j]
        Rr.append(torch.sparse.FloatTensor(
            Rr_idx, values, torch.Size([node_r_idx[j].shape[0], Ra[j].size(0)])))
        Rs.append(torch.sparse.FloatTensor(
            Rs_idx, values, torch.Size([node_s_idx[j].shape[0], Ra[j].size(0)])))

    buf = [attr, state, Rr, Rs, Ra]
    '''
    
    Rr, Rs = [], []
    Values = []

    for j in range(len(rels[0])):
        Rr_idx, Rs_idx, values = rels[0][j], rels[1][j], rels[2][j]
        V = torch.ones(values.shape)
        Values.append(values)
        Rr.append(torch.sparse.FloatTensor(
            Rr_idx, V, torch.Size([node_r_idx[j].shape[0], Ra[j].size(0)])))
        Rs.append(torch.sparse.FloatTensor(
            Rs_idx, V, torch.Size([node_s_idx[j].shape[0], Ra[j].size(0)])))
    buf = [attr, state, Rr, Rs, Ra, Values]



        # st_time = time.time()
    with torch.set_grad_enabled(phase=='train'):
        if use_gpu:
            for d in range(len(buf)):
                if type(buf[d]) == list:
                    for t in range(len(buf[d])):
                        buf[d][t] = Variable(buf[d][t].cuda())
                else:
                    buf[d] = Variable(buf[d].cuda())
        else:
            for d in range(len(buf)):
                if type(buf[d]) == list:
                    for t in range(len(buf[d])):
                        buf[d][t] = Variable(buf[d][t])
                else:
                    buf[d] = Variable(buf[d])

        attr, state, Rr, Rs, Ra, Values = buf

        pstep = 3

        predicted = model(
            attr, state, Rr, Rs, Ra,Values, n_particles,
            node_r_idx, node_s_idx, pstep,
            instance_idx, phases_dict, 0)
        
        vels = denormalize([predicted.data.cpu().numpy()], [stat[1]])[0]
        data[0][:n_particles] += vels * args.dt
        data[1][:n_particles] = vels

In [59]:
p_pred[150]

array([[-1.07669001e-02,  4.04614782e+00, -3.90244204e-05],
       [-6.76894039e-02,  4.22975588e+00,  1.81299620e-05],
       [-2.97123622e-02,  3.93414545e+00, -3.90244204e-05],
       [-6.85890540e-02,  4.03060198e+00, -2.39675301e-05],
       [-2.74079274e-02,  3.82771707e+00, -3.67006214e-05],
       [ 9.85515217e-05,  4.01023293e+00, -9.90541776e-06],
       [ 1.20577678e-01,  3.79373288e+00, -3.73473667e-05],
       [ 3.81798476e-01,  4.10228491e+00,  1.53673245e-04],
       [ 1.09315419e+00,  4.84573030e+00,  4.43061406e-04],
       [ 1.03009713e+00,  4.75158024e+00,  4.04473249e-04],
       [ 1.00497949e+00,  4.72040558e+00,  2.42982758e-04],
       [ 9.34327304e-01,  4.62511778e+00, -3.90244204e-05],
       [ 8.58312905e-01,  4.53815174e+00, -3.90244204e-05],
       [ 7.84302413e-01,  4.45870733e+00, -3.90244204e-05],
       [ 7.21382141e-01,  4.38916063e+00, -3.90244204e-05],
       [ 6.35276198e-01,  4.31076717e+00, -3.90244204e-05],
       [ 3.54287386e-01,  3.96515274e+00

In [None]:
import pyflex
pyflex.init()

recs = []

for idx in range(3):

    print("Rollout %d / %d" % (idx, 3))

    # ground truth
    for step in range(args.time_step - 1):
        data_path = os.path.join(args.dataf, 'valid', str(infos[idx]), str(step) + '.h5')
        data_nxt_path = os.path.join(args.dataf, 'valid', str(infos[idx]), str(step + 1) + '.h5')

        data = load_data(data_names, data_path)
        data_nxt = load_data(data_names, data_nxt_path)
        velocities_nxt = data_nxt[1]

        if step == 0:
            if args.env == 'BoxBath':
                positions, velocities, clusters = data
                n_shapes = 0
                scene_params = np.zeros(1)
            elif args.env == 'FluidFall':
                positions, velocities = data
                n_shapes = 0
                scene_params = np.zeros(1)
            elif args.env == 'RiceGrip':
                positions, velocities, shape_quats, clusters, scene_params = data
                n_shapes = shape_quats.shape[0]
            elif args.env == 'FluidShake':
                positions, velocities, shape_quats, scene_params = data
                n_shapes = shape_quats.shape[0]
            else:
                raise AssertionError("Unsupported env")

            count_nodes = positions.shape[0]
            n_particles = count_nodes - n_shapes
            print("n_particles", n_particles)
            print("n_shapes", n_shapes)

            p_gt = np.zeros((args.time_step - 1, n_particles + n_shapes, args.position_dim))
            s_gt = np.zeros((args.time_step - 1, n_shapes, args.shape_state_dim))
            v_nxt_gt = np.zeros((args.time_step - 1, n_particles + n_shapes, args.position_dim))

            p_pred = np.zeros((args.time_step - 1, n_particles + n_shapes, args.position_dim))

        p_gt[step] = positions[:, -args.position_dim:]
        v_nxt_gt[step] = velocities_nxt[:, -args.position_dim:]

        # print(step, np.sum(np.abs(v_nxt_gt[step, :args.n_particles])))

        if args.env == 'RiceGrip' or args.env == 'FluidShake':
            s_gt[step, :, :3] = positions[n_particles:, :3]
            s_gt[step, :, 3:6] = p_gt[max(0, step-1), n_particles:, :3]
            s_gt[step, :, 6:10] = data[2]
            s_gt[step, :, 10:] = data[2]

        positions = positions + velocities_nxt * args.dt

    # model rollout
    data_path = os.path.join(args.dataf, 'valid', str(infos[idx]), '0.h5')
    data = load_data(data_names, data_path)

    for step in range(args.time_step - 1):
        if step % 10 == 0:
            print("Step %d / %d" % (step, args.time_step - 1))

        p_pred[step] = data[0]

        if args.env == 'RiceGrip' and step == 0:
            data[0] = p_gt[step + 1].copy()
            data[1] = np.concatenate([v_nxt_gt[step]] * (args.n_his + 1), 1)
            continue

        # st_time = time.time()
        attr, state, rels, n_particles, n_shapes, instance_idx = \
                prepare_input(data, stat, args, phases_dict, args.verbose_data)

        Ra, node_r_idx, node_s_idx, pstep = rels[3], rels[4], rels[5], rels[6]

        Rr, Rs = [], []
        for j in range(len(rels[0])):
            Rr_idx, Rs_idx, values = rels[0][j], rels[1][j], rels[2][j]
            Rr.append(torch.sparse.FloatTensor(
                Rr_idx, values, torch.Size([node_r_idx[j].shape[0], Ra[j].size(0)])))
            Rs.append(torch.sparse.FloatTensor(
                Rs_idx, values, torch.Size([node_s_idx[j].shape[0], Ra[j].size(0)])))

        buf = [attr, state, Rr, Rs, Ra]

        with torch.set_grad_enabled(False):
            if use_gpu:
                for d in range(len(buf)):
                    if type(buf[d]) == list:
                        for t in range(len(buf[d])):
                            buf[d][t] = Variable(buf[d][t].cuda())
                    else:
                        buf[d] = Variable(buf[d].cuda())
            else:
                for d in range(len(buf)):
                    if type(buf[d]) == list:
                        for t in range(len(buf[d])):
                            buf[d][t] = Variable(buf[d][t])
                    else:
                        buf[d] = Variable(buf[d])

            attr, state, Rr, Rs, Ra = buf
            # print('Time prepare input', time.time() - st_time)

            # st_time = time.time()
            vels = model(
                attr, state, Rr, Rs, Ra, n_particles,
                node_r_idx, node_s_idx, pstep,
                instance_idx, phases_dict, args.verbose_model)
            # print('Time forward', time.time() - st_time)

            # print(vels)

            if args.debug:
                data_nxt_path = os.path.join(args.dataf, 'valid', str(infos[idx]), str(step + 1) + '.h5')
                data_nxt = normalize(load_data(data_names, data_nxt_path), stat)
                label = Variable(torch.FloatTensor(data_nxt[1][:n_particles]).cuda())
                # print(label)
                loss = np.sqrt(criterionMSE(vels, label).item())
                print(loss)

        vels = denormalize([vels.data.cpu().numpy()], [stat[1]])[0]

        if args.env == 'RiceGrip' or args.env == 'FluidShake':
            vels = np.concatenate([vels, v_nxt_gt[step, n_particles:]], 0)
        data[0] = data[0] + vels * args.dt

        if args.env == 'RiceGrip':
            # shifting the history
            # positions, restPositions
            data[1][:, args.position_dim:] = data[1][:, :-args.position_dim]
        data[1][:, :args.position_dim] = vels

        if args.debug:
            data[0] = p_gt[step + 1].copy()
            data[1][:, :args.position_dim] = v_nxt_gt[step]

    ##### render for the ground truth
    pyflex.set_scene(env_idx, scene_params, 0)

    if args.env == 'RiceGrip':
        halfEdge = np.array([0.15, 0.8, 0.15])
        center = np.array([0., 0., 0.])
        quat = np.array([1., 0., 0., 0.])
        pyflex.add_box(halfEdge, center, quat)
        pyflex.add_box(halfEdge, center, quat)
    elif args.env == 'FluidShake':
        x, y, z, dim_x, dim_y, dim_z, box_dis_x, box_dis_z = scene_params
        boxes = calc_box_init_FluidShake(box_dis_x, box_dis_z, height, border)

        x_box = x + (dim_x-1)/2.*0.055

        for box_idx in range(len(boxes) - 1):
            halfEdge = boxes[box_idx][0]
            center = boxes[box_idx][1]
            quat = boxes[box_idx][2]
            pyflex.add_box(halfEdge, center, quat)


    for step in range(args.time_step - 1):
        if args.env == 'RiceGrip':
            pyflex.set_shape_states(s_gt[step])
        elif args.env == 'FluidShake':
            pyflex.set_shape_states(s_gt[step, :-1])

        mass = np.zeros((n_particles, 1))
        if args.env == 'RiceGrip':
            p = np.concatenate([p_gt[step, :n_particles, -3:], mass], 1)
        else:
            p = np.concatenate([p_gt[step, :n_particles], mass], 1)

        pyflex.set_positions(p)
        pyflex.render(capture=0)

    ##### render for the predictions
    pyflex.set_scene(env_idx, scene_params, 0)

    if args.env == 'RiceGrip':
        pyflex.add_box(halfEdge, center, quat)
        pyflex.add_box(halfEdge, center, quat)
    elif args.env == 'FluidShake':
        for box_idx in range(len(boxes) - 1):
            halfEdge = boxes[box_idx][0]
            center = boxes[box_idx][1]
            quat = boxes[box_idx][2]
            pyflex.add_box(halfEdge, center, quat)

    for step in range(args.time_step - 1):
        if args.env == 'RiceGrip':
            pyflex.set_shape_states(s_gt[step])
        elif args.env == 'FluidShake':
            pyflex.set_shape_states(s_gt[step, :-1])

        mass = np.zeros((n_particles, 1))
        if args.env == 'RiceGrip':
            p = np.concatenate([p_pred[step, :n_particles, -3:], mass], 1)
        else:
            p = np.concatenate([p_pred[step, :n_particles], mass], 1)

        pyflex.set_positions(p)
        pyflex.render(capture=0)

pyflex.clean()