# This document is a run script of tracking with trackastra and post porcessing

use kernel: trackastra host: dlhost12

Resources

In [33]:
!nvidia-smi

import torch
from trackastra.model import Trackastra
from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks
from trackastra.data import example_data_bacteria, example_data_hela, example_data_fluo_3d

device = "cuda" if torch.cuda.is_available() else "cpu"

Tue Aug 20 09:19:16 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.89.02    Driver Version: 525.89.02    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    On   | 00000000:03:00.0 Off |                  Off |
| 30%   48C    P2   119W / 300W |   5243MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000    On   | 00000000:04:00.0 Off |                  Off |
| 47%   76C    P2   260W / 300W |  24937MiB / 49140MiB |     92%      Default |
|       

## Functions: Trackastra R1 input preparation

In [2]:
import os
import cv2
import numpy as np
from skimage import measure
from skimage.segmentation import watershed
from scipy.ndimage import distance_transform_edt, center_of_mass
import matplotlib.pyplot as plt
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from scipy.spatial.distance import cdist
from PIL import Image
import re
import pandas as pd

def numerical_sort(value):
    """
    Extracts the numeric part from the filename for sorting.
    Assumes that the filename format is '<number>_htert_Run'.
    """
    parts = re.findall(r'\d+', value)
    return int(parts[0]) if parts else value

def load_images_from_directory(directory, filtered_img_list = None):
    images = []
    #filenames = sorted([filename for filename in os.listdir(directory) if filename.endswith("_htert_Run.png") and not filename.startswith("._") and "Printed" not in filename], key = numerical_sort) # for cellenONE
    if filtered_img_list: 
        filenames = sorted([filename + ".png" for filename in filtered_img_list], key = numerical_sort)
    else: 
        filenames = sorted([filename for filename in os.listdir(directory) if filename.endswith(".png") and not filename.startswith("._")], key = numerical_sort)
    print(f"filenames are {filenames}")

    for filename in filenames:
        img_path = os.path.join(directory, filename)
        img = Image.open(img_path).convert('L')
        #img = expand_image(img, mode = "images")
        img_array = np.array(img)
        #img_array = np.rot90(img_array)
        images.append(img_array) # for cellenONE
    return np.array(images), [f.replace(".png", "") for f in filenames]


def expand_image(img, mode, factor=3,):
    # Get original dimensions
    original_width, original_height = img.size

    # Calculate new dimensions
    new_width = original_width * factor
    new_height = original_height * factor

    # Resize the image
    if mode == "images": 
        new_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
    elif mode == "masks":
        new_img = img.resize((new_width, new_height), Image.Resampling.NEAREST)

    return new_img

def remove_empty_frame(imgs, masks):
    ind_to_remove = []
    for i in range(masks.shape[0]):
        if np.all(masks[i] == 0):
            ind_to_remove.append(i)

    imgs_new = np.delete(imgs, ind_to_remove, axis = 0)
    masks_new = np.delete(masks, ind_to_remove, axis = 0)

    assert imgs_new.shape == masks_new.shape

    return imgs_new, masks_new, ind_to_remove


def mask_to_bbox(mask):
    """
    Converts a binary mask to a bounding box.

    :param numpy.ndarray mask: Binary mask.
    :return: Bounding box in the format (x, y, w, h).
    :rtype: list[int]
    """
    rows, cols = np.where(mask == 255)
    x1, x2 = np.min(cols), np.max(cols)
    y1, y2 = np.min(rows), np.max(rows)
    return [x1, y1, x2-x1, y2-y1]


def load_masks_from_directory(directory, img_shape, fix_overlap=False, overlap_threshold = 0.6, top_threshold = 0.005, bottom_threshold = 0.995):
    """
    Load masks from a directory and handle overlaps if specified.

    Parameters:
    - directory: str, path to the directory containing the masks.
    - img_shape: tuple, shape of the images.
    - fix_overlap: bool, whether to fix overlaps between masks.

    Returns:
    - masks: numpy array, combined masks for each frame.
    """
    masks = []
    current_object_index = 1
    # frames = sorted([frame for frame in os.listdir(directory) if os.path.isdir(os.path.join(directory, frame)) and frame.endswith("_htert_Run")], key = numerical_sort) # for cellenONE
    frames = sorted([frame for frame in os.listdir(directory) if os.path.isdir(os.path.join(directory, frame)) and frame.startswith("Image_")], key = numerical_sort)
    print(f"The frames are {frames}")

    for frame in frames: #sorted(os.listdir(directory), key = numerical_sort):  # Loop through the frame folders
        frame_dir = os.path.join(directory, frame)
        if os.path.isdir(frame_dir) and frame.startswith("Image_"):  # The directory needs to start with Image for normal runs 
        # if os.path.isdir(frame_dir) and frame.endswith("_htert_Run"): # for cellenONE
            frame_mask = np.zeros(img_shape, dtype=np.int32)

            if fix_overlap:
                curr_masks = []
                
                for filename in sorted(os.listdir(frame_dir)):
                    if filename.endswith(".png") and not filename.startswith("._"):  # Loop through the png files 
                        mask_path = os.path.join(frame_dir, filename)
                        mask = Image.open(mask_path).convert('L')
                        mask_array = np.array(mask)
                        bbox = mask_to_bbox(mask_array)
                        _, y, _, h = bbox

                        if y >= top_threshold * mask_array.shape[0] and (y + h) <= bottom_threshold * mask_array.shape[0]:  # Ignore detections close to top and bottom thresholds
                            if len(np.unique(mask_array)) != 2:
                                raise ValueError("something is up", np.unique(mask_array))
                            curr_masks.append(mask_array)
                        #else:
                            #print(f"Mask {filename} ignored due to top/bottom threshold")


                #print(f"Number of masks in the current frame: {len(curr_masks)}")

                # Create an overlap matrix
                overlap_matrix = np.zeros((len(curr_masks), len(curr_masks)), dtype=int)
                for i in range(len(curr_masks)):
                    for j in range(i + 1, len(curr_masks)):
                        mask_i = curr_masks[i]
                        mask_j = curr_masks[j]

                        # Ensure masks are binary
                        mask1_binary = (mask_i == 255)
                        mask2_binary = (mask_j == 255)

                        # Calculate the size of each mask
                        size1 = np.sum(mask1_binary)
                        size2 = np.sum(mask2_binary)

                        # Identify the smaller and larger masks
                        if size1 < size2:
                            smaller_mask = mask1_binary
                            larger_mask = mask2_binary
                            smaller_size = size1
                        else:
                            smaller_mask = mask2_binary
                            larger_mask = mask1_binary
                            smaller_size = size2

                        # Calculate the overlap
                        overlap = np.sum(smaller_mask & larger_mask)
                        overlap_percentage = overlap / smaller_size
                        #print(f"overalp between mask {i} and maks {j} is {overlap_percentage} with threshold being {overlap_threshold}")

                        if overlap_percentage >= overlap_threshold:
                            overlap_matrix[i, j] = 1
                            overlap_matrix[j, i] = 1

                #print(f"Overlap matrix:\n{overlap_matrix}")

                # Cluster overlapping objects
                sparse_matrix = csr_matrix(overlap_matrix)
                n_components, labels = connected_components(csgraph=sparse_matrix, directed=False, return_labels=True)

                # Group masks by their component labels to form clusters
                clusters = [[] for _ in range(n_components)]
                for mask_index, component_label in enumerate(labels):
                    clusters[component_label].append(mask_index)

                #print(f"Clusters: {clusters}")

                for c in range(len(clusters)):  # Loop through each cluster and merge them
                    masks_in_cluster = [curr_masks[j] for j in clusters[c]]

                    # Create a combined mask
                    combined_mask = np.zeros_like(masks_in_cluster[0], dtype=np.int32)
                    for mask in masks_in_cluster:
                        combined_mask[mask > 0] = 1

                    # Label the combined mask
                    frame_mask[combined_mask > 0] = current_object_index
                    current_object_index += 1

                    #print(f"Processed cluster {c} with {len(masks_in_cluster)} masks.")

            else:  # If not fixing overlaps
                for filename in sorted(os.listdir(frame_dir)):
                    if filename.endswith(".png") and not filename.startswith("._"):  # Loop through the png files  _htert_Run
                        mask_path = os.path.join(frame_dir, filename)
                        mask = Image.open(mask_path).convert('L')
                        mask_array = np.array(mask)
                        if len(np.unique(mask_array)) != 2:
                            raise ValueError("something is up", np.unique(mask_array))
                        # Assign the current object index to the mask pixels
                        frame_mask[(mask_array == 255)] = current_object_index
                        current_object_index += 1
                        #print(f"Processed mask {filename} with index {current_object_index - 1}")

            masks.append(frame_mask)
            # if np.all(frame_mask == 0): print(f"The frame that has all zero is {frame}")
            #print(f"Added frame mask for {frame}, current number of masks: {len(masks)}")

    #print(f"Total frames processed: {len(masks)}")
    return np.array(masks)


## Main()

In [50]:
# Define the main directory
chip = "A138856A" # "A118880" #"A138974A" # "A138856A"
run = "10dropRun4" # "PrintRun_Apr1223_1311" #'htert_20230822_131349_843.Run'
main_img_directory = f"/projects/steiflab/archive/data/imaging/{chip}/NozzleImages/{run}"
#main_img_directory = f"/projects/steiflab/archive/data/imaging/{chip}/CellenONEImages/{run}" # for cellenONE
main_mask_directory = f"/projects/steiflab/scratch/leli/{chip}/{run}/rcnn_output_masks"
out_folder = f'{chip}/{run}/tracked'

# Load images
imgs, img_names = load_images_from_directory(main_img_directory)

# Load masks
masks = load_masks_from_directory(main_mask_directory, imgs[0].shape, fix_overlap = True, overlap_threshold = 0.5)

print("Images shape:", imgs.shape)
print("Masks shape:", masks.shape)

# Ensure the shape matches the required format: (time, y, x)
imgs = imgs.reshape(-1, imgs.shape[1], imgs.shape[2])
masks = masks.reshape(-1, masks.shape[1], masks.shape[2])

imgs, masks, ind_to_remove = remove_empty_frame(imgs, masks)

print("Images shape:", imgs.shape)
print("Masks shape:", masks.shape)


# Load a pretrained model
# or from a local folder
# model = Trackastra.from_folder('path/my_model_folder/', device=device)
model = Trackastra.from_pretrained("general_2d", device=device)

# Track the cells
track_graph = model.track(imgs, masks, mode="greedy")  # or mode="ilp", or "greedy_nodiv"

# Write to cell tracking challenge format
ctc_tracks, masks_tracked = graph_to_ctc(
      track_graph,
      masks,
      outdir=out_folder,
)


## create a file that connects tiffs with the images
tifs = sorted([t for t in os.listdir(out_folder) if t.endswith(".tif")])
img_names_new = np.delete(img_names, ind_to_remove, axis = 0)
link_file = pd.DataFrame({"tifs": tifs, "imgs": img_names_new})
link_file.to_csv(os.path.join(out_folder, "tif_to_img.csv"), index=False)


filenames are ['Image_000001.png', 'Image_000002.png', 'Image_000003.png', 'Image_000004.png', 'Image_000005.png', 'Image_000006.png', 'Image_000007.png', 'Image_000008.png', 'Image_000009.png', 'Image_000010.png', 'Image_000011.png', 'Image_000012.png', 'Image_000013.png', 'Image_000014.png', 'Image_000015.png', 'Image_000016.png', 'Image_000017.png', 'Image_000018.png', 'Image_000019.png', 'Image_000020.png', 'Image_000021.png', 'Image_000022.png', 'Image_000023.png', 'Image_000024.png', 'Image_000025.png', 'Image_000026.png', 'Image_000027.png', 'Image_000028.png', 'Image_000029.png', 'Image_000030.png', 'Image_000031.png', 'Image_000032.png', 'Image_000033.png', 'Image_000034.png', 'Image_000035.png', 'Image_000036.png', 'Image_000037.png', 'Image_000038.png', 'Image_000039.png', 'Image_000040.png', 'Image_000041.png', 'Image_000042.png', 'Image_000043.png', 'Image_000044.png', 'Image_000045.png', 'Image_000046.png', 'Image_000047.png', 'Image_000048.png', 'Image_000049.png', 'Imag

INFO:trackastra.model.model:Loading model state from /home/leli/.trackastra/.models/general_2d/model.pt
INFO:trackastra.model.model_api:Predicting weights for candidate graph


Images shape: (1818, 313, 192)
Masks shape: (1818, 313, 192)
/home/leli/.trackastra/.models/general_2d already downloaded, skipping.
Using device cuda


INFO:trackastra.data.wrfeat:Extracting features from 1818 detections
INFO:trackastra.data.wrfeat:Using single process for feature extraction
Extracting features: 100%|██████████| 1818/1818 [00:07<00:00, 249.38it/s]
INFO:trackastra.model.model_api:Building windows
Building windows: 100%|██████████| 1815/1815 [00:00<00:00, 27441.72it/s]
INFO:trackastra.model.model_api:Predicting windows
Computing associations: 100%|██████████| 1815/1815 [00:37<00:00, 47.77it/s]
INFO:trackastra.model.model_api:Running greedy tracker
INFO:trackastra.tracking.tracking:Build candidate graph with delta_t=1
INFO:trackastra.tracking.tracking:Added 3016 vertices, 2942 edges                             
INFO:trackastra.tracking.tracking:Running greedy tracker
Greedily matched edges:  98%|█████████▊| 2886/2942 [00:00<00:00, 111346.03it/s]
Converting graph to CTC results: 100%|██████████| 176/176 [00:00<00:00, 3349.46it/s]
Saving masks: 100%|██████████| 1818/1818 [00:08<00:00, 225.30it/s]


## Functions: Visualization, helper  

In [4]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import tifffile as tiff

# Function to create a video from the saved PNGs
def create_video(output_dir, output_video, num_frames, width, height, fps=5, n_img = 2, frames_to_process = None):
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter(output_video, fourcc, fps, (width*n_img+10*(n_img-1), height))

    for frame_idx in range(num_frames):

        # Check frame range
        if frames_to_process is not None: 
            if frame_idx not in frames_to_process:
                continue

        if not os.path.isfile(os.path.join(output_dir, f'man_track{frame_idx:04d}.png')): 
            raise ValueError(f"This file does not exists {os.path.join(output_dir, f'man_track{frame_idx:04d}.png')}")

        frame_path = os.path.join(output_dir, f'man_track{frame_idx:04d}.png')
        frame = cv2.imread(frame_path)
        video_writer.write(frame)

    video_writer.release()
    print(f'Video saved as {output_video}')

import os
import cv2
import numpy as np
import tifffile as tiff

def process_frames(imgs, tracking_dir, output_dir, frames_to_process = None):
    font = cv2.FONT_HERSHEY_SIMPLEX

    # Load images if `imgs` is a file path
    if os.path.isfile(imgs):
        imgs = np.load(imgs)
        #print(f"Loaded images from {imgs}")

    tif_files = [file for file in os.listdir(tracking_dir) if not file.startswith("._") and file.endswith("tif")]
    tif_files = sorted(tif_files)
    #print(f"Found {len(tif_files)} tif files")

    # Create the output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        #print(f"Created output directory: {output_dir}")

    for i, file in enumerate(tif_files):
        #print(f"Processing file {i}: {file}")

        # Check frame range
        if frames_to_process is not None: 
            if i not in frames_to_process:
                continue
            
        #print(f"Frame {i} is within the specified range")

        # Read the tiff file
        tif = tiff.imread(os.path.join(tracking_dir, file))
        if tif is None:
            print(f"Error reading label image: {os.path.join(tracking_dir, file)}")
            continue

        # Process the corresponding image
        img = imgs[i]
        img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX)  # Normalize the image to 8-bit range
        img = img.astype(np.uint8)  # Convert to 8-bit for visualization
        original_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        annotated_img = original_img.copy()

        unique_labels = np.unique(tif)
        #print(f"Unique labels found: {unique_labels}")

        for label in unique_labels:
            if label == 0:  # Skip the background
                continue

            # Create a mask for the current label
            mask = np.zeros(tif.shape, dtype=np.uint8)
            mask[tif == label] = 255

            # Find contours
            contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            #print(f"Found {len(contours)} contours for label {label}")

            # Draw contours and label
            for contour in contours:
                cv2.drawContours(annotated_img, [contour], -1, (57, 255, 20), 1)  # Neon green color with thinner trace
                # Get the bounding box for placing the label
                x, y, w, h = cv2.boundingRect(contour)
                cv2.putText(annotated_img, str(label), (x, y - 10), font, 0.5, (0, 0, 255), 1, cv2.LINE_AA)  # Purple

        # Create a white space (column) between the images
        white_space = np.ones((original_img.shape[0], 10, 3), dtype=np.uint8) * 255

        # Combine the original and annotated images with white space in between
        combined_img = cv2.hconcat([original_img, white_space, annotated_img])
        #print(f"Combined image shape: {combined_img.shape}")

        output_path = os.path.join(output_dir, file.replace(".tif", ".png"))
        # Save the output image
        # print(f"Save and complete file {i}: {file}")
        if not cv2.imwrite(output_path, combined_img):
            print(f"Error saving image: {output_path}")
        else:
            print(f"Saved image for frame: {i} at {output_path}")


import os
import numpy as np
import cv2
import tifffile as tiff
import matplotlib.pyplot as plt

# Predefined colors with names
color_map = {
    'red': (255, 0, 0),
    'green': (0, 255, 0),
    'blue': (0, 0, 255),
    'cyan': (0, 255, 255),
    'magenta': (255, 0, 255),
    'yellow': (255, 255, 0),
    'orange': (255, 165, 0),
    'purple': (128, 0, 128),
    'pink': (255, 192, 203),
    'lime': (0, 255, 0)
}

def get_color_name(index):
    """
    Get the color name and RGB values based on the index.

    Parameters:
    - index: int, index of the color in the color map.

    Returns:
    - tuple: (color_name, color_rgb), where color_name is the name of the color and color_rgb is the RGB tuple.
    """
    color_names = list(color_map.keys())
    color_name = color_names[index % len(color_map)]
    color_rgb = color_map[color_name]
    return color_name, color_rgb

def display_colored_labels(path, labels):
    """
    Display an image with each label colored uniquely.

    Parameters:
    - path: str, path to the directory containing TIFF files or a single TIFF file.
    - labels: list of int, list of labels to be colored.

    Output:
    - Display the image with colored labels.
    """
    # Create a blank canvas for the final image
    final_image = None

    def process_tif_file(file_path):
        nonlocal final_image
        tif = tiff.imread(file_path)

        # Initialize the final image if it hasn't been already
        if final_image is None:
            final_image = np.zeros((tif.shape[0], tif.shape[1], 3), dtype=np.uint8)

        # Color each label with a unique color
        for i, label in enumerate(labels):
            color_name, color_rgb = get_color_name(i)
            print(f"Label {label} is colored with {color_name} (RGB: {color_rgb})")
            mask = (tif == label)
            final_image[mask] = color_rgb

    # Check if the path is a directory or a single file
    if os.path.isdir(path):
        # Process each TIFF file in the directory
        for file in os.listdir(path):
            if file.endswith(".tif"):
                file_path = os.path.join(path, file)
                process_tif_file(file_path)
    else:
        # Process the single TIFF file
        process_tif_file(path)

    # Display the final image
    if final_image is not None:
        plt.figure(figsize=(10, 10))
        plt.title("Colored Labels")
        plt.imshow(final_image)
        plt.axis('off')
        plt.show()
    else:
        print("No TIFF files found or processed.")






## Plot()

In [51]:
# Directory containing tracking results (TIFF files and text file)
tracking_dir = f"/projects/steiflab/scratch/leli/trackastra/{out_folder}"
print(f"{os.path.isdir(tracking_dir)}")
# Create an output directory for PNGs
output_dir = f"/projects/steiflab/scratch/leli/trackastra/{out_folder}_imgs/"
os.makedirs(output_dir, exist_ok=True)

total_frame_num = len([file for file in os.listdir(main_img_directory) if file.startswith("Image_") and file.endswith(".png")])
# total_frame_num = len([filename for filename in os.listdir(main_img_directory) if filename.endswith("_htert_Run.png") and not filename.startswith("._") and "Printed" not in filename])
print(f"total frame number is {total_frame_num}")

# Process each frame
act_rcnn_inds = [i+1 for i in range(total_frame_num) if i not in ind_to_remove] # here we are tying to find the corresponding frame index that matches with the rcnn results
assert len([file for file in os.listdir(tracking_dir) if not file.startswith("._") and file.endswith("tif")]) == len(act_rcnn_inds)

start = 3274
end = 3757
frames_to_process = [i for i, act in enumerate(act_rcnn_inds) if act >= start and act <= end] # here if teh act rcnn index is in range then we include the 0-starting index which will be used to index the ti files later
print(frames_to_process)
process_frames(imgs, tracking_dir, output_dir, frames_to_process = None)
print(f"Process frames done!")

'''# Create a video from the saved PNGs
print(f"These are the tiffs that should be in the ground truth")
output_video = f'/projects/steiflab/scratch/leli/trackastra/{out_folder}_imgs/tracked_video_val.mp4'
height, width = imgs.shape[1], imgs.shape[2]  # Get height and width from images
create_video(output_dir, output_video, imgs.shape[0], width, height, fps=3, frames_to_process = frames_to_process)'''

output_video = f'/projects/steiflab/scratch/leli/trackastra/{out_folder}_imgs/tracked_video_full.mp4'
height, width = imgs.shape[1], imgs.shape[2]
create_video(output_dir, output_video, imgs.shape[0], width, height, fps=3, frames_to_process = None)

True
total frame number is 2776
[]
Saved image for frame: 0 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_imgs/man_track0000.png
Saved image for frame: 1 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_imgs/man_track0001.png
Saved image for frame: 2 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_imgs/man_track0002.png
Saved image for frame: 3 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_imgs/man_track0003.png
Saved image for frame: 4 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_imgs/man_track0004.png
Saved image for frame: 5 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_imgs/man_track0005.png
Saved image for frame: 6 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_imgs/man_track0006.png
Saved image for frame: 7 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_imgs/man_tra

## gen_MOTA_df()

In [None]:
import pandas as pd 
import sys 

out_folder = "postprocessing/tracked_1.0_pp"

def get_bbox(mask):
    """
    Converts a binary mask to a bounding box.

    :param numpy.ndarray mask: Binary mask.
    :param int val: Value to consider for the mask (1 for binary mask).
    :return: Bounding box in the format (x, y, w, h).
    :rtype: list[int]
    """
    rows, cols = np.where(mask == 1)
    
    if rows.size == 0 or cols.size == 0:
        print("No pixels found with the specified value.")
        return [0, 0, 0, 0]
    
    x1, x2 = np.min(cols), np.max(cols)
    y1, y2 = np.min(rows), np.max(rows)
    
    #print(f"Rows: {rows}, Cols: {cols}")
    #print(f"x1: {x1}, x2: {x2}, y1: {y1}, y2: {y2}")
    
    return [x1, y1, x2 - x1 + 1, y2 - y1 + 1] 


# Read the text file into a DataFrame
df = pd.read_csv(f"/projects/steiflab/scratch/leli/trackastra/{out_folder}/man_track.txt", sep='\s+', header=None, names=['Index', 'Start', 'End', 'Mother'])
to_remove_list = df[df['Mother'] != 0]['Index'].tolist()

seg_dir = "/projects/steiflab/scratch/leli/A138974A/PrintRun_Apr1223_1311/rcnn_output_masks"
seg_ind = [i+1 for i in range(3757) if i not in ind_to_remove]
seg_name = ([f'Image_{i:06d}' for i in seg_ind])
merged_seg_name =  [f'merged_Image_{i:06d}' for i in seg_ind if os.path.isdir(os.path.join(seg_dir, f'merged_Image_{i:06d}'))]
print(f"merged image directories are: {merged_seg_name} should be eight in total")
track_name = os.listdir(f"/projects/steiflab/scratch/leli/trackastra/{out_folder}")
track_name = ([file for file in track_name if not file.startswith("._") and file.endswith(".tif")])

# Initialize the DataFrame
tracking_df = pd.DataFrame(columns=["Folder", "trackastra", "Track_ID", "N objects", "Paths", "iou", "x", "y", "w", "h"])

for t, s in zip(track_name, seg_name): 
    #print(f"beginning for {t} and {s}")
    tif = tiff.imread(os.path.join(f"/projects/steiflab/scratch/leli/trackastra/{out_folder}", t))
    masks =  [os.path.join(s, file) for file in os.listdir(os.path.join(seg_dir, s)) if not file.startswith("._") and file.endswith(".png")]
    if f"merged_{s}" in merged_seg_name:
        merged = [os.path.join(f"merged_{s}", file) for file in os.listdir(os.path.join(seg_dir, f"merged_{s}")) if not file.startswith("._") and file.endswith(".png")]
        #print(f"merged_{s} is in and there are these pngs {merged}")
        masks = masks + merged
        #print(f"merged_{s} is in and there are these pngs {merged} inputted into the masks: {masks}")


    if tif is None:
        print(f"Error reading label image: {tif}")
        continue

    unique_labels = np.unique(tif)

    if t == 'man_track0823.tif' and s == 'Image_001080':
        print(f"The unique vals are {unique_labels} and there are masks that can be considered: {masks}")

    for label in unique_labels:
        
        if label == 0 :#or label in to_remove_list:  # Skip the background
            continue

        curr_iou = 0.00

        for m in masks: 
            #print(f"beginning for {label} and {m}")
            mask_img = cv2.imread(os.path.join(seg_dir, m), cv2.IMREAD_GRAYSCALE)
        
            if mask_img is not None and np.any(mask_img == 255):
                # Create binary masks
                mask1 = (tif == label).astype(np.uint8)
                mask2 = (mask_img == 255).astype(np.uint8)
                
                # Calculate intersection and union
                intersection = np.sum(mask1 & mask2)
                union = np.sum(mask1 | mask2)
                
                # Calculate IoU
                iou = intersection / union if union != 0 else 0

                bbox = get_bbox(mask1)
                x, y, w, h = bbox

                if t == 'man_track0823.tif' and s == 'Image_001080':
                    print(f"here for mask: {m} we got iou: {iou}")

                if iou > curr_iou:
                    new_entry = {"Folder": s, 
                                    "trackastra": t, 
                                    "Track_ID": label, 
                                    "N objects": len(unique_labels)-1, 
                                    "Paths": m, 
                                    "iou": iou, 
                                    "x": x, 
                                    "y": y, 
                                    "w": w, 
                                    "h": h, 
                                    }
                    if t == 'man_track0823.tif' and s == 'Image_001080':
                        print(f"Now the new entry has iou: {iou} when the threshold is {curr_iou}")

                    curr_iou = iou

                    '''tracking_dict['Folder'].append(s)
                    tracking_dict['Track_ID'].append(label)
                    tracking_dict['N objects'].append(len(unique_labels)-1)
                    tracking_dict['Paths'].append(m)
                    tracking_dict['trackastra'].append(t)
                    tracking_dict['iou'].append(iou)
                    tracking_dict['x'].append(x)
                    tracking_dict['y'].append(y)
                    tracking_dict['w'].append(w)
                    tracking_dict['h'].append(h)'''

                    
        if t == 'man_track0823.tif' and s == 'Image_001080':
            print(f"here for label {label} the new entry is {new_entry}")

        # add in the new entry combination of label and Folder
        tracking_df = pd.concat([tracking_df, pd.DataFrame([new_entry])], ignore_index=True)

#tracking_df = pd.DataFrame(tracking_dict)
print(tracking_df.head(10))
tracking_df.to_csv(f"/projects/steiflab/scratch/leli/trackastra/{out_folder}/tracking_df.csv", index = False)
print(f"The min overlap is {np.min(tracking_df['iou'])} and the max overlap is: {np.max(tracking_df['iou'])} and the averge is {np.mean(tracking_df['iou'])}")

# Assert that the DataFrame contains only unique combinations of Folder, trackastra, and Track_ID
assert tracking_df.duplicated(subset=["Folder", "trackastra", "Track_ID"]).sum() == 0, "The DataFrame contains duplicate combinations of Folder, trackastra, and Track_ID."

# If no assertion error is raised, print a confirmation
print("\nThe DataFrame contains unique combinations of Folder, trackastra, and Track_ID.")

Here we see that the overlapping should be approximate to 100 so lets dive into the ones that are a little bit far from it

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import os
import shutil
import pandas as pd
import matplotlib.pyplot as plt
import os
import shutil
from PIL import Image, ImageOps, ImageDraw, ImageFont

def plot_iou_histogram(df):
    """
    Plot a histogram of the IOU column in the given DataFrame.
    
    :param df: pandas DataFrame containing the column 'iou'.
    """
    plt.figure(figsize=(10, 6))
    plt.hist(df['iou'], bins=30, edgecolor='k', alpha=0.7)
    plt.title('Distribution of IOU')
    plt.xlabel('IOU')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.show()

def copy_files_below_threshold(df, raw_directory, pp_directory, output_directory, threshold=0.95):
    """
    Copy files with IOU below a certain threshold to an output directory.
    
    :param df: pandas DataFrame containing columns 'Folder', 'trackastra', 'iou'.
    :param home_directory: str, path to the home directory.
    :param output_directory: str, path to the output directory.
    :param threshold: float, IOU threshold.
    """
    # Filter rows with IOU below the threshold
    filtered_df = df[df['iou'] <= threshold]
    
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    
    for index, row in filtered_df.iterrows():
        src_folder = os.path.join(raw_directory, row['Folder']+".png")
        src_trackastra = os.path.join(pp_directory, row['trackastra'].replace(".tif", ".png"))
        track_id = row["Track_ID"]
        perc = row["iou"]

        # Open the images in RGB
        img1 = Image.open(src_folder).convert('RGB')
        img2 = Image.open(src_trackastra).convert('RGB')
        
        # Determine the size of the new image
        new_width = max(img1.width, img2.width)
        new_height = max(img1.height, img2.height)
        
        # Create new images with white background
        new_img1 = Image.new('RGB', (new_width, new_height), (255, 255, 255))
        new_img2 = Image.new('RGB', (new_width, new_height), (255, 255, 255))
        
        # Paste the original images onto the white background
        new_img1.paste(img1, (0, 0))
        new_img2.paste(img2, (0, 0))
        
        # Combine images side by side (horizontally) with additional space for text
        combined_img = Image.new('RGB', (new_width * 2, new_height + 50), (255, 255, 255))  # Add extra height for text
        combined_img.paste(new_img1, (0, 50))  # Adjust y-position to leave space for text
        combined_img.paste(new_img2, (new_width, 50))  # Adjust y-position to leave space for text
        
        # Add text (Track_ID) to the white space above the images
        draw = ImageDraw.Draw(combined_img)
        font = ImageFont.load_default()
        text = f"Track_ID: {track_id}   IOU: {perc}"
        bbox = draw.textbbox((0, 0), text, font=font)
        text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
        draw.text(((combined_img.width - text_width) // 2, 10), text, font=font, fill=(0, 0, 0))  # Center the text horizontally
    
    
        
        # Save the combined image
        combined_img.save(os.path.join(output_directory, row['Folder']+"_"+row['trackastra'].replace(".tif", ".png")))
        
        print(f"Copied {row['Folder']} to {output_directory}")

# Define directories
raw_directory = "/projects/steiflab/scratch/leli/A138974A/PrintRun_Apr1223_1311/rcnn_output_overlayed"
pp_directory = "/projects/steiflab/scratch/leli/trackastra/postprocessing/tracked_1.0_imgs_pp"
output_directory = "/projects/steiflab/scratch/leli/trackastra/postprocessing/low_iou_imgs"

# Plot IOU histogram
plot_iou_histogram(tracking_df)

# Copy files below 95% IOU
copy_files_below_threshold(tracking_df, raw_directory, pp_directory, output_directory, threshold=0.7)


## Functions: MOTA 

In [None]:
import pandas as pd
import numpy as np
from scipy.optimize import linear_sum_assignment

def calculate_cost_matrix(tracking_sequences, ground_truth_sequences):
    cost_matrix = np.zeros((len(tracking_sequences), len(ground_truth_sequences)))

    for i, track_seq in enumerate(tracking_sequences):
        for j, gt_seq in enumerate(ground_truth_sequences):
            cost_matrix[i, j] = len(set(track_seq) ^ set(gt_seq))  # Symmetric difference

    return cost_matrix

def get_sequence_pairs(tracking_sequences, ground_truth_sequences, row_ind, col_ind):
    sequence_pairs = []
    missing_from_tracking_seq = set()
    
    for i, j in zip(row_ind, col_ind):
        tracking_sequence = tracking_sequences.iloc[i]
        ground_truth_sequence = ground_truth_sequences.iloc[j]
        missing_from_tracking_seq.update(set(ground_truth_sequence) - set(tracking_sequence))
        sequence_pairs.append((tracking_sequence, ground_truth_sequence))

    return sequence_pairs, missing_from_tracking_seq

def classify_errors_by_parent_and_children(track_info, track_ids):
    parent_tracks = set(track_info[track_info['Mother'] == 0]['Index'])
    child_tracks = set(track_info[track_info['Mother'] != 0]['Index'])

    parent_error_tracks = set(track_ids) & parent_tracks
    child_error_tracks = set(track_ids) & child_tracks

    return parent_error_tracks, child_error_tracks

# Function to get the last two components of a path
def get_last_two_components(path):
    components = path.split('/')
    return '/'.join(components[-2:])


## MOTA() analysis

In [None]:
import pandas as pd
import numpy as np
from scipy.optimize import linear_sum_assignment

# Load data
ground_truth = pd.read_csv('/projects/steiflab/scratch/leli/A138974A/PrintRun_Apr1223_1311/ground_truth/object/gt_object.csv', dtype={"Folder": str, "Paths": str})
man_track = pd.read_csv(f'/projects/steiflab/scratch/leli/trackastra/{out_folder}/man_track.txt', sep='\s+', header=None, names=['Index', 'Start', 'End', 'Mother'])
tracking_data = tracking_df.copy()

# Shorten the paths
tracking_data['Paths'] = [get_last_two_components(path) for path in tracking_data['Paths']]
ground_truth['Paths'] = [get_last_two_components(path) for path in ground_truth['Paths']]

# Extract numerical representation from 'Folder' column in tracking_data
tracking_data['numerical_representation'] = tracking_data['Folder'].str.extract(r'(\d+)').astype(int)
ground_truth['numerical_representation'] = ground_truth['Folder'].str.extract(r'(\d+)').astype(int)

# Find minimum and maximum values
min_value = ground_truth['numerical_representation'].min()
max_value = ground_truth['numerical_representation'].max()
print(f"The ground truth starts at {min_value} and stops at {max_value}")
# Filter rows in tracking_data 
tracking_data = tracking_data[(tracking_data['numerical_representation'] >= min_value) & (tracking_data['numerical_representation'] <= max_value)]

# Drop the intermediate columns
tracking_data = tracking_data.drop(columns=['numerical_representation'])
ground_truth = ground_truth.drop(columns=['numerical_representation'])

tracking_df = tracking_data.dropna(subset=['Paths'])

tracking_sequences = tracking_df.groupby('Track_ID')['Paths'].apply(list)
ground_truth_sequences = ground_truth.groupby('Track_ID')['Paths'].apply(list)

# Compute cost matrix
cost_matrix = calculate_cost_matrix(tracking_sequences, ground_truth_sequences)

# Find optimal assignment 
row_ind, col_ind = linear_sum_assignment(cost_matrix)

sequence_pairs, missing_from_tracking_seq = get_sequence_pairs(tracking_sequences, ground_truth_sequences, row_ind, col_ind)

tracking_paths = set(tracking_df['Paths'])
ground_truth_paths = set(ground_truth['Paths'])

# Calculate MoT metrics based on this assignment
false_positives = tracking_data['Paths'].isna().sum() + len(tracking_paths - ground_truth_paths)
false_negatives = len(ground_truth_paths - tracking_paths)

# Calculate ID switches based on assignment
id_switches = len(missing_from_tracking_seq & tracking_paths)

missing_assignments = tracking_data['Paths'].isna().sum()

# Get unique track IDs for false positives, false negatives, and ID switches
fp_track_ids = list(set(tracking_df[tracking_df['Paths'].isin(tracking_paths - ground_truth_paths)]['Track_ID']))
fn_track_ids = list(set(ground_truth[ground_truth['Paths'].isin(ground_truth_paths - tracking_paths)]['Track_ID']))
id_switch_track_ids = list(set(tracking_df[tracking_df['Paths'].isin(missing_from_tracking_seq & tracking_paths)]['Track_ID']))

# Print results
print(f"Missing assignments: {missing_assignments}")
print(f"ID Switches: {id_switches}")
print(f"False Positives: {false_positives}")
print(f"False Negatives: {false_negatives}")
print(f"Total Observations: {len(ground_truth_paths)}")
print(f'MOTA: {1-((id_switches + false_positives + false_negatives))/len(ground_truth_paths)}')

'''print(f"Unique Track IDs for False Positives: {fp_track_ids}")
print(f"Unique Track IDs for False Negatives: {fn_track_ids}")
print(f"Unique Track IDs for ID Switches: {id_switch_track_ids}")

# Filter for track IDs greater than or equal to 369 and where the mother is non-zero
filtered_tracks = man_track[(man_track['Index'] >= 369) & (man_track['Mother'] != 0)]

# Get the list of track IDs
track_ids = filtered_tracks['Index'].unique().tolist()

print("\n")
print(f"Child Tracks: {track_ids}")
print("\n")
print(f"These ones are False positive and they are child tracks: {[i for i in fp_track_ids if i in track_ids]} and these account for {len([i for i in fp_track_ids if i in track_ids]) / len(fp_track_ids)} of the entire FP")
print("\n")
print(f"These ones are False negatives and they are child tracks: {[i for i in fn_track_ids if i in track_ids]} and these account for {len([i for i in fn_track_ids if i in track_ids]) / len(fn_track_ids)} of the entire FN")
print("\n")
print(f"These ones are ID Switches and they are child tracks: {[i for i in id_switch_track_ids if i in track_ids]} and these account for {len([i for i in id_switch_track_ids if i in track_ids]) / len(id_switch_track_ids)} of the entire IS")
print("\n")
print(f"These are tracks that have both FP and IS: {set(fp_track_ids) & set(id_switch_track_ids)} and amongst those, these are child tracks: {[i for i in list(set(fp_track_ids) & set(id_switch_track_ids)) if i in track_ids]} and this accounts for {len([i for i in list(set(fp_track_ids) & set(id_switch_track_ids)) if i in track_ids]) / len(set(fp_track_ids) & set(id_switch_track_ids))}")
'''


### So here we can see that sub-tracks are causing an issue

## R1_processing RF approach: Note that the top and bottom threshold has been loosely optimized, 0.005 is best 

### Classification Model for Post-Processing Child Tracks

#### Overview
This document outlines the rationale and approach for a classification model designed to post-process child tracks in object tracking data. The model classifies child tracks into three categories to determine the appropriate action for each track. The three categories are: "Merge with Parent," "Keep Separate," and "Discard Track."

#### Classification Classes

##### 1. Merge with Parent --> CLASS 0
**Description:**  
This classifies child tracks that should be merged back into their parent tracks. The decision is based on a high overlap ratio (e.g., more than 70%) between the parent and child tracks. This scenario typically occurs when the child track is a temporary split from the parent due to noise or minor tracking inaccuracies.

**Example:**  
Track 1 (parent) temporarily splits into track 2 (child) due to noise. Since track 2 overlaps with track 1 by 80%, it should be merged back into track 1.


##### 2. Keep Separate --> CLASS 1
**Description:**  
This classifies child tracks that represent legitimate new objects and should remain separate from their parent tracks. This decision is made when thechild track has a consistent trajectory.

**Example:**  
Track 1 (parent) splits into track 2 (child) due to a new object emerging. Track 2 should be kept separate from track 1.



##### 3. Discard Track --> CLASS 2
**Description:**  
This classifies child tracks that are short-lived and represent noise rather than valid objects. These tracks are typically discarded if they last for a very short duration (e.g., less than or equal to three frames).

**Example:**  
Track 1 (parent) is temporarily interrupted by a noise artifact that creates track 2 (child). Track 2 should be discarded.




### Step 1: create data: 
#### This is a list of features we can generate: 
**num_frames**: The number of frames the track spans

**parent_dist**: Euclidean distance between the centroids of the parent track's last frame and the child track's first frame.

**size**: The number of pixels this object has.

**size_avg**: The average number of pixels this object has. 

**size_std**: The std of the pixels numbers.

**size_diff**: The diff between the max and the min sizes.

**size_max**L: The max size of the pixel numbers. 

**size_change**: Ratio of the size (number of pixels) of the child track's first frame to the parent track's last frame.

**shape_change**: Measures of shape dissimilarity between the parent track's last frame and the child track's first frame (e.g., using Hu moments or contour similarity).

**num_obj_nearby_parent**: Number of tracks in the vicinity of the parent track's last frame.

**num_obj_nearby**: Number of tracks in the vicinity of the current track's first frame.

**closest_dist**: The distance from the child track to the nearest other track. If there are no other tracks, a large value is returned.

**avg_parent_delta_direction**: The average of each consecutive parent pair's θ=arctan2(Δy,Δx), if only one frame, pi is returned. 

**total_parent_delta_direction**: The θ=arctan2(Δy,Δx) between the parent's first and last frames, if only one frame, pi is returned. 

**avg_delta_direction**: The average of each consecutive parent pair's θ=arctan2(Δy,Δx), if only one frame, pi is returned. 

**total_delta_direction**: The θ=arctan2(Δy,Δx) between the parent's first and last frames, if only one frame, pi is returned. 



## Manual Target 

In [None]:
import pandas as pd

# Data as a list of tuples
data = [
    (9, 0), (10, 0), (12, 0), (13, 0), (15, 0), (16, 0), (20, 0), (21, 0),
    (40, 0), (41, 0), (42, 0), (43, 0), (52, 2), (53, 2), (58, 0), (59, 0),
    (68, 0), (69, 2), (71, 1), (72, 1), (80, 0), (81, 0), (85, 0), (86, 0),
    (87, 0), (88, 0), (90, 1), (91, 1), (102, 0), (103, 2), (104, 0), (105, 2),
    (122, 2), (123, 0), (135, 0), (136, 2), (168, 1), (169, 1), (179, 1), (180, 2),
    (181, 1), (182, 2), (185, 0), (186, 0), (192, 1), (193, 2), (194, 0), (195, 2),
    (214, 1), (215, 1), (218, 0), (219, 0), (221, 2), (222, 1), (231, 1), (232, 1),
    (236, 0), (237, 0), (259, 0), (260, 0), (275, 0), (276, 0), (281, 0), (282, 2),
    (285, 0), (286, 0), (299, 0), (300, 2), (305, 0), (306, 0), (307, 2), (308, 2),
    (313, 2), (314, 2), (319, 1), (320, 1), (328, 0), (332, 2), (333, 0),
    (334, 1), (335, 0), (339, 1), (340, 2), (342, 0), (343, 0), (353, 0), (354, 0),
    (355, 0), (356, 2), (360, 0), (361, 0), (367, 1), (368, 1), (388, 0), (389, 0),
    (390, 0), (391, 0), (392, 0), (393, 2), (394, 0), (395, 2), (401, 0), (402, 2),
    (403, 2), (404, 0), (405, 0), (406, 2), (426, 0), (427, 0)
]

# Convert to DataFrame
df = pd.DataFrame(data, columns=['Track_ID', 'Target'])

print(df)


## Construct_feats()

In [6]:
import os
import numpy as np
import pandas as pd
import cv2
from skimage import measure
from scipy.spatial import distance
from scipy.ndimage import center_of_mass
import tifffile as tiff

def calculate_direction_change(frames, track_id):
    coords = [center_of_mass(frame == track_id) for frame in frames]
    
    # If there's only one frame, return default values
    if len(coords) <= 1:
        return np.pi, np.pi

    deltas = np.diff(coords, axis=0)

    # Ensure deltas is at least 2-dimensional for proper indexing
    if deltas.ndim == 1:
        deltas = deltas.reshape(-1, 2)

    directions = np.arctan2(deltas[:, 1], deltas[:, 0])
    avg_direction = np.mean(directions) if len(directions) > 0 else np.pi
    total_direction = np.arctan2(coords[-1][1] - coords[0][1], coords[-1][0] - coords[0][0])
    
    return avg_direction, total_direction

def calculate_features(track_info_file, tif_directory, track_ids):
    # Load track information
    track_info = pd.read_csv(track_info_file, sep='\s+', header=None, names=['Track_ID', 'Start', 'End', 'Parent'], dtype={'Track_ID': int, 'Start': int, 'End': int, 'Parent': int})
    
    # Initialize a dictionary to store the features
    features_dict = {
        'Track_ID': [],
        'num_frames': [],
        'parent_dist': [],
        'size_avg': [],
        'size_std': [],
        'size_diff': [],
        'size_max': [],
        'size_change': [],
        'shape_change': [],
        'num_obj_nearby_parent': [],
        'num_obj_nearby': [],
        'closest_dist': [],
        'avg_parent_delta_direction': [],
        'total_parent_delta_direction': [],
        'avg_delta_direction': [],
        'total_delta_direction': [],
        'is_parent': [],
        'has_grandparent': [],
        'siblings_count': [],
        'avg_sibling_dist': [],
        'sibling_dist_diff': [],
        'num_parent_frames': [],
        'first_frame_y': [],
    }
    
    for i in track_ids:

        track_id, start_frame, end_frame, parent_id = track_info.loc[track_info['Track_ID'] == i].values[0]

        #print(track_id, start_frame, end_frame, parent_id)
        # Load the frames for the current track
        track_frames = [tiff.imread(os.path.join(tif_directory, f'man_track{frame:04d}.tif')) for frame in range(start_frame, end_frame + 1)]
        
        # Calculate the number of frames
        num_frames = len(track_frames)
        
        # Calculate the size (number of pixels) for each frame and then get the average and std
        sizes = [np.sum(frame == track_id) for frame in track_frames]
        size_avg = np.mean(sizes)
        size_std = np.std(sizes) if num_frames > 1 else 0
        size_max = np.max(sizes)
        size_diff = abs(np.max(sizes) - np.min(sizes))

        # Calculate the size change ratio
        if parent_id != 0:
            #print(f"parent id is: {parent_id}")
            parent_start_frame = track_info[track_info['Track_ID'] == parent_id]['Start'].values[0]
            parent_end_frame = track_info[track_info['Track_ID'] == parent_id]['End'].values[0]
            #print(f"parent end frame is: {parent_end_frame}")
            parent_frame = tiff.imread(os.path.join(tif_directory, f'man_track{parent_end_frame:04d}.tif'))
            parent_size = np.sum(parent_frame == parent_id)
            size_change = size_avg / parent_size if parent_size != 0 else 0
        else:
            size_change = np.nan
        
        # Calculate the shape change using Hu moments
        contours, _ = cv2.findContours((track_frames[0] == track_id).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        hu_moments = cv2.HuMoments(cv2.moments(contours[0])).flatten() if contours else np.zeros(7)
        
        if parent_id != 0:
            parent_contours, _ = cv2.findContours((parent_frame == parent_id).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            parent_hu_moments = cv2.HuMoments(cv2.moments(parent_contours[0])).flatten() if parent_contours else np.zeros(7)
            shape_change = np.sum(np.abs(hu_moments - parent_hu_moments))
        else:
            shape_change = np.nan
        
        # Calculate the number of objects nearby the parent and current track's first frame
        radius = 0.15 * max(track_frames[0].shape)
        
        def count_nearby_objects(frame, track_id):
            centroids = [center_of_mass(frame == obj_id) for obj_id in np.unique(frame) if obj_id != track_id and obj_id != 0]
            track_centroid = center_of_mass(frame == track_id)
            nearby_objects = [centroid for centroid in centroids if distance.euclidean(centroid, track_centroid) <= radius]
            return len(nearby_objects)
        
        if parent_id != 0:
            num_obj_nearby_parent = count_nearby_objects(parent_frame, parent_id)
        else:
            num_obj_nearby_parent = 0
        
        num_obj_nearby = count_nearby_objects(track_frames[0], track_id)
        
        # Calculate the Euclidean distance to the nearest other track
        all_other_tracks = np.unique(track_frames[0])
        if len(all_other_tracks) > 1:
            distances = [distance.euclidean(center_of_mass(track_frames[0] == other_id), center_of_mass(track_frames[0] == track_id))
                         for other_id in all_other_tracks if other_id != track_id]
            closest_dist = min(distances)
        else:
            closest_dist = 0
        
        if parent_id != 0:
            parent_frames = [tiff.imread(os.path.join(tif_directory, f'man_track{frame:04d}.tif')) for frame in range(parent_start_frame, parent_end_frame + 1)]
            # print(f"this is the size of the parent frames: {len(parent_frames)}, from {parent_start_frame} to {parent_end_frame}")
            avg_parent_delta_direction, total_parent_delta_direction = calculate_direction_change(parent_frames, parent_id)
        else:
            avg_parent_delta_direction = total_parent_delta_direction = np.pi
        
        avg_delta_direction, total_delta_direction = calculate_direction_change(track_frames, track_id)
        
        # Check if the current track is a parent
        is_parent = int(track_info['Parent'].isin([track_id]).any())
        
        # Check if the current track has a grandparent
        has_grandparent = int(parent_id != 0 and track_info.loc[track_info['Track_ID'] == parent_id, 'Parent'].values[0] != 0)
        
        # Count the number of siblings (tracks sharing the same parent)
        siblings = track_info[track_info['Parent'] == parent_id]
        siblings_count = track_info[track_info['Parent'] == parent_id].shape[0] if parent_id != 0 else 0

        # Calculate the average, and difference between max and min distance to siblings
        if siblings_count > 1:
            sibling_distances = []
            for frame in track_frames:
                for sibling_id in siblings['Track_ID']:
                    if sibling_id != track_id and sibling_id in frame:
                        sibling_centroid = center_of_mass(frame == sibling_id)
                        track_centroid = center_of_mass(frame == track_id)
                        # print(f"The track {i} has sibling centroid {sibling_centroid} and the track centroid {track_centroid} when in this frame there are tracks {np.unique(frame)}")
                        sibling_distances.append(distance.euclidean(sibling_centroid, track_centroid))
            avg_sibling_dist = np.mean(sibling_distances) if sibling_distances else 0
            sibling_dist_diff = (np.max(sibling_distances) - np.min(sibling_distances)) if sibling_distances else 0
        else:
            avg_sibling_dist = sibling_dist_diff = 0
        
        # Calculate the number of parent frames
        num_parent_frames = len(parent_frames) if parent_id != 0 else 0
        
        # Get the y position of the first frame
        first_frame_y = center_of_mass(track_frames[0] == track_id)[0]
        
        # Append the features to the dictionary
        features_dict['Track_ID'].append(track_id)
        features_dict['num_frames'].append(num_frames)
        features_dict['parent_dist'].append(closest_dist)
        features_dict['size_avg'].append(size_avg)
        features_dict['size_std'].append(size_std)
        features_dict['size_max'].append(size_max)
        features_dict['size_diff'].append(size_diff)
        features_dict['size_change'].append(size_change)
        features_dict['shape_change'].append(shape_change)
        features_dict['num_obj_nearby_parent'].append(num_obj_nearby_parent)
        features_dict['num_obj_nearby'].append(num_obj_nearby)
        features_dict['closest_dist'].append(closest_dist)
        features_dict['avg_parent_delta_direction'].append(avg_parent_delta_direction)
        features_dict['total_parent_delta_direction'].append(total_parent_delta_direction)
        features_dict['avg_delta_direction'].append(avg_delta_direction)
        features_dict['total_delta_direction'].append(total_delta_direction)
        features_dict['is_parent'].append(is_parent)
        features_dict['has_grandparent'].append(has_grandparent)
        features_dict['siblings_count'].append(siblings_count)
        features_dict['avg_sibling_dist'].append(avg_sibling_dist)
        features_dict['sibling_dist_diff'].append(sibling_dist_diff)
        features_dict['num_parent_frames'].append(num_parent_frames)
        features_dict['first_frame_y'].append(first_frame_y)

    # print(f"features_dict is {features_dict}")
    return features_dict

'''# Example usage
track_info_file = f'/projects/steiflab/scratch/leli/trackastra/{out_folder}/man_track.txt'
tif_directory = f'/projects/steiflab/scratch/leli/trackastra/{out_folder}'
track_ids = df['Track_ID']

features_dict = calculate_features(track_info_file, tif_directory, track_ids)
features_dict["Target"] = df["Target"]

features_df = pd.DataFrame(features_dict)
features_df.to_csv(f'/projects/steiflab/scratch/leli/trackastra/postprocessing/df.csv' , index = False)

features_df.head(20)'''

'# Example usage\ntrack_info_file = f\'/projects/steiflab/scratch/leli/trackastra/{out_folder}/man_track.txt\'\ntif_directory = f\'/projects/steiflab/scratch/leli/trackastra/{out_folder}\'\ntrack_ids = df[\'Track_ID\']\n\nfeatures_dict = calculate_features(track_info_file, tif_directory, track_ids)\nfeatures_dict["Target"] = df["Target"]\n\nfeatures_df = pd.DataFrame(features_dict)\nfeatures_df.to_csv(f\'/projects/steiflab/scratch/leli/trackastra/postprocessing/df.csv\' , index = False)\n\nfeatures_df.head(20)'

In [None]:
from sklearn.compose import make_column_transformer, ColumnTransformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline

#df = features_df.copy()
df = pd.read_csv('/projects/steiflab/scratch/leli/trackastra/postprocessing/df.csv')
X = df.drop(columns=['Target'])
y = df["Target"]

numeric_feats = ["num_frames",
                "parent_dist", 
                "size_avg", 
                "size_std", 
                "size_diff",
                "size_max", 
                "size_change", 
                "shape_change", 
                "num_obj_nearby_parent", 
                "num_obj_nearby", 
                "closest_dist", 
                "avg_parent_delta_direction",
                "total_parent_delta_direction", 
                "avg_delta_direction", 
                "total_delta_direction", 
                "siblings_count",
                "avg_sibling_dist",
                "sibling_dist_diff", 
                "num_parent_frames", 
                "first_frame_y",
                ]
categorical_feats = ["is_parent", 
                    "has_grandparent",
                    ]
drop_feats = ["Track_ID",]

assert X.shape[1] == len(numeric_feats) + len(categorical_feats) + len(drop_feats)

# Define the transformers for numerical and categorical data
numerical_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='median')),
    ('scaler', StandardScaler())
])

categorical_transformer = Pipeline(steps=[
    ('imputer', SimpleImputer(strategy='most_frequent')),
    ('onehot', OneHotEncoder(handle_unknown='ignore'))
])

rf_ct = make_column_transformer(
    ("passthrough", numeric_feats),  # For random forest just keep it as it is
    (categorical_transformer, categorical_feats),  # OHE on categorical features
    ("drop", drop_feats),  # drop the drop features
)

ct = make_column_transformer(
    (numerical_transformer, numeric_feats),  # For random forest just keep it as it is
    (categorical_transformer, categorical_feats),  # OHE on categorical features
    ("drop", drop_feats),  # drop the drop features
)

print(f"The random forest ct is: \n {rf_ct}")
print(f"The ct is: \n {ct}")


In [None]:
import pandas as pd
import numpy as np
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split, RandomizedSearchCV, cross_validate
from sklearn.dummy import DummyClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neural_network import MLPClassifier
import lightgbm as lgb
import xgboost as xgb
import catboost as cb
from sklearn.pipeline import Pipeline
from sklearn.metrics import make_scorer, accuracy_score, precision_score, recall_score, f1_score
import joblib
from sklearn.tree import plot_tree

best_model_dir = '/projects/steiflab/scratch/leli/trackastra/postprocessing/'

# Split the data into training and testing datasets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Define the models
models = {
    'Dummy': DummyClassifier(strategy='most_frequent'),
    'Logistic Regression': LogisticRegression(),
    'Random Forest': RandomForestClassifier(),
    'Decision Tree': DecisionTreeClassifier(),
    'SVM': SVC(),
    'k-NN': KNeighborsClassifier(),
    'Naive Bayes': GaussianNB(),
    'Gradient Boosting': GradientBoostingClassifier(),
    'XGBoost': xgb.XGBClassifier(use_label_encoder=False, eval_metric='mlogloss'),
    'LightGBM': lgb.LGBMClassifier(),
    'CatBoost': cb.CatBoostClassifier(verbose=0),
}

# Define the scoring metrics
scoring = {
    'accuracy': make_scorer(accuracy_score),
    'precision': make_scorer(precision_score, average='weighted', zero_division=0),
    'recall': make_scorer(recall_score, average='weighted', zero_division=0),
    'f1': make_scorer(f1_score, average='weighted', zero_division=0)
}

# Define the models and their hyperparameter search space
param_distributions = {
    'Logistic Regression': {
        'classifier__C': np.logspace(-4, 4, 50),
        'classifier__solver': ['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'],
        'classifier__max_iter': [200, 500, 800, 1000]
    },
    'Random Forest': {
        'classifier__n_estimators': [50, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1500],
        'classifier__max_depth': [None, 10, 20, 30, 40, 50, 60, 70, 100, 150],
    },
    'Decision Tree': {
        'classifier__max_depth': [None, 10, 20, 30, 40, 50],
        'classifier__min_samples_split': [2, 5, 10],
        'classifier__min_samples_leaf': [1, 2, 4],
    },
    'SVM': {
        'classifier__C': np.logspace(-3, 2, 50),
        'classifier__kernel': ['linear', 'poly', 'rbf', 'sigmoid'],
        'classifier__gamma': ['scale', 'auto'] + list(np.logspace(-3, 2, 50))
    },
    'k-NN': {
        'classifier__n_neighbors': list(range(1, 31)),
        'classifier__weights': ['uniform', 'distance'],
        'classifier__metric': ['euclidean', 'manhattan', 'chebyshev', 'minkowski']
    },
    'Naive Bayes': {
        'classifier__var_smoothing': np.logspace(-9, -1, 100)
    },
    'Gradient Boosting': {
        'classifier__n_estimators': [50, 100, 200, 300, 400, 500, 600],
        'classifier__learning_rate': np.logspace(-4, 0, 50),
        'classifier__max_depth': [3, 4, 5, 6, 7, 8, 9, 10]
    },
    'XGBoost': {
        'classifier__n_estimators': [50, 100, 200, 300, 400, 500, 600],
        'classifier__learning_rate': np.logspace(-4, 0, 50),
        'classifier__max_depth': [3, 4, 5, 6, 7, 8, 9, 10],
        'classifier__subsample': [0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    },
    'LightGBM': {
        'classifier__n_estimators': [50, 100, 200],  # Reduced range
        'classifier__learning_rate': np.logspace(-3, 0, 5),  # Reduced from 50 to 5
        'classifier__num_leaves': [20, 31, 50],  # Reduced range
        'classifier__min_gain_to_split': [0, 0.1, 0.5],  # Reduced range
        'classifier__max_depth': [10, 20, 30],  # Reduced range
        'classifier__max_bin': [255, 128],  # Reduced range
        'classifier__early_stopping_rounds': [10],  # Focused on a single value
        'classifier__max_iter': [100, 200, 300]  # Reduced range for quicker searches
    }, 
    'CatBoost': {
        'classifier__iterations': [50, 100, 200],
        'classifier__depth': [4, 5, 6, 7, 8, 9, 10],
        'classifier__learning_rate': np.logspace(-4, 0, 50)
    },
}

# Function to evaluate the model and store the results
def evaluate_model(name, model, X_train, y_train, X_test, y_test, preprocessor):
    model_pipeline = Pipeline(steps=[('preprocessor', preprocessor), ('classifier', model)])
    cv_results = cross_validate(model_pipeline, X_train, y_train, cv=5, scoring=scoring)
    
    print(f"\nCross-validation results for {name}:")
    print(pd.DataFrame(cv_results))

    if name == "Dummy" or name == "LightGBM":
        model_pipeline.fit(X_train, y_train)
        y_pred = model_pipeline.predict(X_test)
        test_accuracy = accuracy_score(y_test, y_pred)
        test_precision = precision_score(y_test, y_pred, average='weighted', zero_division=0)
        test_recall = recall_score(y_test, y_pred, average='weighted', zero_division=0)
        test_f1 = f1_score(y_test, y_pred, average='weighted', zero_division=0)
        best_model = model_pipeline
    
    else: 
        # Perform random search for hyperparameter tuning
        random_search = RandomizedSearchCV(model_pipeline, param_distributions=param_distributions[name], n_iter=20, cv=5, verbose=1, random_state=123, n_jobs=-1)
        random_search.fit(X_train, y_train)
        
        # Get the best model
        best_model = random_search.best_estimator_
        
        # Predict on the test set
        y_pred = best_model.predict(X_test)
        test_accuracy = accuracy_score(y_test, y_pred)
        test_precision = precision_score(y_test, y_pred, average='weighted')
        test_recall = recall_score(y_test, y_pred, average='weighted')
        test_f1 = f1_score(y_test, y_pred, average='weighted')
    
    return name, test_accuracy, test_precision, test_recall, test_f1, best_model

# Evaluate each model and store the results
test_performance = {
    "model": [],
    'accuracy': [],
    'precision': [],
    'recall': [],
    'f1': [],
}
test_model = {}

for name, model in models.items():
    preprocessor = ct
    if name == 'Decision Tree' or name == "Random Forest": 
        preprocessor = rf_ct
    name, test_accuracy, test_precision, test_recall, test_f1, best_model = evaluate_model(name, model, X_train, y_train, X_test, y_test, preprocessor = preprocessor)
    test_performance['model'].append(name)
    test_performance['accuracy'].append(test_accuracy)
    test_performance['precision'].append(test_precision)
    test_performance['recall'].append(test_recall)
    test_performance['f1'].append(test_f1)

    test_model[name] = best_model

    # save the optimized model 
    model_path = os.path.join(best_model_dir, "models", f"{name}_model.pkl")
    joblib.dump(best_model, model_path)
    print(f"Model '{name}' saved to '{model_path}'")

    # save the testing performance
    pd.DataFrame(test_performance).to_csv(os.path.join(best_model_dir, f"test_performance.csv"), index = False)
    print("saved the test perpformance")


test_df = pd.DataFrame(test_performance)

print("Best model based on accuracy:", test_df.loc[test_df['accuracy'].idxmax()]['model'])
print("Best model based on precision:",  test_df.loc[test_df['precision'].idxmax()]['model'])
print("Best model based on recall:",  test_df.loc[test_df['recall'].idxmax()]['model'])
print("Best model based on F1 score:",  test_df.loc[test_df['f1'].idxmax()]['model'])




### Up until this point we have the model trained and saved. The next step is to reflect these changes in the tif files. 

The best performing model to try is SVM abd perhaps catboost and gradient boost

## Function: gen_pp_res()

In [7]:
import os
import shutil
import pandas as pd
import numpy as np
from skimage import io
from skimage.measure import label, regionprops
from scipy.ndimage import label
import joblib
import tifffile as tiff
from skimage.morphology import binary_erosion
from scipy.ndimage import distance_transform_edt
from skimage.segmentation import watershed

def connect_objects_localized(mask1, mask2, kernel_size=5, iterations=1):
    """
    Connects objects in two masks using a localized morphological dilation.

    Parameters:
    - mask1: numpy array, the first mask.
    - mask2: numpy array, the second mask.
    - kernel_size: int, size of the dilation kernel.
    - iterations: int, number of dilation iterations.

    Returns:
    - connected_labels: numpy array, the labeled image after connecting objects.
    """
    combined_mask = np.maximum(mask1, mask2)
    binary_mask = (combined_mask > 0)
    distance = distance_transform_edt(binary_mask)
    markers, _ = label(binary_mask)
    labels = watershed(-distance, markers, mask=binary_mask)
    connection_mask = np.zeros_like(combined_mask)
    for label_val in np.unique(labels):
        if label_val == 0:
            continue
        component_mask = (labels == label_val)
        if np.sum(component_mask & mask1) > 0 and np.sum(component_mask & mask2) > 0:
            connection_mask[component_mask] = 255
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    eroded_connection = binary_erosion(connection_mask, kernel).astype(np.uint8) * 255
    connected_objects = np.where(eroded_connection > 0, eroded_connection, combined_mask)
    return connected_objects

def get_centroid_y(mask, label_value):
    regions = regionprops(mask)
    for region in regions:
        if region.label == label_value:
            return region.centroid[1]  # Corrected to return the y-coordinate
    return None

def update_track_info_across_frame(old_track_info_df, track_info_df, frame, frame_number):
    unique_labels = np.unique(frame)
    unique_labels = unique_labels[unique_labels != 0]  # Remove background (label 0)

    for track_id in unique_labels:
        row = track_info_df[track_info_df['Track_ID'] == track_id]

        if not row.empty:
            start_frame = int(row['Start'].values[0])
            end_frame = int(row['End'].values[0])
            start_frame = min(start_frame, frame_number)
            end_frame = max(end_frame, frame_number)
            track_info_df.loc[track_info_df['Track_ID'] == track_id, 'Start'] = start_frame
            track_info_df.loc[track_info_df['Track_ID'] == track_id, 'End'] = end_frame
        else:
            parent_value = old_track_info_df.loc[old_track_info_df['Track_ID'] == track_id, 'Parent'].values[0] if track_id in old_track_info_df['Track_ID'].values else 0
            new_row = pd.DataFrame({
                'Track_ID': [track_id],
                'Start': [frame_number],
                'End': [frame_number],
                'Parent': [parent_value]
            })
            track_info_df = pd.concat([track_info_df, new_row], ignore_index=True)
    
    return track_info_df

import os
import numpy as np
import tifffile as tiff
import matplotlib.pyplot as plt
import cv2

def color_labels(tif, labels, colors):
    """
    Color the specified labels in the TIFF image with the given colors.

    Parameters:
    - tif: numpy array, the TIFF image.
    - labels: list of int, the labels to color.
    - colors: list of tuple, the colors corresponding to each label.

    Returns:
    - colored_img: numpy array, the image with colored labels.
    """
    colored_img = cv2.cvtColor(tif.astype(np.uint8), cv2.COLOR_GRAY2BGR)

    for label, color in zip(labels, colors):
        mask = (tif == label)
        colored_img[mask] = color

    return colored_img

def display_colored_images(frame, labels_to_color, title):
    """
    Display the colored images for the specified labels
    """
    plt.figure(figsize=(8, 8))
    colors = [
        (255, 0, 0),  # Red
        (0, 255, 0),  # Green
        (0, 0, 255),  # Blue
        (255, 255, 0),  # Yellow
        (255, 0, 255), # Magenta
        (255, 165, 0),  # Orange
        (128, 0, 128) ,  # Purple
    ]

    colored_img1 = color_labels(frame, labels_to_color, colors)

    plt.imshow(colored_img1)
    plt.title(title)
    plt.tight_layout()
    plt.show()

def postprocess_frame(frame, track_info, classification, min_size=100):
    
    groups = track_info.groupby('Root')['Track_ID'].apply(list).to_dict()
    unique_labels = np.unique(frame)
    unique_labels = unique_labels[unique_labels != 0]

    groups = {root: tracks for root, tracks in groups.items() if any(track in unique_labels for track in tracks)}

    for root, group in groups.items():
        print(f"This group has root {root} and contains {group}")

        objs_to_be_merged = []

        for track_id in group:
            print(f"iteration: {track_id} where it is \n {track_info.loc[track_info['Track_ID'] == track_id]}")
            if track_id == root or track_info.loc[track_info['Track_ID'] == track_id, 'Parent'].values[0] == 0:
                continue
            print(f"iteration: {track_id}")
            mask = (frame == track_id)
            print(f"This objects {track_id} has a size of {np.sum(mask)}")

            # Remove small objects
            if np.sum(mask) < min_size:
                frame[mask] = 0
                print("Object with track id: {track_id} is too small so we remove all together")
                continue

            action = classification.loc[track_id, 'action']

            if action == 1:
                print(f"Object with track id: {track_id} has action 1")
                continue  # Keep as is

            elif action == 2:
                print(f"Object with track id: {track_id} has action 2")
                frame[mask] = 0  # Remove

            elif action == 0:
                print(f"Object with track id: {track_id} has action 0")
                centroid_y = get_centroid_y(frame, track_id)
                objs_to_be_merged.append((track_id, centroid_y))
        
        if len(objs_to_be_merged) != 0:
            objs_to_be_merged.sort(key=lambda x: x[1])
            print(f"These are the objects that will be merged {objs_to_be_merged}")

            combined_mask = (frame == objs_to_be_merged[0][0])
            for i in range(1, len(objs_to_be_merged)):
                        next_mask = (frame == objs_to_be_merged[i][0])

                        # here this is becuase we realize that the function morphologyEx alter other part of the masks so it does not just add the connecting part it alters the original
                        old_mask1 = combined_mask.copy()
                        old_mask2 = next_mask.copy()

                        for kernel_size in range(4):
                            temp = np.logical_or(combined_mask, next_mask).astype(np.uint8)
                            combined_mask = cv2.morphologyEx(temp, cv2.MORPH_CLOSE, np.ones((kernel_size * 5, kernel_size * 5), np.uint8))
                            num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(combined_mask, connectivity=8)
                            if num_labels-1 == 1: 
                                break
                        
                        combined_mask = np.logical_or(combined_mask, old_mask1).astype(np.uint8)
                        combined_mask = np.logical_or(combined_mask, old_mask2).astype(np.uint8)
                        
                        num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(combined_mask, connectivity=8)
                        if num_labels-1 > 1:
                            #display_colored_images(combined_mask, labels_to_color = [1, 2, 3, 4], title= f'current Frame: {frame_num}')
                            print("Failed to connect objects after multiple attempts.")
                        else: 
                            print(f"Succeeded to connect objects with kernel size {kernel_size}")

            # do the final painting of pixels 
            frame[combined_mask > 0] = root

    return frame

import pandas as pd
from sklearn.linear_model import LinearRegression

def predict_next_centroids(centroids, centroids_frame_with_prediction, predict_this_frame):
    """
    Predicts the next centroid positions given a list of centroids and the corresponding frame numbers.

    Parameters:
    - centroids: list of tuples (y, x) representing the coordinates of the centroids.
    - centroids_frame_with_prediction: list of frame numbers corresponding to the centroids.

    Returns:
    - predicted_centroids: tuple (y, x) representing the predicted coordinates of the next centroid.
    """
    # Create a DataFrame for easy handling
    df = pd.DataFrame(centroids, columns=['y', 'x'])
    df['frame'] = centroids_frame_with_prediction

    # Extract features and targets
    X = df['frame'].values.reshape(-1, 1)  # Frames as features
    y_y = df['y'].values  # y-coordinates as target
    y_x = df['x'].values  # x-coordinates as target

    # Fit linear regression models
    model_y = LinearRegression()
    model_x = LinearRegression()
    model_y.fit(X, y_y)
    model_x.fit(X, y_x)

    # Predict the next frame
    next_frame = np.array([predict_this_frame]).reshape(-1, 1)
    pred_y = model_y.predict(next_frame)[0]
    pred_x = model_x.predict(next_frame)[0]

    return pred_y, pred_x

def remove_track(track_info_file, target_tif_dir, to_be_removed_id):

    track_info = pd.read_csv(track_info_file, sep='\s+', names=['Track_ID', 'Start', 'End', 'Parent'])
    track_id, start_frame, end_frame, parent_id = track_info.loc[track_info['Track_ID'] == to_be_removed_id].values[0]

    for frame_number in range(start_frame, end_frame + 1):
        frame_path = os.path.join(target_tif_dir, f'man_track{frame_number:04d}.tif')
        if os.path.exists(frame_path):
            frame = tiff.imread(frame_path)
            frame[frame == to_be_removed_id] = 0
            tiff.imwrite(frame_path, frame)
            print(f"Removed track {to_be_removed_id} from frame {frame_number}")

    new_track_info = track_info[track_info['Track_ID'] != to_be_removed_id]
    new_track_info.to_csv(track_info_file, sep=' ', index=False, header=False)
    
    return new_track_info

def diverge_track(track_info_file, target_tif_dir, to_be_split_id, new_id, diverging_start_frame):

    track_info = pd.read_csv(track_info_file, sep='\s+', names=['Track_ID', 'Start', 'End', 'Parent'])
    track_id, start_frame, end_frame, parent_id = track_info.loc[track_info['Track_ID'] == to_be_split_id].values[0]

    if diverging_start_frame <= start_frame or diverging_start_frame > end_frame:
        raise ValueError("Diverging start frame must be within the original track's range.")
    
    # Iterate through the frames and update the track
    for frame_number in range(diverging_start_frame, end_frame + 1):
        frame_path = os.path.join(target_tif_dir, f'man_track{frame_number:04d}.tif')
        if os.path.exists(frame_path):
            frame = tiff.imread(frame_path)
            frame[frame == to_be_split_id] = new_id
            tiff.imwrite(frame_path, frame)
            print(f"Diverge frame {frame_number}: track {to_be_split_id} -> {new_id}")
    
    # Create a new row for the new track
    new_track_row = pd.DataFrame({
        'Track_ID': [new_id],
        'Start': [diverging_start_frame],
        'End': [end_frame],
        'Parent': [to_be_split_id]
    })
    
    # Append the new track row to the DataFrame 
    new_track_info = track_info._append(new_track_row, ignore_index=True)
    temp = new_track_info['Track_ID'].values
    print(f"Added {new_id} in the track info so now {new_id} is in new_track_info {new_id in temp}")

    # Update the end frame of the original track
    new_track_info.loc[new_track_info['Track_ID'] == to_be_split_id, 'End'] = diverging_start_frame - 1

    new_track_info.to_csv(track_info_file, sep=' ', index=False, header=False)

    return new_track_info

def merge_track(track_info_file, target_tif_dir, to_be_merged_ids):

    track_info = pd.read_csv(track_info_file, sep='\s+', names=['Track_ID', 'Start', 'End', 'Parent'])
    new_track_info = track_info.copy()

    # Determine alpha and beta tracks based on start frames
    alpha_id, beta_id = to_be_merged_ids if track_info.loc[track_info['Track_ID'] == to_be_merged_ids[0], 'Start'].values[0] < track_info.loc[track_info['Track_ID'] == to_be_merged_ids[1], 'Start'].values[0] else to_be_merged_ids[::-1]

    alpha_id, alpha_start_frame, alpha_end_frame, alpha_parent_id = track_info.loc[track_info['Track_ID'] == alpha_id].values[0]
    beta_id, beta_start_frame, beta_end_frame, beta_parent_id = track_info.loc[track_info['Track_ID'] == beta_id].values[0]

    # For frames from alpha_end_frame + 1 to beta_end_frame, all beta track objects are relabeled to alpha.
    sizes_after_alpha_end = []
    for frame_number in range(alpha_end_frame+1, beta_end_frame+1):
        frame_path = os.path.join(target_tif_dir, f'man_track{frame_number:04d}.tif')
        if os.path.exists(frame_path):
            frame = tiff.imread(frame_path)
            beta_mask = (frame == beta_id).astype(np.uint8)
            sizes_after_alpha_end.append(np.sum(beta_mask))
            frame[beta_mask > 0] = alpha_id  # Relabel beta to alpha
            tiff.imwrite(frame_path, frame)
            print(f"Merge frame {frame_number}: track {beta_id} -> {alpha_id}")

    median = np.median(sizes_after_alpha_end)
    mad = np.median(np.abs(sizes_after_alpha_end - median)) 
    lower_bound = median - 2 * mad

    # assume the beta starts at the same frame but ends at (include) alpha_end_frame
    new_beta_start_frame = beta_start_frame.copy()
    new_beta_end_frame = alpha_end_frame.copy()

    if beta_start_frame < alpha_end_frame+1:

        for frame_number in range(beta_start_frame, alpha_end_frame + 1):
            frame_path = os.path.join(target_tif_dir, f'man_track{frame_number:04d}.tif')
            if os.path.exists(frame_path):
                frame = tiff.imread(frame_path)
                beta_mask = (frame == beta_id).astype(np.uint8)
                size_beta = np.sum(beta_mask)
                if size_beta < lower_bound:
                    frame[beta_mask > 0] = 0  # Remove small objects
                    print(f"Removed small object in frame {frame_number} with size {size_beta}")
                else:
                    frame[beta_mask > 0] = alpha_id  # Merge beta into alpha
                    print(f"Merged frame {frame_number}: track {beta_id} -> {alpha_id} (This is before alpha end frame)")
                new_beta_start_frame = frame_number + 1
                tiff.imwrite(frame_path, frame)

    # Update track_info DataFrame
    new_track_info.loc[new_track_info['Track_ID'] == alpha_id, 'End'] = max(alpha_end_frame, beta_end_frame)
    if new_beta_start_frame > new_beta_end_frame: 
        print(f"Remove track ID: {beta_id}")
        new_track_info = new_track_info[new_track_info['Track_ID'] != beta_id]

    else:
        print(f"Putting in new start and end frames {new_beta_start_frame}, {new_beta_end_frame}")
        new_track_info.loc[new_track_info['Track_ID'] == beta_id, 'Start'] = new_beta_start_frame
        new_track_info.loc[new_track_info['Track_ID'] == beta_id, 'End'] = new_beta_end_frame

    new_track_info.to_csv(track_info_file, sep=' ', index=False, header=False)

    temp = new_track_info["Track_ID"].values
    print(f"confirmation that the new track info file do not have beta id {beta_id} {beta_id not in temp}")

    return new_track_info

import numpy as np
from scipy.spatial.distance import cdist

def maj_object_within_radius(frame, point, radius):
    """
    Check if there are any object pixels within a radius around a specified point
    and return the label of the object with the most pixels inside the radius.

    Parameters:
    - frame: numpy array, the pixel assignment matrix.
    - point: tuple (x, y), the coordinates of the given point.
    - radius: float, the radius within which to check for object pixels.

    Returns:
    - int, label of the object with the most pixels within the radius, or 0 if none.
    """
    # Get the coordinates and labels of all non-background pixels
    object_coords = np.column_stack(np.where(frame != 0))
    object_labels = frame[frame != 0]

    # Calculate the Euclidean distance from the given point to each object pixel
    distances = cdist([point], object_coords, metric='euclidean')

    # Get pixels within the specified radius
    within_radius = distances[0] <= radius
    if np.any(within_radius):
        # Count the number of pixels for each label within the radius
        labels_within_radius = object_labels[within_radius]
        unique_labels, counts = np.unique(labels_within_radius, return_counts=True)
        non_zero_labels = unique_labels[unique_labels != 0]
        if len(non_zero_labels) > 0:
            max_count_label = non_zero_labels[np.argmax(counts[unique_labels != 0])]
            return max_count_label
    return 0



## Function track_again() 

This step is to perform tracking again but with filtered segmentations 

Scenarios to filter out: 
    Things that are not moving 

In [8]:
import pandas as pd
import os
import tifffile as tiff
def load_tif_masks_from_directory(directory, img_shape, small_non_moving_tracks):
    masks = []
    current_object_index = 1
    for file in sorted(os.listdir(directory)):
        if file.startswith("._") or not file.endswith(".tif"): continue
        tif = tiff.imread(os.path.join(directory, file))
        if tif is None:
            print(f"Error reading label image: {tif}")
            continue
        unique_labels = sorted(np.unique(tif)) # get the unique labels for this specific frame

        for label in unique_labels: 
            if label == 0:
                continue
            elif label in small_non_moving_tracks: 
                tif[tif == label] = 0
                continue
            tif[tif == label] = current_object_index
            current_object_index += 1

        masks.append(tif)
        
    return np.array(masks)

def calculate_movement_and_size(tif_directory, track_info):
    """
    Calculate the movement and size of each track across frames.
    """
    track_movements = {}
    track_sizes = {}
    
    for index, row in track_info.iterrows():
        track_id = row['Track_ID']
        start_frame = row['Start']
        end_frame = row['End']
        
        movement = []
        size = []

        for frame_num in range(start_frame, end_frame + 1):
            frame_file = os.path.join(tif_directory, f"man_track{frame_num:04d}.tif")
            if os.path.exists(frame_file):
                frame = tiff.imread(frame_file)
                
                # Assuming that the track_id is represented by specific pixel values
                track_pixels = np.argwhere(frame == track_id)
                if len(track_pixels) > 0:
                    centroid = np.mean(track_pixels, axis=0)
                    size.append(len(track_pixels))
                    
                    if frame_num > start_frame:
                        movement.append(np.linalg.norm(centroid - prev_centroid))
                    
                    prev_centroid = centroid
        # calc the median and the total dist travelled excluding outliers
        filtered_movement = movement.copy()
        if filtered_movement and len(filtered_movement) > 5: 
            q1, q3 = np.percentile(filtered_movement, [25, 75])
            upper_bound = q3 + 1.5 * (q3-q1)
            filtered_movement = [m for m in filtered_movement if m < upper_bound]


        track_movements[track_id] = np.median(filtered_movement) if len(filtered_movement)>0 else 0
        track_sizes[track_id] = np.median(size) if size else 0

        print(f"The track {track_id} has movement: {np.median(filtered_movement) if len(filtered_movement)>0 else 0} with size {np.median(size) if size else 0}")

    return track_movements, track_sizes

def find_small_non_moving_tracks(track_info_file, tif_directory, movement_threshold=2.0, size_threshold=173):
    """
    Find tracks that have not moved significantly and are small in size.
    
    Parameters:
    - track_info_file: str, path to the track info text file
    - tif_directory: str, path to the directory containing TIFF files
    - movement_threshold: float, threshold for movement to consider a track as non-moving
    - size_threshold: int, threshold for size to consider a track as small
    
    Returns:
    - list of track IDs that have not moved significantly and are small in size
    """
    track_info = pd.read_csv(track_info_file, sep='\s+', names=['Track_ID', 'Start', 'End', 'Parent'])
    track_movements, track_sizes = calculate_movement_and_size(tif_directory, track_info)
    
    small_non_moving_tracks = [
        track_id for track_id in track_movements
        if track_movements[track_id] <= movement_threshold #and track_sizes[track_id] <= size_threshold
    ]
    
    return sorted(small_non_moving_tracks)


## run track again()

In [52]:
out_folder = f'{chip}/{run}/tracked'
track_info_file = f'/projects/steiflab/scratch/leli/trackastra/{out_folder}/man_track.txt'
tif_directory = f"/projects/steiflab/scratch/leli/trackastra/{out_folder}"
linkfile_directory = f"/projects/steiflab/scratch/leli/trackastra/{out_folder}/tif_to_img.csv"
out_folder = f'{chip}/{run}/tracked_again'

# get the track ids that should be removed
small_non_moving_tracks = find_small_non_moving_tracks(track_info_file, tif_directory)
print(f"the small no moving objects are : {small_non_moving_tracks}")

# Load images
linkfile = pd.read_csv(linkfile_directory)
imgs, img_names = load_images_from_directory(main_img_directory, list(linkfile['imgs']))

# Load masks
masks = load_tif_masks_from_directory(tif_directory, imgs[0].shape, small_non_moving_tracks)

print("Images shape:", imgs.shape)
print("Masks shape:", masks.shape)

# Ensure the shape matches the required format: (time, y, x)
imgs = imgs.reshape(-1, imgs.shape[1], imgs.shape[2])
masks = masks.reshape(-1, masks.shape[1], masks.shape[2])

imgs, masks, ind_to_remove = remove_empty_frame(imgs, masks)

print("Images shape:", imgs.shape)
print("Masks shape:", masks.shape)


# Load a pretrained model
# or from a local folder
# model = Trackastra.from_folder('path/my_model_folder/', device=device)
model = Trackastra.from_pretrained("general_2d", device=device)

# Track the cells
track_graph = model.track(imgs, masks, mode="greedy")  # or mode="ilp", or "greedy_nodiv"

# Write to cell tracking challenge format
ctc_tracks, masks_tracked = graph_to_ctc(
    track_graph,
    masks,
    outdir=out_folder,
)

## create a file that connects tiffs with the images
tifs = sorted([t for t in os.listdir(out_folder) if t.endswith(".tif") and not t.startswith("._")])
img_names_new = np.delete(img_names, ind_to_remove, axis = 0)
link_file = pd.DataFrame({"tifs": tifs, "imgs": img_names_new})
link_file.to_csv(os.path.join(out_folder, "tif_to_img.csv"), index=False)



The track 1 has movement: 7.889370353614154 with size 903.0
The track 2 has movement: 5.356137229006466 with size 203.0
The track 3 has movement: 0 with size 731.0
The track 4 has movement: 20.641785799384078 with size 73.5
The track 5 has movement: 0 with size 743.0
The track 6 has movement: 42.911658692921435 with size 122.5
The track 7 has movement: 56.792095470187434 with size 215.0
The track 8 has movement: 10.85884174003499 with size 63.5
The track 9 has movement: 53.06840660754753 with size 121.0
The track 10 has movement: 4.00817382601755 with size 274.0
The track 11 has movement: 5.430485969846325 with size 427.5
The track 12 has movement: 5.126065570198625 with size 410.0
The track 13 has movement: 5.722166105188759 with size 633.0
The track 14 has movement: 16.091773785506753 with size 185.0
The track 15 has movement: 4.872362608777653 with size 816.5
The track 16 has movement: 7.211442400707882 with size 103.0
The track 17 has movement: 8.042793992524444 with size 102.0
The

INFO:trackastra.model.model:Loading model state from /home/leli/.trackastra/.models/general_2d/model.pt


Images shape: (1818, 313, 192)
Masks shape: (1818, 313, 192)
Images shape: (1817, 313, 192)
Masks shape: (1817, 313, 192)
/home/leli/.trackastra/.models/general_2d already downloaded, skipping.


INFO:trackastra.model.model_api:Predicting weights for candidate graph


Using device cuda


INFO:trackastra.data.wrfeat:Extracting features from 1817 detections
INFO:trackastra.data.wrfeat:Using single process for feature extraction
Extracting features: 100%|██████████| 1817/1817 [00:07<00:00, 256.59it/s]
INFO:trackastra.model.model_api:Building windows
Building windows: 100%|██████████| 1814/1814 [00:00<00:00, 23791.26it/s]
INFO:trackastra.model.model_api:Predicting windows
Computing associations: 100%|██████████| 1814/1814 [00:37<00:00, 48.75it/s]
INFO:trackastra.model.model_api:Running greedy tracker
INFO:trackastra.tracking.tracking:Build candidate graph with delta_t=1
INFO:trackastra.tracking.tracking:Added 2985 vertices, 2902 edges                             
INFO:trackastra.tracking.tracking:Running greedy tracker
Greedily matched edges:  99%|█████████▊| 2865/2902 [00:00<00:00, 85608.41it/s]
Converting graph to CTC results: 100%|██████████| 146/146 [00:00<00:00, 2672.21it/s]
Saving masks: 100%|██████████| 1817/1817 [00:06<00:00, 300.25it/s]


## plot()

In [53]:
# Directory containing tracking results (TIFF files and text file)
tracking_dir = f"/projects/steiflab/scratch/leli/trackastra/{out_folder}"
print(f"{os.path.isdir(tracking_dir)}")
# Create an output directory for PNGs
output_dir = f"/projects/steiflab/scratch/leli/trackastra/{out_folder}_imgs/"
os.makedirs(output_dir, exist_ok=True)

total_frame_num = len([file for file in os.listdir(main_img_directory) if file.startswith("Image_") and file.endswith(".png")])
#total_frame_num = len([filename for filename in os.listdir(main_img_directory) if filename.endswith("_htert_Run.png") and not filename.startswith("._") and "Printed" not in filename])
print(f"total frame number is {total_frame_num}")

# Process each frame
act_rcnn_inds = [i+1 for i in range(total_frame_num) if i not in ind_to_remove] # here we are tying to find the corresponding frame index that matches with the rcnn results
#assert len([file for file in os.listdir(tracking_dir) if not file.startswith("._") and file.endswith("tif")]) == len(act_rcnn_inds)

start = 3274
end = 3757
frames_to_process = [i for i, act in enumerate(act_rcnn_inds) if act >= start and act <= end] # here if teh act rcnn index is in range then we include the 0-starting index which will be used to index the ti files later
print(frames_to_process)
process_frames(imgs, tracking_dir, output_dir, frames_to_process = None)
print(f"Process frames done!")

'''# Create a video from the saved PNGs
print(f"These are the tiffs that should be in the ground truth")
output_video = f'/projects/steiflab/scratch/leli/trackastra/{out_folder}_imgs/tracked_video_val.mp4'
height, width = imgs.shape[1], imgs.shape[2]  # Get height and width from images
create_video(output_dir, output_video, imgs.shape[0], width, height, fps=3, frames_to_process = frames_to_process)'''

output_video = f'/projects/steiflab/scratch/leli/trackastra/{out_folder}_imgs/tracked_video_full.mp4'
height, width = imgs.shape[1], imgs.shape[2]
create_video(output_dir, output_video, imgs.shape[0], width, height, fps=3, frames_to_process = None)

True
total frame number is 2776
[]
Saved image for frame: 0 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_imgs/man_track0000.png
Saved image for frame: 1 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_imgs/man_track0001.png
Saved image for frame: 2 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_imgs/man_track0002.png
Saved image for frame: 3 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_imgs/man_track0003.png
Saved image for frame: 4 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_imgs/man_track0004.png
Saved image for frame: 5 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_imgs/man_track0005.png
Saved image for frame: 6 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_imgs/man_track0006.png
Saved image for frame: 7 at /projects/steiflab/scratch/leli/trackastr

## gen_pp_main() for sub tracks

In [54]:
import warnings
import joblib
import sys 
# Define paths
source_tif_dir = f"/projects/steiflab/scratch/leli/trackastra/{out_folder}"
target_tif_dir = f"/projects/steiflab/scratch/leli/trackastra/{out_folder}_postprocessed"
track_info_file = f'/projects/steiflab/scratch/leli/trackastra/{out_folder}/man_track.txt'
classifier_model_path = f'/projects/steiflab/scratch/leli/trackastra/postprocessing/models/Random Forest_model.pkl'
example_csv_path = f"/projects/steiflab/scratch/leli/trackastra/postprocessing/df.csv"

os.makedirs(target_tif_dir, exist_ok=True)
# Copy original tif files to the target directory
if os.path.exists(target_tif_dir):
    shutil.rmtree(target_tif_dir)
shutil.copytree(source_tif_dir, target_tif_dir)

os.makedirs(target_tif_dir, exist_ok=True)


# Load the classifier model
classifier = joblib.load(classifier_model_path)
#classifier = test_model['SVM']

# Read the track info and example CSV
track_info = pd.read_csv(track_info_file, sep='\s+', names=['Track_ID', 'Start', 'End', 'Parent'])
new_track_info = pd.DataFrame(columns=track_info.columns)

# here we are finding the root track of all
def find_root(track_id):
    parent = track_info.loc[track_info['Track_ID'] == track_id, 'Parent'].values[0]
    if parent == 0:
        return track_id
    else:
        return find_root(parent)

track_info['Root'] = track_info['Track_ID'].apply(find_root)

# here we are prepping and predicting the action classes for all
track_info_file = f'/projects/steiflab/scratch/leli/trackastra/{out_folder}/man_track.txt'
tif_directory = f'/projects/steiflab/scratch/leli/trackastra/{out_folder}'
track_ids = track_info.loc[track_info['Parent'] != 0]['Track_ID']
features_dict = calculate_features(track_info_file, tif_directory, track_ids)
if not features_dict['Track_ID']: 
    print("there are no sub tracks!!!")
    sys.exit()

examples = pd.DataFrame(features_dict)
examples["action"] = classifier.predict(examples)
examples = examples.set_index('Track_ID')

for filename in os.listdir(target_tif_dir):
    if not filename.startswith("._") and filename.endswith(".tif"):
        frame_path = os.path.join(target_tif_dir, filename)
        frame = tiff.imread(frame_path)
        print(f"start processing frame: {filename} #####################################")
        processed_frame = postprocess_frame(frame, track_info, examples)
        tiff.imwrite(frame_path, processed_frame)
        print(f"complete processing frame: {filename}")

        frame_num = int(filename.replace('man_track', '').replace('.tif', ''))
        new_track_info = update_track_info_across_frame(track_info, new_track_info, processed_frame, frame_num)
        new_track_info.to_csv(os.path.join(target_tif_dir, "man_track.txt"), sep=' ', index=False, header=False)

        #if frame_num == 1032: # and frame_num <= 2519:
            #print(f"1032 has unique values")
            #display_colored_images(processed_frame, labels_to_color = [178, 189, 190, 191], title= f'current Frame: {frame_num}')
            #break
            #print(f"the frame 2517 has unique values: {np.unique(processed_frame)}")
             
            #plot_frame(frame, title=f'current Frame: {frame_num} when track id is: {track_id}', ids = [track_id]+siblings)



print("Sub track Processing completed.")


start processing frame: man_track0000.tif #####################################
This group has root 1 and contains [1]
iteration: 1 where it is 
    Track_ID  Start  End  Parent  Root
0         1      0   20       0     1
complete processing frame: man_track0000.tif
start processing frame: man_track0001.tif #####################################
This group has root 1 and contains [1]
iteration: 1 where it is 
    Track_ID  Start  End  Parent  Root
0         1      0   20       0     1
complete processing frame: man_track0001.tif
start processing frame: man_track0002.tif #####################################
This group has root 1 and contains [1]
iteration: 1 where it is 
    Track_ID  Start  End  Parent  Root
0         1      0   20       0     1
This group has root 2 and contains [2]
iteration: 2 where it is 
    Track_ID  Start  End  Parent  Root
1         2      2   31       0     2
complete processing frame: man_track0002.tif
start processing frame: man_track0003.tif ###############

## plot()

In [55]:
# Directory containing tracking results (TIFF files and text file)
tracking_dir = f"/projects/steiflab/scratch/leli/trackastra/{out_folder}_postprocessed"
print(f"{os.path.isdir(tracking_dir)}")
# Create an output directory for PNGs
output_dir = f"/projects/steiflab/scratch/leli/trackastra/{out_folder}_postprocessed_imgs"
os.makedirs(output_dir, exist_ok=True)

# Process each frame
total_frame_num = len([file for file in os.listdir(main_img_directory) if file.startswith("Image_") and file.endswith(".png")])
#total_frame_num = len([filename for filename in os.listdir(main_img_directory) if filename.endswith("_htert_Run.png") and not filename.startswith("._") and "Printed" not in filename])
act_rcnn_inds = [i+1 for i in range(total_frame_num) if i not in ind_to_remove] # here we are tying to find the corresponding frame index that matches with the rcnn results
#assert len([file for file in os.listdir(tracking_dir) if not file.startswith("._") and file.endswith("tif")]) == len(act_rcnn_inds)

start = 3274
end = 3757
frames_to_process = [i for i, act in enumerate(act_rcnn_inds) if act >= start and act <= end] # here if teh act rcnn index is in range then we include the 0-starting index which will be used to index the ti files later
print(frames_to_process)

process_frames(imgs, tracking_dir, output_dir, frames_to_process = None)
print(f"Process frames done!")

# Create a video from the saved PNGs
'''print(f"These are the tiffs that should be in the ground truth")
output_video = f'/projects/steiflab/scratch/leli/trackastra/postprocessing/tracked_1.0_imgs_pp/tracked_video_val.mp4'
height, width = imgs.shape[1], imgs.shape[2]  # Get height and width from images
create_video(output_dir, output_video, imgs.shape[0], width, height, fps=3, frames_to_process = frames_to_process)'''

output_video = f'/projects/steiflab/scratch/leli/trackastra/{out_folder}_postprocessed_imgs/tracked_video_full.mp4'
create_video(output_dir, output_video, imgs.shape[0], width, height, fps=3, frames_to_process = None)


True
[]
Saved image for frame: 0 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_imgs/man_track0000.png
Saved image for frame: 1 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_imgs/man_track0001.png
Saved image for frame: 2 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_imgs/man_track0002.png
Saved image for frame: 3 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_imgs/man_track0003.png
Saved image for frame: 4 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_imgs/man_track0004.png
Saved image for frame: 5 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_imgs/man_track0005.png
Saved image for frame: 6 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_imgs/man_track0006.pn

Saved image for frame: 24 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_imgs/man_track0024.png
Saved image for frame: 25 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_imgs/man_track0025.png
Saved image for frame: 26 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_imgs/man_track0026.png
Saved image for frame: 27 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_imgs/man_track0027.png
Saved image for frame: 28 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_imgs/man_track0028.png
Saved image for frame: 29 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_imgs/man_track0029.png
Saved image for frame: 30 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_imgs/man_track0030.png

## gen_pp_main() for all tracks

In [56]:
import warnings
import joblib
import shutil

# Define paths
#out_folder = "A138974A/PrintRun_Apr1223_1311/tracked"
source_tif_dir = f"/projects/steiflab/scratch/leli/trackastra/{out_folder}_postprocessed"
target_tif_dir = f"/projects/steiflab/scratch/leli/trackastra/{out_folder}_postprocessed_2.0"
track_info_file = f'/projects/steiflab/scratch/leli/trackastra/{out_folder}_postprocessed_2.0/man_track.txt'


os.makedirs(target_tif_dir, exist_ok=True)
# Copy original tif files to the target directory
if os.path.exists(target_tif_dir):
    shutil.rmtree(target_tif_dir)
shutil.copytree(source_tif_dir, target_tif_dir)

print("Initiate Main track Processing ... ")
track_info = pd.read_csv(track_info_file, sep='\s+', names=['Track_ID', 'Start', 'End', 'Parent']) # read in again the new versin of track info 
track_ids = sorted(track_info['Track_ID'])
#track_ids  = [50]

# This number is the original max track id, for divered track we would want to creat new ones
new_track_label = np.max(track_info['Track_ID'].values)

while track_ids: 
    
    print(f"Check for track ID: {track_ids[0]}")

    # there may be changes evertime we go throuh a track
    track_info = pd.read_csv(track_info_file, sep='\s+', names=['Track_ID', 'Start', 'End', 'Parent']) # read in again the new versin of track info 

    #### Check 1: <=2 Frames
    track_id, start_frame, end_frame, parent_id = track_info.loc[track_info['Track_ID'] == track_ids[0]].values[0]
    if end_frame - start_frame <=2: 
        remove_track(track_info_file, target_tif_dir, to_be_removed_id = track_ids[0])
        track_ids = track_ids[1:]
        continue

    #### Check 2: Moving up Tracks and paused tracks 
    track_id, start_frame, end_frame, parent_id = track_info.loc[track_info['Track_ID'] == track_ids[0]].values[0]
    centroids = []
    centroids_frame = []
    skipped_frames = []
    for frame_number in range(start_frame, end_frame+1):

        #print(f"centroids_with_prediction is : {centroids_with_prediction}")

        frame_path = os.path.join(target_tif_dir, f'man_track{frame_number:04d}.tif')
        frame = tiff.imread(frame_path)
        binary_mask = (frame == track_id).astype(np.uint8)

        #When this track does not exist in this frame we keep going 
        if len(np.unique(binary_mask)) == 1:
            skipped_frames.append(frame_number)
            continue

        # add in the centroid
        centroid = regionprops(binary_mask)[0].centroid
        if centroid is None: raise ValueError("The centroid point being added is Nnne")

        if len(centroids) >=5: # if we are in the middle of the tracklet 
            y_changes = np.diff([c[0] for c in centroids])
            median_change = np.median(y_changes)
            mad = np.median(np.abs(y_changes - median_change))
            threshold = 2.5 * mad # 2.5 is the usual value but can be changed

            #print(f"The centroids are currently {centroids}")

            # here we already have the MAD threshold, the lower bound is median change in y direction - the threshold. 
            # Since the object is always going down, the y value should only increase. So once the object move up, the change in y value should be negative so it is on the lower bound. 
            # Here we are checking if it is outside the lower bound. 
            #print(f"The difference between the current object and the last one is {np.diff([centroids[-1][0], centroid[0]])} with the last item being {centroids[-1][0]} and the current centroid y value is {centroid[0]} with the lower bound be {median_change - threshold}")
            if np.diff([centroids[-1][0], centroid[0]]) <= median_change - threshold: 

                # since this is case 2 so we add a prefix to the track id so we can come back to it
                new_track_label = new_track_label+1
                diverge_track(track_info_file, target_tif_dir, to_be_split_id = track_id, new_id = int(new_track_label), diverging_start_frame = frame_number)
                track_ids.append(int(new_track_label))
                break

            elif len(skipped_frames) >=2:

                # since this is case 2.5 so we add a prefix to the track id so we can come back to it
                new_track_label = new_track_label+1
                diverge_track(track_info_file, target_tif_dir, to_be_split_id = track_id, new_id = int(new_track_label), diverging_start_frame = frame_number)
                track_ids.append(int(new_track_label))
                break
        
        centroids.append(centroid)
        centroids_frame.append(frame_number)


        # else: # if we are at the beginning we do not do anything yet, might change later

    #### Check 3: label switching
    centroids_with_prediction = centroids.copy() # here this centroid will contain the LR predicted centroid
    centroids_frame_with_prediction = centroids_frame.copy()

    #print(f"IN CASE 3: The centroids are {centroids_with_prediction} and the frame numbers are {centroids_frame_with_prediction}")

    if len(centroids_with_prediction) >= 5: 
        covered_by = []
        
        for frame_number in range(sorted(centroids_frame_with_prediction, reverse = True)[0] + 1, sorted(centroids_frame_with_prediction, reverse = True)[0] + 4):
            frame_path = os.path.join(target_tif_dir, f'man_track{frame_number:04d}.tif')
            if os.path.exists(frame_path):
                frame = np.array(tiff.imread(frame_path))

                curr_c = predict_next_centroids(centroids_with_prediction, centroids_frame_with_prediction, predict_this_frame = frame_number)

                #centroids_with_prediction.append(curr_c)
                #centroids_frame_with_prediction.append(frame_number)

                if 0 <= int(curr_c[0]) < frame.shape[0] and 0 <= int(curr_c[1]) < frame.shape[1]:
                    maj_label = maj_object_within_radius(frame, curr_c, radius = 3.5)
                    covered_by.append(maj_label)
                    if maj_label != 0:
                        binary_mask = (frame == maj_label).astype(np.uint8)
                        centroid = regionprops(binary_mask)[0].centroid
                        centroids_with_prediction.append(curr_c)
                        centroids_frame_with_prediction.append(frame_number)

                else:
                    print("prediction went out of bound")
                    covered_by.append(0)
                    centroid = regionprops(binary_mask)[0].centroid


        #print(f"here we see that the track is covered by {covered_by} when the centroids are {centroids_with_prediction}")
        if len(covered_by) == 3:
            non_zero_values = [x for x in covered_by if x > 0]
            for label in set(non_zero_values):
                if non_zero_values.count(label) >= 2:
                    # the ids that are to be merge do not matter because we pick which one is which within the merge track function
                    new = merge_track(track_info_file, target_tif_dir, to_be_merged_ids = (label, track_id))

                    temp = new["Track_ID"].values
                    print(f"The label {label} is not in the track info {label not in temp} and the track id {track_id}")
                    if label not in new["Track_ID"].values and label in track_ids: # make sure to have .values, apparently pd series check index not the value if we do no include this :(
                        track_ids.remove(label)

    #### Check 4: Overall y movement to remove the ones that did not make a movement 
    if len(centroids) < 5 and len(centroids) > 1:
        overall_y_movement = centroids[-1][0] - centroids[0][0]
        if overall_y_movement <= 5:
            remove_track(track_info_file, target_tif_dir, to_be_removed_id = track_ids[0])


    assert len(track_ids) == len(np.unique(track_ids))

    track_ids = track_ids[1:]


print("Main track Processing completed.")

print("initiate updating tracking info csv at the very end to correct and ensure the track info final version")

track_info = pd.read_csv(track_info_file, sep='\s+', names=['Track_ID', 'Start', 'End', 'Parent']) # read in again the new versin of track info 
new_track_info = pd.DataFrame(columns=track_info.columns)
for filename in os.listdir(target_tif_dir):
    if not filename.startswith("._") and filename.endswith(".tif"):
        frame_path = os.path.join(target_tif_dir, filename)
        frame = tiff.imread(frame_path)

        frame_num = int(filename.replace('man_track', '').replace('.tif', ''))
        new_track_info = update_track_info_across_frame(track_info, new_track_info, frame, frame_num)
        new_track_info.to_csv(os.path.join(target_tif_dir, "man_track.txt"), sep=' ', index=False, header=False)



Initiate Main track Processing ... 
Check for track ID: 1
Check for track ID: 2
Check for track ID: 3
Removed track 3 from frame 23
Removed track 3 from frame 24
Check for track ID: 4
Removed track 4 from frame 25
Removed track 4 from frame 26
Check for track ID: 5
Check for track ID: 6
Removed track 6 from frame 34
Removed track 6 from frame 35
Removed track 6 from frame 36
Check for track ID: 7
Diverge frame 41: track 7 -> 147
Diverge frame 42: track 7 -> 147
Diverge frame 43: track 7 -> 147
Diverge frame 44: track 7 -> 147
Diverge frame 45: track 7 -> 147
Diverge frame 46: track 7 -> 147
Diverge frame 47: track 7 -> 147
Diverge frame 48: track 7 -> 147
Diverge frame 49: track 7 -> 147
Diverge frame 50: track 7 -> 147
Diverge frame 51: track 7 -> 147
Diverge frame 52: track 7 -> 147
Diverge frame 53: track 7 -> 147
Diverge frame 54: track 7 -> 147
Diverge frame 55: track 7 -> 147
Diverge frame 56: track 7 -> 147
Diverge frame 57: track 7 -> 147
Diverge frame 58: track 7 -> 147
Diverg

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Merged frame 253: track 28 -> 26 (This is before alpha end frame)
Merged frame 254: track 28 -> 26 (This is before alpha end frame)
Merged frame 255: track 28 -> 26 (This is before alpha end frame)
Merged frame 256: track 28 -> 26 (This is before alpha end frame)
Merged frame 257: track 28 -> 26 (This is before alpha end frame)
Merged frame 258: track 28 -> 26 (This is before alpha end frame)
Merged frame 259: track 28 -> 26 (This is before alpha end frame)
Merged frame 260: track 28 -> 26 (This is before alpha end frame)
Merged frame 261: track 28 -> 26 (This is before alpha end frame)
Merged frame 262: track 28 -> 26 (This is before alpha end frame)
Merged frame 263: track 28 -> 26 (This is before alpha end frame)
Merged frame 264: track 28 -> 26 (This is before alpha end frame)
Merged frame 265: track 28 -> 26 (This is before alpha end frame)
Merged frame 266: track 28 -> 26 (This is before alpha end frame)
Merged frame 267: track 28 -> 26 (This is before alpha end frame)
Merged fra

After merging let us print out the images and the video to see!

In [58]:
# Directory containing tracking results (TIFF files and text file)
tracking_dir = f"/projects/steiflab/scratch/leli/trackastra/{out_folder}_postprocessed_2.0"
print(f"{os.path.isdir(tracking_dir)}")
# Create an output directory for PNGs
output_dir = f"/projects/steiflab/scratch/leli/trackastra/{out_folder}_postprocessed_2.0_imgs"
os.makedirs(output_dir, exist_ok=True)

# Process each frame
total_frame_num = len([file for file in os.listdir(main_img_directory) if file.startswith("Image_") and file.endswith(".png")])
#total_frame_num = len([filename for filename in os.listdir(main_img_directory) if filename.endswith("_htert_Run.png") and not filename.startswith("._") and "Printed" not in filename])
act_rcnn_inds = [i+1 for i in range(total_frame_num) if i not in ind_to_remove] # here we are tying to find the corresponding frame index that matches with the rcnn results
#assert len([file for file in os.listdir(tracking_dir) if not file.startswith("._") and file.endswith("tif")]) == len(act_rcnn_inds)

start = 3274
end = 3757
frames_to_process = [i for i, act in enumerate(act_rcnn_inds) if act >= start and act <= end] # here if teh act rcnn index is in range then we include the 0-starting index which will be used to index the ti files later
print(frames_to_process)

process_frames(imgs, tracking_dir, output_dir, frames_to_process = None)
print(f"Process frames done!")

# Create a video from the saved PNGs
'''print(f"These are the tiffs that should be in the ground truth")
output_video = f'/projects/steiflab/scratch/leli/trackastra/postprocessing/tracked_1.0_imgs_pp/tracked_video_val.mp4'
height, width = imgs.shape[1], imgs.shape[2]  # Get height and width from images
create_video(output_dir, output_video, imgs.shape[0], width, height, fps=3, frames_to_process = frames_to_process)'''

height, width = imgs.shape[1], imgs.shape[2]  # Get height and width from images
output_video = f'/projects/steiflab/scratch/leli/trackastra/{out_folder}_postprocessed_2.0_imgs/tracked_video_full.mp4'
create_video(output_dir, output_video, imgs.shape[0], width, height, fps=3, frames_to_process = None)


True
[]
Saved image for frame: 0 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_2.0_imgs/man_track0000.png
Saved image for frame: 1 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_2.0_imgs/man_track0001.png
Saved image for frame: 2 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_2.0_imgs/man_track0002.png
Saved image for frame: 3 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_2.0_imgs/man_track0003.png
Saved image for frame: 4 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_2.0_imgs/man_track0004.png
Saved image for frame: 5 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocessed_2.0_imgs/man_track0005.png
Saved image for frame: 6 at /projects/steiflab/scratch/leli/trackastra/A138856A/10dropRun4/tracked_again_postprocess