In [None]:
import os, sys
sys.path.append('../../')
from os.path import abspath, dirname
import zarr
import z5py
import numpy as np
import pandas as pd
from glob import glob 
from skimage.measure import regionprops
from skimage.io import imread, imsave
from scipy import stats
from scipy.stats import skewnorm, lognorm
from scipy.optimize import minimize
import itertools
import collections
from natsort import natsorted

from easi_fish import n5_metadata_utils as n5mu
from easi_fish import roi_prop, spot, intensity
import warnings
warnings.filterwarnings('ignore')

import importlib
importlib.reload(spot)
importlib.reload(roi_prop)
importlib.reload(intensity)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('retina')

sns.set_style('white', rc={'axes.grid':True})
sns.set_context('talk')

Spot counts for cells with highly expressed genes (dense spots)
1. Measure total intensity of every ROI after bleed-through correction and background subtraction.
2. Calculate the number of spot from total intensity based on unit-spot intensity
3. Correlate the number of spots (from air-localize) with the total fluorescence intensity/voxel in each ROI and determine a 'cutoff'. 
   Spot count > cutoff: use spot count converted based on total fluorescence intensity; 
   Spot count < cutoff: use spot count from Airlocalize

### On units
- all images are based on pixel units - 
- roi meta file (output) are based on physical unit (um - pre-expansion)
- spots files are in um (post-expansion)

In [None]:
## input
ddir = '/u/home/f/f7xiesnm/project-zipursky/easifish/273LU'
!mkdir -p $ddir/proc
tiles = natsorted(os.listdir(ddir+'/outputs'))
# tiles.remove('tile4')
tiles

In [None]:

for tile in tiles:
    output_dir = ddir + f'/proc/{tile}_v1'

    fix_round = tile
    mov_rounds = [] # 'r1v3', 'r2v3', 'r4v3', 'r5v3']
    round_channels = collections.OrderedDict({
        tile: ('c0', 'c1', 'c4'),
    })

    dapi_channel = 'c3' # _reg' # for 
    lb_scale = 's3'
    lb_res = [1.84,1.84,1.68]
    ex = 2

    # images
    subpath     =   '/c3/s3' 
    # subpath_reg =   '/c3_reg/s3' 
    fix_dir  =  ddir + f"/outputs/{fix_round}/stitching/export.n5"
    lb_dir  =   ddir + f"/outputs/{fix_round}/segmentation/{fix_round}-c3.tif"
    reg_dirs = [
        # ddir + f"/outputs/r1v3/registration/r1v3-to-r3v3/warped",
        # ddir + f"/outputs/r2v3/registration/r2v3-to-r3v3/warped",
        # ddir + f"/outputs/r4v3/registration/r4v3-to-r3v3/warped",
        # ddir + f"/outputs/r5v3/registration/r5v3-to-r3v3/warped",
        ]

    fx_spots = [
        ddir + f'/outputs/{tile}/spots/spots_c0.txt',
        ddir + f'/outputs/{tile}/spots/spots_c1.txt',
        ddir + f'/outputs/{tile}/spots/spots_c4.txt',
    ]

    intn_threshs = [20]*len(fx_spots)

    for f in fx_spots:
        assert os.path.isfile(f)

    ## output
    out_badroi = os.path.join(output_dir, 'bad_roi_list.npy')
    out_allroi = os.path.join(output_dir, "roi_all.csv") 
    out_roi = os.path.join(output_dir, "roi.csv") 
    out_spots = os.path.join(output_dir, "spotcount.csv")

    REMOVE_BLEEDTHRU = False

    # remove bleed through!
    if REMOVE_BLEEDTHRU:

        bleed_thru_epsilon = 1
        c_qry = 'c0'
        c_ref = 'c4'

        f_ref = ddir+f'/outputs/{theround}/spots/spots_{c_ref}.txt'
        f_qry = ddir+f'/outputs/{theround}/spots/spots_{c_qry}.txt'
        f_qry_kept = os.path.join(output_dir, f'kept_spots_{theround}_{c_qry}.txt')

        ref_dots = np.loadtxt(f_ref, delimiter=',')
        qry_dots = np.loadtxt(f_qry, delimiter=',')
        qry_kept, qry_removed = spot.remove_bleed_thru_spots(ref_dots, qry_dots, epsilon=bleed_thru_epsilon)

        # save 
        np.savetxt(f_qry_kept, qry_kept, delimiter=",")

        ### TODO - replace f_qry with f_qry_kept in spot list

    # output dir
    if not os.path.isdir(output_dir):
        print(output_dir)
        os.mkdir(output_dir)

    # image size in pixel (x, y, z)
    grid = n5mu.read_voxel_grid(fix_dir, subpath)
    # voxel resolution in µm (x, y, z) (post-expansion)
    vox  = n5mu.read_voxel_spacing(fix_dir, subpath)
    # image size in physical space (x, y, z) (post-expansion)
    size = grid*vox
    print('subpath: ', subpath)
    print('voxel size: ', vox)
    print('image size (pixel): ', grid)
    print('image size (um post-ex): ', size)

    # get segmentation mask
    lb = imread(lb_dir)
    # roi = np.max(lb) # this is only correct if this lb is uncropped
    roi = len(np.unique(lb[lb!=0])) # this would be better
    print(lb.shape)
    print('num roi: ', roi)

    # Get list of good ROIs  
    if len(reg_dirs) > 0:
        ### Make sure to only include ROIs that are intact and in the overlapping regions across all rounds of FISH
        ### remove any unregistered parts
        mask = np.ones(grid[::-1])
        for reg_dir in reg_dirs:
            reg = zarr.open(store=zarr.N5Store(reg_dir), mode='r')     
            img2 = reg[subpath_reg][...]
            print("image loaded")
            mask[img2==0]=0
        print("mask generated")

        bad_roi=np.unique(lb[mask==0])
        if bad_roi[0] == 0: # remove the label 0 - extracellular space
            bad_roi = bad_roi[1:]
        np.save(out_badroi, bad_roi)
        print("# of ROIs rejected:", len(bad_roi))

    else:
        bad_roi = np.array([])

    # get cell locations (in um - pre-expansion) from segmentation mask
    roi_meta_all = roi_prop.roi_prop_v2(lb, lb_res, ex)
    roi_meta_all.to_csv(out_allroi)

    roi_meta = roi_meta_all.set_index('roi').copy()
    roi_meta = roi_meta.loc[roi_meta.index.difference(bad_roi)]
    roi_meta.to_csv(out_roi)

    # count spots for every cell (roi)
    # all labels
    lb_id = np.unique(lb[lb!=0]) # exclude 0
    lb_id = np.hstack([[0], lb_id]) # include 0 - noncell
    # selected cells
    lb_id_selected = roi_meta.index.values

    # prep
    spotcount = pd.DataFrame(index=lb_id_selected, dtype=int)
    # spotcount_intn = pd.DataFrame(index=lb_id_selected, dtype=float)

    i = 0
    for r,chs in round_channels.items():
        for c in chs:
            # f_intns = fx_intns[i]
            f_spots = fx_spots[i]
            intn_th = intn_threshs[i]
            print(r, c, f_spots) #, f_intns)

            # spots
            spots_rc = np.loadtxt(f_spots, delimiter=',')
            print(len(spots_rc))

            # filter
            filter_cond = spots_rc[:,3] > intn_th
            spots_rc = spots_rc[filter_cond]
            print(len(spots_rc))

            # count spots
            res = spot.spot_counts_worker(lb, spots_rc, lb_res,
                                          lb_id=lb_id, 
                                          remove_noncell=True, 
                                          selected_roi_list=lb_id_selected,
                                          )
            spotcount[f"{r}_{c}"] = res 

            # # count spots by intensity
            # res = spot.get_spot_counts_from_intn(f_intns, f_spots, roi_meta, lb_res)
            # spotcount_intn[f'{r}_{c}'] = res
            i = i + 1

    # save results
    spotcount.to_csv(out_spots)
    # spotcount_intn.to_csv(out_spots_intn)