# SAM4EM: Zero-shot Multi-particle Segmentation and Tracking in LPTEM

A workflow for particle segmentation and tracking in liquid phase transmission electron microscopy

Before starting: Install SAM2, install additional packages, etc based on the github website description.

Run the cell below to start. **Change notebook_name to your notebook's name**.

*If using Colab*, also **change folder_path to the directory in your Drive where your notebook is located**.

In [1]:
notebook_name = "SAM4EM.ipynb" # replace with your notebook name
folder_path = '/content/drive/MyDrive/SAM4EM' # replace with the folder containing your notebook, if using Colab

import ipywidgets as widgets
from ipywidgets import GridspecLayout, HBox
from IPython.display import display, clear_output, Javascript
from IPython import get_ipython
import nbformat

# Load the notebook
def load_notebook():
    with open(notebook_name, encoding='utf-8', errors='replace') as f:
        return nbformat.read(f, as_version=4)
def run_tagged_cells(tag):
    notebook = load_notebook()
    for cell in notebook.cells:
        if cell.cell_type == "code" and tag in cell.metadata.get("tags", []):
            print(f"Running cell with tag: {tag}")
            get_ipython().run_cell(cell.source)
def scroll_to_tag(tag):
    display(Javascript(f"""
        var cells = Jupyter.notebook.get_cells();
        for (var i = 0; i < cells.length; i++) {{
            var tags = cells[i].metadata.tags || [];
            if (tags.includes('{tag}')) {{
                $('html, body').animate({{scrollTop: $(cells[i].element).offset().top}}, 500);
                break;
            }}
        }}
    """))


# Module 1: Segmentation with SAM2

## Environment and Function Dependencies Set-Up

Run the below cell and select whether you are working in Google Colab or JupyterLab/Jupyter Notebook.

If you are using Google Colab, enable the permissions on the pop-up window so that the notebook can access your data in Google Drive.

In [None]:
def setup_colab():
    print("Setting up Google Colab workflow requirements...")

    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    import os
    os.chdir(folder_path)

    run_tagged_cells("setup_colab")
    print("Google Colab setup is complete")
def setup_jupyterlab():
    print("Setting up JupyterLab workflow requirements...")
    run_tagged_cells("setup_jupyterlab")
    print("JupyterLab setup is complete")


def on_setup_colab_click(b):
    with output:
        clear_output()
        setup_colab()
def on_setup_jupyterlab_click(b):
    with output:
        clear_output()
        setup_jupyterlab()

button_setup_colab = widgets.Button(description="Google Colab", style={'button_color': 'bisque'})
button_setup_jupyterlab = widgets.Button(description="JupyterLab", style={'button_color': 'lightblue'})

output = widgets.Output()


button_setup_colab.on_click(on_setup_colab_click)
button_setup_jupyterlab.on_click(on_setup_jupyterlab_click)

display(widgets.VBox([button_setup_colab, button_setup_jupyterlab, output]))

Set up SAM2 and its functions below

In [None]:
# sets up SAM2 and loads functions
def setup_SAM2_functions():
    print("Setting up SAM2 module...")
    run_tagged_cells("setup_SAM2_functions")
    if using_colab:
        run_tagged_cells("setup_SAM2_functions_colab")
    print("SAM2 module setup is complete")
def on_setup_SAM2_functions_click(b):
    with output:
        clear_output(wait=True)
        setup_SAM2_functions()

output = widgets.Output()

button_setup_SAM2_functions = widgets.Button(description="Set up SAM2 module", style={'button_color': 'steelblue', 'text_color':'white'})
button_setup_SAM2_functions.on_click(on_setup_SAM2_functions_click)

display(widgets.VBox([button_setup_SAM2_functions, output]))

## Select the data and output paths

Set **experiment_dir** to your path that contains a folder called "**jpeg_images**" that has your experimental JPEG images.
The **experiment_name** will be used to name outputted files.
To change the default output file save location, change **output_dir**

The video should be stored as a list of JPEG frames, with filenames of <frame_index>.jpeg, where frame_index is the frame number padded with 0s to be 6 digits long. The JPEG frames can be extracted from a video using ffmpeg

In [None]:
# Output files will be named begining with experiment_name
experiment_name = 'example' # This should be the name of the folder within "Data"

# The experiment_dir is contains the folders where your data and outputs are stored
experiment_dir = os.path.join('Data', experiment_name)

# By default, the output files (segmented mask array and images, positional analyses, etc)
# will be stored in a folder within experiment_dir called output
output_dir = create_unique_folder(os.path.join(experiment_dir, 'output'))

# video_dir contains the JPEG frames of your video
video_dir = os.path.join(experiment_dir, 'jpeg_images')

# Get the frame names from the directory
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))

height, width = (Image.open(os.path.join(video_dir, frame_names[frame_idx]))).size

**Define your custom colormap here**, if you have one. A default colormap is provided below

In [6]:
from matplotlib.colors import ListedColormap
colors_custom = [
    "#E69F00",  # Orange
    "#D55E00",   # Red-orange
    "#56B4E9",  # Sky Blue
    "#009E73",  # Green
    "#0072B2", # Blue
    "#CC79A7", # Pink
    "#F0E442", # Yellow
    "#B3D100",  # Yellow-green
    "#9E1B32"  # Dark red
]

cmap_custom = ListedColormap(colors_custom, name="cmap_custom")
cmap = cmap_custom

## Initialize the inference state
This cell loads JPEG frames and stores their pixels in inference_state

In [None]:
inference_state = predictor.init_state(video_path=video_dir)

## Segmentation
We can segment objects using either box prompts, points prompts, or a combination of the two. ***You must mark all particles you want to track on each frame you choose to interact with.***

**Box**: In this method, we can draw bounding boxes around particles. In the pop-up window, click on the two diagonal corners of the box you would like to draw. Then, exit out of the window.

**Points**: Alternatively, you can click on the particle to identify it. You may add multiple positive and/or negative points. Points can be used in conjunction with a box to refine the particle mask.
- *Positive points*: Mark a point(s) where the particle is located
- *Negative points*: Mark a point(s) where the mask incorrectly identifies pixels as part of the particle

If you have run any previous tracking, reset the predictor (inference_state) first

In [None]:
# Reset Predictions button
reset_predictions_button = widgets.Button(
    description='Reset All Predictions',
    button_style='danger',
    tooltip='Click to reset inference state. You will lose all predictions',
    icon='trash can',
    layout=widgets.Layout(width='550px')
)

def on_reset_predictions_click(b):
    global old_frame_idx, old_obj_id, prompts
    predictor.reset_state(inference_state)
    old_frame_idx = -1
    old_obj_id = -1
    prompts = {}
    print("Inference state has been reset")

reset_predictions_button.on_click(on_reset_predictions_click)
display(reset_predictions_button)

#### Select Inputs
**Enter the frame number** you want to select particles on. For each particle you would like to track, first **enter the particle's unique Particle ID**. Then **select your input mode**: a bounding box or a positive/negative point(s). **Click Run Action**, then follow the instructions in the pop-up window.

*If you are using Colab*, you must rerun this entire cell each time you are adding a new prompt, and then follow the standard directions. Also click the "Run additional step (Colab only)" button in the next cell after EACH time you add prompts

In [None]:
ann_obj_id = None
ann_frame_idx = None
input_mode = None
point_type = None

## setting particle ID and frame index
text_input_id = widgets.Text(
    description='Enter particle ID :',
    placeholder='Enter a unique integer',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='550px')
)
text_input_frame = widgets.Text(
    description='Enter Frame Number:',
    placeholder='Enter an integer 0,..N',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='550px')
)

inputs_vbox = widgets.VBox([text_input_id, text_input_frame])
reset_predictions_button.layout = widgets.Layout(width='300px', height = '62px')
inputs_and_reset_hbox = widgets.HBox([inputs_vbox, reset_predictions_button])

confirm_button = widgets.Button(
    description='Set ID & Frame',
    button_style='success',  # Green button
    tooltip='Click to set the particle ID and Frame #',
    icon='check',
    layout=widgets.Layout(width='550px')
)
def on_confirm_click(b):
    global ann_obj_id, ann_frame_idx
    with output:
        output.clear_output()
        try:
            ann_obj_id = int(text_input_id.value)
            ann_frame_idx = int(text_input_frame.value)
            print(f"Annotation Object ID set to: {ann_obj_id}")
            print(f"Frame Number set to: {ann_frame_idx}")
        except ValueError:
            print("Please enter a valid integer.")

confirm_button.on_click(on_confirm_click)


## Functions to run prompts
def action_boxprompt():
    with output:
        output.clear_output(wait=True)
        print(f"Running box prompt with particle ID {ann_obj_id} on frame {ann_frame_idx}...")
        if using_colab:
            run_tagged_cells("get_points_colab")
        else:
            run_tagged_cells("add_box_prompt")
def action_pointprompt():
    with output:
        output.clear_output(wait=True)
        print(f"Running points prompt with particle ID {ann_obj_id} on frame {ann_frame_idx}...")
        print(f'Point type: {point_type}')
        if using_colab:
            run_tagged_cells("get_points_colab")
        else:
            run_tagged_cells("add_points_prompt")

## Actions for buttons to select input mode and point type
def on_select_input_mode(mode):
    global input_mode
    input_mode = mode
    with output:
        # output.clear_output()
        print(f'Input mode set to {input_mode}')
        if input_mode == "points":
            point_type_buttons.layout.display = 'flex'
        else:
            point_type_buttons.layout.display = 'none'

def on_select_point_type(point_type_selection):
    global point_type
    point_type = True if point_type_selection=='positive' else False
    with output:
        # output.clear_output()
        print(f'Point type set to {point_type_selection}')

output = widgets.Output()

# Buttons
box_button = widgets.Button(description = 'Box', style={'button_color': '#f5c87a'}, icon='crop', layout=widgets.Layout(width='273px'))
points_button = widgets.Button(description='Points', style={'button_color': '#f5c87a'}, icon='crosshairs', layout=widgets.Layout(width='273px'))

positive_button = widgets.Button(description='Positive Points', style={'button_color': '#f7e4a3'}, icon='plus', layout=widgets.Layout(width='134.5px'))
negative_button = widgets.Button(description='Negative Points', style={'button_color': '#f7e4a3'}, icon='eraser', layout=widgets.Layout(width='134.5px'))
point_type_buttons = widgets.HBox([positive_button, negative_button], layout=widgets.Layout(width='300px', margin='0 25% 0 277px'))

box_button.on_click(lambda b: on_select_input_mode('box'))
points_button.on_click(lambda b: on_select_input_mode('points'))

positive_button.on_click(lambda b: on_select_point_type("positive"))
negative_button.on_click(lambda b: on_select_point_type("negative"))

# Run button
run_action_button = widgets.Button(
    description='Run Action',
    button_style='primary',
    tooltip='Click to run the selected input mode (box or points)',
    icon='play',
    layout=widgets.Layout(width='550px')
)
def on_run_action_click(b):
    if ann_obj_id is not None and ann_frame_idx is not None:
        if input_mode == 'box':
            action_boxprompt()
        elif input_mode == 'points':
            if point_type is not None:
                action_pointprompt()
            else:
                print("Please select point type")
        else:
            print("Please select input mode (box or points)")
    else:
        print("Please set the particle ID and frame number first.")

run_action_button.on_click(on_run_action_click)
reset_predictions_button.layout.display = 'none'  # Hide from tab order


point_type_buttons.layout.display = 'none'
input_mode_buttons = widgets.HBox([box_button, points_button])


# reset_predictions_button.layout = widgets.Layout(width='300px', margin='0 25% 0 277px')
# display(reset_predictions_button)
#
display(widgets.VBox([inputs_and_reset_hbox, confirm_button, input_mode_buttons, point_type_buttons, run_action_button, output]))

In [10]:
# You only need to run this cell if using Colab
output = widgets.Output()
output.clear_output()
# Run button 2
run_action_colab_button = widgets.Button(
    description='Run additional step (Colab only)',
    button_style='primary',
    tooltip='Click to run if using Colab after adding a prompt',
    icon='play',
    layout=widgets.Layout(width='550px')
)
def on_run_action_colab_click(b):
    plt.close('all')
    output.clear_output()
    run_tagged_cells("add_prompts_colab")

run_action_colab_button.on_click(on_run_action_colab_click)
if using_colab:
    display(widgets.VBox([run_action_colab_button, output]))

### Propagate Masklets through Video

After selecting all objects you would like to track, run the below cell to propagate the masks through the entire video. Enter the desired **visualization frame stride** (how often you would like the masks to be printed in this notebook during propagation) and **save frame stride** (how often you want to save a back-up of the segmented masks during propagation).

The video_segments dictionary will be saved as a .pkl file to your experimental_dir every save_frame_stride frames, in the case of the kernel crashing or disconnecting during long runtimes. The files will contain _temp in their filename and can be deleted after propagation is finished.

Afte propagation is completed, a .pkl file called video_segments_final will be saved to your experimental_dir. Additionally, two folders called "masks" (pngs of the masks on a grey background) and "combined_results" (pngs of the original image and the mask side-by-side) will be created in experimental_dir

In [None]:
def action_propagate():
    print(f"Propagate masklets through all frames...")
    run_tagged_cells("propagate")

text_input_visframe = widgets.Text(
    description='Enter frame stride for visualization :',
    placeholder='Enter an integer 1,..,N',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='550px')
)
text_input_savframe = widgets.Text(
    description='Enter frame stride for saving:',
    placeholder='Enter an integer 1,..,N',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='550px')
)
confirm_button = widgets.Button(
    description='Set how often you want to see and save the mask outputs while it is propagating',
    button_style='success',  # Green button
    tooltip='Click to set the Frame #',
    icon='check',
    layout=widgets.Layout(width='550px')
)
output = widgets.Output()
vis_frame_stride = None
save_frame_stride = None

def on_confirm_click(b):
    global vis_frame_stride, save_frame_stride, output_dir
    with output:
        output.clear_output()
        try:
            vis_frame_stride = int(text_input_visframe.value)
            save_frame_stride= int(text_input_savframe.value)
            output_dir = output_dir
            print(f"Visualization frame stride set to: {vis_frame_stride}")
            print(f"Save frame stride set to: {save_frame_stride}")
        except ValueError:
            print("Please enter a valid integer.")

confirm_button.on_click(on_confirm_click)

run_propagate_button = widgets.Button(
    description='Propagate segmentation',
    button_style='info',
    tooltip='Click to propagate segmentation throughout the video',
    icon='play',
    layout=widgets.Layout(width='550px')
)
def on_run_propagate_click(b):
    if vis_frame_stride is not None and save_frame_stride is not None:
        action_propagate()
    else:
        print("Please set the frame numbers first.")

run_propagate_button.on_click(on_run_propagate_click)
display(widgets.VBox([text_input_visframe, text_input_savframe, confirm_button, run_propagate_button, output]))

### Loading Data Back In to Obtain Processed Results

This section is only required *if the propagation ended before finishing* (i.e. kernel crashed) and the masks_array and/or the images of the masks and combined results did not save.
- If you only have a video_segments file: Select the "Load in video segments..." option in the menu. Check both checkboxes to save the mask array and the images.
- If you have a mask array file but no images of the results: Select the "Load in masks..." option in the menu. Check the second checkbox ("Save images...")

**Set the following paths**: *segments_path*, *video_dir*, and *masks_array_path* (if loading in a mask array). Set *experiment_name*, *experiment_dir* and *output_dir* if you have not already.

In [16]:
# # Set experiment_name and experiment_dir again if necessary. Outputs will save to output_dir
experiment_name = 'example'
experiment_dir = os.path.join('Data', experiment_name)

output_dir = os.path.join(experiment_dir, 'output')

# Change video_dir to location of the original JPEG images
video_dir = os.path.join(experiment_dir, 'jpeg_images')

# Change segments_path to your filename
segments_path = os.path.join(output_dir, f'{experiment_name}_video_segments_final.pkl')

# If you have masks_array already saved, change masks_array_path to the filename
masks_array_path = os.path.join(output_dir, f'{experiment_name}_masks.npy')

In [None]:
global_vars = {"output_dir": output_dir, "experiment_name": experiment_name}

output = widgets.Output()

dropdown = widgets.Dropdown(
    options = ['Select action', 'Load in video segments from .pkl file', 'Load in numpy masks array from file'],
    style = {'description_width': 'initial'},
    layout=widgets.Layout(width='550px')
)
save_result_images_checkbox = widgets.Checkbox(
    value = False,
    description = 'Save images of masks and combined results to stitch later',
    style = {'description_width': 'initial'},
    tooltip = 'Results will save to two folders (masks and combined_results) created in output_dir',
    layout=widgets.Layout(width='550px')
)
save_masks_array_checkbox = widgets.Checkbox(
    value = False,
    description = 'Save mask array for future tracking analysis',
    tooltip = 'Array will save to output_dir',
    style = {'description_width': 'initial'},
    layout=widgets.Layout(width='550px'),
)
run_button = widgets.Button(
    description = "Run",
    icon = 'check',
    layout=widgets.Layout(width='550px'),
    button_style = 'success'
)

def on_run_button_click(b):
    selected_action = dropdown.value

    # Get checkbox states
    save_result_images = save_result_images_checkbox.value
    save_masks_array = save_masks_array_checkbox.value

    with output:
        output.clear_output()
        if selected_action == 'Load in numpy masks array from file':
            print("Loading masks array from file...")
            run_tagged_cells("load_masks_array")
            if save_result_images:
                print("Saving results as images...")
            print("Done")


        elif selected_action == 'Load in video segments from .pkl file':
            print("Loading video segments from file...")
            run_tagged_cells("load_video_segments")
            if save_masks_array:
                print("Saving masks array file...")
                run_tagged_cells("save_masks_array")
            if save_result_images:
                print("Saving results as images...")
                run_tagged_cells("save_result_images")
            print("Done")


        else:
            print("Please select a valid action.")

run_button.on_click(on_run_button_click)

display(widgets.VBox([dropdown, save_masks_array_checkbox, save_result_images_checkbox, run_button, output]))


### Optional visualization: Visualize masks from selected frames by changing frames_to_visualize

In [None]:
frames_to_visualize = [0, 5]
# visualize_masks(video_dir, frame_names, video_segments, frames_to_visualize=frames_to_visualize, background_color=0.3, return_mask_array=False, cmap=cmap_custom)
visualize_combined_masks(combined_masks, frames_to_visualize=frames_to_visualize, height=height, width=width, background_color=0.3, cmap = cmap)

## Export Results as Animated Video

Run the ffmpeg command below or in your terminal to stitch together combined_results into a video **at your fps**

In [None]:
original_dir = os.getcwd()
os.chdir(os.path.join(output_dir, 'combined_results'))

# change 80 to your fps
!ffmpeg -s 1920x1080 -r 80 -f image2 -i %06d.png -vcodec mpeg4 -q:v 5 -pix_fmt yuv420p -vf scale=1280:-2 ../animation.mp4

os.chdir(original_dir)

# Module 2: Particle Tracking
Analyze the motion of the objects you have segmented

### Particle Tracking Module Set-up

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage.measure import label, regionprops
from skimage import measure, exposure
import os
from PIL import Image, ImageFilter, ImageDraw
from matplotlib import cm
import scipy.stats
import collections
import matplotlib.colors as clr
import matplotlib.pylab as pl
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import ListedColormap
from scipy.stats import norm
from scipy.stats import gamma
import statistics
from matplotlib.colors import Normalize, LinearSegmentedColormap

run_tagged_cells("setup_tracking_functions")

**Enter the following information below** regardless of whether you are using the Trajectory or Analysis section of this module: frames per second, height and width of your video in pixels, pixel size in your video (in nanometers), and folder paths

In [30]:
fps = 80
pixel_size = 2 # nm/px

experiment_name = 'example'
experiment_dir = os.path.join('Data', experiment_name)
output_dir = os.path.join(experiment_dir, 'output')

In [33]:
if using_colab == False:
  plt.rcParams['font.family'] = 'sans-serif'
  plt.rcParams['font.sans-serif'] = 'Arial'
# custom cmap
colors = [
    "#E69F00",  # Orange
    "#D55E00",   # Red-orange
    "#56B4E9",  # Sky Blue
    "#009E73",  # Green
    "#0072B2", # Blue
    "#CC79A7", # Pink
    "#F0E442", # Yellow
    "#B3D100",  # Yellow-green
    "#9E1B32"  # Dark red
]
cmap_custom = ListedColormap(colors, name='cmap_custom')

### Step 1: Trajectories
This section finds the centroid (x, y) and angle of each object in each frame of your mask array and outputs the trajectories (in nm) in a .csv file.

**Set the path to your masks array file (.npy)** to load it in. The array shape should be (frames, num_objects, height, width).
Skip the first cell if masks_array was already calculated during this kernel run

In [None]:
# Or, directly change masks_array_path to the path where the array is saved
masks_array_path = os.path.join(output_dir, f'{experiment_name}_masks.npy')
masks_array = np.load(masks_array_path)

frames, num_objects, h, w = masks_array.shape
print(masks_array.shape)

In [None]:
# Use regionprops on each mask in each frame to gather positional data
data = []
plt.figure(figsize=(8, 6))

for frame_number in range(len(masks_array)):
        mask_i=masks_array[frame_number]
        plt.text(15,60, f'Frame {frame_number}', color='white', fontsize=12,
         bbox=dict(facecolor='black', alpha=0.5))


        plt.imshow(np.sum(mask_i, axis=0), cmap='gray')

        for mask_index, mask in enumerate(mask_i):
                labeled_mask = label(mask)  # Label connected components in the mask
                regions = regionprops(labeled_mask)
                if not regions:
                        data.append([frame_number, mask_index, np.nan, np.nan,  np.nan])
                else:
                        region = regions[np.argmax([r.area for r in regions])]
                        centroid = region.centroid
                        x_pixels = centroid[1]
                        x = x_pixels * pixel_size # convert to nm
                        y_pixels = centroid[0]
                        y = y_pixels * pixel_size # convert to nm
                        angle = np.rad2deg(region.orientation)
                        plt.plot(x_pixels, y_pixels, 'ro', label=f'Object {mask_index}', markersize=3)
                        data.append([frame_number, mask_index, x, y, angle])


# Export results
data = np.array(data)
df = pd.DataFrame(data, columns=["Frame", "Object", "x", "y", "Angle"])
# df.to_csv("trajectories_SAM2.csv", index=False)
trajectories_file = create_unique_file_path(os.path.join(output_dir, f'{experiment_name}_trajectories_SAM2.csv'))
df.to_csv(trajectories_file, index=False)

# Customize plot
#plt.title("Centroid Locations for Each Mask")
#plt.xlabel("X Coordinate")
#plt.ylabel("Y Coordinate")
#plt.grid(True)
plt.xlabel("x (pixels)")
plt.ylabel("y (pixels)")
plt.show()

### Step 2: Motion Analysis

This section performs analyses (tMSD, distribution of displacements, autocorrelation, trajectory) and generates figures that save as PDFs to your output directory.
***Ensure that you have set the fps and pixel size in the "Particle Tracking Module Set-up" section above***

**Set the path to your trajectories file (.csv)**.

In [35]:
trajectories_file = os.path.join(output_dir, f'{experiment_name}_trajectories_SAM2.csv')
save_path = output_dir

Select which analyses you would like to perform:

In [None]:
df=pd.read_csv(trajectories_file)
df=df.dropna()

height_nm = height*pixel_size
width_nm = width*pixel_size

title_label = widgets.Label(value="Select the analyses you would like to perform:", style={'font_size': '20px'})
tmsd_checkbox = widgets.Checkbox(value=False, description='tMSD', layout=widgets.Layout(width='550px'))
trajectory_checkbox = widgets.Checkbox(value=False, description='Trajectory')
displacement_checkbox = widgets.Checkbox(value=False, description='Distribution of Displacements')
autocorrelation_checkbox = widgets.Checkbox(value=False, description='Autocorrelation')

run_button = widgets.Button(
    description = "Run",
    icon = 'check',
    layout=widgets.Layout(width='550px'),
    button_style = 'success'
)

output = widgets.Output()

def on_run_button_click(b):
    with output:
        output.clear_output()
        if tmsd_checkbox.value:
            print("Calculating tMSD...")
            run_tagged_cells("tMSD")
        if trajectory_checkbox.value:
            print("Calculating Trajectory...")
            run_tagged_cells("trajectory")
        if displacement_checkbox.value:
            print("Calculating Distribution of Displacements...")
            run_tagged_cells("dist_disp")
        if autocorrelation_checkbox.value:
            print("Calculating Autocorrelation...")
            run_tagged_cells("acorr")

run_button.on_click(on_run_button_click)

widgets.VBox([title_label, tmsd_checkbox, trajectory_checkbox, displacement_checkbox, autocorrelation_checkbox, run_button, output])


# Dependencies

## Setup cells

#### Environment Set-up

In [None]:
using_colab = True
if using_colab:
    from google.colab import drive
    drive.mount('/content/drive')
    import os
    os.chdir(folder_path)
    from google.colab import output
    output.enable_custom_widget_manager()

    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/sam2.git'

    # !mkdir -p videos
    # !wget -P videos https://dl.fbaipublicfiles.com/segment_anything_2/assets/bedroom.zip
    # !unzip -d videos videos/bedroom.zip

    !mkdir -p ../checkpoints/
    !wget -P ../checkpoints/ https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt

In [None]:
using_colab = False

In [None]:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from PIL import Image
import pickle
from matplotlib.colors import ListedColormap

In [None]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

In [None]:
# Functions to create a unique folder or file path to avoid overwriting
def create_unique_folder(base_folder):
    """Append a number if the folder already exists."""
    folder_name = base_folder
    counter = 1

    while os.path.exists(folder_name):
        folder_name = f"{base_folder}_{counter}"
        counter += 1

    os.makedirs(folder_name)
    return folder_name

def create_unique_file_path(base_path):
    """Append a number if the file already exists."""
    file_path = base_path
    base_name, ext = os.path.splitext(base_path)  # Split base name and extension
    counter = 1

    # If the file exists, append a number to make it unique
    while os.path.exists(file_path):
        file_path = f"{base_name}_{counter}{ext}"
        counter += 1

    return file_path

#### Loading SAM2

In [None]:
from sam2.build_sam import build_sam2_video_predictor

sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

#### Segmentation and Visualization Functions

In [None]:
def show_mask(mask, ax, obj_id=None, random_color=False, cmap = "Paired"):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        # Get colormap
        cmap_name = cmap
        cmap = plt.get_cmap(cmap)

        # when using Paired cmap, indices 2 and 3 are both green. Change index 2 to the purple in Paired
        if cmap_name == "Paired":
            modified_paired_cmap = cmap(np.arange(cmap.N))
            modified_paired_cmap[2] = mcolors.to_rgba((0.6510, 0.4588, 0.8353))
            cmap = mcolors.ListedColormap(modified_paired_cmap)

        cmap_idx = 0 if obj_id is None else obj_id-1
        color = np.array([*cmap(cmap_idx)[:3], 0.8])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

In [None]:
# Handles user click inputs (Jupyter)
def get_click_coordinates(image, input_mode, point_type = True, plotting_args = None, cmap = "Paired"):
    """
    Handles user click inputs

    Args:
        image: image as an image file or an np.array.
        input_mode: 'box' or 'points'
        point_type: True if defining a positive point, False if defining a negative point
        plotting args: list of [prompts, out_mask_logits, out_obj_ids(np.arrau)]
    """

    %matplotlib tk
    if not isinstance(image, np.ndarray):
        image = Image.open(image)
        image_array = np.array(image)
    else:
        image_array = image

    # Display the image
    fig, ax = plt.subplots()
    ax.imshow(image_array, cmap='gray')
    if input_mode == 'points':
        if point_type == True:
            ax.set_title("Select particle point(s), then close this window")
        if point_type == False:
            ax.set_title("Select negative point(s) to refine particle mask, then close this window")
    if input_mode == 'box':
        ax.set_title("Select two diagonal corners of the bounding box around the particle")

    if plotting_args is not None:
        prompts, out_mask_logits, out_obj_ids = plotting_args
        for i, out_obj_id in enumerate(out_obj_ids):
            # show_points(*prompts[out_obj_id], plt.gca())
            show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id, cmap=cmap)

    # List to store coordinates of clicks
    clicked_coordinates = []

    # Define the click event function
    def onclick(event):
        # Only store coordinates if the click is within the image plot area
        if event.xdata is not None and event.ydata is not None:
            x, y = (event.xdata), (event.ydata)
            clicked_coordinates.append((x, y))

            # Show a red marker at the clicked location
            ax.plot(x, y, 'rx', markersize=6, markeredgewidth = 1)  # 'ro' is for red dots
            fig.canvas.draw()  # Redraw the figure to update the marker

    # Define the key event function to detect "Enter" key
    def on_key(event):
        nonlocal enter_pressed
        if event.key == 'enter':
            enter_pressed = True  # Set the flag to break out of the loop

    # Connect the click event and key event to the figure
    cid_click = fig.canvas.mpl_connect('button_press_event', onclick)
    cid_key = fig.canvas.mpl_connect('key_press_event', on_key)

    # Flag to check if "Enter" was pressed
    enter_pressed = False

    # Display the plot and wait for user input in a loop
    plt.show(block=False)

    while plt.fignum_exists(fig.number) and not enter_pressed:
        fig.canvas.flush_events()  # Keep the GUI responsive

    # Disconnect events after finishing
    fig.canvas.mpl_disconnect(cid_click)
    fig.canvas.mpl_disconnect(cid_key)

    %matplotlib inline

    # Return the collected coordinates
    return np.array(clicked_coordinates)

In [None]:
# Handles user click inputs (Colab)
from ipywidgets import Button, VBox, Output
global clicked_coordinates
def get_click_coordinates(image, input_mode, point_type = True, plotting_args = None, cmap = "Paired"):
    """
    Handles user click inputs in Colab. Changes the global variable clicked_coordinates

    Args:
        image: image as an image file or an np.array.
        input_mode: 'box' or 'points'
        point_type: True if defining a positive point, False if defining a negative point
        plotting args: list of [prompts, out_mask_logits, out_obj_ids(np.arrau)]
    """
    %matplotlib widget
    if not isinstance(image, np.ndarray):
        image = Image.open(image)
        image_array = np.array(image)
    else:
        image_array = image

    # Display the image
    fig, ax = plt.subplots()
    ax.imshow(image_array, cmap='gray')
    if input_mode == 'points':
        if point_type == True:
            ax.set_title("Select particle point(s), then press the Done button")
        if point_type == False:
            ax.set_title("Select negative point(s) to refine particle mask, then press the Done button")
    if input_mode == 'box':
        ax.set_title("Select two diagonal corners of the bounding box around the particle, then press the Done button")

    if plotting_args is not None:
        prompts, out_mask_logits, out_obj_ids = plotting_args
        for i, out_obj_id in enumerate(out_obj_ids):
            # show_points(*prompts[out_obj_id], plt.gca())
            show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id, cmap=cmap)

    # List to store coordinates of clicks
    global clicked_coordinates
    clicked_coordinates = []

    output = Output()

    def onclick(event):
      x, y = (event.xdata), (event.ydata)
      global clicked_coordinates
      clicked_coordinates.append((x, y))
      with output:
          print(f"Clicked at: x={event.xdata:.2f}, y={event.ydata:.2f}")
          ax.plot(x, y, 'rx', markersize=6, markeredgewidth = 1)
          fig.canvas.draw()

    def on_done_clicked(b):
      with output:
          print("Done! Here are the collected points:")
          print(clicked_coordinates)
      plt.close()

    # Create a "done" button
    done_button = Button(description="Done", button_style='success')
    done_button.on_click(on_done_clicked)



    # Define the key event function to detect "Enter" key
    def on_key(event):
        nonlocal enter_pressed
        if event.key == 'enter':
            enter_pressed = True

    # Connect the click event and key event to the figure
    cid_click = fig.canvas.mpl_connect('button_press_event', onclick)
    cid_key = fig.canvas.mpl_connect('key_press_event', on_key)

    # Flag to check if "Enter" was pressed
    enter_pressed = False
    display(VBox([done_button, output]))


    plt.show()

    %matplotlib inline

    # # Return the collected coordinates
    # return np.array(clicked_coordinates)

In [None]:
def visualize_masks(
    video_dir, frame_names, video_segments, frames_to_visualize=None, vis_frame_stride=1, background_color=None, return_mask_array=False, visualize=True, cmap = None
):
    """
    Visualize masks for specified frames using SAM2 outputted video segments.

    Args:
        video_dir (str): Directory containing the video and frames.
        frame_names (list): List of frame file names.
        video_segments (list): List of masks for each frame (as dictionaries).
        frames_to_visualize (int, list, optional): Specific frame(s) to visualize.
            If None, visualize all frames with the given stride.
        vis_frame_stride (int): Stride for visualizing frames (if `frames_to_visualize` is None).
        background_color: If defined, masks will be shown on a grey background (value 0 to 1)
        return_mask_array: If set to true, will return an array with each mask for the specified frames
        visualize: If set to True, will show output of masks

    Returns:
        mask_array (if return_mask_array = True): array of shape (frames, num_objects, height, width)
    """

    # Determine frames to visualize
    if frames_to_visualize is None:
        frames_to_process = range(0, len(frame_names), vis_frame_stride)
    elif isinstance(frames_to_visualize, int):
        frames_to_process = [frames_to_visualize]
    elif isinstance(frames_to_visualize, list):
        frames_to_process = frames_to_visualize
    else:
        raise ValueError("frames_to_visualize must be None, an int, or a list of ints.")

    # if no background_color is set, then masks will be shown on the original image
    if background_color != None:
        # Load the first image to determine dimensions
        first_image = Image.open(os.path.join(video_dir, frame_names[0]))
        image_height, image_width = np.array(first_image).shape[:2]
        # Grey background for mask visualization
        background = np.full((image_height, image_width), fill_value=background_color, dtype=float)

    if cmap == None:
        cmap = plt.get_cmap('Paired')

    all_masks = []
    # Process the specified frames
    for out_frame_idx in frames_to_process:
        if out_frame_idx < 0 or out_frame_idx >= len(frame_names):
            print(f"Skipping invalid frame index: {out_frame_idx}")
            continue

        if return_mask_array==True:
            masks_frame_i = []

            # Save masks of each object
            for out_obj_id, out_mask in video_segments[out_frame_idx].items():
                mask_squeezed = np.squeeze(out_mask)
                masks_frame_i.append(mask_squeezed)
            all_masks.append(masks_frame_i)

        if visualize==True:
            plt.figure(figsize=(6, 4))
            plt.title(f"Frame {out_frame_idx}")

            # Load the frame (original image or grey background)
            if background_color == None:
                frame_path = os.path.join(video_dir, frame_names[out_frame_idx])
                frame_image = Image.open(frame_path)
                plt.imshow(frame_image)
            else:
                plt.imshow(background, cmap = 'gray', vmin = 0, vmax = 1)

            # Overlay the masks
            for out_obj_id, out_mask in video_segments[out_frame_idx].items():
                show_mask(out_mask, plt.gca(), obj_id=out_obj_id, cmap = cmap)

            plt.show()

    if return_mask_array == True:
        return np.array(all_masks)

In [None]:
def save_masks_and_results(video_dir, frame_names, video_segments, vis_frame_stride=1, background_color=0.3, return_mask_array=False, output_dir = None, cmap = None):
    """
    Save either just masks or original images next to masks (results), with unique folders for results.

    Args:
        video_dir (str): Directory containing the video and frames.
        frame_names (list): List of frame file names.
        video_segments (list): List of masks for each frame.
        vis_frame_stride (int): Stride for saving frames. Defaults to 1.
        background_color (float): Gray background intensity. Defaults to 0.3.
        return_mask_array: Will return all_masks (shape = frames, obj, h, w) if True
        output_dir: Directory where the masks will save to
        cmap: Defines colors of masks (indexed by obj_id). If none defined, "Paired" will be used
    Returns:
        mask_array (if return_mask_array = True): array of shape (frames, num_objects, height, width)

    """
    all_masks = []

    if output_dir == None:
        output_dir = video_dir

    masks_folder = create_unique_folder(os.path.join(output_dir, 'masks'))
    combined_folder = create_unique_folder(os.path.join(output_dir, 'combined_results'))

    # Load the first image to determine dimensions
    first_image = Image.open(os.path.join(video_dir, frame_names[0]))
    image_height, image_width = np.array(first_image).shape[:2]

    # Background for mask visualization
    background = np.full((image_height, image_width), fill_value=background_color, dtype=float)

    # Colormap
    if cmap == None:
        cmap = plt.get_cmap('Paired')

    # Iterate through frames
    for out_frame_idx in range(0, len(frame_names), 1):

        if return_mask_array==True:
            masks_frame_i = []

            # Append masks of each object to all_masks
            for out_obj_id, out_mask in video_segments[out_frame_idx].items():
                mask_squeezed = np.squeeze(out_mask)
                masks_frame_i.append(mask_squeezed)
            all_masks.append(masks_frame_i)

        # Visualization and saving figures for specified frames
        if out_frame_idx in range(0, len(frame_names), vis_frame_stride):
            fig, axes = plt.subplots(1, 2, figsize=(10, 6))

            # Display original image on left subplot
            original_image = Image.open(os.path.join(video_dir, frame_names[out_frame_idx]))
            axes[0].imshow(original_image, cmap='gray')
            axes[0].axis('off')

            # Display colored masks on a grey background on right subplot
            axes[1].imshow(background, cmap='gray', vmin=0, vmax=1)
            for out_obj_id, out_mask in video_segments[out_frame_idx].items():
                show_mask(out_mask, axes[1], obj_id=out_obj_id, cmap=cmap)
            axes[1].axis('off')

            # Create file paths
            combined_image_path = os.path.join(combined_folder, f"{str(out_frame_idx).zfill(6)}.png")
            mask_image_path = os.path.join(masks_folder, f"{str(out_frame_idx).zfill(6)}.png")

            # Save the combined figure
            plt.subplots_adjust(wspace=0)
            plt.savefig(combined_image_path, bbox_inches='tight', pad_inches=0)

            # Save only the mask on grey background
            bbox = axes[1].get_window_extent().transformed(fig.dpi_scale_trans.inverted())
            fig.savefig(mask_image_path, bbox_inches=bbox, pad_inches=0)

            plt.close()

            if out_frame_idx%20==0:
                print(f"Combined image saved: {combined_image_path}")
                print(f"Mask image saved: {mask_image_path}")

    if return_mask_array == True:
        return np.array(all_masks)

The functions below are for 'combined masks', which are an image for each frame that contain the masks for all objects, in contrast to masks_array, where each object in a frame has an individual array

In [None]:
# Create combined masks for each frame: shape (frames, num_obj, h, w) --> (frames, h, w)
def put_per_obj_mask(masks_array, height, width):
    """Combine per-object masks from a list/array into a single mask for each frame."""
    combined_masks = []

    for frame_idx in range(masks_array.shape[0]):
        frame_mask = np.zeros((height, width), dtype=np.uint8)

        # Iterate over each object in the frame. Combine the masks by assigning the object ID where the mask is non-zero
        for obj_id in range(masks_array.shape[1]):
            object_mask = masks_array[frame_idx, obj_id]
            frame_mask[object_mask > 0] = obj_id + 1

        combined_masks.append(frame_mask)

    return np.array(combined_masks)

# Shows a single frame's combined mask
def show_combined_mask(mask, ax, obj_id=None, cmap="Paired"):
    """Display mask on the given axis using the same colormap logic as show_mask."""
    # Get colormap
    cmap_name = cmap
    cmap = plt.get_cmap(cmap)

    # when using Paired cmap, indices 2 and 3 are both green. Change index 2 to the purple in Paired
    if cmap_name == "Paired":
        modified_paired_cmap = cmap(np.arange(cmap.N))
        modified_paired_cmap[2] = mcolors.to_rgba((0.6510, 0.4588, 0.8353))
        cmap = mcolors.ListedColormap(modified_paired_cmap)

    # If obj_id is None, use the first color; otherwise, map obj_id to color
    cmap_idx = 0 if obj_id is None else obj_id-1
    color = np.array([*cmap(cmap_idx)[:3], .9])

    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)

    ax.imshow(mask_image)


# Displays the combined masks for the specified frames
def visualize_combined_masks(
    combined_masks,
    frames_to_visualize=None,
    vis_frame_stride=1,
    height=None,
    width=None,
    background_color=0.3,
    cmap="Paired"
):
    """
    Visualize combined masks for specified frames.

    Args:
        combined_masks (np.ndarray): Combined masks array with shape (frames, height, width).
        frames_to_visualize (int, list, optional): Specific frame(s) to visualize.
            If None, visualize all frames with the given stride.
        vis_frame_stride (int): Stride for visualizing frames (if `frames_to_visualize` is None).
        height (int, optional): Height of the mask (required if `background_color` is used).
        width (int, optional): Width of the mask (required if `background_color` is used).
        background_color (float): Gray background intensity (0 to 1). Defaults to 0.5.
        cmap (str): Colormap to use for object colors. Defaults to "Paired".

    Returns:
        None
    """
    # Determine frames to visualize
    if frames_to_visualize is None:
        frames_to_process = range(0, combined_masks.shape[0], vis_frame_stride)
    elif isinstance(frames_to_visualize, int):
        frames_to_process = [frames_to_visualize]
    elif isinstance(frames_to_visualize, list):
        frames_to_process = frames_to_visualize
    else:
        raise ValueError("frames_to_visualize must be None, an int, or a list of ints.")

    background = None
    if background_color is not None:
        background = np.full((height, width), fill_value=background_color, dtype=float)

    # Visualize each specified frame
    for frame_idx in frames_to_process:
        if frame_idx < 0 or frame_idx >= combined_masks.shape[0]:
            print(f"Skipping invalid frame index: {frame_idx}")
            continue

        plt.figure(figsize=(6, 4))
        plt.title(f"Combined Mask - Frame {frame_idx}")
        ax = plt.gca()

        if background is not None:
            ax.imshow(background, cmap="gray", vmin=0, vmax=1)

        # Overlay masks for each object
        for obj_id in np.unique(combined_masks[frame_idx]):
            if obj_id != 0:  # Skip the background
                obj_mask = (combined_masks[frame_idx] == obj_id).astype(np.uint8)
                show_combined_mask(obj_mask, ax, obj_id=obj_id, cmap=cmap)

        plt.axis("off")
        plt.show()

## Segmentation - Module 1

####  Prompt Inputs - getting user input from clicks

In [None]:
# Code for adding prompts (Jupyter)
if input_mode == "box" or input_mode == "Box":
    if ann_frame_idx == old_frame_idx: # there should be associated predictions to show
        clicked_corners = get_click_coordinates(os.path.join(video_dir, frame_names[ann_frame_idx]), input_mode = input_mode, plotting_args = [None, out_mask_logits, out_obj_ids], cmap=cmap_custom)
    else:
        clicked_corners = get_click_coordinates(os.path.join(video_dir, frame_names[ann_frame_idx]), input_mode = input_mode, plotting_args = None, cmap=cmap_custom)

    x_min, x_max = np.min(clicked_corners[:,0]), np.max(clicked_corners[:,0])
    y_min, y_max = np.min(clicked_corners[:,1]), np.max(clicked_corners[:,1])

    box = np.array([x_min, y_min, x_max, y_max], dtype = np.float32)

    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=ann_frame_idx,
        obj_id=ann_obj_id,
        box=box,
    )

if input_mode == 'points' or input_mode == "Points":
    if ann_obj_id != old_obj_id or ann_frame_idx != old_frame_idx:
        points = np.array([])
        labels = np.array([])
        old_obj_id = ann_obj_id

    # If any predictions have been made on this frame, then show them. otherwise, jsut show the image
    if ann_frame_idx == old_frame_idx: # there should be associated predictions to show
        new_point = get_click_coordinates(os.path.join(video_dir, frame_names[ann_frame_idx]), input_mode = input_mode, point_type = point_type, plotting_args = [prompts, out_mask_logits, out_obj_ids], cmap=cmap_custom)
    else:
        new_point = get_click_coordinates(os.path.join(video_dir, frame_names[ann_frame_idx]), input_mode = input_mode, point_type = point_type, plotting_args = None, cmap=cmap_custom)

    num_inputs = len(new_point)

    points = new_point if points.shape[0] == 0 else np.vstack((points, new_point))
    if labels.shape[0] == 0:
        labels = np.ones(num_inputs) if point_type else np.zeros(num_inputs)
    else:
        labels = np.append(labels, np.ones(num_inputs) if point_type else np.zeros(num_inputs))

    prompts[ann_obj_id] = points, labels
    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=ann_frame_idx,
        obj_id=ann_obj_id,
        points=points,
        labels=labels,
    )

old_frame_idx = ann_frame_idx

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
for i, out_obj_id in enumerate(out_obj_ids):
    # show_box(box, plt.gca())
    show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id, cmap=cmap_custom)

In [None]:
# Code for getting click input in Colab - the coordinates are appended to the global variable clicked_coordinates
if input_mode == "box" or input_mode == "Box":
    if ann_frame_idx == old_frame_idx: # there should be associated predictions to show
        clicked_corners = get_click_coordinates(os.path.join(video_dir, frame_names[ann_frame_idx]), input_mode = input_mode, plotting_args = [None, out_mask_logits, out_obj_ids], cmap=cmap_custom)
    else:
        clicked_corners = get_click_coordinates(os.path.join(video_dir, frame_names[ann_frame_idx]), input_mode = input_mode, plotting_args = None, cmap=cmap_custom)


if input_mode == 'points' or input_mode == "Points":
    if ann_obj_id != old_obj_id or ann_frame_idx != old_frame_idx:
        points = np.array([])
        labels = np.array([])
        old_obj_id = ann_obj_id

    # If any predictions have been made on this frame, then show them. otherwise, jsut show the image
    if ann_frame_idx == old_frame_idx: # there should be associated predictions to show
        new_point = get_click_coordinates(os.path.join(video_dir, frame_names[ann_frame_idx]), input_mode = input_mode, point_type = point_type, plotting_args = [prompts, out_mask_logits, out_obj_ids], cmap=cmap_custom)
    else:
        new_point = get_click_coordinates(os.path.join(video_dir, frame_names[ann_frame_idx]), input_mode = input_mode, point_type = point_type, plotting_args = None, cmap=cmap_custom)

In [None]:
# Code for adding prompts from click inputs in Colab
if input_mode == "box" or input_mode == "Box":
    clicked_corners = np.array(clicked_coordinates)
    plt.close('all')

    x_min, x_max = np.min(clicked_corners[:,0]), np.max(clicked_corners[:,0])
    y_min, y_max = np.min(clicked_corners[:,1]), np.max(clicked_corners[:,1])

    box = np.array([x_min, y_min, x_max, y_max], dtype = np.float32)

    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=ann_frame_idx,
        obj_id=ann_obj_id,
        box=box,
    )

if input_mode == 'points' or input_mode == "Points":
    new_point = np.array(clicked_coordinates)
    plt.close('all')

    if ann_obj_id != old_obj_id or ann_frame_idx != old_frame_idx:
        points = np.array([])
        labels = np.array([])
        old_obj_id = ann_obj_id

    num_inputs = len(new_point)

    points = new_point if points.shape[0] == 0 else np.vstack((points, new_point))
    if labels.shape[0] == 0:
        labels = np.ones(num_inputs) if point_type else np.zeros(num_inputs)
    else:
        labels = np.append(labels, np.ones(num_inputs) if point_type else np.zeros(num_inputs))

    prompts[ann_obj_id] = points, labels
    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=ann_frame_idx,
        obj_id=ann_obj_id,
        points=points,
        labels=labels,
    )

old_frame_idx = ann_frame_idx

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
for i, out_obj_id in enumerate(out_obj_ids):
    # show_box(box, plt.gca())
    show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id, cmap=cmap_custom)

#### Propagation

In [None]:
# run propagation throughout the video and collect the results in a dictionary

video_segments = {}  # video_segments contains the per-frame segmentation results
i = 0
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

    if i%vis_frame_stride == 0:
        visualize_masks(video_dir, frame_names, video_segments, frames_to_visualize=[out_frame_idx], cmap=cmap_custom)

    if i%save_frame_stride==0:
        print(output_dir)
        segments_path_temp = create_unique_file_path(os.path.join(output_dir, 'video_segments_temp.pkl'))
        print(segments_path_temp)
        with open(segments_path_temp, 'wb') as f:
            pickle.dump(video_segments, f)
        print(f'{out_frame_idx} Video segments saved successfully!')
    i+=1


# after propagation is done, save video_segments to the experimental_dir
segments_path = create_unique_file_path(os.path.join(output_dir, f'{experiment_name}_video_segments_final.pkl'))
with open(segments_path, 'wb') as f:
    pickle.dump(video_segments, f)
print("Video segments saved successfully!")

# Save masks and results to folder
masks_array = save_masks_and_results(video_dir, frame_names, video_segments, return_mask_array=True, output_dir = output_dir, cmap = cmap_custom)
print(f'Shape of masks_array = {masks_array.shape}')

# Save masks_array to experimental_dir for future tracking analysis
masks_array_save_path = create_unique_file_path(os.path.join(output_dir, f'{experiment_name}_masks.npy'))
np.save(masks_array_save_path, masks_array)

# Create combined_masks array for visualization
combined_masks = put_per_obj_mask(masks_array, height=height, width=width)

#### Loading Data (video_segments or mask_array) back in to output results (mask_array and/or segmented images)

In [None]:
# Load in video segments from segments_path
with open(segments_path, 'rb') as f:
    video_segments = pickle.load(f)
print(f'{len(video_segments)} video segments loaded successfully!')

# get the frame names again so they can be inputs to visualization functions
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

# Take the first N frames, where N = length of video_segments (i.e. if video_segments only contains a subset the dataset)
frame_names = frame_names[:len(video_segments)]


# Create mask_array
masks_array = visualize_masks(video_dir, frame_names, video_segments, background_color=0.3, visualize=False, return_mask_array=True, cmap=cmap_custom)

# assign number of frames, objects, and height and width of images
if len(masks_array.shape) == 4: # if multiple particles
    num_frames, num_objects, height, width = masks_array.shape
else: # if 1 particle
    num_objects=1
    num_frames, height, width = masks_array.shape
print(f'Number of frames: {num_frames}')
print(f'Number of objects: {num_objects}')
print(f'Image height: {height}')
print(f'Image width: {width}')

combined_masks = put_per_obj_mask(masks_array, height=height, width=width)

In [None]:
# Load in masks_array from masks_array_path
masks_array = np.load(masks_array_path)

# assign number of frames, objects, and height and width of images
if len(masks_array.shape) == 4: # if multiple particles
    num_frames, num_objects, height, width = masks_array.shape
else: # if 1 particle
    num_objects=1
    num_frames, height, width = masks_array.shape
print(f'Number of frames: {num_frames}')
print(f'Number of objects: {num_objects}')
print(f'Image height: {height}')
print(f'Image width: {width}')

combined_masks = put_per_obj_mask(masks_array, height=height, width=width)

In [None]:
# Save mask array within output_dir
masks_array_save_path = create_unique_file_path(os.path.join(output_dir, f'{experiment_name}_masks.npy'))
np.save(masks_array_save_path, masks_array)
print(f'Masks saved to {masks_array_save_path} a')

In [None]:
# Save masks and results to folder
save_masks_and_results(video_dir, frame_names, video_segments, return_mask_array=False, output_dir = output_dir, cmap = cmap_custom)


## Particle Motion Analysis Functions - Module 2

TMSD

In [None]:
def get_msd(x):
    msd = []
    for i in range(1, len(x)):
        msd.append(np.average((x[i:] - x[:-i])**2))
    return np.array(msd)
def get_2d_msd(x, y):
    return get_msd(x) + get_msd(y)

In [None]:
fig, ax = plt.subplots(figsize=(6,4))

# Iterate through each unique object in the dataset
for obj in df['Object'].unique():

    df_obj = df[df['Object'] == obj]
    x = np.array(df_obj['x'])
    y = np.array(df_obj['y'])
    frame = np.array(df_obj['Frame'])
    angle = np.array(df_obj['Angle'])
    msd1 = get_msd(np.sqrt(x**2 + y**2))

    time = [1/fps]
    for i in range(1, len(df_obj)):
        time.append(time[i-1] + 1/fps)
    angle_time=time
    time = time[:len(x)-1]

    ax.scatter(time, msd1, lw=3, color=cmap_custom(int(obj)))
ax.set_yscale('log')
ax.set_xscale('log')
plt.legend()
plt.title(f'MSD', fontsize=20)
ax.set_xlabel(r'$\tau$ (s)',fontsize=20)
ax.set_ylabel(r'$ \overline{\delta r^2(\tau)} $ (nm$^2$)',fontsize=20)
plt.yticks(fontsize=20)
plt.xticks(fontsize=20)

tMSD_file = f'{experiment_name}_tMSD.pdf'
plt.savefig(save_path+'/'+tMSD_file, format ="pdf", bbox_inches='tight')


Trajectory

In [None]:
# Function to create a colormap from white to a target color
def create_colormap(color):
    return LinearSegmentedColormap.from_list("custom_cmap", ["white", color])

# Create #num_objects# colormaps
cmaps = [create_colormap(color) for color in colors]

fig, ax = plt.subplots(figsize=(6,5))

for obj, cmap_i in zip(df['Object'].unique(), cmaps):

    df_obj = df[df['Object'] == obj]
    x = np.array(df_obj['x'])
    y = np.array(df_obj['y'])
    frame = np.array(df_obj['Frame'])
    angle = np.array(df_obj['Angle'])

    t = np.linspace(0, (len(df_obj) - 1) / fps, len(df_obj))

    scat = ax.scatter(x,y,c=t, s=3, zorder=1, cmap=cmap_i)
    for i in range(len(x) - 1):
        color = scat.to_rgba(t[i])
        ax.plot([x[i], x[i+1]], [y[i], y[i+1]], color=color, lw=2, zorder=0)

#fig.colorbar(scat).set_label(label='t(s)',size=20, rotation=0)
cbar = plt.colorbar(scat, ax=ax)  # Add color bar with specific axes
cbar.set_label('t(s)', size=20, rotation=0, labelpad=15)

ax.set_xlabel('x', fontsize =20)
ax.set_ylabel('y', fontsize=20)
plt.gca().invert_yaxis()
plt.title(f'Trajectories', fontsize=20)
# plt.ylim([0, Nc])
# plt.xlim([0,Nl])
ax.set_aspect('equal', 'datalim')
plt.tight_layout()
plt.xlim((0, width_nm))
plt.ylim((height_nm,0))

traj_file = f'{experiment_name}_Trajectory.pdf'
plt.savefig(save_path+'/'+traj_file, format ="pdf", bbox_inches='tight')


Distribution of Displacements

In [None]:
def displacement(x):
    disps = [x[i] - x[i-1] for i in range(1, len(x))]
    return disps

def r2(x,y):
    r = np.sqrt(x**2+y**2)
    return r

def plot_hist(ax, r, title="", text="",color=""):
    #dx = np.diff(x)
    #dy = np.diff(y)
    #r=r2(x,y)
    dx = np.diff(r)
    counts, bins = np.histogram(dx, bins=92, density=True)
    counts = counts / np.max(counts)
    bin_centers = (bins[:-1] + bins[1:]) / 2
    ax.plot(bin_centers, counts,color=color,linewidth=4, label=text)
    #hist = ax.plot(dr, histogram, density=True,color=color, label=text)
    #rv = scipy.stats.norm(np.mean(dx),np.std(dx))
    #xs = np.linspace(np.min(dx), np.max(dx), 1000)
    #pdf_values = rv.pdf(xs)
    #pdf_normalized = pdf_values / np.max(pdf_values)
    #line = ax.plot(xs, pdf_normalized, label='Gaussian PDF',color='gray')
    #mu = np.mean(dx)
    #sigma = np.std(dx)
    #kurtosis = scipy.stats.kurtosis(dx)

    #ax.set_title(title)
    #ax.set_yscale('log')
    return ax

In [None]:
fig, ax=plt.subplots(figsize=(6, 4))
for obj in df['Object'].unique():

    df_obj = df[df['Object'] == obj]
    x = np.array(df_obj['x'])
    y = np.array(df_obj['y'])
    frame = np.array(df_obj['Frame'])
    angle = np.array(df_obj['Angle'])

    r=r2(x,y)
    disp_r=displacement(r)
    disp_r_min = np.min(disp_r)
    disp_r_max = np.max(disp_r)
    r_xlim = np.linspace(disp_r_min,disp_r_max,5000)

    plot_hist(ax, disp_r, title="",text="",color=cmap_custom(int(obj)))#, bins=40, density=True, color=cmpa(obj), alpha=0.3, zorder=0)#, edgecolor = 'black'

plt.xlabel('Δr',fontsize=20)
plt.ylabel('Normalized PDF',fontsize=20)  # Set the label for the y-axis
plt.title(f'Distribution of Displacements', fontsize=20)
plt.yticks(fontsize=20)
plt.xticks(fontsize=20)
plt.xlim([-50,50])
#ax.set_xscale('log')
distdisp_file = f'{experiment_name}_Distribution of Displacements.pdf'
plt.savefig(save_path+'/'+distdisp_file, format ="pdf", bbox_inches='tight')

Autocorrelation

In [None]:
def acf(x, length=200):
    x = x[1:] - x[0:-1]
    return np.array([1]+[np.corrcoef(x[:-i], x[i:])[0,1]  \
        for i in range(1, length)])

In [None]:
plt.figure(figsize=(6, 4))

for obj in df['Object'].unique():

    df_obj = df[df['Object'] == obj]
    x = np.array(df_obj['x'])
    y = np.array(df_obj['y'])
    frame = np.array(df_obj['Frame'])
    angle = np.array(df_obj['Angle'])

    r=r2(x,y)
    z_acorr = acf(np.asarray(r), length = len(x))


    plt.plot(np.arange(0,50),z_acorr[0:50],color=cmap_custom(int(obj)))

plt.title(f'Autocorrelation',fontsize=20)
plt.yticks(fontsize=20)
plt.xticks(fontsize=20)
plt.xlabel(r'lag time $\tau$ (s)',fontsize=20)
plt.ylabel(r'$C_v(\tau)$',fontsize=20)  # Set the label for the y-axis
aCorr_file = f'{experiment_name}_Autocorrelation.pdf'
plt.savefig(save_path+'/'+aCorr_file, format="pdf", bbox_inches="tight")
