# Fitting forces in Zebrafish Embryogenesis
In this last example, real data from the embryogenesis of zebrafish, obtained by lightsheet microscopy, should be used to infer the active forces acting in the tissue. In a very simplified model, we describe the problem as a viscous flow of tissue on a sphere, reducing to two dimensions. The viscous flow is manipulated by a active forcefield, which can vary arbitrarily in space and time. To obtain this non-parameterized description, the force at every position is generated by interpolation between a grid of points on the sphere and at different times, where at each such point a two dimensional force vector can be arbitrarily chosen. The optimization tries to find the force field on the interpolation grid which corresponds best to the observed tissue movement.

In [None]:
import jax.numpy as jnp
import numpy as np
from jax import jit
from jax.scipy.ndimage import map_coordinates
from jax import tree_util
from matplotlib import pyplot as plt
from matplotlib.colors import SymLogNorm
from scipy.special import sph_harm

from adoptODE import dataset_adoptODE, train_adoptODE

In [None]:
# Define the 2D zebrafish system
def define_2d_zebrafish(**kwargs_sys):
  disc_time = kwargs_sys['disc_time']
  disc_theta = kwargs_sys['disc_th']
  disc_phi = kwargs_sys['disc_phi']
  t_max = kwargs_sys['t_evals'][-1]
  N_times = len(kwargs_sys['t_evals'])
  force_rescale_factor = kwargs_sys['force_rescale']
  ths = kwargs_sys['ths']
  phis = kwargs_sys['phis']
  N_bins_th = kwargs_sys['N_bins_th']
  N_bins_phi = kwargs_sys['N_bins_phi']
  spacing_theta = (ths[1] - ths[0])
  spacing_phi = (phis[1] - phis[0])
  force_loss = kwargs_sys[
      'force_loss']  # parameter to control the strength of forcing the velocity to zero
  ini_loss = kwargs_sys['ini_loss']
  t_shift_loss = kwargs_sys['t_shift_loss']

  @jit
  def extend_mean(vecs): # Extending the mean through the singular point at the pole of the spherical coordinates
    dx = jnp.mean(vecs[0, :, 0] * np.cos(phis) - vecs[0, :, 1] * np.sin(phis))
    dy = jnp.mean(vecs[0, :, 0] * np.sin(phis) + vecs[0, :, 1] * np.cos(phis))
    dus, dvs = dx * np.cos(phis) + dy * np.sin(phis), -dx * np.sin(
        phis) + dy * np.cos(phis)
    dx2 = -jnp.mean(vecs[-1, :, 0] * np.cos(phis) -
                    vecs[-1, :, 1] * np.sin(phis))
    dy2 = -jnp.mean(vecs[-1, :, 0] * np.sin(phis) +
                    vecs[-1, :, 1] * np.cos(phis))
    dus2, dvs2 = -dx2 * np.cos(phis) - dy2 * np.sin(phis), -dx2 * np.sin(
        phis) + dy2 * np.cos(phis)
    return jnp.concatenate([
        jnp.stack([dus, dvs], axis=1).reshape(1, -1, 2), vecs,
        jnp.stack([dus2, dvs2], axis=1).reshape(1, -1, 2)
    ],
                           axis=0)

  # Derivatives:
  def d_dtheta(u):
    return (u[2:] - u[:-2]) / (2 * spacing_theta)

  def d_dphi(u):
    return (jnp.roll(u, -1, axis=1) -
            jnp.roll(u, 1, axis=1))[1:-1] / (2 * spacing_phi)

  def d2_dtheta2(u):
    return (u[2:] - 2 * u[1:-1] + u[0:-2]) / (spacing_theta**2)

  def d2_dphi2(u):
    return (jnp.roll(u, -1, axis=1) - 2 * u +
            jnp.roll(u, 1, axis=1))[1:-1] / (spacing_phi**2)

  def d2_dphitheta(u):
    dudtheta = (u[2:] - u[:-2])
    return (jnp.roll(dudtheta, -1, axis=1) -
            jnp.roll(dudtheta, 1, axis=1)) / (4 * spacing_phi * spacing_theta)

  # Interpolating the force between the given grid of forces and actual positions, taking rotation and time shift of individual measurements into account.
  @jit
  def force_intp(t, params, iparams): 
    force_on_grid = params['force_grid']
    if 'rot' in iparams:
      rot = iparams['rot']
    else:
      rot = 0
    if 't_shift' in iparams:
      t_shift = iparams['t_shift'] * 1000
    else:
      t_shift = 0
    t_here = t + t_shift
    T_per_sect = t_max / disc_time
    t_sect = jnp.floor(
        jnp.clip(jnp.array(t_here / T_per_sect), 0,
                 disc_time - 1)).astype(int)  # What if to small? negativ?
    t_in_sect = jnp.fmod(t_here, T_per_sect) + jnp.clip(
        t_here, t_max, jnp.inf) - t_max
    force_at_t = (t_in_sect / T_per_sect *
                  (force_on_grid[t_sect + 1] - force_on_grid[t_sect]) +
                  force_on_grid[t_sect])

    grid = jnp.array(
        jnp.meshgrid(
            np.linspace(0, disc_theta, N_bins_th),
            np.linspace(0, disc_phi + 1, N_bins_phi + 1)[:-1] +
            ((disc_phi + 1) * rot / (2 * np.pi)), indexing='ij'))
    force_u = map_coordinates(force_at_t[:, :, 0], grid, order=1,
                              mode='wrap')  # INDEXING?
    force_v = map_coordinates(force_at_t[:, :, 1], grid, order=1, mode='wrap')
    return jnp.stack([force_u,
                      force_v], axis=2) * force_rescale_factor

  @jit
  def eom(y, t, params, iparams, exparams):
    vel = y['velocity']
    vel_ext = extend_mean(vel)
    u_ext = vel_ext[:, :, 0]
    v_ext = vel_ext[:, :, 1]
    D = 10
    D2 = 10
    R = exparams['R']

    sin = jnp.tile(jnp.sin(ths), (N_bins_phi, 1)).transpose()
    cos = jnp.tile(jnp.cos(ths), (N_bins_phi, 1)).transpose()

    u = u_ext[1:-1]
    v = v_ext[1:-1]
    dudP = d_dphi(u_ext)
    dvdP = d_dphi(v_ext)
    dudT = d_dtheta(u_ext)
    dvdT = d_dtheta(v_ext)
    d2udP2 = d2_dphi2(u_ext)
    d2vdP2 = d2_dphi2(v_ext)
    d2udT2 = d2_dtheta2(u_ext)
    d2vdT2 = d2_dtheta2(v_ext)
    d2udPT = d2_dphitheta(u_ext)
    d2vdPT = d2_dphitheta(v_ext)

    force_t = force_intp(t, params, iparams)

    du_dt = (D + D2) / R**2 * (d2udT2 + cos / sin * dudT - u / sin**2) + (
        D / (R * sin)**2) * (d2udP2 - 2 * cos * dvdP) + (D2 / (R**2 * sin)) * (
            d2vdPT - cos / sin * dvdP) - u / R * dudT - v / (
                R * sin) * dudP + cos / (R * sin) * v * v + force_t[:, :, 0]

    dv_dt = (D + D2) / R**2 * (1 / sin**2 * d2vdP2) + (D / R**2) * (
        d2vdT2 + cos / sin * dvdT - v / sin**2 + 2 * cos / sin**2 * dudP) + (
            D2 /
            (R**2 * sin)) * (d2udPT + cos / sin * dudP) - u / R * dvdT + v / (
                R * sin) * dvdP + cos / (R * sin) * v * u + force_t[:, :, 1]

    return {'velocity': jnp.stack([du_dt, dv_dt], axis=2)}

  @jit
  def loss(ys, params, iparams, exparams, ys_target):
    f_loss = jnp.mean(params['force_grid']**2 * force_loss[np.newaxis, np.newaxis, :]) # Malus on large force values
    valid = ~jnp.isnan(ys_target['velocity'])

    # Loss by derivation from observed velocity field
    target_loss = jnp.mean(
        jnp.sum(valid *
                (ys['velocity'] - jnp.nan_to_num(ys_target['velocity']))**2 + (~valid)*ys['velocity']**2/100,
                axis=-1))
    
    # Additional loss for the initial condition, to counter the tendency to avoid forces at the beginning by setting larger initial conditions.
    initial_loss = jnp.mean(
        jnp.sum(
            (ys['velocity'][0] - jnp.nan_to_num(ys_target['velocity'][0]))**2,
            axis=-1)) * ini_loss / N_times
    t_shift_l = iparams['t_shift']**2*t_shift_loss # Punish large shifts in time of a measurement
    return target_loss + f_loss + initial_loss + t_shift_l

  def gen_params():
    force_grid = np.zeros((disc_time + 1, disc_theta + 1, disc_phi, 2))
    iparams = {} if kwargs_sys['N_sys'] == 0 else {
        'rot': np.zeros(kwargs_sys['N_sys']),
        't_shift': np.zeros(kwargs_sys['N_sys'])
    }
    return {
        'force_grid': force_grid
    }, iparams, {
        'R': 350 * np.ones(kwargs_sys['N_sys'])
    }

  def gen_y0():
    grid = np.meshgrid(ths, phis)
    ns = range(4)
    vel = np.zeros((kwargs_sys['N_bins_th'], kwargs_sys['N_bins_phi'], 2),
                   complex)
    for n in ns:
      for m in range(-n, n + 1):
        vel[:, :, 0] += sph_harm(m, n, grid[1], grid[0]) * np.dot(
            np.random.randn(2), np.array([1, 1j])) / 100
        vel[:, :, 1] += sph_harm(m, n, grid[1], grid[0]) * np.dot(
            np.random.randn(2), np.array([1, 1j])) / 100
    return {'velocity': np.real(vel)}

  return eom, loss, gen_params, gen_y0, {'force': force_intp}

In [None]:
# Load prepared, binned data:
velocity_field = np.load('../data/Zebrafish/Fields.npy')
densities = np.load('../data/Zebrafish/Densities.npy')
addInfos = np.load('../data/Zebrafish/AddInfos.pickle',
                   allow_pickle=True)

In [None]:
# Select measurements to simultaneously fit:

# Measurement 3 is a damaged embryo, hence we exclude it.

data_to_take = np.array([0, 1, 2]) # For demonstration this takes a reduced dataset with only three measurements, reducing computation time and memory consumption

# data_to_take = np.array([0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12]) # This would use all availabel measurements except for the damaged one with index 3.

In [None]:
# Prepare dataset
t_evals = np.arange(0, addInfos['N_times'] * 180.0, 180.0)
exparams = {'R': addInfos['Radii'][data_to_take]}
kwargs_sys = {
    'N_sys': len(data_to_take),
    'disc_time': 15,
    'disc_th': 10,
    'disc_phi': 15,
    't_evals': t_evals,
    'force_rescale': 1e-3,
    'ths': addInfos['ths'],
    'phis': np.linspace(0, 2 * np.pi, addInfos['N_bins_phi'] + 1)[:-1],
    'N_bins_th': addInfos['N_bins_th'],
    'N_bins_phi': addInfos['N_bins_phi'],
    'force_loss': np.array([1e-2,5e-1]),
    'ini_loss': 10,
    't_shift_loss': 1e-5
}
ys_target = {'velocity': velocity_field[data_to_take, ..., :2]}
y0_train = {'velocity': np.nan_to_num(ys_target['velocity'][:, 0])}
rots_guess = jnp.array([
    -3.6457386, -0.75646764, 0.43805844, np.nan, -5.2802362, -0.16981766,
    -1.3853836, -5.1078386, 0.0, -4.2, -1.8, -4.6, -0.7
])[data_to_take]
iparams_train = {'rot': rots_guess, 't_shift': np.zeros(len(data_to_take))}

# The loss function is defined per system, hence it cannot constrain different iparams values for different systems.
# Alternatively, the multi_measurement_constraint function is an arbitrary term added to the loss, which can depend
# on the full iparams vectors for the different systems, and hence constrain these relative to each other.
# Here we use it to enforce the shift in time between the systems to be relativ only, avoiding an absolut shift of
# all systems.
@jit
def multi_meas_constraint(ys, params, iparams, exparams, ys_target):
  return 1e-4*jnp.sum(iparams['t_shift'])**2


kwargs_adoptODE = {
    'N_backups': 10,
    'epochs': 100, # For more accurate results use 1000 epochs or more
    'lr': 5e-3,
    'lr_ip': 2e-3,
    'lr_y0': 1e-3,
    'lr_decay': 1.0,
    'lr_decay_ip': 1.0,
    'multi_measurement_constraint': multi_meas_constraint
}

dataset = dataset_adoptODE(define_2d_zebrafish,
                          ys_target,
                          t_evals,
                          kwargs_sys,
                          kwargs_adoptODE,
                          exparams=exparams,
                          y0_train=y0_train,
                          iparams_train=iparams_train,
                          true_iparams=iparams_train)

In [None]:
params_final, losses, errors, params_history = train_adoptODE(dataset)

In [None]:
fig, ax = plt.subplots(1,2, figsize=(7,3))

force = dataset.params_train['force_grid']

ax[0].imshow(np.roll(force[4,...,1],4,1), cmap='coolwarm', norm=SymLogNorm(1e-4), aspect=15/11)

y_tick_loc1 = np.linspace(-0.5,10.5,4)
ax[0].set_yticks(y_tick_loc1)
ax[0].set_yticklabels(['0', '$\\frac{\pi}{3}$', '$\\frac{2\pi}{3}$', '$\pi$'])
ax[0].set_xticks(np.linspace(0,14.5,4))
ax[0].set_xticklabels(['0', '$\\frac{1\pi}{3}$', '$\\frac{4\pi}{3}$', '$2\pi$'])
ax[0].set_ylabel('$\\theta$ (polar)')
ax[0].set_xlabel('$\\phi$ (azimuthal)')
ax[0].set_title("$\\phi$ Force at 5.5 hpf")

ax[1].imshow(np.mean(force[1:-1,...,0], axis=2).T, cmap='coolwarm', norm=SymLogNorm(1e-4), aspect=14/11)
#ax.plot(th_max[3:-3]-0.5, np.linspace(0,13,42))
ax[1].set_xticks(np.linspace(0,13,6))
ax[1].set_xticklabels(["{:.1f}".format(t) for t in np.linspace(5+9/60,7.5-9/60,6)])
ax[1].set_yticks(y_tick_loc1)
ax[1].set_yticklabels(['0', '$\\frac{\pi}{3}$', '$\\frac{2\pi}{3}$', '$\pi$'])
ax[1].set_ylabel('$\\theta$ (polar)')
ax[1].set_xlabel('Hours past fertilization')
ax[1].set_title("$\\theta$ Force");