#### Imports

In [None]:
import os
import re
import sys
import numpy as np
import matplotlib.pyplot as plt

import networkx as nx

In [None]:
sys.path.append('..')

In [None]:
from src.utils.image_utils import load_czi_images, enhance_cell_image_contrast
from src.utils.plot_utils import show_3d_segmentation_overlay
from src.track import *

#### Functions

#### Inputs

In [None]:
czi_file_path = '/home/dafei/data/MS2/gRNA2_12.03.25-st-13-II---.czi'
seg_maps_dir = '/home/dafei/output/MS2/3d_cell_segmentation/gRNA2_12.03.25-st-13-II---/masks'

In [None]:
image_data = load_czi_images(czi_file_path)

In [None]:
def extract_time_number(filename):
    """Extract time number from filename like z_stack_t5_seg_masks.npz"""
    try:
        # Find the pattern t{number}_
        match = re.search(r't(\d+)_', filename)
        return int(match.group(1)) if match else 0
    except:
        return 0

# Sort files by time number
files = sorted([os.path.join(seg_maps_dir, x) for x in os.listdir(seg_maps_dir) 
                if x.startswith('z_stack_t') and x.endswith('_seg_masks.npz')], 
               key=extract_time_number)

In [None]:
files


In [None]:
z_stack_seg_mask_t0 = np.load(files[0])['masks']
z_stack_t0 = image_data[0, 0, 1, :, :, :, 0]

In [None]:
show_3d_segmentation_overlay(z_stack_t0, z_stack_seg_mask_t0)

In [None]:
z_stack_seg_mask_t1 = np.load(files[1])['masks']
z_stack_t1 = image_data[0, 1, 1, :, :, :, 0]

In [None]:
show_3d_segmentation_overlay(z_stack_t1, z_stack_seg_mask_t1)

In [None]:
z_stack_seg_mask_t2 = np.load(files[2])['masks']
z_stack_t2 = image_data[0, 2, 1, :, :, :, 0]

In [None]:
show_3d_segmentation_overlay(z_stack_t2, z_stack_seg_mask_t2)

Tracking cells using adjacency graph

In [None]:
# Compute cell locations for the first frame
centers1 = get_cell_centers(z_stack_seg_mask_t0)
labels1 = np.unique(z_stack_seg_mask_t0)
g1 = compute_cell_location(centers=centers1, labels=labels1)

# Compute cell locations for the second frame
centers2 = get_cell_centers(z_stack_seg_mask_t1)
labels2 = np.unique(z_stack_seg_mask_t1)
g2 = compute_cell_location(centers=centers2, labels=labels2)

# Compute cell locations for the third frame
centers3 = get_cell_centers(z_stack_seg_mask_t2)
labels3 = np.unique(z_stack_seg_mask_t2)
g3 = compute_cell_location(centers=centers3, labels=labels3)

In [None]:
matches_1_t0_0 = match_points_between_frames(g1, g2, z_stack_seg_mask_t0, z_stack_seg_mask_t1,distance_threshold=np.sqrt(3))
matches_2_t1_1 = match_points_between_frames(g2, g3, z_stack_seg_mask_t1, z_stack_seg_mask_t2,distance_threshold=np.sqrt(3))

In [None]:
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')

for label_t_plus1 in range(1,449):
    label_t = matches_1_t0_0[label_t_plus1] if label_t_plus1 in matches_1_t0_0 else None
    if label_t is None:
        continue  # Skip if no match found
    # Get positions for the current label in both frames
    row_mask_t1 =centers2[:, 0] == label_t_plus1
    matching_rows1 = centers2[row_mask_t1]
    pos1 = matching_rows1[0][1:]

    row_mask_t2 = centers1[:, 0] == label_t
    matching_rows2 = centers1[row_mask_t2]
    pos2 = matching_rows2[0][1:]
    
    ax.scatter(*pos1, color='blue', label=f'Cell {label_t_plus1} in Frame 2',s=4)
    ax.scatter(*pos2, color='red', label=f'Cell {label_t} in Frame 1' if pos2 is not None else 'No match', s=4)
    ax.plot([pos1[0], pos2[0]], [pos1[1], pos2[1]], [pos1[2], pos2[2]], 'black', linewidth=1, linestyle='--' if pos2 is not None else 'None')
    ax.set_title('Cell Tracking Between Frames')
    ax.set_xlim([0, z_stack_seg_mask_t1.shape[1]])
    ax.set_ylim([0, z_stack_seg_mask_t1.shape[2]])
    ax.set_zlim([0, z_stack_seg_mask_t1.shape[0]])
#ax.legend()
ax.grid(True)
plt.show()

In [None]:
label_t_plus1 = 1
label_t = matches_1_t0_0[label_t_plus1] if label_t_plus1 in matches_1_t0_0 else None

cell_mask_t_plus1 = (z_stack_seg_mask_t1 == label_t_plus1).astype(np.uint8)
cell_mask_t = (z_stack_seg_mask_t0 == label_t).astype(np.uint8)

show_3d_segmentation_overlay(z_stack_t1, cell_mask_t_plus1)
show_3d_segmentation_overlay(z_stack_t0, cell_mask_t)

In [69]:
#init tracklets
tracklets = {}
for i,(label_t_plus1, label_t) in enumerate(matches_1_t0_0.items()):
    tracklets[i] = [int(label_t),int(label_t_plus1)]

In [72]:
def find_key_for_last_tracklet_value_optimized(tracklets, matches_dict, tracklet_id):
    """
    Optimized version using next() with generator expression for early termination.
    """
    last_value = tracklets[tracklet_id][-1]
    
    # Use next() with generator for immediate return on first match
    return next((key for key, value in matches_dict.items() if value == last_value), -1)

In [73]:
max_id = max(tracklets.keys())
new_tracklets = {}

# First, extend existing tracklets
for id, labels in tracklets.items():
    key = find_key_for_last_tracklet_value_optimized(tracklets, matches_2_t1_1, id)
    if key != -1:  # Only append if a match was found
        labels.append(int(matches_2_t1_1[key]))

# Then, create new tracklets for unmatched cells in matches_2_t1_1
used_keys = set()
for id, labels in tracklets.items():
    if len(labels) >= 2:  # Check if tracklet was extended
        used_keys.add(labels[-2])  # Add the second-to-last value (the key that was found)

# Create new tracklets for unused matches
for key, value in matches_2_t1_1.items():
    if key not in used_keys:
        max_id += 1
        new_tracklets[max_id] = [-1, int(key), int(value)]

# Add new tracklets
tracklets.update(new_tracklets)

In [74]:
len(tracklets)

426

In [75]:
tracklets

{0: [146, 1, 1],
 1: [1, 2, 2],
 2: [148, 3, 3],
 3: [2, 4, 4],
 4: [3, 5, 5],
 5: [4, 6, 6],
 6: [5, 7, 7],
 7: [6, 8, 8],
 8: [7, 9, 9],
 9: [154, 10, 10],
 10: [8, 11, 11],
 11: [9, 13, 13],
 12: [11, 14, 14],
 13: [12, 17, 17],
 14: [13, 19, 19],
 15: [14, 20, 20],
 16: [15, 22, 22],
 17: [16, 23, 23],
 18: [17, 24, 24],
 19: [19, 25, 25],
 20: [18, 26, 26],
 21: [20, 27, 27],
 22: [22, 28, 28],
 23: [21, 29, 29],
 24: [24, 30, 30],
 25: [23, 31, 31],
 26: [25, 32, 32],
 27: [26, 33, 33],
 28: [27, 34, 34],
 29: [28, 35, 35],
 30: [30, 37, 37],
 31: [31, 38, 38],
 32: [32, 40, 40],
 33: [34, 41, 41],
 34: [35, 42, 42],
 35: [36, 43, 43],
 36: [37, 44, 44],
 37: [40, 46, 46],
 38: [41, 47, 47],
 39: [39, 48, 48],
 40: [43, 50, 50],
 41: [44, 51, 51],
 42: [45, 52, 52],
 43: [46, 53],
 44: [47, 54, 54],
 45: [48, 55, 55],
 46: [49, 56, 56],
 47: [50, 57, 57],
 48: [51, 58, 58],
 49: [53, 59, 59],
 50: [54, 60, 60],
 51: [55, 61, 61],
 52: [56, 62],
 53: [57, 63, 63],
 54: [59, 65, 65