In [None]:
from pathlib import Path
from trackmatexml import TrackmateXML
import seaborn as sns
from matplotlib import pyplot as plt
import xarray as xr
import numpy as np
import pandas as pd
import os
import re
from collections import defaultdict
import pickle
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt

# makes figures look better in Jupyter
sns.set_context('talk')
sns.set_style("ticks")
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

In [None]:
# load the data
output_dir = '/Volumes/salmonella/users/madison/2024_DIMM_MultirepAnalysis/combined_datasets'
plot_output = '/Volumes/salmonella/users/madison/2024_DIMM_MultirepAnalysis/LineagePlotting/Final_Plots'
with open(os.path.join(output_dir, 'all_datasets_dict.pkl'), 'rb') as f:
    ds_all_dict = pickle.load(f)

all_metadata_df = pd.read_csv(os.path.join(output_dir, 'all_metadata.csv'))


In [None]:
# check everything looks ok
all_metadata_df.head()
print(len(all_meta_data_df))
print(len(ds_all_dict))

In [None]:
#optional subsetting if you need to
keys_to_include = { '20240809__XY05_crop3__Track_157' }
subset_dict = {k: ds_all_dict[k] for k in keys_to_include}
subset_dict['20240809__XY05_crop3__Track_157']

In [None]:
##HARDCODED INFO BASED ON PRIOR ANALYSES IN BACKGROUND AND T=0 INTENSITIES
#Do not change, use these #'s in any downstream analysis
gfp_background_dict = {'20240809':0, '20240918':99.11, '20241028':66.29}
ruby_background_dict = {'20240809':6.7898, '20240918':155.8976, '20241028':0}
gfp_threshold = 1515.608
pixel_conversion = 0.13
log_gfp_th = np.log(gfp_threshold)
IDs_to_drop = ['20241028__XY24_crop2__Track_5', '20240918__XY07_crop1__Track_2'] 
#remove any keys that shouldn't be included in dataset,  one here was a perfect duplicate lineage (likely made two of the same crop), other had 3 top cells die and no mother div. 
for id in IDs_to_drop:
        if id in ds_all_dict:
            del ds_all_dict[id]


In [None]:
def add_processed_data_to_dataset_dict(dataset_dict, metadata_df, data_var_to_modify, new_var_name, 
                                    interpolate = True, smooth = True, log_trans = False, 
                                    fluor_background_dict = None, window_size = 3, save_dict=False):
    """
    goal here is to add a new data variable to the combined dataset dictionary
    dataset_dict is the dictionary of all the xarray datasets for each position
    data_var_to_modify is the data you want to you want to modify to create the new data variable
    new_var_name is name of the new data variable
    window_size is the size of moving window to use for the smoothing
    save_dict set to True if you want to save the updated dictionary - default is False - it is probably best to test things out first
    and you can always resave the dictionary outside of the function
    """

    for unique_id, ds in dataset_dict.items():

        print(f"Running analysis on lineage: {unique_id}")
        # ds is now our dataset - we don't need to select it
 
        # first get the GFP data or whatever variable you want to modify
        data = ds[data_var_to_modify]
        # normalize the data based on the background
        experiment_name = ds.coords['experiment'].item()
        print(f"Experiment name for {unique_id}: '{experiment_name}'")
        if fluor_background_dict != None:
            background_value_to_add = fluor_background_dict.get(experiment_name, 0)
            print(f"{unique_id} background = {background_value_to_add}")
            background_adjusted_data = data + background_value_to_add
        else:
            background_adjusted_data = data
        # interpolate the data - set fill_value to None to not go beyond edges I think
        if interpolate == True:
            data_interpolated = background_adjusted_data.interpolate_na(dim='time', method='linear', fill_value=None,
                                                                    max_gap=1, use_coordinate=False)
        else:
            data_interpolated = background_adjusted_data
        if smooth == True: 
            data_smoothed = data_interpolated.rolling(time=window_size, center=False, min_periods=1).mean()
        else:
            data_smoothed = data_interpolated
        
        # add the data to the ds
        ds[new_var_name] = data_smoothed
        #log transform and make a new array if you want one
        if log_trans ==True:
            ds[new_var_name+'_transformed'] = xr.apply_ufunc(np.log1p, ds[new_var_name])
        else:
            continue
    # save the data if save_dict = True
    if save_dict:
        with open(os.path.join(output_dir, 'all_datasets_dict.pkl'), 'wb') as f:
            pickle.dump(ds_all_dict, f)

    return dataset_dict

In [None]:
def pixeltoum_conversion(dataset_dict, metadata_df, data_var_to_modify, new_var_name, pixel_conversion, save_dict=False):
    """
    convert a size based metric to um based on pixel conversion
    """

    for unique_id, ds in dataset_dict.items():

        print(f"Running analysis on lineage: {unique_id}")
        data = ds[data_var_to_modify]
        converted_data = data*pixel_conversion
        ds[new_var_name] = converted_data
 
    # save the data if save_dict = True
    if save_dict:
        with open(os.path.join(output_dir, 'all_datasets_dict.pkl'), 'wb') as f:
            pickle.dump(ds_all_dict, f)

    return dataset_dict

In [None]:
pixeltoum_conversion(ds_all_dict, all_metadata_df, 'feret_diameter', 'feret_diameter_um', pixel_conversion)

In [None]:
add_processed_data_to_dataset_dict(ds_all_dict, all_metadata_df, 'GFP_median_intensity', 
                                    'GFP_median_intensity_processed', fluor_background_dict = gfp_background_dict, log_trans = True,  
                                    window_size = 3, save_dict=False)

In [None]:
add_processed_data_to_dataset_dict(ds_all_dict, all_metadata_df, 'TRITC_median_intensity', 
                                    'TRITC_median_intensity_processed', fluor_background_dict = ruby_background_dict, log_trans = True,  
                                    window_size = 3, save_dict=False)

In [None]:
def extract_mother_variables(dataset_dict, metadata, timepoints, variables=None):
    """ 
    dataset_dict is the dictionary of all the xarray datasets for each position
    metadata is the df that stores all the metadata for each positions
    timepoints is the list of timepoints you are measuring over - since it is the same for every lineage it makes more sense to supply it
    than to extract it each time
    variables should be a list ['GFP_median_intensity', 'Area']. List can have one or more elements
    """
    
    all_mothers = []
    for unique_id, ds in dataset_dict.items():
    
        print(f"Running analysis on lineage: {unique_id}")
        # ds is now our dataset - we don't need to select it
        # first find the corresponding metadata - we could split the key name but I feel like this could lead to unanticipated problems
        ds_metadata = metadata.loc[metadata["unique_ID"] == unique_id].iloc[0]

        # find mother cell - based on trackmate it should always be the cell that has parent = 0
        mother_cell = ds.parents.where(ds.parents == 0, drop=True).cell.values
        # this shouldn't happen but print an error if there is no mother cell
        if len(mother_cell) == 0:
            print(f"No mother cell found in {ds_metadata.experiment,}, {ds_metadata.position}.")
            continue
        mother_cell_idx = mother_cell[0]

        # use a dictionary to store data before creating a df 
        mc_data_dict = {
            "unique_ID": ds_metadata.unique_ID,
            "experiment": ds_metadata.experiment,
            "position": ds_metadata.position,
            "mother_cell": mother_cell_idx,
            "time": timepoints
        }
        
        # extract features over time listed in variables
        for var in variables:
            data_to_store = ds[var].sel(cell=mother_cell_idx).values
            # change to 1D shape
            data_to_store_reshaped = np.squeeze(data_to_store)
            # add to dictionary
            mc_data_dict[var] = data_to_store_reshaped

        # create a temp df to store the data
        df_temp = pd.DataFrame(mc_data_dict)
        # store each df_temp in a list
        all_mothers.append(df_temp)

    df_all = pd.concat(all_mothers, ignore_index=True)
    return df_all

In [None]:
# use this to select out data on mother cell
# since timepoints should be the same for all ds we are just going to extract them from the first entry in the ds_all_dict
first_key = next(iter(ds_all_dict))
first_dataset = ds_all_dict[first_key]
timepoints = first_dataset.coords['time'].values
variables_to_extract = ["GFP_median_intensity_processed_transformed", 'TRITC_median_intensity_processed_transformed', 'feret_diameter_um']
mothers_df = extract_mother_variables(ds_all_dict, all_metadata_df, timepoints, variables=variables_to_extract)
mothers_df.to_csv(plot_output + '/mothers_df.csv')

In [None]:
def slope_calculations(dataset_dict, metadata, data_var_threshold, slope_threshold=5, duration_thresh=5, data_var=None, lifetime_min=72):
    """ 
    ##Final slope calculations for reporter
    this will take a given dataset variable and ID start and stop times. Keep in mind whether you use log transformed data or not. 
    -the data_var_threshold is where you are considering something 'positive' (ie the GFP+ thresh)
    -slope_threshold is the minimum slope it must cross to be considered an increase
    -duration_thresh is how many frames it mus drop below that slope_threshold to be considered the end of an increase
    -lifetime_min is just a value for categorization at the end - ie you don't want to call something non responsive if it 
     didn't exist for this amount of time. 
    """
    all_slope_data = []
    from scipy.stats import rankdata
    for unique_id, ds in dataset_dict.items():
        print(f"Running analysis on lineage: {unique_id}")
        ds_metadata = metadata.loc[metadata["unique_ID"] == unique_id].iloc[0]

        slope = ds[data_var].diff(dim='time')
        slope_clean = slope.fillna(-np.inf)
        above = slope_clean > slope_threshold

        cell_records = []

        for cell in ds.cell.values:
            cell_data = ds[data_var].sel(cell=cell)
            parent = ds['parents'].sel(cell=cell).values.item()
            lifetime = np.count_nonzero(~np.isnan(cell_data))
            positive = cell_data > data_var_threshold
            crosses_gfp_th = positive.any().item()
            gfp_idx = positive.idxmax(dim='time').values if crosses_gfp_th else np.nan
            max_var = cell_data.max(dim='time').values
            max_var_time = cell_data.idxmax(dim='time').values
            post_max_data = cell_data.sel(time=slice(max_var_time, None))
            min_post_max = post_max_data.min(dim='time').values
            min_time_after_max = post_max_data.idxmin(dim='time').values
            
            first_time = cell_data.notnull().idxmax(dim="time").values
            print("first_time:", first_time)
            print("first_time in ds['time']:", first_time in ds['time'].values)
            print("Check if value exists:", ds['centroid_y'].sel(time=first_time, cell=cell))

            appearance_position_value = ds['centroid_y'].sel(time=first_time, cell=cell).values
            print(appearance_position_value)
            above_cell = above.sel(cell=cell)
            crosses_slope_th = above_cell.any().item()
            start_idx = above_cell.idxmax(dim='time').values if crosses_slope_th else np.nan
            if not np.isnan(start_idx):
                start_idx_var_value = cell_data.sel(time=start_idx).values
                start_idx_position_value = ds['centroid_y'].sel(time=start_idx, cell=cell).values
                #position_data_all_cells = ds['centroid_y'].sel(time=start_idx) #gives centroid_y array of all cells in the unique_id
                #pos_values = position_data_all_cells.values
                #ranks = rankdata(pos_values, method='min')  # lower values = lower rank
                #cell_names = position_data_all_cells['cell'].values
                #rank_dict = dict(zip(cell_names, ranks))
                #cell_pos_rank = rank_dict[cell]

            else: 
                start_idx_var_value=np.nan
                start_idx_position_value = np.nan


            if not np.isnan(gfp_idx) and gfp_idx != first_time:
                below = slope.sel(cell=cell) < slope_threshold
                rolling_sum = below.rolling(time=duration_thresh).sum()
                sustained_below = rolling_sum == duration_thresh
                time_1d = ds['time']

                valid_after_start = time_1d >= start_idx
                sustained_below_after_start = sustained_below.where(valid_after_start, drop=False)
                end_idx = sustained_below_after_start.idxmax(dim='time').values if crosses_slope_th else np.nan
                if not np.isnan(end_idx):
                    end_idx_var_value = cell_data.sel(time=end_idx).values
                else: 
                    end_idx_var_value=np.nan

                max_slope = slope_clean.sel(cell=cell).max(dim='time').values
                max_slope_idx = slope_clean.sel(cell=cell).argmax(dim='time').values
                time_of_max_slope = ds['time'].isel(time=max_slope_idx).values
                born_on = False
            
            elif not np.isnan(gfp_idx) and gfp_idx==first_time:
                start_idx = np.nan
                start_idx_var_value = np.nan
                start_idx_position_value = np.nan
                end_idx = np.nan
                end_idx_var_value = np.nan
                gfp_idx = np.nan
                max_slope = np.nan
                time_of_max_slope = np.nan
                inc_magnitude = np.nan
                born_on = True

            else: 
                start_idx = np.nan
                start_idx_var_value = np.nan
                start_idx_position_value = np.nan
                end_idx = np.nan
                end_idx_var_value = np.nan
                gfp_idx = np.nan
                max_slope = np.nan
                time_of_max_slope = np.nan
                inc_magnitude = np.nan
                born_on = False

            cell_records.append({
                "unique_ID": ds_metadata.unique_ID,
                "experiment": ds_metadata.experiment,
                "position": ds_metadata.position,
                "track": ds_metadata.track,
                "cell_id": cell,
                "parent": parent,
                "max_slope": max_slope,
                "appearance_time" : first_time,
                "time_of_max_slope": time_of_max_slope,
                "start_inc": start_idx,
                "start_idx_var_value": start_idx_var_value, 
                "y_pos_start_inc": start_idx_position_value,
                "y_pos_appearance": appearance_position_value,
                "gfp_on": gfp_idx,
                "end_inc": end_idx,
                "end_idx_var_value": end_idx_var_value, 
                "born_on": born_on,
                "lifetime": lifetime, 
                "slope_th": slope_threshold, 
                "max_gfp": max_var, 
                "max_gfp_time": max_var_time,
                "min_post_max": min_post_max,
                "min_time_post_max": min_time_after_max
            })

        all_slope_data.append(pd.DataFrame(cell_records))

    df_slope_all = pd.concat(all_slope_data, ignore_index=True)
    df_slope_all['switch_like'] = (df_slope_all['start_inc'] < df_slope_all['gfp_on']) & (df_slope_all['gfp_on'] < df_slope_all['end_inc']) 
    df_slope_all['becomes_positive'] = ((df_slope_all['born_on'] == False) & (df_slope_all['gfp_on'].notna()))
    df_slope_all["duration"] = df_slope_all["end_inc"] - df_slope_all["start_inc"]
    df_slope_all['magnitude_inc'] = df_slope_all['end_idx_var_value'] - df_slope_all['start_idx_var_value']
    
    df_slope_all.loc[df_slope_all['born_on'] == True, 'category'] = None
    df_slope_all.loc[((df_slope_all['switch_like'] == True) & (df_slope_all['lifetime'] > lifetime_min)), 'category'] = 'switch'
    df_slope_all.loc[((df_slope_all['switch_like'] == False) & (df_slope_all['becomes_positive'] == True) & (df_slope_all['lifetime'] > lifetime_min)), 'category'] = 'gradual_on'
    df_slope_all.loc[((df_slope_all['switch_like'] == False) & (df_slope_all['becomes_positive'] == False) & (df_slope_all['born_on'] == False) & (df_slope_all['lifetime'] > lifetime_min)), 'category'] = 'failed_response'
    df_slope_all.loc[(df_slope_all['lifetime'] > lifetime_min), 'long_enough_time'] = True
    df_slope_all.loc[(df_slope_all['lifetime'] <= lifetime_min), 'long_enough_time'] = False

    
    return df_slope_all

In [None]:
slope_df_04 = slope_calculations(ds_all_dict, all_metadata_df, log_gfp_th, slope_threshold = 0.04, duration_thresh = 5, data_var='GFP_median_intensity_processed_transformed', lifetime_min = 72)

In [None]:
##Code to add the parent start increase time to the slope df

lookup = slope_df_04.set_index(['unique_ID', 'cell_id'])['start_inc']

missing_keys = []

def safe_parent_lookup(row):
    key = (row['unique_ID'], row['parent'])  # Assuming 'parent' refers to 'cell_id' of parent
    value = lookup.get(key)
    if value is None:
        missing_keys.append(key)
    return value

slope_df_04['parent_start_inc'] = slope_df_04.apply(safe_parent_lookup, axis=1)

print(f"Missing {len(set(missing_keys))} parent lookups (e.g., parent not found):")
print(set(missing_keys))
slope_df_04.to_csv(plot_output +'/slope_df_final.csv')

In [None]:
def constitutive_slope_calculations(dataset_dict, metadata, data_var=None):
    """ 
    Final ruby summary stat calculations
    Get magnitude, max intensity, and greatest negative slope values and times in a d
    Note there's no filtering here - assumes you're filtering/categorizing on the basis of the GFP dataframe 
    (ie merging the two eventually and dropping any cells not around for long enough etc)
    """
    all_slope_data = []

    for unique_id, ds in dataset_dict.items():
        print(f"Running analysis on lineage: {unique_id}")
        ds_metadata = metadata.loc[metadata["unique_ID"] == unique_id].iloc[0]

        slope = ds[data_var].diff(dim='time')
        cell_records = []
        
        for cell in ds.cell.values:
            cell_data = ds[data_var].sel(cell=cell)
            lifetime = np.count_nonzero(~np.isnan(cell_data))
            max_var = cell_data.max(dim='time').values
            mean_var = cell_data.mean(dim='time').values
            max_var_time = cell_data.idxmax(dim='time').values
        
            slope_cell = slope.sel(cell=cell)
            valid_slope = slope_cell.where(~np.isnan(slope_cell), drop=True)
        
            if valid_slope.size > 0:
                min_slope = valid_slope.min(dim='time').values
                min_slope_idx = valid_slope.argmin(dim='time').values
                time_of_min_slope = ds['time'].isel(time=min_slope_idx + 1).values  # +1 because diff shifts data
            else:
                min_slope = np.nan
                time_of_min_slope = np.nan


            cell_records.append({
                "unique_ID": ds_metadata.unique_ID,
                "experiment": ds_metadata.experiment,
                "position": ds_metadata.position,
                "track": ds_metadata.track,
                "cell_id": cell,
                "min_slope_ruby": min_slope,
                "time_of_min_slope_ruby": time_of_min_slope,
                "lifetime": lifetime, 
                "mean_ruby_lifetime" : mean_var,
                "max_ruby": max_var, 
                "max_ruby_time": max_var_time,
            })

        all_slope_data.append(pd.DataFrame(cell_records))

    df_slope_all = pd.concat(all_slope_data, ignore_index=True)
    
    return df_slope_all

In [None]:
ruby_feature_df = constitutive_slope_calculations(ds_all_dict, all_metadata_df, data_var = 'TRITC_median_intensity_processed_transformed')
ruby_feature_df.to_csv(plot_output + '/ruby_features_final.csv')

In [None]:
def make_df_gen(dataset_dict, metadata, length_variable):
    """ 
    dataset_dict is the dictionary of all the xarray datasets for each position
    metadata is the df that stores all the metadata for each positions
    """
    
    all_gen_data = []
    for unique_id, ds in dataset_dict.items():
        
        print(f"Running analysis on lineage: {unique_id}")
        # ds is now our dataset - we don't need to select it
        # first find the corresponding metadata - we could split the key name but I feel like this could lead to unanticipated problems
        ds_metadata = metadata.loc[metadata["unique_ID"] == unique_id].iloc[0]

        # loop through cells and get length data
        cell_records = []

        for cell_id in ds.cell.values:
            length_data = ds[length_variable].sel(cell=cell_id)
            #print("Length Data", length_data)
            if length_data.notnull().sum().item() == 0:
                print(f"Skipping cell {cell_id} — all NaN")
                continue
            non_nan_times = length_data['time'].where(length_data.notnull(), drop=True)
            first_time = non_nan_times[0].values

            
            # find daughters for cell and when they appear
            daughter_data = ds['daughters'].sel(cell=cell_id)
            parent = ds['parents'].sel(cell=cell_id).values.item()
            division_indices = np.where((daughter_data.values != 0) & (~np.isnan(daughter_data.values)))[0]
            daughter_appearance_times = daughter_data['time'].values[division_indices]
            length_data = ds[length_variable].sel(cell=cell_id)
            print("Cell ID:", cell_id)
            #print("Time values:", daughter_data['time'].values)
            print("Division indices:", division_indices)
            print("Daughter appearance times:", daughter_appearance_times)
            first_time = length_data.notnull().idxmax(dim="time").values.item()
            last_time = length_data.notnull()[::-1].idxmax(dim="time").values.item()
            print("First time:", first_time)
            print("Last time:", last_time)
            
            # Collect start and end times for each cycle
            cycle_times = [first_time] + list(daughter_appearance_times) #values, not indices
            print("# of Gens", len(cycle_times))
             
            for i in range(len(cycle_times)):
                start_time = cycle_times[i]  # Value of cycle time, not index
                cycle = i
                if i < len(cycle_times) - 1:
                    end_time = cycle_times[i + 1] - 5# keep coordinate value, note this is 1 time frame (5 min) from the start of the next cycle, change 5 if you have diff intervals
                else:
                    end_time = last_time  # last coordinate value for the cell
                print("End Time:", end_time)
                start_index = np.argmin(np.abs(length_data['time'].values - start_time))
                end_index = np.argmin(np.abs(length_data['time'].values - end_time))
                
                start_length = length_data.isel(time=start_index).values
                end_length = length_data.isel(time=end_index).values
                print("Start Index:", start_index)
                print("End Index:", end_index)
                
                # Access length values 
                start_length = length_data.isel(time=start_index).values.item()
                end_length = length_data.isel(time=end_index).values.item()
                print(start_length)
                print(end_length)
        
                cell_records.append({
                    "unique_ID": ds_metadata.unique_ID,
                    "experiment": ds_metadata.experiment,
                    "channel_width": ds_metadata.channel_width,
                    "last_valid_cell_time": last_time,
                    "first_valid_cell_time": first_time,
                    "position": ds_metadata.position,
                    "track": ds_metadata.track,
                    "cell_id": cell_id,
                    "parent":parent,
                    "cycle": cycle,
                    "start_time": start_time,
                    "start_length": start_length,
                    "end_time": end_time,
                    "end_length": end_length
                })
        
        all_gen_data.append(pd.DataFrame(cell_records))
    
    df_all_gen = pd.concat(all_gen_data, ignore_index=True)
        


    return df_all_gen

In [None]:
df_gen = make_df_gen(ds_all_dict, all_metadata_df, 'feret_diameter_um')


In [None]:
df_gen['cycle_duration'] = df_gen['end_time']-df_gen['start_time']
df_gen['total_growth'] = df_gen['end_length']-df_gen['start_length']
df_gen['avg_elong_rate'] = df_gen['total_growth']/df_gen['cycle_duration']
df_gen.to_csv(plot_output + '/df_gen.csv')

In [None]:
dfs = [slope_df_005, slope_df_01, slope_df_03, slope_df_04, slope_df_05, slope_df_07, slope_df_09, slope_df_11] 

combined_df = pd.concat(dfs, ignore_index=True)

grouped = combined_df.groupby(['slope_th', 'category']).size().reset_index(name='count')

# Pivot the table to get categories as columns
pivot_df = grouped.pivot(index='slope_th', columns='category', values='count').fillna(0)

# Plot as stacked bar
pivot_df.plot(kind='bar', stacked=True, figsize=(10, 6), colormap = 'tab20b')

plt.xlabel('slope threshold')
plt.ylabel('Count')
#plt.title('Stacked Bar Plot of Categories by df_id')
plt.legend(title='Category', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.savefig(plot_output +'/slope_cat_comparison.pdf',bbox_inches='tight', transparent=True)

In [None]:
bool_cols = ['born_on', 'becomes_positive', 'long_enough_time', 'switch_like',]
combined_df = pd.concat(dfs, ignore_index=True)

# Create a new column that captures the boolean combination as a string
combined_df['combo'] = combined_df[bool_cols].astype(str).agg('-'.join, axis=1)

# Group by df_id and combo
grouped = combined_df.groupby(['slope_th', 'combo']).size().reset_index(name='count')

# Pivot for stacked bar plot
pivot_df = grouped.pivot(index='slope_th', columns='combo', values='count').fillna(0)

# Plot
pivot_df.plot(kind='bar', stacked=True, figsize=(12, 6))

plt.xlabel('Slope Threshold')
plt.ylabel('Count')
plt.legend(title='Born On, Becomes Positive, Around Long Enough, Switch Like',loc='best', bbox_to_anchor=(1, 1))
#plt.tight_layout()
plt.savefig(plot_output +'/slope_comp_allbooleans.pdf',bbox_inches='tight', transparent=True)

In [None]:
#Getting select positions for representative trace plots
switch_keys_to_include = { '20240809__XY10_crop7__Track_11', '20240809__XY04_crop3__Track_3', '20240809__XY14_crop2__Track_7' }
gradual_keys_to_include = { '20240918__XY12_crop9__Track_25', '20241028__XY06_crop5__Track_39', '20241028__XY17_crop1__Track_0' }
non_res_keys_to_include = {'20240918__XY08_crop4__Track_10', '20240918__XY10_crop5__Track_12', '20241028__XY08_crop3__Track_5'}
switch_dict = {k: ds_all_dict[k] for k in switch_keys_to_include}
gradual_dict = {k: ds_all_dict[k] for k in gradual_keys_to_include}
non_res_dict = {k: ds_all_dict[k] for k in non_res_keys_to_include}
sample_dicts = [switch_dict, gradual_dict, non_res_dict]

In [None]:
from matplotlib import colormaps as cm
from matplotlib.ticker import MultipleLocator


# Layout and figure config
rows, cols = 1, 3
figsize = (15, 3)

# Determine global y-limits across all sample_dicts
ymin, ymax = np.inf, -np.inf
for sample_dict in sample_dicts:
    for ds in sample_dict.values():
        for cell in ds.cell.values:
            parent_value = ds.parents.sel(cell=cell)
            if parent_value == 0:
                y_vals = ds["GFP_median_intensity_processed_transformed"].sel(cell=cell).values
                ymin = min(ymin, np.nanmin(y_vals))
                ymax = max(ymax, np.nanmax(y_vals))

with PdfPages('/Users/Madison/Desktop/0527_test.pdf') as pdf:
    fig, axes = plt.subplots(rows, cols, figsize=figsize)
    axes = axes.flatten()
    plot_idx = 0

    for sample_dict in sample_dicts:
        ax = axes[plot_idx]

        # Assign a unique color per unique_id
        unique_ids = sorted(sample_dict.keys())
        cmap = plt.colormaps['viridis_r'].resampled(len(unique_ids))  # correct for newer versions
        color_map = {uid: cmap(i) for i, uid in enumerate(unique_ids)}

        for unique_id in unique_ids:
            ds = sample_dict[unique_id]
            for cell in ds.cell.values:
                parent_value = ds.parents.sel(cell=cell)
                if parent_value == 0:
                    var_values = ds["GFP_median_intensity_processed_transformed"].sel(cell=cell)
                    ax.plot(ds.time, var_values, color=color_map[unique_id], alpha=0.8)

        # Axes formatting
        ax.set_title(f"Sample {plot_idx + 1}", fontsize=8)
        ax.set_xlabel("Time (mins)", fontsize=6)
        ax.set_ylabel("GFP_median_intensity_processed_transformed", fontsize=6)
        ax.tick_params(labelsize=6)
        ax.set_ylim([ymin, ymax])
        time_values = ds.time.values  # assumes consistent time across samples
        if hasattr(time_values[0], 'astype'):  # likely a numpy.datetime64
            time_values = time_values.astype(float)
        else:
            time_values = np.array(time_values)
        tick_interval = 120
        xticks = np.arange(time_values.min(), time_values.max() + tick_interval, tick_interval)
        ax.set_xticks(xticks)
        plot_idx += 1

        if plot_idx == rows * cols:
            plt.tight_layout()
            pdf.savefig(fig)
            plt.close(fig)
            fig, axes = plt.subplots(rows, cols, figsize=figsize)
            axes = axes.flatten()
            plot_idx = 0

    # Save remaining plots
    if plot_idx > 0:
        for i in range(plot_idx, rows * cols):
            fig.delaxes(axes[i])
        plt.tight_layout()
        pdf.savefig(fig)
        plt.close(fig)

In [None]:
def plot_lineages_to_pdf_with_switch_detection(dataset_dict, metadata, firstcross_df, var_to_plot="GFP_median_intensity_processed_transformed",  pdf_filename="plots.pdf",  plots_per_page=16):
    """
    Modification of the plotting code function so that it plots the start, end times, and TH crossing as produced by the slope function as full pdf of plots to check all of them
    some of the plotting features are currently hardcoded - rows, cols, figsize - so you may want to change these
    """
    rows, cols = 4, 4
    with PdfPages(pdf_filename) as pdf:
       
        fig, axes = plt.subplots(rows, cols, figsize=(16, 16))
        axes = axes.flatten()
        plot_idx = 0
        
        for unique_id, ds in dataset_dict.items():
            y_val_dict={}
            ax = axes[plot_idx]
            print(unique_id)
            ds_metadata = metadata.loc[metadata["unique_ID"] == unique_id].iloc[0]
            for cell in ds.cell.values:
                var_values = ds[var_to_plot].sel(cell=cell)
                parent_value = ds.parents.sel(cell=cell)
                if parent_value ==0:
                    color = 'limegreen'
                else: 
                    color = 'lightgrey' 
                alpha = 1 if ds.parents.sel(cell=cell)==0 else 0.5
                linewidth = 2 if ds.parents.sel(cell=cell)==0 else 0.75
                ax.plot(ds.time, var_values, color=color, alpha=alpha, linewidth = linewidth)

            # Highlight switch points (start and end)
                switch_info = firstcross_df[(firstcross_df['unique_ID'] == unique_id) & (firstcross_df['cell_id'] == cell)]
                
                if not switch_info.empty:
                    if cell ==1:
                        row = switch_info.iloc[0]
                        for switch_time, color, marker, label in [
                        ('start_inc', 'forestgreen', 'o', 'Start'),
                        ('gfp_on', 'black', '.', 'TH'),
                        ('end_inc', 'magenta', 'o', 'End'),
                        ]:
                            if pd.notna(row[switch_time]):
            # Find index in time array closest to the switch time
                                switch_t = row[switch_time]
                                if switch_t in ds.time.values:
                                    y_val = var_values.sel(time=switch_t)
                                    print(f"{label} ({switch_time}) at {switch_t}: y_val = {float(y_val.values)}")
                                    ax.scatter(switch_t, y_val.values, color=color, marker=marker, s=50, linewidth = 2, label=label) 

                if cell !=1:
                            row = switch_info.iloc[0]
                            for switch_time, color, marker, label in [
                            ('start_inc', 'darkgreen', 'X', 'Start'),
                            ('gfp_on', 'gray', '.', 'TH'),
                            ('end_inc', 'purple', 'X', 'End'),
                            ]:
                                if pd.notna(row[switch_time]):
                 #Find index in time array closest to the switch time
                                    switch_t = row[switch_time]
                                    if switch_t in ds.time.values:
                                        y_val = var_values.sel(time=switch_t)
                                        print(f"{label} ({switch_time}) at {switch_t}: y_val = {float(y_val.values)}")
                                        ax.scatter(switch_t, y_val.values, color=color, marker=marker, alpha=0.5, linewidth = 0.75, s=50, label=label)

            ax.set_title(f"{unique_id}", fontsize=8)
            ax.set_xlabel("Time (mins)", fontsize=6)
            #ax.set_yscale('log')
            ax.set_ylabel(var_to_plot, fontsize=6)
            ax.tick_params(labelsize=6)
            plot_idx += 1

            # If page is full, save and reset
            if plot_idx == plots_per_page:
                plt.tight_layout()
                pdf.savefig(fig)
                plt.close(fig)
                fig, axes = plt.subplots(rows, cols, figsize=(16, 16))
                axes = axes.flatten()
                plot_idx = 0

        # Save remaining plots
        if plot_idx > 0:
            for i in range(plot_idx, plots_per_page):
                fig.delaxes(axes[i])
            plt.tight_layout()
            pdf.savefig(fig)
            plt.close(fig)

    print(f"Saved plots to {pdf_filename}")

In [None]:
plot_lineages_to_pdf_with_switch_detection(ds_all_dict, all_metadata_df, slope_df_04, var_to_plot="GFP_median_intensity_processed_transformed",  pdf_filename=plot_output + "/final_lineages_with_switch.pdf",  plots_per_page=16)