In [1]:
# IMPORTS

import numpy as onp
import jax.numpy as np
import optax
from jax import random
from jax import jit
from jax import vmap
from jax import grad
from jax import lax
from jax import config
config.update("jax_enable_x64", True)
config.update('jax_debug_nans', True)
config.update('jax_traceback_filtering', 'off')

from jax_md import space, smap, energy, minimize, quantity, simulate, partition, util
from jax_md.util import f32

from collections import namedtuple

vectorize = np.vectorize

from functools import partial
from simulator.utils import ttc_force, _normalize
from simulator.force import general_force_generator
from simulator.render import render
from simulator.dynamics import pedestrian, PedestrianState, StraightWall

In [2]:
# POSITIONS AND VELOCITIES
POS_NUMS = 2
ANGLE_NUMS = 2
R = 0.1
K = 1.5
T_0 = 3.0
V_2_MAG = 1.
CYCLE_NUMS = 1000

positions = np.stack([np.linspace(0.3, 2., POS_NUMS), np.zeros([POS_NUMS,])], axis=1)
angles = np.linspace(0, 1/2, ANGLE_NUMS)
v_2 = V_2_MAG * np.stack([np.cos(onp.pi * angles), np.sin(onp.pi * angles)], axis=1)
v_1 = np.zeros([ANGLE_NUMS, 2])

In [3]:
# PARAMS INIT
paral_weights = np.ones([10, 10, 10])
perpen_weights = np.ones([10, 10, 10])
d_0 = 10.
v_0 = 10.

def _loss_fn(params, pos, v1, v2):
    paral_weights = params['paral']
    perpen_weights = params['perpen']
    v_0 = params['v0']
    d_0 = params['d0']
    return np.linalg.norm(general_force_generator(paral_weights, perpen_weights, v_0, d_0)(pos, v1, v2) - ttc_force(pos, v1, v2, R, K, T_0)) ** 2

def loss_fn(params, pos, v1, v2):
    # loss_fn = sum over sets of (pos, v) ||F_pred - F||^2
    full_loss_fn = vmap(vmap(_loss_fn, (None, None, 0, 0)), (None, 0, None, None))
    return np.sum(full_loss_fn(params, pos, v1, v2))

In [4]:
# OPTIMIZATION
start_learning_rate = 0.1
optimizer = optax.adam(start_learning_rate)

# PARAMS
params = {'paral' : paral_weights,
          'perpen' : perpen_weights,
          'd0' : d_0,
          'v0' : v_0}
opt_state = optimizer.init(params)

In [5]:
loss_fn(params, positions, v_1, v_2)



Array(2.54108976e+41, dtype=float64)

In [6]:
# UPDATE LOOP
for i in range(CYCLE_NUMS):
  print(f"Current update loop: {i}/{CYCLE_NUMS}")
  print(f"Current loss: {loss_fn(params, positions, v_1, v_2)}")
  grads = grad(loss_fn)(params, positions, v_1, v_2)
  # fails before this point
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)


Current update loop: 0/1000
Current loss: 2.5410897621699358e+41
Current update loop: 1/1000
Current loss: 1.615593515834843e+41
Current update loop: 2/1000
Current loss: 1.1488173994291905e+41
Current update loop: 3/1000
Current loss: 1.0704131864393428e+41
Current update loop: 4/1000
Current loss: 7.828423765735825e+40
Current update loop: 5/1000
Current loss: 1.072952334892542e+41
Current update loop: 6/1000
Current loss: 7.583370539510201e+40
Current update loop: 7/1000
Current loss: 4.955658073126851e+40
Current update loop: 8/1000
Current loss: 2.9065403806286126e+40
Current update loop: 9/1000
Current loss: 1.4426209100003102e+40
Current update loop: 10/1000
Current loss: 5.3657556679993155e+39
Current update loop: 11/1000
Current loss: 1.294229636573991e+39
Current update loop: 12/1000
Current loss: 1.3115193476344079e+39
Current update loop: 13/1000
Current loss: 3.8342065806760303e+39
Current update loop: 14/1000
Current loss: 7.676349215646842e+39
Current update loop: 15/100

KeyboardInterrupt: 

In [None]:
def fn_1(x):
    return x + 1

def fn_2(x):
    return x - 1

A = [fn_1, fn_2]
x = np.array(0)


Array([[0, 2],
       [1, 4]], dtype=int64)

In [None]:
test = dict(a=9, b="Nope")
print(test)

{'a': 9, 'b': 'Nope'}


AttributeError: 'dict' object has no attribute 'a'