## MODEL - Simulating Cavity Expansion for *Alasaadi et al.*

### Notes

- **Model**
    - Use `simple_cell_sim` particle modeling engine to simulate cavity expansion and its effect on surrounding cells
    - Place cells in ring-shaped band around a single large central "cell" (the cavity)
    - Fix outer-most ring of "cells" in place to represent the stiff vitelline membrane
    - After an initial burn-in, gradually increase the size of the cavity
    
    
* **Technical**
    - Ensure there is a `Figures` folder and a `Movies` folder before setting `save_figures` or `save_movies` to `True`

### Prep

In [None]:
### Imports

import numpy as np
import matplotlib.pyplot as plt

from itertools import combinations_with_replacement as combr
from matplotlib import animation, rcParams
from ipywidgets import interact, fixed

import sys; sys.path.insert(0, '..')
from simple_cell_sim import simulation as scsim
from simple_cell_sim import potential_funcs as pfs
from simple_cell_sim import force_funcs as ffs

In [None]:
### Overall parameters

# Saving of outputs
save_figures = False
save_movie = False

### Cavity expansion model

In [None]:
### Functions used to generate disk/ring-like initial conditions

def generate_ring_pts(n, r):
    """Generate (y, x) coordinates for a ring of equally spaced points.
    n <- number of points
    r <- radius of ring    
    """
    angles = np.linspace(0.0, 2.0*np.pi, n+1)
    x = r * np.cos(angles[:-1])
    y = r * np.sin(angles[:-1])
    return np.array([y, x]).T


def generate_eqdist_disk_pts(n_rings):
    """Generates approximately equidistant points inside a unit disk.
    Based on holoborodko.com/pavel/2015/07/23/generating-equidistant-points-on-unit-disk."""
    
    # Guard
    if (not isinstance(n_rings, int)) or (n_rings<=0):
        raise ValueError("n_rings must be a positive integer.")
    
    # Prep
    steps  = np.linspace(0.0, 1.0, n_rings)
    points = np.array([0.0, 0.0])[None,:]
    
    # For each step away from the center...
    for i,r in enumerate(steps[1:]):
        
        # Get number of points at that distance
        n = np.pi / np.arcsin(1.0 / (2*(i+1)))
        n = int(np.round(n))
        
        # Get point positions
        ring = generate_ring_pts(n, r)
        points = np.concatenate([points, ring], axis=0)
        
    # Done
    return points


# Testing
fig, ax = plt.subplots(1, 5, figsize=(9, 2), sharex=True, sharey=True)
for i in range(1, 6): 
    disk_pts = generate_eqdist_disk_pts(i)
    ax[i-1].scatter(disk_pts[:, 1], disk_pts[:, 0], s=10)
    ax[i-1].set_aspect('equal', 'box')
    ax[i-1].set_title('n_rings={:d}'.format(i))
plt.tight_layout()
plt.show()

In [None]:
### Defining the initial condition

# Basic scaled disk
n_rings = 9
pos = generate_eqdist_disk_pts(n_rings) * n_rings

# Remove inner rings (except origin)
radii = np.sqrt(np.sum(pos**2.0, axis=1))
mask  = (radii == 0.0) | (radii > (n_rings - 4))
pos   = pos[mask]
radii = radii[mask]

# Set cell identities
cell_states = np.ones(pos.shape[0], dtype=int)  # Cells
cell_states[0] = 0                              # Cavity
cell_states[np.isclose(radii, n_rings)] = 2     # Outer "cells" (vitelline membrane)

# Done
n_cells = pos.shape[0]

# Show result
fig, ax = plt.subplots(1, 2, figsize=(6, 3))
ax[0].scatter(pos[:, 1], pos[:, 0], c=cell_states, s=10)
ax[0].axis('equal')
ax[0].set_xlabel('x'); ax[0].set_ylabel('y')
ax[1].hist(scsim.get_dists(pos)[-1].ravel(), bins=20, alpha=0.7)
ax[1].set_xlabel('dist'); ax[1].set_ylabel('count')
plt.tight_layout()
plt.show()

In [None]:
### Setup of other parameters

# Simulation params
steps   = 10000
delta_t = 0.001

# General model params
rnd_stdev = 0.1  # Sigma of normal distribution of random noise forces
rnd_bound = 0.3  # Random noise forces will be within [-rnd_bound, rnd_bound]

# Cavity expansion params
cavity_dist0 =  5.5  # Initial dist0 of the cavity
cavity_distE =  7.5  # Final dist0 of the cavity
burn_in      = int(0.3*steps)  # Number of steps before cavity expansion
expansion    = int(0.5*steps)  # Number of steps during cavity expansion

# Get all combinations (pairs) of cell states (with replacement)
cell_state_pairs = csps = list(combr(np.unique(cell_states), 2))
print("Cell state pairs:", cell_state_pairs)

# Prep force term dict
force_terms = {}

# Define cavity compression force term
state_mask = np.zeros((cell_states.shape[0], cell_states.shape[0]), dtype=bool)
for csp in csps[:3]:
    state_mask |= (cell_states==csp[0])[:, None] & (cell_states==csp[1])[None, :]
    state_mask |= (cell_states==csp[1])[:, None] & (cell_states==csp[0])[None, :]
force_terms['cavity'] = [
    ffs.f_Hooke,           # force_func
    [cavity_dist0, 50.0],  # force_params (here [dist0, k])
    0.0,                   # min_range
    cavity_dist0,          # max_range (here dist0)
    state_mask,            # state_mask
    None,                  # rnd_stdev
    None,                  # rnd_bound
]

# Define all cellular force terms (uniformly)
cell_dist0    = 1.2
cell_maxRange = 2.0
state_mask = np.zeros((cell_states.shape[0], cell_states.shape[0]), dtype=bool)
for csp in csps[3:]:
    state_mask |= (cell_states==csp[0])[:, None] & (cell_states==csp[1])[None, :]
    state_mask |= (cell_states==csp[1])[:, None] & (cell_states==csp[0])[None, :]
force_terms['cells'] = [
    ffs.f_anharmonic,                    # force_func
    [cell_dist0, -0.4, 2.0, 6.0, 3.0],   # force_params (here [dist0, pot0, m, e1, e2])
    0.0,                                 # min_range
    cell_maxRange,                       # max_range (here dist0)
    state_mask,                          # state_mask
    rnd_stdev,                           # rnd_stdev
    rnd_bound,                           # rnd_bound
]

# Report force terms
for force_key, force_term in force_terms.items():
    print('\nForce term %s:' %force_key)
    for ft in force_term:
        print(' ', ft)

# Plot potentials and force terms
fig, ax = plt.subplots(1, 2, figsize=(8,2))
for force_key, force_term in force_terms.items():
    force_func, force_params, min_range, max_range = force_term[:4]
    exec('pf = pfs.' + force_func.__name__.replace('f', 'pot'))  # :)
    dists = np.linspace(min_range, max_range, 200)
    ax[0].plot(dists, pf(dists, *force_params), label='force_term '+force_key, alpha=0.8)
    ax[1].plot(dists, force_func(dists, *force_params), alpha=0.8)
ax[0].set_xlabel('dist'); ax[1].set_xlabel('dist')
ax[0].set_ylabel('potential'); ax[1].set_ylabel('force'); 
ax[0].legend(fontsize=6)
ax[0].set_ylim(-1, 1); ax[1].set_ylim(-1, 1)
plt.tight_layout()
plt.show()

In [None]:
### Run simulation

# Prep outputs
pos_out    = np.empty((steps+1,) + pos.shape)
pos_out[0] = pos
force_out  = np.empty((steps,) + pos.shape)
time_pts   = np.zeros(steps+1)
cavity_dist0s    = np.zeros(steps+1)
cavity_dist0s[0] = cavity_dist0

# Prep cavity increment
cavity_increment = (cavity_distE - cavity_dist0) / expansion

# Run the simulation
import warnings
with warnings.catch_warnings():
    warnings.simplefilter("error")
    
    for step in range(steps):
        
        # Get updated positions
        pos_out[step+1], force_out[step] = scsim.timestep(
            pos_out[step], list(force_terms.values()), delta_t)
        time_pts[step+1] = time_pts[step] + delta_t
        
        # Force cavity and outer "cells" (vitelline membrane) to remain stationary
        pos_out[step+1, cell_states==0] = pos_out[step, cell_states==0]
        pos_out[step+1, cell_states==2] = pos_out[step, cell_states==2]
        
        # After an initial burn-in, gradually grow the size of the cavity
        if (step > burn_in) and (step < (burn_in+expansion)):
            cavity_dist0s[step+1] = cavity_dist0s[step] + cavity_increment
            force_terms['cavity'][1][0] = cavity_dist0s[step+1]
            force_terms['cavity'][3]    = cavity_dist0s[step+1]
        else:
            cavity_dist0s[step+1] = cavity_dist0s[step]

In [None]:
### Visualize results

@interact(step=(0, steps-1, 1))
def show_model(step=0):
    
    plt.figure(figsize=(6,6))
    
    plt.scatter(pos_out[step, :, 1], pos_out[step, :, 0], 
                c=cell_states, cmap='viridis', s=10, zorder=100)
    #plt.quiver(pos_out[step, :, 1], pos_out[step, :, 0], 
    #           force_out[step, :, 1], force_out[step, :, 0],
    #           angles='xy', scale_units='xy', scale=1.0,
    #           color='magenta', alpha=0.3, width=0.003)
    
    plt.title('step %i - time %.2f' % (step, step * delta_t))
    plt.show()
    
# Show cavity_dist0
plt.figure(figsize=(7, 2))
plt.plot(cavity_dist0s, lw=1)
plt.xlabel('time step')
plt.ylabel('cavity_dist0')
plt.show()

# ->> Works nicely!

### Optimized visualization

In [None]:
### Simple measure of cell density

# Get inner cell tissue area based on expanding cavity
outer_radius = pos_out[0, :, 0].max()
outer_area   = np.pi * outer_radius**2.0
cavity_radii = cavity_dist0s - cell_dist0/2
cavity_areas = np.pi * cavity_radii**2.0
inner_areas  = outer_area - cavity_areas

# Measure cell density by dividing by inner cell tissue area
densities_div = (cell_states==1).sum() / inner_areas

In [None]:
### Stylized visualization function

def show_model_stylized(step, ax, show_quadrant_only=False,
                        show_time=True, show_rchange=True, 
                        show_density=True, show_ticklabels=True):
    
    # Prep
    if ax is None:
        plt.figure(figsize=(5,5))
        ax = plt.gca()
    
    # Clear previous plot, if any (needed for anim)
    ax.cla()
    
    # Cavity
    cavity_radius = cavity_dist0s[step] - cell_dist0/2
    ax.add_patch(
        plt.Circle(
            (pos_out[step, cell_states==0, 1], 
             pos_out[step, cell_states==0, 0]),
            radius=cavity_radius, facecolor='#AC0003', alpha=0.6
        )
    )   
    
    # Outer "cells" (vitelline membrane)
    outer_rdy = pos_out[step, cell_states==2, :]
    outer_rdy = np.concatenate([outer_rdy, outer_rdy[0, :].reshape(1, -1)])
    ax.plot(outer_rdy[:, 1], outer_rdy[:, 0], '.-', color='0.4', ms=5, lw=1)
    
    # Cells
    for cell in pos_out[step, cell_states==1]:
        ax.add_patch(
            plt.Circle(
                (cell[1], cell[0]), radius=cell_dist0/2.0 * 0.75, 
                facecolor='lightgray', edgecolor='0.3', alpha=1.0
            )
        )
        
    # Add connections between cells within max_range
    _, _, dists = scsim.get_dists(pos_out[step])
    for c in range(1, dists.shape[0]):
        for n in range(1, dists.shape[1]):
            if (cell_states[c] == 1) and (cell_states[n] == 1):
                if dists[c, n] < cell_maxRange:
                    ax.plot(
                        [pos_out[step, c, 1], pos_out[step, n, 1]], 
                        [pos_out[step, c, 0], pos_out[step, n, 0]],
                        lw=1.0, c='0.3', alpha=0.8, zorder=-1)
    
    # Remove tick labels (& adjust the ticks)
    if not show_ticklabels:
        
        # No labels
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        
        # No ticks
        #ax.set_xticks([]) 
        #ax.set_yticks([]) 
        
        # Ticks inset & added on right/top
        ax.tick_params(axis="both", direction="in")
        ax.xaxis.set_ticks_position('both')
        ax.yaxis.set_ticks_position('both')    
    
    # Set axis aspect and ranges
    ax.set_aspect('equal', 'box')
    xmin, xmax = pos_out[:, :, 1].min(), pos_out[:, :, 1].max()
    ymin, ymax = pos_out[:, :, 0].min(), pos_out[:, :, 0].max()
    if show_quadrant_only:
        ax.set_xlim(0.0, xmax + 0.01 * (xmax-xmin))
        ax.set_ylim(0.0, ymax + 0.01 * (ymax-ymin))
    else:
        ax.set_xlim(xmin - 0.1 * (xmax-xmin), xmax + 0.1 * (xmax-xmin))
        ax.set_ylim(ymin - 0.1 * (ymax-ymin), ymax + 0.1 * (ymax-ymin))
    
    # Set title
    title_str = ''
    if show_time:
        title_str += 'step: %i; time: %.2f\n' % (step, step * delta_t)
    if show_rchange:
        cavity_radius0 = cavity_dist0 - cell_dist0/2
        expand_factor  = cavity_radius / cavity_radius0
        title_str += 'expansion factor: %.2f\n' % expand_factor
    if show_density:
        title_str += 'cell density: %.0f%%' % (densities_div[step] * 100 / densities_div[0])
    ax.set_title(title_str, fontsize=10)

In [None]:
### Stylized visualization of key time points

# Prep subplots
fig, ax = plt.subplots(1, 3, figsize=(10, 4), sharex=True, sharey=True)

# Optional alternative showing top-right quadrant only
show_quadrant_only = sqo = False
#show_quadrant_only = sqo = True

# Generate figures
show_model_stylized(burn_in, ax[0], sqo)
show_model_stylized(burn_in+expansion//2, ax[1], sqo)
show_model_stylized(steps, ax[2], sqo)

# Cosmetics
plt.tight_layout()

# Saving
if save_figures:
    if not sqo:
        plt.savefig('../Figures/cavity_expansion.png', dpi=300)
        plt.savefig('../Figures/cavity_expansion.pdf', transparent=True)
    else:
        plt.savefig('../Figures/cavity_expansion_quad.png', dpi=300)
        plt.savefig('../Figures/cavity_expansion_quad.pdf', transparent=True)
    
# Done
plt.show()

In [None]:
### Interactive stylized visualization

fig = plt.figure(figsize=(5,5))
interact(
    show_model_stylized, step=(0, steps-1, 1), 
    ax=fixed(None), show_time=fixed(True), show_ticklabels=fixed(True))
plt.show()

### Write out as movie

In [None]:
### Generate & save movie

if save_movie:

    # Prep
    rcParams['figure.dpi'] = 200  # Default is 100

    # Initialize figure
    fig = plt.figure(figsize=(5,5))
    ax = fig.gca()

    # Animation function
    def animate(tp):
        show_model_stylized(
            tp, ax, show_time=False, show_ticklabels=False)
        return (fig,)

    # Construct animation
    subsampling = 6
    cut_burn_in = int(4/5 * burn_in)
    anim = animation.FuncAnimation(
        fig, animate, 
        frames=range(cut_burn_in, steps+1, subsampling), 
        interval=1, blit=True)

    # Initialize writer
    #Writer = animation.writers['ffmpeg']
    Writer = animation.writers['ffmpeg_file']
    writer = Writer(fps=60, metadata=dict(artist='JMHartmann'))

    # Write movie
    anim.save('../Movies/cavity_expansion.mp4', writer=writer)
    #anim.save('../Movies/cavity_expansion.avi', writer=writer)

    # Wrap
    rcParams['figure.dpi'] = 100
    
else:
    print("Movie saving disabled.")