In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
from matplotlib import gridspec
rc('text', usetex=False)
plt.rcParams.update({'font.size': 8})
import scipy
from scipy.interpolate import Rbf, interp1d, griddata
from scipy.signal import find_peaks
from scipy.misc import derivative
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares
import os
import time

In [None]:
def fun(x, gamma, alpha, n):
    # Compute growth rates on mesh
    mu = np.exp(-x)
    vel = 0.5 * (1 -np.exp(-x))
    dx = np.mean(np.diff(x))
    
    def step(t, y):
        y = y.reshape((len(x),3))
        p1 = y[:,0]
        p2 = y[:,1]
        p3 = y[:,2]

        # Finite difference spatial derivatives
        dp1dx = np.zeros_like(x)
        dp2dx = np.zeros_like(x)
        dp3dx = np.zeros_like(x)
        for i in range(1, len(x)-1):
            dp1dx[i] = (p1[i] - p1[i-1]) / dx
            dp2dx[i] = (p2[i] - p2[i-1]) / dx
            dp3dx[i] = (p3[i] - p3[i-1]) / dx
        dp1dx[-1] = (p1[-1] - p1[-2]) / dx
        dp2dx[-1] = (p2[-1] - p2[-2]) / dx
        dp3dx[-1] = (p3[-1] - p3[-2]) / dx
        dp1dx[0] = (p1[1] - p1[0]) / dx
        dp2dx[0] = (p2[1] - p2[0]) / dx
        dp3dx[0] = (p3[1] - p3[0]) / dx

        # Update protein concs
        dy = np.zeros_like(y)
        dp1dt = alpha/(1 + p3**n) - gamma*p1 - mu*p1 - vel*dp1dx
        dp2dt = alpha/(1 + p1**n) - gamma*p2 - mu*p2 - vel*dp2dx
        dp3dt = alpha/(1 + p2**n) - gamma*p3 - mu*p3 - vel*dp3dx
        dy[:,0] = dp1dt
        dy[:,1] = dp2dt
        dy[:,2] = dp3dt
        return dy.ravel()
    return step

In [None]:
def make_kymo_euler(alpha, gamma, n, nx, t0, tmax):
    L = (tmax-t0)/2
    x = np.linspace(0, L, nx)
    dx = np.diff(x).mean()
    dt = dx * 0.5
    nt = int((tmax-t0) // dt)
    dydt = fun(x, gamma, alpha, n)
    y = np.zeros((nx,3,nt))
    y[:,1,0] = 5
    for t in range(1,nt):
        y[:,:,t] = y[:,:,t-1] + dydt(t*dt, y[:,:,t-1]).reshape((nx,3)) * dt
    kymo = np.zeros((nx,nt,3))
    kymo[:,:,0] = y[:,0,:]
    kymo[:,:,1] = y[:,1,:]
    kymo[:,:,2] = y[:,2,:]
    return kymo

In [None]:
def make_kymo(alpha, gamma, n, nx, nt, t0, tmax):
    L = (tmax-t0)/2
    x = np.linspace(0, L, nx)
    y0 = np.zeros((nx,3))
    y0[:,1] = 5
    y0 = y0.ravel()
    res = solve_ivp(fun(x, gamma, alpha, n), t_span=(t0,tmax), y0=y0, t_eval=np.linspace(t0,tmax,nt), method='LSODA')
    sol = res.y.reshape((nx,3,nt))
    kymo = np.zeros((nx,nt,3))
    kymo[:,:,0] = sol[:,0,:]
    kymo[:,:,1] = sol[:,1,:]
    kymo[:,:,2] = sol[:,2,:]    
    return kymo

In [None]:
kymo = make_kymo_euler(alpha=1e4, gamma=2, n=2, nx=500, t0=0, tmax=48)
plt.plot(kymo[:,40,:])

In [None]:
#for t in range(100):
#    sol[t*500/L:,:,t] = np.inf
def norm_kymo(kymo):
    nkymo = np.zeros_like(kymo)
    nkymo[:,:,0] = kymo[:,:,0] / kymo[:,:,0].max()
    nkymo[:,:,1] = kymo[:,:,1] / kymo[:,:,1].max()
    nkymo[:,:,2] = kymo[:,:,2] / kymo[:,:,2].max()
    return nkymo

In [None]:
def map_kymo(kymo):
    rkymo = np.zeros_like(kymo)
    rkymo[:] = np.nan
    nx,nt,_ = kymo.shape
    for t in range(nt):
        for xx in range(((t*nx)//nt)):
            rkymo[-xx+((t*nx)//nt),t,:] = kymo[xx,t,:]
    return rkymo

In [None]:
rkymo = map_kymo(norm_kymo(kymo))
plt.figure(figsize=(60/25.4, 30/25.4))
plt.imshow(rkymo, aspect='auto')
plt.xlabel('Time')
plt.ylabel('$Rad. pos.$')
plt.xticks([])
plt.yticks([])

In [None]:
def residuals(data, L, nx, nt):
    def func(x):
        alpha, gamma, n = x
        alpha = 10**alpha
        kymo = norm_kymo(make_kymo(alpha, gamma, n, L, nx, nt))
        #rkymo = map_kymo(kymo)
        residuals = data.ravel() - kymo.ravel()
        return residuals
    return func

In [None]:
data = norm_kymo(make_kymo(1e4, 0.3, 2, 10, 200, 100))
res = least_squares(
    residuals(data, 10, 200,100), 
    [0,0.3,2], 
    #bounds=[[0,0,1],[1e6,10,10]],
    method='lm'
)

In [None]:
a,g,n = res.x
a = 10**a
print(a, g, n)
k = make_kymo(a, g, n, 10, 200, 100)
plt.imshow(norm_kymo(k), aspect='auto')
plt.figure()
plt.imshow(norm_kymo(data), aspect='auto')

In [None]:
from numpy.fft import fft2, fftshift
frkymo = fftshift(fft2(rkymo[:,:,0]))

In [None]:
plt.imshow(np.absolute(frkymo[230:270,230:270]))

In [None]:
def compute_wave_length(unmapped_kymo, speed, tscale=1, debug=False):
  if debug:
    plt.figure()
    plt.imshow(unmapped_kymo, aspect='auto')

  mean_wave_lengths = []
  for c in range(3):
    # Take the wave at R=0
    signal = unmapped_kymo[:,0,c]
    pks,props = find_peaks(signal, prominence=0.1)
    if debug:
      plt.figure()
      plt.plot(signal)
      plt.plot(pks, signal[pks], '+')
    # Compute the frequency
    if len(pks)>1:
      # Compute average period
      mean_period = np.mean(np.diff(pks)) * tscale
      mean_freq = 1 / mean_period
      mean_wave_length = (speed + 0.5) / mean_freq
      mean_wave_lengths.append(mean_wave_length)
  return np.mean(mean_wave_lengths)

def compute_wave_speed(nkymo, dt=1, tscale=1, rscale=1, debug=False):
  if debug:
    plt.figure()
    plt.imshow(nkymo)

  # Find peaks in radial axis
  pks_list = []
  for channel in range(3):
    for i in range(1, nkymo.shape[0], 1):
      nk = nkymo[i,1:-1,channel]
      idx = ~np.isnan(nk) # times when both signals are within the colony
      nk = nk[idx]
      pks,props = find_peaks(nk, prominence=0.1)
      #print(props['peak_heights'])
      # Trim off end points
      tpks = []
      for p in pks:
        if p>1 and p<len(nk)-5:
          tpks.append(p)
          if debug:
            plt.plot(p, i, '.w')
      pks_list.append(tpks)

  # Track peaks from one timepoint to next +dt
  speeds = []
  for i in range(len(pks_list)-dt):
    pks = pks_list[i]
    # Peaks at next time point
    next_pks = pks_list[i+dt]
    for pk in pks:
      # Find closest peak at next time point
      min_dist = 10
      for next_pk in next_pks:
        if next_pk<=pk:
          dist = np.abs(next_pk-pk)
          if dist<min_dist:
            min_dist = dist
      # Make a threshold to avoid wrap around effects
      if min_dist<10:
        speed = min_dist*rscale / (dt*tscale)
      else:
        speed = np.nan
      speeds.append(speed)
      
  # Average of speeds of peaks, make sure we have enough data for estimate
  if len(speeds)>5:
    mean_speed = np.nanmean(speeds)
  else:
    mean_speed = np.nan

  nt,nx = nkymo.shape[:2]
  x1 = nt*mean_speed/rscale*tscale
  y1 = nx/mean_speed*rscale/tscale
  if debug:
      if x1<nx:
        plt.plot([0,x1], [nt,0], 'w')
      elif y1<nt:
        plt.plot([0,nx], [y1,0], 'w')

  #print('Mean wave speed = ', mean_speed)
  return mean_speed

In [None]:
vels = []
wavelengths = []
gammas = np.linspace(0, 2, 6, endpoint=True)
alphas = np.logspace(1, 6, 12, endpoint=True)
gamma = 0.5
#for gamma in gammas:
for alpha in alphas:
    kymo = make_kymo_euler(alpha=alpha, gamma=gamma, n=2, nx=250, t0=0, tmax=24)
    rkymo = map_kymo(norm_kymo(kymo))
    vp = compute_wave_speed(rkymo.transpose(1,0,2), dt=12, tscale=24/kymo.shape[1], rscale=12/kymo.shape[0], debug=False)
    wavelength = compute_wave_length(norm_kymo(kymo.transpose(1,0,2)), vp, tscale=24/kymo.shape[1])
    vels.append(vp)
    wavelengths.append(wavelength)
print(kymo.shape)

In [None]:
print(kymo.shape)

print(vels)
print(wavelengths)
plt.imshow(np.isnan(rkymo[:,:,1]))

In [None]:
#plt.plot(gammas, vels, '.-')
plt.plot(np.log(alphas), wavelengths, '.-')
plt.ylim([0,6])
print(alphas)

In [None]:
vp = compute_wave_speed(rkymo, dt=1, tscale=24/250, rscale=20/250, debug=False)
print(vp)

In [None]:
l = compute_wave_length(kymo, vp, tscale=24/250, debug=False)

In [None]:
l