# All figures for Chapter 2

In [None]:
from dask.distributed import Client
client = Client(processes=False)
client

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import microutil as mu
import scipy.ndimage as ndi
from scipy.stats import linregress
from skimage.morphology import disk
from skimage.segmentation import expand_labels
from skimage.filters import sobel
from skimage.registration import phase_cross_correlation
from srs_tools import BackgroundEstimator

%matplotlib widget
plt.style.use("paper")

colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
labels_cmap = plt.cm.viridis.copy()
labels_cmap.set_under(alpha=0)

mask_cmap = plt.cm.gray_r.copy()
mask_cmap.set_under(alpha=0)

rlabels = plt.cm.Reds.copy()
rlabels.set_under(alpha=0)

blabels = plt.cm.Blues.copy()
blabels.set_under(alpha=0)

olabels = plt.cm.Oranges.copy()
olabels.set_under(alpha=0)

figpath = "/Users/johnrussell/Documents/figures/2023-06-figs-paper/"
savefigs=False

## IO and analysis

### Blank Agarose Dataset

In [None]:
blank_ds = xr.open_zarr("/Users/johnrussell/Data/2023-04-02/d7g_agarose.zarr").load()
blank_ds['squash'] = blank_ds.images.isel(C=-1).mean('Z', dtype='f4')
blank_ds['mask'] = blank_ds.ignore_mask.isel(C=1)

In [None]:
blank_ds['single_labels'] = xr.DataArray(BackgroundEstimator._make_cv_labels(blank_ds.mask.any('S').astype('u2'), N=24), dims=list('YX'))

In [None]:
expanded = expand_labels(blank_ds['single_labels'].squeeze().data, distance=3)
expanded = (blank_ds['mask'].astype(bool).any('S').data==0)*expanded

In [None]:
neighborhood = expanded-blank_ds['single_labels'].squeeze().data

In [None]:
nan_masked = blank_ds['squash'].where(blank_ds['mask']==0)
scales = nan_masked.mean(list('YX'))
# offset = nan_masked.mean(list('YX'))
# scale = nan_masked.std(list('YX'))
scores = (blank_ds['squash']/scales)
# scores = (nan_masked-offset)/scale
bkgd_est = scores.median('S')*scales
# bkgd_est = xr.DataArray(bkgd_est, dims=list("SYX"))
# bkgd_est = scores.mean('S')*scale + offset

In [None]:
shared_avgs = (mu.single_cell.average(blank_ds['single_labels'].squeeze().to_dataset(name='labels'), blank_ds['squash'])
               .to_dataframe(name='avg')
               .unstack('S')
               .dropna())
shared_avgs.columns = list(range(10))
est_avgs = (mu.single_cell.average(blank_ds['single_labels'].squeeze().to_dataset(name='labels'), bkgd_est)
               .to_dataframe(name='avg')
               .unstack('S')
               .dropna())
est_avgs.columns = list(range(10))

In [None]:
rmse = np.sqrt(((shared_avgs-est_avgs)**2).mean(1))

In [None]:
neighbor_labels = xr.Dataset({'labels':xr.DataArray(neighborhood, dims=list('YX'))})

In [None]:
neighbor_avgs = (mu.single_cell.average(neighbor_labels, blank_ds['squash'])
                 .to_dataframe(name='avg')
                 .unstack('S')
                 .dropna())
neighbor_avgs.columns = list(range(10))

In [None]:
neighbor_rmse = np.sqrt(((shared_avgs-neighbor_avgs)**2).mean(1))

In [None]:
blank_be = BackgroundEstimator(blank_ds['squash'].expand_dims('T', axis=1), blank_ds['mask'].expand_dims('T', axis=1))

In [None]:
blank_be.make_cv_labels()
# blank_be.sigma_scan()
# blank_be.sigma_opt.load();
blank_be.sigma_opt = xr.DataArray([2,42],dims=['k'])
print(blank_be.sigma_opt)

In [None]:
true_avgs = (mu.single_cell.average(blank_be.cv_labels.to_dataset(name='labels'), blank_ds.squash)
             .to_series()
             .dropna())
pred_avgs = (mu.single_cell.average(blank_be.cv_labels.to_dataset(name='labels'), blank_be.background_estimate.squeeze())
             .to_series()
             .dropna())

In [None]:
final_avgs = (mu.single_cell.average(blank_ds['single_labels'].squeeze().to_dataset(name='labels'), blank_be.background_estimate.squeeze())
               .to_dataframe(name='avg')
               .unstack('S')
               .dropna())
final_avgs.columns = list(range(10))

In [None]:
final_rmse = np.sqrt(((final_avgs-shared_avgs)**2).mean(1))

### Scanspeed test cell mixture

In [None]:
scanspeed_ds = xr.open_zarr("/Users/johnrussell/Data/2023-02-27/cell_scanspeed.zarr/")
scanspeed_ds['squash'] = scanspeed_ds.images.isel(C=-1).mean('Z', dtype='f4').load()
scanspeed_ds['fluo'] = scanspeed_ds.images.isel(C=0, T=3).mean('Z').load()
scanspeed_ds.labels.load()
scanspeed_ds['cv_labels'] = xr.apply_ufunc(BackgroundEstimator._make_cv_labels, scanspeed_ds['labels'], vectorize=True,
            input_core_dims=[list("YX")],
            output_core_dims=[list("YX")],
            dask="parallelized",
            output_dtypes=["u2"],
            dask_gufunc_kwargs={"allow_rechunk": True})
scanspeed_ds['bf'] = scanspeed_ds.images.isel(C=1, T=3, Z=2)

In [None]:
bf_edges = xr.apply_ufunc(sobel, scanspeed_ds['bf'], input_core_dims=[list('YX')], output_core_dims=[list('YX')], vectorize=True, dask='parallelized').load()
fluo_edges = xr.apply_ufunc(sobel, scanspeed_ds['fluo'], input_core_dims=[list('YX')], output_core_dims=[list('YX')], vectorize=True, dask='parallelized').load()

fluo_aligned_labels = xr.zeros_like(scanspeed_ds.labels)
for i in range(scanspeed_ds.sizes['I']):
    shift, _, _ = phase_cross_correlation(fluo_edges[i].data, bf_edges[i].data)
    fluo_aligned_labels[i] = scanspeed_ds.labels[i].shift(dict(zip('YX',shift.astype(int)))).fillna(0).astype(int)
    scanspeed_ds['fluo'][i] = scanspeed_ds['fluo'][i].shift(dict(zip('YX',-shift.astype(int)))).fillna(0).astype(int)

n = 3
d = scanspeed_ds.labels.ndim
structure= np.zeros(d*(2*n+1,), dtype='bool')
structure[n]= disk(n)
fluo_dilated = ndi.binary_dilation(fluo_aligned_labels.data>0, structure=structure, iterations=3)

fluo_init_bkgd = xr.apply_ufunc(BackgroundEstimator._make_initial_estimate,
                                scanspeed_ds['fluo'], scanspeed_ds['labels']>0,
                                input_core_dims=[list("IYX"),list("IYX")], output_core_dims=[list('IYX')])

fluo_bsub = np.clip(scanspeed_ds['fluo']-fluo_init_bkgd, a_min=0, a_max=np.inf)
fluo_bsub /= fluo_bsub.max(list('YX'))

fluo_thresh = mu.calc_thresholds(fluo_bsub)
# fluo_mask = (fluo_bsub> 0.5*fluo_thresh)*(fluo_aligned_labels>0)
fluo_mask = (fluo_bsub> 0.25*fluo_thresh)*(scanspeed_ds.labels>0)

# fluo_avgs = mu.single_cell.average(fluo_aligned_labels.to_dataset(), fluo_mask).to_series().dropna()
fluo_avgs = mu.single_cell.average(scanspeed_ds.labels.to_dataset(), fluo_mask).to_series().dropna()

fluo_series = (fluo_avgs>0.25)

In [None]:
from srs_tools.util import check_labels_from_multiindex

iplt = 2
idx = fluo_series.loc[fluo_series].loc[iplt].index+1

# f = scanspeed_ds['fluo'].isel(I=iplt).data
f = fluo_bsub.isel(I=iplt).data
m = np.isin(scanspeed_ds.labels.isel(I=iplt), idx)

fig, ax = plt.subplots(1,2, sharex=True, sharey=True)
ax[0].imshow(f, cmap='gray')
ax[0].imshow(fluo_mask.isel(I=iplt).data, vmin=0.5, cmap=blabels)
ax[1].imshow(f, cmap='gray')
ax[1].imshow(m, cmap=rlabels, vmin=0.5, alpha=0.3)

In [None]:
scanspeed_raw_avgs = (mu.single_cell.average(scanspeed_ds, 'squash', label_name='labels')
              .to_series()
              .rename('srs')
              .dropna()
              .reset_index('T')
              .rename(columns={'T':'speed'})
              .assign(T=lambda df: np.log2(df['speed']).astype(int)-3)
              .set_index('T', append=True)
              .assign(d7g=fluo_series)
              .dropna()
              .astype({'d7g':bool}))

scanspeed_cv_avgs = (mu.single_cell.average(scanspeed_ds, 'squash', label_name='cv_labels')
              .to_series()
              .rename('srs')
              .dropna()
              .reset_index('T')
              .rename(columns={'T':'speed'})
              .assign(T=lambda df: np.log2(df['speed']).astype(int)-3)
              .set_index('T', append=True)
              .assign(d7g=fluo_series)
              .dropna()
              .astype({'d7g':bool}))

scanspeed_cv_sds = (mu.single_cell.standard_dev(scanspeed_ds, 'squash', label_name='cv_labels')
              .to_series()
              .rename('srs')
              .dropna()
              .reset_index('T')
              .rename(columns={'T':'speed'})
              .assign(T=lambda df: np.log2(df['speed']).astype(int)-3)
              .set_index('T', append=True)
              .assign(d7g=fluo_series)
              .dropna()
              .astype({'d7g':bool}))

In [None]:
scanspeed_be = BackgroundEstimator(
    scanspeed_ds.squash.isel(T=2).drop_vars('T').rename({'I':'S'}).expand_dims('T', axis=1),
    scanspeed_ds.labels.rename({'I':'S'}).expand_dims('T', axis=1),)

In [None]:
scanspeed_be.make_cv_labels()
# scanspeed_be.sigma_scan()
# scanspeed_be.sigma_opt.load();
scanspeed_be.sigma_opt = xr.DataArray([10,10], dims='k')
print(scanspeed_be.sigma_opt.data)
scanspeed_be.background_estimate.load();

In [None]:
scanspeed_bsub = (scanspeed_ds.squash.isel(T=2).rename({'I':'S'}).squeeze()-scanspeed_be.background_estimate).squeeze()

In [None]:
scanspeed_bsub_avgs = (mu.single_cell.average(scanspeed_be.labels.squeeze().to_dataset(name='labels'),scanspeed_bsub,)
              .to_series()
              .rename('srs')
              .dropna()
              .to_frame() 
              .assign(d7g=fluo_series)
              .dropna()
              .astype({'d7g':bool}))

### Demo timelapse frame

In [None]:
ds= xr.open_zarr("/Users/johnrussell/Data/2023-02-27/dh224_analysis.zarr")

demo_imgs = ds.images.isel(T=32, C=2).mean('Z').load()
demo_labels = ds.labels.isel(T=32).load()
demo_be = BackgroundEstimator(demo_imgs.expand_dims('T', axis=1), demo_labels.expand_dims('T', axis=1))
demo_be.make_cv_labels()
# demo_be.sigma_scan()
# demo_be.sigma_opt.load();
demo_be.sigma_opt = xr.DataArray([ 2., 82.], dims='k')
print(demo_be.sigma_opt)
# demo_be.background_estimate.load();

In [None]:
demo_ds = demo_be.to_dataset().isel(S=8).squeeze()
demo_dilate = demo_be.dilated_mask.isel(S=8).squeeze()

In [None]:
pair = demo_be.to_dataset().isel(S=[4,8])[['images', 'labels']]
pair_be = BackgroundEstimator(pair.images, pair.labels)
pair_be.make_dilated_mask()
nan_mask = pair.images.where(pair_be.dilated_mask==0).squeeze()
init = pair_be.initial_estimate.squeeze()

### DH245 Timelapse

In [None]:
dh245_ds = xr.open_zarr("/Users/johnrussell/Data/2023-04-25/dh245_timelapse.zarr/").isel(T=slice(1,None))

In [None]:
dh245_ds['srs_squash'] = dh245_ds.images.isel(C=-1).mean('Z', dtype='f4')

In [None]:
dh245_be = BackgroundEstimator(dh245_ds.srs_squash, dh245_ds.labels)
dh245_be.make_cv_labels()
# dh245_be.sigma_scan(n_samples=5)
# dh245_be.sigma_opt.load();
dh245_be.sigma_opt = xr.DataArray([10,34], dims=['k'])
print(dh245_be.sigma_opt)

In [None]:
dh245_be.background_estimate.load();

In [None]:
dh245_ds.srs_squash.load();

In [None]:
bkgd = dh245_be.background_estimate.where(~dh245_be.dilated_mask).mean(list('SYX')).load()
bulk = dh245_ds.srs_squash.where(dh245_ds.labels>0).mean(list('SYX')).load()

In [None]:
t245 = 14.2*np.arange(bulk.shape[0])

In [None]:
bkgd_obs = dh245_ds.srs_squash.where(~dh245_be.dilated_mask).mean(list('SYX')).load()

In [None]:
# raw_avgs = (mu.single_cell.average(ds, srs)
#             .to_series()
#             .dropna())

## Figures

In [None]:
# figure 1
fig, ax = plt.subplots(1,2, figsize=(6.52, 2.5))
ax[0].imshow(blank_ds['squash'].isel(S=1).data, cmap='gray')
ax[0].axis('off')
ax[0].set_title("Blank SRS Image")

ax[1].hist(blank_ds['squash'].isel(S=1).data.ravel(), bins=255, range=(0.5,255.5));
ax[1].set_ylabel("Number of pixels")
ax[1].set_xlabel("Intensity")
ax[1].set_title("Pixel Intensity Distribution")
if savefigs: plt.savefig(figpath+"image_with_hist.png")

In [None]:
# Autocorrelation plots?
# y = blank_ds.squash.isel(S=0).mean('X').data
# y = (y-y.mean())/(y.std())
# x = blank_ds.squash.isel(S=0).mean('Y').data
# x = (x-x.mean())/x.std()

# a = np.zeros_like(y)
# for i in range(512):
#      a[i] = np.mean(x*np.roll(x, i))

# cx = np.correlate(x,x,mode='same')/512
# cy = np.correlate(y,y,mode='same')/512

# plt.figure()
# plt.plot(a)

# plt.figure()
# plt.plot(np.arange(256),cx[256:])
# plt.plot(np.arange(256),cy[256:])

In [None]:
# time varying background
fig, ax = plt.subplots()
ax.plot(t245, bulk, label='Average Cell Signal')
ax.plot(t245, bkgd, label='Background')
ax.set_title("Changes in Background Over Time")
ax.legend()
ax.set_ylabel("SRS Intensity (a.u.)")
ax.set_xlabel("Time (minutes)")
if savefigs: plt.savefig(figpath+"time_varying_bkgd.png")

In [None]:
df = pd.concat([scanspeed_cv_avgs.srs.rename('mean'), scanspeed_cv_sds.srs.rename('sd')], axis=1)
t = scanspeed_cv_avgs.speed.unique()

fig, ax = plt.subplots(1,2,sharey=True,layout='constrained', figsize=(6.52,3))
for i in range(t.shape[0]):
    d = df.loc[pd.IndexSlice[:,:,5-i]]
    ax[0].plot('mean','sd', ".", data=d, markersize=2, label=f"{t[5-i]:d} Hz")
    lr = linregress(d['mean'].values.squeeze(), d['sd'].values.squeeze())
    xplt = np.linspace(d['mean'].quantile(0.01), d['mean'].quantile(0.99))
    ax[0].plot(xplt, lr.slope*xplt+lr.intercept, '--', color=colors[i], linewidth=1, zorder=-1,)
ax[0].legend(loc='center right', markerscale=5)#, bbox_to_anchor=(1.5,0.99));
ax[0].set_xlabel("Local Average Intensity (a.u.)")
ax[0].set_ylabel("Local Noise (log scale - a.u.)")

ax[1].loglog();
ax[1].errorbar(t,df.groupby('T').sd.mean().values, yerr=df.groupby('T').sd.std().values,capsize=3, fmt='ko')#,label="Average Noise Level $\pm \sigma$")
scale = df.groupby('T').sd.mean().iloc[1]/np.sqrt(t[1])
ax[1].plot(t, scale*np.sqrt(t), 'k--', alpha=0.5, zorder=-1, label="$\sqrt{f}$")
ax[1].set_xlabel("Scan Frequency (log scale)")
ax[1].legend(loc='lower right')

plt.suptitle("Noise is Independent of Background")
if savefigs: plt.savefig(figpath+"raw_noise_scan_speed.png")

In [None]:
fig,ax = plt.subplots(1,2, figsize=(6.52,3))
ax[0].imshow(blank_ds['squash'].isel(S=0).data, cmap='gray')
ax[0].imshow(blank_ds['mask'].any('S').data, cmap=mask_cmap, vmin=0.5, interpolation='none')
ax[0].imshow(blank_ds['single_labels'].squeeze(), cmap=labels_cmap, vmin=0.5, interpolation='none')
ax[0].set_xticks([])
ax[0].set_yticks([]);
ax[0].set_title("Aligned test regions", fontsize=12)

ax[1].imshow(blank_ds['squash'].isel(S=0).data, cmap='gray')
ax[1].imshow(blank_ds['mask'].any('S').data, cmap=mask_cmap, vmin=0.5, interpolation='none')
ax[1].imshow(neighborhood, cmap=labels_cmap, vmin=0.5, interpolation='none')
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[1].set_title("Local neighborhoods", fontsize=12)
if savefigs: plt.savefig(figpath+"region_donuts_overlay.png")

In [None]:
plt.figure()
counts, bins, patches = plt.hist(neighbor_rmse, bins=100, density=True, range=(0,4), alpha=0.75, 
                                 label=f"Neighborhood estimate (median={neighbor_rmse.median():0.2f})")
plt.axvline(neighbor_rmse.median(), linestyle='--', color=colors[0])
plt.hist(rmse, bins=bins, alpha=0.75, density=True, label=f"Estimate from other FOVs (median={rmse.median():0.2f})");
plt.axvline(rmse.median(), linestyle='--', color=colors[1])
# plt.hist(final_rmse, bins=bins, alpha=0.75, density=True,label=f"Final Estimate (median={final_rmse.median():0.2f})")
# plt.axvline(final_rmse.median(), linestyle='--', color=colors[2])
plt.xlabel("RMSE")
plt.ylabel("Density")
plt.legend();
plt.title("Multi-FOV vs. local estimation")
if savefigs: plt.savefig(figpath+"estimate_comparison.png")

In [None]:
iplt = 2

In [None]:
all_labels = scanspeed_ds.labels.isel(I=iplt).data

In [None]:
# d7g_mask = np.isin(all_labels,scanspeed_raw_avgs.loc[scanspeed_raw_avgs.d7g].loc[(iplt, slice(None), tplt)].index.values+1)
d7g_mask = np.isin(all_labels,fluo_series.loc[fluo_series].loc[iplt].index+1)

In [None]:
# unl_mask = np.isin(all_labels,scanspeed_raw_avgs.loc[~scanspeed_raw_avgs.d7g].loc[(iplt, slice(None), tplt)].index.values+1)
unl_mask = np.isin(all_labels,fluo_series.loc[~fluo_series].loc[iplt].index+1)

In [None]:
# d7g_rings = expand_labels(d7g_mask, distance=2) ^ d7g_mask
d7g_rings = d7g_mask^ndi.binary_erosion(d7g_mask, iterations=2)
# unl_rings = expand_labels((all_labels*(~d7g_mask))>0, distance=2) ^ ((all_labels*(~d7g_mask))>0)
unl_rings = unl_mask^ndi.binary_erosion(unl_mask, iterations=2)

In [None]:
iplt = 2
tplt = 2
xsel = slice(327, 490)
ysel = slice(62, 196)
bfshow = scanspeed_ds.images.isel(I=iplt, T=tplt, C=1, Z=7, X=xsel, Y=ysel).load().data
df = scanspeed_raw_avgs.loc[pd.IndexSlice[:,:,tplt]]
fig, ax = plt.subplot_mosaic("ABC", figsize=(6.52,3), sharex=True, sharey=True)
ax['A'].imshow(scanspeed_ds['squash'].isel(I=iplt, T=tplt, X=xsel, Y=ysel).data, cmap='gray')
ax['A'].axis('off')
ax['A'].set_title("SRS")

ax['B'].imshow(scanspeed_ds['fluo'].isel(I=iplt, X=xsel, Y=ysel).data, cmap='gray')
ax['B'].axis('off')
ax['B'].set_title("Fluorescence")

ax['C'].imshow(bfshow, cmap='gray')
ax['C'].axis('off')
ax['C'].set_title("Brightfield")

for x in 'ABC':
    ax[x].imshow(d7g_rings[(ysel, xsel)], cmap=blabels, vmin=0.5, interpolation='none')
    ax[x].imshow(unl_rings[(ysel, xsel)], cmap=olabels, vmin=0.5, interpolation='none')


In [None]:
# iplt = 2
# tplt = 2
# bfshow = scanspeed_ds.images.isel(I=iplt, T=tplt, C=1, Z=7).load().data
# df = scanspeed_raw_avgs.loc[pd.IndexSlice[:,:,tplt]]
# ax['A'].imshow(scanspeed_ds['squash'].isel(I=iplt, T=tplt).data, cmap='gray')
# ax['A'].axis('off')
# ax['A'].set_title("SRS")

# ax['B'].imshow(scanspeed_ds['fluo'].isel(I=iplt).data, cmap='gray')
# ax['B'].axis('off')
# ax['B'].set_title("Fluorescence")

# ax['C'].imshow(bfshow, cmap='gray')
# ax['C'].axis('off')
# ax['C'].set_title("Brightfield")

# for x in 'ABC':
#     ax[x].imshow(d7g_rings, cmap=blabels, vmin=0.5, interpolation='none')
#     ax[x].imshow(unl_rings, cmap=olabels, vmin=0.5, interpolation='none')

In [None]:
fig, ax = plt.subplot_mosaic("ABC;DDD", figsize=(6.52,6))
iplt = 2
tplt = 2
xsel = slice(327, 490)
ysel = slice(62, 196)
bfshow = scanspeed_ds.images.isel(I=iplt, T=tplt, C=1, Z=7, X=xsel, Y=ysel).load().data
df = scanspeed_raw_avgs.loc[pd.IndexSlice[:,:,tplt]]
ax['A'].imshow(scanspeed_ds['squash'].isel(I=iplt, T=tplt, X=xsel, Y=ysel).data, cmap='gray')
ax['A'].axis('off')
ax['A'].set_title("SRS")

ax['B'].imshow(scanspeed_ds['fluo'].isel(I=iplt, X=xsel, Y=ysel).data, cmap='gray')
ax['B'].axis('off')
ax['B'].set_title("Fluorescence")

ax['C'].imshow(bfshow, cmap='gray')
ax['C'].axis('off')
ax['C'].set_title("Brightfield")

for x in 'ABC':
    ax[x].imshow(d7g_rings[(ysel, xsel)], cmap=blabels, vmin=0.5, interpolation='none')
    ax[x].imshow(unl_rings[(ysel, xsel)], cmap=olabels, vmin=0.5, interpolation='none')
hrange = (100,160)
ax['D'].hist(df.loc[df.d7g]['srs'].values, bins=50, range=hrange, histtype='step', density=True, linewidth=2, label="D7G Labeled")
ax['D'].hist(df.loc[~df.d7g]['srs'].values, bins=50, range=hrange, histtype='step', density=True, linewidth=2, label='Unlabeled')
ax['D'].set_ylabel("Fraction of Cells")
ax['D'].set_xlabel("SRS Intensity (a.u.)")
ax['D'].legend(loc='upper left')#, bbox_to_anchor=(1.5,0.95))
ax['D'].set_title("Cellular SRS Intensities");
if savefigs:
    plt.savefig(figpath+"raw_srs_fluo_hist_demo.png")

In [None]:
vmin = demo_ds.images.min().item()
vmax = demo_ds.images.max().item()
fig, ax = plt.subplots(2,3, figsize=(6.52,5), sharex=True, sharey=True)
for a in ax.ravel(): a.axis('off')
ax[0,0].imshow(demo_ds.images.data, cmap='gray', vmin=vmin, vmax=vmax)
ax[0,0].set_title("1: Input image")

ax[0,1].imshow(demo_ds.images.data, cmap='gray', vmin=vmin, vmax=vmax)
ax[0,1].imshow(demo_ds.labels.data, cmap=rlabels, vmin=0.1, interpolation='none')
ax[0,1].set_title("2. Input masks")

ax[0,2].imshow(demo_ds.images, cmap='gray', vmin=vmin, vmax=vmax)
ax[0,2].imshow(demo_ds.labels, cmap=rlabels, vmin=.1, interpolation='none')
ax[0,2].imshow(demo_ds.cv_labels, cmap=blabels, vmin=.1, interpolation='none')
ax[0,2].set_title("3. CV masks")

ax[1,0].imshow(demo_ds.images.where(~demo_dilate).data, cmap='gray', vmin=vmin, vmax=vmax)
ax[1,0].set_title("4. Dilate")

ax[1,1].imshow(demo_ds.initial_estimate,  cmap='gray', vmin=vmin, vmax=vmax)
ax[1,1].set_title("5. Initialize")

ax[1,2].imshow(demo_ds.background_estimate, cmap='gray',  vmin=vmin, vmax=vmax)
ax[1,2].set_title("6. Smooth")
if savefigs: plt.savefig(figpath+"landfill_steps.png")

In [None]:
vmin = pair.images.min().item()
vmax = pair.images.max().item()

fig, ax= plt.subplots(2,2, figsize=(6.52,7), sharex=True, sharey=True)
for a in ax.ravel(): a.axis('off')
for i in range(2):
    ax[0,i].imshow(nan_mask.data[i], cmap='gray', vmin=vmin, vmax=vmax)
    ax[0,i].set_title(f"Masked Image {i+1}")
    # ax[0,i].imshow(pair_labels[i], cmap=labels_cmap, vmin=0.5)
scales = nan_mask.median(list('YX'))
scale_im = scales[-1]*((nan_mask/scales).median(list('S')))
show = pair.images[-1].squeeze().copy()
show = show.where(pair_be.dilated_mask[-1]==0, scale_im)
ax[1,0].imshow(show, cmap='gray', vmin=vmin, vmax=vmax )
ax[1,0].set_title("Image 2 filled with Image 1")
ax[1,1].imshow(init[-1], cmap='gray', vmin=vmin, vmax=vmax)
ax[1,1].set_title("Fill remaining with KNN average");
plt.savefig(figpath+"landfill_init.png", dpi=200)

In [None]:
fig, ax = plt.subplots(1,2, figsize=(6.52,3))
ax[0].set_box_aspect(1)
ax[0].plot(true_avgs.values, pred_avgs.values, 'k.', markersize=1)
x = [50,140]
ax[0].plot(x,x, 'r:', zorder=-1)
ax[0].set_xlabel("True Intensity (a.u.)")
ax[0].set_ylabel("Predicted Intensity (a.u.)")
ax[0].set_title("Accuracy of Background Estimation")
ax[0].set_xlim(45, 145)
ax[0].set_ylim(45, 145)

counts, bins, patches = ax[1].hist(neighbor_rmse, bins=100, density=True, range=(0,4), alpha=0.45, label=f"Neighborhood \n(median={neighbor_rmse.median():0.2f})")
ax[1].axvline(neighbor_rmse.median(), linestyle='--', color=colors[0])
ax[1].hist(rmse, bins=bins, alpha=0.45, density=True, label=f"Other FOVs\n(median={rmse.median():0.2f})");
ax[1].axvline(rmse.median(), linestyle='--', color=colors[1])
ax[1].hist(final_rmse, bins=bins, alpha=0.75, density=True,label=f"Final\n(median={final_rmse.median():0.2f})")
ax[1].axvline(final_rmse.median(), linestyle='--', color=colors[2])
ax[1].legend(loc='upper right')#, bbox_to_anchor=(1.1, 0.99));
ax[1].set_xlabel("RMSE")
ax[1].set_ylabel("Density")
ax[1].set_title("Prediction Errors")
if savefigs: plt.savefig(figpath+"accuracy.png")

In [None]:
iplt = 2
tplt = 2
xsel = slice(327, 490)
ysel = slice(62, 196)
bfshow = scanspeed_ds.images.isel(I=iplt, T=tplt, C=1, Z=7).load().data
df = scanspeed_bsub_avgs#.loc[iplt]

fig, ax = plt.subplot_mosaic("ABC;DDD", figsize=(6.52,6))
ax['A'].imshow(scanspeed_ds['squash'].isel(I=iplt, T=tplt, X=xsel, Y=ysel).data, cmap='gray')
ax['A'].axis('off')
ax['A'].set_title("SRS")

ax['B'].imshow(scanspeed_bsub.isel(S=iplt, X=xsel, Y=ysel).data, cmap='gray')
ax['B'].axis('off')
ax['B'].set_title("SRS Corrected")

ax['C'].imshow(bfshow[(ysel,xsel)], cmap='gray')
ax['C'].axis('off')
ax['C'].set_title("Brightfield")

for x in 'ABC':
    ax[x].imshow(d7g_rings[(ysel, xsel)], cmap=blabels, vmin=0.5, interpolation='none')
    ax[x].imshow(unl_rings[(ysel, xsel)], cmap=olabels, vmin=0.5, interpolation='none')
    
hrange = (-10,50)
ax['D'].hist(df.loc[df.d7g]['srs'].values, bins=50, range=hrange, histtype='step', density=True, linewidth=2, label="D7G Labeled")
ax['D'].hist(df.loc[~df.d7g]['srs'].values, bins=50, range=hrange, histtype='step', density=True, linewidth=2, label='Unlabeled')
ax['D'].set_ylabel("Fraction of Cells")
ax['D'].set_xlabel("SRS Intensity (a.u.)")
ax['D'].legend(loc='upper left')#, bbox_to_anchor=(1.5,0.95))
ax['D'].set_title("Background Subtracted Cellular SRS Intensities");
if savefigs:
    plt.savefig(figpath+"bsub_srs_fluo_hist_demo.png")

In [None]:
fig, ax = plt.subplots()
ax.plot(t245, bulk, label='Cell Signal')
ax.plot(t245, bkgd_obs, label='Observed Background')
ax.plot(t245, (bulk-bkgd)+63, label='Corrected Signal')
ax.plot(t245, bkgd, ':',label='Predicted Background')
ax.legend()
ax.set_ylabel("SRS Intensity (a.u.)")
ax.set_xlabel("Time (minutes)")
plt.title("Accounting for Time-Varying Background")
if savefigs: plt.savefig(figpath+"time_varying_bkgd_sub.png")