This is quite the complex data splitting procedure.
The data is split into holdout data, training, validation, and testing.
The training and validation data only contains single-cells that have grond truth labels at the terminal time point.
While testing and holdout data contains celles that do and no not have ground truth labels at the terminal time point.

In [1]:
import pathlib

import numpy as np
import pandas as pd
from scipy.spatial.distance import euclidean
from sklearn.model_selection import train_test_split

try:
    cfg = get_ipython().config
    in_notebook = True
except NameError:
    in_notebook = False
if in_notebook:
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm

In [2]:
# write the cleaned dataframe to a parquet file
sc_file_path = pathlib.Path("../results/cleaned_sc_profile.parquet").resolve(
    strict=True
)
sc_endpoint_file_path = pathlib.Path(
    "../results/cleaned_endpoint_sc_profile.parquet"
).resolve(strict=True)

train_test_wells_file_path = pathlib.Path(
    "../../5.bulk_timelapse_model/data_splits/train_test_wells.parquet"
).resolve(strict=True)

sc_profile = pd.read_parquet(sc_file_path)
sc_endpoint_profile = pd.read_parquet(sc_endpoint_file_path)
train_test_wells = pd.read_parquet(train_test_wells_file_path)
print(f"sc_profile shape: {sc_profile.shape}")
print(f"sc_endpoint_profile shape: {sc_endpoint_profile.shape}")

sc_profile shape: (182804, 2376)
sc_endpoint_profile shape: (11340, 368)


In [3]:
sc_endpoint_profile.head()

Unnamed: 0,Metadata_plate,Metadata_Well,Metadata_number_of_singlecells,Metadata_compound,Metadata_dose,Metadata_control,Metadata_ImageNumber,Metadata_FOV,Metadata_Time,Metadata_Cells_Number_Object_Number,...,Cells_Texture_Correlation_AnnexinV_3_01_256,Cells_Texture_Correlation_AnnexinV_3_02_256,Cells_Texture_Correlation_AnnexinV_3_03_256,Cells_Texture_Correlation_DNA_3_02_256,Cells_Texture_DifferenceVariance_AnnexinV_3_02_256,Cells_Texture_InverseDifferenceMoment_AnnexinV_3_03_256,Cells_Texture_SumAverage_AnnexinV_3_00_256,Cells_Texture_SumAverage_DNA_3_01_256,Metadata_Well_FOV,Metadata_sc_unique_track_id
0,1,C-09,69,Staurosporine,39.06,positive,1,2,13.0,23,...,-0.105321,0.258335,0.245637,-0.370941,0.063992,0.382514,-0.135229,2.375118,C-09_0002,C-09_0002_129
1,1,C-09,69,Staurosporine,39.06,positive,1,2,13.0,25,...,-0.650287,-0.535347,-0.911682,-0.218577,0.378465,0.509536,-0.265312,3.65782,C-09_0002,C-09_0002_39
2,1,C-09,69,Staurosporine,39.06,positive,1,2,13.0,36,...,-0.153755,0.307227,-0.233445,-0.14766,-0.076404,0.218955,-0.069362,2.324912,C-09_0002,C-09_0002_66
3,1,C-09,69,Staurosporine,39.06,positive,1,2,13.0,38,...,0.270074,0.500432,0.445759,-0.1628,-0.105987,0.30409,-0.0258,2.870185,C-09_0002,C-09_0002_65
4,1,C-09,69,Staurosporine,39.06,positive,1,2,13.0,42,...,-0.312933,0.471374,-0.05758,-1.426125,-0.291901,0.369278,-0.142464,0.237744,C-09_0002,C-09_0002_75


In [4]:
sc_profile.head()

Unnamed: 0,Metadata_plate,Metadata_Well,Metadata_number_of_singlecells,Metadata_compound,Metadata_dose,Metadata_control,Metadata_ImageNumber,Metadata_FOV,Metadata_Time,Metadata_Cells_Number_Object_Number,...,channel_DNA_cls_feature_94_scDINO,channel_DNA_cls_feature_95_scDINO,channel_DNA_cls_feature_96_scDINO,channel_DNA_cls_feature_97_scDINO,channel_DNA_cls_feature_98_scDINO,channel_DNA_cls_feature_99_scDINO,channel_DNA_cls_feature_9_scDINO,Metadata_sc_unique_track_id,Metadata_Well_FOV,Metadata_sc_unique_track_id_count
0,1,C-02,180,Staurosporine,0.0,negative,1,1,0.0,101,...,0.313944,1.126927,-0.143103,0.241127,-0.293259,-0.283715,1.434163,C-02_0001_17,C-02_0001,13
1,1,C-02,180,Staurosporine,0.0,negative,1,1,0.0,111,...,0.10275,0.845704,0.08393,-1.990931,-0.030848,-1.033722,-0.942127,C-02_0001_18,C-02_0001,13
2,1,C-02,180,Staurosporine,0.0,negative,1,1,0.0,11,...,0.810937,0.30094,-0.22878,1.782329,0.153739,-0.763335,0.725093,C-02_0001_5,C-02_0001,1
3,1,C-02,180,Staurosporine,0.0,negative,1,1,0.0,128,...,-0.711263,0.067196,-0.149771,1.40565,0.063245,2.16211,3.187469,C-02_0001_19,C-02_0001,13
4,1,C-02,180,Staurosporine,0.0,negative,1,1,0.0,132,...,-0.25113,-1.851114,0.669517,-0.439855,1.576201,0.747753,0.895601,C-02_0001_20,C-02_0001,6


In [5]:
# add a ground truth column to the sc_profile dataframe based on if the track id is in the endpoint profile
sc_profile["Metadata_ground_truth_present"] = (
    sc_profile["Metadata_sc_unique_track_id"]
    .isin(sc_endpoint_profile["Metadata_sc_unique_track_id"])
    .astype(bool)
)
sc_profile

Unnamed: 0,Metadata_plate,Metadata_Well,Metadata_number_of_singlecells,Metadata_compound,Metadata_dose,Metadata_control,Metadata_ImageNumber,Metadata_FOV,Metadata_Time,Metadata_Cells_Number_Object_Number,...,channel_DNA_cls_feature_95_scDINO,channel_DNA_cls_feature_96_scDINO,channel_DNA_cls_feature_97_scDINO,channel_DNA_cls_feature_98_scDINO,channel_DNA_cls_feature_99_scDINO,channel_DNA_cls_feature_9_scDINO,Metadata_sc_unique_track_id,Metadata_Well_FOV,Metadata_sc_unique_track_id_count,Metadata_ground_truth_present
0,1,C-02,180,Staurosporine,0.0,negative,1,0001,0.0,101,...,1.126927,-0.143103,0.241127,-0.293259,-0.283715,1.434163,C-02_0001_17,C-02_0001,13,False
1,1,C-02,180,Staurosporine,0.0,negative,1,0001,0.0,111,...,0.845704,0.083930,-1.990931,-0.030848,-1.033722,-0.942127,C-02_0001_18,C-02_0001,13,False
2,1,C-02,180,Staurosporine,0.0,negative,1,0001,0.0,11,...,0.300940,-0.228780,1.782329,0.153739,-0.763335,0.725093,C-02_0001_5,C-02_0001,1,False
3,1,C-02,180,Staurosporine,0.0,negative,1,0001,0.0,128,...,0.067196,-0.149771,1.405650,0.063245,2.162110,3.187469,C-02_0001_19,C-02_0001,13,False
4,1,C-02,180,Staurosporine,0.0,negative,1,0001,0.0,132,...,-1.851114,0.669517,-0.439855,1.576201,0.747753,0.895601,C-02_0001_20,C-02_0001,6,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
182799,1,E-11,97,Staurosporine,156.25,test,13,0004,12.0,61,...,1.538514,-1.256247,-1.167025,2.543944,0.492066,-0.437865,E-11_0004_80,E-11_0004,11,False
182800,1,E-11,97,Staurosporine,156.25,test,13,0004,12.0,6,...,-0.289426,0.895888,-2.646345,1.653147,-0.763507,-1.567225,E-11_0004_11,E-11_0004,11,False
182801,1,E-11,97,Staurosporine,156.25,test,13,0004,12.0,62,...,2.786090,-0.056308,-0.942568,-0.867108,-0.909181,-0.240636,E-11_0004_147,E-11_0004,3,False
182802,1,E-11,97,Staurosporine,156.25,test,13,0004,12.0,64,...,2.720270,-0.602779,1.213263,-1.253546,-0.766996,2.020763,E-11_0004_86,E-11_0004,9,False


At this point there are two subsets to the dataset and will be split into the following datasplits:
- Single-cells that have a single cell ground truth
     - Holdout wells: 1/3 of wells
    - Train: 80% of the all cells with a single cell ground truth from the non-holdout wells
    - Validation:  10% of the all cells with a single cell ground truth from the non-holdout wells
    - Test: 10% of the all cells with a single cell ground truth from the non-holdout wells
- Single-cells that do not have a single cell ground truth
    - Holdout wells: 1/3 of wells
    - Test: 100% of the all cells with a single cell ground truth from the non-holdout wells


### hold out wells regardless of ground truth

In [6]:
index_data_split_and_ground_truth_dict = {
    "index": [],
    "data_split": [],
    "ground_truth": [],
}

In [7]:
# map the data_split to the sc_profile dataframe via the well
sc_profile["Metadata_data_split"] = sc_profile["Metadata_Well"].map(
    train_test_wells.set_index("Metadata_Well")["data_split"]
)
sc_profile.loc[sc_profile["Metadata_data_split"] == "test", "Metadata_data_split"] = (
    "well_holdout"
)
holdout_df = sc_profile.loc[sc_profile["Metadata_data_split"] == "well_holdout"]
index_data_split_and_ground_truth_dict["index"].append(holdout_df.index.tolist())
index_data_split_and_ground_truth_dict["data_split"].append(
    holdout_df["Metadata_data_split"].tolist()
)
index_data_split_and_ground_truth_dict["ground_truth"].append(
    holdout_df["Metadata_ground_truth_present"].tolist()
)
# get the non holdout wells
non_holdout_wells = sc_profile.loc[sc_profile["Metadata_data_split"] != "well_holdout"]
print(f"sc_profile shape after mapping data_split: {non_holdout_wells.shape}")
print(f"holdout_df shape: {holdout_df.shape}")

sc_profile shape after mapping data_split: (123348, 2378)
holdout_df shape: (59456, 2378)


### Cells that have a single cell ground truth

In [8]:
cell_wout_ground_truth_df = non_holdout_wells.loc[
    non_holdout_wells["Metadata_ground_truth_present"] == False
].copy()
cell_w_ground_truth_df = non_holdout_wells.loc[
    non_holdout_wells["Metadata_ground_truth_present"] == True
].copy()

print(f"cell_w_ground_truth_df shape: {cell_w_ground_truth_df.shape}")
print(f"cell_wout_ground_truth_df shape: {cell_wout_ground_truth_df.shape}")

cell_w_ground_truth_df shape: (6432, 2378)
cell_wout_ground_truth_df shape: (116916, 2378)


##

In [9]:
# split the data into 80, 10, 10 stratified by the well
train_sc_w_ground_truth_df, test_sc_w_ground_truth_df = train_test_split(
    cell_w_ground_truth_df,
    test_size=0.2,
    stratify=cell_w_ground_truth_df["Metadata_Well"],
    random_state=0,
)
test_sc_w_ground_truth_df, val_sc_w_ground_truth_df = train_test_split(
    test_sc_w_ground_truth_df,
    test_size=0.5,
    stratify=test_sc_w_ground_truth_df["Metadata_Well"],
    random_state=0,
)

train_sc_w_ground_truth_df["Metadata_data_split"] = "train"
train_sc_w_ground_truth_df["Metadata_ground_truth_present"] = True
val_sc_w_ground_truth_df["Metadata_data_split"] = "val"
val_sc_w_ground_truth_df["Metadata_ground_truth_present"] = True
test_sc_w_ground_truth_df["Metadata_data_split"] = "test"
test_sc_w_ground_truth_df["Metadata_ground_truth_present"] = True

print(f"train_sc_w_ground_truth_df shape: {train_sc_w_ground_truth_df.shape[0]}")
print(f"val_sc_w_ground_truth_df shape: {val_sc_w_ground_truth_df.shape[0]}")
print(f"test_sc_w_ground_truth_df shape: {test_sc_w_ground_truth_df.shape[0]}")
assert (
    train_sc_w_ground_truth_df.shape[0]
    + val_sc_w_ground_truth_df.shape[0]
    + test_sc_w_ground_truth_df.shape[0]
    == cell_w_ground_truth_df.shape[0]
)
assert (
    np.round(train_sc_w_ground_truth_df.shape[0] / cell_w_ground_truth_df.shape[0], 2)
    == 0.8
)
assert (
    np.round(val_sc_w_ground_truth_df.shape[0] / cell_w_ground_truth_df.shape[0], 2)
    == 0.1
)
assert (
    np.round(test_sc_w_ground_truth_df.shape[0] / cell_w_ground_truth_df.shape[0], 2)
    == 0.1
)

# add to records
index_data_split_and_ground_truth_dict["index"].append(
    train_sc_w_ground_truth_df.index.tolist()
)
index_data_split_and_ground_truth_dict["data_split"].append(
    train_sc_w_ground_truth_df["Metadata_data_split"].tolist()
)
index_data_split_and_ground_truth_dict["ground_truth"].append(
    train_sc_w_ground_truth_df["Metadata_ground_truth_present"].tolist()
)
index_data_split_and_ground_truth_dict["index"].append(
    val_sc_w_ground_truth_df.index.tolist()
)
index_data_split_and_ground_truth_dict["data_split"].append(
    val_sc_w_ground_truth_df["Metadata_data_split"].tolist()
)
index_data_split_and_ground_truth_dict["ground_truth"].append(
    val_sc_w_ground_truth_df["Metadata_ground_truth_present"].tolist()
)
index_data_split_and_ground_truth_dict["index"].append(
    test_sc_w_ground_truth_df.index.tolist()
)
index_data_split_and_ground_truth_dict["data_split"].append(
    test_sc_w_ground_truth_df["Metadata_data_split"].tolist()
)
index_data_split_and_ground_truth_dict["ground_truth"].append(
    test_sc_w_ground_truth_df["Metadata_ground_truth_present"].tolist()
)

train_sc_w_ground_truth_df shape: 5145
val_sc_w_ground_truth_df shape: 644
test_sc_w_ground_truth_df shape: 643


#### Non tracked cells

In [10]:
cell_wout_ground_truth_df["Metadata_data_split"] = "test"
# add to records
index_data_split_and_ground_truth_dict["index"].append(
    cell_wout_ground_truth_df.index.tolist()
)
index_data_split_and_ground_truth_dict["data_split"].append(
    cell_wout_ground_truth_df["Metadata_data_split"].tolist()
)
index_data_split_and_ground_truth_dict["ground_truth"].append(
    cell_wout_ground_truth_df["Metadata_ground_truth_present"].tolist()
)
print(f"test_sc_wo_ground_truth_df shape: {cell_wout_ground_truth_df.shape[0]}")

test_sc_wo_ground_truth_df shape: 116916


### Fetch the indices from each ground truth and data split and add the status back to sc_profile

In [11]:
# flatten each list in the dictionary
import itertools

for key in index_data_split_and_ground_truth_dict.keys():
    index_data_split_and_ground_truth_dict[key] = list(
        itertools.chain.from_iterable(index_data_split_and_ground_truth_dict[key])
    )
data_split_data_df = pd.DataFrame.from_dict(
    index_data_split_and_ground_truth_dict,
    orient="columns",
)
assert data_split_data_df.shape[0] == sc_profile.shape[0]

In [12]:
# sort the dataframe by index
data_split_data_df.sort_values(
    by=["index"],
    inplace=True,
)
# make the index the index column in data_split_data_df
data_split_data_df.set_index("index", inplace=True)
data_split_data_df.reset_index(drop=False, inplace=True)
data_split_data_df.head()

Unnamed: 0,index,data_split,ground_truth
0,0,well_holdout,False
1,1,well_holdout,False
2,2,well_holdout,False
3,3,well_holdout,False
4,4,well_holdout,False


In [13]:
# addthe data_split and ground truth columns to the sc_profile dataframe by index
sc_profile_with_data_splits_df = pd.concat(
    [sc_profile, data_split_data_df],
    axis=1,
)
sc_profile_with_data_splits_df.drop(
    columns=["Metadata_data_split", "Metadata_ground_truth_present"],
    inplace=True,
)
sc_profile_with_data_splits_df.rename(
    columns={
        "data_split": "Metadata_data_split",
        "ground_truth": "Metadata_ground_truth_present",
    },
    inplace=True,
)

In [14]:
# final breakdown of the data
train_gt = sc_profile_with_data_splits_df[
    sc_profile_with_data_splits_df["Metadata_data_split"] == "train"
].copy()
train_gt = train_gt[train_gt["Metadata_ground_truth_present"] == True].copy()
val_gt = sc_profile_with_data_splits_df[
    sc_profile_with_data_splits_df["Metadata_data_split"] == "val"
].copy()
val_gt = val_gt[val_gt["Metadata_ground_truth_present"] == True].copy()
test_gt = sc_profile_with_data_splits_df[
    sc_profile_with_data_splits_df["Metadata_data_split"] == "test"
].copy()
test_gt = test_gt[test_gt["Metadata_ground_truth_present"] == True].copy()
test_wo_gt = sc_profile_with_data_splits_df[
    sc_profile_with_data_splits_df["Metadata_data_split"] == "test"
].copy()
test_wo_gt = test_wo_gt[test_wo_gt["Metadata_ground_truth_present"] == False].copy()
holdout_w_gt = sc_profile_with_data_splits_df[
    sc_profile_with_data_splits_df["Metadata_data_split"] == "well_holdout"
].copy()
holdout_w_gt = holdout_w_gt[
    holdout_w_gt["Metadata_ground_truth_present"] == True
].copy()
holdout_wo_gt = sc_profile_with_data_splits_df[
    sc_profile_with_data_splits_df["Metadata_data_split"] == "well_holdout"
].copy()
holdout_wo_gt = holdout_wo_gt[
    holdout_wo_gt["Metadata_ground_truth_present"] == False
].copy()
# assertion time :)
assert sc_profile_with_data_splits_df.shape[0] == sc_profile.shape[0]
assert (
    sc_profile_with_data_splits_df.shape[0]
    == train_gt.shape[0]
    + val_gt.shape[0]
    + test_gt.shape[0]
    + test_wo_gt.shape[0]
    + holdout_w_gt.shape[0]
    + holdout_wo_gt.shape[0]
)

In [15]:
print(f"train_gt shape: {train_gt.shape[0]}")
print(f"val_gt shape: {val_gt.shape[0]}")
print(f"test_gt shape: {test_gt.shape[0]}")
print(f"test_wo_gt shape: {test_wo_gt.shape[0]}")
print(f"holdout_w_gt shape: {holdout_w_gt.shape[0]}")
print(f"holdout_wo_gt shape: {holdout_wo_gt.shape[0]}")

train_gt shape: 5145
val_gt shape: 644
test_gt shape: 643
test_wo_gt shape: 116916
holdout_w_gt shape: 3220
holdout_wo_gt shape: 56236


In [16]:
# write the data splits to a parquet file
data_split_file_path = pathlib.Path("../results/data_splits.parquet").resolve()
data_split_data_df.to_parquet(
    data_split_file_path,
    index=False,
)