In [1]:
# Code adapted from Machine Learning Engineering (Cornell Tech 2025)
import torch
import numpy as np
import random

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

In [2]:
from google.colab import drive
import sys
import os

# --- 1. Mount Drive ---
drive.mount('/content/drive')

# --- 2. Define Paths ---
# Path to the source code (loaders.py) - REMAINS ON DRIVE
DRIVE_CODE_PATH = '/content/drive/MyDrive/GoogleColab/dataforsptransformer/'

# Path to the zipped data file on Drive
ZIP_SOURCE_PATH = os.path.join(DRIVE_CODE_PATH, 'crc_markers.zip')

# Local disk folder where the FAST images will be unzipped
FAST_DATA_PATH = '/content/fast_data/'

# --- 3. Unzip Data (Performance Fix) ---
if not os.path.exists(FAST_DATA_PATH):
    print(f"üöÄ Unzipping data from Drive to fast local disk: {FAST_DATA_PATH}")
    !mkdir -p "$FAST_DATA_PATH"
    # The -q flag silences the output. -d sets the destination directory.
    !unzip -q "$ZIP_SOURCE_PATH" -d "$FAST_DATA_PATH"

    print("‚úÖ Data transfer complete. Starting new batch load test.")
else:
    print("Fast data directory already exists.")


# --- 4. Set Final Variables ---
# PROJECT_DIR for the rest of your notebook now points to the FAST images
PROJECT_DIR = FAST_DATA_PATH

# Add the Drive path for Python to find 'loaders.py' and other modules
if DRIVE_CODE_PATH not in sys.path:
    sys.path.append(DRIVE_CODE_PATH)
    print(f"‚úÖ Added {DRIVE_CODE_PATH} to Python system path.")

Mounted at /content/drive
üöÄ Unzipping data from Drive to fast local disk: /content/fast_data/
‚úÖ Data transfer complete. Starting new batch load test.
‚úÖ Added /content/drive/MyDrive/GoogleColab/dataforsptransformer/ to Python system path.


In [5]:
from google.colab import files

# This will open a 'Choose Files' button in your output cell
# Choose train_helpers.py
uploaded = files.upload()

for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))

Saving data_cleaning.py to data_cleaning.py
User uploaded file "data_cleaning.py" with length 5434 bytes


In [6]:
!python3 data_cleaning.py

In [7]:
import pandas as pd

df = pd.read_csv("/content/fast_data/CRC_clusters_neighborhoods_markers_cleaned.csv")
df.head()

Unnamed: 0,CellID,ClusterID,EventID,File Name,Region,TMA_AB,TMA_12,Index in File,groups,patients,...,CD8+ICOS+,CD8+Ki67+,CD8+PD-1+,Treg-ICOS+,Treg-Ki67+,Treg-PD-1+,neighborhood number final,neighborhood name,GraphID,KNN
0,0,10668,0,reg001_A,reg001,A,1,0,1,1,...,0,0,0,0,0,0,9.0,Granulocyte enriched,reg001_A 1 reg001 1,"[79342, 215276, 86261, 19386, 85435]"
1,1,10668,4,reg001_A,reg001,A,1,4,1,1,...,0,0,0,0,0,0,4.0,Macrophage enriched,reg001_A 1 reg001 1,"[209687, 147136, 231069, 61712, 50718]"
2,2,10668,5,reg001_A,reg001,A,1,5,1,1,...,0,0,0,0,0,0,3.0,Immune-infiltrated stroma,reg001_A 1 reg001 1,"[116916, 169947, 195850, 32707, 214032]"
3,3,10668,6,reg001_A,reg001,A,1,6,1,1,...,0,0,0,0,0,0,3.0,Immune-infiltrated stroma,reg001_A 1 reg001 1,"[7427, 215280, 163828, 223298, 93734]"
4,4,10668,30,reg001_A,reg001,A,1,30,1,1,...,0,0,0,0,0,0,4.0,Macrophage enriched,reg001_A 1 reg001 1,"[158193, 153016, 27846, 87820, 97886]"


In [8]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m63.7/63.7 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/1.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.3/1.3 MB[0m [31m87.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successf

In [9]:
df.columns

Index(['CellID', 'ClusterID', 'EventID', 'File Name', 'Region', 'TMA_AB',
       'TMA_12', 'Index in File', 'groups', 'patients', 'spots',
       'CD44 - stroma:Cyc_2_ch_2', 'FOXP3 - regulatory T cells:Cyc_2_ch_3',
       'CD8 - cytotoxic T cells:Cyc_3_ch_2',
       'p53 - tumor suppressor:Cyc_3_ch_3',
       'GATA3 - Th2 helper T cells:Cyc_3_ch_4',
       'CD45 - hematopoietic cells:Cyc_4_ch_2', 'T-bet - Th1 cells:Cyc_4_ch_3',
       'beta-catenin - Wnt signaling:Cyc_4_ch_4', 'HLA-DR - MHC-II:Cyc_5_ch_2',
       'PD-L1 - checkpoint:Cyc_5_ch_3', 'Ki67 - proliferation:Cyc_5_ch_4',
       'CD45RA - naive T cells:Cyc_6_ch_2', 'CD4 - T helper cells:Cyc_6_ch_3',
       'CD21 - DCs:Cyc_6_ch_4', 'MUC-1 - epithelia:Cyc_7_ch_2',
       'CD30 - costimulator:Cyc_7_ch_3', 'CD2 - T cells:Cyc_7_ch_4',
       'Vimentin - cytoplasm:Cyc_8_ch_2', 'CD20 - B cells:Cyc_8_ch_3',
       'LAG-3 - checkpoint:Cyc_8_ch_4', 'Na-K-ATPase - membranes:Cyc_9_ch_2',
       'CD5 - T cells:Cyc_9_ch_3', 'IDO-1 - metaboli

## Prepare Graph Data

### Subtask:
Process the `df` DataFrame to extract all necessary components for graph construction: node features (`icols`), spatial coordinates ('X:X', 'Y:Y'), node identifiers ('CellID'), graph identifiers ('GraphID'), and parse the 'KNN' column to identify neighbors. Also, define a node-level target label, for example, using the 'ClusterID' column.


In [16]:
icols = ['CD44 - stroma:Cyc_2_ch_2', 'FOXP3 - regulatory T cells:Cyc_2_ch_3',
'CD8 - cytotoxic T cells:Cyc_3_ch_2', 'p53 - tumor suppressor:Cyc_3_ch_3',
'GATA3 - Th2 helper T cells:Cyc_3_ch_4', 'CD45 - hematopoietic cells:Cyc_4_ch_2',
'T-bet - Th1 cells:Cyc_4_ch_3', 'beta-catenin - Wnt signaling:Cyc_4_ch_4',
'HLA-DR - MHC-II:Cyc_5_ch_2', 'PD-L1 - checkpoint:Cyc_5_ch_3',
'Ki67 - proliferation:Cyc_5_ch_4', 'CD45RA - naive T cells:Cyc_6_ch_2',
'CD4 - T helper cells:Cyc_6_ch_3', 'CD21 - DCs:Cyc_6_ch_4', 'MUC-1 - epithelia:Cyc_7_ch_2',
'CD30 - costimulator:Cyc_7_ch_3', 'CD2 - T cells:Cyc_7_ch_4', 'Vimentin - cytoplasm:Cyc_8_ch_2',
'CD20 - B cells:Cyc_8_ch_3', 'LAG-3 - checkpoint:Cyc_8_ch_4', 'Na-K-ATPase - membranes:Cyc_9_ch_2',
'CD5 - T cells:Cyc_9_ch_3', 'IDO-1 - metabolism:Cyc_9_ch_4', 'Cytokeratin - epithelia:Cyc_10_ch_2',
'CD11b - macrophages:Cyc_10_ch_3', 'CD56 - NK cells:Cyc_10_ch_4', 'aSMA - smooth muscle:Cyc_11_ch_2',
'BCL-2 - apoptosis:Cyc_11_ch_3', 'CD25 - IL-2 Ra:Cyc_11_ch_4', 'CD11c - DCs:Cyc_12_ch_3',
'PD-1 - checkpoint:Cyc_12_ch_4', 'Granzyme B - cytotoxicity:Cyc_13_ch_2',
'EGFR - signaling:Cyc_13_ch_3', 'VISTA - costimulator:Cyc_13_ch_4',
'CD15 - granulocytes:Cyc_14_ch_2', 'ICOS - costimulator:Cyc_14_ch_4',
'Synaptophysin - neuroendocrine:Cyc_15_ch_3', 'GFAP - nerves:Cyc_16_ch_2',
'CD7 - T cells:Cyc_16_ch_3', 'CD3 - T cells:Cyc_16_ch_4', 'Chromogranin A - neuroendocrine:Cyc_17_ch_2',
'CD163 - macrophages:Cyc_17_ch_3', 'CD45RO - memory cells:Cyc_18_ch_3', 'CD68 - macrophages:Cyc_18_ch_4',
'CD31 - vasculature:Cyc_19_ch_3', 'Podoplanin - lymphatics:Cyc_19_ch_4', 'CD34 - vasculature:Cyc_20_ch_3',
'CD38 - multifunctional:Cyc_20_ch_4', 'CD138 - plasma cells:Cyc_21_ch_3', 'KNN', 'GraphID', 'X:X', 'Y:Y', 'Z:Z' ,'size:size']

In [17]:
non_feature_cols = ['KNN', 'GraphID', 'CellID', 'X:X', 'Y:Y', 'Z:Z', 'size:size']

# Filter the global icols variable from cell XUyeyUmnxvMn
# Note: The icols variable in the notebook context (Variable #19) is the one from XUyeyUmnxvMn
feature_cols = [col for col in icols if col not in non_feature_cols]

# Create 'node_features' column using the filtered feature_cols
df['node_features'] = df[feature_cols].values.tolist()

# Create 'pos' column with 'X:X' and 'Y:Y' coordinates
df['pos'] = df[['X:X', 'Y:Y']].values.tolist()

# Display the first few rows of the updated df DataFrame
df.head()

Unnamed: 0,CellID,ClusterID,EventID,File Name,Region,TMA_AB,TMA_12,Index in File,groups,patients,...,CD8+PD-1+,Treg-ICOS+,Treg-Ki67+,Treg-PD-1+,neighborhood number final,neighborhood name,GraphID,KNN,node_features,pos
0,0,10668,0,reg001_A,reg001,A,1,0,1,1,...,0,0,0,0,9.0,Granulocyte enriched,reg001_A 1 reg001 1,"[79342, 215276, 86261, 19386, 85435]","[1.843590736, 17.39870644, 0.0, 59.39188385, 3...","[77, 589]"
1,1,10668,4,reg001_A,reg001,A,1,4,1,1,...,0,0,0,0,4.0,Macrophage enriched,reg001_A 1 reg001 1,"[209687, 147136, 231069, 61712, 50718]","[30.28452492, 18.37573814, 74.69523621, 271.42...","[106, 826]"
2,2,10668,5,reg001_A,reg001,A,1,5,1,1,...,0,0,0,0,3.0,Immune-infiltrated stroma,reg001_A 1 reg001 1,"[116916, 169947, 195850, 32707, 214032]","[139.4885101, 249.7469788, 85.55697632, 705.36...","[107, 545]"
3,3,10668,6,reg001_A,reg001,A,1,6,1,1,...,0,0,0,0,3.0,Immune-infiltrated stroma,reg001_A 1 reg001 1,"[7427, 215280, 163828, 223298, 93734]","[20.59688568, 81.75975799999999, 0.0, 0.0, 34....","[98, 564]"
4,4,10668,30,reg001_A,reg001,A,1,30,1,1,...,0,0,0,0,4.0,Macrophage enriched,reg001_A 1 reg001 1,"[158193, 153016, 27846, 87820, 97886]","[67.32872772, 122.1954727, 11.02828407, 325.07...","[217, 329]"


## Construct PyTorch Geometric Data Objects

### Subtask:
Group the prepped data by `GraphID`. For each unique `GraphID`, construct a `torch_geometric.data.Data` object. This involves: creating a mapping of `CellID`s to internal node indices for the current graph; extracting node features (`x`) using the now correctly filtered `icols` list; building the `edge_index` from the parsed `KNN` lists (ensuring neighbors belong to the same graph); calculating `edge_attr` (Euclidean distances between connected nodes using 'X:X' and 'Y:Y' coordinates); and assigning node labels (`y`) based on the chosen target column (e.g., 'ClusterID').


## Prepare Node-Level Data

### Subtask:
Confirm that the `df` DataFrame has the `node_features` column (with only the specified marker features from `feature_cols`), the `pos` column (with 'X:X' and 'Y:Y' coordinates), and the 'KNN' column parsed as actual Python lists. This re-uses the correctly prepared columns from previous steps.

#### Instructions
1. Import the `ast` module for safe evaluation of string literals.
2. Inspect the data type of the 'KNN' column in the `df` DataFrame to confirm if it's currently a string. For example, use `df['KNN'].dtype`.
3. If the 'KNN' column is of object type (meaning it contains strings), apply `ast.literal_eval` to each element in the 'KNN' column to convert the string representations of lists into actual Python lists.
4. Print the data type of the 'KNN' column again and display the first few entries of the modified 'KNN' column to verify the conversion.

In [19]:
import ast

# 1. Inspect the data type of the 'KNN' column
print(f"Original 'KNN' column dtype: {df['KNN'].dtype}")

# 2. If it's an object (likely strings), apply ast.literal_eval
if df['KNN'].dtype == 'object':
    # Ensure only valid list strings are processed to avoid errors with NaN or non-string values
    df['KNN'] = df['KNN'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
    print("Converted 'KNN' column from string to Python list.")
else:
    print("'KNN' column is already in the correct format (Python lists or equivalent).")

# 3. Print the new data type and display the first few entries
print(f"New 'KNN' column dtype: {df['KNN'].dtype}")
print("First 5 entries of 'KNN' column after processing:")
print(df['KNN'].head().tolist())

Original 'KNN' column dtype: object
Converted 'KNN' column from string to Python list.
New 'KNN' column dtype: object
First 5 entries of 'KNN' column after processing:
[[79342, 215276, 86261, 19386, 85435], [209687, 147136, 231069, 61712, 50718], [116916, 169947, 195850, 32707, 214032], [7427, 215280, 163828, 223298, 93734], [158193, 153016, 27846, 87820, 97886]]


## Perform Node-Level Train/Val/Test Split

### Subtask:
Randomly shuffle all unique 'CellID's from the `df`. Then, split these 'CellID's into training, validation, and test sets according to the specified ratios (70/15/15). These sets of 'CellID's will define the center nodes for the ego-graphs in each respective dataset split.


In [20]:
import random

# 1. Extract all unique CellIDs
all_node_ids = df['CellID'].unique().tolist()

# 2. Set a random seed for reproducibility
random.seed(42)

# 3. Randomly shuffle the all_node_ids list in-place
random.shuffle(all_node_ids)

# 4. Define the split ratios
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

total_unique_nodes = len(all_node_ids)

# 5. Calculate the number of CellIDs for each split
train_split_size = int(total_unique_nodes * train_ratio)
val_split_size = int(total_unique_nodes * val_ratio)
# Ensure test_split_size covers the remainder to account for rounding
test_split_size = total_unique_nodes - train_split_size - val_split_size

# 6. Divide the shuffled all_node_ids list into three new lists
train_node_ids = all_node_ids[:train_split_size]
val_node_ids = all_node_ids[train_split_size : train_split_size + val_split_size]
test_node_ids = all_node_ids[train_split_size + val_split_size :]

# 7. Print the total number of unique CellIDs and the size of each resulting split
print(f"Total unique CellIDs: {total_unique_nodes}")
print(f"Train node IDs split size: {len(train_node_ids)}")
print(f"Validation node IDs split size: {len(val_node_ids)}")
print(f"Test node IDs split size: {len(test_node_ids)}")

Total unique CellIDs: 258385
Train node IDs split size: 180869
Validation node IDs split size: 38757
Test node IDs split size: 38759


## Construct Ego-Graphs for Each Split

### Subtask:
For each 'CellID' in the `train_node_ids`, `val_node_ids`, and `test_node_ids` lists, create a `torch_geometric.data.Data` object representing an ego-graph. Each ego-graph will have the current 'CellID' as its center node (index 0). Its node features (`x`), position (`pos`), and label (`y`) will be extracted. The ego-graph will also include its immediate neighbors (from the 'KNN' column) as additional nodes, along with edges and Euclidean distance-based edge attributes connecting the center node to its neighbors. Collect these ego-graph `Data` objects into `train_dataset`, `val_dataset`, and `test_dataset`.


In [21]:
import torch
from torch_geometric.data import Data
from scipy.spatial.distance import euclidean

# 1. Create a dictionary mapping CellID to its row for efficient lookup
df_indexed = df.set_index('CellID')
print("Created df_indexed for efficient CellID lookup.")

# 2. Define the create_ego_graph function
def create_ego_graph(center_node_id, df_indexed, feature_cols, target_col='ClusterID'):
    # a. Retrieve the center node's data
    center_node_data = df_indexed.loc[center_node_id]

    # b. Get center node's features, position, and label
    center_x = torch.tensor(center_node_data['node_features'], dtype=torch.float).unsqueeze(0)
    center_pos = torch.tensor(center_node_data['pos'], dtype=torch.float).unsqueeze(0)
    center_y = torch.tensor([center_node_data[target_col]], dtype=torch.long)

    # c. Identify neighbors and filter out invalid ones
    neighbors_original_ids = center_node_data['KNN']
    valid_neighbors_original_ids = [nid for nid in neighbors_original_ids if nid in df_indexed.index]

    # d. Create a list of all unique node IDs in the ego-graph (center + valid neighbors)
    ego_graph_node_ids = [center_node_id] + valid_neighbors_original_ids

    # e. Create a mapping from ego-graph node IDs to local indices
    local_id_map = {node_id: i for i, node_id in enumerate(ego_graph_node_ids)}

    # f. Populate x_ego, pos_ego, y_ego tensors for all nodes in the ego-graph
    x_ego_list = [center_x]
    pos_ego_list = [center_pos]
    y_ego_list = [center_y] # y_ego is primarily the center node's label

    for neighbor_id in valid_neighbors_original_ids:
        neighbor_data = df_indexed.loc[neighbor_id]
        x_ego_list.append(torch.tensor(neighbor_data['node_features'], dtype=torch.float).unsqueeze(0))
        pos_ego_list.append(torch.tensor(neighbor_data['pos'], dtype=torch.float).unsqueeze(0))
        y_ego_list.append(torch.tensor([neighbor_data[target_col]], dtype=torch.long)) # Append neighbor labels too

    x_ego = torch.cat(x_ego_list, dim=0)
    pos_ego = torch.cat(pos_ego_list, dim=0)
    # y_ego for ego-graph is usually just the center node's label for node classification,
    # but here we'll keep all node labels for completeness if needed later.
    # However, for a single ego-graph, often 'y' refers to the center node's property.
    # Let's keep it consistent by having `y` refer to the center node's label.
    # If multiple labels are needed, it would be `y_all_nodes` or similar.
    # For this task, we'll assign the center node's label to the ego-graph's 'y'.
    y_ego_graph_label = center_y

    # g. Construct edge_index: from center node to each neighbor
    edge_indices = []
    edge_attributes = []

    center_local_idx = local_id_map[center_node_id]
    center_pos_coord = pos_ego[center_local_idx]

    for neighbor_id in valid_neighbors_original_ids:
        neighbor_local_idx = local_id_map[neighbor_id]
        neighbor_pos_coord = pos_ego[neighbor_local_idx]

        # Add edge (center -> neighbor)
        edge_indices.append([center_local_idx, neighbor_local_idx])
        # Add edge (neighbor -> center) to make it undirected in adjacency matrix sense
        edge_indices.append([neighbor_local_idx, center_local_idx])

        # h. Calculate edge_attr (Euclidean distance) for center-neighbor edges
        distance = euclidean(center_pos_coord, neighbor_pos_coord)
        edge_attributes.append([distance])
        edge_attributes.append([distance]) # For the reverse edge

    if len(edge_indices) > 0:
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attributes, dtype=torch.float)
    else:
        # If no neighbors, define empty tensors for edges
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, 1), dtype=torch.float) # Assuming 1 feature for edge_attr

    # i. Create and return a torch_geometric.data.Data object
    data = Data(x=x_ego, edge_index=edge_index, edge_attr=edge_attr, y=y_ego_graph_label, pos=pos_ego,
                center_node_id=torch.tensor(center_node_id, dtype=torch.long))
    return data

print("Defined create_ego_graph function.")

# 3. Initialize empty lists for datasets
train_ego_dataset = []
val_ego_dataset = []
test_ego_dataset = []

# 4. Iterate and populate datasets
print("Generating train ego-graphs...")
for node_id in train_node_ids:
    train_ego_dataset.append(create_ego_graph(node_id, df_indexed, feature_cols))

print("Generating validation ego-graphs...")
for node_id in val_node_ids:
    val_ego_dataset.append(create_ego_graph(node_id, df_indexed, feature_cols))

print("Generating test ego-graphs...")
for node_id in test_node_ids:
    test_ego_dataset.append(create_ego_graph(node_id, df_indexed, feature_cols))

# 5. Print the sizes of the new datasets
print("\nEgo-Graph Dataset Sizes:")
print(f"  Train ego-graphs: {len(train_ego_dataset)}")
print(f"  Validation ego-graphs: {len(val_ego_dataset)}")
print(f"  Test ego-graphs: {len(test_ego_dataset)}")


Created df_indexed for efficient CellID lookup.
Defined create_ego_graph function.
Generating train ego-graphs...
Generating validation ego-graphs...
Generating test ego-graphs...

Ego-Graph Dataset Sizes:
  Train ego-graphs: 180869
  Validation ego-graphs: 38757
  Test ego-graphs: 38759


In [22]:
def train_epoch(model, loader, criterion, optimizer, device, mask_ratio=0.15):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()

        # Create a mask for node features
        mask = torch.rand(data.x.shape) < mask_ratio
        masked_x = data.x.clone()
        # Replace masked features with zeros or a special token (here, zeros for simplicity)
        masked_x[mask] = 0

        # Forward pass with masked features
        output = model(masked_x, data.edge_index, data.edge_attr, data.batch)

        # Calculate loss only for masked features
        loss = criterion(output[mask], data.x[mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, criterion, device, mask_ratio=0.15):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)

            # Create a mask for node features (same as train)
            mask = torch.rand(data.x.shape) < mask_ratio
            masked_x = data.x.clone()
            masked_x[mask] = 0

            # Forward pass
            output = model(masked_x, data.edge_index, data.edge_attr, data.batch)

            # Calculate loss only for masked features
            loss = criterion(output[mask], data.x[mask])
            total_loss += loss.item()
    return total_loss / len(loader)

print("Defined train_epoch and evaluate functions.")

Defined train_epoch and evaluate functions.


In [26]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv

print("Modules imported successfully.")

Modules imported successfully.


In [27]:
class GraphTransformer(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads, edge_dim, num_layers=2, dropout_rate=0.5):
        super(GraphTransformer, self).__init__()
        self.dropout_rate = dropout_rate

        self.conv_layers = nn.ModuleList()
        # First TransformerConv layer
        self.conv_layers.append(TransformerConv(in_channels, hidden_channels, heads=heads, edge_dim=edge_dim, dropout=dropout_rate))

        # Additional TransformerConv layers (if num_layers > 1)
        for _ in range(num_layers - 1):
            self.conv_layers.append(TransformerConv(hidden_channels * heads, hidden_channels, heads=heads, edge_dim=edge_dim, dropout=dropout_rate))

        # Final linear layer
        self.lin = nn.Linear(hidden_channels * heads, out_channels)

    def forward(self, x, edge_index, edge_attr, batch=None):
        for i, conv_layer in enumerate(self.conv_layers):
            x = conv_layer(x, edge_index, edge_attr)
            x = F.relu(x) # Apply activation function
            if i < len(self.conv_layers) - 1: # Apply dropout to intermediate layers
                x = F.dropout(x, p=self.dropout_rate, training=self.training)

        # Pass through the final linear layer
        x = self.lin(x)
        return x

# Instantiate the model with example parameters
in_channels = len(feature_cols) # Node feature dimension
hidden_channels = 64
out_channels = len(feature_cols) # Output dimension matches input feature dimension for a reconstruction task example
heads = 4 # Number of attention heads
edge_dim = 1 # Edge attribute dimension (Euclidean distance)
num_layers = 2 # Number of TransformerConv layers

model = GraphTransformer(in_channels, hidden_channels, out_channels, heads, edge_dim, num_layers)

print(f"GraphTransformer model created with {num_layers} layers.")
print(model)


GraphTransformer model created with 2 layers.
GraphTransformer(
  (conv_layers): ModuleList(
    (0): TransformerConv(49, 64, heads=4)
    (1): TransformerConv(256, 64, heads=4)
  )
  (lin): Linear(in_features=256, out_features=49, bias=True)
)


In [29]:
import torch.nn as nn
import torch.optim as optim

# Instantiate the Mean Squared Error (MSE) loss function
criterion = nn.MSELoss()

# Define a learning rate
learning_rate = 1e-3

# Instantiate the Adam optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

print("Loss function (criterion) and optimizer have been defined.")

Loss function (criterion) and optimizer have been defined.


In [30]:
import torch

# Check for GPU and set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device) # Move model to the selected device

# Define the number of epochs
epochs = 50

# Initialize DataLoaders for ego-graphs
batch_size = 32 # Assuming batch_size is defined from previous steps

train_ego_loader = DataLoader(train_ego_dataset, batch_size=batch_size, shuffle=True)
val_ego_loader = DataLoader(val_ego_dataset, batch_size=batch_size, shuffle=False)
test_ego_loader = DataLoader(test_ego_dataset, batch_size=batch_size, shuffle=False)

print(f"Training on: {device}")
print("Starting training and evaluation loop...")

for epoch in range(1, epochs + 1):
    train_loss = train_epoch(model, train_ego_loader, criterion, optimizer, device)
    val_loss = evaluate(model, val_ego_loader, criterion, device)

    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

print("Training and evaluation loop finished.")

# Optionally, run evaluation on the test set after training
test_loss = evaluate(model, test_ego_loader, criterion, device)
print(f'Test Loss: {test_loss:.4f}')

Training on: cuda
Starting training and evaluation loop...
Epoch: 001, Train Loss: 402471.7387, Val Loss: 374594.5991
Epoch: 002, Train Loss: 371664.9501, Val Loss: 349728.1523
Epoch: 003, Train Loss: 356608.5665, Val Loss: 344695.1324
Epoch: 004, Train Loss: 358464.5033, Val Loss: 338585.4149
Epoch: 005, Train Loss: 351820.2312, Val Loss: 332859.1095
Epoch: 006, Train Loss: 349812.9335, Val Loss: 325071.1947
Epoch: 007, Train Loss: 342862.2491, Val Loss: 326380.2203
Epoch: 008, Train Loss: 347715.5770, Val Loss: 340676.7844
Epoch: 009, Train Loss: 340715.7379, Val Loss: 325037.9451
Epoch: 010, Train Loss: 338340.6157, Val Loss: 335009.7777
Epoch: 011, Train Loss: 337164.2671, Val Loss: 324600.4767
Epoch: 012, Train Loss: 339205.7936, Val Loss: 325999.4618
Epoch: 013, Train Loss: 338924.9218, Val Loss: 326837.9510
Epoch: 014, Train Loss: 336175.3510, Val Loss: 322304.8555
Epoch: 015, Train Loss: 335511.9681, Val Loss: 322050.0723
Epoch: 016, Train Loss: 336451.3761, Val Loss: 318059.98

KeyboardInterrupt: 

### Explanation of Graph Transformer Model Inputs

The `GraphTransformer` model defined above, and more generally, many PyTorch Geometric models, expect specific input tensors to represent the graph structure and its features. Here's a detailed breakdown of the inputs:

1.  **`x` (Node Features)**:
    *   **Purpose**: This tensor represents the features associated with each node in the graph. In our context, these are the various marker expressions for each cell.
    *   **Expected Shape**: `[num_nodes, num_node_features]`, where `num_nodes` is the total number of nodes (cells) in the graph (or batch of graphs), and `num_node_features` is the dimensionality of each node's feature vector. For our model, `num_node_features` is `len(feature_cols)`, which is 62.
    *   **Expected Data Type**: `torch.Tensor` with `dtype=torch.float`.

2.  **`edge_index` (Graph Connectivity)**:
    *   **Purpose**: This tensor defines the connections (edges) between nodes in the graph. It's typically represented in a sparse COO (coordinate) format, indicating source and target nodes for each edge.
    *   **Expected Shape**: `[2, num_edges]`, where the first row contains the indices of the source nodes and the second row contains the indices of the target nodes for `num_edges` edges.
    *   **Expected Data Type**: `torch.Tensor` with `dtype=torch.long`.

3.  **`edge_attr` (Edge Attributes)**:
    *   **Purpose**: This tensor represents attributes associated with each edge. In our case, this is the Euclidean distance between connected cells.
    *   **Expected Shape**: `[num_edges, num_edge_features]`, where `num_edges` is the total number of edges and `num_edge_features` is the dimensionality of each edge's attribute vector. For our model, `num_edge_features` is 1 (the Euclidean distance).
    *   **Expected Data Type**: `torch.Tensor` with `dtype=torch.float`.

4.  **`batch` (Batch Vector)**:
    *   **Purpose**: When processing a batch of multiple graphs simultaneously (which is common in graph neural networks to leverage GPU parallelism), the `batch` tensor is used to identify which graph each node belongs to. It maps each node to its respective graph in the batched graph structure.
    *   **Expected Shape**: `[num_nodes]`, where `num_nodes` is the total number of nodes across all graphs in the batch.
    *   **Expected Data Type**: `torch.Tensor` with `dtype=torch.long`.
    *   **Note**: This `batch` vector is typically generated automatically by PyTorch Geometric's `DataLoader` when it aggregates multiple `Data` objects into a single `Batch` object for forwarding through the model. The model receives a single large graph that is a disconnected union of all individual graphs in the batch, and the `batch` vector helps the `TransformerConv` layers correctly perform computations within each individual graph.

In [31]:
val_loss = evaluate(model, val_ego_loader, criterion, device)
print(f'Validation Loss: {val_loss:.4f}')

Validation Loss: 333298.7409
