In [1]:
import matplotlib.pyplot as plt
import jax.numpy as np
import jax
from jax import jit, value_and_grad, grad
import numpy as onp
import torch
import torch.nn as nn

In [2]:
@jit
def compute_potential_point(coord, target_coord):
    return 20 * np.sum((coord - target_coord) ** 2)

@jit
def compute_rolling_friction_force(velocity, mass, radius, f, g=9.8):
    return - np.sign(velocity) * mass * g * radius * f / radius

@jit
def compute_acceleration(potential_force, friction_force, mass):
    return (potential_force + friction_force) / mass

@jit
def get_new_cv(current_coordinate, current_velocity, acceleration, dt):
    new_velocity = current_velocity + acceleration * dt
    new_coordinate = current_coordinate + new_velocity * dt
    return new_coordinate, new_velocity

@jit
def run_sim(coordinate_init, velocity_init, target_coordinate, constants):
    trajectory = []
    sim_time = 0.2
    n_steps = 20
    dt = sim_time / n_steps
    coordinate = coordinate_init
    velocity = velocity_init
    for t in np.linspace(0, sim_time, n_steps):
        trajectory.append(coordinate)
        l2_force = - grad(compute_potential_point)(coordinate, target_coordinate)
        friction_force = compute_rolling_friction_force(velocity,
                                                        constants['mass'],
                                                        constants['radius'],
                                                        constants['f'])
        acceleration = compute_acceleration(l2_force,
                                            friction_force,
                                            constants['mass'])
        coordinate, velocity = get_new_cv(coordinate, velocity, acceleration, dt)
    return coordinate, velocity, trajectory

@jit
def compute_loss(coordinate_init, velocity_init, target_coordinate, attractor, constants):
    final_coord, final_velocity, trajectory = run_sim(coordinate_init, velocity_init, attractor, constants)
    return np.sum(np.abs(final_coord - target_coordinate))

@jit
def compute_loss_sequential(coordinate_init, velocity_list, target_coordinate, attractor, constants):
    assert len(velocity_list) == 5
    coordinate = coordinate_init
    for action_id in range(5):
        final_coord, final_velocity, trajectory = run_sim(coordinate, velocity_list[action_id], attractor, constants)
        coordinate = final_coord
    return np.sum(np.abs(final_coord - target_coordinate))


In [3]:
from collections import namedtuple


In [4]:
constants = {}
constants['radius'] = 0.05
constants['ro'] = 1000.
constants['volume'] = 4 * np.pi * (constants['radius'] ** 3) / 3
constants['mass'] = constants['volume'] * constants['ro']
constants['f'] = 0.007
const = namedtuple('Constants', list(constants.keys()))
target_coordinate = np.array([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]])
coordinate_init = np.array([[0.2, 0.4], [0.2, 0.4], [0.2, 0.4]])
velocity_init = np.array([[1., 0.], [1., 0.1], [1., 0.]])
attractor = np.array([[0., 0.], [0., 0.], [0., 0.]])

#%time trajectory = run_sim(coordinate_init, velocity_init, coordinate_target, sim_time, n_steps)

In [5]:
test = jax.vmap(run_sim, in_axes=(0, 0, 0, None))(coordinate_init, velocity_init, target_coordinate, constants)

In [6]:
test

(DeviceArray([[0.677606  , 0.52097744],
              [0.677606  , 0.5321716 ],
              [0.677606  , 0.52097744]], dtype=float32),
 DeviceArray([[2.441649 , 0.8532523],
              [2.441649 , 0.8399593],
              [2.441649 , 0.8532523]], dtype=float32),
 [DeviceArray([[0.2, 0.4],
               [0.2, 0.4],
               [0.2, 0.4]], dtype=float32),
  DeviceArray([[0.21228497, 0.40076396],
               [0.21228497, 0.4017571 ],
               [0.21228497, 0.40076396]], dtype=float32),
  DeviceArray([[0.22676106, 0.40227914],
               [0.22676106, 0.40425783],
               [0.22676106, 0.40227914]], dtype=float32),
  DeviceArray([[0.24331768, 0.404534  ],
               [0.24331768, 0.40748313],
               [0.24331768, 0.404534  ]], dtype=float32),
  DeviceArray([[0.26182836, 0.40751132],
               [0.26182836, 0.41140833],
               [0.26182836, 0.40751132]], dtype=float32),
  DeviceArray([[0.28215167, 0.41118833],
               [0.28215167, 0.416

In [7]:
target_coordinate = np.array([[0.9, 0.5], [0.9, 0.5], [0.9, 0.5]])
coordinate_init = np.array([[0.2, 0.4], [0.2, 0.4], [0.5, 0.4]])
velocity_init = np.array([[1., 0.1], [1., 0.], [0., 0.]])
attractor = np.array([[0., 0.], [0., 0.], [0., 0.]])

# from functools import partial
# vmap(partial(compute_loss, config=dictionary))(X[i:i+batch], y[i:i+batch])

vmapped_loss = jax.vmap(compute_loss, [0, 0, 0, 0, None])
v_g_loss = value_and_grad(lambda c,v,t,a,con: np.sum(vmapped_loss(c,v,t,a,con)), 1)

In [8]:
v_g_loss(coordinate_init, velocity_init, target_coordinate, attractor, constants)

(DeviceArray(4.421179, dtype=float32),
 DeviceArray([[-0.11271466, -0.11271466],
              [-0.11271466, -0.11271466],
              [-0.11271466, -0.11271466]], dtype=float32))

In [20]:
@jit
def func(a, b):
    return a * b ** 2, (b, a)

In [21]:
value_and_grad(func, has_aux=True)(3., 2.)

((DeviceArray(12., dtype=float32),
  (DeviceArray(2., dtype=float32), DeviceArray(3., dtype=float32))),
 DeviceArray(4., dtype=float32))

In [9]:
velocity_list = [velocity_init, velocity_init, velocity_init, velocity_init, velocity_init]
value_and_grad(compute_loss_sequential, 1)(coordinate_init, velocity_list, target_coordinate, attractor, constants)

(DeviceArray(4.008059, dtype=float32),
 [DeviceArray([[-0.00026383, -0.00026383],
               [-0.00026383, -0.00026383],
               [-0.00026383, -0.00026383]], dtype=float32),
  DeviceArray([[0.00119946, 0.00119946],
               [0.00119946, 0.00119946],
               [0.00119946, 0.00119946]], dtype=float32),
  DeviceArray([[-0.0054532, -0.0054532],
               [-0.0054532, -0.0054532],
               [-0.0054532, -0.0054532]], dtype=float32),
  DeviceArray([[0.02479224, 0.02479224],
               [0.02479224, 0.02479224],
               [0.02479224, 0.02479224]], dtype=float32),
  DeviceArray([[-0.11271466, -0.11271466],
               [-0.11271466, -0.11271466],
               [-0.11271466, -0.11271466]], dtype=float32)])

In [None]:
*[list(constants.values())] * 3

In [None]:
onp.random.uniform((-1, 0), (0, 2), (3, 2))

In [None]:
onp.array([const] * 3)

In [None]:
class Controller(nn.Module):
    def __init__(self):
        super().__init__()
        self.controller = nn.Sequential(nn.Linear(5, 20),
                                        nn.ReLU(),
                                        nn.Linear(20, 50),
                                        nn.ReLU(),
                                        nn.Linear(50, 2))
    
    def forward(self, x):
        return self.controller(x)

In [None]:
ctrl = Controller()
opt = torch.optim.Adam(ctrl.parameters())

In [None]:
from time import time
velocity_init = np.array([1., 0.])
for step in range(1000):
    s = time()
    
    coordinate_init = np.array(onp.random.uniform(-1., 1., size=(3, 2)))
    dist = onp.linalg.norm(coordinate_init - target_coordinate, axis=1).reshape(-1, 1)
    direction = (coordinate_init - target_coordinate) / dist

    net_inp = torch.cat([torch.from_numpy(onp.array(o)) for o in [direction, coordinate_init, dist]], dim=1)
    
    
    controller_out = ctrl(net_inp)
    velocity_init = np.array(controller_out.cpu().data.numpy())
    loss_val, v_grad = v_g_loss(coordinate_init, velocity_init, target_coordinate, attractor)
    opt.zero_grad()
    controller_out.backward(torch.from_numpy(onp.array(v_grad)))
    opt.step()
    if step % 50 == 0:
        print(time() - s, loss_val, velocity_init, v_grad)


In [None]:
torch.from_numpy(onp.array(v_grad))

In [None]:
target_coordinate = np.array([[0.9, 0.5]])
coordinate_init = np.array([[-1.2, -0.4]])
attractor = np.array([[0., 0.]])

dist = onp.linalg.norm(coordinate_init - target_coordinate, axis=1).reshape(-1, 1)
direction = (coordinate_init - target_coordinate) / dist
net_inp = torch.cat([torch.from_numpy(onp.array(o)) for o in [direction, coordinate_init, dist]], dim=1)

controller_out = ctrl(net_inp)
velocity_init = np.array(controller_out.cpu().data.numpy())


In [None]:
final_coordinate, trajectory = run_sim(coordinate_init, velocity_init, attractor, constants)

In [None]:
traj = onp.array(trajectory)[:, 0, :]

In [None]:
fig, ax = plt.subplots() # note we must use plt.subplots, not plt.subplot
# (or if you have an existing figure)
# fig = plt.gcf()
# ax = fig.gca()

ax.plot(traj[:, 0], traj[:, 1])
ax.scatter(attractor[0, 0], attractor[0, 1], c='b', label='attractor')
ax.scatter(target_coordinate[0, 0], target_coordinate[0, 1], c='r', label='target')
ax.scatter(coordinate_init[0, 0], coordinate_init[0, 1], c='g', label='init')
fig.legend()
#ax.set_xlim(-0.5, 0.5)
#ax.set_ylim(-0.5, 0.5)
fig.show()