# Phenotype Data

## Library Import

In [None]:
# import json
# import os

# import matplotlib.pyplot as plt
# import numpy as np
# import pandas as pd
# import scanpy as sc
# import seaborn as sns
import torch

# import torch_geometric
from matplotlib import colors as mcolors

In [None]:
# import argparse
import json
import os

import numpy as np
import pandas as pd
import scanpy as sc
from sklearn.model_selection import train_test_split
from tqdm import tqdm

CANCER_TYPES = [
    "Adenocarcinoma",
    "Squamous cell carcinoma",
    "Adeno squamous cell carcinoma",
    "Control",
    "Large cell carcinoma",
    "Large cell neuroendocrine carcinoma",
    "NSCLC",
    "Mesotheliom",
    "Basaloides Ca",
    "NA",
]

TARGETS = [
    "Relapse",
    "Ev.O",
    "Grade",
    "Stage",
    "T.new",
    "N",
    "M.new",
    "R",
    "Chemo",
    "Radio",
    "DFS",
    "OS",
    "LN.Met",
    "Dist.Met",
]

IMC_TO_CODEX_SCALE_FACTOR = 2.75

## Helper Functions

## NSCLC Dataset

In [None]:
adata_raw = sc.read_h5ad("../data/phenotype/sce_all_annotated.h5ad")
adata_raw

In [None]:
def filter_for_cancer_types(cell_information, cancer_types):
    assert set(cancer_types).issubset(set(CANCER_TYPES[:3]))
    return cell_information.loc[cell_information["DX.name"].isin(cancer_types)]


def load_adata(adata_path, cancer_types, targets):
    # load adata and cell information
    adata = sc.read_h5ad(adata_path)
    cell_information = adata.obs

    # replace -2147483648 with NaN
    cell_information_clean = cell_information.replace(-2147483648, np.nan)

    # filter for cancer types
    cell_information_clean = cell_information_clean.loc[
        (cell_information_clean["DX.name"] == "Adenocarcinoma")
        | (cell_information_clean["DX.name"] == "Squamous cell carcinoma")
        | (cell_information_clean["DX.name"] == "Adeno squamous cell carcinoma")
    ]
    cell_information_clean = filter_for_cancer_types(cell_information_clean, cancer_types)

    # normalize area
    cell_information_clean["Area"] = cell_information_clean["Area"] / cell_information_clean[
        "Area"
    ].quantile(0.95)

    # lower cell category, type and subtype (e.g. "T cell" -> "t cell")
    cell_information_clean["cell_category"] = cell_information_clean["cell_category"].str.lower()
    cell_information_clean["cell_type"] = cell_information_clean["cell_type"].str.lower()
    cell_information_clean["cell_subtype"] = cell_information_clean["cell_subtype"].str.lower()

    # drop NaNs in cell category, type and subtype
    cell_information_clean = cell_information_clean.dropna(
        subset=["cell_category", "cell_type", "cell_subtype"]
    )

    # convert LN.Met, Dist.Met, and NeoAdj to binary
    cell_information_clean["LN.Met"] = cell_information_clean["LN.Met"].cat.rename_categories(
        {"No LN Metastases": 0, "LN Metastases": 1}
    )
    cell_information_clean["Dist.Met"] = cell_information_clean["Dist.Met"].cat.rename_categories(
        {"No Dist. Metastases": 0, "Dist. Metastases": 1}
    )
    cell_information_clean["NeoAdj"] = cell_information_clean["NeoAdj"].cat.rename_categories(
        {"NoNeoAdjuvantTherapy": 0, "NeoAdjuvantTherapy": 1}
    )

    # drop NaNs in target columns
    cell_information_clean = cell_information_clean.dropna(subset=targets)

    return adata, cell_information_clean

In [None]:
adata_path = "../data/phenotype/sce_all_annotated.h5ad"
target_names = ["Relapse", "Ev.O", "LN.Met"]

adata, cell_information_clean = load_adata(adata_path, CANCER_TYPES[:3], target_names)

In [None]:
cell_information_clean

In [None]:
cell_information_clean["Tma_ac"]

In [None]:
def separate_regions(adata, cell_information, targets):
    # group regions by Tma_ac
    grouped = cell_information.groupby("Tma_ac", observed=True)

    # get the biomarker names = proteins measured
    biomarker_names = adata.var_names.values

    # initialize the regions and targets dictionaries
    # regions = {region_name: {"cell_ids": cell_ids, "coordinates": coordinates, "cell_types": cell_types, "sizes": sizes, "biomarkers": biomarkers}}
    # targets = {target1: [target_value1, target_value2, ...], target2: [target_value1, target_value2, ...], ...}
    regions = {}
    targets = {target: [] for target in targets}

    # iterate over the groups and extract the information for each region
    for name, group in grouped:
        cell_ids = group.index.values
        coordinates = group[["Center_X", "Center_Y"]].values
        cell_types = group["cell_subtype"].values
        sizes = group["Area"].values
        biomarkers = adata[cell_ids].layers["c_counts_asinh_scaled"].toarray()
        for target in targets:
            target_values = group[target]
            unique_targets = target_values.unique()
            assert len(unique_targets) == 1
            targets[target].append(unique_targets[0])
        regions[name] = {
            "cell_ids": cell_ids,
            "coordinates": coordinates,
            "cell_types": cell_types,
            "sizes": sizes,
            "biomarkers": biomarkers,
        }

    return regions, targets, biomarker_names

In [None]:
regions, targets, biomarker_names = separate_regions(adata, cell_information_clean, target_names)

In [None]:
def train_valid_split(cell_information, target_columns, valid_fraction):
    # concat all target columns to one string used for stratification
    cell_information["stratify"] = (
        cell_information[target_columns].astype(str).agg("".join, axis=1)
    )

    # split patients into train and valid sets using stratified sampling
    patient_df = cell_information[["Patient_ID", "stratify"]].drop_duplicates()
    train_patients, valid_patients = train_test_split(
        list(patient_df["Patient_ID"]),
        test_size=valid_fraction,
        stratify=list(patient_df["stratify"]),
        random_state=44,
    )

    # get the regions for train and valid sets
    train_regions = cell_information.loc[cell_information["Patient_ID"].isin(train_patients)]
    train_regions = list(train_regions["Tma_ac"].unique())
    valid_regions = cell_information.loc[cell_information["Patient_ID"].isin(valid_patients)]
    valid_regions = list(valid_regions["Tma_ac"].unique())

    return train_regions, valid_regions

In [None]:
valid_fraction = 0.2
output_dir = "../data/phenotype/regions"

train_regions, valid_regions = train_valid_split(
    cell_information_clean, ["Relapse", "Grade"], valid_fraction
)

os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "train_regions_grade.json"), "w") as f:
    json.dump(train_regions, f)
with open(os.path.join(output_dir, "valid_regions_grade.json"), "w") as f:
    json.dump(valid_regions, f)

In [None]:
def store_regions(regions, targets, biomarker_names, output_dir):
    for name, region in tqdm(list(regions.items())):
        os.makedirs(os.path.join(output_dir, name), exist_ok=True)
        cell_ids = region["cell_ids"]
        # store cell coordinates as csv with cell_ids
        cell_coords = region["coordinates"]
        cell_coords *= IMC_TO_CODEX_SCALE_FACTOR
        coordinates_df = pd.DataFrame(cell_coords, columns=["X", "Y"])
        coordinates_df["CELL_ID"] = cell_ids
        coordinates_df.to_csv(os.path.join(output_dir, name, "coordinates.csv"), index=False)
        # store cell types as csv with cell_ids
        cell_types = region["cell_types"]
        cell_types_df = pd.DataFrame(cell_types, columns=["CELL_TYPE"])
        cell_types_df["CELL_ID"] = cell_ids
        cell_types_df.to_csv(os.path.join(output_dir, name, "cell_types.csv"), index=False)
        # store cell sizes as csv with cell_ids
        cell_sizes = region["sizes"]
        cell_sizes_df = pd.DataFrame(cell_sizes, columns=["SIZE"])
        cell_sizes_df["CELL_ID"] = cell_ids
        cell_sizes_df.to_csv(os.path.join(output_dir, name, "cell_sizes.csv"), index=False)
        # store biomarker expression as csv with cell_ids
        biomarkers = region["biomarkers"]
        biomarkers_df = pd.DataFrame(biomarkers, columns=biomarker_names)
        biomarkers_df["CELL_ID"] = cell_ids
        biomarkers_df.to_csv(os.path.join(output_dir, name, "expression.csv"), index=False)
    # store target values as csv, columns are target keys, rows are regions
    targets_df = pd.DataFrame(targets)
    targets_df.columns = [col.upper() for col in targets_df.columns]
    targets_df["REGION_ID"] = list(regions.keys())
    targets_df.to_csv(os.path.join(output_dir, "targets.csv"), index=False)

In [None]:
store_regions(regions, targets, biomarker_names, output_dir)

## Graphs

In [None]:
example_graph = torch.load("../data/phenotype/raw/175A_119.0.gpt", weights_only=False)
example_graph

In [None]:
example_graph.edge_attr[:, 1].unique()

In [None]:
cell_types = example_graph.x[:, 0].long()

In [None]:
targets_path = os.path.join("../data/phenotype/raw", "targets.csv")
if not os.path.exists(targets_path):
    raise FileNotFoundError(f"File {targets_path} does not exist. Please provide the targets.")
targets = pd.read_csv(targets_path)

In [None]:
targets.loc[targets["REGION_ID"] == "175A_119"]

In [None]:
graph_tasks = ["RELAPSE", "EV.O", "LN.MET"]

# append the targets to the graph
example_graph.y = torch.zeros((len(graph_tasks)), dtype=torch.float32)
for i, task in enumerate(graph_tasks):
    if task not in targets.columns:
        raise ValueError(f"Task {task} not found in targets.")
    example_graph.y[i] = torch.tensor(
        targets.loc[targets["REGION_ID"] == example_graph.region_id, task].values
    )
    example_graph.y[i] = example_graph.y[i].to(torch.float32)

In [None]:
example_graph.y

In [None]:
json_path = "../data/phenotype/json_files"

ct_mapping_path = os.path.join(json_path, "cell_type_mapping.json")
ct_freq_path = os.path.join(json_path, "cell_type_freq.json")
biomarkers_path = os.path.join(json_path, "biomarkers.json")
with open(ct_mapping_path, "r") as f:
    cell_type_mapping = json.load(f)
with open(ct_freq_path, "r") as f:
    cell_type_freq = json.load(f)
with open(biomarkers_path, "r") as f:
    biormarkers = json.load(f)

In [None]:
cell_type_freq

In [None]:
cell_type_mapping

In [None]:
biormarkers

## Testing Code

In [None]:
import copy
import os

import rootutils
import torch
from torch_geometric.data import Batch, Data

rootutils.setup_root(os.getcwd(), indicator=".project-root", pythonpath=True)

from src.data.cellular_graph_datamodule import (
    CellularGraphDataModule,
    CellularGraphDataset,
)

In [None]:
mode = "finetuning"
data_dir = "../data/phenotype/nsclc/raw"
processed_dir = "../data/phenotype/nsclc/processed"
batch_size = 64
num_workers = 4
pin_memory = False
json_path = "../data/phenotype/nsclc/json_files"
subgraph_size = 3
num_iterations = 50000
node_features = [
    "cell_type",
    "size",
    "biomarker_expression",
    "neighborhood_composition",
    "center_coord",
]
edge_features = ["edge_type", "distance"]
node_features_to_use = ["cell_type", "size"]
edge_features_to_use = ["edge_type"]
sampling_avoid_unassigned = True
unassigned_cell_type = "Unassigned"
graph_tasks = ["RELAPSE"]
redo_preprocess = False
seed = 44

In [None]:
data_module = CellularGraphDataModule(
    mode=mode,
    data_dir=data_dir,
    processed_dir=processed_dir,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=pin_memory,
    json_path=json_path,
    subgraph_size=subgraph_size,
    num_iterations=num_iterations,
    node_features=node_features,
    edge_features=edge_features,
    node_features_to_use=node_features_to_use,
    edge_features_to_use=edge_features_to_use,
    sampling_avoid_unassigned=sampling_avoid_unassigned,
    unassigned_cell_type=unassigned_cell_type,
    graph_tasks=graph_tasks,
    redo_preprocess=redo_preprocess,
    seed=seed,
)

In [None]:
data_module

In [None]:
data_module.prepare_data()

In [None]:
data_module.setup()

In [None]:
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()
test_loader = data_module.test_dataloader()
trainset = train_loader.dataset
valset = val_loader.dataset
testset = test_loader.dataset

In [None]:
from torchmetrics.classification import (
    BinaryAccuracy,
    BinaryAUROC,
    BinaryConfusionMatrix,
    BinaryF1Score,
    BinaryPrecision,
    BinaryRecall,
)

from src.models.bgrl_phenotype_module import BGRLPhenotypeLitModule
from src.models.components.bgrl import BGRL
from src.models.components.bgrl_projector import BGRLProjector
from src.models.components.gnn import GNN, GNN_pred

In [None]:
gnn = GNN(
    num_layer=3,
    num_node_type=30,
    num_feat=76,
    emb_dim=256,
)

In [None]:
gnn_pred = GNN_pred(
    num_layer=3,
    num_node_type=30,
    num_feat=76,
    emb_dim=256,
)

In [None]:
projector = BGRLProjector(
    input_size=256,
    output_size=256,
    hidden_size=512,
)

In [None]:
bgrl = BGRL(
    encoder=gnn,
    projector=projector,
)

In [None]:
val_outputs = []
iterator = iter(test_loader)
for i in range(10):
    data = next(iterator)

    region_id = data.region_id[0][0]

    # run online encoder to get predictions for graph
    # prediction for graph is the mean of all subgraphs
    with torch.no_grad():
        res = gnn_pred(data)
    y_pred = res[0].flatten().mean()

    # get ground truth labels (graph level, same for all subgraphs)
    y_true = data.y[0]

    val_outputs.append(
        {
            "y_pred": y_pred,
            "y_true": y_true,
            "region_id": region_id,
        }
    )

val_outputs

In [None]:
y_preds = torch.tensor([d["y_pred"] for d in val_outputs], dtype=torch.float)
y_trues = torch.tensor([d["y_true"] for d in val_outputs], dtype=torch.float)

In [None]:
logits = y_preds
labels = y_trues

probs = torch.sigmoid(logits)
preds = (probs >= 0.5).int()

# compute metrics
metrics = {
    "auroc": BinaryAUROC()(probs, labels).item(),
    "accuracy": BinaryAccuracy()(preds, labels).item(),
    "f1": BinaryF1Score()(preds, labels).item(),
    "precision": BinaryPrecision()(preds, labels).item(),
    "recall": BinaryRecall()(preds, labels).item(),
}

# compute balanced accuracy
cm = BinaryConfusionMatrix()(preds, labels)
tn, fp, fn, tp = cm.flatten()
sensitivity = tp / (tp + fn + 1e-8)
specificity = tn / (tn + fp + 1e-8)
metrics["balanced_accuracy"] = ((sensitivity + specificity) / 2).item()

In [None]:
metrics