This is quite the complex data splitting procedure.
The data is split into holdout data, training, validation, and testing.
The training and validation data only contains single-cells that have ground truth labels at the terminal time point.
While testing and holdout data contains cells that do and no not have ground truth labels at the terminal time point.

In [1]:
import itertools
import pathlib

import numpy as np
import pandas as pd
from scipy.spatial.distance import euclidean
from sklearn.model_selection import train_test_split

try:
    cfg = get_ipython().config
    in_notebook = True
except NameError:
    in_notebook = False
if in_notebook:
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm

In [2]:
# save the profiles as cleaned data
cleaned_sc_profile_file_path = pathlib.Path(
    "../results/cleaned_timelapse_profiles.parquet"
).resolve(strict=True)
# save the endpoint profiles as cleaned data
cleaned_endpoint_sc_profile_file_path = pathlib.Path(
    "../results/cleaned_endpoint_profiles.parquet"
).resolve(strict=True)
ground_truth_ids_file_path = pathlib.Path(
    "../results/ground_truth_ids.parquet"
).resolve(strict=True)
sc_profile_df = pd.read_parquet(cleaned_sc_profile_file_path)

endpoint_sc_profile_df = pd.read_parquet(cleaned_endpoint_sc_profile_file_path)
ground_truth_ids = pd.read_parquet(ground_truth_ids_file_path)

# drop array columnms
array_columns = [
    "Metadata_x",
    "Metadata_y",
]
sc_profile_df = sc_profile_df.drop(columns=array_columns)
endpoint_sc_profile_df = endpoint_sc_profile_df.drop(columns=array_columns)

In [3]:
endpoint_sc_profile_df.head()

Unnamed: 0,Metadata_plate,Metadata_Well,Metadata_number_of_singlecells,Metadata_compound,Metadata_dose,Metadata_control,Metadata_ImageNumber,Metadata_FOV,Metadata_Time,Metadata_Cells_Number_Object_Number,...,Nuclei_Texture_Correlation_AnnexinV_3_02_256,Nuclei_Texture_Correlation_AnnexinV_3_03_256,Nuclei_Texture_Correlation_DNA_3_02_256,Nuclei_Texture_DifferenceVariance_AnnexinV_3_01_256,Nuclei_Texture_InverseDifferenceMoment_AnnexinV_3_03_256,Nuclei_Texture_InverseDifferenceMoment_DNA_3_03_256,Nuclei_Texture_SumAverage_AnnexinV_3_00_256,Nuclei_Texture_SumAverage_DNA_3_01_256,Metadata_Well_FOV,Metadata_sc_unique_track_id
0,1,C-09,153,Staurosporine,39.06,positive,1,2,13,1.0,...,-0.832951,-0.666071,0.643979,-1.267717,-1.853163,0.47445,-1.437757,0.347586,C-09_0002,C-09_0002_18
1,1,C-09,153,Staurosporine,39.06,positive,1,2,13,3.0,...,-0.471301,-0.589558,-1.321561,0.652602,0.668286,0.260828,0.034058,0.566404,C-09_0002,C-09_0002_32
2,1,C-09,153,Staurosporine,39.06,positive,1,2,13,6.0,...,-0.832951,-0.666071,0.643979,-1.267717,-1.853163,0.47445,-1.437757,0.347586,C-09_0002,C-09_0002_52
3,1,C-09,153,Staurosporine,39.06,positive,1,2,13,8.0,...,0.222415,-0.872616,-1.172975,-0.369571,0.427535,0.220951,0.193142,0.647287,C-09_0002,C-09_0002_89
4,1,C-09,153,Staurosporine,39.06,positive,1,2,13,9.0,...,-0.172902,-0.928359,-0.854807,0.018437,0.463694,0.208027,0.118983,0.660762,C-09_0002,C-09_0002_162


In [4]:
sc_profile_df.head()

Unnamed: 0,Metadata_plate,Metadata_Well,Metadata_number_of_singlecells,Metadata_compound,Metadata_dose,Metadata_control,Metadata_ImageNumber,Metadata_FOV,Metadata_Time,Metadata_Cells_Number_Object_Number,...,channel_DNA_cls_feature_98_scDINO,channel_DNA_cls_feature_99_scDINO,channel_DNA_cls_feature_9_scDINO,Metadata_Image_FileName_CL_488_1_crop,Metadata_Image_FileName_CL_488_2_crop,Metadata_Image_FileName_CL_561_crop,Metadata_Image_FileName_DNA_crop,Metadata_parent_path,Metadata_sc_unique_track_id,Metadata_Well_FOV
0,1,C-09,168,Staurosporine,39.06,positive,1,2,0,7,...,-0.12352,2.401852,1.516202,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,C-09_0002_5,C-09_0002
1,1,C-09,168,Staurosporine,39.06,positive,1,2,0,9,...,-0.835988,-0.264486,0.153676,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,C-09_0002_6,C-09_0002
2,1,C-09,168,Staurosporine,39.06,positive,1,2,0,10,...,-0.359857,0.659583,0.537619,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,C-09_0002_7,C-09_0002
3,1,C-09,168,Staurosporine,39.06,positive,1,2,0,11,...,0.211796,0.443178,1.129714,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,C-09_0002_8,C-09_0002
4,1,C-09,168,Staurosporine,39.06,positive,1,2,0,12,...,-1.694061,-0.22899,0.648714,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,/home/lippincm/4TB_A/live_cell_timelapse_apopt...,C-09_0002_9,C-09_0002


In [5]:
# there are 10 doses, with three wells each
# one well is needed for each dose for training
# select one well per dose
test_wells = []
for dose in sc_profile_df["Metadata_dose"].unique():
    wells = sc_profile_df[sc_profile_df["Metadata_dose"] == dose][
        "Metadata_Well"
    ].tolist()
    selected_well = np.random.choice(wells, 1)[0]
    print(f"Selected well {selected_well} for dose {dose}")
    test_wells.append(str(selected_well))

train_wells = sc_profile_df[~sc_profile_df["Metadata_Well"].isin(test_wells)][
    "Metadata_Well"
].tolist()

Selected well E-09 for dose 39.06
Selected well E-06 for dose 4.88
Selected well C-07 for dose 9.77
Selected well E-03 for dose 0.61
Selected well E-05 for dose 2.44
Selected well D-11 for dose 156.25
Selected well E-08 for dose 19.53
Selected well E-04 for dose 1.22
Selected well D-10 for dose 78.13
Selected well E-02 for dose 0.0


In [6]:
train_test_wells = pd.DataFrame(
    {
        "data_split": ["train"] * len(np.unique(train_wells))
        + ["test"] * len(np.unique(test_wells)),
        "Metadata_Well": np.unique(train_wells).tolist()
        + np.unique(test_wells).tolist(),
    }
)

In [7]:
# drop NaN rows
before_shape = sc_profile_df.shape
print(f"sc_profile shape before dropping NaNs: {before_shape}")
sc_profile_df = sc_profile_df.dropna()
print(f"sc_profile shape after dropping NaNs: {sc_profile_df.shape}")
print(f"Dropped {before_shape[0] - sc_profile_df.shape[0]} rows with NaNs")
# same for endpoint profile
sc_endpoint_profile_before_shape = endpoint_sc_profile_df.shape
print(
    f"sc_endpoint_profile shape before dropping NaNs: {sc_endpoint_profile_before_shape}"
)
endpoint_sc_profile_df = endpoint_sc_profile_df.dropna()
print(f"sc_endpoint_profile shape after dropping NaNs: {endpoint_sc_profile_df.shape}")
print(
    f"Dropped {sc_endpoint_profile_before_shape[0] - endpoint_sc_profile_df.shape[0]} rows with NaNs"
)

sc_profile shape before dropping NaNs: (188065, 2378)
sc_profile shape after dropping NaNs: (185502, 2378)
Dropped 2563 rows with NaNs
sc_endpoint_profile shape before dropping NaNs: (9918, 543)
sc_endpoint_profile shape after dropping NaNs: (9364, 543)
Dropped 554 rows with NaNs


In [8]:
endpoint_sc_profile_df.drop(
    columns=["Metadata_coordinates_x", "Metadata_coordinates_y"], inplace=True
)

In [9]:
endpoint_sc_profile_df.drop_duplicates(
    subset=["Metadata_sc_unique_track_id"], inplace=True
)

In [10]:
# if the track id is not in the sc_profile_df, drop it from the endpoint profile
before_shape = endpoint_sc_profile_df.shape
endpoint_sc_profile_df = endpoint_sc_profile_df[
    endpoint_sc_profile_df["Metadata_sc_unique_track_id"].isin(
        sc_profile_df["Metadata_sc_unique_track_id"].unique()
    )
]
# if the track id is not in the endpoint profile, drop it from the sc_profile_df
sc_profile_before_shape = sc_profile_df.shape
sc_profile_df = sc_profile_df[
    sc_profile_df["Metadata_sc_unique_track_id"].isin(
        endpoint_sc_profile_df["Metadata_sc_unique_track_id"].unique()
    )
]
print(
    f"Dropped {before_shape[0] - endpoint_sc_profile_df.shape[0]} rows from endpoint profile that were not in sc_profile"
)
print(
    f"Dropped {sc_profile_before_shape[0] - sc_profile_df.shape[0]} rows from sc_profile that were not in endpoint profile"
)

Dropped 440 rows from endpoint profile that were not in sc_profile
Dropped 109069 rows from sc_profile that were not in endpoint profile


In [11]:
# add a ground truth column to the sc_profile dataframe based on if the track id is in the endpoint profile
sc_profile_df["Metadata_ground_truth_present"] = (
    sc_profile_df["Metadata_sc_unique_track_id"]
    .isin(ground_truth_ids["Metadata_sc_unique_track_id"])
    .astype(bool)
)

In [12]:
sc_profile_df["Metadata_ground_truth_present"].unique()

array([False,  True])

In [13]:
sc_profile_df_final_timepoint = sc_profile_df.loc[
    sc_profile_df["Metadata_Time"] == sc_profile_df["Metadata_Time"].max()
]

### Cells that have a single cell ground truth

In [14]:
sc_profile_df_gt = sc_profile_df_final_timepoint[
    sc_profile_df_final_timepoint["Metadata_ground_truth_present"] == True
].copy()
sc_profile_df_no_gt = sc_profile_df[
    sc_profile_df["Metadata_ground_truth_present"] == False
].copy()

In [15]:
train_sc_w_ground_truth_df = sc_profile_df_gt[
    sc_profile_df_gt["Metadata_Well"].isin(train_wells)
].copy()
test_sc_w_ground_truth_df = sc_profile_df_gt[
    sc_profile_df_gt["Metadata_Well"].isin(test_wells)
].copy()
print(f"train_sc_w_ground_truth_df shape: {train_sc_w_ground_truth_df.shape}")
print(f"test_sc_w_ground_truth_df shape: {test_sc_w_ground_truth_df.shape}")

train_sc_w_ground_truth_df shape: (4994, 2379)
test_sc_w_ground_truth_df shape: (2216, 2379)


In [16]:
# undersample training set and move excess cells to testing set
min_cells_per_dose = train_sc_w_ground_truth_df["Metadata_dose"].value_counts().min()

# Get the sampled training data and the remaining cells
train_sampled_list = []
train_excess_list = []

for dose in train_sc_w_ground_truth_df["Metadata_dose"].unique():
    dose_data = train_sc_w_ground_truth_df[
        train_sc_w_ground_truth_df["Metadata_dose"] == dose
    ]
    sampled = dose_data.sample(min_cells_per_dose, random_state=0)
    excess = dose_data.drop(sampled.index)

    train_sampled_list.append(sampled)
    train_excess_list.append(excess)

# Combine the results
train_sc_w_ground_truth_df = pd.concat(train_sampled_list, ignore_index=True)
train_excess_df = pd.concat(train_excess_list, ignore_index=True)

# Add the excess training cells to the test set
test_sc_w_ground_truth_df = pd.concat(
    [test_sc_w_ground_truth_df, train_excess_df], ignore_index=True
)

train_sc_w_ground_truth_df["Metadata_data_split"] = "train"
test_sc_w_ground_truth_df["Metadata_data_split"] = "test"
sc_profile_df_no_gt["Metadata_data_split"] = "test_no_gt"

In [17]:
output_dict = {
    "sc_id": [],
    "data_split": [],
}
output_dict["sc_id"].append(
    train_sc_w_ground_truth_df["Metadata_sc_unique_track_id"].tolist()
)
output_dict["data_split"].append(
    train_sc_w_ground_truth_df["Metadata_data_split"].tolist()
)
output_dict["sc_id"].append(
    test_sc_w_ground_truth_df["Metadata_sc_unique_track_id"].tolist()
)
output_dict["data_split"].append(
    test_sc_w_ground_truth_df["Metadata_data_split"].tolist()
)
output_dict["sc_id"].append(sc_profile_df_no_gt["Metadata_sc_unique_track_id"].tolist())
output_dict["data_split"].append(sc_profile_df_no_gt["Metadata_data_split"].tolist())
for key in output_dict:
    output_dict[key] = list(itertools.chain.from_iterable(output_dict[key]))

In [18]:
ids_and_splits_df = pd.DataFrame.from_dict(output_dict)

In [19]:
train_df = sc_profile_df[
    sc_profile_df["Metadata_sc_unique_track_id"].isin(
        ids_and_splits_df.loc[ids_and_splits_df["data_split"] == "train", "sc_id"]
    )
].copy()
test_df = sc_profile_df[
    sc_profile_df["Metadata_sc_unique_track_id"].isin(
        ids_and_splits_df.loc[ids_and_splits_df["data_split"] == "test", "sc_id"]
    )
].copy()
test_no_gt_df = sc_profile_df_no_gt[
    sc_profile_df_no_gt["Metadata_sc_unique_track_id"].isin(
        ids_and_splits_df.loc[ids_and_splits_df["data_split"] == "test_no_gt", "sc_id"]
    )
].copy()

train_df_terminal_time = endpoint_sc_profile_df[
    endpoint_sc_profile_df["Metadata_sc_unique_track_id"].isin(
        ids_and_splits_df.loc[ids_and_splits_df["data_split"] == "train", "sc_id"]
    )
].copy()
test_df_terminal_time = endpoint_sc_profile_df[
    endpoint_sc_profile_df["Metadata_sc_unique_track_id"].isin(
        ids_and_splits_df.loc[ids_and_splits_df["data_split"] == "test", "sc_id"]
    )
].copy()
test_no_gt_df_terminal_time = endpoint_sc_profile_df[
    endpoint_sc_profile_df["Metadata_sc_unique_track_id"].isin(
        ids_and_splits_df.loc[ids_and_splits_df["data_split"] == "test_no_gt", "sc_id"]
    )
].copy()


print(f"train_df shape: {train_df.shape}")
print(f"test_df shape: {test_df.shape}")
print(f"test_no_gt_df shape: {test_no_gt_df.shape}")
print(f"train_df_terminal_time shape: {train_df_terminal_time.shape}")
print(f"test_df_terminal_time shape: {test_df_terminal_time.shape}")
print(f"test_no_gt_df_terminal_time shape: {test_no_gt_df_terminal_time.shape}")

print(
    f"Number of unique cells in training set: {train_df['Metadata_sc_unique_track_id'].nunique()}"
)
print(
    f"Number of unique cells in testing set: {test_df['Metadata_sc_unique_track_id'].nunique()}"
)
print(
    f"Number of unique cells in testing set without ground truth: {test_no_gt_df['Metadata_sc_unique_track_id'].nunique()}"
)
print(
    f"Number of unique cells in training set at terminal time: {train_df_terminal_time['Metadata_sc_unique_track_id'].nunique()}"
)
print(
    f"Number of unique cells in testing set at terminal time: {test_df_terminal_time['Metadata_sc_unique_track_id'].nunique()}"
)
print(
    f"Number of unique cells in testing set without ground truth at terminal time: {test_no_gt_df_terminal_time['Metadata_sc_unique_track_id'].nunique()}"
)
print(f"Train data size: {train_df.shape}")
print(f"Test data size: {test_df.shape}")
print(
    f"Train size at last timepoint: {train_df.loc[train_df['Metadata_Time'] == train_df['Metadata_Time'].max()].shape}"
)
print(
    f"Test size at last timepoint: {test_df.loc[test_df['Metadata_Time'] == test_df['Metadata_Time'].max()].shape}"
)

train_df shape: (23720, 2379)
test_df shape: (49544, 2379)
test_no_gt_df shape: (3169, 2380)
train_df_terminal_time shape: (2300, 541)
test_df_terminal_time shape: (4910, 541)
test_no_gt_df_terminal_time shape: (496, 541)
Number of unique cells in training set: 2300
Number of unique cells in testing set: 4910
Number of unique cells in testing set without ground truth: 496
Number of unique cells in training set at terminal time: 2300
Number of unique cells in testing set at terminal time: 4910
Number of unique cells in testing set without ground truth at terminal time: 496
Train data size: (23720, 2379)
Test data size: (49544, 2379)
Train size at last timepoint: (2300, 2379)
Test size at last timepoint: (4910, 2379)


In [20]:
assert (
    train_df.loc[
        train_df["Metadata_Time"] == train_df["Metadata_Time"].max(),
        "Metadata_sc_unique_track_id",
    ].nunique()
    == train_df_terminal_time["Metadata_sc_unique_track_id"].nunique()
)
test_df_terminal_time = test_df_terminal_time.loc[
    test_df_terminal_time["Metadata_sc_unique_track_id"].isin(
        test_df.loc[
            test_df["Metadata_Time"] == test_df["Metadata_Time"].max(),
            "Metadata_sc_unique_track_id",
        ].unique()
    )
]
assert (
    test_df.loc[
        test_df["Metadata_Time"] == test_df["Metadata_Time"].max(),
        "Metadata_sc_unique_track_id",
    ].nunique()
    == test_df_terminal_time["Metadata_sc_unique_track_id"].nunique()
)

In [21]:
# save each split as parquet files
train_file_path = pathlib.Path("../results/train_sc_profile.parquet").resolve()
test_file_path = pathlib.Path("../results/test_sc_profile.parquet").resolve()
test_no_gt_file_path = pathlib.Path(
    "../results/test_no_gt_sc_profile.parquet"
).resolve()
train_terminal_time_file_path = pathlib.Path(
    "../results/train_sc_profile_terminal_time.parquet"
).resolve()
test_terminal_time_file_path = pathlib.Path(
    "../results/test_sc_profile_terminal_time.parquet"
).resolve()
test_no_gt_terminal_time_file_path = pathlib.Path(
    "../results/test_no_gt_sc_profile_terminal_time.parquet"
).resolve()

train_df.to_parquet(train_file_path, index=False)
test_df.to_parquet(test_file_path, index=False)
test_no_gt_df.to_parquet(test_no_gt_file_path, index=False)
train_df_terminal_time.to_parquet(train_terminal_time_file_path, index=False)
test_df_terminal_time.to_parquet(test_terminal_time_file_path, index=False)
test_no_gt_df_terminal_time.to_parquet(test_no_gt_terminal_time_file_path, index=False)