In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from functools import partial
from field import lin_interp, linear_ma, gradient

In [6]:
# @jax.jit
def step(
        params,
        grid_star : jnp.ndarray,
        dx : float,
        rate : float):
        
    pos = params["pos"]
    weight = params["weight"]
    grid_size = grid_star.shape[0]
        
    grid = linear_ma(pos, weight, grid_size, dx)

    # define the desired change 
    field = (grid - grid_star)

    # get the gradient grid
    grad_field = gradient(field, dx)

    grad_pos_x = lin_interp(
        pos, 
        grad_field[0])
    
    grad_pos_y = lin_interp(
        pos, 
        grad_field[1])
    
    grad_pos_z = lin_interp(
        pos, 
        grad_field[2])
    
    grad_pos = jnp.stack([
        grad_pos_x, grad_pos_y, grad_pos_z])

    params["pos"] = pos + rate * grad_pos

    return params


In [37]:
def mse(params, grid_star : jnp.ndarray):
    pos = params["pos"]
    weight = params["weight"]
    grid_size = grid_star.shape[0]
        
    grid = linear_ma(pos, weight, grid_size, 1)

    return jnp.mean((grid_star - grid)**2)

def optimize(
        params, 
        grid_star : jnp.ndarray, 
        dx : float, 
        n_steps : int,
        rate : float,
        heat : float) -> jnp.ndarray:
    
    for i in range(n_steps):
        p = i * 1 / n_steps
        params = step(params, grid_star, dx, rate)
        params["pos"] = params["pos"] +\
            jax.random.normal(jax.random.PRNGKey(0), params["pos"].shape) * (1-p) * heat
        print(f"mse : {mse(params, grid_star)}")

    return params


In [38]:
grid_size = 32
N = grid_size ** 3
dx = 1

key = jax.random.key(0)

key_pos, key_weight = jax.random.split(key)

# create particles
pos = jax.random.uniform(key_pos, (3, N))
# weight = jax.random.uniform(key_weight, (N,)) / N * (grid_size ** 3)
weight = jnp.ones((N, )) / N * (grid_size ** 3)

# assign particles 
grid_star = linear_ma(pos, weight, grid_size, dx)

# no matter what N or grid_size is -> ~0.5
print(jnp.mean(grid_star))

1.0


In [41]:
key = jax.random.key(1)

key_pos, key_weight = jax.random.split(key)

pos = jax.random.uniform(key_pos, (3, N)) 
# weight = jax.random.uniform(key_weight, (N, )) / N * (grid_size ** 3)
weight = jnp.ones((N, )) / N * (grid_size ** 3)

params_init = {
    'pos' : pos,
    'weight' : weight}

params = optimize(params_init, grid_star, dx, 1000, 0.001, 0.001)

mse : 1.8157958984375
mse : 1.66522216796875
mse : 1.56671142578125
mse : 1.4921875
mse : 1.43548583984375
mse : 1.3887939453125
mse : 1.3546142578125
mse : 1.3233642578125
mse : 1.30181884765625
mse : 1.2742919921875
mse : 1.25848388671875
mse : 1.25421142578125
mse : 1.23431396484375
mse : 1.2315673828125
mse : 1.21807861328125
mse : 1.204345703125
mse : 1.1947021484375
mse : 1.2010498046875
mse : 1.1865234375
mse : 1.18505859375
mse : 1.18072509765625
mse : 1.182861328125
mse : 1.186767578125
mse : 1.18701171875
mse : 1.19097900390625
mse : 1.1834716796875
mse : 1.17425537109375
mse : 1.1748046875
mse : 1.17828369140625
mse : 1.18267822265625
mse : 1.171142578125
mse : 1.17388916015625
mse : 1.1767578125
mse : 1.18505859375
mse : 1.17034912109375
mse : 1.16680908203125
mse : 1.1767578125
mse : 1.174560546875
mse : 1.1614990234375
mse : 1.16595458984375
mse : 1.15655517578125
mse : 1.15985107421875
mse : 1.1473388671875
mse : 1.15478515625
mse : 1.14617919921875
mse : 1.1427612304687

In [46]:
print(params["pos"])

grid = linear_ma(params["pos"], params["weight"], grid_size, 0.1)

# define the desired change 
field = (grid - grid_star)**2


# plot side by side
plt.figure()
plt.subplot(1, 3, 1)
plt.imshow(grid_star[grid_size // 2, :, :])
plt.subplot(1, 3, 2)
plt.imshow(grid[grid_size // 2, :, :])
plt.subplot(1, 3, 3)
plt.imshow(field[grid_size // 2, :, :])
plt.show()

[[ 0.32640833  1.2600455   0.36404896 ...  0.5622338   0.09272025
   0.9424865 ]
 [ 0.4644098  -0.436663    0.2614224  ... -0.02800144  1.012751
   0.6078537 ]
 [ 0.00967982  0.27013174  0.40404823 ...  0.1913875   0.46287608
   0.21914636]]
