In [1]:
import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import optax


In [2]:
tmp_vel = jnp.array([1.0, 5.0])

tmp_other_vel = jnp.array([[2.0, 6.0], [3.0, 7.0], [4.0, 8.0]])

In [3]:
tmp_vel - tmp_other_vel

Array([[-1., -1.],
       [-2., -2.],
       [-3., -3.]], dtype=float32)

In [4]:
from dataloader import PeopleTPNextDatasetJAX, data_loader_tp_next, tp_to_linear

In [5]:
key = jax.random.key(0)
T, N = 8, 4
positions = jax.random.normal(key, (T, N, 2))
velocities = jax.random.normal(jax.random.split(key)[1], (T, N, 2))

# Random access
ds = PeopleTPNextDatasetJAX(positions, velocities)
print("len(ds) =", len(ds))  # (T-1) * N
k = tp_to_linear(3, 2, N)    # (t=3, i=2), valid since t <= T-2
s = ds[k]
print("sample:", s["pos"].shape, s["others_pos"].shape, s["next_vel"].shape)

# Batched iteration
loader = data_loader_tp_next(positions, velocities, batch_size=5, rng_key=key, shuffle=True)
batch = next(iter(loader))
print("batch pos:", batch["pos"].shape)            # (5, 2)
print("batch others_pos:", batch["others_pos"].shape)  # (5, 3, 2)
print("batch next_vel:", batch["next_vel"].shape)  # (5, 2)


len(ds) = 28
sample: (2,) (3, 2) (2,)
batch pos: (5, 2)
batch others_pos: (5, 3, 2)
batch next_vel: (5, 2)


In [6]:
batch.keys()

dict_keys(['pos', 'vel', 'next_vel', 'person_index', 'time_index', 'others_pos', 'others_vel'])

In [7]:
from functions import ForceNet, TrueForceNet

In [8]:
goal_velocities = jnp.zeros((N, 2))
pedestrian_hidden_sizes = [64, 64]
goal_hidden_sizes = [64, 64]


In [9]:
model = ForceNet(jr.PRNGKey(0), goal_velocities, pedestrian_hidden_sizes, goal_hidden_sizes)
opt = optax.adam(learning_rate=1e-3, b1=0.9, b2=0.999)
opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))

In [10]:
model(2, jnp.array([1.0,1.0]), jnp.array([1.0,2.0]), jnp.array([-1.0,3.0]))

Array([-0.12352696,  0.37512487], dtype=float32)

In [None]:
def single_loss_fn(model, pedestrian_idx, position, other_positions, velocity, other_velocities, y_velocity, dt):
    rel_disp = position - other_positions
    rel_vel = velocity - other_velocities
    f = jax.vmap(model.pedestrian_force, in_axes=(0, 0))(rel_disp, rel_vel)
    goal_f = model.goal_force(pedestrian_idx, velocity)
    return jnp.linalg.norm((goal_f + jnp.sum(f, axis=0))*dt + velocity - y_velocity)**2

def batch_loss_fn(model, pedestrian_indices, positions, other_positions, velocities, other_velocities, y_velocities, dt):
    loss_fn = jax.vmap(single_loss_fn, in_axes=(None, 0, 0, 0, 0, 0, 0, None))
    return jnp.mean(loss_fn(model, pedestrian_indices, positions, other_positions, velocities, other_velocities, y_velocities, dt))

@eqx.filter_jit
def make_step(model, pedestrian_indices, positions, other_positions, velocities, other_velocities, y_velocities, dt, opt_state, opt_update):
    loss_fn = eqx.filter_value_and_grad(batch_loss_fn)
    loss, grads = loss_fn(model, pedestrian_indices, positions, other_positions, velocities, other_velocities, y_velocities, dt)
    updates, opt_state = opt_update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

@eqx.filter_jit
def eval_step(model, pedestrian_indices, positions, other_positions, velocities, other_velocities, y_velocities, dt):
    loss = batch_loss_fn(model, pedestrian_indices, positions, other_positions, velocities, other_velocities, y_velocities, dt)
    return loss

In [12]:
loss, model, opt_state = make_step(
    model,
    batch["person_index"],
    batch["pos"],
    batch["others_pos"],
    batch["vel"],
    batch["others_vel"],
    batch["next_vel"],
    0.01,
opt_state, opt.update)

In [13]:
loss

Array(6.945366, dtype=float32)

In [14]:
tmp = jnp.ones((49,2))
jnp.sum(tmp, axis=0)

Array([49., 49.], dtype=float32)

In [15]:
goal_velocities = jnp.load("v_star.npy")
   
model = TrueForceNet(goal_velocities, tau=0.5, A=8.0, d0=0.7, B=1.2)

dataset = jnp.load("pedestrians.npz")
print(dataset['dt'])
positions = dataset["positions"]
velocities = dataset["velocities"]
# Split into train and eval sets along the first dimension (time)
num_timesteps = positions.shape[0]
num_eval = int(num_timesteps * 0.05)
num_train = num_timesteps - num_eval

# Shuffle time indices
train_idx = jnp.arange(num_train)
eval_idx = jnp.arange(num_train, num_timesteps)

# Apply split
train_positions = positions[train_idx]
train_velocities = velocities[train_idx]
eval_positions = positions[eval_idx]
eval_velocities = velocities[eval_idx]


train_loader = data_loader_tp_next(train_positions, train_velocities,
            batch_size=1,
            rng_key=jr.PRNGKey(0),
            shuffle=True, drop_last=True)

0.05


In [16]:
opt = optax.adam(learning_rate=1e-3, b1=0.9, b2=0.999)
opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))

In [17]:
batch = next(iter(train_loader))
batch["pos"].shape
batch["vel"].shape
batch["next_vel"].shape
batch["others_pos"].shape
batch["others_vel"].shape
print(batch["pos"])
print(batch["vel"])
print(batch["next_vel"])
print(batch["person_index"])
print(batch["time_index"])

[[17.189068  -3.6377563]]
[[ 1.4240818  -0.00475567]]
[[ 1.4233329  -0.00557549]]
[18]
[356]


In [22]:
batch["pos"]

Array([[17.189068 , -3.6377563]], dtype=float32)

In [21]:
print(single_loss_fn(model, batch["person_index"], batch["pos"], batch["others_pos"], batch["vel"], batch["others_vel"], batch["next_vel"], dataset["dt"]))
print(batch_loss_fn(model, batch["person_index"], batch["pos"], batch["others_pos"], batch["vel"], batch["others_vel"], batch["next_vel"], dataset["dt"]))
print(make_step(model, batch["person_index"], batch["pos"], batch["others_pos"], batch["vel"], batch["others_vel"], batch["next_vel"], dataset["dt"], opt_state, opt.update))

0.02255247
2.1684043e-19
(Array(2.1684043e-19, dtype=float32), TrueForceNet(goal_velocities=f32[50,2], tau=0.5, A=8.0, d0=0.7, B=1.2), (ScaleByAdamState(count=Array(1, dtype=int32), mu=TrueForceNet(goal_velocities=f32[50,2], tau=None, A=None, d0=None, B=None), nu=TrueForceNet(goal_velocities=f32[50,2], tau=None, A=None, d0=None, B=None)), EmptyState()))


In [19]:
person_index = 0
time_index = 235
pos = train_positions[time_index, person_index]
others_pos = train_positions[time_index, 1:] if person_index == 0 else train_positions[time_index, 0:1]
vel = train_velocities[time_index, person_index]
others_vel = train_velocities[time_index, 1:] if person_index == 0 else train_velocities[time_index, 0:1]
next_vel = train_velocities[time_index+1, person_index]

print(pos)
print(others_pos)
print(vel)
print(others_vel)
print(next_vel)


[ 4.7591434 -4.0447936]
[[ -1.6872896    0.34555712]
 [  5.7773037    2.655088  ]
 [  1.8499109    0.04773311]
 [  3.3504868   17.168934  ]
 [ 16.461939    -3.1173956 ]
 [ -4.763045     2.9769816 ]
 [ -8.962392    -5.5123854 ]
 [ -2.804605    -7.5884314 ]
 [  2.635891     2.7238073 ]
 [  5.917557   -14.037876  ]
 [ -5.5558944   -6.363122  ]
 [ -9.332295    10.579169  ]
 [ -6.9737906  -10.361043  ]
 [  4.28889     11.245201  ]
 [  1.3777896   -5.840718  ]
 [ 12.403202    -2.909596  ]
 [ -2.801527    11.485667  ]
 [  8.368222    -3.590843  ]
 [ -0.9838379  -13.467561  ]
 [ -1.1920987   -4.738559  ]
 [ -8.039457     2.55964   ]
 [ 10.008805   -14.0938    ]
 [  1.0443714  -19.772448  ]
 [ -1.2362013    3.3829339 ]
 [  3.6638246    5.699161  ]
 [-17.380713     5.3679996 ]
 [ -3.7335882    8.671069  ]
 [  4.985122    -8.2943    ]
 [-19.447084   -14.522283  ]
 [  1.6592189  -11.022123  ]
 [ -1.3687954   -9.83473   ]
 [  0.3646395   -2.581844  ]
 [ -3.900082   -16.199055  ]
 [  4.966606    -0.

In [20]:
print(single_loss_fn(model, person_index, pos, others_pos, vel, others_vel, next_vel, dataset["dt"]))
# print(batch_loss_fn(model, person_index, pos, others_pos, vel, others_vel, next_vel, dataset["dt"]))
# print(make_step(model, person_index, pos, others_pos, vel, others_vel, next_vel, dataset["dt"], opt_state, opt.update))

2.220446e-16
