# Data Processing and Feature Engineering for Spatial Transcriptomics

> **Objective:** This notebook processes spatial transcriptomics data to prepare node features and edge lists for a dynamic graph-based machine learning model.

### Main Workflow Steps:
1.  **Configuration:** Centralizes all imports, file paths, and key parameters.
2.  **Node Feature Generation:** Processes gene expression data to create a feature vector for each cell (node).
3.  **Edge Inference with Palantir:** Uses the Palantir algorithm to infer cell-cell transition probabilities, forming the connections (edges) in the graph.
4.  **Edge Filtering:** Analyzes and filters the inferred edges to retain only the most significant connections.
5.  **Final Data Assembly:** Combines the filtered edges with cell metadata (time, type) to produce the final CSV and feature files ready for model training.

### 1. Configuration and Setup

**Objective:** To import all necessary libraries and define all input/output paths and key parameters in a centralized location. This makes the notebook easier to configure and maintain.

**Inputs:**
-   The required h5ad-related data files can be obtained through the **data_preparation.ipynb** file.

**Outputs:**
-   Defined and created output directories for storing results.
-   Variables containing paths and parameters for use throughout the notebook.

In [None]:
import anndata as ad
import pandas as pd
import numpy as np
import scanpy as sc
import pickle
import os
from palantir.utils import run_diffusion_maps, determine_multiscale_space, compute_kernel

# --- Input Paths ---
# Main data file for all steps
H5AD_SPATIAL_PATH = './data/spatiotemporal_mouse/mouse_spatial.h5ad'
# Cell to numeric ID mapping, derived from the main spatial file
CELL_MAP_PATH = './data/spatiotemporal_mouse/cell_map_spatial.csv'
# CellType to numeric ID mapping
CELLTYPE_MAP_PATH = "./data/spatiotemporal_mouse/celltype_map_spatial.csv"

# --- Output Base Path ---
# All results will be saved relative to this directory
OUTPUT_BASE_DIR = './result/spatiotemporal_mouse/'

# --- Create Output Directories ---
OUTPUT_FEATURES_DIR = os.path.join(OUTPUT_BASE_DIR, 'features')
OUTPUT_EDGES_DIR = os.path.join(OUTPUT_BASE_DIR, 'edges')
OUTPUT_FINAL_DIR = os.path.join(OUTPUT_BASE_DIR, 'final_data')

os.makedirs(OUTPUT_FEATURES_DIR, exist_ok=True)
os.makedirs(OUTPUT_EDGES_DIR, exist_ok=True)
os.makedirs(OUTPUT_FINAL_DIR, exist_ok=True)

# --- Output File Paths ---
HVG_GENE_INDEX_PATH = os.path.join(OUTPUT_FEATURES_DIR, 'HVG_1000_gene_index.csv')
NODE_FEATURES_PKL_PATH = os.path.join(OUTPUT_FEATURES_DIR, 'node_features.pkl')
RAW_EDGES_PATH = os.path.join(OUTPUT_EDGES_DIR, 'palantir_raw_edges.csv')
FILTERED_EDGES_PATH = os.path.join(OUTPUT_EDGES_DIR, 'palantir_filtered_edges.csv')
FINAL_CSV_PATH = os.path.join(OUTPUT_FINAL_DIR, 'ml_ready_data.csv')
FINAL_FEAT_PATH = os.path.join(OUTPUT_FINAL_DIR, 'ml_ready_features.npy')

# --- Key Parameters ---
N_TOP_GENES = 1200
FINAL_GENE_COUNT = 1000
PCA_N_COMPS = 50
PALANTIR_KNN = 30
PALANTIR_ALPHA = 10
EDGE_WEIGHT_THRESHOLD = 0.1
RANDOM_STATE = 42

print("Configuration complete. Output directories are set up.")

Configuration complete. Output directories are set up.


### 2. Node Feature Generation from Gene Expression

**Objective:** This step processes the entire spatial transcriptomics dataset to generate feature vectors for every cell. These features will be used later by the machine learning model.


In [12]:
# --- Load Data ---
adata = ad.read_h5ad(H5AD_SPATIAL_PATH)
cell_map_df = pd.read_csv(CELL_MAP_PATH)
cell_name_to_id = dict(zip(cell_map_df['Cell'], cell_map_df['ID']))

print("--- Initial Data ---")
print(adata)

# --- Preprocessing and Gene Selection ---
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=N_TOP_GENES, flavor='cell_ranger')

# --- Ensure Required Genes are Included ---
required_genes = ["Igf2", "Plagl1", "Tcf12", "Wt1", "Srebf1"]
required_genes_upper = [g.upper() for g in required_genes]
gene_id_upper = adata.var["gene_id"].astype(str).str.upper()
required_gene_ids = adata.var[gene_id_upper.isin(required_genes_upper)].index.tolist()
print(f"✅ Required genes found in data: {required_gene_ids}")

# Combine HVGs with required genes
hvg_gene_ids = adata.var[adata.var['highly_variable']].index.tolist()
combined_gene_ids = list(set(hvg_gene_ids).union(required_gene_ids))
final_gene_ids = combined_gene_ids[:FINAL_GENE_COUNT]
adata_hvg = adata[:, final_gene_ids].copy()
print(f"Total genes selected for features: {adata_hvg.n_vars}")

# --- Save Gene Index ---
gene_symbols = adata_hvg.var.loc[final_gene_ids, 'gene_id'].fillna("").astype(str).tolist()
hvg_df = pd.DataFrame({
    'Gene_ID': final_gene_ids,
    'Gene_Symbol': gene_symbols,
    'HVG_Index': range(len(final_gene_ids))
})
hvg_df.to_csv(HVG_GENE_INDEX_PATH, index=False)
print(f"✅ Gene index file has been saved in the 'features' folder.")

# --- Format and Save Node Features ---
adata_df = pd.DataFrame(
    adata_hvg.X.toarray() if not isinstance(adata_hvg.X, np.ndarray) else adata_hvg.X,
    index=adata_hvg.obs_names,
    columns=adata_hvg.var_names
)

expression_data = {}
for cell_name, row in adata_hvg.obs.iterrows():
    time = row['time']
    if pd.isna(time) or cell_name not in cell_name_to_id:
        continue
    cell_id = cell_name_to_id[cell_name]
    expr_vector = adata_df.loc[cell_name].values.astype(float).tolist()
    expression_data[int(cell_id)] = {int(time): expr_vector}

with open(NODE_FEATURES_PKL_PATH, 'wb') as f:
    pickle.dump(expression_data, f)

print(f"✅ Node features dictionary created with {len(expression_data)} entries.")
print(f"✅ Node features file has been saved in the 'features' folder.")

--- Initial Data ---
AnnData object with n_obs × n_vars = 48909 × 23397
    obs: 'orig.ident', 'nCount_Spatial', 'nFeature_Spatial', 'percent.mt', 'integrated_snn_res.0.6', 'seurat_clusters', 'batch', 'nCount_SCT', 'nFeature_SCT', 'SCT_snn_res.0.4', 'SCT_snn_res.0.6', 'SCT_snn_res.0.8', 'SCT_snn_res.1', 'SCT_snn_res.1.2', 'cell_annotion', 'sub_cell', 'imagerow', 'imagecol', 'sample', 'x', 'y', 'main_celltype', 'group', 'time'
    var: 'gene_id'
    uns: 'batch_colors', 'hvg', 'log1p', 'main_celltype_colors', 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap'
    obsp: 'connectivities', 'distances'
✅ Required genes found in data: ['9734', '13383', '16325', '16842', '18555']
Total genes selected for features: 1000
✅ Gene index file has been saved in the 'features' folder.
✅ Node features dictionary created with 48909 entries.
✅ Node features file has been saved in the 'features' folder.


### 3. Edge Inference with Palantir

**Objective:** To infer cell-cell relationships (edges) from a representative sample. This is done by first loading the complete dataset, filtering it to a specific sample ('P14'), and then running the computationally intensive Palantir analysis on this smaller, focused subset. This ensures the resulting cell names are compatible with the main dataset.

In [16]:
# === MODIFIED LOGIC: Load the full dataset first, then filter for the P14 sample ===

# --- Load Full Dataset ---
adata_full = sc.read_h5ad(H5AD_SPATIAL_PATH)
print("Full dataset loaded:")
print(adata_full)

# --- Filter for the Target Sample ---
# Let's inspect the available samples to ensure we use the correct identifier
available_samples = adata_full.obs['sample'].unique().tolist()
print(f"Available samples in dataset: {available_samples}")

# Define the sample we want to analyze with Palantir
TARGET_SAMPLE = 'P14_1_3'
if TARGET_SAMPLE not in available_samples:
    raise ValueError(f"Target sample '{TARGET_SAMPLE}' not found in the dataset!")

# Create the subset for Palantir analysis. .copy() is crucial to avoid view errors.
adata_palantir = adata_full[adata_full.obs['sample'] == TARGET_SAMPLE].copy()
print(f"\n--- Created subset for Palantir using sample '{TARGET_SAMPLE}' ---")
print(adata_palantir)

# --- Perform standard single-cell preprocessing on the subset ---
sc.pp.normalize_total(adata_palantir, target_sum=1e4)
sc.pp.log1p(adata_palantir)
sc.pp.highly_variable_genes(adata_palantir, n_top_genes=2000)
adata_palantir = adata_palantir[:, adata_palantir.var['highly_variable']]
sc.pp.scale(adata_palantir)
sc.tl.pca(adata_palantir, n_comps=PCA_N_COMPS)
print(f"\nData subset preprocessed. PCA computed with {PCA_N_COMPS} components.")

# --- Run Diffusion Maps and Compute Transition Matrix ---
pca_df = pd.DataFrame(adata_palantir.obsm["X_pca"], index=adata_palantir.obs_names)
dm_res = run_diffusion_maps(pca_df, n_components=30)
ms_data = determine_multiscale_space(dm_res)

# Compute the affinity kernel and normalize it
kernel = compute_kernel(ms_data, knn=PALANTIR_KNN, alpha=PALANTIR_ALPHA)
row_sums = np.array(kernel.sum(axis=1)).flatten()
row_sums[row_sums == 0] = 1 # Avoid division by zero
trans_probs = kernel.multiply(1.0 / row_sums[:, None])
trans_probs = trans_probs.tocsr()

# --- Construct and Save Edge List ---
rows, cols = trans_probs.nonzero()
weights = trans_probs[rows, cols].A1
cell_names = np.array(pca_df.index)
edges_df = pd.DataFrame({
    "source_cell": cell_names[rows],
    "target_cell": cell_names[cols],
    "weight": weights
})
edges_df.to_csv(RAW_EDGES_PATH, index=False)

print("\n--- Palantir: Results ---")
print(f"Generated {len(edges_df)} raw edges.")
print("Sample of raw edges:")
print(edges_df.head())
print(f"\n✅ Raw edges file has been successfully saved in the 'edges' folder.")

Full dataset loaded:
AnnData object with n_obs × n_vars = 48909 × 23397
    obs: 'orig.ident', 'nCount_Spatial', 'nFeature_Spatial', 'percent.mt', 'integrated_snn_res.0.6', 'seurat_clusters', 'batch', 'nCount_SCT', 'nFeature_SCT', 'SCT_snn_res.0.4', 'SCT_snn_res.0.6', 'SCT_snn_res.0.8', 'SCT_snn_res.1', 'SCT_snn_res.1.2', 'cell_annotion', 'sub_cell', 'imagerow', 'imagecol', 'sample', 'x', 'y', 'main_celltype', 'group', 'time'
    var: 'gene_id'
    uns: 'batch_colors', 'hvg', 'log1p', 'main_celltype_colors', 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap'
    obsp: 'connectivities', 'distances'
Available samples in dataset: ['E20_2_1', 'P1_1_1', 'P4_2_1', 'P14_1_3']

--- Created subset for Palantir using sample 'P14_1_3' ---
AnnData object with n_obs × n_vars = 20639 × 23397
    obs: 'orig.ident', 'nCount_Spatial', 'nFeature_Spatial', 'percent.mt', 'integrated_snn_res.0.6', 'seurat_clusters', 'batch', 'nCount_SCT', 'nFeature_SCT', 'SCT_snn_res.0.4', 'SCT_snn_res.0.6', 'SCT_snn_r

  view_to_actual(adata)



Data subset preprocessed. PCA computed with 50 components.

--- Palantir: Results ---
Generated 761042 raw edges.
Sample of raw edges:
             source_cell            target_cell    weight
0  p14-1.rds:BIN.13310_3  p14-1.rds:BIN.17414_3  0.048961
1  p14-1.rds:BIN.13310_3  p14-1.rds:BIN.26453_3  0.005772
2  p14-1.rds:BIN.13310_3  p14-1.rds:BIN.19162_3  0.027242
3  p14-1.rds:BIN.13310_3  p14-1.rds:BIN.26454_3  0.032049
4  p14-1.rds:BIN.13310_3  p14-1.rds:BIN.10347_3  0.010856

✅ Raw edges file has been successfully saved in the 'edges' folder.


### 4. Edge Weight Analysis and Filtering

**Objective:** To analyze the distribution of the inferred edge weights and filter out weak connections below a defined threshold. This step helps to reduce noise and focus the downstream analysis on more confident cell-cell transitions.

In [17]:
# --- Load Raw Edges ---
raw_edges_df = pd.read_csv(RAW_EDGES_PATH)

# --- Analyze Edge Weight Distribution ---
print("--- Edge Weight Analysis ---")
print(f"Total number of raw edges: {len(raw_edges_df)}")
thresholds = [0.001, 0.005, 0.01, 0.05, 0.1, 0.5]
for t in thresholds:
    count = (raw_edges_df["weight"] > t).sum()
    print(f"Number of edges with weight > {t}: {count}")
    
# --- Filter Edges by Weight ---
filtered_edges_df = raw_edges_df[raw_edges_df["weight"] > EDGE_WEIGHT_THRESHOLD]
filtered_edges_df.to_csv(FILTERED_EDGES_PATH, index=False)

print(f"\nFiltering edges with weight > {EDGE_WEIGHT_THRESHOLD}")
print(f"Number of edges remaining after filtering: {len(filtered_edges_df)}")
print("\nSample of filtered edges:")
print(filtered_edges_df.head())
print(f"\n✅ Filtered edges file has been successfully saved in the 'edges' folder.")

--- Edge Weight Analysis ---
Total number of raw edges: 761042
Number of edges with weight > 0.001: 716719
Number of edges with weight > 0.005: 533790
Number of edges with weight > 0.01: 397297
Number of edges with weight > 0.05: 99471
Number of edges with weight > 0.1: 40660
Number of edges with weight > 0.5: 1464

Filtering edges with weight > 0.1
Number of edges remaining after filtering: 40660

Sample of filtered edges:
              source_cell            target_cell    weight
8   p14-1.rds:BIN.13310_3   p14-1.rds:BIN.3440_3  0.140821
20  p14-1.rds:BIN.13310_3  p14-1.rds:BIN.45330_3  0.116539
48  p14-1.rds:BIN.27358_3  p14-1.rds:BIN.29966_3  0.101190
61  p14-1.rds:BIN.27358_3   p14-1.rds:BIN.7459_3  0.192855
70  p14-1.rds:BIN.27358_3   p14-1.rds:BIN.4540_3  0.126427

✅ Filtered edges file has been successfully saved in the 'edges' folder.


### 5. Final Data Assembly for Machine Learning

**Objective:** To combine the filtered edge list with cell metadata (timepoints, cell type labels) and format it into the final structure required by the machine learning model. This step now works correctly because the edges were generated from a subset of the main spatial data.

In [18]:
# === REVERTED TO ORIGINAL LOGIC: This now works correctly ===

# --- Load All Necessary Data and Mappings ---
edges_final_df = pd.read_csv(FILTERED_EDGES_PATH)
cell_id_map = pd.read_csv(CELL_MAP_PATH)
celltype_mapping = pd.read_csv(CELLTYPE_MAP_PATH)
adata_meta = ad.read_h5ad(H5AD_SPATIAL_PATH)

print("--- Final Assembly: Loaded Data ---")
print(f"Loaded {len(edges_final_df)} filtered edges.")
print(f"Loaded {len(cell_id_map)} cell ID mappings.")
print(f"Loaded {len(celltype_mapping)} cell type mappings.")

# --- Create Mapping Dictionaries ---
cell_to_id = dict(zip(cell_id_map["Cell"], cell_id_map["ID"]))
id_to_cell = {v: k for k, v in cell_to_id.items()}
# Use the 'main_celltype' column from the main spatial data file
cell_to_celltype = adata_meta.obs["main_celltype"].to_dict()
celltype_to_id = dict(zip(celltype_mapping["Celltype"], celltype_mapping["ID"]))
cell_id_to_time = {
    cid: adata_meta.obs.loc[cname, "time"]
    for cid, cname in id_to_cell.items()
    if cname in adata_meta.obs.index
}

# --- Map IDs, Labels, and Timestamps to the Edge List ---
def map_label_from_id(cell_id):
    cell_name = id_to_cell.get(cell_id)
    cell_type = cell_to_celltype.get(cell_name)
    return celltype_to_id.get(cell_type, np.nan) if cell_type else np.nan

edges_final_df["u"] = edges_final_df["source_cell"].map(cell_to_id)
edges_final_df["i"] = edges_final_df["target_cell"].map(cell_to_id)
edges_final_df["label"] = edges_final_df["u"].map(map_label_from_id)
edges_final_df["ts"] = edges_final_df["u"].map(cell_id_to_time)

# --- Clean and Prepare the DataFrame ---
df_out = edges_final_df[["u", "i", "ts", "label"]].dropna()
df_out["label"] = df_out["label"].astype(int)

df_out = df_out.sort_values("ts").reset_index(drop=True)
df_out = df_out.groupby("ts", group_keys=False).apply(lambda x: x.sample(frac=1, random_state=RANDOM_STATE)).reset_index(drop=True)

print("\n--- Data after Mapping and Cleaning ---")
print(f"Dataframe contains {len(df_out)} valid edges.")
print(df_out.head())

# --- Ensure Balanced Labels in Validation/Test Splits ---
if not df_out.empty:
    labels_needed = set(celltype_to_id.values())
    max_attempts = 100
    success = False
    df_final = df_out # Initialize df_final

    for attempt in range(max_attempts):
        shuffled_df = df_out.sample(frac=1, random_state=RANDOM_STATE + attempt).reset_index(drop=True)
        n = len(shuffled_df)
        split_test = int(n * 0.85)
        split_val = int(n * 0.70)
        
        if split_val < split_test and split_test < n:
            valid_labels = set(shuffled_df.iloc[split_val:split_test]["label"].unique())
            test_labels = set(shuffled_df.iloc[split_test:]["label"].unique())
            
            if labels_needed.issubset(valid_labels) and labels_needed.issubset(test_labels):
                df_final = shuffled_df.copy()
                success = True
                print(f"\n✅ Success: Found a shuffle that balances labels after {attempt+1} attempts.")
                break

    if not success:
        print("\n⚠️ Warning: Could not find a perfect shuffle to balance labels. Using default.")
else:
    df_final = df_out

# --- Save Final CSV and Feature Files ---
df_final["idx"] = range(1, len(df_final) + 1)
df_final.to_csv(FINAL_CSV_PATH, index=False)
feat_array = np.ones((df_final.shape[0], 1), dtype=np.float32)
np.save(FINAL_FEAT_PATH, feat_array)

print(f"\n✅ Final CSV data has been successfully saved in the 'final_data' folder.")
print(f"✅ Feature matrix has been successfully saved in the 'final_data' folder.")
print(f"✅ Final files contain {df_final.shape[0]} rows.")

--- Final Assembly: Loaded Data ---
Loaded 40660 filtered edges.
Loaded 48909 cell ID mappings.
Loaded 3 cell type mappings.

--- Data after Mapping and Cleaning ---
Dataframe contains 40660 valid edges.
       u      i  ts  label
0  48335  34913  16      2
1  34175  48164  16      2
2  39638  35648  16      2
3  46097  35276  16      2
4  40042  39669  16      2

✅ Success: Found a shuffle that balances labels after 1 attempts.

✅ Final CSV data has been successfully saved in the 'final_data' folder.
✅ Feature matrix has been successfully saved in the 'final_data' folder.
✅ Final files contain 40660 rows.
