In [153]:
import numpy as np
import torch
import mmengine
import pickle
from sklearn.cluster import KMeans

In [154]:
info_path = '../data/CODA_motion_modes.pkl'
# info_path = '../data/CODA_test.pkl'
info = mmengine.load(info_path)

In [None]:
train_info_path = '../data/CODA_train.pkl'
f = open(train_info_path, 'rb')
train_info = pickle.load(f)
f.close()

In [24]:
obs_len = 8
pred_len = 12
n_clusters = 50
results = []
for data in train_info:
    traj = np.concatenate([np.array(data[0]), np.array(data[1])], axis=0)
    traj = traj - traj[obs_len - 1]
    ref = traj[0]
    angle = np.arctan2(ref[1], ref[0])
    rot_mat = np.array([[np.cos(angle), -np.sin(angle)],
                        [np.sin(angle), np.cos(angle)]])
    traj = np.dot(traj, rot_mat.T)
    results.append(traj)
results = np.array(results)[:100000]
cluster_data = results[:, obs_len:].reshape(results.shape[0], -1)
clf = KMeans(n_clusters=n_clusters).fit(cluster_data)
motion_modes = clf.cluster_centers_.reshape(n_clusters, -1, 2)

In [33]:
len(results)

100000

In [42]:
pos_range = [np.inf, np.inf, -np.inf, -np.inf]
pos_idx = [-1, -1, -1, -1]
for i, result in enumerate(results):
    mean_pos = np.mean(result[obs_len:], axis=0)
    if mean_pos[0] < pos_range[0]:
        pos_range[0] = mean_pos[0]
        pos_idx[0] = i
    if mean_pos[0] > pos_range[2]:
        pos_range[2] = mean_pos[0]
        pos_idx[2] = i
    if mean_pos[1] < pos_range[1]:
        pos_range[1] = mean_pos[1]
        pos_idx[1] = i
    if mean_pos[1] > pos_range[3]:
        pos_range[3] = mean_pos[1]
        pos_idx[3] = i

In [45]:
train_info[55834]

(array([[ 575376.06, -509707.1 ],
        [ 575376.06, -509707.1 ],
        [ 575376.06, -509707.1 ],
        [ 575376.06, -509707.1 ],
        [ 575376.06, -509707.1 ],
        [ 575376.06, -509707.1 ],
        [ 571229.44, -510410.44],
        [ 567343.44, -510973.84]], dtype=float32),
 array([[ 559475.3 , -503173.12],
        [ 558330.75, -495851.66],
        [ 554147.44, -491298.16],
        [ 549872.2 , -486682.72],
        [ 541826.1 , -483952.75],
        [ 533105.3 , -483064.16],
        [ 527820.6 , -480244.2 ],
        [ 522496.22, -477364.84],
        [ 522239.72, -470708.34],
        [ 515234.38, -461643.8 ],
        [ 502435.8 , -455782.94],
        [ 489720.03, -449943.06]], dtype=float32),
 array([[[ 1.0000000e+09,  1.0000000e+09],
         [ 1.0000000e+09,  1.0000000e+09],
         [ 1.0000000e+09,  1.0000000e+09],
         [ 1.0000000e+09,  1.0000000e+09],
         [ 1.0000000e+09,  1.0000000e+09],
         [-9.3581192e+01,  3.3630646e+02],
         [-9.3484955e+01,  3

In [44]:
results[55834]

array([[  7737.2383,   2502.5723],
       [  7737.2383,   2502.5723],
       [  7737.2383,   2502.5723],
       [  7737.2383,   2502.5723],
       [  7737.2383,   2502.5723],
       [  7737.2383,   2502.5723],
       [  3750.7966,   1161.8721],
       [     0.    ,      0.    ],
       [ -8987.235 ,   6479.8306],
       [-11258.331 ,  13533.628 ],
       [-16099.9   ,  17379.883 ],
       [-21041.93  ,  21272.998 ],
       [-29415.035 ,  22716.262 ],
       [-38167.81  ,  22235.52  ],
       [-43827.26  ,  24197.838 ],
       [-49535.203 ,  26212.621 ],
       [-50825.492 ,  32747.904 ],
       [-59157.35  ,  40610.52  ],
       [-72712.66  ,  44406.145 ],
       [-86182.92  ,  48193.92  ]], dtype=float32)

In [43]:
pos_range, pos_idx

([-40600.926, -120.61984, 16719.006, 55540.035], [55834, 87964, 56105, 55897])

In [32]:
motion_modes[40]

array([[-2.4663570e+02,  6.6555825e+03],
       [-4.2536381e+01,  1.1740108e+04],
       [ 1.1838752e+02,  1.7078488e+04],
       [-6.0090210e+03,  2.4203889e+04],
       [-5.2982822e+03,  3.0262061e+04],
       [-1.6728844e+04,  3.5501480e+04],
       [-2.7927158e+04,  4.0860746e+04],
       [-3.3818582e+04,  4.7137102e+04],
       [-3.9626605e+04,  5.3457750e+04],
       [-2.1309232e+04,  6.0437129e+04],
       [-1.7823463e+04,  6.5404715e+04],
       [-2.5725193e+04,  7.2588164e+04]], dtype=float32)

In [None]:
310033 125171 310168 125190 308938 114493