# Load images and attempt tracking
----

Author: Ryan Corbyn  
Date: 15/11/2023  

Here I will load the images from a folder into a numpy array and then display it in Napari to see what I am dealing with. 

---
### Import dependancies

In [None]:
import tifffile as tf 
import napari 

import numpy as np
import pandas as pd
import tqdm
import tkinter as tk
from tkinter import filedialog
import btrack


from cellpose import core, models, io, metrics
# from cellpose.contrib import openvino_utils
import os

----
### Select the image file to analyse. 

In [None]:
# Generate a file dialogue box to select folder 
# containing the images to analyse. 
root = tk.Tk()
root.attributes("-topmost", True)
root.withdraw() # Stops a second window opening
images_folder = filedialog.askdirectory(title = 'Select image Folder')
print(images_folder)

# Find the folder name that the script exists in.
script_folder = os.getcwd()
required_files = script_folder + '/required_files/'

---- 
### Select the file containing the plate layout

In [None]:
# Generate a file dialogue box to select folder 
# containing the images to analyse. 
root = tk.Tk()
root.attributes("-topmost", True)
root.withdraw() # Stops a second window opening
plate_layout_loc = filedialog.askopenfilename(title = 'Select the File containing the plate layout')
print(plate_layout_loc)

---
### Select the Cellpose model. 

In [None]:
# Select the cellpose model for segmentation. 
root = tk.Tk()
root.attributes("-topmost", True)
root.withdraw() # Stops a second window opening
seg_model = filedialog.askopenfilename(title = 'Select Cellpose file', initialdir=required_files)
print(seg_model)

--- 
### Input total experiment time in hour. 

In [None]:
# manually input the experiment time. 
total_experiment_time = 17 # in hours
total_time_in_minute = total_experiment_time * 60

---- 
### Get file names

In [None]:
def get_file_names(folder):
    '''Get all the file names from the folder.'''
    files = os.listdir(folder)
    return(files)

----
### Perform the BTracks Algorithm 

In [None]:
def find_Btracks(image_masks, configer_file, loc, save_loc, im_name):
    '''A method to find the tracks from the cellpose masks
    generated from the image stack. 
    this method makes use of the trackpy package.'''
    
    FEATURES = [
      "area",
      "major_axis_length",
      "minor_axis_length",
       "eccentricity",
      "solidity",
                ]

    objects = btrack.utils.segmentation_to_objects(image_masks, properties=tuple(FEATURES))
    
    with btrack.BayesianTracker() as tracker:
        # configure the tracker using a config file
        tracker.configure_from_file(loc + configer_file)

        tracker.features = FEATURES

        # append the objects to be tracked
        tracker.append(objects) 

        # set the volume (Z axis volume is set very large for 2D data)
        tracker.volume=((0, 2000), (0, 2000), (-1e5, 1e5))

        # track them (in interactive mode)
        tracker.track(step_size=100)

        # generate hypotheses and run the global optimizer
        tracker.optimize()

        # store the data in an HDF5 file
        tracker.export(save_loc + im_name + '_tracks.h5', obj_type='obj_type_1')
        print(loc)

        # get the tracks as a python list
        tracks = tracker.tracks#(tracking_updates=TRACKING_UPDATES)

        # get the data in a format for napari
        data, properties, graph = tracker.to_napari(ndim=2)
    
    return(data, properties, graph)

---
### Get the image data

In [None]:
def get_image_data(im_name):
    '''Get the image data using tifffile module '''
    
    image_stack = tf.imread(im_name)
    image_stack = np.array(image_stack)
    
    return(image_stack)

----
### Cellpose segmentation

In [None]:
def segment_cells(segment_model, im_dat):    
    
    # DEFINE CELLPOSE MODEL (without size model)
    model = models.CellposeModel(gpu=True, model_type=segment_model)
    # model = openvino_utils.to_openvino(model)
    # initialise variables. 
    cellpose_masks = []
    # Cellpose segmentation of all frames in the analysis. 
    for i in tqdm.tqdm(range( im_dat.shape[0] )) : 
        masks, flows, styles = model.eval(im_dat[i, ...], channels=[0, 0], diameter = None)
        cellpose_masks.append(masks)
    
    cellpose_masks = np.array(cellpose_masks)
    return(cellpose_masks)

----
### Perform tracking usng BTracks.

In [None]:
def track_cells(cellpose_masks, file):
    '''Run the BTrack algorithm.'''
    
    configer_file = 'cell_config.json'
    
    save_folder = required_files + '/Tracks/'
    if os.path.exists(save_folder) == False: 
        os.makedirs(save_folder)
    
    # Find the tracks for the individual cells in the image stack. 
    cell_tracks, properties, graph = find_Btracks(cellpose_masks, configer_file, required_files, save_folder, file)
    # Convert output to a dataframe. 
    tracks_df = pd.DataFrame(cell_tracks, columns = ['Track ID', 'Frame', 'x position', 'y position'])
    
    return(tracks_df, cell_tracks)

--- 
### Get the tracking data. 

In [None]:
def filter_track_data(tracks_df):    
    '''This method that filters the tracks produced by 
    the Btrack algorithm to remove those cells that have a 
    track length lower than the minimum number of tracks, which 
    in this case is 2. 
    Properties of the cell tracks are also calculated and saved to a 
    pandas dataframe. '''
    
    # Initialise variables. 
    track = 0 
    index = []
    
    # Find the indicies that correspond to a change in 
    # Track ID. 
    for i in range( len(tracks_df) ):
        if tracks_df['Track ID'].iloc[i] != track: 
            index.append(i)
            track = tracks_df['Track ID'].iloc[i]
    
    # Initialise
    track_data = pd.DataFrame()

    # Extract the x and y positions of the cell. 
    x_vals = tracks_df['x position']
    y_vals = tracks_df['y position']

    # Loop around all track IDs.
    for i in range( len(index) - 1 ): 
        # Ignore single frame tracks. 
        if index[i]+1 < index[i+1]:
            # Calculate total track length. 
            total_track_length = int(tracks_df['Frame'].iloc[index[i+1]-1] - 
                                  tracks_df['Frame'].iloc[index[i]]) 

            # Initialise and ensure cleared for each loop. 
            distace_per_frame = []
            speed = []
            # If tracks are longer than 1 frame. 
            if total_track_length > 1: 
                # Loop around all tracked frames. 
                for j in range( total_track_length - 1 ):
                    # Counter
                    ind = j + 1
                    # Calculate x and y displacement traveled per frame. 
                    x_distance_per_frame = np.power(x_vals[index[i]] - x_vals[index[i]+ind], 2)
                    y_distance_per_frame = np.power(y_vals[index[i]] - y_vals[index[i]+ind], 2)
                    # Cacluate total displacement per frame. 
                    distace_per_frame.append( np.sqrt(x_distance_per_frame 
                                                     + y_distance_per_frame) )
                    # Calculate the velocity of the cell per frame. 
                    speed.append( distace_per_frame[j] / frame_time )

                # Calculate distance traveled by the cell. 
                x_displacement = np.power(x_vals[ index[i]] - x_vals[ index[i + 1] -1], 2)
                y_displacement = np.power(y_vals[ index[i]] - y_vals[ index[i + 1] -1], 2)
                displacement = np.sqrt(x_displacement + y_displacement) 
                # Calculate the cell speed. 
                velocity = displacement / total_time_in_minute

                # Calcuate the total cell diplacement
                distance = np.sum(distace_per_frame)
                mean_speed = np.mean(speed) 
                speed_std = np.std(speed)

                # Create a dataframe for the tracked data. 
                track_dict = {'Track Number': [int(tracks_df['Track ID'].iloc[index[i]])], 
                             'Total Frames Tracked': [total_track_length], 
                             'Total Distance cell travels (Pixels)': [displacement], 
                             'Total Cell Speed (Pixels/Minute)': [mean_speed], 
                             'Standard Deviation Total Cell Speed (Pixels/Minute)': [speed_std], 
                             'Total Displacement cell travels (Pixels)': [distance], 
                             'Mean cell Velocity (Pixels/minute)': [velocity] }

                track_data = pd.concat([track_data, pd.DataFrame(track_dict)] )    

    return(track_data)

---- 
### Save data

In [None]:
def save_data(data_folder, file_name, drug, mask, all_tracks, filtered_df, well, tracker ): 
    '''Save the raw tracks data, 
    Save the filtered track data. 
    Save the cellpose masks as a .tif stack file. 
    '''
    
    # Create a folder to save the data in. 
    save_folder = os.path.dirname(data_folder) + '/' + '0.Analysis results/' + well + '/'
    
    if os.path.exists(save_folder) == False:
        os.makedirs(save_folder)
        
    tf.imwrite(save_folder + '/' + file_name[0:15]+ drug +'_cp_masks.tif', mask)
    
    all_tracks.to_csv(save_folder + '/' + file_name[0:15]+ drug +'_all_tracks.csv', index = False)
  
    filtered_df.to_csv(save_folder + '/' + file_name[0:15] + drug +'_filtered_track_analysis.csv', index = False)
    

---
# Run script 

In [None]:
# Get file names from folder. 
file_names = get_file_names(images_folder)

# Read the Plate layout from the file. 
plate_layout = pd.read_excel(plate_layout_loc + layout_names[0], header = None)

# For each file
for i in range( len(file_names) ) : 

    # Find the well the data comes from. 
    well = file_names[i][9:11]
    
    # Load the image data. 
    image_data = get_image_data(images_folder + '/' + file_names[i])
    
    # Calculate the time period between image frames. 
    frame_time = total_time_in_minute / (image_data.shape[0]-1)
    
    # Use cellpose to generate the cell segmentation masks. 
    cell_mask = segment_cells(seg_model, image_data)

    # Use the BTrack module to find cell tracks. 
    cell_tracks_df, tracks = track_cells(cell_mask, file_names[i][0:-4])
   
    # Filter the tracks data. 
    filtered_data = filter_track_data(cell_tracks_df)
    
    # Find out what drug was used in an experiment. 
    ind = np.where(plate_layout[0] == well)[0]
    treatment = np.array(plate_layout[1].iloc[ind])[0]
    
    # Save all the data. 
    save_data(images_folder, file_names[i], treatment, cell_mask, cell_tracks_df, filtered_data, well, tracks)

--- 
### View data through Napari

In [None]:
viewer = napari.Viewer()
viewer.add_image(image_data)
viewer.add_labels( np.array(cell_mask) )
viewer.add_tracks( np.array(tracks) )