Skip to content

Commit

Permalink
feat: merged with tapir_contours
Browse files Browse the repository at this point in the history
  • Loading branch information
ElpadoCan committed Jul 2, 2023
2 parents f25706c + a2e9f6d commit cd8035c
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 80 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ cellacdc/scripts/test1.py
cellacdc/scripts/correct_shift_X_old.py
cellacdc/scripts/correct_shift_X_old2.py
cellacdc/scripts/correct_shift_X_single_old.py
cellacdc/_tests
cellacdc/_version.py
cellacdc/.qt_for_python
cellacdc/metrics/*
Expand All @@ -50,7 +51,7 @@ cellacdc/_test_all_icons.py
cellacdc/test1.py
cellacdc/test_download_model.py
cellacdc/test_segm.npy
cellacdc/_profile/\spline_to_obj/regression.ipynb
cellacdc/_profile/spline_to_obj/regression.ipynb
cellacdc/deprecated
cellacdc/bioformats/jars/.old_bioformats_package.jar
cellacdc/java
Expand Down
3 changes: 3 additions & 0 deletions cellacdc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def np_replace_values(arr, old_values, new_values):
return arr

def nearest_nonzero_2D(a, y, x, max_dist=None):
value = a[round(y), round(x)]
if value > 0:
return value
r, c = np.nonzero(a)
dist = ((r - y)**2 + (c - x)**2)
if max_dist is not None:
Expand Down
184 changes: 105 additions & 79 deletions cellacdc/trackers/TAPIR/TAPIR_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from cellacdc import printl
from cellacdc.transformation import resize_lab
from cellacdc.core import nearest_nonzero_2D
from cellacdc.core import nearest_nonzero_2D, get_obj_contours

from ..CellACDC import CellACDC_tracker

Expand All @@ -24,6 +24,12 @@
class SizesToResize:
values = np.arange(256, 1025, 128)

class TrackingInputs:
values = ['Intensity image', 'Segmented objects']

class PointsToTrack:
values = ['Centroids', 'Contours']

class tracker:
def __init__(
self, model_checkpoint_path: os.PathLike=TAPIR_CHECKPOINT_PATH
Expand All @@ -42,7 +48,9 @@ def track(
max_distance=5, save_napari_tracks=False,
use_visibile_information=True, export_to=None,
signals=None, export_to_extension='.csv',
track_segmented_objects=False
tracking_input: TrackingInputs='Intensity image',
which_points_to_track: PointsToTrack='Centroids',
number_of_points_per_object: int=8
):

if video_grayscale.ndim == 4:
Expand All @@ -51,6 +59,7 @@ def track(
raise TypeError(msg)

self._use_visibile_information = use_visibile_information
self._which_points_to_track = which_points_to_track
self.segm_video = segm_video
self.max_dist = max_distance
num_frames = len(video_grayscale)
Expand All @@ -75,12 +84,14 @@ def track(

frames_rgb = self._get_frames_to_track(
reversed_resized_frames, reversed_resized_segm,
track_segmented_objects
tracking_input
)
query_points = self._initialize_query_points(
reversed_resized_segm, track_segmented_objects
query_points, tracks_start_frames = self._initialize_query_points(
reversed_resized_segm, tracking_input, which_points_to_track,
number_of_points_per_object
)

self.tracks_start_frames = tracks_start_frames

# import matplotlib.pyplot as plt
# plt.imshow(frames_rgb[0])
# plt.plot(query_points[:,2], query_points[:,1], 'r.')
Expand All @@ -91,9 +102,7 @@ def track(
self.state
)

tracked_video = self._apply_tracks(
self.reversed_tracks, self.reversed_visibles
)
tracked_video = self._apply_tracks()

if save_napari_tracks:
self._save_napari_tracks(export_to)
Expand All @@ -104,9 +113,9 @@ def track(

def _get_frames_to_track(
self, reversed_resized_frames, reversed_resized_segm,
track_segmented_objects
tracking_input
):
if track_segmented_objects:
if tracking_input == 'Segmented objects':
frames = np.zeros(reversed_resized_segm.shape, dtype=np.float32)
for frame_i, lab in enumerate(reversed_resized_segm):
rp = skimage.measure.regionprops(lab)
Expand All @@ -129,8 +138,7 @@ def _save_napari_tracks(self, export_to):
df = pd.DataFrame(data=napari_tracks, columns=['ID', 'T', 'Y', 'X'])
df.to_csv(napari_tracks_path, index=False)

def _save_tracks(self, export_to):
print('Saving tracks...')
def _build_tracks_table(self):
tracks = self.reversed_tracks[:, ::-1]
visibles = self.reversed_visibles[:, ::-1]
resized_segm = self.reversed_resized_segm[::-1]
Expand All @@ -139,13 +147,9 @@ def _save_tracks(self, export_to):
xx = []
yy = []
visibles_li = []
Y, X = resized_segm.shape[-2:]
segm_IDs = []
for tr, track in enumerate(tqdm(tracks, ncols=100)):
x, y = track[-1]
y_int, x_int = round(y), round(x)
y_int = max(0, min(y_int, Y-1))
x_int = max(0, min(x_int, X-1))
track_ID = resized_segm[-1, y_int, x_int]
track_ID = self._get_track_ID(resized_segm, track)
for frame_i, (x, y) in enumerate(track):
yc = y*self.resize_ratio_height
xc = x*self.resize_ratio_width
Expand All @@ -155,33 +159,36 @@ def _save_tracks(self, export_to):
xx.append(xc)
yy.append(yc)
visibles_li.append(visible)
segm_ID = nearest_nonzero_2D(
resized_segm[frame_i], y, x, max_dist=self.max_dist
)
segm_IDs.append(segm_ID)
df = pd.DataFrame({
'frame_i': frames,
'track_ID': track_ID,
'track_ID': segm_IDs,
'segm_ID': track_IDs,
'y_point': yy,
'x_point': xx,
'visible': visibles_li
}).set_index(['frame_i', 'track_ID']).sort_index()
df.to_csv(export_to)
return df

def _save_tracks(self, export_to):
print('Saving tracks...')
self.df_tracks.to_csv(export_to)

def to_napari_tracks(self, use_centroids=False):
print('Building napari tracks data...')
napari_tracks = []
num_frames = len(self.reversed_resized_segm)
Y, X = self.reversed_resized_segm.shape[-2:]
resized_segm = self.reversed_resized_segm[::-1]
for tr, track in enumerate(tqdm(self.reversed_tracks, ncols=100)):
x, y = track[0]
y_int, x_int = round(y), round(x)
y_int = max(0, min(y_int, Y-1))
x_int = max(0, min(x_int, X-1))
track_ID = self.reversed_resized_segm[0, y_int, x_int]
track_ID = self._get_track_ID(resized_segm, track[::-1])
for reversed_frame_i, (x, y) in enumerate(track):
visible = self.reversed_visibles[tr, reversed_frame_i]
if not visible and self._use_visibile_information:
continue
y_int, x_int = round(y), round(x)
y_int = max(0, min(y_int, Y-1))
x_int = max(0, min(x_int, X-1))
self._append_napari_point(
napari_tracks, y, x, num_frames, reversed_frame_i,
track_ID, use_centroids=use_centroids
Expand All @@ -207,51 +214,43 @@ def _append_napari_point(
xc = x*self.resize_ratio_width
napari_tracks.append((track_ID, frame_i, yc, xc))

def _apply_tracks(self, reversed_tracks, reversed_visibles):
def _get_track_ID(self, resized_segm, track, max_dist=None):
Y, X = resized_segm.shape[-2:]
x, y = track[-1]
# frame_i = self.tracks_start_frames[(round(y), round(x))]
# I still don't know how to get the start frame of each track
# because TAPIR returns a float even for the initialized query
# point of each track
frame_i = -1
y_int, x_int = round(y), round(x)
y_int = max(0, min(y_int, Y-1))
x_int = max(0, min(x_int, X-1))
track_ID = resized_segm[frame_i, y_int, x_int]
return track_ID

def _apply_tracks(self):
print('Applying tracks data...')

# Restore correct order (we tracked backwards)
tracks = reversed_tracks[:, ::-1]
resized_segm = self.reversed_resized_segm[::-1]
visibles = reversed_visibles[:, ::-1]
self.df_tracks = self._build_tracks_table()
self.df_tracks = self.df_tracks[self.df_tracks.visible>0]

# Iterate tracks and determine tracked IDs
old_IDs_tracks = {}
tracked_IDs_tracks = {}
Y, X = resized_segm.shape[-2:]
for tr, track in enumerate(tqdm(tracks, ncols=100)):
# Get the track ID from last frame (we track in reverse)
x, y = track[-1]
y0, x0 = round(y), round(x)
y0 = max(0, min(y0, Y-1))
x0 = max(0, min(x0, X-1))
tracked_ID = resized_segm[-1, y0, x0]
for frame_i, (x, y) in enumerate(track):
if frame_i == 0:
continue

visible = visibles[tr, frame_i]
if not visible and self._use_visibile_information:
continue
y_int, x_int = round(y), round(x)
y_int = max(0, min(y_int, Y-1))
x_int = max(0, min(x_int, X-1))
idxs = (frame_i, y_int, x_int)
oldID = resized_segm[idxs]
if oldID == 0:
oldID = nearest_nonzero_2D(
resized_segm[frame_i], y, x, max_dist=self.max_dist
)

if oldID == 0:
continue

if frame_i not in old_IDs_tracks:
old_IDs_tracks[frame_i] = [oldID]
tracked_IDs_tracks[frame_i] = [tracked_ID]
else:
old_IDs_tracks[frame_i].append(oldID)
tracked_IDs_tracks[frame_i].append(tracked_ID)
for (frame_i, track_ID), df in self.df_tracks.groupby(level=(0,1)):
if track_ID == 0:
continue

oldID = df['segm_ID'].mode().iloc[0]
if oldID == 0:
continue

if frame_i not in old_IDs_tracks:
old_IDs_tracks[frame_i] = [oldID]
tracked_IDs_tracks[frame_i] = [track_ID]
else:
old_IDs_tracks[frame_i].append(oldID)
tracked_IDs_tracks[frame_i].append(track_ID)

tracked_video = self.segm_video.copy()
for frame_i in old_IDs_tracks.keys():
Expand All @@ -271,23 +270,50 @@ def _apply_tracks(self, reversed_tracks, reversed_visibles):
return tracked_video

def _initialize_query_points(
self, reversed_resized_segm, track_segmented_objects
self, reversed_resized_segm, tracking_input,
which_points_to_track, number_of_points_per_object
):
first_lab = reversed_resized_segm[0]
first_lab_rp = skimage.measure.regionprops(first_lab)
num_objs = len(first_lab_rp)
query_points = np.zeros((num_objs, 3), dtype=int)
tracks_start_frames = {}
if which_points_to_track == 'Centroids':
query_points = np.zeros((num_objs, 3), dtype=int)
else:
all_contours = []
for o, obj in enumerate(first_lab_rp):
if track_segmented_objects:
obj_edt = distance_transform_edt(obj.image)
argmax = np.argmax(obj_edt)
yc_loc, xc_loc = np.unravel_index(argmax, obj_edt.shape)
ymin, xmin, _, _ = obj.bbox
yc, xc = yc_loc+ymin, xc_loc+xmin
if which_points_to_track == 'Centroids':
if tracking_input == 'Segmented objects':
# Track the center of the edt of the object
# since edt is also the input image
obj_edt = distance_transform_edt(obj.image)
argmax = np.argmax(obj_edt)
yc_loc, xc_loc = np.unravel_index(argmax, obj_edt.shape)
ymin, xmin, _, _ = obj.bbox
yc, xc = yc_loc+ymin, xc_loc+xmin
else:
# Track the centroid of the object
yc, xc = obj.centroid
query_points[o, 1:] = int(yc), int(xc)
tracks_start_frames[tuple(query_points[0][1:])] = 0
else:
yc, xc = obj.centroid
query_points[o, 1:] = int(yc), int(xc)
return query_points
contours = get_obj_contours(obj)[:-1]
if number_of_points_per_object > 1:
num_points = len(contours)
if number_of_points_per_object < num_points:
step = num_points // number_of_points_per_object
contours = contours[::step]
all_contours.append(contours)
for x, y in contours:
tracks_start_frames[(y, x)] = 0
if which_points_to_track == 'Contours':
all_contours = np.concatenate(all_contours)
nrows = len(all_contours)
query_points = np.zeros((nrows, 3), dtype=int)
query_points[:, 2] = all_contours[:,0]
query_points[:, 1] = all_contours[:,1]

return query_points, tracks_start_frames

def url_help():
return 'https://deepmind-tapir.github.io/'

0 comments on commit cd8035c

Please sign in to comment.