---
# Cellpose segmentation and BTrack tracking
---

Author: Ryan Corbyn,   
Date: 06/01/2023.

-----
#### Script Description:  
This program is capable of segmenting and tracking cells from a phase-contrast time-lapse data.  

The script can handle the time-lapse data being in the form of either: 
1. A folder containing the time-lapse images as individual single frames (In this case the naming convention must be carefully chosen so that the images can be read in time-order. 
2. A single .tif file that contains all of the time-lapse images saved as a 3D image, with the image shape being: (time, x, y).

**A single stack image must be saved as .tif file type.**

The user is required to input the format of their imaging files at the start of the script. 

Once the image location has been selected, the image data is read into the program and the cells in the images are segmented using the cellpose library: https://cellpose.readthedocs.io/en/latest/index.html. The script is designed so that the user can select either one of the default cellpose segmentation models, or a bespoke retrained model can be used for cell segmentation. If using a retrained model, it may be required to first load the script into the cellpose library through the GUI:
https://cellpose.readthedocs.io/en/latest/models.html

Once the cells have been segmented, the segmenation masks are saved to create a cell mask image stack. This mask-stack is then used to capture cell movement using the BTrack library: https://btrack.readthedocs.io/en/latest/. Th BTrack library has some level of adaptability to improve the quality of the cell tracks generated by refining:
1) The number of frames that an object is tracked over. 
2) The number of permissable frames that an object may not be detected. 
3) The number of pixels away from the previous frame that an object is allowed to move before it is considered a new object. 

Once the tracking has been complete, this script allows the tracks to be filtered to only include objects tracked for longer than a user-defined minimum number of frames. This is useful to exclude objects/cells that only appear in the field of view for a few frames within the time-lapse measurement from the analysis. 

From the output of defining the tracks of the cells, the total displacement of the object and its trajectory are calculated. 
The results from all this analysis is saved into an analysis folder that is generated within the same folder that contains the raw image data.  


-----
#### Required inputs from the User:  
This script requires that the user manually include: 
1. The total measurement time in minutes - Cell 2.
2. The Image resolution used when recording the images - Cell 2. 
3. The minimum number of frames an object must be tracked for to be included in the analysis - Cell 2. 
4. How the image data is stored, either:
    - A single time-lapse stack image ('single_file') - Cell 2. 
    - A series of single frame .tif files saved within a folder ('folder_of_images') - Cell 2. 
5. The directory that houses the experiment images - Cell 11. 
6. The segmentation model for the analysis - Cell 12.
-------

## Cell 1
---
### Import the dependancies for the script
---

In [None]:
#Import cellpose and other dependancies
from cellpose import models, io, core, plot 
import btrack
import napari

# Import libraries for data analysis
import numpy as np 
from PIL import Image
import tifffile as tf
import pandas as pd
import os

# Import dependancies for user interface/input
from tqdm import tqdm
import tkinter as tk
from tkinter import filedialog

## Cell 2
---
### User inputs
---

In the following cell, the user needs to input: 
1) The total time of the experiment in minuntes. 
2) The resolution of the images to be analysed. 
3) The minimum track length for an object to be included in the analysis.
4) How the time-lapse data is stored, either as:
    - A series of individual images within a folder: 'folder_of_images'
    - A single stack image containing all the time-lapse data: 'single_file'

In [None]:
experiment_time = 16*30 # total measurement time in minutes. 
image_resolution = 0.635 # microns per pixel. 

# select the minimum number of tracks that an object must be tracked for. 
min_tracks = 3

# Select the input type for the experiment data. 
# input_type = 'folder_of_images'
input_type = 'single_file'

### Cell 3 
- A Function to get all the file names from the imaging folder defined below.
- Returns a list of image file names. 

In [None]:
def get_image_file_names(folder_path): 
    '''Get all the file names for .tif files in a folder.'''
    
    # initalised 
    list_of_files = []
    # Loop through all files and folders within the directory
    # specified in "folder_path"
    for root, dirs, files in os.walk(folder_path):
        # For all the file names. 
        for file in files:
            # If the file type matches one of the following extensions: 
            # Save the file name. 
            if file[-4:] == '.jpg' or file[-4:] == '.tif' or file[-4:] == '.png' or file[-4:] == '.TIF':
                list_of_files.append(os.path.join(file))
                
    return(list_of_files)

### Cell 4
- A fuction that runs the cellpose analysis on the image frame sent to the function. 
- Returns the cell masks generated by cellpose. 

In [None]:
def get_cellpose_analysis(image, channel, cellpose_model):
    '''A method to analyse the images in a stack using the 
    cellpose algorithm.'''
    
    masks, flows, styles = cellpose_model.eval(image, diameter=None, channels = channel)
    
    return(masks, flows)

### Cell 5
- A function to perform the cell tracking on the cellpose masks using the BTracks Algorithm. 
- Returns a python dictionary of the cell tracks. 

In [None]:
def find_Btracks(image_masks, configer_file, script_fold):
    '''A method to find the tracks from the cellpose masks
    generated from the image stack. 
    this method makes use of the trackpy package.'''
    
    FEATURES = [
      "area",
      "major_axis_length",
      "minor_axis_length",
       "eccentricity",
      "solidity",
                ]

    objects = btrack.utils.segmentation_to_objects(image_masks, properties=tuple(FEATURES))
    
    with btrack.BayesianTracker() as tracker:
        # configure the tracker using a config file
        tracker.configure_from_file(configer_file)

        tracker.features = FEATURES

        # append the objects to be tracked
        tracker.append(objects) 

        # set the volume (Z axis volume is set very large for 2D data)
        tracker.volume=((0, 2000), (0, 2000), (-1e5, 1e5))

        # track them (in interactive mode)
        tracker.track(step_size=100)

        # generate hypotheses and run the global optimizer
        tracker.optimize()

        # store the data in an HDF5 file
        tracker.export(script_fold + '\\tracks.h5', obj_type='obj_type_1')

        # get the tracks as a python list
        tracks = tracker.tracks#(tracking_updates=TRACKING_UPDATES)

        # get the data in a format for napari
        data, properties, graph = tracker.to_napari(ndim=2)
    
    return(data, properties, graph)

### Cell 6
- A function to filter the tracks by the minimum track length defined in cell 2. 
- Returns a pandas dataframe of the filtered tracks. 

In [None]:
def filter_tracks(tracks, props, min_track_len, indicies):
    '''Filter the tracks variable to only contain tracks longer than a 
    specified minimum (track_min). The resulting pandas array is then 
    sorted by object order.'''
    
    
    filtered_tracks = pd.DataFrame()
    filtered_props = pd.DataFrame()
    
    for i in range(indicies.shape[0]):
        if indicies[i][1] - indicies[i][0] > min_track_len: 
            single_track = pd.DataFrame(tracks.iloc[indicies[i][0]: indicies[i][1]])
            single_props = pd.DataFrame( props.iloc[indicies[i][0]: indicies[i][1]] )
            
            filtered_tracks = pd.concat([filtered_tracks, single_track], ignore_index = True)
            filtered_props = pd.concat([filtered_props, single_props], ignore_index = True)
        
    return(filtered_tracks, filtered_props)

### Cell 7
- A function to find the indicies at which the cell track ID changes within the cell tracking dataframe. 
- Returns an numpy array of the points of Track ID changes within the cell tracking dataframe.

In [None]:
def particle_index(tracks):
    '''Find the indicies where the particle object in the 
    sorted_tracks variable changes. '''
    
    # Set the initial particle value
    particle = tracks['Track'].iloc[0]
    # generate the index_score variable. 
    index_score = [0]

    # Loop to find the indicies in which the particle value changes. 
    for i in range(tracks.shape[0]):
        if tracks['Track'].iloc[i] != particle:
            index_score.append(i-1)
            particle = tracks['Track'].iloc[i]
            index_score.append(i)

    # grab last i value. 
    index_score.append(i)
    
    # Convert index score to a 2D numpy array. 
    index_score = np.reshape(index_score, [int(len(index_score)/2), 2])
    
    return(index_score)

### Cell 8
- A function to measure the displacement a cell travels throughout the time-lapse measurement and it's trajectory (direction of travel). 
- Returns: 
    1. A list of distances travelled by all tracked objects.
    2. A list of the trajectories for the tracked cells. 
    
Note: Displacement is defined as the difference in the x, y co-ordinates between the initial start position and the final end position of the cell track. 

In [None]:
def get_traj_and_displacement(tracks, indicies): 
    '''Find the trajectory and the distance moved for each of the 
    particles defined by the trackpy algorithm. '''
    
    # initialise variable. 
    distance_moved = []
    trajectory = []
    
    # for all the tracks identified
    for l in range(indicies.shape[0]):
        x_distance = tracks['X_pos'][indicies[l, 1]] - tracks['X_pos'][indicies[l, 0]] 
        y_distance = tracks['Y_pos'][indicies[l, 1]] - tracks['Y_pos'][indicies[l, 0]] 
        
        # calculate the total distance moved and trajectory. 
        distance_moved.append( np.sqrt( np.power(x_distance, 2) + np.power(y_distance, 2) ) )
        trajectory.append( np.arctan2(y_distance, x_distance) * 180/np.pi)
    
    return(distance_moved, trajectory)

### Cell 9
- A function to calculate the total distance travelled and the speed of cell movement. 
- Returns a numpy array of the distance travelled by the cells, and their speed of movement. 

Note: Distance is defined as the total distance moved by the cell through all frames, this is distinct from displacement, as this includes all points of cell movement between the start and end of a cell track.  
Speed is defined as distance travelled / length of time tracked. 

In [None]:
def distance_and_speed(tracks, ind, res, time_frame):
    '''Calculate the distance and speed that the cells 
    travel. '''

    all_rs_distance = []
    speed = []

    for i in range(index_2.shape[0]):
        x = np.array( tracks['X_pos'][ind[i][0]:ind[i][1]] )
        y = np.array( tracks['Y_pos'][ind[i][0]:ind[i][1]] )
        rs_distance = 0

        for j in range(x.shape[0]-1):
            # travel = SQRT ( (Delta X)^2 + (Delta Y)^2 ) 
            travel = np.sqrt( np.power( (x[j] - x[j+1]), 2 ) +  np.power( (y[j] - y[j+1]), 2 ) )
            rs_distance = rs_distance + travel

        all_rs_distance.append(rs_distance*res)
        
        speed.append( np.array(all_rs_distance[i]) / (time_frame * x.shape[0]) )

    all_rs_distance = np.round(np.array(all_rs_distance), 3)

    speed = np.round( speed, 3) 
    
    return(all_rs_distance, speed)

### Cell 10 
- A function to calculate the directionality of cell movement. 
- Returns an numpy array of directionality information for all cell tracks. 

In [None]:
def directionality_calc(distance_travelled, displace):
    '''Calcualted the directness of a cells motion. This is a very 
    simple calculation and is taken from ibidi Chemotaxis tool manual. 
    as of 12/01/2023, this information can be found at the following URL: 
    https://ibidi.com/img/cms/products/software/chemotaxis_tool/Manual_ChemotaxisTool_2_0_eng.pdf
    on page 15. '''
    
    direction = np.array(displace)/np.array(distance_travelled)
    
    return(direction)

## Cell 11
---
### Select the folder containing the raw images 
---
Running the following cell generates a file-dialogue box to select the folder containing the images to analyse. 

#### Important Note:  
This script is capable of taking in a user input of either: 
1. A folder containing lots of single frames from a time-lapse experiment
2. A Single file containing the whole time-lapse image stack as a .tif file. 

Both options are included below, the user is required to select which is appropriate for their analysis.  
- It is recommended to comment out the lines you do not need. 
- A block-comment can be performed (on windows) by selecting the code you wish to comment/uncomment and pushing "ctrl + /" at the same time. 

In [None]:
# Creates dialogue to ask directory
# Get the folder containing the image stack. 
root = tk.Tk()
root.attributes("-topmost", True)
root.withdraw() # Stops a second window opening

if input_type == 'folder_of_images': 
    ################
    # For a Folder containing lots of images.
    # Select the folder containing images. 
    SourceFolder = filedialog.askdirectory(title = 'Select Folder Containing images')
    # Get all .tif files in a folder. 
    file_list = get_image_file_names(SourceFolder)
    ################
    
else:
    #################
    ## For Selecting a single time-lapse image as a .tif.
    # Select the time-lapse image file. 
    SourceFolder = filedialog.askopenfilename(title = 'Select time-lapse image file')
    # Convert 
    file_list = SourceFolder 

##########
print(SourceFolder, len(file_list))

## Cell 12
--- 
### Select the Cellpose model
---
Two options here: 
1. Can select a pre-defined cellpose model from the file directory.
2. Can choose one of the pre-existing cellpose models. 


In [None]:
# Select user defined segmentation model. 
root = tk.Tk()
root.withdraw()
model_dir = tk.filedialog.askopenfilename(title = "Select Cellpose model")

# Output the names of all the pre-trained segmentation models that come as standard with cellpose. 
# models.MODEL_NAMES 

# # Select the pre-defined segmentation model.
# model_dir = 'livecell'

print(model_dir)

## Cell 14
--- 
### Run the cellpose segmentation on all the images in the folder. 
---

In [None]:
# Set the channels for grey-scale images. 
channel = [0, 0]
# initialise the variable. 
image_stack = []
stack_masks = []

# If the input was a single time-lapse image stack.  
if input_type == 'single_file': 
    # Image-stack already exists. 
    image_stack = tf.imread(file_list)
    loop_counter = image_stack.shape[0]
else:
    loop_counter = len(file_list)

# Using a user-defined model for analysis. 
model = models.CellposeModel(model_type = model_dir)

for j in tqdm( range( loop_counter ) ): 
    if input_type == 'single_file': 
        # Get a single frame from the image_stack. 
        phase_image_arr = np.array(image_stack[j, ...])
    else: 
        # Select a single file name. 
        image_file = file_list[j]
        # Extract the image stack data using the tifffile.  
        phase_image_stack = Image.open(SourceFolder + '//' + image_file)
        phase_image_arr = np.array(phase_image_stack)

        # # Create an image stack. 
        image_stack.append(phase_image_arr)
        
    # Get the mask and the flow from the selected image. 
    masks, flow = get_cellpose_analysis(phase_image_arr, channel, model)
    # Save the image masks. 
    stack_masks.append(masks)

stack_masks = np.array(stack_masks)
print('done')

## Cell 15
--- 
### Use Btracks to track the cell masks
--- 
Note: This results in tracking information being displayed in a pink. This is a normal part of running the Btrack algorithm, and not a error message. 

In [None]:
# stack_masks = np.array(stack_masks)

script_folder = os.getcwd()
configer_file = script_folder + '\\cell_config.json'
# Find the tracks for the individual cells in the image stack. 
cell_tracks, properties, graph = find_Btracks(stack_masks, configer_file, script_folder)

## Cell 16
---
### Analyse the results of the BTracks algorithm
---

In [None]:
# Get the tracks data into a pandas dataframe. 
data_arr = np.array(cell_tracks)
df_data = pd.DataFrame({'Track': data_arr[:, 0].astype(int), 'Frame': data_arr[:, 1].astype(int), 'X_pos': data_arr[:, 2], 'Y_pos': data_arr[:, 3]})

# Find the indicies which describe a change in the object track. 
index = particle_index(df_data)

# Filter the tracks by the minimum track length. 
filtered_tracks, filtered_props = filter_tracks(df_data, pd.DataFrame(properties), min_tracks, index)

# Find the indicies which describe a change in the object track. 
index_2 = particle_index(filtered_tracks)

# Find the total displacement of the cell and it's trajectory for all tracks in the stack. 
displacement, trajectory = get_traj_and_displacement(filtered_tracks, index_2)

time_per_frame = experiment_time / stack_masks.shape[0]

distance, speed = distance_and_speed(filtered_tracks, index_2, image_resolution, time_per_frame)
directionality = directionality_calc(distance, displacement)

## Cell 17
---
### View the total displacement and trajectory in Napari
---

In [None]:
# Create a napari window. 
viewer = napari.Viewer()

# Add the GFP masks to the napari viewer. 
viewer.add_image( np.array(image_stack) )
# Add the tracks to the napari viewer. 
viewer.add_labels(stack_masks)
# Add the tracks to the napari viewer. 
graph = {}
viewer.add_tracks(filtered_tracks, properties=dict(filtered_props), graph=graph)


## Cell 18
---
### Plot the Trajectory and Total Distance on a polar plot
---

In [None]:
# Plot the resultant analysis on a polar plot. 
fig, ax = plot.subplots(subplot_kw={'projection': 'polar'})
plot.polar(trajectory, np.array(displacement)*image_resolution, 'x', color = 'r')
plot.title('Cell trajectory (degrees) and Total Distance Traveled (microns)')

## Cell 19
---
### Set up the save folders and save the data. 
---

In [None]:
# Create Save path.
save_path = SourceFolder + '//Cellpose_and_BTrack_analysis'
# Create the foilder if it does not exist.
if os.path.exists(save_path) == False:
    os.mkdir(save_path)

## Cell 20
----- 
### Save the track information and the track analysis as .csv files and save the Polar plot as a .pdf.
----

In [None]:
# Create a pandas data frame for total distance and trajectory.
dict_track_info = {'Track': np.linspace(1, index_2.shape[0], index_2.shape[0]).astype(int), 
                  'Displacement (Microns)': np.array(displacement) * image_resolution, 
                  'Trajectory (Degrees)': trajectory, 
                  'Total Distance travelled (Microns)': distance,
                  'Speed (Microns / Minutes)': speed, 
                  'Directionality': directionality}

# Generate a pandas dataframe of the tracking data. 
track_info = pd.DataFrame(dict_track_info)

# Save the tracks to csv. 
filtered_tracks.to_csv(save_path + '\\' + 'Tracks_data_(min_track_length_' + str(min_tracks) + 
                       '_tracks).csv', index=False)

# Save the track analysis to csv. 
track_info.to_csv(save_path + '\\' + 'Analysis_of_Tracks_data_(min_track_length_' + str(min_tracks) + 
                       '_tracks).csv', index=False)

# Save the polar plot. 
fig.savefig(save_path + '\\Tracks_Polar_Plot.pdf', bbox_inches = 'tight')

## Cell 21
--- 
### Save the segmenation mask as individual .tif files.  
---

In [None]:
# Save the masks as individual frames. 
save_mask_path = SourceFolder + '//cellpose_mask'
if os.path.exists(save_mask_path) == False:
    os.mkdir(save_mask_path)

for i in range(len(file_list)):
    save_file_name = file_list[i]
    im = stack_masks[i]
    tf.imwrite(save_mask_path + '\\' + save_file_name[0:-4] + "_cellpose_masks.tif", 
         im)

## Cell 22
---
### Save the images as a stack
---

In [None]:
# Save the image stack as a .tif
tf.imwrite(save_path + '\\' + image_file[0:11] + "_image_stack.tif", 
         image_stack)

## Cell 23
---
### Save the segmenation mask as a .tif stack.
---

In [None]:
# Save the masks as a .tif file. 
tf.imwrite(save_path + '\\' + image_file[0:11] + "_cellpose_masks.tif", 
         stack_masks)