This is quite the complex data splitting procedure.
The data is split into holdout data, training, validation, and testing.
The training and validation data only contains single-cells that have ground truth labels at the terminal time point.
While testing and holdout data contains cells that do and no not have ground truth labels at the terminal time point.

In [1]:
import itertools
import pathlib

import numpy as np
import pandas as pd
from scipy.spatial.distance import euclidean
from sklearn.model_selection import train_test_split

try:
    cfg = get_ipython().config
    in_notebook = True
except NameError:
    in_notebook = False
if in_notebook:
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm

In [2]:
# read in the data
sc_file_path = pathlib.Path("../results/cleaned_sc_profile.parquet").resolve(
    strict=True
)
sc_endpoint_file_path = pathlib.Path(
    "../results/cleaned_endpoint_sc_profile.parquet"
).resolve(strict=True)

train_test_wells_file_path = pathlib.Path(
    "../../5.bulk_timelapse_model/data_splits/train_test_wells.parquet"
).resolve(strict=True)

sc_profile = pd.read_parquet(sc_file_path)
sc_endpoint_profile = pd.read_parquet(sc_endpoint_file_path)
train_test_wells = pd.read_parquet(train_test_wells_file_path)
print(f"sc_profile shape: {sc_profile.shape}")
print(f"sc_endpoint_profile shape: {sc_endpoint_profile.shape}")

sc_profile shape: (145489, 2447)
sc_endpoint_profile shape: (31332, 284)


In [3]:
sc_endpoint_profile.head()

Unnamed: 0,Metadata_plate,Metadata_Well,Metadata_number_of_singlecells,Metadata_compound,Metadata_dose,Metadata_control,Metadata_ImageNumber,Metadata_FOV,Metadata_Time,Metadata_Cells_Number_Object_Number,...,Cells_Texture_Correlation_AnnexinV_3_00_256,Cells_Texture_Correlation_AnnexinV_3_01_256,Cells_Texture_Correlation_AnnexinV_3_02_256,Cells_Texture_Correlation_AnnexinV_3_03_256,Cells_Texture_Correlation_DNA_3_03_256,Cells_Texture_DifferenceVariance_AnnexinV_3_03_256,Cells_Texture_DifferenceVariance_DNA_3_01_256,Cells_Texture_SumAverage_DNA_3_03_256,Metadata_Well_FOV,Metadata_sc_unique_track_id
0,1,C-04,161,Staurosporine,1.22,test,1,1,13.0,13.0,...,0.110307,-0.666442,-0.495497,-0.298797,1.069224,0.420328,0.998139,0.617502,C-04_0001,C-04_0001_13
1,1,C-04,161,Staurosporine,1.22,test,1,1,13.0,14.0,...,1.284977,1.281368,1.150719,0.845906,-0.986121,-1.601364,0.732015,0.622838,C-04_0001,C-04_0001_33
2,1,C-04,161,Staurosporine,1.22,test,1,1,13.0,15.0,...,0.648434,0.237933,0.55686,0.53401,1.069224,-0.387579,0.998139,0.617502,C-04_0001,C-04_0001_22
3,1,C-04,161,Staurosporine,1.22,test,1,1,13.0,23.0,...,-1.239068,-0.90782,-1.282647,-0.858384,1.069224,0.881355,0.998139,0.617502,C-04_0001,C-04_0001_19
4,1,C-04,161,Staurosporine,1.22,test,1,1,13.0,24.0,...,-1.112063,-0.763913,-1.145026,-0.772991,1.069224,1.355669,0.998139,0.617502,C-04_0001,C-04_0001_141


In [4]:
sc_profile.head()

Unnamed: 0,Metadata_plate,Metadata_Well,Metadata_number_of_singlecells,Metadata_compound,Metadata_dose,Metadata_control,Metadata_ImageNumber,Metadata_FOV,Metadata_Time,Metadata_Cells_Number_Object_Number,...,channel_DNA_cls_feature_94_scDINO,channel_DNA_cls_feature_95_scDINO,channel_DNA_cls_feature_96_scDINO,channel_DNA_cls_feature_97_scDINO,channel_DNA_cls_feature_98_scDINO,channel_DNA_cls_feature_99_scDINO,channel_DNA_cls_feature_9_scDINO,Metadata_sc_unique_track_id,Metadata_Well_FOV,Metadata_sc_unique_track_id_count
0,1,C-02,178,Staurosporine,0.0,negative,1,1,0.0,98.0,...,0.312717,1.018808,-0.202042,0.147753,-0.357406,-0.140129,1.510029,C-02_0001_17,C-02_0001,12
1,1,C-02,178,Staurosporine,0.0,negative,1,1,0.0,108.0,...,0.410329,0.872185,1.059767,-1.079649,0.219456,0.469064,-0.973322,C-02_0001_18,C-02_0001,12
2,1,C-02,178,Staurosporine,0.0,negative,1,1,0.0,123.0,...,-0.765297,0.035978,-0.131003,1.354683,-0.020082,2.435365,3.352972,C-02_0001_19,C-02_0001,13
3,1,C-02,178,Staurosporine,0.0,negative,1,1,0.0,127.0,...,-0.323113,-1.792923,0.823824,-0.372862,1.516754,0.890084,0.934903,C-02_0001_20,C-02_0001,6
4,1,C-02,178,Staurosporine,0.0,negative,1,1,0.0,15.0,...,0.942032,0.615786,-1.282878,1.124272,-0.04327,-1.727197,0.671195,C-02_0001_5,C-02_0001,1


In [5]:
# add a ground truth column to the sc_profile dataframe based on if the track id is in the endpoint profile
sc_profile["Metadata_ground_truth_present"] = (
    sc_profile["Metadata_sc_unique_track_id"]
    .isin(sc_endpoint_profile["Metadata_sc_unique_track_id"])
    .astype(bool)
)
sc_profile

Unnamed: 0,Metadata_plate,Metadata_Well,Metadata_number_of_singlecells,Metadata_compound,Metadata_dose,Metadata_control,Metadata_ImageNumber,Metadata_FOV,Metadata_Time,Metadata_Cells_Number_Object_Number,...,channel_DNA_cls_feature_95_scDINO,channel_DNA_cls_feature_96_scDINO,channel_DNA_cls_feature_97_scDINO,channel_DNA_cls_feature_98_scDINO,channel_DNA_cls_feature_99_scDINO,channel_DNA_cls_feature_9_scDINO,Metadata_sc_unique_track_id,Metadata_Well_FOV,Metadata_sc_unique_track_id_count,Metadata_ground_truth_present
0,1,C-02,178,Staurosporine,0.0,negative,1,0001,0.0,98.0,...,1.018808,-0.202042,0.147753,-0.357406,-0.140129,1.510029,C-02_0001_17,C-02_0001,12,False
1,1,C-02,178,Staurosporine,0.0,negative,1,0001,0.0,108.0,...,0.872185,1.059767,-1.079649,0.219456,0.469064,-0.973322,C-02_0001_18,C-02_0001,12,False
2,1,C-02,178,Staurosporine,0.0,negative,1,0001,0.0,123.0,...,0.035978,-0.131003,1.354683,-0.020082,2.435365,3.352972,C-02_0001_19,C-02_0001,13,True
3,1,C-02,178,Staurosporine,0.0,negative,1,0001,0.0,127.0,...,-1.792923,0.823824,-0.372862,1.516754,0.890084,0.934903,C-02_0001_20,C-02_0001,6,False
4,1,C-02,178,Staurosporine,0.0,negative,1,0001,0.0,15.0,...,0.615786,-1.282878,1.124272,-0.043270,-1.727197,0.671195,C-02_0001_5,C-02_0001,1,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
145484,1,E-11,80,Staurosporine,156.25,test,13,0004,12.0,44.0,...,1.920718,-0.880146,-0.591598,0.857347,0.676845,-1.641834,E-11_0004_68,E-11_0004,9,True
145485,1,E-11,80,Staurosporine,156.25,test,13,0004,12.0,45.0,...,3.210849,0.098172,-0.448395,0.705242,-1.350891,-0.106632,E-11_0004_71,E-11_0004,10,False
145486,1,E-11,80,Staurosporine,156.25,test,13,0004,12.0,50.0,...,2.862780,-1.492234,-0.273674,0.266194,1.571618,1.251849,E-11_0004_73,E-11_0004,7,False
145487,1,E-11,80,Staurosporine,156.25,test,13,0004,12.0,52.0,...,3.682511,-1.105393,-0.869184,-1.709636,1.173656,0.243441,E-11_0004_72,E-11_0004,7,False


At this point there are two subsets to the dataset and will be split into the following datasplits:
- Single-cells that have a single cell ground truth
     - Holdout wells: 1/3 of wells
    - Train: 80% of the all cells with a single cell ground truth from the non-holdout wells
    - Validation:  10% of the all cells with a single cell ground truth from the non-holdout wells
    - Test: 10% of the all cells with a single cell ground truth from the non-holdout wells
- Single-cells that do not have a single cell ground truth
    - Holdout wells: 1/3 of wells
    - Test: 100% of the all cells with a single cell ground truth from the non-holdout wells


### hold out wells regardless of ground truth

In [6]:
index_data_split_and_ground_truth_dict = {
    "index": [],
    "data_split": [],
    "ground_truth": [],
}

In [7]:
# map the data_split to the sc_profile dataframe via the well
sc_profile["Metadata_data_split"] = sc_profile["Metadata_Well"].map(
    train_test_wells.set_index("Metadata_Well")["data_split"]
)
sc_profile.loc[sc_profile["Metadata_data_split"] == "test", "Metadata_data_split"] = (
    "well_holdout"
)
holdout_df = sc_profile.loc[sc_profile["Metadata_data_split"] == "well_holdout"]
index_data_split_and_ground_truth_dict["index"].append(holdout_df.index.tolist())
index_data_split_and_ground_truth_dict["data_split"].append(
    holdout_df["Metadata_data_split"].tolist()
)
index_data_split_and_ground_truth_dict["ground_truth"].append(
    holdout_df["Metadata_ground_truth_present"].tolist()
)
# get the non holdout wells
non_holdout_wells = sc_profile.loc[sc_profile["Metadata_data_split"] != "well_holdout"]
print(f"sc_profile shape after mapping data_split: {non_holdout_wells.shape}")
print(f"holdout_df shape: {holdout_df.shape}")
assert len(holdout_df) > 0, "No holdout wells found in sc_profile dataframe"
assert len(non_holdout_wells) > 0, "No non-holdout wells found in sc_profile dataframe"
assert len(non_holdout_wells) + len(holdout_df) == len(
    sc_profile
), "The number of holdout wells and non-holdout wells does not equal the total number of wells in the sc_profile dataframe"

sc_profile shape after mapping data_split: (98001, 2449)
holdout_df shape: (47488, 2449)


### Cells that have a single cell ground truth

In [8]:
cell_wout_ground_truth_df = non_holdout_wells.loc[
    non_holdout_wells["Metadata_ground_truth_present"] == False
].copy()
cell_w_ground_truth_df = non_holdout_wells.loc[
    non_holdout_wells["Metadata_ground_truth_present"] == True
].copy()

print(f"cell_w_ground_truth_df shape: {cell_w_ground_truth_df.shape}")
print(f"cell_wout_ground_truth_df shape: {cell_wout_ground_truth_df.shape}")

cell_w_ground_truth_df shape: (15775, 2449)
cell_wout_ground_truth_df shape: (82226, 2449)


##

In [9]:
# Group by Metadata_sc_unique_track_id
grouped = cell_w_ground_truth_df.groupby("Metadata_sc_unique_track_id")

# Create a representative DataFrame for splitting
train_groups = grouped.first().reset_index()


# Get the indices for each split
train_indices = cell_w_ground_truth_df[
    cell_w_ground_truth_df["Metadata_sc_unique_track_id"].isin(
        train_groups["Metadata_sc_unique_track_id"]
    )
].index


# Create the train, validation, and test DataFrames
train_sc_w_ground_truth_df = cell_w_ground_truth_df.loc[train_indices].copy()

# Assign metadata for data splits
train_sc_w_ground_truth_df["Metadata_data_split"] = "train"
train_sc_w_ground_truth_df["Metadata_ground_truth_present"] = True


# Print the shapes of the splits
print(f"train_sc_w_ground_truth_df shape: {train_sc_w_ground_truth_df.shape[0]}")

# Add to records
index_data_split_and_ground_truth_dict["index"].append(
    train_sc_w_ground_truth_df.index.tolist()
)
index_data_split_and_ground_truth_dict["data_split"].append(
    train_sc_w_ground_truth_df["Metadata_data_split"].tolist()
)
index_data_split_and_ground_truth_dict["ground_truth"].append(
    train_sc_w_ground_truth_df["Metadata_ground_truth_present"].tolist()
)

train_sc_w_ground_truth_df shape: 15775


#### Non tracked cells

In [10]:
cell_wout_ground_truth_df["Metadata_data_split"] = "well_holdout"
# add to records
index_data_split_and_ground_truth_dict["index"].append(
    cell_wout_ground_truth_df.index.tolist()
)
index_data_split_and_ground_truth_dict["data_split"].append(
    cell_wout_ground_truth_df["Metadata_data_split"].tolist()
)
index_data_split_and_ground_truth_dict["ground_truth"].append(
    cell_wout_ground_truth_df["Metadata_ground_truth_present"].tolist()
)
print(f"test_sc_wo_ground_truth_df shape: {cell_wout_ground_truth_df.shape[0]}")

test_sc_wo_ground_truth_df shape: 82226


### Fetch the indices from each ground truth and data split and add the status back to sc_profile

In [11]:
# flatten each list in the dictionar

for key in index_data_split_and_ground_truth_dict.keys():
    if isinstance(index_data_split_and_ground_truth_dict[key], list):
        # Flatten the list of lists into a single list
        index_data_split_and_ground_truth_dict[key] = list(
            itertools.chain.from_iterable(index_data_split_and_ground_truth_dict[key])
        )
data_split_data_df = pd.DataFrame.from_dict(
    index_data_split_and_ground_truth_dict,
    orient="columns",
)
print(f"data_split_data_df shape: {data_split_data_df.shape}")
print(f"sc profile shape: {sc_profile.shape}")
assert data_split_data_df.shape[0] == sc_profile.shape[0]

data_split_data_df shape: (145489, 3)
sc profile shape: (145489, 2449)


In [12]:
# sort the dataframe by index
data_split_data_df.sort_values(
    by=["index"],
    inplace=True,
)
# make the index the index column in data_split_data_df
data_split_data_df.set_index("index", inplace=True)
data_split_data_df.reset_index(drop=False, inplace=True)
data_split_data_df.head()

Unnamed: 0,index,data_split,ground_truth
0,0,well_holdout,False
1,1,well_holdout,False
2,2,well_holdout,True
3,3,well_holdout,False
4,4,well_holdout,False


In [13]:
# add the data_split and ground truth columns to the sc_profile dataframe by index
sc_profile_with_data_splits_df = pd.concat(
    [sc_profile, data_split_data_df],
    axis=1,
)
sc_profile_with_data_splits_df.drop(
    columns=["Metadata_data_split", "Metadata_ground_truth_present"],
    inplace=True,
)
sc_profile_with_data_splits_df.rename(
    columns={
        "data_split": "Metadata_data_split",
        "ground_truth": "Metadata_ground_truth_present",
    },
    inplace=True,
)

In [14]:
# final breakdown of the data
train_gt = sc_profile_with_data_splits_df[
    sc_profile_with_data_splits_df["Metadata_data_split"] == "train"
].copy()
train_gt = train_gt[train_gt["Metadata_ground_truth_present"] == True].copy()

holdout_w_gt = sc_profile_with_data_splits_df[
    sc_profile_with_data_splits_df["Metadata_data_split"] == "well_holdout"
].copy()
holdout_w_gt = holdout_w_gt[
    holdout_w_gt["Metadata_ground_truth_present"] == True
].copy()
holdout_wo_gt = sc_profile_with_data_splits_df[
    sc_profile_with_data_splits_df["Metadata_data_split"] == "well_holdout"
].copy()
holdout_wo_gt = holdout_wo_gt[
    holdout_wo_gt["Metadata_ground_truth_present"] == False
].copy()
# assertion time :)
assert sc_profile_with_data_splits_df.shape[0] == sc_profile.shape[0]
assert (
    sc_profile_with_data_splits_df.shape[0]
    == train_gt.shape[0] + holdout_w_gt.shape[0] + holdout_wo_gt.shape[0]
)

In [15]:
print(f"train_gt shape: {train_gt.shape[0]}")

print(f"holdout_w_gt shape: {holdout_w_gt.shape[0]}")
print(f"holdout_wo_gt shape: {holdout_wo_gt.shape[0]}")

train_gt shape: 15775
holdout_w_gt shape: 6899
holdout_wo_gt shape: 122815


In [16]:
data_split_data_df.head()

Unnamed: 0,index,data_split,ground_truth
0,0,well_holdout,False
1,1,well_holdout,False
2,2,well_holdout,True
3,3,well_holdout,False
4,4,well_holdout,False


## Now split the data into all splits including X, Y, and metadata
Each of the splits will saved to be loaded for hyperparameter optimization, training, and testing.

In [17]:
# read in the data
sc_file_path = pathlib.Path("../results/cleaned_sc_profile.parquet").resolve(
    strict=True
)
sc_endpoint_file_path = pathlib.Path(
    "../results/cleaned_endpoint_sc_profile.parquet"
).resolve(strict=True)


sc_profile = pd.read_parquet(sc_file_path)
sc_endpoint_profile = pd.read_parquet(sc_endpoint_file_path)
print(f"sc_profile shape: {sc_profile.shape}")
print(f"sc_endpoint_profile shape: {sc_endpoint_profile.shape}")

sc_profile shape: (145489, 2447)
sc_endpoint_profile shape: (31332, 284)


In [18]:
# merge the sc_profile and data_split_df
sc_profile = pd.concat(
    [
        sc_profile,
        data_split_data_df[["ground_truth", "data_split"]],
    ],
    axis=1,
)
sc_profile.rename(
    columns={
        "ground_truth": "Metadata_ground_truth",
        "data_split": "Metadata_data_split",
    },
    inplace=True,
)

In [19]:
# keep only the last timepoint
sc_profile["Metadata_Time"] = sc_profile["Metadata_Time"].astype("float64")
sc_profile = sc_profile[
    sc_profile["Metadata_Time"] == sc_profile["Metadata_Time"].max()
]
# drop Na values
sc_profile.dropna(inplace=True)
print(f"sc_profile shape after dropping NaN: {sc_profile.shape}")
sc_endpoint_profile.dropna(inplace=True)
print(f"sc_endpoint_profile shape after dropping NaN: {sc_endpoint_profile.shape}")

sc_profile shape after dropping NaN: (10274, 2449)
sc_endpoint_profile shape after dropping NaN: (29832, 284)


In [20]:
# hardcode the features that should exist in the y data
# this will be replaced in the future by an arg or config passed through
selected_y_features = ["Cytoplasm_Intensity_IntegratedIntensity_AnnexinV"]
# or Cytoplasm_Intensity_MaxIntensity_AnnexinV
metadata_y_features = [x for x in sc_endpoint_profile.columns if "Metadata_" in x]
sc_endpoint_profile = sc_endpoint_profile[metadata_y_features + selected_y_features]

In [21]:
train_gt_X = sc_profile.loc[
    (sc_profile["Metadata_data_split"] == "train")
    & (sc_profile["Metadata_ground_truth"] == True)
]

holdout_w_gt_X = sc_profile.loc[
    (sc_profile["Metadata_data_split"] == "well_holdout")
    & (sc_profile["Metadata_ground_truth"] == True)
]
holdout_wo_gt_X = sc_profile.loc[
    (sc_profile["Metadata_data_split"] == "well_holdout")
    & (sc_profile["Metadata_ground_truth"] == False)
]
print(f"total cells: {sc_profile.shape[0]}")
print(f"train_gt_X shape: {train_gt_X.shape}")
print(f"holdout_w_gt_X shape: {holdout_w_gt_X.shape}")
print(f"holdout_wo_gt_X shape: {holdout_wo_gt_X.shape}")

total cells: 10274
train_gt_X shape: (1363, 2449)
holdout_w_gt_X shape: (587, 2449)
holdout_wo_gt_X shape: (8324, 2449)


In [22]:
total_gt_cells = train_gt_X.shape[0] + holdout_w_gt_X.shape[0]
print(f"cells with ground truth: {total_gt_cells}")
print(
    f"Percentage of cells with ground truth: {total_gt_cells / sc_profile.shape[0] * 100:.2f}%"
)

cells with ground truth: 1950
Percentage of cells with ground truth: 18.98%


In [23]:
# now let us get the the Metadata_sc_unique_track_id for each of the data splits with gt
train_df_x_Metadata_sc_unique_track_id = train_gt_X[
    "Metadata_sc_unique_track_id"
].unique()

holdout_w_df_x_Metadata_sc_unique_track_id = holdout_w_gt_X[
    "Metadata_sc_unique_track_id"
].unique()
print(
    f"train_df_x_Metadata_sc_unique_track_id shape: {train_df_x_Metadata_sc_unique_track_id.shape}"
)

print(
    f"holdout_w_df_x_Metadata_sc_unique_track_id shape: {holdout_w_df_x_Metadata_sc_unique_track_id.shape}"
)
# assertions :) make sure that the unique track ids are not overlapping


assert set(train_df_x_Metadata_sc_unique_track_id).isdisjoint(
    set(holdout_w_df_x_Metadata_sc_unique_track_id)
), "train and holdout track ids are overlapping"

train_df_x_Metadata_sc_unique_track_id shape: (1363,)
holdout_w_df_x_Metadata_sc_unique_track_id shape: (587,)


In [24]:
# find only the cell tracks that exist in the sc_profile
# train w gt
train_gt_y = sc_endpoint_profile.loc[
    sc_endpoint_profile["Metadata_sc_unique_track_id"].isin(
        train_df_x_Metadata_sc_unique_track_id
    )
].drop_duplicates("Metadata_sc_unique_track_id")


# holdout w gt
holdout_gt_y = sc_endpoint_profile.loc[
    sc_endpoint_profile["Metadata_sc_unique_track_id"].isin(
        holdout_w_df_x_Metadata_sc_unique_track_id
    )
].drop_duplicates("Metadata_sc_unique_track_id")


# find only cell tracks that exist in the endpoint profile
train_gt_X = train_gt_X.loc[
    train_gt_X["Metadata_sc_unique_track_id"].isin(
        train_gt_y["Metadata_sc_unique_track_id"]
    )
].drop_duplicates("Metadata_sc_unique_track_id")

holdout_w_gt_X = holdout_w_gt_X.loc[
    holdout_w_gt_X["Metadata_sc_unique_track_id"].isin(
        holdout_gt_y["Metadata_sc_unique_track_id"]
    )
].drop_duplicates("Metadata_sc_unique_track_id")

In [25]:
print(f"train_y_gt shape: {train_gt_y.shape}, train_gt_X shape: {train_gt_X.shape}")
print(
    f"holdout_y_gt shape: {holdout_gt_y.shape}, holdout_gt_X shape: {holdout_w_gt_X.shape}"
)
# assertions :) make sure that the number of unique samples are the same
# make sure that the X data does not overlap with other data splits


assert set(train_gt_X.index).isdisjoint(
    set(holdout_w_gt_X.index)
), "train and holdout data are overlapping"


assert (
    train_gt_X.shape[0] == train_gt_y.shape[0]
), "train gt X and y shapes are not the same"
assert (
    holdout_w_gt_X.shape[0] == holdout_gt_y.shape[0]
), "holdout gt X and y shapes are not the same"

train_y_gt shape: (1302, 25), train_gt_X shape: (1302, 2449)
holdout_y_gt shape: (561, 25), holdout_gt_X shape: (561, 2449)


In [26]:
# get metadata
metadata_X_cols = [x for x in train_gt_X.columns if "Metadata_" in x]
metadata_y_cols = [x for x in train_gt_y.columns if "Metadata_" in x]

# train
train_gt_X_metadata = train_gt_X[metadata_X_cols]
train_gt_X.drop(columns=metadata_X_cols, inplace=True)
train_gt_y_metadata = train_gt_y[metadata_y_cols]
train_gt_y.drop(columns=metadata_y_cols, inplace=True)


# holdout w gt
holdout_w_gt_X_metadata = holdout_w_gt_X[metadata_X_cols]
holdout_w_gt_X.drop(columns=metadata_X_cols, inplace=True)
holdout_w_gt_y_metadata = holdout_gt_y[metadata_y_cols]
holdout_gt_y.drop(columns=metadata_y_cols, inplace=True)

# note there is not wo gt for test and holdout
# get the indices
train_gt_X_metadata_index = train_gt_X_metadata.index
train_gt_y_metadata_index = train_gt_y_metadata.index


holdout_w_gt_X_metadata_index = holdout_w_gt_X_metadata.index
holdout_w_gt_y_metadata_index = holdout_w_gt_y_metadata.index

holdout_wo_gt_X_metadata_index = holdout_wo_gt_X.index
dict_of_index_lists = {
    "index_lists": [
        train_gt_X_metadata_index.tolist(),
        train_gt_y_metadata_index.tolist(),
        holdout_w_gt_X_metadata_index.tolist(),
        holdout_w_gt_y_metadata_index.tolist(),
        holdout_wo_gt_X_metadata_index.tolist(),
    ],
    "data_split": [
        "train_gt",
        "train_gt",
        "holdout_gt",
        "holdout_gt",
        "holdout_wo_gt",
    ],
    "data_x_or_y": [
        "X",
        "y",
        "X",
        "y",
        "X",
    ],
}
dict_for_df = {
    "index": [],
    "data_split": [],
    "data_x_or_y": [],
}
for list_item in enumerate(dict_of_index_lists["index_lists"]):
    dict_for_df["index"].append(list_item[1])
    dict_for_df["data_split"].extend(
        [dict_of_index_lists["data_split"][list_item[0]]] * len(list_item[1])
    )
    dict_for_df["data_x_or_y"].extend(
        [dict_of_index_lists["data_x_or_y"][list_item[0]]] * len(list_item[1])
    )

# # flatten each list in the dictionary

dict_for_df["index"] = list(itertools.chain.from_iterable(dict_for_df["index"]))

# create a dataframe from the dictionary
index_data_split_and_x_y_df = pd.DataFrame.from_dict(
    dict_for_df,
    orient="columns",
)
index_data_split_and_x_y_df.head()

Unnamed: 0,index,data_split,data_x_or_y
0,7260,train_gt,X
1,7261,train_gt,X
2,7263,train_gt,X
3,7264,train_gt,X
4,7265,train_gt,X


In [27]:
index_data_split_and_x_y_df["data_split"].unique()

array(['train_gt', 'holdout_gt', 'holdout_wo_gt'], dtype=object)

In [None]:
# write the data splits to a parquet file
# this writes the indexes, ground truth, and data splits to a parquet file
# we do not write the sc_profile dataframe to a parquet file
data_split_file_path = pathlib.Path("../results/data_splits.parquet").resolve()
index_data_split_and_x_y_df.to_parquet(
    data_split_file_path,
    index=False,
)

: 