# ML-ImmuneProfiler

## Single-Cell RNA-seq Cell Type Classification: Notebook 2

## Dataset preparation and splitting for training and testing (utility script)

- Date: 2025-04-17

This notebook implements a comprehensive data processing pipeline for single-cell RNA sequencing analysis, including:

1. **Environment Setup**: Imports necessary libraries and configures warning suppressions
2. **Directory Structure**: Establishes file paths for data, results, models, and plots
3. **Data Loading**: Loads preprocessed gene expression data with the top 100 highly variable genes
4. **Dataset Preparation**: 
    - Converts data to pandas DataFrames
    - Implements function to prepare supervised learning datasets
    - Encodes cell type labels and performs stratified train/validation/test splits (70%/15%/15%)
    - Calculates class weights for handling imbalanced cell populations
5. **Validation**: Verifies dataset integrity by checking class consistency across splits
6. **Persistence**: Saves the processed dataset for use in subsequent analysis notebooks

This notebook serves as a utility script to ensure standardized data preparation for all machine learning models used in this project.


In [None]:
# ------ Module Import ------ #

# Default imports
import os
import pickle
import joblib
from collections import Counter

# Warning suppression
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
warnings.filterwarnings("ignore", category=FutureWarning, module="seaborn")
warnings.filterwarnings("ignore", message="n_jobs value .* overridden to .* by setting random_state.*")

# External imports
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import LabelEncoder




Setup file paths and random seed

In [3]:
# Global Variables
RAND_SEED: int = 16         # it's my birthday :)

DATA_DIR: str = os.path.join(os.path.dirname(os.getcwd()), 'Data')
RESULTS_DIR: str = os.path.join(os.path.dirname(os.getcwd()), 'Results')
PLOT_DIR: str = os.path.join(os.path.dirname(os.getcwd()), 'Plots')
MODEL_DIR: str = os.path.join(os.path.dirname(os.getcwd()), 'Models')
# Create directories if they do not exist
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(PLOT_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

## Data Processing Pipeline: From Raw Data to Machine Learning Ready Datasets

This section describes the data loading and preparation steps for our single-cell RNA sequencing analysis:

1. **Dataset Loading**: We load preprocessed gene expression data containing the top 100 highly variable genes (HVGs) from pickle files along with corresponding metadata.

2. **Data Organization**: The data is converted into pandas DataFrames (`df_100hgv` for expression values and `df_100hgv_meta` for cell type annotations).

3. **Data Preparation**: The `prepare_supervised_dataset()` function:
    - Separates features (gene expression) and labels (cell types)
    - Encodes cell type labels numerically
    - Performs stratified train/validation/test splits (70%/15%/15%)
    - Ensures class consistency across splits
    - Calculates class weights to handle imbalanced cell populations
    - Returns a comprehensive dictionary with all dataset components

4. **Dataset Verification**: The integrity of the splits is verified by:
    - Confirming consistent cell type representation across all splits
    - Analyzing class distribution ratios
    - Checking for potential cross-validation issues
    - Ensuring sufficient samples per class in each split

5. **Dataset Exploration**: The final dataset contains:
    - 53,825 training samples
    - 11,534 validation samples
    - 11,534 test samples
    - 14 distinct cell types with varying representations
    - 100 gene features per sample



In [34]:
dataset_100hgv = pickle.load(open(os.path.join(DATA_DIR, 'supervised_data_100hvg.pkl'), 'rb'))
dataset_100hgv_metadata = pickle.load(open(os.path.join(DATA_DIR, 'supervised_data_100hvg_metadata.pkl'), 'rb'))

In [35]:
# Convert the dataset to a pandas DataFrame
df_100hgv = pd.DataFrame(dataset_100hgv)
df_100hgv_meta = pd.DataFrame(dataset_100hgv_metadata)


In [36]:
df_100hgv.head()

Unnamed: 0_level_0,LYZ,HLA-DRA,S100A9,CD74,GNLY,NKG7,CTSS,S100A8,S100A4,CCL5,...,KLRD1,NPC2,NAP1L1,TSPO,ANXA2,DUSP6,LY6E,TKT,NOSIP,CellType
PPBP,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
-0.067185,4.764285,2.957022,5.596175,2.148582,-0.07227,-0.047454,4.362393,4.969619,3.909285,-0.640449,...,-0.090622,1.951062,0.308107,2.554441,0.78935,2.337746,0.82807,2.484555,1.105315,CD14 Mono
-0.090013,4.147419,3.447765,3.366856,3.755852,0.005502,-0.040562,3.817119,2.211017,3.731036,0.063988,...,0.035326,2.260142,1.248083,1.380529,1.897829,2.36006,1.897827,2.581988,0.788886,CD14 Mono
0.01369,4.705282,2.494352,5.02175,2.587237,-0.030905,0.0133,4.060295,4.222051,2.973074,0.286722,...,-0.005076,1.169653,1.795462,1.560957,1.345841,2.453415,0.918939,2.257665,0.656441,CD14 Mono
0.05879,4.741841,0.769762,5.458511,1.660862,0.656104,-0.062104,3.239459,4.821338,3.688125,0.423983,...,0.008045,0.84213,0.536949,2.514377,1.60481,0.564361,1.885159,1.595583,1.253817,CD14 Mono
0.005776,5.081289,2.117163,5.781392,2.343518,-0.046718,0.114207,3.887106,5.049965,3.536385,0.008269,...,-0.031938,0.888212,1.286644,1.963565,1.736806,2.420097,0.39423,2.42513,-0.022348,CD14 Mono


In [37]:
df_100hgv_meta

Unnamed: 0,dataset_name,creation_date,num_features,num_samples,num_classes,feature_type,class_distribution
B intermediate,supervised_training_data_100hvg,2025-04-17 08:38:28,100,76893,14,Highly Variable Genes (HVGs),3483
B memory,supervised_training_data_100hvg,2025-04-17 08:38:28,100,76893,14,Highly Variable Genes (HVGs),2193
B naive,supervised_training_data_100hvg,2025-04-17 08:38:28,100,76893,14,Highly Variable Genes (HVGs),7656
CD14 Mono,supervised_training_data_100hvg,2025-04-17 08:38:28,100,76893,14,Highly Variable Genes (HVGs),9670
CD16 Mono,supervised_training_data_100hvg,2025-04-17 08:38:28,100,76893,14,Highly Variable Genes (HVGs),6351
CD4 Naive,supervised_training_data_100hvg,2025-04-17 08:38:28,100,76893,14,Highly Variable Genes (HVGs),7114
CD4 TCM,supervised_training_data_100hvg,2025-04-17 08:38:28,100,76893,14,Highly Variable Genes (HVGs),9250
CD4 TEM,supervised_training_data_100hvg,2025-04-17 08:38:28,100,76893,14,Highly Variable Genes (HVGs),1335
CD8 Naive,supervised_training_data_100hvg,2025-04-17 08:38:28,100,76893,14,Highly Variable Genes (HVGs),9072
CD8 TCM,supervised_training_data_100hvg,2025-04-17 08:38:28,100,76893,14,Highly Variable Genes (HVGs),2400


In [40]:
def prepare_supervised_dataset(
    df: pd.DataFrame,
    label_col: str = 'CellType',
    meta_cols: list = ['CellType'],
    test_size: float = 0.15,
    val_size: float = 0.15,
    random_state: int = RAND_SEED
):
    """
    Prepares a supervised learning dataset from single-cell data.
    
    Returns:
        Dictionary containing:
            - X_train, X_val, X_test: np.ndarray
            - y_train, y_val, y_test: np.ndarray (encoded)
            - meta_train, meta_val, meta_test: pd.DataFrame
            - label_encoder: fitted LabelEncoder
            - class_weights: dict
            - feature_genes: List of genes used
    """
    print(" Starting supervised dataset preparation...")

    # Step 1: Define expression and label columns
    expr_cols = [col for col in df.columns if col not in meta_cols]
    X_expr = df[expr_cols].astype(np.float32)
    y_raw = df[label_col]

    # Step 2: We'll use all the available HVGs (already selected in the dataframe)
    print(f" Using {len(expr_cols)} genes in the dataset...")
    X = X_expr

    # Step 3: Encode labels
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y_raw)
    label_names = label_encoder.classes_
    print(f" Found {len(label_names)} unique cell types.")

    # Step 4: Stratified splits
    print(" Splitting into train/val/test sets...")
    X_temp, X_test, y_temp, y_test, meta_temp, meta_test = train_test_split(
        X, y_encoded, df[meta_cols], test_size=test_size,
        stratify=y_encoded, random_state=random_state
    )
    val_fraction = val_size / (1 - test_size)
    X_train, X_val, y_train, y_val, meta_train, meta_val = train_test_split(
        X_temp, y_temp, meta_temp, test_size=val_fraction,
        stratify=y_temp, random_state=random_state
    )

    # Step 5: Align common classes across all sets (fix for cuML)
    print(" Aligning common classes across splits (cuML compatibility)...")
    common_classes = np.intersect1d(np.unique(y_train), np.unique(y_val))
    mask_train = np.isin(y_train, common_classes)
    mask_val = np.isin(y_val, common_classes)
    mask_test = np.isin(y_test, common_classes)

    X_train, y_train, meta_train = X_train[mask_train], y_train[mask_train], meta_train.iloc[mask_train]
    X_val, y_val, meta_val = X_val[mask_val], y_val[mask_val], meta_val.iloc[mask_val]
    X_test, y_test, meta_test = X_test[mask_test], y_test[mask_test], meta_test.iloc[mask_test]

    # Step 6: Recalculate class weights
    class_weights_array = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
    class_weights = dict(zip(np.unique(y_train), class_weights_array))

    # Step 7: Display summary
    print("\n Final dataset sizes:")
    print(f"  Train: {X_train.shape[0]} samples")
    print(f"  Val:   {X_val.shape[0]} samples")
    print(f"  Test:  {X_test.shape[0]} samples")

    print("\n Class distribution in Train:")
    for cls, count in sorted(Counter(y_train).items()):
        print(f"  {label_names[cls]:<25}: {count} cells")

    print("\n Class Weights:")
    for cls, weight in class_weights.items():
        print(f"  {label_names[cls]:<25}: {weight:.2f}")

    return {
        "X_train": X_train, "X_val": X_val, "X_test": X_test,
        "y_train": y_train, "y_val": y_val, "y_test": y_test,
        "meta_train": meta_train, "meta_val": meta_val, "meta_test": meta_test,
        "label_encoder": label_encoder,
        "class_weights": class_weights,
        "feature_genes": list(expr_cols)
    }


In [41]:
spv_dataset = prepare_supervised_dataset(
    df_100hgv,
    label_col='CellType',
    meta_cols=['CellType'],
    test_size=0.15,
    val_size=0.15,
    random_state=RAND_SEED
)

joblib.dump(spv_dataset, os.path.join(DATA_DIR, 'spv_split_dataset_100hvg.pkl'))
print(f"Saved dataset to {os.path.join(DATA_DIR, 'spv_split_dataset_100hvg.pkl')}")

 Starting supervised dataset preparation...
 Using 100 genes in the dataset...
 Found 14 unique cell types.
 Splitting into train/val/test sets...
 Aligning common classes across splits (cuML compatibility)...

 Final dataset sizes:
  Train: 53825 samples
  Val:   11534 samples
  Test:  11534 samples

 Class distribution in Train:
  B intermediate           : 2438 cells
  B memory                 : 1535 cells
  B naive                  : 5360 cells
  CD14 Mono                : 6769 cells
  CD16 Mono                : 4445 cells
  CD4 Naive                : 4980 cells
  CD4 TCM                  : 6475 cells
  CD4 TEM                  : 935 cells
  CD8 Naive                : 6350 cells
  CD8 TCM                  : 1680 cells
  CD8 TEM                  : 6045 cells
  Dendritic cells          : 1838 cells
  NK                       : 4259 cells
  Treg                     : 716 cells

 Class Weights:
  B intermediate           : 1.58
  B memory                 : 2.50
  B naive               

## Load the Dataset and Verify Integrity

This section loads the preprocessed supervised dataset from disk and performs comprehensive validation checks. The code:

1. **Loads the dataset** from a pickled file containing stratified train/validation/test splits with all necessary components:
    - Feature matrices (X_train, X_val, X_test) containing gene expression data
    - Target vectors (y_train, y_val, y_test) with encoded cell type labels
    - Metadata dataframes with original cell type names
    - Label encoder for mapping between numeric and text labels
    - Class weights to handle imbalanced data distributions
    - Feature genes list (100 highly variable genes)

2. **Verifies dataset integrity** through:
    - Checking class consistency across all splits
    - Analyzing class distribution proportions across splits
    - Identifying potential cross-validation issues
    - Ensuring sufficient samples per class
    - Validating the stratified splitting ratios (70%/15%/15%)

3. **Examines dataset characteristics** including:
    - Dataset shapes and dimensions
    - Class encoding mappings
    - Class distributions and imbalances
    - Feature gene information
    - Data quality checks (e.g., detecting NaN values)



In [None]:
# load the dataset
spv_dataset = joblib.load(os.path.join(DATA_DIR, 'spv_split_dataset_100hvg.pkl'))

# verify the integrity of cell types in test, train and validation sets
def verify_cell_types_integrity(spv_dataset):
    """
    Verify the integrity of cell types in the train, validation, and test sets.
    """
    print("Verifying cell types integrity...")
    
    # Check unique classes in each split
    for split in ['train', 'val', 'test']:
        y_split = spv_dataset[f'y_{split}']
        unique_classes = np.unique(y_split)
        print(f"Unique classes in {split} set: {unique_classes}")
    
    # Check for common classes across all splits
    train_classes = set(np.unique(spv_dataset['y_train']))
    val_classes = set(np.unique(spv_dataset['y_val']))
    test_classes = set(np.unique(spv_dataset['y_test']))
    
    common_classes = train_classes.intersection(val_classes, test_classes)
    print(f"Common classes across all splits: {sorted(list(common_classes))}")
    
    # Check class distributions
    label_encoder = spv_dataset['label_encoder']
    class_names = label_encoder.classes_
    
    print("\nClass distribution ratios (% of total):")
    print(f"{'Class':<15} {'Train %':<10} {'Val %':<10} {'Test %':<10} {'Total %':<10}")
    print("-" * 55)
    
    total_train = len(spv_dataset['y_train'])
    total_val = len(spv_dataset['y_val'])
    total_test = len(spv_dataset['y_test'])
    total_all = total_train + total_val + total_test
    
    for i, class_name in enumerate(class_names):
        train_count = np.sum(spv_dataset['y_train'] == i)
        val_count = np.sum(spv_dataset['y_val'] == i)
        test_count = np.sum(spv_dataset['y_test'] == i)
        total_count = train_count + val_count + test_count
        
        train_ratio = train_count / total_train * 100
        val_ratio = val_count / total_val * 100
        test_ratio = test_count / total_test * 100
        total_ratio = total_count / total_all * 100
        
        print(f"{class_name:<15} {train_ratio:>8.2f}% {val_ratio:>8.2f}% {test_ratio:>8.2f}% {total_ratio:>8.2f}%")
    
    # Check dataset split ratios
    print("\nDataset split ratios:")
    print(f"Train: {total_train / total_all * 100:.2f}% ({total_train} samples)")
    print(f"Val:   {total_val / total_all * 100:.2f}% ({total_val} samples)")
    print(f"Test:  {total_test / total_all * 100:.2f}% ({total_test} samples)")
    
    # Check for potential cross-validation issues
    print("\nPotential CV issues assessment:")
    
    # Check if any classes are missing from any split
    if len(train_classes) != len(class_names) or len(val_classes) != len(class_names) or len(test_classes) != len(class_names):
        print("Warning: Some classes are missing from one or more splits!")
        print(f"Missing from train: {sorted(list(set(range(len(class_names))) - train_classes))}")
        print(f"Missing from val: {sorted(list(set(range(len(class_names))) - val_classes))}")
        print(f"Missing from test: {sorted(list(set(range(len(class_names))) - test_classes))}")
    else:
        print("✓ All classes are represented in each split")
    
    # Check for severe class imbalance
    class_balance_issues = []
    for i, class_name in enumerate(class_names):
        train_count = np.sum(spv_dataset['y_train'] == i)
        val_count = np.sum(spv_dataset['y_val'] == i)
        test_count = np.sum(spv_dataset['y_test'] == i)
        
        if train_count < 5 or val_count < 5 or test_count < 5:
            class_balance_issues.append(f"{class_name} (train={train_count}, val={val_count}, test={test_count})")
    
    if class_balance_issues:
        print("Warning: Some classes have very few samples in some splits:")
        for issue in class_balance_issues:
            print(f"  - {issue}")
    else:
        print("✓ All classes have sufficient samples in each split")
    
    # Check split ratios per class
    print("\nSplit ratios per class:")
    print(f"{'Class':<15} {'Train %':<10} {'Val %':<10} {'Test %':<10}")
    print("-" * 45)
    
    for i, class_name in enumerate(class_names):
        train_count = np.sum(spv_dataset['y_train'] == i)
        val_count = np.sum(spv_dataset['y_val'] == i)
        test_count = np.sum(spv_dataset['y_test'] == i)
        class_total = train_count + val_count + test_count
        
        class_train_ratio = train_count / class_total * 100
        class_val_ratio = val_count / class_total * 100
        class_test_ratio = test_count / class_total * 100
        
        print(f"{class_name:<15} {class_train_ratio:>8.2f}% {class_val_ratio:>8.2f}% {class_test_ratio:>8.2f}%")
    
    print("\nIntegrity verification complete!")

verify_cell_types_integrity(spv_dataset)

Verifying cell types integrity...
Unique classes in train set: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13]
Unique classes in val set: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13]
Unique classes in test set: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13]
Common classes across all splits: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]

Class distribution ratios (% of total):
Class           Train %    Val %      Test %     Total %   
-------------------------------------------------------
B intermediate      4.53%     4.53%     4.53%     4.53%
B memory            2.85%     2.85%     2.85%     2.85%
B naive             9.96%     9.95%     9.95%     9.96%
CD14 Mono          12.58%    12.57%    12.58%    12.58%
CD16 Mono           8.26%     8.26%     8.26%     8.26%
CD4 Naive           9.25%     9.25%     9.25%     9.25%
CD4 TCM            12.03%    12.03%    12.03%    12.03%
CD4 TEM             1.74%     1.73%     1.73%     1.74%
CD8 Naive          11.80%    11.80%    11.80%    11.80%
CD8 TCM 