#### goal is to extract all features for each mother cell lineage into xarray format

In [None]:
from pathlib import Path
from trackmatexml import TrackmateXML
from matplotlib import pyplot as plt
import xarray as xr
import numpy as np
import pandas as pd
import os
import re
from collections import defaultdict
import pickle
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt

In [None]:
# setup directory to save combined xarray data
output_dir = '' # change this to be the dir you want to save the combined datasets in
os.makedirs(output_dir, exist_ok=True)
# import metadata from google doc and on ferets diameter
# ferets diameter data location
ferets_csv_location = ''
# extract data from google doc - download as csv
tracked_data_log = pd.read_csv('')
# filter based on 'Done' status - not case sensitive
filtered_tracked_data_log = tracked_data_log[tracked_data_log["MgM XML file correction status"].str.contains("Done", na=False,case=False)]

In [None]:
# first create a function to rename spotheaders in xml files - want to replace CH2 with GFP etc

def find_and_replace(input_string, replacements):
    """
    Identifies dictionary keys in the input_string and replaces them with corresponding values.

    Parameters:
    - input_string (str): The string to process.
    - replacements (dict): A dictionary where keys are substrings to find and values are replacements.

    Returns:
    - str: The modified string with replacements applied.
    - list: A list of keys that were found and replaced in the input_string.
    """
    # Create a regex pattern to match any of the dictionary keys
    pattern = re.compile('|'.join(re.escape(key) for key in replacements.keys()))

    # List to store keys that were found and replaced
    found_keys = []

    # Function to replace matched keys with corresponding values
    def replace_match(match):
        matched_key = match.group(0)
        found_keys.append(matched_key)
        return replacements[matched_key]

    # Perform the substitution
    result_string = pattern.sub(replace_match, input_string)

    return result_string, found_keys

# create a channel_dict that matches the CH1 etc used in trackmate to actual channel names
channel_dict = {'CH1': 'MASKS', 'CH2': 'GFP', 'CH3': 'TRITC', 'CH4': 'PHASE'}

In [None]:
parent_data_path = '/Volumes/salmonella/users/madison/'
frame_interval = 5 # mins
expt_dates = ['20240809', '20240918', '20241028']
# supply a string to identify the correct xml file to load as there is usually >1 per folder
xml_file_label = "*FINAL_mgm.xml"
# list to store xarray datasets for each position
datasets_list = []
for expt in expt_dates:
    print(f"Processing data for {expt}:")
    # create the path to the data
    data_path = Path(os.path.join(parent_data_path, str(expt)+'_DIMM', 'Channel_Crops'))
    subfolders_to_process = filtered_tracked_data_log.loc[filtered_tracked_data_log['Date']==int(expt), 'File Name'].tolist()
    subfolder_paths = [d for d in data_path.iterdir() if d.is_dir()]
    subfolder_paths = [s for s in subfolder_paths if s.name in subfolders_to_process]

    for subfolder in subfolder_paths:
        
        # extract mc lineage trackname 
        date_expt = int(expt)  # date in google docs is int64 not string
        mc_lineage_tracknumber=filtered_tracked_data_log.loc[(filtered_tracked_data_log['Date'] == date_expt) & (filtered_tracked_data_log['File Name'] == subfolder.name), 'Mother_Lineage_Trackname_03_08'].values[0].astype(int)
        mc_lineage_trackname= 'Track_'+str(mc_lineage_tracknumber)
        channel_width=filtered_tracked_data_log.loc[(filtered_tracked_data_log['Date'] == date_expt) & (filtered_tracked_data_log['File Name'] == subfolder.name), 'channel_width'].values[0].astype(float)
        # extract xml file 
        xml_files = list(subfolder.glob("*FINAL_mgm.xml"))
        if len(xml_files) == 1:
            xml_file_path = subfolder.joinpath(xml_files[0])
            
            tmxml = TrackmateXML()
            tmxml.loadfile(xml_file_path)
            
            # Step 0: get data for mother cell lineage
            tracks = tmxml.analysetrack(mc_lineage_trackname, duplicate_split=False, break_split=False)
            # because trackmatexml does this weird thing of adding the parent to the start of the daughter track we need to remove
            # the first spotid from every track that is not the parent
            for track in tracks:
                if track.parent != 0:
                    track.spotids = track.spotids[1:]

            # Step1: check if a fake cell was added and if yes - remove it
            check_fake_cell_added_xml = filtered_tracked_data_log.loc[(filtered_tracked_data_log['Date'] == date_expt) & (filtered_tracked_data_log['File Name'] == subfolder.name), 'Division issue?'].str.lower().values
            if check_fake_cell_added_xml == 'yes':
            
                fake_spot_id_to_remove = 100000 # note we give it this same number all the time
                for track in tracks:
                    # Filter out the spot_id_to_remove from the spotids array
                    track.spotids = np.array([sid for sid in track.spotids if sid != fake_spot_id_to_remove], dtype=track.spotids.dtype)

            # Step 1a: 

            # Step 2: extract all data for spotIDs into dictionary - also rename spotheaders to incorporate channel name
            spotheader_data_dict = defaultdict(list)

            for spot_header in tmxml.spotheader:
                for track in tracks:
                    updated_spot_header_name, _ = find_and_replace(spot_header, channel_dict)
                    spotheader_data_dict[updated_spot_header_name].append(tmxml.getproperty(track.spotids, spot_header))


            # Step3: check if a frame had to be added and if yes then extract data for mother cell and add it to the spotheader_data_dict
            # cells and parents extracted directly from tracked won't change
            
            check_extra_frame=filtered_tracked_data_log.loc[(filtered_tracked_data_log['Date'] == date_expt) & (filtered_tracked_data_log['File Name'] == subfolder.name), 'Needs frame appended'].str.lower().values[0]
            frame_number=filtered_tracked_data_log.loc[(filtered_tracked_data_log['Date'] == date_expt) & (filtered_tracked_data_log['File Name'] == subfolder.name), 'Transition Frame 0 in TRACKMATE'].values[0]
            
            if check_extra_frame =='yes':
                
                csv_path = subfolder / 'spots.csv'    
                extra_spot_data = pd.read_csv(csv_path)
                # the first 3 rows probably need to be deleted in all csvs but I will check for ID in label column and only delete the row if it is not found
                check_remove_rows = extra_spot_data.iloc[:3]['LABEL'].str.contains('ID', case=False, na=False)
                extra_spot_data = pd.concat([extra_spot_data.iloc[:3][check_remove_rows], extra_spot_data.iloc[3:]]).reset_index(drop=True)
                # convert format of columns so they are int or floats
                cols_as_int32 = ['ID']
                cols_as_object = ['LABEL']
                for column in extra_spot_data.columns:
                    if column in cols_as_int32:
                        # do this rather than just using .astype to avoid issues with nans or empty values
                        extra_spot_data[column] = pd.to_numeric(extra_spot_data[column], errors='coerce').astype('int32')
                    elif column not in cols_as_object:
                        extra_spot_data[column] = pd.to_numeric(extra_spot_data[column], errors='coerce').astype('float64')
                # find mother cell - should be lowest y when sort - and then extract that row
                extra_spot_data=extra_spot_data.sort_values(by='POSITION_Y').reset_index(drop=True)
                # create new df to store mc data
                # this solves the issue for this position - mother cell died so second cell was tracked
                #if (date_expt == 20241028) and (subfolder.name == 'XY19_crop5'):
                    ##else:
                extra_mc_data = extra_spot_data.iloc[[0]]
                # delete MANUAL_SPOT_COLOR and LABEL columns as they are not in full trackmate df
                extra_mc_data.drop(columns=['MANUAL_SPOT_COLOR','LABEL', 'TRACK_ID'], inplace=True)
                # change frame from 0 to be frame number listed in google doc
                extra_mc_data['FRAME']=frame_number 
                # insert column for ROI_N_POINTS as this wasn't included in single-timepoint data from TrackMate - set to np.nan
                extra_mc_data['ROI_N_POINTS']=np.nan
                # update column names of dataframe to match renaming of columns done in step 2 for spotheader_data_dict
                for col in extra_mc_data.columns:
                    updated_name, _ = find_and_replace(col, channel_dict)
                    extra_mc_data.rename(columns={col: updated_name}, inplace=True)
                # now that the column names match add the data for the extra frame to the first np.array based on matching column name with 
                # key in spotheader_data_dict
                for key in spotheader_data_dict:
                    if key in extra_mc_data.columns:
                        # Get the value from the DataFrame for the current key
                        extra_mc_data_value = extra_mc_data.at[0, key]
                                
                        # Insert this value at the beginning of the first NumPy array in the list
                        spotheader_data_dict[key][0] = np.insert(spotheader_data_dict[key][0], 0, extra_mc_data_value)

            # Step 4: add ferets diameter data extracted using regionprops
            
            ferets_data_df = pd.read_csv(ferets_csv_location)
            # first subset the regionprops data to only include relevant info for expt and position
            ferets_data_df_subset = ferets_data_df[(ferets_data_df["Expt"] == date_expt) & (ferets_data_df["Position_crop"] == subfolder.name)]
            # sometimes there are issues in mismatches between centroid values due to differences at 10th decimal place for e.g. to avoid this we will
            # round the values before comparing. I chose 4 decimal places somewhat arbitrarily 
            decimal_places = 4 # how to do the rounding
            # round the df centroid data 
            ferets_data_df_subset["POSITION_X"] = ferets_data_df_subset["POSITION_X"].round(decimal_places)
            ferets_data_df_subset["POSITION_Y"] = ferets_data_df_subset["POSITION_Y"].round(decimal_places)
            # create a lookup dictionary from df
            lookup = {(row.POSITION_X, row.POSITION_Y, row.frame): row.feret_diameter_max for row in ferets_data_df_subset.itertuples(index=False)}
            # use lookup to create new ferets key in spotheader_data_dict with values for each track across frames
            # also round centroid values  extracted from dict - but not in place 
            spotheader_data_dict["feret_diameter"] = [
                np.array([lookup.get((round(x, decimal_places), round(y, decimal_places), t), np.nan) for x, y, t in zip(spotheader_data_dict["POSITION_X"][i], spotheader_data_dict["POSITION_Y"][i], 
                                                                           spotheader_data_dict["FRAME"][i])])
                for i in range(len(spotheader_data_dict["FRAME"]))
            ]
            # there will be some nans as the ferets from regionprops was only calculated on relevant frames but trackmate data is on all frames
            # later we reduce the trackmate data to only include relevant frames so this will get rid of these nan

            # step 5 reset frames based on end and start
            # first extract start and end frames from google doc
            start_frame=filtered_tracked_data_log.loc[(filtered_tracked_data_log['Date'] == date_expt) & (filtered_tracked_data_log['File Name'] == subfolder.name), 'Transition Frame 0 in TRACKMATE'].values[0].astype(int)
            end_frame = filtered_tracked_data_log.loc[(filtered_tracked_data_log['Date'] == date_expt) & (filtered_tracked_data_log['File Name'] == subfolder.name), 'TrackmateEndFrame'].values[0].astype(int)
            num_frames = end_frame-start_frame+1
            
            # convert frames to int and then reset to 0 using start_frame
            spotheader_data_dict['FRAME'] = [arr.astype(int) - start_frame for arr in spotheader_data_dict['FRAME']]

            # step 6: insert a time column based on frame interval - this doesn't really need to be done here but it might be useful

            spotheader_data_dict['Time (mins)'] = [frame_array * frame_interval for frame_array in spotheader_data_dict['FRAME']]

            # Step 7: extract variables that don't change over time - cells, parents. Note these are not individual cells at each timepoint but 
            # the cell ID of the subtrack - this is why it is not affected by the adding back of a frame
            cells = [track.cell for track in tracks]
            parents = [track.parent for track in tracks]

            # Step 8: find cells that start after max frames and then remove all of these arrays from spotheader_dict and then also from cells, parents

            arrays_to_remove = []
            for i, frame_subset in enumerate(spotheader_data_dict['FRAME']):
            
                # for each array check if the first value is within the range 
                if frame_subset[0] > num_frames-1:
                    arrays_to_remove.append(i)
             
            for key in spotheader_data_dict:
                
                #spotheader_data_dict[key] = [arr for i, arr in enumerate(spotheader_data_dict[key]) if i not in to_delete]
            
                for i in sorted(arrays_to_remove, reverse=True):
                    del spotheader_data_dict[key][i]
                            
            # now fix cells and parents to remove those that are no longer in spotheader_data_dict
            for i in sorted(arrays_to_remove, reverse=True):
                del cells[i]
                del parents[i]   

    
            # Step 9: for a lot of data we need to store it in an array that will always be same size but for each cell it will only fill certain frames in the array
            # it might be easier to make an array, fill it with 0 where cell should be and nan where it was not present and then use it as a template for each 
            # feature that is measured
            
            holder_array = np.full((len(cells), num_frames), np.nan, dtype=np.float32)  
            
            for i, frame_subset in enumerate(spotheader_data_dict['FRAME']):
            
                # check if there are frames values to remove
                exceed_indices = np.where(frame_subset > num_frames-1)[0] # this is because it is 0 indexed
            
                if exceed_indices.size > 0:
                    # Get the first index where the element exceeds the threshold
                    first_exceed_index = exceed_indices[0]
                
                    # Slice the frame array up to this index (excluding the exceeding element)
                    new_frame_arr = frame_subset[:first_exceed_index]
               
                else:
                    # If no elements exceed the threshold, keep the array unchanged
            
                    new_frame_arr = frame_subset.copy()
              
            
                holder_array[i, new_frame_arr] = 0


            # Step 10: create arrays of data want to store in xarray format
            
            # gfp_median_intensity
            gfp_median_intensity_data = holder_array.copy()
            for i in range(len(spotheader_data_dict["MEDIAN_INTENSITY_GFP"])):
                t = spotheader_data_dict["FRAME"][i]
                v = spotheader_data_dict["MEDIAN_INTENSITY_GFP"][i]
                mask = t < num_frames
                t_filtered = t[mask]
                v_filtered = v[mask]
                gfp_median_intensity_data[i, t_filtered] = v_filtered
            
            # tritc_median_intensity
            tritc_median_intensity_data = holder_array.copy()
            for i in range(len(spotheader_data_dict["MEDIAN_INTENSITY_TRITC"])):
                t = spotheader_data_dict["FRAME"][i]
                v = spotheader_data_dict["MEDIAN_INTENSITY_TRITC"][i]
                mask = t < num_frames
                t_filtered = t[mask]
                v_filtered = v[mask]
                tritc_median_intensity_data[i, t_filtered] = v_filtered
            
            # area
            area_data = holder_array.copy()
            for i in range(len(spotheader_data_dict["AREA"])):
                t = spotheader_data_dict["FRAME"][i]
                v = spotheader_data_dict["AREA"][i]
                mask = t < num_frames
                t_filtered = t[mask]
                v_filtered = v[mask]
                area_data[i, t_filtered] = v_filtered
            
            # feret_diameter
            feret_diameter_data = holder_array.copy()
            for i in range(len(spotheader_data_dict["feret_diameter"])):
                t = spotheader_data_dict["FRAME"][i]
                v = spotheader_data_dict["feret_diameter"][i]
                mask = t < num_frames
                t_filtered = t[mask]
                v_filtered = v[mask]
                feret_diameter_data[i, t_filtered] = v_filtered
            
            # centroid_x 
            centroid_x_data = holder_array.copy()
            for i in range(len(spotheader_data_dict["POSITION_X"])):
                t = spotheader_data_dict["FRAME"][i]
                v = spotheader_data_dict["POSITION_X"][i]
                mask = t < num_frames
                t_filtered = t[mask]
                v_filtered = v[mask]
                centroid_x_data[i, t_filtered] = v_filtered
               
            # centroid_y 'POSITION_Y'
            centroid_y_data = holder_array.copy()
            for i in range(len(spotheader_data_dict["POSITION_Y"])):
                t = spotheader_data_dict["FRAME"][i]
                v = spotheader_data_dict["POSITION_Y"][i]
                mask = t < num_frames
                t_filtered = t[mask]
                v_filtered = v[mask]
                centroid_y_data[i, t_filtered] = v_filtered

            # Step 11: add daughter information
            # find time when each cell appears - note that any array still in the data_dict has to have appeared in our filtered frames
            # as any array starting after the endpoint has already been removed. We don't need to do any masking here since we only care
            # about the frame the cell appears at
            new_cell_appearances = [int(arr[0]) for arr in spotheader_data_dict['FRAME']] 
            daughters = holder_array.copy()
            for cell, parent, appearance_time in zip(cells, parents, new_cell_appearances):
                if parent != 0:
                    # here we need to use parent-1 as if parent is cell 1 it should be in row 0 of the array 
                    # appearance time is already based on zero indexing (i.e. frames go from 0-216) so we don't need to adjust that
                    daughters[parent-1, appearance_time]=cell 


            # Step 12: combine all dataarrays into a dataset for that lineage
            # metadata to store as coords
            position = subfolder.name
            trackname = mc_lineage_trackname
            channel_width= channel_width 
            num_cells=len(cells)
            timepoints = np.arange(0, num_frames*frame_interval, frame_interval)

            # convert all lineage data into xarray dataarrays
            gfp_median_intensity_array = xr.DataArray(gfp_median_intensity_data, dims=("cell", "time"), name="GFP_median_intensity")
            tritc_median_intensity_array = xr.DataArray(tritc_median_intensity_data, dims=("cell", "time"), name="TRITC_median_intensity")
            area_array = xr.DataArray(area_data, dims=("cell", "time"), name="Area")
            feret_diameter_array = xr.DataArray(feret_diameter_data, dims=("cell", "time"), name="feret_diameter")
            centroid_x_array = xr.DataArray(centroid_x_data, dims=("cell", "time"), name="centroid_x")
            centroid_y_array = xr.DataArray(centroid_y_data, dims=("cell", "time"), name="centroid_y")
            parents_array = xr.DataArray(parents, dims="cell", name="parents")
            daughters_array = xr.DataArray(daughters, dims=("cell", "time"), name="daughters")
            
            # now create a dataset of those dataarrays
            lineage_ds = xr.Dataset(
                data_vars={
                    "GFP_median_intensity": gfp_median_intensity_array,
                    "TRITC_median_intensity": tritc_median_intensity_array,
                    "Area": area_array,
                    "feret_diameter": feret_diameter_array,
                    "centroid_x": centroid_x_array,
                    "centroid_y": centroid_y_array,
                    "parents": parents_array,
                    "daughters": daughters_array,
                },
            
                coords={"experiment": expt, "position": position, "track": trackname, 
                        "cell": cells, "time": timepoints, "channel_width": channel_width},
            )
            # save the dataset
            lineage_ds.to_netcdf(Path(subfolder / 'lineage_dataset.nc'))
            # append the dataset for each position to the datasets_list
            datasets_list.append(lineage_ds)

In [None]:
# this function will be used to store all the ds in one dictionary and create a dataframe of metadata for all ds
def datasets_to_dict(datasets_list):
    # create a dict to store all the datasets and a list for the metadata (experiment, position, track)
    all_dataset_dict = {}
    all_metadata = []

    for ds in datasets_list:
        experiment = ds.coords["experiment"].item()
        position = ds.coords["position"].item()
        track = ds.coords["track"].item()
        channel_width = ds.coords["channel_width"].item()

        # the key for each ds in the dict is the unique combination of experiment, position, track
        # note they are split by two underscores 
        key = f"{experiment}__{position}__{track}"
        all_dataset_dict[key] = ds

        all_metadata.append({
            "unique_ID": key,
            "experiment": experiment,
            "position": position,
            "track": track,
            "channel_width":channel_width
        })

    all_metadata_df = pd.DataFrame(all_metadata)
    return all_dataset_dict, all_metadata_df

In [None]:
# combine all data
# combine individual ds into a dictionary and create a df with metadata on experiment, position, track
ds_all_dict, all_metadata_df = datasets_to_dict(datasets_list)

In [None]:
# save the dictionary and metadata dataframe
# use pickle to save the dictionary
with open(os.path.join(output_dir, 'all_datasets_dict_0530.pkl'), 'wb') as f:
    pickle.dump(ds_all_dict, f)
    
all_metadata_df.to_csv(os.path.join(output_dir, 'all_metadata_0530.csv'), index=False)

In [None]:
datasets_list[0].daughters