# 2023-05-16 d7-glucose uptake

In [None]:
from dask.distributed import Client

client = Client(processes=False)
client

In [None]:
import glob
import re
import os
from pathlib import Path
from aicsimageio.readers import TiffGlobReader

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from mpl_interactions import hyperslicer
import napari
import microutil as mu
import scipy.ndimage as ndi
from scipy.stats import linregress
from scipy.spatial import cKDTree
from skimage.morphology import disk
from skimage.draw import disk as draw_disk
from skimage.registration import phase_cross_correlation
from skimage.filters import sobel
from tqdm.autonotebook import tqdm
from skimage.registration import phase_cross_correlation
from srs_tools import BackgroundEstimator

%matplotlib widget
plt.style.use('paper')

In [None]:
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

labels_cmap = plt.cm.viridis.copy()
labels_cmap.set_under(alpha=0)

fig_path = "/Users/johnrussell/Documents/figures/2023-06-figs-paper/"
savefigs = False

## Load and preprocess data

In [None]:
# def indexer_check(idxr):
#     for col in idxr:
#         consecutive = np.arange(idxr[col].nunique())
#         found = np.sort(idxr[col].unique())
#         assert np.allclose(consecutive, found), f"Found invalid values in column {col}"

# experiment_paths = [Path(x) for x in glob.glob("/Volumes/Seagate Backup Plus Drive/2023-05-16/dh229*/")]

# experiment_paths = [e for e in experiment_paths if not str(e).endswith(".zarr")]

# full_idxr = pd.DataFrame()
# for i,e in enumerate(experiment_paths):
#     files = [x for x in glob.glob(str(e/"*.tif")) if "LUT" not in x]
#     indexer = pd.DataFrame([re.findall("(\d+)", Path(x).name) for x in files], columns=list('SLZC')).astype(int)
#     indexer['Q'] = pd.Series([re.findall("fluo|srs", Path(x).name)[0] for x in files],dtype='category').cat.codes
#     indexer['filenames'] = pd.Series(files, copy=True)

#     scale = indexer.groupby('S')["Q"].nunique().sum()

#     indexer['T']=indexer.groupby(list('SQ'))["L"].transform(lambda x: (x-x.min())//scale)
#     indexer['I'] = i
#     indexer_check(indexer[list('SQTCZ')])
#     full_idxr = pd.concat([full_idxr,indexer])

# full_idxr['C'] = full_idxr['C']+full_idxr['Q']
# full_idxr = full_idxr.drop(columns='Q')

# readers = []
# for i in range(2):
#     idxr = full_idxr.loc[full_idxr['I']==i]
#     readers.append(TiffGlobReader(idxr['filenames'], idxr[list('STCZ')]))

# ds0 = readers[1].get_xarray_dask_stack(scene_character='S')

## Metadata

In [None]:
# import xml.etree.ElementTree as ET

# meta_path = "/Volumes/Seagate Backup Plus Drive/2023-05-16/dh229/*_Properties.xml"

# files = sorted(glob.glob(meta_path))

# indexer = full_idxr.loc[full_idxr['I']==0].drop(columns='I')
# ldm_lookup = indexer.groupby('L')[['S','T']].min()

# timestamps = []
# dim_data = []
# for s,file in enumerate(files):
#     parsed = ET.parse(file)
#     ldm = int(re.findall("(\d+)_Properties",Path(file).name)[0])
    
#     for x in parsed.iter("TimeStamp"):
#         d = x.attrib
#         d["ldm"] = ldm
#         s,t = ldm_lookup.loc[ldm].values
#         d['S'] = s
#         d['T'] = t
#         timestamps.append(d)
    
#     for x in parsed.iter("DimensionDescription"):
#         d = x.attrib
#         d["ldm"] = ldm
#         dim_data.append(d)
        
# timestamps = pd.DataFrame(timestamps)
# dim_data = pd.DataFrame(dim_data).set_index(['ldm','DimID']).unstack('DimID')

# timestamps['datetime'] = timestamps.apply(
#     lambda df: pd.to_datetime(df['Date'] + " " + df['Time']) + pd.to_timedelta(int(df['MiliSeconds']), unit="ms"),axis=1)

# timestamps = timestamps.sort_values('datetime')

# timestamps.to_hdf("/Users/johnrussell/Data/2023-05-16-dh229_timestamps.h5", key='df')

# dim_data.to_hdf("/Users/johnrussell/Data/2023-05-16-dh229_dimdata`.h5", key='df')

## Analysis

In [None]:
timestamps = pd.read_hdf("/Users/johnrussell/Data/2023-05-16/2023-05-16-timestamps.h5", key='df')
timestamps['RelativeTime'] = timestamps['RelativeTime'].astype('f4')
dim_data = pd.read_hdf("/Users/johnrussell/Data/2023-05-16/2023-05-16-dimdata.h5", key='df')

times = timestamps.groupby(list("ST"))['RelativeTime'].min()/60
dt = timestamps.groupby(['S','T']).datetime.min().groupby('S').diff()
dt_min = dt.mean().total_seconds()/60

zpos = timestamps[['S','T','ldm']].set_index('ldm').join(dim_data.loc[:,('Origin','Z')].rename("Z")).groupby(['S','T']).last().astype('f4')

In [None]:
ds = xr.open_zarr("/Users/johnrussell/Data/2023-05-16/dh229.zarr/").isel(T=slice(36))

In [None]:
t_data = dt_min*ds['T'].values #+28.75 

In [None]:
srs_squash = ds.images.isel(C=2).mean('Z', dtype='f4')#.load()

In [None]:
srs_max = ds.images.isel(C=2).max('Z').astype('f4')

In [None]:
fluo = ds.images.isel(C=0).max('Z')#.load()

In [None]:
test_srs = srs_squash.isel(T=35).load().data
test_fluo = fluo.isel(T=35).load().data

In [None]:
shift = [phase_cross_correlation(s,f)[0] for s,f in zip(test_srs, test_fluo)]

In [None]:
fluo.load();
ds.fmax_labels.load();

In [None]:
aligned_fluo = xr.concat([f.shift(dict(zip('YX', s.astype(int))), fill_value=0) for f,s in zip(fluo, shift)], dim='S')
aligned_mask = xr.concat([f.shift(dict(zip('YX', s.astype(int))), fill_value=0) for f,s in zip(ds.fmax_labels, shift)], dim='S')

In [None]:
if 'labels' not in ds:
    tracked = xr.zeros_like(aligned_mask)
    for i,s in enumerate(aligned_mask):
        tracked.data[i] = mu.btrack.gogogo_btrack(s.data, "cell_config.json", 10, "/Users/johnrussell/tmp/tracks/h5")
    
    ds['labels'] = tracked
    ds[['labels']].to_zarr("/Users/johnrussell/Data/2023-05-16/dh229.zarr/", mode='a')
    

In [None]:
# be = BackgroundEstimator(srs_squash, tracked)
be =  BackgroundEstimator(srs_squash, ds.labels)
be.make_cv_labels()
# be.sigma_scan(n_samples=5)
# be.sigma_opt.load();
be.sigma_opt = xr.DataArray([2,66], dims=['k']) #optimum for mean projection
be.background_estimate;

# be_max =  BackgroundEstimator(srs_max, ds.labels)
# be_max.make_cv_labels()
# be_max.sigma_opt = xr.DataArray([2,66], dims=['k']) #optimum for mean projection
# be_max.background_estimate;

In [None]:
srs_bsub = srs_squash-be.background_estimate
# max_sub = srs_max-be_max.background_estimate

In [None]:
srs_bsub.load();

In [None]:
# bkgd_avg = be.background_estimate.mean(list('SYX'))
# bkgd_std = be.background_estimate.std(list('SYX'))

# plt.figure()
# plt.errorbar(bkgd_avg['T'].values, bkgd_avg, yerr = bkgd_std, fmt='o', capsize=3)

In [None]:
cell = ['S', 'CellID'] #shorthand for use in groupby

In [None]:
ds.labels.load();

In [None]:
srs_avgs = (mu.single_cell.average(ds,srs_bsub)
            .to_series()
            .unstack('T')
            .dropna(how='all')
            .stack('T', dropna=False))

# max_avgs = (mu.single_cell.average(ds,max_sub)
#             .to_series()
#             .unstack('T')
#             .dropna(how='all')
#             .stack('T', dropna=False))

In [None]:
coms = (mu.single_cell.center_of_mass(ds)
            .to_series()
            .unstack(['com','T'])
            .dropna(how='all')
            .stack('T', dropna=False))

In [None]:
areas = (mu.single_cell.area(ds)
            .to_series()
            .reorder_levels(['S','CellID','T'])
            .loc[coms.index])

In [None]:
disp = np.sqrt((coms.groupby(['S','CellID']).diff()**2).sum(1, skipna=False))

In [None]:
first_time = coms.dropna().reset_index('T').groupby(cell).first().set_index('T', append=True).index

In [None]:
disp.loc[first_time]=0

**Goal**

Select cells that have 15 consecutive observations with displacements less than 7

- GLobal shift happens between t=0 and t=1 so ignore that in check 
  - otherwise do some time to time global registration and then leave it in
- Some movements as large as 8 or 9 seem to be real
- Some of the tricky tracking errors where it the cell label walks away into the daughter are quite small...

*BUT* i think  the 15 consecutive observations is a more stringent cut anyway i.e. many of the tracking errors do not last that long/

In [None]:
cell_max_disps = disp.loc[pd.IndexSlice[:,:,2:]].groupby(['S','CellID']).max().dropna()

In [None]:
fig, ax = plt.subplots(1,2,sharex=True, figsize=(9,4))
ax[0].hist(disp.loc[pd.IndexSlice[:,:,2:]].dropna().values, bins=45, range=[0,15], )
ax[0].set_title("All displacements")
ax[1].hist(cell_max_disps, bins=45, range=[0,15],)
ax[1].set_title("Cell Maximum displacements")
for a in ax:
    a.semilogy();

In [None]:
# x = disp.loc[(0, 150)]
# df = disp.loc[(0,15)]

In [None]:
def check_consecutive(x, N=15):
    m = x.isna()
    s = m.cumsum()
    final_mask = s.map(s[~m].value_counts()).ge(N) & ~m
    return final_mask

In [None]:
# Specific entries that are a part of a 20+ consecutive observation run
nconsec = 15
consecutive_cells = disp.groupby(['S','CellID'], group_keys=False).apply(check_consecutive, N=nconsec)

In [None]:
disp_consecutive = disp.loc[consecutive_cells].unstack('T')

In [None]:
fig, ax = plt.subplots(1,2,sharex=True, figsize=(6.52,3.2))
ax[0].hist(cell_max_disps.values, bins=45, range=[0,15], )
ax[0].set_title("All cells")
ax[1].hist(disp_consecutive.max(1).values, bins=45, range=[0,15],)
ax[1].set_title("Continuously observed cells")
ax[0].set_xlabel("$N_i{cells}$")
fig.suptitle("Maximum single-cell displacements")
for a in ax:
    a.semilogy()
    a.set_xlabel("Displacement (pixels)")
if savefigs:
    plt.savefig(fig_path+"max_disps.png")

In [None]:
big_jump = (disp_consecutive.loc[:,2:].groupby(['S','CellID']).max()>8).stack('T')

In [None]:
def check_labels_from_multiindex(labels, index):
    check_labels = xr.zeros_like(labels)
    
    for s in range(labels.sizes['S']):
        try:
            checks = index.get_loc_level(s)[1].values+1
        except KeyError:
            continue
            
        mask = xr.DataArray(np.isin(labels.data[s], checks), dims=labels[s].dims)
        check_labels[s] = labels[s].where(mask)
    return check_labels

In [None]:
jump_idx = big_jump.loc[big_jump].reset_index('T').index.unique()

In [None]:
# check = check_labels_from_multiindex(ds.labels, jump_idx)

# v = napari.Viewer()
# il= v.add_image(srs_bsub)
# fl = v.add_image(aligned_fluo)
# # ll = v.add_labels(ds.labels)
# cl = v.add_labels(check)

In [None]:
#big_jump_cell_times = big_jump.loc[big_jump].reset_index('T')['T']

In [None]:
# This doesnt work any more

# srs_with_jumps = srs_avgs.loc[consecutive_cells].unstack('T').loc[jump_idx]
# pts = srs_with_jumps.loc[jump_idx].max(1)

# fig, ax = plt.subplots()
# ax.plot(srs_with_jumps.values.T, 'k', linewidth=1, alpha=0.2);
# ax.plot(jump_idx.get_level_values('T'), pts.values, 'rx')
# plt.ylim(-3, 38);

In [None]:
no_jumps = big_jump.groupby(cell).sum()==0
# no_jump_idx = no_jumps.loc[no_jumps].index

In [None]:
consecutive_coms = coms.loc[consecutive_cells]
central_cells = ((consecutive_coms['X'].between(5,507))&(consecutive_coms['Y'].between(5,507))).groupby(cell).all()

In [None]:
srs_clean = srs_avgs.loc[consecutive_cells].unstack('T').loc[no_jumps&central_cells]
# max_clean = max_avgs.loc[consecutive_cells].unstack('T').loc[no_jumps&central_cells]
# raw_clean = raw_avgs.loc[consecutive_cells].unstack('T').loc[no_jumps&central_cells]

In [None]:
consecutive_cells.unstack('T').sum().max()

In [None]:
central_cells.sum()

In [None]:
srs_clean.shape[0]

In [None]:
# srs_clean.to_csv("/Users/johnrussell/Data/2023-05-16/2023-05-16-srs.csv")
# areas_clean.to_csv("/Users/johnrussell/Data/2023-05-16/2023-05-16-area.csv")

In [None]:
areas_clean = areas.loc[consecutive_cells].unstack('T').loc[no_jumps&central_cells]

In [None]:
# first_times_clean = areas_clean.stack('T').reset_index('T').groupby(cell).first().set_index('T', append=True).index
first = srs_clean.stack('T').reset_index('T').groupby(cell).first().set_index('T', append=True).iloc[:,0]

In [None]:
all_first = srs_avgs.reset_index('T').dropna().groupby(cell).first().set_index('T', append=True).iloc[:,0]

In [None]:
plt.figure()
plt.plot(t_data, areas_clean.values.T, 'k', alpha=0.025);
# plt.errorbar(t_data, areas_clean.median().values, yerr=(areas_clean.quantile([0.25,0.75])-areas_clean.median()).abs().values, fmt='mo', capsize=3)
plt.ylim([0,400])
plt.title("Single Cell Area Trajectories")
plt.ylabel("Area (pixels)")
plt.xlabel("Time (Minutes)")
if savefigs:
    plt.savefig(fig_path+"cell_areas.png");

In [None]:
growth_cells = areas_clean.loc[(areas_clean.stack('T').loc[first.index]<100).droplevel("T")]

In [None]:
plt.figure()
plt.plot(t_data, growth_cells.values.T, 'k', alpha=0.025);
# plt.errorbar(t_data, areas_clean.median().values, yerr=(areas_clean.quantile([0.25,0.75])-areas_clean.median()).abs().values, fmt='mo', capsize=3)
plt.ylim([0,400])
plt.title("Single Cell Area Trajectories")
plt.ylabel("Area (pixels)")
plt.xlabel("Time (Minutes)")
if savefigs:
    plt.savefig(fig_path+"cell_areas.png");

**Note** `fit_all_exponentials` should maybe take an arg about an offset and whether to use absolute or relative t values

In [None]:
# srs_jumps = (srs_clean.diff(axis=1).abs()>15).any(1)
# check_idx = srs_jumps.loc[srs_jumps].index

In [None]:
# time_count = srs_avgs.groupby(['S','CellID']).count()

# present = (~srs_avgs.unstack(['S','CellID']).isna()).astype(int)
# appears_once = (present.diff()==1).astype(int).sum()<2
# disappears_once = (present.diff()==-1).astype(int).sum()<2

# keep = srs_avgs.unstack('T').loc[(appears_once & disappears_once & (time_count>15)), :35]

In [None]:
fig, ax = plt.subplots(1,5, sharex=True, sharey=True, figsize=(6.52,3))
for s in range(5):
    traces = ax[s].plot(srs_clean.loc[s].values.T, 'k', linewidth=1, alpha=0.05)
    ax[s].plot(5*(zpos.loc[s]-zpos.loc[s].mean())-10, '.-', label='Scaled Z')
    
ylim = ax[0].get_ylim()
ax[0].plot(20,100,'k', label='Cell Traces')
ax[0].set_ylim(ylim)
ax[0].legend();
plt.suptitle("\"Sawtoothing\" corresponds with changes in Z")
if savefigs:
    plt.savefig(fig_path+"sawtooth_abs_z.png")

In [None]:
# fig, ax = plt.subplots(1,5, sharex=True, sharey=True, figsize=(10,3))
# for s in range(5):
#     traces = ax[s].plot(max_clean.loc[s].values.T, 'k', linewidth=1, alpha=0.05)
#     ax[s].plot(5*(zpos.loc[s]-zpos.loc[s].mean())-10, '.-', label='Scaled Z postion')
    
# ylim = ax[0].get_ylim()
# ax[0].plot(20,100,'k', label='Cell Traces')
# ax[0].set_ylim(ylim)
# ax[0].legend()
# plt.suptitle("Sawtooth seems to correspond to changes in Z (max projections)")

In [None]:
savefigs

In [None]:
dz = zpos.groupby('S').diff()['Z'].loc[(slice(None), slice(srs_clean.columns.max()))]
dsrs = srs_clean.diff(axis=1)#/dt_min

fig, ax = plt.subplots(2,1, sharex=True, )
for s in range(5):
    ax[0].plot(t_data, dsrs.loc[s].mean().values, color=colors[s])
    ax[1].plot(t_data, dz.loc[s].values, color=colors[s])    
ax[0].plot(t_data, dsrs.mean().values, 'k--', label="Mean")
ax[0].legend(loc='lower left')
# ax[0].plot(xplt, dy, 'k')
ax[1].set_xlabel("Time (minutes)")
ax[0].set_ylabel("$\Delta$ SRS (a.u.)")
ax[1].set_ylabel("$\Delta$ Z (µm)")
ax[0].set_title("Z-oscillations correspond with \"Sawtoothing\"")
if savefigs: plt.savefig(fig_path+"sawtooth_deltas.png")

### Additive correction - obsolete

In [None]:
# diff = srs_clean.diff(axis=1)
# test_diff = diff.loc[0]
# test_dz = dz.loc[0]

# dz_gb = diff.stack('T').unstack('CellID').groupby(dz.fillna(0)).mean()

# plt.figure()
# plt.errorbar(dz_gb.index.values, dz_gb.mean(1), yerr=dz_gb.std(1), fmt='o', capsize=3)
# plt.xlabel("$\Delta Z$")
# plt.ylabel("$\Delta SRS$")

# fig, ax = plt.subplots(3,1)
# ax[0].plot(test_diff.std().values)
# # ax[0].plot((mcorr*test_diff).std().values, label='mult. corr.')
# ax[0].set_ylabel("$\sigma_{\Delta SRS}$")

# ax[1].plot(test_dz)
# ax[1].set_ylabel("$\Delta$ Z (µm)")

# ax[2].plot(test_diff.mean().values)
# ax[2].set_ylabel("$\Delta SRS$")

# fig.suptitle("Z-changes do not change the variance")
# if savefigs:
#     plt.savefig(fig_path+"dz_var.png"))

# a,k=fit[0]
# dy = a*k*np.exp(-k*xplt)
# y_pred = exp_approach(t_data, a,k)
# dy_pred = a*k*np.exp(-k*t_data)*dt_min

# res_dsrs = (srs_clean.diff(axis=1)- dy_pred).groupby('S').mean()

# fig, ax = plt.subplots(1,2, sharey=True)
# ax[0].plot(t_data, dsrs.groupby('S').mean().values.T)
# ax[0].plot(t_data,np.zeros_like(t_data),'k')
# ax[1].plot(t_data, res_dsrs.values.T)
# ax[1].plot(t_data,np.zeros_like(t_data),'k')

# fig, ax = plt.subplots(1,2, sharex=True, sharey=True)
# ax[0].plot(test_diff.values.T, 'k', linewidth=1, alpha=0.1)
# # ax[0].plot(5*(zpos.loc[s]-zpos.loc[s].mean())-10, '.-',
# # ax[1].plot(test_corrected_diff.values.T, 'k', linewidth=1, alpha=0.1);
# ax[1].plot((test_diff - test_diff.mean()).values.T, 'k', linewidth=1, alpha=0.1);
# # ax[0].axhline(5, color='r')
# # ax[0].axhline(-5, color='b')
# # ax[1].axhline(5, color='r')
# # ax[1].axhline(-5, color='b')
# ax[0].set_title("Raw")
# ax[1].set_title("Corrected")
# ax[0].set_ylabel("$\Delta SRS$")
# ax[0].set_xlabel("Time index")
# ax[1].set_xlabel("Time index")
# ax[0].plot(5*test_dz.values-10,'-o')
# ax[1].plot(5*test_dz.values-10,'-o')
# if savefigs:
#     plt.savefig(fig_path+"test_fov_corr.png")

# fig, ax = plt.subplots(5,2, sharex=True, sharey=True)
# for s in range(5):
#     ax[s,0].plot(diff.loc[s].values.T, 'k', linewidth=1, alpha=0.02)
#     ax[s,1].plot((diff.loc[s]-diff.loc[s].mean()).values.T, 'k', linewidth=1, alpha=0.02);
#     ax[s,0].axhline(5, color='r')
#     ax[s,0].axhline(-5, color='b')
#     ax[s,1].axhline(5, color='r')
#     ax[s,1].axhline(-5, color='b')
#     ax[s,0].set_ylabel("$\Delta SRS$")
#     # ax[s,2].plot(5*dz.loc[s].fillna(0).values,'-o')
#     # ax[s,1].plot(5*dz.loc[s].values-10,'-o')
# ax[0,0].set_ylim([-10,10])
# ax[0,0].set_title("Raw")
# ax[0,1].set_title("Corrected")
# # ax[0,2].set_title("$\Delta Z$")
# ax[s,0].set_xlabel("Time index")
# ax[s,1].set_xlabel("Time index")

# diff = (srs_clean.diff(axis=1) - dy_pred)

# offset = (diff.groupby('S').transform(lambda df: df-df.mean())+dy_pred).stack('T', dropna=False)
# # offset.loc[first_times_clean] = srs_clean.stack('T').loc[first_times_clean]
# offset.loc[first.index] = first

# srs_recon = offset.unstack('T').cumsum(axis=1)

### Multiplicative correction

In [None]:
k = np.random.randn((ds.sizes['T'])).astype('f4')
k0 = np.zeros(ds.sizes['T']).astype('f4')

In [None]:
from scipy.optimize import minimize

In [None]:
def d2loss_with_reg(k, x, alpha=1):
    d = np.diff((1+k)*x, n=2, axis=-1)
    d = d*d
    w = np.convolve((1+k), [1,2,1], mode='valid')
    w = w*w
    return np.nanmean(d/w) + alpha*np.mean(k*k)

In [None]:
def d2loss_with_reg_bcs(k, x, alpha=1.):
    s = 1+k
    y = s*x
    d = np.zeros_like(y)
    d[:,1:-1] = -2*y[:,1:-1] + y[:,2:] + y[:,:-2]
    d[:,0] = y[:,0] -2*y[:,1] + y[:,2]
    d[:,-1] = y[:,-1] -2*y[:,-2] + y[:,-2]
    d = d*d
    w = np.zeros_like(s)
    w[1:-1] = 2*s[1:-1] + s[2:] +s[:-2]
    w[0] = s[0]+2*s[1]+s[2]
    w[-1] = s[-1]+2*s[-2]+s[-3]
    w = w*w
    return np.nanmean(d/w) + alpha*np.mean(k*k)

In [None]:
# def d1loss_with_reg(k, alpha=1):
#     d = np.diff((1+k)*test, n=1, axis=1)
#     return np.nanmean(d*d) + alpha*np.mean(k*k)

# def central_d1loss_with_reg(k, alpha=1):
#     y = np.pad((1+k)*test, 1, mode='edge')
#     d = np.roll(y,1, axis=-1)-np.roll(y,-1, axis=-1)
#     d = d[:,1:-1]
#     return np.nanmean(d*d) + alpha*np.mean(k*k)

In [None]:
na = 9
nt = srs_clean.shape[1]
ns = 5
alphas = np.logspace(-2,2, num=na)
avgs = np.zeros((ns, nt, na))
kopts = np.zeros((ns, nt, na))
tloss = np.zeros((ns, na))

for s in tqdm(range(ns)):
    y = srs_clean.loc[s].values
    val_mask = np.zeros(y.shape[0],dtype='bool')
    val_mask[np.random.choice(np.arange(y.shape[0]),size=50, replace=False)]=1
    val = y[val_mask]
    train = y[~val_mask]

    for i,a in enumerate(tqdm(alphas, leave=False)):
        res = minimize(d2loss_with_reg_bcs, k0, (train, a))
        kopts[s,:,i] = res.x
        tloss[s,i] = res.fun
        avgs[s,:,i] = np.nanmean((1+res.x)*train, axis=0)

In [None]:
viridis_colors = plt.cm.viridis(np.linspace(0,1,alphas.shape[0]))
fig, axs = plt.subplots(1,5, sharex=True, sharey=True)
for s, ax in zip(range(ns), axs):
    for x, c, a in zip(avgs[s].T, viridis_colors, alphas):
        ax.plot(x, color=c, label=f"{a:0.2e}")
    ax.plot(srs_clean.loc[s].mean().values, 'k', label='Original')
fig.suptitle('Second Derivative Correction')
# axs[-1].legend(title="Regularization parameter", loc='upper right',bbox_to_anchor=(1.35, 1.05))
if savefigs: plt.savefig(fig_path+"wt_d2_correction.png")

In [None]:
kopt= pd.DataFrame(kopts[...,4], index=pd.Index(range(kopts.shape[0]), name='S'), columns=pd.Index(range(kopts.shape[1]), name='T'))

In [None]:
fig, ax = plt.subplots(1,2,sharex=True, sharey=True,figsize=(6.52,3))
for i in range(5):
    ax[0].plot(t_data, srs_clean.loc[i].mean().values, color=colors[i])
    ax[1].plot(t_data, (1+kopt.loc[i].values)*srs_clean.loc[i].mean().values, color=colors[i])
ax[0].set_title("Raw SRS Intensity")
ax[1].set_title("SRS Intensity with smoothing")
ax[0].set_xlabel("Time (minutes)")
ax[1].set_xlabel("Time (minutes)")
ax[0].set_ylabel("SRS Intenity (a.u.)")
if savefigs: plt.savefig(fig_path+"sawtooth_corr.png")

In [None]:
fig, ax = plt.subplots(3,1, sharex=True, )
for s in range(5):
    ax[0].plot(t_data, dsrs.loc[s].mean().values, color=colors[s])
    ax[1].plot(t_data, dz.loc[s].values, color=colors[s])    
    ax[2].plot(t_data, 1+kopt.loc[s], color=colors[s])
ax[0].plot(t_data, dsrs.mean().values, 'k--', label="Mean")
ax[0].legend(loc='lower left')
# ax[0].plot(xplt, dy, 'k')
ax[1].set_xlabel("Time (minutes)")
ax[0].set_ylabel("$\Delta SRS$")
ax[1].set_ylabel("$\Delta$ Z (µm)")
ax[0].set_title("Z-oscillations correspond with \"Sawtoothing\"")
if savefigs:
    plt.savefig(fig_path+"wt_sawtooth_deltas_weights.png")

In [None]:
srs_recon = (1+kopt)*srs_clean

In [None]:
all_recon = (1+kopt)*(srs_avgs.unstack('T'))

In [None]:
working_dir = "/Users/johnrussell/Data/2023-05-16/"
ename = 'dh229'

### Inspect rescaled data

In [None]:
fig, ax = plt.subplots(1,2,sharex=True, sharey=True, figsize=(6.52,3.2))
ax[0].plot(t_data, srs_clean.values.T, 'k', alpha=0.01);
ax[0].set_title('Raw')

ax[1].plot(t_data, srs_recon.values.T, 'k', alpha=0.01);
ax[1].set_title('Corrected')
ax[0].set_ylabel("SRS (a.u.)")
fig.suptitle("Second Derivative Smoothing Reduces Sawtoothing")
for a in ax:
    a.set_xlabel("Time (minutes)");
    a.set_ylim([-10,40])
if savefigs:
    plt.savefig(fig_path+"saw_corr_before_after.png")

## Curve fitting

In [None]:
from scipy.optimize import curve_fit

In [None]:
def exp_approach(x, a, k):
    return a*(1-np.exp(-k*x))#+b

In [None]:
def exp_approach2(x, a, k,t):
    return a*(1-np.exp(-k*(x-t)))#+b

In [None]:
def exp_approach3(x, a, k, b):
    return a*(1-np.exp(-k*x))+b

In [None]:
def exp_decay(x, a, b, k):
    return a*np.exp(-k*x)#+b

In [None]:
df = srs_recon
# xdata = dt_min*df.columns.values + 30
y = df.mean().values
y_offset = y[0]
y = y-y_offset
sigma = df.sem().values

fit = curve_fit(exp_approach, t_data, y, sigma=sigma, p0=[20, 0.1], bounds=(0,np.inf))

a,k = fit[0]
eqn = f"Exponential fit: \n{a:0.1f} (1-exp[-{k:0.3g}t])"
xplt = np.linspace(0, 550, 250)
plt.figure()
plt.plot(xplt, exp_approach(xplt, *fit[0])+y_offset, zorder=0, linewidth=3, alpha=0.75, label=eqn)
# plt.plot(t_data, y+y_offset, 'k.', zorder=-1)
plt.errorbar(t_data, y+y_offset,yerr=sigma, fmt='k.', capsize=3, zorder=-1)
plt.title(f"Wild Type Net Glucose Uptake")
plt.legend()
plt.xlabel("Time (minutes)")
plt.ylabel("SRS Intensity (a.u.)")
if savefigs:
    plt.savefig(fig_path+"bulk_fit.png")

In [None]:
1-(1/np.e)

In [None]:
k*60

In [None]:
np.sqrt(fit[1][1,1])

In [None]:
# srs_total = srs_recon*areas_clean

# fig,ax = plt.subplots(1,2, sharex=True,figsize=(6.52, 3.2))
# ax[0].plot(t_data, srs_recon.values.T, 'k', linewidth=1, alpha=0.02)
# ax[0].set_xlabel("Time (minutes)")
# ax[0].set_ylabel("SRS Intensity (a.u.)")
# ax[0].set_title("Averages")
# ax[0].set_ylim([-10,40])
# ax[1].plot(t_data, srs_total.values.T, 'k', linewidth=1, alpha=0.02)
# ax[1].set_xlabel("Time (minutes)")
# # ax[1].set_ylabel("SRS Intensity (a.u.)")
# ax[1].set_title("Totals")
# ax[1].set_ylim([-2000, 8000])
# fig.suptitle("Single Cell Traces");
# if savefigs:
#     plt.savefig(fig_path+"sc_area_total.png")

## Single Cell Analysis

In [None]:
def fit_exp(x, y, p0):
    try:
        p, cov = curve_fit(exp_approach, x, y, p0=p0, bounds=(0,np.inf), absolute_sigma=True)
        out = np.array([*p, *np.diag(cov)])
    except RuntimeError:
        out = np.full(2*len(p0), np.nan)
    return out

In [None]:
# slower to use dask delayed by ~50%
df = srs_recon

all_srs_params = pd.DataFrame(index=df.index, columns=['a','k', 'sig_a', 'sig_k','x0','y0'], dtype='f4')

for i, (idx, s) in enumerate(tqdm(df.iterrows(), total=len(df))):
    y = s.values
    mask = ~np.isnan(y)
    y = y[mask]
    y_offset = y[0]
    y = y-y_offset
    x = t_data[mask]
    x_offset = x[0]
    x = x - x_offset
    out = fit_exp(x,y,fit[0])
    all_srs_params.loc[idx] = [*out, x_offset, y_offset]

In [None]:
unconverged = all_srs_params.loc[all_srs_params.isna().any(axis=1)].index
srs_params = all_srs_params.dropna()

In [None]:
# outfile = working_dir+f"{ename}_srs_tables.h5"
# srs_clean.to_hdf(outfile, key='raw')
# srs_recon.to_hdf(outfile, key='recon')
# areas_clean.to_hdf(outfile, key='area')
# srs_params.to_hdf(outfile, key='params')

In [None]:
# fig, ax = plt.subplots(1,2,sharex=True, sharey=True, figsize=(9,4))
# ax[0].plot(t_data, srs_recon.loc[srs_params.index].values.T, 'k', alpha=0.02);
# ax[0].set_title('converged')
# ax[0].set_ylim([-10,40])

# ax[1].plot(t_data, srs_recon.loc[unconverged].values.T, 'k', alpha=0.025);
# ax[1].set_title('unconverged')

In [None]:
pcts = np.linspace(0,1, 101)
qs = srs_params.quantile(pcts)

In [None]:
plt.figure()
tc = 3
plt.plot(pcts, qs.sig_a.values, '-o', markersize=2, label='sig_a')
plt.plot(pcts, qs.sig_k.values,'-o', markersize=2,label='sig_k')
plt.legend()
plt.semilogy();

In [None]:
vqs = (srs_params['a']*srs_params['k']).quantile(pcts)

In [None]:
plt.figure()

plt.plot(pcts, qs.a.values, '-o', markersize=2, label='a')
plt.plot(pcts, qs.k.values,'-o', markersize=2,label='k')
plt.plot(pcts, qs.sig_a.values, '-o', markersize=2, label='sig_a')
plt.plot(pcts, qs.sig_k.values,'-o', markersize=2,label='sig_k')
plt.plot(pcts, vqs.values, '-o', markersize=2, label='v')
plt.legend()
plt.semilogy();

In [None]:
low_q = (srs_params<qs.iloc[2]).any(axis=1)
low_q = low_q.loc[low_q]
print(len(low_q))

In [None]:
high_q = (srs_params>qs.iloc[-2]).any(axis=1)
high_q = high_q.loc[high_q]
print(len(high_q))

In [None]:
mid_q = ((srs_params>qs.iloc[2]).all(axis=1))&((srs_params<qs.iloc[-2]).all(axis=1))
mid_q = mid_q.loc[mid_q]

In [None]:
fig, ax = plt.subplots(1,3, sharey=True)
ax[0].plot(srs_recon.loc[low_q.index].values.T, 'k', alpha=0.05)
ax[1].plot(srs_recon.loc[mid_q.index].values.T, 'k', alpha=0.05)
ax[2].plot(srs_recon.loc[high_q.index].values.T, 'k', alpha=0.05);

In [None]:
k_sig_k =np.sqrt(srs_params['sig_k'])/ srs_params['k']

In [None]:
plt.figure()
_ = plt.hist(k_sig_k.loc[np.isfinite(k_sig_k)].values, bins=100, range=(0,5))

In [None]:
df = srs_recon.loc[srs_params.index]
fig, ax = plt.subplots(1,3,sharex=True, sharey=True, figsize=(10,3))
ax[0].plot(t_data, df.loc[srs_params.index].loc[srs_params['sig_k']==0].values.T, 'k', alpha=0.2);
ax[0].set_title('$\sigma_k=0$')

ax[1].plot(t_data, df.loc[srs_params['sig_k']>qs.loc[0.99, 'sig_k']].values.T, 'k', alpha=0.2);
ax[1].set_title('$\sigma_k -> inf$')

ax[2].plot(t_data, df.loc[srs_params['sig_k'].between(*qs.loc[[0.02, 0.99],'sig_k'])].values.T, 'k', alpha=0.02);
ax[2].set_title('$\sigma_k \in [1\%, 99\%]$')

In [None]:
df = srs_recon.loc[srs_params.index]
fig, ax = plt.subplots()
# l1 = plt.plot(t_data, df.loc[(srs_params['sig_k']==0)&(srs_params['sig_a']==0)].values.T, color=colors[0]);
l2=  plt.plot(t_data, df.loc[(srs_params['sig_k']==0)&(srs_params['sig_a']>0)].values.T, color=colors[1]);
l3=  plt.plot(t_data, df.loc[(srs_params['sig_k']>1)].values.T, color=colors[2]);
plt.legend([l[0] for l in ( l2, l3)],['$\sigma_k=0$ $\sigma_a=0$', '$\sigma_k=0$ $\sigma_a>0$', '$\sigma_k>1$']);
plt.title("Traces with uncoverged exponential paramters")

In [None]:
sig_k_sel = srs_params['sig_k'].between(*qs.loc[[0.02, 0.98],'sig_k'])#.between(1e-16, 1e4)
sig_a_sel = srs_params['sig_a'].between(*qs.loc[[0.02, 0.98],'sig_a'])#.between(1e-16,1e5)#.between(1e-16,10)

In [None]:
hrange =(0, srs_params['k'].quantile([0.98]).item())

In [None]:
w = 1./srs_params.loc[sig_k_sel,'sig_k']
# w = w*w
w /= w.sum()

In [None]:
fig, ax = plt.subplots(1,2, figsize=(6.52, 2.5))
df = srs_params.loc[sig_k_sel&sig_a_sel]
hranges = ([0,0.05], [0,50])
for i,x in enumerate('ka'):
    # w = 1./df[f'sig_{x}']
    # w /= w.sum()
    
    counts, bins, patches = ax[i].hist(df[x].values, bins=100, range=hranges[i])
    print(x, bins[np.argmax(counts)], df[x].mean())
    ax[i].axvline(df[x].median(), color='k', alpha=0.5, label=f'Median = {df[x].median():0.3g}')
    ax[i].legend()
ax[0].set_xlabel(r"Uptake Rate  $\left(min^{-1}\right)$")
ax[1].set_xlabel(r"Amplitude $\left( a.u. \right)$")

ax[0].set_ylabel("Cell counts")

fig.suptitle("Estimated Single-Cell Metabolic Parameters")
if savefigs:
    plt.savefig(fig_path+"sc_params.png")

In [None]:
plt.figure()
plt.plot(srs_recon.loc[srs_recon.index.intersection(srs_params.loc[(np.sqrt(srs_params.sig_k)/srs_params.k)>=1].index)].values.T, 'k', alpha=0.1);

In [None]:
fig, ax = plt.subplots(1,2, figsize=(6.52, 2.5))
df = srs_params.loc[(np.sqrt(srs_params.sig_k)/srs_params.k)<=1]
hranges = ([0,0.05], [0,50])
for i,x in enumerate('ka'):
    # w = 1./df[f'sig_{x}']
    # w /= w.sum()
    
    counts, bins, patches = ax[i].hist(df[x].values, bins=100, range=hranges[i])
    print(x, bins[np.argmax(counts)], df[x].mean())
    ax[i].axvline(df[x].median(), color='k', alpha=0.5, label=f'Median = {df[x].median():0.3g}')
    ax[i].legend()
ax[0].set_xlabel(r"Uptake Rate  $\left(min^{-1}\right)$")
ax[1].set_xlabel(r"Amplitude $\left( a.u. \right)$")

ax[0].set_ylabel("Cell counts")

fig.suptitle("Estimated Single-Cell Metabolic Parameters")
if savefigs:
    plt.savefig(fig_path+"sc_params.png")

In [None]:
df = df.drop((0,96))

In [None]:
bigk = df.loc[df.k>0.025]
smallk = df.loc[df.k<0.005]

In [None]:
xy_big= bigk.groupby('x0')['y0']
xy_small= smallk.groupby('x0')['y0']

In [None]:
xy_big.get_group(0)

In [None]:
plt.figure()
plt.errorbar(xy_big.mean().index, xy_big.mean(),yerr=xy_big.sem(), fmt='.', capsize=3)
plt.errorbar(xy_small.mean().index, xy_small.mean(),yerr=xy_small.sem(), fmt='.', capsize=3)

In [None]:
plt.figure()
plt.hist(df['x0'])

In [None]:
plt.figure()
counts, bin, patches = plt.hist((srs_params['a']*srs_params['k']).values, bins=100, range=(0,0.5))
plt.title("Net Glucose Uptake Rates $\\left( A\cdot k\\right)$")

In [None]:
params_by_age = df.join(first.reset_index('T')['T']).groupby('T').mean()
params_by_age_sem = df.join(first.reset_index('T')['T']).groupby('T').sem()

In [None]:
# fig, ax = plt.subplots(1,2)
# ax[0].errorbar(t_data[:22], params_by_age['k'], params_by_age_sem['k'], fmt='o', capsize=3)
# ax[0].set_ylabel("k")
# ax[0].set_xlabel("Time of birth")
# plt.semilogy();

In [None]:
bkgd_avgs = srs_bsub.where(ds.labels==0).mean(list('YX'))

In [None]:
bkgd_stds = srs_bsub.where(ds.labels==0).std(list('YX'))

In [None]:
print(bkgd_avgs.min().item(), bkgd_avgs.max().item())

In [None]:
print(bkgd_stds.min().item(), bkgd_stds.max().item())

In [None]:
bad = srs_recon.loc[(srs_recon<-1*bkgd_stds.mean().item()).any(axis=1)]

In [None]:
plt.figure()
plt.plot(bad.values.T, 'k', alpha=0.2);

In [None]:
bad

In [None]:
srs_params.join(pd.Series(index=bad.index, name='x'), how='inner')

In [None]:
srs_recon.loc[srs_recon[0]<-10]

In [None]:
m = check_labels_from_multiindex(ds.labels, bad.index)

In [None]:
v = napari.Viewer()
sl = v.add_image(srs_bsub)
fl = v.add_image(aligned_fluo)
ml = v.add_labels(m)

In [None]:
params = srs_params.loc[(sig_k_sel&sig_a_sel)]
traces = srs_recon.loc[params.index]
fig, ax = plt.subplots(1,5, sharey=True)
q = params['k'].quantile(np.linspace(0,1,6))
for i,(r,x) in enumerate(traces.groupby(pd.cut(params['k'], q.values))):
    print(i)
    ax[i].plot(x.values.T, 'k', alpha=0.05);
    ax[i].set_title(f"{0.2*i:0.1f}-{0.2*(i+1):0.1f}");

In [None]:
nanbkgd = srs_bsub.where(ds.labels==0)

In [None]:
be.cv_labels.load();

In [None]:
bkgd_stds = mu.single_cell.standard_dev(be.cv_labels.to_dataset(name='labels'), srs_bsub)

In [None]:
bkgd_sds = bkgd_stds.to_series().dropna()

In [None]:
params = srs_params.loc[(sig_k_sel&sig_a_sel)]
traces = srs_recon.loc[params.index].drop((0,96))
fig, ax = plt.subplots(figsize=(6.52,3))
ax.set_aspect(9, anchor='C');
pct = np.linspace(0,1,6)
q = (params['k']).quantile(pct)
inds = []
for i,(r,x) in enumerate(traces.groupby(pd.cut(params['k'], q.values))):
    print(i)
    inds.append(x.index)
    ax.errorbar(t_data, x.mean().values, yerr=x.sem(), fmt='-o', capsize=3, markersize=3, label=f"{100*pct[i]:0.0f}-{100*pct[i+1]:0.0f}%");
    # ax[i].set_title(r);
ax.legend(title="Percentiles of $k$")
plt.title("Net glucose uptake by $k$-Percentiles")
plt.xlabel("Time (minutes)")
plt.ylabel("SRS (a.u.)")
# plt.savefig(fig_path+"kquantiles.png")
# m = bkgd_avgs.groupby('T').mean().values#nanbkgd.mean(list("SYX")).data
# s = bkgd_avgs.groupby('T').std().values#nand.std(list("SYX")).data
# plt.fill_between(t_data, m+s, m-s, color='k', alpha=0.5)

In [None]:
params = srs_params.loc[(sig_k_sel&sig_a_sel)]
df = srs_recon.loc[params.index]
# xdata = dt_min*df.columns.values + 30
y = df.mean().values
y_offset = y[0]
y = y-y_offset
sigma = df.sem().values

fit = curve_fit(exp_approach, t_data, y, sigma=sigma, p0=[20, 0.05], bounds=(0,np.inf))

a,k = fit[0]
eqn = f"Exponential fit: \n{a:0.1f} (1-exp[-{k:0.2g}t])"
xplt = np.linspace(0, 550, 250)
plt.figure()
plt.plot(xplt, exp_approach(xplt, *fit[0])+y_offset, zorder=0, linewidth=3, alpha=0.75, label=eqn)
# plt.plot(t_data, y+y_offset, 'k.', zorder=-1)
plt.errorbar(t_data, y+y_offset,yerr=sigma, fmt='k.', capsize=3, zorder=-1)
plt.title(f"Wild Type Net Glucose Uptake")
plt.legend()
plt.ylabel("SRS Intensity (a.u.)")
plt.xlabel("Time (minutes)")
if savefigs:
    plt.savefig(fig_path+"bulk_fit.png")

In [None]:
# m = check_labels_from_multiindex(ds.labels,inds[0])

# v = napari.Viewer()
# sl = v.add_image(srs_bsub)
# ml = v.add_labels(m)

In [None]:
print("mean", srs_params['k'].mean())
print("median", srs_params['k'].median())
print("selection mean", srs_params.loc[sig_k_sel&sig_a_sel,'k'].mean())
print("inv var mean", (srs_params.loc[sig_k_sel&sig_a_sel,'k']*w).sum())

In [None]:
fig, ax = plt.subplots(1,2, figsize=(6.52, 2.5))
df = srs_params.loc[~(sig_k_sel&sig_a_sel)]
hranges = ([0,0.05], [0,40])
for i,x in enumerate('ka'):
    # w = 1./df[f'sig_{x}']
    # w /= w.sum()
    
    ax[i].hist(df[x].values, bins=100, range=hranges[i])
    ax[i].axvline(df[x].median(), color='k', alpha=0.5, label=f'Median = {df[x].median():0.3g}')
    ax[i].legend()
ax[0].set_xlabel(r"Uptake Rate  $\left(min^{-1}\right)$")
ax[1].set_xlabel(r"Amplitude $\left( a.u. \right)$")

ax[0].set_ylabel("Cell counts")

fig.suptitle("Estimated Single-Cell Metabolic Parameters")
if savefigs:
    plt.savefig(fig_path+"bad_sc_params.png")

In [None]:
df = srs_params.loc[(sig_k_sel&sig_a_sel)]
p = df['a']*df['k']
plt.figure()
# x = p.loc[p>0.0002]
plt.hist(p.values, bins=100, range=(0,0.6));
print(p.median())

In [None]:
big_k = srs_params.loc[srs_params['k']>0.02].index
big_k_traces = srs_recon.loc[big_k]
n_big_k = big_k_traces.shape[0] - big_k_traces.isna().sum()
small_k = srs_params.loc[srs_params['k']<0.0025].index
small_k_traces = srs_recon.loc[small_k]
n_small_k = small_k_traces.shape[0] - small_k_traces.isna().sum()

In [None]:
plt.figure()
plt.errorbar(t_data, big_k_traces.mean(), yerr = big_k_traces.std(), capsize=3, fmt='-o')
plt.errorbar(t_data, small_k_traces.mean(), yerr = small_k_traces.std(), capsize=3, fmt='-o')

In [None]:
plt.figure()
plt.plot(t_data, n_big_k)
plt.plot(t_data, n_small_k)

In [None]:
fig, ax = plt.subplots(1,3,sharex=True, sharey=True)
ax[0].plot(big_k_traces.values.T, 'k', alpha=0.05)
ax[0].set_title("Big k")
ax[1].plot(small_k_traces.values.T,'k', alpha=0.05) 
ax[1].set_title("Small k")
ax[2].plot(srs_clean.loc[srs_params['k'].loc[srs_params['k'].between(0.0025, 0.2)].index].values.T, 'k', alpha=0.05)
ax[2].set_title("Medium k")

In [None]:
from srs_tools.util import check_labels_from_multiindex

In [None]:
# big_k_labels = check_labels_from_multiindex(ds.labels, big_k)
# small_k_labels = check_labels_from_multiindex(ds.labels, small_k)

# v = napari.Viewer()
# il = v.add_image(aligned_fluo)
# big_l = v.add_labels(big_k_labels>0, color={1:'red'}, name="big k", opacity=0.5)
# small_l = v.add_labels(small_k_labels>0, color={1:'blue'}, name='small k', opacity=0.5)

## Age dependence

In [None]:
tser = first.reset_index('T')['T']

In [None]:
all_t0 = all_first.reset_index('T')['T']

In [None]:
plt.figure()
counts, bins, patches = plt.hist(all_recon.loc[all_t0<5,35], bins=80, range=(0,40),density=True, alpha=0.6)
plt.hist(all_recon.loc[all_t0.between(20,25),35], bins=bins, density=True, alpha=0.6)

In [None]:
plt.figure()
xold= srs_recon.loc[tser<5, 35].values
xnew = srs_recon.loc[tser.between(15,21), 35].values
counts, bins, patches = plt.hist(xold, bins=80, alpha=0.6, density=True, label="Oldest")
plt.hist(xnew, bins=bins, alpha=0.6, density=True, label="Newest");

In [None]:
print(np.quantile(xnew[~np.isnan(xnew)], [0.25, 0.75]))
print(np.quantile(xold[~np.isnan(xold)], [0.25, 0.75]))

In [None]:
nold = xold.shape[0]-np.isnan(xold).sum()

In [None]:
nnew = xnew.shape[0]-np.isnan(xnew).sum()

In [None]:
for z in [xold , xnew]:
    denom = np.sqrt(z.shape[0]-np.isnan(z).sum())
    print(np.nanmean(z), np.nanmedian(z), np.nanstd(z)/denom)

In [None]:
Tmax = 36
ngroups = 3
width = (Tmax - nconsec)/ngroups
bins = width*np.arange(ngroups+1)-0.5
cut = pd.cut(tser, bins)
gb = srs_recon.groupby(cut.cat.codes)
by_age=gb.mean()
by_age_sem=gb.sem()
by_age_count = gb.count()

In [None]:
# plt.figure()
# plt.plot(srs_recon.loc[all_srs_params['sig_k']==0].values.T, 'k', alpha=0.1);

# plt.figure()
# plt.plot(srs_recon.loc[all_srs_params['sig_k']==0].values.T, 'k', alpha=0.1);

In [None]:
plt.figure()
for t in range(ngroups):
    plt.errorbar(t_data, by_age.loc[t].values, yerr=gb.sem().loc[t].values,fmt='.', capsize=3)

In [None]:
plt.figure()
params = pd.DataFrame(index=range(ngroups),columns=['a','k','t'])
for t in range(ngroups):
    y = by_age.loc[t].values
    mask = ~np.isnan(y)
    x = t_data[mask]
    y = y[mask]
    sigma = by_age_sem.loc[t].values[mask]
    fit = curve_fit(exp_approach2, x, y, p0=[20, 0.01,50], bounds=(0,np.inf))
    params.loc[t] = fit[0]
    plt.plot(x,y, '.', color=colors[t])
    plt.plot(xplt ,exp_approach2(xplt, *fit[0]), color=colors[t], label=" - ".join([f"{f:0.3f}" for f in fit[0]]))
plt.legend(title="a       -     k     -     t")
plt.axhline(0,color='k', linestyle=':')
plt.ylim(-5,25)

In [None]:
xplt = np.linspace(t_data.min(), t_data.max(), 101)

In [None]:
dt.mean().total_seconds()/60

In [None]:
t_bins = dt.mean().total_seconds()/60*(bins+0.5)

In [None]:
plt.figure()
params = pd.DataFrame(index=range(ngroups),columns=['a','k'])
fits = []
for t in range(ngroups):
    y = by_age.loc[t].values
    mask = ~np.isnan(y)
    x = t_data[mask]
    x_offset = x[0]
    x = x - x_offset
    y = y[mask]
    y_offset = y[0]
    y = y-y_offset
    fit = curve_fit(exp_approach, x, y, sigma=by_age_sem.loc[t][mask], p0=[20, 0.01], bounds=(0,np.inf))
    fits.append(fit)
    plt.errorbar(x+x_offset,y+y_offset, yerr=by_age_sem.loc[t][mask], fmt='.', capsize=2, color=colors[t])
    xx = xplt[xplt>=x_offset]
    plt.plot(xx , exp_approach(xx-x_offset, *fit[0])+y_offset, color=colors[t], label=f"{t_bins[t]:0.1f} ≤"+" $T_{Birth}$ "+f"< {t_bins[t+1]:0.1f} \n {fit[0][0]:0.1f}(1-exp[-{fit[0][1]:0.3g}])", alpha=0.6)
# plt.legend(title="    A --- k --- $y_0$")
plt.xlabel("Time (minutes)")
plt.ylabel("SRS Intensity (a.u.)")
plt.title("Net glucose uptake in cells grouped by age")
plt.legend()
if savefigs: plt.savefig(fig_path+"uptake_by_age.png")
# plt.axhline(0,color='k', linestyle=':')
# plt.ylim(0,25)

In [None]:
for f in fits:
    print(np.product(f[0]))

In [None]:
k = [f[0][1] for f in fits]
s = [np.sqrt(f[1][1,1]) for f in fits]
plt.figure()
plt.errorbar(t_bins[:-1] +np.mean(t_bins[:2]), k, yerr=s, fmt='-o', capsize=3)

In [None]:
codes = cut.cat.codes[cut.cat.codes>-1].rename("codes")

In [None]:
ps = pd.concat([srs_params.loc[sig_k_sel&sig_a_sel,'k'], codes], axis=1, join='inner')

In [None]:
plt.figure()
for i in range(3):
    x = ps.loc[ps['codes']==i, 'k']
    plt.hist(x.values, bins=51, range=(0,0.05), density=True, alpha=0.6)
    plt.axvline(x.median(), color=colors[i])

$$ y = a(1-e^{-k(t-t_0)})$$

In [None]:
cell_count = (~srs_clean.isna()).sum()

In [None]:
all_count = srs_avgs.groupby('T').count()

In [None]:
lr = linregress(t_data, np.log2(all_count.values))

In [None]:
lr2 = linregress(t_data[:20], np.log2(all_count.values)[:20])

In [None]:
print(1/lr2.slope)

In [None]:
print(1/lr.slope)

In [None]:
plt.figure()
plt.plot(t_data, cell_count.values, '-.')
plt.plot(t_data, all_count.values.astype(int), '.')
plt.plot(t_data, np.exp(np.log(2)*(lr.slope*t_data + lr.intercept)), 'k')
plt.semilogy()

In [None]:
plt.figure()
params = pd.DataFrame(index=range(5),columns=['a','k'])
for t in range(3):
    y = by_age.loc[t].values
    mask = ~np.isnan(y)
    x = t_data[mask]
    y = y[mask]
    sigma = by_age_sem.loc[t].values[mask]
    fit = curve_fit(exp_approach, x, y, p0=[20, 0.01], bounds=(0,np.inf))
    params.loc[t] = fit[0]
    plt.plot(x,y, '.', color=colors[t])
    plt.plot(xplt ,exp_approach(xplt, *fit[0]), color=colors[t], label=" - ".join([f"{f:0.3f}" for f in fit[0]]))
plt.legend(title="    a       -     k   ")
plt.axhline(0,color='k', linestyle=':')
plt.ylim(-5,25)

In [None]:
q = times.groupby('T').quantile([0.025, 0.975])

In [None]:
errs = q.unstack(level=-1).values.T

## Area growth rates

In [None]:
# slower to use dask delayed by ~50%
# df = growth_cells

# all_area_params = pd.DataFrame(index=df.index, columns=['a','k', 't', 'sig_a', 'sig_k', 'sig_t'], dtype='f4')

# for i, (idx, s) in enumerate(tqdm(df.iterrows(), total=len(df))):
#     y = s.values
#     mask = ~np.isnan(y)
#     y = y[mask]
#     x = t_data[mask]
#     out = fit_exp(x,y,[200,0.05,x[0]])
#     all_area_params.loc[idx] = out

# # unconverged = all_params.loc[all_params.isna().any(1)].index
# area_params = all_area_params.dropna()

# combo_rates = pd.concat([area_params['k'].rename('k_srs'), srs_params['k'].rename('k_area')],axis=1, join='inner')

# combo_rates.corr(method='spearman') 

In [None]:
# srs_total_mean = srs_bsub.where(ds.labels>0).mean(list('YX'))
# srs_scaled_avg = (1+kopt.values)*srs_total_mean

In [None]:
srs_scaled_avg = srs_recon.groupby('S').mean()

In [None]:
a0 = areas.unstack('T')

In [None]:
cell_counts = (a0>0).groupby('S').sum()

In [None]:
fig, ax = plt.subplots(1,2)
area_slope = []
srs_slope = []
log = np.log
for i in range(5):
    # a = vol_unit*vfactor*(total_area.data[i]**(3./2))
    a = cell_counts.values[i]
    # a = total_area.data[i]
    # a = volumes.values[i]
    lr_a = linregress(t_data, log(a))
    area_slope.append(lr_a.slope)
    ax[0].plot(t_data, a, color=colors[i], label=f"{lr_a.slope:0.3g}±{lr_a.stderr:0.3g}")
    # s = srs_total_mean.data[i]
    # s = srs_scaled_avg.data[i]
    s = srs_scaled_avg.values[i]
    s = s-s.min()
    
    # s = srs_recon.loc[i].mean().values
    # lr_s = linregress(t_data, np.log(s))
    fit_s = fit_exp(t_data, s, [20,0.01])
    # srs_slope.append(lr_s.slope)
    srs_slope.append(fit_s[1])
    # ax[1].plot(t_data, s, color=colors[i], label=f"{lr_s.slope:0.3g}±{lr_s.stderr:0.3g}")
    ax[1].plot(t_data, s, color=colors[i], label=f"{fit_s[1]:0.3g}±{fit_s[-1]:0.3g}")
for a in ax:
    a.legend(title="Slope ± std. err. $\\left(min^{-1}\\right)$")
    a.semilogy()
    a.set_xlabel("Time (minutes)")
ax[0].set_title("Number of cells (per FOV - log scale)")
ax[0].set_ylabel("Approximate Volume $\\left(\mu m^3 \\right)$")
ax[1].set_title("Average Cellular SRS (per FOV)")
ax[1].set_ylabel("SRS intenstity (a.u.)")

In [None]:
r=[]
for x,y in zip(srs_slope,area_slope):
    r.append(x/y)
    print(x, y, x/y)

In [None]:
print(r)

In [None]:
np.mean(r)

In [None]:
np.std(r)/np.sqrt(5)

In [None]:
print(f"{np.mean(r):0.3g} ± {np.std(r)/np.sqrt(5):0.3g}")

In [None]:
plt.figure()
plt.plot(srs_scaled_avg.values.T)

In [None]:
outfile = "/Users/johnrussell/Data/2023-05-16/dh229_srs_tables.h5"

In [None]:
srs_clean.to_hdf(outfile, key='raw')
srs_recon.to_hdf(outfile, key='recon')
areas_clean.to_hdf(outfile, key='area')
srs_params.to_hdf(outfile, key='params')