-
Notifications
You must be signed in to change notification settings - Fork 91
/
forecastTrajectories.py
150 lines (127 loc) · 6.43 KB
/
forecastTrajectories.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import sys
try:
sys.path.remove('/usr/local/lib/python2.7/dist-packages/Theano-0.6.0-py2.7.egg')
except:
print 'Theano 0.6.0 version not found'
import numpy as np
import argparse
import theano
import os
from theano import tensor as T
from neuralmodels.utils import permute
from neuralmodels.loadcheckpoint import *
from neuralmodels.costs import softmax_loss, euclidean_loss
from neuralmodels.models import *
from neuralmodels.predictions import OutputMaxProb, OutputSampleFromDiscrete
from neuralmodels.layers import *
from neuralmodels.updates import Adagrad,RMSprop,Momentum,Adadelta
import cPickle
import pdb
import socket as soc
import copy
import readCRFgraph as graph
import time
from unNormalizeData import unNormalizeData
from convertToSingleVec import convertToSingleVec
global rng
rng = np.random.RandomState(1234567890)
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--checkpoint',type=str,default='checkpoint')
parser.add_argument('--forecast',type=str,default='malik')
parser.add_argument('--motion_prefix',type=int,default=50)
parser.add_argument('--motion_suffix',type=int,default=100)
parser.add_argument('--temporal_features',type=int,default=0)
parser.add_argument('--full_skeleton',type=int,default=1)
parser.add_argument('--dataset_prefix',type=str,default='')
parser.add_argument('--train_for',type=str,default='final')
parser.add_argument('--drop_features',type=int,default=0)
parser.add_argument('--drop_id',type=int,default=9)
args = parser.parse_args()
'''Loads H3.6m dataset'''
print 'Loading H3.6m'
sys.path.insert(0,'CRFProblems/H3.6m')
import processdata as poseDataset
poseDataset.T = 150
poseDataset.delta_shift = 100
poseDataset.num_forecast_examples = 24
poseDataset.motion_prefix = args.motion_prefix
poseDataset.motion_suffix = args.motion_suffix
poseDataset.temporal_features = args.temporal_features
poseDataset.full_skeleton = args.full_skeleton
poseDataset.dataset_prefix = args.dataset_prefix
poseDataset.crf_file = './CRFProblems/H3.6m/crf'
poseDataset.train_for = args.train_for
poseDataset.drop_features = args.drop_features
poseDataset.drop_id = [args.drop_id]
poseDataset.runall()
print '**** H3.6m Loaded ****'
new_idx = poseDataset.new_idx
featureRange = poseDataset.nodeFeaturesRanges
path = args.checkpoint
if not os.path.exists(path):
print 'Checkpoint path does not exist. Exiting!!'
sys.exit()
crf_file = './CRFProblems/H3.6m/crf'
if args.forecast == 'srnn':
path_to_checkpoint = '{0}checkpoint'.format(path)
print "Using checkpoint at: ",path_to_checkpoint
if os.path.exists(path_to_checkpoint):
[nodeNames,nodeList,nodeFeatureLength,nodeConnections,edgeList,edgeListComplete,edgeFeatures,nodeToEdgeConnections,trX,trY,trX_validation,trY_validation,trX_forecasting,trY_forecasting,trX_forecast_nodeFeatures] = graph.readCRFgraph(poseDataset)
print trX_forecast_nodeFeatures.keys()
print 'Loading the model (this takes long, can take upto 25 minutes)'
model = loadDRA(path_to_checkpoint)
print 'Loaded S-RNN from ',path_to_checkpoint
t0 = time.time()
trY_forecasting = model.convertToSingleVec(trY_forecasting,new_idx,featureRange)
fname = 'ground_truth_forecast'
model.saveForecastedMotion(trY_forecasting,path,fname)
trX_forecast_nodeFeatures_ = model.convertToSingleVec(trX_forecast_nodeFeatures,new_idx,featureRange)
fname = 'motionprefix'
model.saveForecastedMotion(trX_forecast_nodeFeatures_,path,fname)
forecasted_motion = model.predict_sequence(trX_forecasting,trX_forecast_nodeFeatures,sequence_length=trY_forecasting.shape[0],poseDataset=poseDataset,graph=graph)
forecasted_motion = model.convertToSingleVec(forecasted_motion,new_idx,featureRange)
fname = 'forecast'
model.saveForecastedMotion(forecasted_motion,path,fname)
skel_err = np.mean(np.sqrt(np.sum(np.square((forecasted_motion - trY_forecasting)),axis=2)),axis=1)
err_per_dof = skel_err / trY_forecasting.shape[2]
fname = 'forecast_error'
model.saveForecastError(skel_err,err_per_dof,path,fname)
t1 = time.time()
del model
elif args.forecast == 'lstm3lr' or args.forecast == 'erd':
path_to_checkpoint = '{0}checkpoint.{1}'.format(path,iteration)
if os.path.exists(path_to_checkpoint):
print "Loading the model {0} (this may take sometime)".format(args.forecast)
model = load(path_to_checkpoint)
print 'Loaded the model from ',path_to_checkpoint
trX_forecasting,trY_forecasting = poseDataset.getMalikTrajectoryForecasting()
fname = 'ground_truth_forecast'
model.saveForecastedMotion(trY_forecasting,path,fname)
fname = 'motionprefix'
model.saveForecastedMotion(trX_forecasting,path,fname)
forecasted_motion = model.predict_sequence(trX_forecasting,sequence_length=trY_forecasting.shape[0])
fname = 'forecast'
model.saveForecastedMotion(forecasted_motion,path,fname)
skel_err = np.mean(np.sqrt(np.sum(np.square((forecasted_motion - trY_forecasting)),axis=2)),axis=1)
err_per_dof = skel_err / trY_forecasting.shape[2]
fname = 'forecast_error'
model.saveForecastError(skel_err,err_per_dof,path,fname)
del model
elif args.forecast == 'dracell':
path_to_checkpoint = '{0}checkpoint.{1}'.format(path,iteration)
if os.path.exists(path_to_checkpoint):
[nodeNames,nodeList,nodeFeatureLength,nodeConnections,edgeList,edgeListComplete,edgeFeatures,nodeToEdgeConnections,trX,trY,trX_validation,trY_validation,trX_forecasting,trY_forecasting,trX_forecast_nodeFeatures] = graph.readCRFgraph(poseDataset,noise=0.7,forecast_on_noisy_features=True)
print trX_forecast_nodeFeatures.keys()
print 'Loading the model'
model = loadDRA(path_to_checkpoint)
print 'Loaded DRA: ',path_to_checkpoint
t0 = time.time()
trY_forecasting = model.convertToSingleVec(trY_forecasting,new_idx,featureRange)
trX_forecast_nodeFeatures_ = model.convertToSingleVec(trX_forecast_nodeFeatures,new_idx,featureRange)
fname = 'motionprefixlong'
model.saveForecastedMotion(trX_forecast_nodeFeatures_,path,fname)
cellstate = model.predict_cell(trX_forecasting,trX_forecast_nodeFeatures,sequence_length=trY_forecasting.shape[0],poseDataset=poseDataset,graph=graph)
fname = 'forecast_celllong_{0}'.format(iteration)
model.saveCellState(cellstate,path,fname)
t1 = time.time()
del model