# Split each plate from the batch data into training, testing, and holdout data

In [1]:
import pathlib
import random

import pprint

import pandas as pd
from sklearn.model_selection import train_test_split

## Set paths and variables

In [2]:
# Set random state for the whole notebook to ensure reproducibility
random.seed(0)

# Path to directory with feature selected profiles
path_to_feature_selected_data = pathlib.Path(
    "../3.preprocessing_features/data/single_cell_profiles/"
).resolve(strict=True)

# Find all feature selected parquet files
feature_selected_files = list(path_to_feature_selected_data.glob("*_feature_selected.parquet"))

# Make directory for split data
output_dir = pathlib.Path("./data")
output_dir.mkdir(exist_ok=True)

## Load in feature selected data

In [3]:
# Load in all feature selected files as dataframes
feature_selected_dfs_dict = {
    pathlib.Path(file).stem.split('_')[0]: pd.read_parquet(file) for file in feature_selected_files
}

pprint.pprint(feature_selected_dfs_dict, indent=4)

{   'localhost240926150001':       Metadata_WellRow  Metadata_WellCol  Metadata_heart_number  \
0                    B                 2                      7   
1                    B                 2                      7   
2                    B                 2                      7   
3                    B                 2                      7   
4                    B                 2                      7   
...                ...               ...                    ...   
16561                G                11                     19   
16562                G                11                     19   
16563                G                11                     19   
16564                G                11                     19   
16565                G                11                     19   

      Metadata_cell_type Metadata_heart_failure_type Metadata_treatment  \
0                healthy                        None               DMSO   
1               

## For each dataframe, take only the DMSO cells and split 70/30 for training and testing

In [4]:
# Set random state as 0 (same as the rest of the notebook)
random_state = 0

# Set the ratio of the test data to 30% (training data will be 70%)
test_ratio = 0.30

for plate, df in feature_selected_dfs_dict.items():
    # Filter only the rows with DMSO treatment for model training and testing
    DMSO_df = df[df.Metadata_treatment == "DMSO"]
    print(f"Plate: {plate} contains {DMSO_df.shape[0]} DMSO profiles")

    # Split data into training and test sets
    train_df, test_df = train_test_split(
        DMSO_df,
        test_size=test_ratio,
        stratify=DMSO_df[["Metadata_cell_type"]],
        random_state=random_state,
    )

    # Print the shapes of the training and testing data
    print(f"Training data shape: {train_df.shape}")
    print(f"Testing data shape: {test_df.shape}")

    # Save training and test data
    train_df.to_parquet(output_dir / f"{plate}_train.parquet")
    test_df.to_parquet(output_dir / f"{plate}_test.parquet")

Plate: localhost240928120001 contains 1573 DMSO profiles
Training data shape: (1101, 641)
Testing data shape: (472, 641)
Plate: localhost240927060001 contains 1526 DMSO profiles
Training data shape: (1068, 652)
Testing data shape: (458, 652)
Plate: localhost240927120001 contains 1514 DMSO profiles
Training data shape: (1059, 684)
Testing data shape: (455, 684)
Plate: localhost240926150001 contains 2022 DMSO profiles
Training data shape: (1415, 657)
Testing data shape: (607, 657)


## Combine the 4 plates together using the common morphology features

In [5]:
# Identify common feature columns (excluding Metadata_ columns)
common_features = set.intersection(*[
    set(df.columns[~df.columns.str.startswith("Metadata_")]) 
    for df in feature_selected_dfs_dict.values()
])

# Concat all filtered dataframes while keeping Metadata columns
combined_df = pd.concat(
    [df[list(common_features) + [col for col in df.columns if col.startswith("Metadata_")]] 
     for df in feature_selected_dfs_dict.values()],
    ignore_index=True
)

print(combined_df.shape)
combined_df.head()

(54610, 494)


Unnamed: 0,Cells_Texture_Correlation_PM_3_03_256,Cells_Intensity_IntegratedIntensityEdge_PM,Cells_Texture_Correlation_Hoechst_3_02_256,Cells_RadialDistribution_RadialCV_PM_4of4,Nuclei_Intensity_IntegratedIntensityEdge_Mitochondria,Nuclei_Texture_Correlation_Hoechst_3_03_256,Cytoplasm_Texture_InfoMeas1_Hoechst_3_00_256,Nuclei_AreaShape_Zernike_3_1,Nuclei_Intensity_MADIntensity_PM,Cells_AreaShape_Zernike_6_0,...,Metadata_Cells_Location_Center_Y,Metadata_Image_Count_Cells,Metadata_ImageNumber,Metadata_Plate,Metadata_Well,Metadata_Cells_Number_Object_Number,Metadata_Cytoplasm_Parent_Cells,Metadata_Cytoplasm_Parent_Nuclei,Metadata_Nuclei_Number_Object_Number,Metadata_Site
0,0.558955,1.586534,0.761087,-0.08355,-0.935522,0.073294,-0.342113,0.023656,0.119329,-0.543098,...,268.256144,9,2,localhost240928120001,B02,1,1,6,6,f04
1,0.153794,1.57517,0.349604,-1.163928,-0.403699,1.236566,0.191088,-0.889207,0.178704,-0.953293,...,204.845356,4,3,localhost240928120001,B02,1,1,3,3,f10
2,-0.895822,1.739962,-0.340866,-1.091112,-0.506543,0.061749,0.918389,0.899627,0.404541,-0.990414,...,333.104966,9,5,localhost240928120001,B02,1,1,3,3,f14
3,0.486792,0.505123,0.317325,-0.826691,-0.806719,-0.703684,-0.304492,-0.702015,0.330322,1.859774,...,88.534307,18,7,localhost240928120001,B02,1,1,3,3,f18
4,-0.544021,0.796886,0.493012,-0.475095,0.246624,1.46137,1.714124,1.796197,2.563246,-0.192049,...,96.955249,10,1,localhost240928120001,B02,1,1,4,4,f00


## Filter combined dataframe for only DMSO treated cells and split into training and testing dataframes

In [6]:
# Set random state as 0 (same as the rest of the notebook)
random_state = 0

# Set the ratio of the test data to 30% (training data will be 70%)
test_ratio = 0.30

# Filter only the DMSO treated cells
DMSO_combined_df = combined_df[combined_df.Metadata_treatment == "DMSO"]

# Split the combined data into training and test
training_data, testing_data = train_test_split(
    DMSO_combined_df,
    test_size=test_ratio,
    stratify=DMSO_combined_df[["Metadata_cell_type"]],
    random_state=random_state,
)

# View shapes and example output
print("The testing data contains", testing_data.shape[0], "single-cells.")
print("The training data contains", training_data.shape[0], "single-cells.")

# Save training and test data
training_data.to_parquet(output_dir / "combined_batch1_train.parquet")
testing_data.to_parquet(output_dir / "combined_batch1_test.parquet")

testing_data.head()

The testing data contains 1991 single-cells.
The training data contains 4644 single-cells.


Unnamed: 0,Cells_Texture_Correlation_PM_3_03_256,Cells_Intensity_IntegratedIntensityEdge_PM,Cells_Texture_Correlation_Hoechst_3_02_256,Cells_RadialDistribution_RadialCV_PM_4of4,Nuclei_Intensity_IntegratedIntensityEdge_Mitochondria,Nuclei_Texture_Correlation_Hoechst_3_03_256,Cytoplasm_Texture_InfoMeas1_Hoechst_3_00_256,Nuclei_AreaShape_Zernike_3_1,Nuclei_Intensity_MADIntensity_PM,Cells_AreaShape_Zernike_6_0,...,Metadata_Cells_Location_Center_Y,Metadata_Image_Count_Cells,Metadata_ImageNumber,Metadata_Plate,Metadata_Well,Metadata_Cells_Number_Object_Number,Metadata_Cytoplasm_Parent_Cells,Metadata_Cytoplasm_Parent_Nuclei,Metadata_Nuclei_Number_Object_Number,Metadata_Site
20878,0.643089,-0.987517,0.212148,-0.02332,-0.747686,0.998117,0.256104,-0.750681,-0.495985,-0.575258,...,695.687922,16,754,localhost240927060001,E11,13,13,18,18,f10
39718,0.515605,-0.20037,1.380977,0.836508,0.23717,0.785418,0.013624,-0.565253,0.685659,0.984653,...,735.260532,22,153,localhost240926150001,B11,15,15,24,24,f08
26235,0.661925,-0.151658,0.968283,-0.347531,-0.251971,-0.852752,-0.054869,-0.840579,0.334713,0.651846,...,823.755518,16,142,localhost240927120001,B11,14,14,24,24,f19
19207,0.260102,0.005521,-0.014662,-1.28144,0.171277,-0.625947,-1.738076,-0.631204,-0.421141,-0.208801,...,635.996658,11,624,localhost240927060001,E05,5,5,14,14,f12
25233,1.133778,2.631057,1.673325,1.032913,4.973389,2.30592,-0.835579,0.740312,1.890899,-0.53251,...,606.757209,23,3,localhost240927120001,B02,14,14,22,22,f10
