In [None]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import norm
from tqdm import tqdm
import tifffile as tiff
import gc
import warnings

# -------------------------------------------
# Imports and Utility Function Definitions
# -------------------------------------------
import matplotlib.pyplot as plt

from napatrackmater.Trackvector import (
    TrackVector, SHAPE_FEATURES, DYNAMIC_FEATURES, SHAPE_DYNAMIC_FEATURES
)

warnings.filterwarnings("ignore", category=RuntimeWarning)

# -------------------------------------------
# Helper Functions
# -------------------------------------------

def calculate_position_time(frame, time_interval):
    """Calculate time in seconds for a given frame and time interval."""
    if frame <= 119:
        return frame * 1.49 * time_interval
    elif frame <= 219:
        return ((119 * 1.49 * time_interval) + ((frame - 119) * (2.00 * time_interval)))
    else:
        return ((119 * 1.49 * time_interval) + (100 * 2.00 * time_interval) + ((frame - 219) * (2.99 * time_interval)))

def get_nucleus_centroid(nuc_label, t, full_nuc_array):
    """Compute centroid of a nucleus in a 3D labeled array at time t."""
    nuc_label = int(nuc_label)
    t = int(t)
    z_coords, y_coords, x_coords = np.where(full_nuc_array[t] == nuc_label)
    if len(z_coords) > 0:
        return np.mean(x_coords), np.mean(y_coords), np.mean(z_coords)
    else:
        return None

def get_membrane_centroid(mem_label, t, full_mem_array):
    """Compute centroid of a membrane in a 3D labeled array at time t."""
    mem_label = int(mem_label)
    t = int(t)
    z_coords, y_coords, x_coords = np.where(full_mem_array[t] == mem_label)
    if len(z_coords) > 0:
        return np.mean(x_coords), np.mean(y_coords), np.mean(z_coords)
    else:
        return None

def calculate_distance(row, centroid, coord_cols=('x', 'y', 'z')):
    """Calculate Euclidean distance between row coordinates and centroid."""
    return np.sqrt(sum((row[c] - centroid[i])**2 for i, c in enumerate(coord_cols)))

def determine_cell_type(group):
    """Hierarchical rules to determine cell type for a group."""
    cell_types = set(group['cell_type'])
    if len(cell_types) == 1:
        single_cell_type = list(cell_types)[0]
        if pd.notna(single_cell_type):
            return single_cell_type
        else:
            return 'unknown'
    if 'basal' in cell_types and 'goblet' in cell_types:
        return 'unknown'
    if 'basal' in cell_types:
        return 'basal'
    if 'goblet' in cell_types:
        return 'goblet'
    return 'unknown'

# -------------------------------------------
# Directory and File Setup
# -------------------------------------------
channel = 'nuclei'
tracking_directory = 'nuclei_membrane_tracking'
data_frames_dir = os.path.join(tracking_directory, 'dataframes/')
save_dir = tracking_directory  # Assuming save_dir is same as tracking_directory

Path(save_dir).mkdir(exist_ok=True, parents=True)
Path(data_frames_dir).mkdir(exist_ok=True, parents=True)

save_file = os.path.join(data_frames_dir, f'results_dataframe_{channel}.csv')

# -------------------------------------------
# Load DataFrames and Images
# -------------------------------------------
global_shape_dynamic_dataframe = pd.read_csv(save_file, index_col=0)
global_shape_dynamic_dataframe['t_hours'] = [
    calculate_position_time(t, 100) / 3600 for t in global_shape_dynamic_dataframe['t'].to_list()
]

man_spots2 = pd.read_csv(
    'nuclei_membrane_tracking/MastodonTable_modeltesting-Spot.csv',
    encoding='latin-1', low_memory=False
)
man_links2 = pd.read_csv(
    'nuclei_membrane_tracking/MastodonTable_modeltesting-Link.csv',
    encoding='latin-1', low_memory=False
)
man_branches2 = pd.read_csv(
    'nuclei_membrane_tracking/MastodonTable_modeltesting-BranchSpot.csv',
    encoding='latin-1', low_memory=False
)

# -------------------------------------------
# Preprocess Spot, Link, and Branch DataFrames
# -------------------------------------------
man_spots2 = man_spots2.iloc[2:,]
man_links2 = man_links2.iloc[2:,]
man_branches2 = man_branches2.iloc[2:,]

for i in range(1, 4):
    man_spots2.iloc[:, i] = pd.to_numeric(man_spots2.iloc[:, i]).astype('float')
for i in range(4, 8):
    man_spots2.iloc[:, i] = pd.to_numeric(man_spots2.iloc[:, i]).astype('float')
for i in range(1, 7):
    man_links2.iloc[:, i] = pd.to_numeric(man_links2.iloc[:, i]).astype('float', errors='ignore')

man_spots2.rename(columns={
    "Spot position": "POSITION_X",
    "Spot position.1": "POSITION_Y",
    "Spot position.2": "POSITION_Z"
}, inplace=True)

merged1 = man_spots2.merge(
    man_links2[['Link target IDs', 'Link target IDs.1']],
    left_on='ID', right_on='Link target IDs.1', how="left"
).drop(columns=['Link target IDs.1'])
merged1.rename(columns={"Link target IDs": "Spot source ID"}, inplace=True)
concatenated_df = merged1
concatenated_branches = man_branches2

# -------------------------------------------
# Data Type Conversion and Track Relabeling
# -------------------------------------------
concatenated_df['ID'] = pd.to_numeric(concatenated_df['ID']).astype('int')
concatenated_df['Spot frame'] = pd.to_numeric(concatenated_df['Spot frame']).astype('int')
concatenated_df['Spot track ID'] = pd.to_numeric(concatenated_df['Spot track ID']).astype('int')
concatenated_df['Spot source ID'] = pd.to_numeric(concatenated_df['Spot source ID'], errors='coerce').astype('Int32')

unique_track_ids = concatenated_df['Spot track ID'].unique()
new_track_id_mapping = {old_id: new_id for new_id, old_id in enumerate(unique_track_ids)}
concatenated_df['Spot track ID relabelled'] = concatenated_df['Spot track ID'].map(new_track_id_mapping)

duplicated_source_ids = concatenated_df.loc[
    concatenated_df.duplicated(subset=['Spot source ID'], keep='first') &
    ~(pd.isna(concatenated_df['Spot source ID'])), 'Spot source ID'
].values
concatenated_df['Dividing'] = concatenated_df['ID'].isin(duplicated_source_ids)

unique_track_ids = concatenated_df['Spot track ID relabelled'].unique()
concatenated_branches['Branch N spots'] = pd.to_numeric(concatenated_branches['Branch N spots']).astype('int')
concatenated_branches['Branch depth'] = pd.to_numeric(concatenated_branches['Branch depth']).astype('int')

# -------------------------------------------
# Build Tracklets Dictionary
# -------------------------------------------
tracklets_dict = {}
for track_id in tqdm(unique_track_ids):
    og_track_id = concatenated_df[concatenated_df['Spot track ID relabelled'] == track_id]['Spot track ID'].values[0]
    track_df = concatenated_df[concatenated_df['Spot track ID'] == og_track_id]
    track_branches = concatenated_branches[concatenated_branches['Label'].isin(track_df['Label'].values)]
    track_branches.reset_index(inplace=True)
    for i in range(0, len(track_branches)):
        tracklet_id = i
        branch = track_branches.loc[i]
        branch_label = branch['Label']
        branch_spots_len = branch['Branch N spots']
        generation = branch['Branch depth']
        current_id = track_df.loc[track_df['Label'] == branch_label, 'ID'].values[0]
        current_source_id = track_df.loc[track_df['Label'] == branch_label, 'Spot source ID'].values[0]
        lineage_ids = [current_id]
        for _ in range(branch_spots_len - 1):
            next_id = track_df.loc[track_df['ID'] == current_source_id, 'ID'].values[0]
            next_source_id = track_df.loc[track_df['ID'] == current_source_id, 'Spot source ID'].values[0]
            lineage_ids.insert(0, next_id)
            if not pd.isna(next_source_id):
                current_source_id = next_source_id
            else:
                break
        track_str = str(track_id) + str(generation) + str(tracklet_id)
        tracklets_dict[track_str] = lineage_ids

id_to_tracklet = {id_: tracklet_id for tracklet_id, ids in tracklets_dict.items() for id_ in ids}
concatenated_df['Track ID'] = concatenated_df['ID'].map(id_to_tracklet)
concatenated_df['Track ID numeric'] = concatenated_df['Track ID'].apply(int)

# -------------------------------------------
# Coordinate Conversion and Image Extraction
# -------------------------------------------
concatenated_df['X_orig'] = concatenated_df['POSITION_X'].astype('float') / 0.691
concatenated_df['Y_orig'] = concatenated_df['POSITION_Y'].astype('float') / 0.691
concatenated_df['Z_orig'] = concatenated_df['POSITION_Z'].astype('float') / 2

gc.collect()
full_nuc_array = tiff.imread('seg_nuclei_timelapses/timelapse_sixth_dataset.tif')

concatenated_df.loc[:, 'Spot radius'] = pd.to_numeric(concatenated_df.loc[:, 'Spot radius'].astype(float))
concatenated_df2 = concatenated_df[concatenated_df['Spot radius'] != 5.0]

concatenated_df2['X_idx'] = concatenated_df2['X_orig'].round(0).astype(int)
concatenated_df2['Y_idx'] = concatenated_df2['Y_orig'].round(0).astype(int)
concatenated_df2['Z_idx'] = concatenated_df2['Z_orig'].round(0).astype(int)
concatenated_df2['t_idx'] = concatenated_df2['Spot frame'].astype(int)
concatenated_df2['nuc_label'] = concatenated_df2.apply(
    lambda row: full_nuc_array[row['t_idx'], row['Z_idx'], row['Y_idx'], row['X_idx']], axis=1
)
concatenated_df2 = concatenated_df2.drop(columns=['X_idx', 'Y_idx', 'Z_idx', 't_idx'])
concatenated_df3 = concatenated_df2[concatenated_df2['nuc_label'] != 0]

# -------------------------------------------
# Manual Annotation and Cell Type Assignment
# -------------------------------------------
gt_df = pd.read_csv('manual_labeled_cells_sixth_dataset_updated.csv')
gt_df_crop = gt_df.loc[
    (gt_df['cell_type'] == 'basal') | (gt_df['cell_type'] == 'goblet'),
    ['Centroid.X', 'Centroid.Y', 'Centroid.Z', 'cell_type']
]
gt_df_crop['Label'] = gt_df_crop.apply(
    lambda row: full_nuc_array[-1][int(row['Centroid.Z']), int(row['Centroid.Y']), int(row['Centroid.X'])], axis=1
)

man_spots_ann_final = concatenated_df3[concatenated_df3['Spot frame'] == 359]
concatenated_merged_final = pd.merge(
    man_spots_ann_final, gt_df[['Label', 'cell_type']],
    left_on='nuc_label', right_on='Label', how='left'
)

mapping_dict = concatenated_merged_final.groupby('Spot track ID relabelled').apply(determine_cell_type).to_dict()
concatenated_df3['cell_type'] = concatenated_df3['Spot track ID relabelled'].map(mapping_dict)

# Further assignment of 'cell_type' based on other columns
concatenated_df3['cell_type'] = np.where(
    concatenated_df3['celltypes'] == "1", 'mcc',
    np.where(
        concatenated_df3['celltypes.3'] == "1", 'ssc',
        np.where(
            concatenated_df3['celltypes.5'] == "1", 'ic',
            np.where(
                concatenated_df3['celltypes.7'] == "1", 'basal',
                np.where(
                    concatenated_df3['celltypes.9'] == "1", 'goblet',
                    np.where(
                        concatenated_df3['celltypes.13'] == "1", 'basal',
                        np.where(
                            concatenated_df3['celltypes.16'] == "1", 'goblet',
                            concatenated_df3['cell_type']
                        )
                    )
                )
            )
        )
    )
)
concatenated_df3['cell_type'] = concatenated_df3['cell_type'].fillna('unknown')

# Annotation column
concatenated_df3['annotation'] = pd.NA
concatenated_df3['annotation'] = np.where(
    concatenated_df3['celltypes'] == "1", 'manual',
    np.where(
        concatenated_df3['celltypes.3'] == "1", 'manual',
        np.where(
            concatenated_df3['celltypes.5'] == "1", 'manual',
            np.where(
                concatenated_df3['celltypes.7'] == "1", 'manual',
                np.where(
                    concatenated_df3['celltypes.9'] == "1", 'manual',
                    np.where(
                        concatenated_df3['celltypes.13'] == "1", 'manual',
                        np.where(
                            concatenated_df3['celltypes.9'] == "1", 'manual',
                            concatenated_df3['annotation']
                        )
                    )
                )
            )
        )
    )
)
concatenated_df3['annotation'] = concatenated_df3['annotation'].fillna('automatic')
concatenated_df3['t_hours'] = [
    calculate_position_time(t, 100) / 3600 for t in concatenated_df3['Spot frame'].to_list()
]

# -------------------------------------------
# Prepare Final DataFrames for Export
# -------------------------------------------
concatenated_df4 = concatenated_df3[[
    'ID', 'Spot frame', 't_hours', 'POSITION_X', 'POSITION_Y', 'POSITION_Z',
    'Spot track ID', 'Spot source ID', 'Spot track ID relabelled', 'Dividing',
    'Track ID', 'Track ID numeric', 'X_orig', 'Y_orig', 'Z_orig', 'nuc_label',
    'cell_type', 'annotation'
]]
concatenated_df4_manual = concatenated_df4[concatenated_df4['annotation'] == 'manual']
concatenated_df4_manual.to_csv('all_manually_labelled_tracks_sixth_dataset.csv')

# -------------------------------------------
# Assign Nucleus Labels to Global DataFrame
# -------------------------------------------
global_shape_dynamic_dataframe['X_idx'] = global_shape_dynamic_dataframe['x'].astype('int')
global_shape_dynamic_dataframe['Y_idx'] = global_shape_dynamic_dataframe['y'].astype('int')
global_shape_dynamic_dataframe['Z_idx'] = global_shape_dynamic_dataframe['z'].astype('int')
global_shape_dynamic_dataframe['t_idx'] = global_shape_dynamic_dataframe['t'].astype('int')
global_shape_dynamic_dataframe['nuc_label'] = global_shape_dynamic_dataframe.apply(
    lambda row: full_nuc_array[
        row['t_idx'], row['Z_idx'], row['Y_idx'], row['X_idx']
    ], axis=1
)
global_shape_dynamic_dataframe = global_shape_dynamic_dataframe.drop(
    columns=['X_idx', 'Y_idx', 'Z_idx', 't_idx']
)

# -------------------------------------------
# Remove Duplicates by Closest to Centroid (Nucleus)
# -------------------------------------------
for t_value, nuc_label_value in tqdm(
    global_shape_dynamic_dataframe[
        global_shape_dynamic_dataframe.duplicated(subset=['t', 'nuc_label'], keep=False)
    ][['t', 'nuc_label']].drop_duplicates().values
):
    centroid = get_nucleus_centroid(nuc_label_value, t_value, full_nuc_array)
    if centroid is not None:
        subset_df = global_shape_dynamic_dataframe[
            (global_shape_dynamic_dataframe['t'] == t_value) &
            (global_shape_dynamic_dataframe['nuc_label'] == nuc_label_value)
        ]
        subset_df['distance_to_centroid'] = subset_df.apply(
            lambda row: calculate_distance(row, centroid, ('x', 'y', 'z')), axis=1
        )
        closest_idx = subset_df['distance_to_centroid'].idxmin()
        global_shape_dynamic_dataframe.loc[
            subset_df.index.difference([closest_idx]), 'nuc_label'
        ] = np.nan

for t_value, nuc_label_value in tqdm(
    concatenated_df3[
        concatenated_df3.duplicated(subset=['Spot frame', 'nuc_label'], keep=False)
    ][['Spot frame', 'nuc_label']].drop_duplicates().values
):
    centroid = get_nucleus_centroid(nuc_label_value, t_value, full_nuc_array)
    if centroid is not None:
        subset_df = concatenated_df3[
            (concatenated_df3['Spot frame'] == t_value) &
            (concatenated_df3['nuc_label'] == nuc_label_value)
        ]
        subset_df['distance_to_centroid'] = subset_df.apply(
            lambda row: calculate_distance(row, centroid, ('X_orig', 'Y_orig', 'Z_orig')), axis=1
        )
        closest_idx = subset_df['distance_to_centroid'].idxmin()
        concatenated_df3.loc[
            subset_df.index.difference([closest_idx]), 'nuc_label'
        ] = np.nan

filtered_concatenated_df = concatenated_df3[pd.notna(concatenated_df3['nuc_label'])]
filtered_concatenated_df2 = filtered_concatenated_df[[
    'ID', 'Spot frame', 'POSITION_X', 'POSITION_Y', 'POSITION_Z',
    'Spot source ID', 'Spot track ID relabelled', 'Track ID', 'Track ID numeric',
    'cell_type', 'X_orig', 'Y_orig', 'Z_orig', 'nuc_label', 'annotation'
]]
filtered_concatenated_df2.to_csv('for_Ziwei_new2.csv')

# -------------------------------------------
# Merge with Global Shape/Dynamic DataFrame
# -------------------------------------------
global_shape_dynamic_dataframe_merge = pd.merge(
    global_shape_dynamic_dataframe,
    filtered_concatenated_df2,
    left_on=['t', 'nuc_label'], right_on=['Spot frame', 'nuc_label'], how='right'
)
global_shape_dynamic_dataframe_merge = global_shape_dynamic_dataframe_merge.drop_duplicates(
    subset=['x', 'y', 'z', 't', 'Track ID_x'], keep='first'
).reset_index(drop=True)
global_shape_dynamic_dataframe_merge['t_hours'] = [
    calculate_position_time(t, 100) / 3600 for t in global_shape_dynamic_dataframe_merge['Spot frame'].to_list()
]
global_shape_dynamic_dataframe_merge.to_csv(
    'nuclei_membrane_tracking/nuclei_manual_dataset_withid.csv'
)

# -------------------------------------------
# Membrane Data Processing
# -------------------------------------------
del full_nuc_array
gc.collect()

nuc_mem_overlap = pd.read_csv('mem_nuc_overlap_table.csv')
global_shape_dynamic_dataframe_membrane = pd.read_csv(
    'nuclei_membrane_tracking/membrane_results_df.csv'
)
full_mem_array = tiff.imread('seg_membrane_timelapses_fixed/timelapse_sixth_dataset.tif')

global_shape_dynamic_dataframe_membrane['X_idx'] = global_shape_dynamic_dataframe_membrane['x'].astype('int')
global_shape_dynamic_dataframe_membrane['Y_idx'] = global_shape_dynamic_dataframe_membrane['y'].astype('int')
global_shape_dynamic_dataframe_membrane['Z_idx'] = global_shape_dynamic_dataframe_membrane['z'].astype('int')
global_shape_dynamic_dataframe_membrane['t_idx'] = global_shape_dynamic_dataframe_membrane['t'].astype('int')
global_shape_dynamic_dataframe_membrane['mem_label'] = global_shape_dynamic_dataframe_membrane.apply(
    lambda row: full_mem_array[
        row['t_idx'], row['Z_idx'], row['Y_idx'], row['X_idx']
    ], axis=1
)
global_shape_dynamic_dataframe_membrane = global_shape_dynamic_dataframe_membrane.drop(
    columns=['X_idx', 'Y_idx', 'Z_idx', 't_idx']
)
global_shape_dynamic_dataframe_membrane2 = global_shape_dynamic_dataframe_membrane[
    global_shape_dynamic_dataframe_membrane['mem_label'] != 0
]

for t_value, mem_label_value in tqdm(
    global_shape_dynamic_dataframe_membrane2[
        global_shape_dynamic_dataframe_membrane2.duplicated(subset=['t', 'mem_label'], keep=False)
    ][['t', 'mem_label']].drop_duplicates().values
):
    centroid = get_membrane_centroid(mem_label_value, t_value, full_mem_array)
    if centroid is not None:
        subset_df = global_shape_dynamic_dataframe_membrane2[
            (global_shape_dynamic_dataframe_membrane2['t'] == t_value) &
            (global_shape_dynamic_dataframe_membrane2['mem_label'] == mem_label_value)
        ]
        subset_df['distance_to_centroid'] = subset_df.apply(
            lambda row: calculate_distance(row, centroid, ('x', 'y', 'z')), axis=1
        )
        closest_idx = subset_df['distance_to_centroid'].idxmin()
        global_shape_dynamic_dataframe_membrane2.loc[
            subset_df.index.difference([closest_idx]), 'mem_label'
        ] = np.nan

global_shape_dynamic_dataframe_membrane2 = global_shape_dynamic_dataframe_membrane2.merge(
    nuc_mem_overlap[['mem_label', 'nuc_label', 't']], on=['t', 'mem_label'], how='left'
)
global_shape_dynamic_dataframe_merge_membrane = pd.merge(
    global_shape_dynamic_dataframe_membrane2,
    filtered_concatenated_df2[[
        'Spot frame', 'POSITION_X', 'POSITION_Y', 'POSITION_Z',
        'Spot track ID relabelled', 'Track ID', 'Track ID numeric',
        'cell_type', 'X_orig', 'Y_orig', 'Z_orig', 'nuc_label', 'annotation'
    ]],
    left_on=['t', 'nuc_label'], right_on=['Spot frame', 'nuc_label'], how='right'
)
global_shape_dynamic_dataframe_merge_membrane.to_csv(
    'nuclei_membrane_tracking/membrane_manual_dataset.csv'
)