# Split each plate from the batch data into training, testing, and holdout data

In [1]:
import pathlib
import random

import pprint

import pandas as pd
from sklearn.model_selection import train_test_split

## Set paths and variables

In [2]:
# Set random state for the whole notebook to ensure reproducibility
random_state=0
random.seed(random_state)

# Path to directory with feature selected profiles
path_to_feature_selected_data = pathlib.Path(
    "../3.preprocessing_features/data/single_cell_profiles/"
).resolve(strict=True)

# Find all feature selected parquet files
feature_selected_files = list(path_to_feature_selected_data.glob("*_feature_selected.parquet"))

# Make directory for split data
output_dir = pathlib.Path("./data")
output_dir.mkdir(exist_ok=True)

## Load in feature selected data

In [3]:
# Load in all feature selected files as dataframes
feature_selected_dfs_dict = {
    pathlib.Path(file).stem.split('_')[0]: pd.read_parquet(file) for file in feature_selected_files
}

pprint.pprint(feature_selected_dfs_dict, indent=4)

{   'localhost240926150001':       Metadata_WellRow  Metadata_WellCol  Metadata_heart_number  \
0                    B                 2                      7   
1                    B                 2                      7   
2                    B                 2                      7   
3                    B                 2                      7   
4                    B                 2                      7   
...                ...               ...                    ...   
16561                G                11                     19   
16562                G                11                     19   
16563                G                11                     19   
16564                G                11                     19   
16565                G                11                     19   

      Metadata_cell_type Metadata_heart_failure_type Metadata_treatment  \
0                healthy                        None               DMSO   
1               

## For each dataframe, take only the DMSO cells and split 70/30 for training and testing

In [4]:
# Set the ratio of the test data to 30% (training data will be 70%)
test_ratio = 0.30

for plate, df in feature_selected_dfs_dict.items():
    # Filter only the rows with DMSO treatment for model training and testing
    DMSO_df = df[df.Metadata_treatment == "DMSO"]
    print(f"Plate: {plate} contains {DMSO_df.shape[0]} DMSO profiles")

    # Split data into training and test sets
    train_df, test_df = train_test_split(
        DMSO_df,
        test_size=test_ratio,
        stratify=DMSO_df[["Metadata_cell_type"]],
        random_state=random_state,
    )

    # Print the shapes of the training and testing data
    print(f"Training data shape: {train_df.shape}")
    print(f"Testing data shape: {test_df.shape}")

    # Save training and test data
    train_df.to_parquet(output_dir / f"{plate}_train.parquet")
    test_df.to_parquet(output_dir / f"{plate}_test.parquet")

Plate: localhost240928120001 contains 1573 DMSO profiles
Training data shape: (1101, 641)
Testing data shape: (472, 641)
Plate: localhost240927060001 contains 1526 DMSO profiles
Training data shape: (1068, 652)
Testing data shape: (458, 652)
Plate: localhost240927120001 contains 1514 DMSO profiles
Training data shape: (1059, 684)
Testing data shape: (455, 684)
Plate: localhost240926150001 contains 2022 DMSO profiles
Training data shape: (1415, 657)
Testing data shape: (607, 657)


## Combine the 4 plates together using the common morphology features

In [5]:
# Assuming output_dir is already defined
train_files = list(output_dir.glob("*_train.parquet"))
test_files = list(output_dir.glob("*_test.parquet"))

# Load files
train_dfs = [pd.read_parquet(f) for f in train_files]
test_dfs = [pd.read_parquet(f) for f in test_files]
all_dfs = train_dfs + test_dfs

# Get intersection of feature columns (excluding Metadata_) across the dataframes
common_features = set.intersection(*[
    set(df.columns[~df.columns.str.startswith("Metadata_")]) for df in all_dfs
])
print(len(common_features), "common features across all dataframes")

# Use metadata columns from first df
metadata_cols = [col for col in all_dfs[0].columns if col.startswith("Metadata_")]
all_cols = metadata_cols + sorted(common_features)

# Reindex with consistent columns
train_dfs = [df.reindex(columns=all_cols) for df in train_dfs]
test_dfs = [df.reindex(columns=all_cols) for df in test_dfs]

# Merge and save
combined_train_df = pd.concat(train_dfs, ignore_index=True)
combined_test_df = pd.concat(test_dfs, ignore_index=True)

combined_train_df.to_parquet(output_dir / "combined_batch1_train.parquet", index=False)
combined_test_df.to_parquet(output_dir / "combined_batch1_test.parquet", index=False)

print("Train shape:", combined_train_df.shape)
print("Test shape:", combined_test_df.shape)

# Print on dataframe to verify
combined_train_df.head()

474 common features across all dataframes
Train shape: (9286, 494)
Test shape: (3984, 494)


Unnamed: 0,Metadata_WellRow,Metadata_WellCol,Metadata_heart_number,Metadata_cell_type,Metadata_heart_failure_type,Metadata_treatment,Metadata_Pathway,Metadata_Nuclei_Location_Center_X,Metadata_Nuclei_Location_Center_Y,Metadata_Cells_Location_Center_X,...,Nuclei_Texture_InfoMeas2_PM_3_03_256,Nuclei_Texture_InverseDifferenceMoment_Hoechst_3_00_256,Nuclei_Texture_InverseDifferenceMoment_Hoechst_3_01_256,Nuclei_Texture_InverseDifferenceMoment_Hoechst_3_02_256,Nuclei_Texture_InverseDifferenceMoment_Hoechst_3_03_256,Nuclei_Texture_InverseDifferenceMoment_PM_3_00_256,Nuclei_Texture_InverseDifferenceMoment_PM_3_01_256,Nuclei_Texture_InverseDifferenceMoment_PM_3_02_256,Nuclei_Texture_InverseDifferenceMoment_PM_3_03_256,Nuclei_Texture_SumEntropy_PM_3_01_256
0,B,2,7,healthy,,DMSO,,586.110856,272.986789,585.812571,...,-0.097983,0.686213,0.0992,0.346359,0.410444,0.592052,0.660238,0.824522,0.602509,-0.493653
1,E,2,19,failing,dilated_cardiomyopathy,DMSO,,544.012459,1015.857049,540.169864,...,-1.011017,-1.032992,-0.807072,0.540309,-0.993311,0.464286,0.844515,0.388414,0.18837,-0.576802
2,E,2,19,failing,dilated_cardiomyopathy,DMSO,,767.384844,203.437593,777.096195,...,-0.769505,-1.469561,0.443859,-1.336991,-1.940503,-0.463749,0.336024,-0.436134,-1.128984,0.109277
3,E,2,19,failing,dilated_cardiomyopathy,DMSO,,937.835383,725.915984,892.558337,...,0.38065,0.876283,-0.108488,-0.453235,0.588815,0.258823,0.417707,-0.923615,-1.161158,0.552393
4,E,2,19,failing,dilated_cardiomyopathy,DMSO,,940.06644,420.07155,924.417976,...,-0.848936,0.75623,1.225858,0.875755,0.919894,0.74193,1.372527,0.573394,-0.009191,-0.26512
