# ND2 native stitch pipeline
This notebook handles a single nd2 file, it must be a multiposition & timelapse dataset.

## Key details:
- Sequential positions are individual tile images belonging to discrete, rectangular & non-overlapping tile regions (panoramas). These are to be identified, stitched, and processed in this notebook.
- No Z-interpretation, maxIP is executed before (NIS) or during pipeline execution.
- During feature extraction (spot detection & intensity extraction) tile images are considered independently. Features are later taken only for the best representation of each cell in tile image overlap regions.

## Key pipeline components:
- ND2 parsing by [nd2](https://github.com/tlambert03/nd2)
  - Dataset layout
  - Metadata
- (optional) flatfield illumination correction
- Tile region identification using stage coordinates and pixel size (metadata)
- Tile region stitching by either:
  - Hard overlap according to stage coordinates, no registration
  - Image registration with [MIST](https://github.com/usnistgov/MIST)
- Segmentation of nuclei (& cells) by [CellPose](https://github.com/MouseLand/cellpose)
- Cell tracking by either:
  - [TrackMate](https://github.com/trackmate-sc/TrackMate)
  - [TrackAstra](https://github.com/weigertlab/trackastra/)
- Spot detection by either:
  - [SpotiFlow](https://github.com/weigertlab/spotiflow)
  - [DeepBlink](https://github.com/bbquercus/deepblink/) **! depricated here, use alternative env & notebook**
- Cell representation in tile image overlap
- Cell intensity and morphology feature extraction
- Data visualization with Napari
- Optional video rendering

In [None]:
import os, sys
import numpy as np
import pandas as pd

from tqdm.notebook import tqdm
from tqdm.contrib.itertools import product

import nd2
from tifffile import imread, imwrite
import imagej
from skimage import measure

import napari
import matplotlib.pyplot as plt

# find path to function imports
from pathlib import Path
path_imports = str(Path(os.getcwd()).resolve().parents[0]) + '/src/'
sys.path.append(path_imports)

# import external function
import importlib
import MetadataParserND2
import TileRegion
import FlatDarkField
import LoopDimensions
import SegmentationCellpose
import LinkPointToObject
import FeatureExtraction

In [None]:
# path of FIJI
path_imagej_mist = '/data/Fiji.app'
if not os.path.exists(path_imagej_mist):
    print('Path invalid!')

In [None]:
# Trackmate imagej macro path
path_script_Trackmate = path_imports + "trackmate_init.py"
if not os.path.exists(path_script_Trackmate):
    print('Path invalid!')

# Data path definitions

In [None]:
# Dataset to process
path_nd2 = '/dummy.nd2'

In [None]:
# Folder with flatfield correction data
path_ff_base = '/dummy/flatfield/'

In [None]:
# Analysis folder setup
name = os.path.basename(path_nd2).split('.')[0]
path_base = os.path.dirname(path_nd2)

path_export_tile = path_base + "/Export/"
path_export_stitch = path_base + "/Stitch/"
path_analysis_segmentation = path_base + "/Masks/"
path_analysis_tracks = path_base + "/Tracks/"
path_analysis_spots = path_base + "/Spots/"
path_features = path_base + "/Features/"

os.makedirs(path_export_tile, exist_ok=True)
os.makedirs(path_export_stitch, exist_ok=True)
os.makedirs(path_analysis_segmentation, exist_ok=True)
os.makedirs(path_analysis_tracks, exist_ok=True)
os.makedirs(path_analysis_spots, exist_ok=True)
os.makedirs(path_features, exist_ok=True)

# Load nd2

In [None]:
# User toggles
flip_x = False
flatfield = True

In [None]:
# Parse metadata
sizes, channels, positions, px_size = MetadataParserND2.get_nd2_meta(path_nd2)

# first loop of positional redundant dimensions (T/Z) lists unique stage positions
positions_first = positions[(positions['iT'] == 0) & (positions['iZ'] == 0)].copy()
data_img = nd2.ND2File(path_nd2).to_dask()

sizes = dict(sizes)
print(sizes)
print(channels)
print(data_img.shape)
positions_first.tail()

In [None]:
# (Optional) bind image position names (incl well codes) to positions_first df
positions_first['name'] = MetadataParserND2.get_nd2_tile_names(path_nd2)

In [None]:
# Parse metadata, show camera settings
print(MetadataParserND2.get_nd2_camera_settings(path_nd2))
print(MetadataParserND2.get_nd2_misc(path_nd2))

In [None]:
# Timeframe to use for stitching, default: half of timelapse length
t_stitch = int(round(sizes.get('T', 1)/2))
print(t_stitch)

In [None]:
# For single channel data place a dummy channel axis
if 'C' not in sizes:
    data_img, sizes = MetadataParserND2.fix_nd2_single_channel(data_img, sizes)

In [None]:
# Flatten Z dimension by maxIP
if sizes.get('Z', 1) > 1:
    data_img, sizes = MetadataParserND2.proj_nd2_max(data_img, sizes)

# Channel usage

In [None]:
# Stitch
channelsOI_stitch = 0
print("Stitch: " + str(np.array(channels)[channelsOI_stitch]))

In [None]:
# Segmentation (nuclei, cyto optional)
channelsOI_cellpose_nucl = 0
channelsOI_cellpose_cell = None #None to disable

print("Cellpose NUCL: ", str(np.array(channels)[channelsOI_cellpose_nucl]))
if channelsOI_cellpose_cell != None:
    print("Cellpose CELL: ", str(np.array(channels)[channelsOI_cellpose_cell]))

In [None]:
# Spot desection
channelsOI_spots = [True, False]
print("Deepblink: " + str(np.array(channels)[channelsOI_spots]))

# Flat field loading

In [None]:
# disable flatfield correction
if not flatfield:
    img_ff = []
    img_df = []

In [None]:
# camera metadata (sensor crop)
with nd2.ND2File(path_nd2) as nd2file:
    nd2meta = nd2file.unstructured_metadata()['ImageMetadataSeqLV|0']['SLxPictureMetadata']['PicturePlanes']['SampleSetting']
    nd2file.close()
max_key = max(int(key) for key in nd2meta.keys())
crop_cam = nd2meta[str(max_key)]['CameraSetting']['ROI']
print(crop_cam)
print(MetadataParserND2.get_nd2_camera_settings(path_nd2))
print(channels)

In [None]:
# pre-process full-frame flatfield images with the recorded sensor crop
ff_custom = True

if ff_custom:
    crop_y = crop_cam['Top']
    crop_x = crop_cam['Left']
    crop_w = crop_cam['Right'] - crop_cam['Left']
    crop_h = crop_cam['Bottom'] - crop_cam['Top']
    
    img_ff_BFP = imread(path_ff_base + 'BFP.tiff')[crop_y:crop_y+crop_h, crop_x:crop_x+crop_w]
    img_ff_BFP = img_ff_BFP / np.mean(img_ff_BFP)

    img_ff_GFP = imread(path_ff_base + 'GFP.tiff')[crop_y:crop_y+crop_h, crop_x:crop_x+crop_w]
    img_ff_GFP = img_ff_GFP / np.mean(img_ff_GFP)

    img_ff_mCh = imread(path_ff_base + 'mCherry.tiff')[crop_y:crop_y+crop_h, crop_x:crop_x+crop_w]
    img_ff_mCh = img_ff_mCh / np.mean(img_ff_mCh)

    img_df = imread(path_ff_base + 'dark.tiff')[crop_y:crop_y+crop_h, crop_x:crop_x+crop_w]

In [None]:
# Load and parse ff and df images
if flatfield:
    if not ff_custom:
        img_ff_BFP = imread(path_ff_BFP)
        img_ff_GFP = imread(path_ff_GFP)
        img_ff_mCh = imread(path_ff_mCh)
        img_ff_TMR = imread(path_ff_TMR)
        img_ff_Cy5 = imread(path_ff_Cy5)    
        img_df = imread(path_df).astype(np.uint16)

    # HERE SPECIFY ff channels to use, consistent to channel list specified above
    # (reasonable default channel setup assumed here)
    if len(channels) == 3:
        img_ff = np.stack([img_ff_mCh, img_ff_GFP, img_ff_BFP])
    if len(channels) == 2:
        img_ff = np.stack([img_ff_GFP, img_ff_BFP])

    if (not img_ff[0].shape == img_df.shape == data_img.shape[-2:]) or (not sizes.get('C', 1) == img_ff.shape[0]):
        # check if dim X and Y are same as ff/df X and Y, and same number of images provides as channels
        print("Wrong dimensions, fix DF and FF, for now no FF correction!")
        flatfield = False

In [None]:
# show example
plt.imshow(img_ff_GFP)
plt.show()

# Stitching
Two options:
a. Stage coordinates hard paste
b. MIST compute optimal stitching using image data

## Find tile region(s)

In [None]:
# X-axis flip for some microscopes
if (flip_x):
    positions_first['X'] = positions_first['X'] * -1

In [None]:
tr_init, overlap_percent = TileRegion.find_tile_region_start(sizes, px_size, positions)
print(tr_init)
print(overlap_percent)

In [None]:
# plot stage positions
plt.scatter(positions_first['X'], positions_first['Y'], c = positions_first.index)
plt.show()

In [None]:
positions_tr_all = TileRegion.find_tile_region_all(tr_init, sizes, positions_first, px_size)
print(positions_tr_all.tail())

### (Optional) Well codes

In [None]:
# Export well code per TR (don't run if you have Meghan data)
tileregion_names = positions_tr_all.merge(positions_first[['iP', 'name']], left_on = 'P', right_on = 'iP').groupby('TR').head(1)
tileregion_names['TR_name'] = tileregion_names['name'].str.split('_').str.get(0)

tileregion_names[['TR', 'TR_name']].to_csv(path_export_stitch + "/" + name + "_TRNames.csv", sep = ";", decimal = ".", index = False)

## A. Place tiles by stage coordinate

In [None]:
positions_first_join = pd.merge(positions_tr_all,
                                positions_first.rename(columns={'iP': 'P'}),
                                how = 'inner', on = 'P')

positions_join = TileRegion.perform_stitching_stage(tr_init, positions_first_join, px_size)
positions_join.to_csv(path_export_stitch + "/" + name + "_TRLayout.csv", sep = ";", decimal = ".", index = False)

In [None]:
# plot tileregion result
TileRegion.plot_tileregion_layout(positions_join, positions_first, tr_init)

## B. Compute optimal tile position by MIST

In [None]:
# temp export tile images to tiff
TileRegion.export_tile_img_for_stitching(data_img, name, path_export_tile, sizes,
                                         tr_init, positions_tr_all, channelsOI_stitch, t_stitch,
                                         flatfield, img_ff, img_df)

In [None]:
# imagej init
ij_mist = imagej.init(path_imagej_mist, add_legacy = True)
TileRegion.perform_stitching_mist(ij_mist, name, path_export_tile, tr_init, positions_tr_all, overlap_percent, px_size)
# release memory from virtual java/imagej instance!
ij_trac.dispose()

In [None]:
# retrieve MIST result(s)
positions_join = TileRegion.parse_mist_result(name, path_export_tile, path_export_stitch, tr_init, positions_tr_all)

In [None]:
# plot tileregion result
TileRegion.plot_tileregion_layout(positions_join, positions_first, tr_init)

## Create stitch image from tiles

In [None]:
# load tile positions table
if 'positions_join' not in locals:
    positions_join = pd.read_csv(path_export_stitch + "/" + name + "_TRLayout.csv", sep = ";", decimal = ".")
print(positions_join.tail())

In [None]:
stitched_images = TileRegion.stitch_tile_images(data_img, tr_init, sizes, positions_join, flatfield, img_ff, img_df, verbose = False, blend = False)

In [None]:
# Save the stitched images individually
for i, arr in enumerate(stitched_images):
    np.save(path_export_stitch + name + "_TR" + str(i) + "_img_stitched", arr)

# Segmentation (single cell)

In [None]:
# load calculated TR layout (if not in current session)
if 'positions_join' not in locals:
    positions_join = pd.read_csv(path_export_stitch + "/" + name + "_TRLayout.csv", sep = ";", decimal = ".")
    tr_init = positions_join['TR'].unique()
print(positions_join.tail())

In [None]:
# load stitched image data (if not in current session)
if 'stitched_images' not in locals:
    stitched_images = []
    for tr in tqdm(range(len(tr_init))):
        stitched_images.append(np.load(path_export_stitch + name + "_TR" + str(tr) + "_img_stitched.npy", mmap_mode='r'))

In [None]:
# exec nuclear segmentation
cellpose_model_nucl = SegmentationCellpose.load_model_cellpose(name = 'nuclei_denoise')

results_cellpose_nucl = LoopDimensions.loop_tileregion_np(
    img = stitched_images,
    func = SegmentationCellpose.create_mask,
    tr_init = tr_init,
    channels_oi = [channelsOI_cellpose_nucl],
    sizes = sizes,
    cellpose_model = cellpose_model_nucl,
    resample = False)

for i, arr in enumerate(results_cellpose_nucl):
    np.save(path_analysis_segmentation + name + "_TR" + str(i) + "_mask_nucl", arr)
    # residual save duplicate to tiff for TrackMate import in ImageJ
    imwrite(path_analysis_segmentation + name + "_TR" + str(i) + "_mask_nucl.tif", data = arr.astype('uint16'))

In [None]:
# exec cellular/cytoplasm segmentation on nuclear and cyto channel
if channelsOI_cellpose_cell is not None:
    cellpose_model_cell = SegmentationCellpose.load_model_cellpose(name = 'cyto3_denoise')

    results_cellpose_cell = LoopDimensions.loop_tileregion_np(
        img = stitched_images,
        func = SegmentationCellpose.create_mask,
        tr_init = tr_init,
        channels_oi = [channelsOI_cellpose_cell, channelsOI_cellpose_nucl],
        sizes = sizes,
        channels = [1, 2],
        diam = 180,
        cellpose_model = cellpose_model_cell,
        resample = False)

    for i, arr in enumerate(results_cellpose_cell):
        np.save(path_analysis_segmentation + name + "_TR" + str(i) + "_mask_cell", arr)

# Tracking (nuclear)

## A. TrackMate
! you probably have to restart the kernel here due to a bug in pyimagej with multiple imagej environments conflicting

In [None]:
# IJ init (might error first exec)
ij_trac = imagej.init([
    'net.imagej:imagej:2.5.0',
    'sc.fiji:TrackMate:7.9.2',
    'sc.fiji:Feature_Detection:2.0.3',
    'ome:bioformats_package:6.11.0'])

In [None]:
# Tracking
script = open(path_script_Trackmate).read()
for file in tqdm([ f for f in os.listdir(path_analysis_segmentation) if (str(f))[-3:] == "tif"]):
    ij_trac.py.run_script("python", script, {"path_mask": path_analysis_segmentation + file, "path_out": path_analysis_tracks, "batchmode": True})
# release memory from virtual java/imagej instance!
ij_trac.dispose()

## B. TrackAstra

In [None]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

# trackastra
from trackastra.utils import normalize
from trackastra.model import Trackastra
from trackastra.tracking import graph_to_napari_tracks, graph_to_ctc

In [None]:
# load calculated TR layout
if 'positions_join' not in locals():
    positions_join = pd.read_csv(path_export_stitch + "/" + name + "_TRLayout.csv", sep = ";", decimal = ".")
    tr_init = positions_join['tr'].unique()

In [None]:
# load stitched image data (if not in current session)
if 'stitched_images' not in locals():
    stitched_images = []
    for tr in tqdm(range(len(tr_init))):
        stitched_images.append(np.load(path_export_stitch + name + "_TR" + str(tr) + "_img_stitched.npy", mmap_mode='r'))

In [None]:
# load stitched image data (if not in current session)
if 'results_cellpose_nucl' not in locals():
    results_cellpose_nucl = []
    for tr in tqdm(range(len(tr_init))):
        results_cellpose_nucl.append(np.load(path_analysis_segmentation + name + "_TR" + str(tr) + "_mask_nucl.npy"))

In [None]:
for tr in tqdm(range(len(tr_init))):
    img = stitched_images[tr][:, 0, channelsOI_cellpose_nucl]
    labels = results_cellpose_nucl[tr].squeeze()
    
    # Normalize your images
    img_norm = np.stack([normalize(x) for x in img])

    # Load a pretrained model
    model = Trackastra.from_pretrained("general_2d", device=device)

    # Track the cells
    track_graph = model.track(img_norm, labels, mode="greedy")  # or mode="ilp", or "greedy_nodiv"
    
    # Convert track graph to tables (Visualise in napari)
    napari_tracks, napari_tracks_graph, _ = graph_to_napari_tracks(track_graph)
    
    # fetch object id from mask for tracks df
    maskid_df = []
    for frame, mask in enumerate(labels):
        subset = napari_tracks[:, 1] == frame
        coords = napari_tracks[subset, 2:4]
        indexes = np.where(subset)
        ids = LinkPointToObject.PointToMaskID(coords, mask)
        df = pd.DataFrame({'mask_id': ids, 'index': indexes[0]}, columns=['mask_id', 'index'])
        maskid_df.append(df)

    maskid_df = pd.concat(maskid_df)
    maskid_df.set_index('index', inplace = True)
    track_df = pd.DataFrame(napari_tracks, columns=['track_id', 'frame', 'y', 'x']).join(maskid_df)

    # write to file
    track_df.to_csv(path_analysis_tracks + name + "_TR" + str(tr) + "_tracks_TrackAstra.csv", index = False)

    print(track_df.head())

# Spot detection
! Executed on single tile images

In [None]:
import SpotSpotiflow
importlib.reload(SpotSpotiflow)

name_model_spotiflow = 'general' # published model
model_spotiflow = SpotSpotiflow.load_model(name_model_spotiflow)
results_spotiflow = LoopDimensions.loop_tiles_pd(data_img = data_img, sizes = sizes,
                                                 func = SpotSpotiflow.spot_detection_pd,
                                                 model = model_spotiflow,
                                                 channels_oi = channelsOI_spots,
                                                 dict_flatfield = {'flatfield': flatfield,
                                                                   'flatfield_func': FlatDarkField.ffdf,
                                                                   'img_ff': img_ff, 'img_df': img_df
                                                 })

results_spotiflow.to_csv(path_or_buf = path_analysis_spots + name + "_all_spots_spotiflow_" + name_model_spotiflow + ".csv", sep = ";", decimal = ".", index = False)

# Cell representation tile overlap

In [None]:
# load calculated TR layout (if not in current session)
if 'positions_join' not in locals():
    positions_join = pd.read_csv(path_export_stitch + "/" + name + "_TRLayout.csv", sep = ";", decimal = ".")
    tr_init = positions_join['TR'].unique()
print(positions_join.tail())

In [None]:
# load segmentation maps (if not in current session)
if 'results_cellpose_nucl' not in locals:
    results_cellpose_nucl = []
    for i, arr in enumerate(tr_init):
        results_cellpose_nucl.append(np.load(path_analysis_segmentation + name + "_TR" + str(i) + "_mask_nucl.npy"))
        print(results_cellpose_nucl[i].shape)

In [None]:
# identify in what cells are presented most completely in which tiles
object_overlap_max = []
for TR, TR_start in enumerate(tqdm(tr_init)):
    results_cellpose_TR = results_cellpose_nucl[TR].squeeze()
    positions_join_TR = positions_join[positions_join['TR'] == TR]
    #print(results_cellpose_TR.shape)
    for i_slice, mask_slice in enumerate(tqdm(results_cellpose_TR)):
        # Calculate the area of each object in mask of stitched image
        n_obj = mask_slice.max()
        props_TR = measure.regionprops_table(mask_slice, properties = ['label', 'area'])
        
        # initiate empty array for storing values
        object_overlap_current = np.zeros(shape = (n_obj, positions_join_TR.shape[0]), dtype = np.uint16)
        
        # Calculate area for each object in mask of each tile (crop from stitch)
        for index, row in positions_join_TR.iterrows():
            x_px_start = row['x_px']
            y_px_start = row['y_px']
            
            # crop mask
            mask_slice_tile = mask_slice[y_px_start : y_px_start + sizes['Y'], x_px_start : x_px_start + sizes['X']]

            # calculate area of each object in tile
            props_tile = measure.regionprops_table(mask_slice_tile, properties = ['label', 'area'])
            object_overlap_current[props_tile['label']-1, row['index']] = props_tile['area'].astype(np.uint16)
        
        # over all FOVs (stored in cols) determine where area representation for each object is largest
        object_overlap_current_max = object_overlap_current.argmax(axis = 1)
        # calculate fraction of overlap relative to stitched mask
        object_overlap_current_max_fraction = np.amax(object_overlap_current, axis = 1) / props_TR['area']

        # translate TR index to P in whole dataset
        object_overlap_current_max_join = pd.merge(positions_join_TR, pd.DataFrame({'index': object_overlap_current.argmax(axis = 1), 'overlap': object_overlap_current_max_fraction, 'IDCell': range(1, n_obj+1)}), on = 'index')
        
        # combine and append data
        object_overlap_current_max_df = pd.DataFrame({
            'IDCell': object_overlap_current_max_join['IDCell'],
            'position_max': object_overlap_current_max_join['P'],
            'T': i_slice,
            'TR': TR,
            'overlap': object_overlap_current_max_join['overlap'],
        })
        object_overlap_max.append(object_overlap_current_max_df)
object_overlap_max = pd.concat(object_overlap_max, axis = 0)

object_overlap_max.to_csv(path_or_buf = path_analysis_segmentation + name + "_mask_tile_overlap.csv", sep = ";", decimal = ".", index = False)

In [None]:
# preview result
print(object_overlap_max)

In [None]:
# plot, here you expect mostly random/even distribution
plt.hist(object_overlap_max['position_max'], bins = 100)
plt.show()

In [None]:
# plot, here you expect right skewed data (cells are mostly found with 100% overlap in discete tiles versus stitch)
plt.hist(object_overlap_max['overlap'], bins = 100)
plt.yscale('log')
plt.show()

# Spot association & filtering
Bind spots to segmented objects (nuclei) & optimal object overlap with tiles

In [None]:
# load calculated TR layout (if not in current session)
if 'positions_join' not in locals():
    positions_join = pd.read_csv(path_export_stitch + name + "_TRLayout.csv", sep = ";", decimal = ".")
    tr_init = positions_join['TR'].unique()

In [None]:
# load segmentation data (if not in current session)
if 'results_cellpose_nucl' not in locals():
    results_cellpose_nucl = []
    for i, arr in enumerate(tr_init):
        results_cellpose_nucl.append(np.load(path_analysis_segmentation + name + "_TR" + str(i) + "_mask_nucl.npy"))
        print(results_cellpose_nucl[i].shape)

In [None]:
# load spots data (if not in current session)
if 'results_spotiflow' not in locals():
    results_spotiflow = pd.read_csv(path_analysis_spots + name + "_all_spots_spotiflow_general.csv", sep = ";", decimal = ".")
    name_source_spots = 'spotiflow_general'
results_spots = results_spotiflow

In [None]:
# load object tile overlap data (if not in current session)
if 'object_overlap_max' not in locals():
    object_overlap_max = pd.read_csv(path_analysis_segmentation + name + "_mask_tile_overlap.csv", sep = ";", decimal = ".")
print(object_overlap_max.tail())

In [None]:
# pad spot XY by tile coordinates by tile position
results_spots_pad = results_spots.merge(positions_join, on = 'P')
results_spots_pad['X_padded'] = results_spots_pad['X'] + results_spots_pad['x_px']
results_spots_pad['Y_padded'] = results_spots_pad['Y'] + results_spots_pad['y_px']
results_spots_pad['T'] = results_spots_pad['T'].astype(int)
results_spots_pad['C'] = results_spots_pad['C'].astype(int)
results_spots_pad['P'] = results_spots_pad['P'].astype(int)
results_spots_pad.tail()

In [None]:
# find segmentation mask object ID for all spots
maskIDs = []

for TR, _ in enumerate(tqdm(tr_init)):
    # for each TR
    for slice, mask in enumerate(tqdm(results_cellpose_nucl[TR].squeeze())):
        # for each timeslice
        spots = results_spots_pad[(results_spots_pad['T'] == slice) & (results_spots_pad['TR'] == TR)][['Y_padded','X_padded']]
        maskID = LinkPointToObject.PointToMaskID(spots.values, mask)
        maskID = np.column_stack((maskID, spots.index))
        maskIDs.append(maskID)
maskIDs = pd.DataFrame(np.concatenate(maskIDs), columns = ['ID', 'index']).set_index('index')

# bind mask id to spot dataframe
results_spots_assigned = results_spots_pad.join(maskIDs)

results_spots_assigned.to_csv(path_or_buf = path_analysis_spots + name + "_all_spots_assigned_" + name_source_spots + ".csv", sep=";", decimal = ".", index = False)

In [None]:
print(results_spots_assigned.head())

In [None]:
# how many spots are now not assigned to cells?
spots_unmatched_n = (results_spots_assigned['ID'] == 0).value_counts()
print(spots_unmatched_n)
print(results_spots_assigned['P'].unique())

In [None]:
# join with max object tile overlap
print(results_spots_assigned.shape)
results_spots_assigned_filter = results_spots_assigned.merge(object_overlap_max, left_on = ['T', 'P', 'ID', 'TR'], right_on = ['T', 'position_max', 'IDCell', 'TR'], how = 'inner')
results_spots_assigned_filter = results_spots_assigned_filter.drop(['position_max', 'overlap'], axis = 1)
print(results_spots_assigned_filter.shape)

results_spots_assigned_filter.to_csv(path_or_buf = path_analysis_spots + name + "_all_spots_assigned_filter_" + name_source_spots + ".csv", sep = ";", decimal = ".", index = False)

In [None]:
print(results_spots_assigned_filter.head())
print(results_spots_assigned_filter['P'].unique())

# Feature extraction single-cell (morphology & intensity)
! Executed on single tile images

In [None]:
importlib.reload(FeatureExtraction)
keys = list(dict.keys(sizes))

results_features = []
for inds in product(*map(range, data_img.blocks.shape)):
    # fetch current position in dataset=
    i_T = inds[keys.index('T')]
    i_P = inds[keys.index('P')]
    XY_px = positions_join.loc[positions_join['P'] == i_P, ['y_px', 'x_px']].values.squeeze()
    TR = positions_join.loc[positions_join['P'] == i_P, 'TR'].values.squeeze() #0
    
    chunk = data_img.blocks[inds].squeeze()
    # crop mask to tile
    mask_TR = results_cellpose_nucl[TR].squeeze()
    mask = mask_TR[i_T, XY_px[0]: XY_px[0] + sizes['Y'], XY_px[1]: XY_px[1] + sizes['X']]
    
    if chunk.ndim == 3:
        img = chunk.compute().squeeze()
        if flatfield:
            for C in range(img.shape[0]):
                img = img.astype(np.float32)
                img[C] = FlatDarkField.ffdf(img[C], img_ff[C].squeeze(), img_df)
        img = np.transpose(img, axes=(1, 2, 0))

    if features_extra_percentile_subtract:
        img_mod = np.zeros_like(img)
        for iC, name_channel in enumerate(channels):
            img_mod[:,:,iC] = im.img_object_percentile_subtract(img[:,:,iC], mask, percent=20)
        img = np.concatenate((img, img_mod), axis = -1)

    data_features = FeatureExtraction.extract_intensity_features_img(img, mask)
    
    # TODO bind channel names (and mod name) to colnames, instead of -0/-1 suffixes
    # print(data_features.columns)
    
    data_features['P'] = i_P
    data_features['T'] = i_T
    data_features['TR'] = TR
    results_features.append(data_features)
results_features = pd.concat(results_features)

# store to file
results_features.to_csv(path_features + name + "_features_cell.csv", index = False)
results_features.head()

In [None]:
# join with max object tile overlap
results_features_filter = results_features.merge(object_overlap_max, left_on = ['T', 'P', 'label', 'TR'], right_on = ['T', 'position_max', 'IDCell', 'TR'], how='inner')
results_features_filter = results_features_filter.drop(['position_max'], axis = 1)

results_features_filter.to_csv(path_features + name + "_features_cell_filter.csv", index = False)
results_features_filter.head()

# Visualization (Napari)

## Load

In [None]:
if 'px_size' not in locals:
    px_size = MetadataParserND2.get_nd2_pxsize(path_nd2)

In [None]:
# load calculated TR layout (if not in current session)
if 'positions_join' not in locals:
    positions_join = pd.read_csv(path_export_stitch + "/" + name + "_TRLayout.csv", sep = ";", decimal = ".")
    tr_init = positions_join['TR'].unique()
print(positions_join.tail())

In [None]:
# load stitched image data (if not in current session)
if 'stitched_images' not in locals:
    stitched_images = []
    for tr in tqdm(range(len(tr_init))):
        stitched_images.append(np.load(path_export_stitch + name + "_TR" + str(tr) + "_img_stitched.npy")) #, mmap_mode='r'

In [None]:
# load segmentation data (if not in current session)
if 'results_cellpose_nucl' not in locals():
    results_cellpose_nucl = []
    for i, arr in enumerate(tr_init):
        results_cellpose_nucl.append(np.load(path_analysis_segmentation + name + "_TR" + str(i) + "_mask_nucl.npy"))
        print(results_cellpose_nucl[i].shape)

## Init

In [None]:
viewer = napari.Viewer()

In [None]:
viewer.dims.axis_labels = ['T', 'Y', 'X']
viewer.scale_bar.visible = True
viewer.scale_bar.font_size = 20
viewer.scale_bar.unit = "um"

## Image stitch

In [None]:
# which tile-region image to take from the set?
tr = 0
# fetch
stitched_image = stitched_images[tr]
stitched_image.shape

In [None]:
# load stitch to napari
viewer.add_image(data_img,
                 channel_axis = 3,
                 name = channels,
                 blending = 'additive',
                 gamma = 1,
                 scale = (px_size, px_size))

## Segmentation

In [None]:
# fetch segmentation map
result_cellpose_nucl = results_cellpose_nucl[tr]
print(result_cellpose_nucl.shape)

In [None]:
# view segmentation map
viewer.add_labels(result_cellpose_nucl,
                  scale = (px_size, px_size))

## Tracks

In [None]:
# load tracking data, highlight few tracks
path_overnight_tracking_all = path_analysis_tracks + name + "_TR" + str(n_img) + "_mask_nucl_all.csv"

data_tracks_csv = pd.read_csv(path_overnight_tracking_all, delimiter = ",")
data_tracks_numpy = data_tracks_csv.loc[:, ['IDTrack', 't', 'y', 'x']].to_numpy()

data_tracks_numpy = np.insert(data_tracks_numpy, 2, np.array([0]), axis = 1)

In [None]:
# view tracks
viewer.add_tracks(
    data_tracks_numpy,
    name = "nuclei_tracks_all",
    colormap = 'hsv',
    tail_length = 20,
    scale = (px_size, px_size))

## Spots

### Unassigned

In [None]:
# load spots data (if not in current session)
if 'results_spotiflow' not in locals():
    results_spotiflow = pd.read_csv(path_analysis_spots + name + "_all_spots_spotiflow_general.csv", sep = ";", decimal = ".")
    name_source_spots = 'spotiflow_general'
results_spots = results_spotiflow

In [None]:
# pad spot XY by tile coordinates
results_spots_pad = results_spots.merge(positions_join, on = 'P')
results_spots_pad['X_padded'] = results_spots_pad['X'] + results_spots_pad['x_px']
results_spots_pad['Y_padded'] = results_spots_pad['Y'] + results_spots_pad['y_px']
results_spots_pad['T'] = results_spots_pad['T'].astype(int)
results_spots_pad['C'] = results_spots_pad['C'].astype(int)
results_spots_pad['P'] = results_spots_pad['P'].astype(int)
print(results_spots_pad.tail())

In [None]:
# visualize all points
viewer.add_points(
    np.insert(results_spots_pad[['T', 'Y_padded', 'X_padded']].values, 1, np.array([0]), axis = 1),
    name = "spots_ALL",
    scale = (px_size, px_size),
    size = 15,
    face_color = '#ffffff00',
    edge_color = 'white',
    edge_width = 0.1,
    visible = True,
    ndim = 4
)

In [None]:
# visualize all points, color by detection probability
data_spots_features = {'prob': results_spots_pad['spot_prob']}
face_color_cycle = ['green', 'red']

viewer.add_points(
    np.insert(results_spots_pad[['T', 'Y_padded', 'X_padded']].values, 1, np.array([0]), axis = 1),
    name = "spots_ALL_color_prob",
    features = data_spots_features,
    scale = (px_size, px_size),
    size = 15,
    face_color = '#ffffff00',
    edge_color = 'prob',
    edge_color_cycle = face_color_cycle,
    edge_width = 0.1,
    visible = True,
    ndim = 4
    )


### Assigned & Filtered

In [None]:
results_spots_assigned_filter = pd.read_csv(path_analysis_spots + name + "_all_spots_assigned_filter_spotiflow_general.csv", sep = ";", decimal = ".")

In [None]:
# visualize nucleus assigned and dedeplicated spots
viewer.add_points(
    np.insert(results_spots_assigned_filter[results_spots_assigned_filter['TR'] == n_img][['T', 'Y_padded', 'X_padded']].values, 1, np.array([0]), axis = 1),
    name = "spots_ALL",
    scale = (px_size, px_size),
    size = 15,
    opacity=.2,
    face_color = '#ffffff00',
    border_color = 'white',
    border_width = 0.1,
    visible = True,
    ndim = 4
)