In [1]:
%pip install pytdc==1.1.12
from tdc.resource.pinnacle import PINNACLE
from tdc.benchmark_group import scdti_group

pinnacle = PINNACLE()
group = scdti_group.SCDTIGroup()

%pip install networkx
# get full sc_ppi graph
import pandas as pd
import networkx as nx
from collections import defaultdict

ppi = pinnacle.get_ppi()
scproteins = pinnacle.get_keys()
# scproteins.head()
protein_to_cells = defaultdict(set)

for _, row in scproteins.iterrows():
    protein_to_cells[row['target']].add(row['cell type'])

# make graph
G = nx.Graph()
for _, row in ppi.iterrows():
    protein1 = row['Protein A']
    protein2 = row['Protein B']
    cell_types1 = protein_to_cells[protein1]
    cell_types2 = protein_to_cells[protein2]

    if protein1 not in G:
        G.add_node(protein1, cell_types=list(cell_types1))
    if protein2 not in G:
        G.add_node(protein2, cell_types=list(cell_types2))
    G.add_edge(protein1, protein2)

# summarize graph
# 1. Preview first few nodes and their attributes
def preview_nodes(G, n=5):
    print("First", n, "nodes and their attributes:")
    for i, (node, attr) in enumerate(G.nodes(data=True)):
        if i >= n: break
        print(f"Node: {node}")
        print(f"Attributes: {attr}\n")

# 2. Preview first few edges
def preview_edges(G, n=5):
    print("First", n, "edges:")
    for i, (node1, node2, attr) in enumerate(G.edges(data=True)):
        if i >= n: break
        print(f"Edge: {node1} -- {node2}")
        print(f"Attributes: {attr}\n")

# 3. Get basic graph statistics
def preview_graph_stats(G):
    print("Graph Statistics:")
    print(f"Number of nodes: {G.number_of_nodes()}")
    print(f"Number of edges: {G.number_of_edges()}")
    print(f"Is directed: {G.is_directed()}")
    print(f"Is weighted: any('weight' in d for u,v,d in G.edges(data=True))")

    # Sample of node attributes available
    if G.nodes():
        sample_node = list(G.nodes(data=True))[0]
        print(f"\nSample node attributes: {list(sample_node[1].keys())}")

# Combined preview
def preview_graph(G, n=5):
    print("=== Graph Preview ===\n")
    preview_graph_stats(G)
    print("\n=== Sample Nodes ===\n")
    preview_nodes(G, n)
    print("=== Sample Edges ===\n")
    preview_edges(G, n)

    # Preview cell types if they exist
    sample_cell_types = set()
    for _, attr in list(G.nodes(data=True))[:n]:
        if 'cell_types' in attr:
            sample_cell_types.update(attr['cell_types'])
    if sample_cell_types:
        print("\n=== Sample Cell Types ===")
        print(list(sample_cell_types)[:5])

preview_graph(G)

Collecting pytdc==1.1.12
  Downloading pytdc-1.1.12.tar.gz (151 kB)
  Preparing metadata (setup.py) ... [?25l- \ done
[?25hCollecting accelerate==0.33.0 (from pytdc==1.1.12)
  Downloading accelerate-0.33.0-py3-none-any.whl.metadata (18 kB)
Collecting dataclasses<1.0,>=0.6 (from pytdc==1.1.12)
  Downloading dataclasses-0.6-py3-none-any.whl.metadata (3.0 kB)
Collecting datasets<2.20.0 (from pytdc==1.1.12)
  Downloading datasets-2.19.2-py3-none-any.whl.metadata (19 kB)
Collecting evaluate==0.4.2 (from pytdc==1.1.12)
  Downloading evaluate-0.4.2-py3-none-any.whl.metadata (9.3 kB)
Collecting fuzzywuzzy<1.0,>=0.18.0 (from pytdc==1.1.12)
  Downloading fuzzywuzzy-0.18.0-py2.py3-none-any.whl.metadata (4.9 kB)
Collecting huggingface_hub<1.0,>=0.20.3 (from pytdc==1.1.12)
  Downloading huggingface_hub-0.32.1-py3-none-any.whl.metadata (14 kB)
Collecting numpy<2.0.0,>=1.26.4 (from pytdc==1.1.12)
  Downloading numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 

Matplotlib is building the font cache; this may take a moment.
Downloading...
100%|██████████| 2.57M/2.57M [00:00<00:00, 8.14MiB/s]
Loading...
Downloading...
100%|██████████| 185k/185k [00:00<00:00, 4.43MiB/s]
Loading...
Downloading...
100%|██████████| 202M/202M [00:10<00:00, 19.4MiB/s] 
Downloading...
100%|██████████| 20.5M/20.5M [00:01<00:00, 19.6MiB/s]
Found local copy...
Loading...
Found local copy...
Loading...
Found local copy...
Found local copy...


In [2]:
def get_train_dev_test(seed=1):
  group = scdti_group.SCDTIGroup()
  train_val = group.get_train_valid_split(seed=seed)
  assert "train" in train_val, "no training set"
  assert "val" in train_val, "no validation set"
  assert len(train_val["train"]) > 0, "no entries in training set"
  tst = group.get_test(seed)["test"]
  return train_val["train"], train_val["val"], tst

preview_graph(G)

=== Graph Preview ===

Graph Statistics:
Number of nodes: 15461
Number of edges: 207640
Is directed: False
Is weighted: any('weight' in d for u,v,d in G.edges(data=True))

Sample node attributes: ['cell_types']

=== Sample Nodes ===

First 5 nodes and their attributes:
Node: FLNC
Attributes: {'cell_types': ['pericyte cell', 'hepatocyte', 'connective tissue cell', 'fast muscle cell', 'skeletal muscle satellite stem cell', 'smooth muscle cell', 'ciliary body', 'cardiac muscle cell', 'mesothelial cell', 'cardiac endothelial cell', 'artery endothelial cell', 'cell of skeletal muscle', 'fibroblast', 'vascular associated smooth muscle cell', 'ocular surface cell', 'endothelial cell of lymphatic vessel', 'tendon cell', 'bronchial smooth muscle cell', 'endothelial cell of artery', 'myometrial cell', 'fibroblast of breast', 'lymphatic endothelial cell', 'gut endothelial cell', 'fibroblast of cardiac tissue', 'medullary thymic epithelial cell']}

Node: SGCD
Attributes: {'cell_types': ['alveolar 

In [3]:
%pip install torch torch-geometric torch-scatter torch-sparse

import torch
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
from torch_geometric.utils import from_networkx
import numpy as np

# https://arxiv.org/abs/2105.14491
# ICLR 2022
class GATv2Predictor(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        hidden_channels=64,
        num_layers=2,
        heads=4,
        dropout=0.2
    ):
        super().__init__()
        self.num_layers = num_layers
        self.convs = torch.nn.ModuleList()

        # First layer
        self.convs.append(
            GATv2Conv(
                in_channels,
                hidden_channels,
                heads=heads,
                dropout=dropout
            )
        )

        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(
                GATv2Conv(
                    hidden_channels * heads,
                    hidden_channels,
                    heads=heads,
                    dropout=dropout
                )
            )

        # Output layer
        self.convs.append(
            GATv2Conv(
                hidden_channels * heads,
                2,  # Binary classification
                heads=1,
                concat=False,
                dropout=dropout
            )
        )

        self.dropout = dropout
        self.optimizer = torch.optim.Adam(
            self.parameters(),
            lr=0.005,
            weight_decay=5e-4
        )

    def forward(self, x, edge_index):
        for i in range(self.num_layers - 1):
            x = self.convs[i](x, edge_index)
            x = F.elu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.convs[-1](x, edge_index)
        return F.log_softmax(x, dim=1)

class SingleCellGATv2Benchmark:
    def __init__(
        self,
        hidden_channels=64,
        num_layers=4,
        heads=4,
        dropout=0.2,
        lr=0.001,
        epochs=200,
        device=None,
        cell_type=None
    ):
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.heads = heads
        self.dropout = dropout
        self.lr = lr
        self.epochs = epochs
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.cell_type = cell_type
        self.model = None

    def convert_to_pyg(self, G, target_labels, node_mapping=None):
        """Convert networkx graph to PyG format with node features"""
        if node_mapping is None:
            node_mapping = {node: i for i, node in enumerate(G.nodes())}

        # Convert graph to PyG format
        pyg_graph = from_networkx(G)

        # Create node features based on cell types and neighbor cell types
        unique_cell_types = set()
        for node in G.nodes():
            unique_cell_types.update(G.nodes[node]['cell_types'])
        cell_type_to_idx = {ct: idx for idx, ct in enumerate(sorted(unique_cell_types))}
        num_cell_types = len(unique_cell_types)

        # Initialize features: [own_cell_types]
        num_nodes = G.number_of_nodes()
        features = torch.zeros((num_nodes, num_cell_types), dtype=torch.float)

        for i, node in enumerate(G.nodes()):
            # Own cell types
            for cell_type in G.nodes[node]['cell_types']:
                features[i, cell_type_to_idx[cell_type]] = 1

        pyg_graph.x = features

        # Add node labels
        labels = torch.tensor([target_labels.get(node, -1) for node in G.nodes()])
        pyg_graph.y = labels
        pyg_graph.train_mask = torch.tensor([label!=-1 for label in labels])

        return pyg_graph, node_mapping

    def train_epoch(self, model, data, optimizer=None, criterion=None):
        criterion = torch.nn.CrossEntropyLoss() if criterion is None else criterion
        optmizer = model.optimizer if optimizer is None else optimizer
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        return loss.item(), optimizer, criterion, out

    def fit(self, G, target_labels, train_mask=None, val_mask=None):
        # Convert graph to PyG format
        data, node_mapping = self.convert_to_pyg(G, target_labels)
        data = data.to(self.device)

        # train_mask = train_mask.to(self.device)
        # val_mask = val_mask.to(self.device)

        # invert node mapping
        imapping = {v:k for k,v in node_mapping.items()}

        # merge masks with null label masks
        data.train_mask = torch.tensor([(imapping[idx] in train_mask) and data.train_mask[idx] for idx in range(len(data.train_mask))])
        data.val_mask = torch.tensor([(imapping[idx] in val_mask) and data.train_mask[idx] for idx in range(len(data.train_mask))])

        # Initialize model
        self.model = GATv2Predictor(
            in_channels=data.x.size(1),
            hidden_channels=self.hidden_channels,
            num_layers=self.num_layers,
            heads=self.heads,
            dropout=self.dropout
        ).to(self.device)

        # Initialize optimizer
        optimizer = self.model.optimizer

        # Training loop
        best_val_loss = float('inf')
        patience = 10
        patience_counter = 0

        # loss fn
        criterion = torch.nn.CrossEntropyLoss()

        # train
        self.model.train()

        for epoch in range(self.epochs):
            train_loss, optimizer, criterion, out = self.train_epoch(self.model, data, optimizer=optimizer, criterion=criterion)

            # # Validation
            # self.model.eval()
            with torch.no_grad():
                val_loss = criterion(
                    self.model(data.x, data.edge_index)[data.val_mask],
                    data.y[data.val_mask]
                ).item()

            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

            if epoch % 10 == 0:
                print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

        return self

    def predict(self, G, nodes):
        """
        Predict target labels for given nodes

        Args:
            G: networkx graph
            nodes: list of nodes to predict for

        Returns:
            pd.DataFrame: DataFrame with columns ['node', 'predicted_label', 'prediction_probability']
        """
        self.model.eval()
        data, node_mapping = self.convert_to_pyg(G, {node: 0 for node in G.nodes()})
        data = data.to(self.device)

        with torch.no_grad():
            out = self.model(data.x, data.edge_index)
            pred = out.max(1)[1]
            pred_proba = torch.exp(out)[:, 1]  # Probability of positive class

            # Create DataFrame only for input nodes
            results_df = pd.DataFrame({
                'node': nodes,
                'predicted_label': pred.cpu().numpy()[[list(G.nodes()).index(node) for node in nodes]],
                'prediction_probability': pred_proba.cpu().numpy()[[list(G.nodes()).index(node) for node in nodes]]
            })

            return results_df

def save_to_azure_blob(dataframe, filename):
    from azure.storage.blob import BlobServiceClient
    from io import StringIO

    # Set up Azure Blob Storage
    connect_str = "DefaultEndpointsProtocol=https;AccountName=mlstudio4221270580;AccountKey=akCuz92UW+pOzfmdb48FO9aFPrNFXMNtsWVqTdtqdZ+INQEA2i3qo43DSx8/DvjcQLVNi5aT6ZS2+AStk0ZEnQ==;EndpointSuffix=core.windows.net"
    blob_service_client = BlobServiceClient.from_connection_string(connect_str)

    # Define container and file
    container_name = "gatv2"
    blob_name = filename

    # Upload file
    blob_client = blob_service_client.get_blob_client(container=container_name, blob=blob_name)
    csv_buffer = StringIO()
    dataframe.to_csv(csv_buffer, index=False)
    csv_buffer.seek(0)  # Reset the pointer to the beginning of the buffer

    blob_client.upload_blob(csv_buffer.getvalue(), overwrite=True)

    print(f"File uploaded to Azure Blob Storage as {blob_name}")

out = []
for i in range(1,11):
  train, dev, test = get_train_dev_test(seed=i)
  full = pd.concat([train, dev, test], axis=0, ignore_index=True)
  full_ra = full[full["disease"] == "RA"]
  full_ibd = full[full["disease"] == "IBD"]
  olabels_ra = {x["name"]:int(x["y"]) for _, x in full_ra.iterrows()}
  olabels_ibd = {x["name"]:int(x["y"]) for _, x in full_ibd.iterrows()}
  benchmark = SingleCellGATv2Benchmark()
  # handle RA
  print("building RA model")
  train_mask = train[train["disease"] == "RA"]["name"].values
  val_mask = dev[dev["disease"] == "RA"]["name"].values
  test_mask = test[test["disease"] == "RA"]["name"].values
  benchmark.fit(G, olabels_ra, train_mask=train_mask, val_mask=val_mask)
  print("making RA predictions")
  out_ra = benchmark.predict(G, test_mask)
  test_ra = test[test["disease"] == "RA"]
  test_ra["preds"] = out_ra["predicted_label"].values
  test_ra["preds_proba"] = out_ra["prediction_probability"].values
  # handle IBD
  print("building IBD model")
  train_mask = train[train["disease"] == "IBD"]["name"].values
  val_mask = dev[dev["disease"] == "IBD"]["name"].values
  test_mask = test[test["disease"] == "IBD"]["name"].values
  benchmark.fit(G, olabels_ibd, train_mask=train_mask, val_mask=val_mask)
  print("making IBD predictions")
  out_ibd = benchmark.predict(G, test_mask)
  test_ibd = test[test["disease"] == "IBD"]
  test_ibd["preds"] = out_ibd["predicted_label"].values
  test_ibd["preds_proba"] = out_ibd["prediction_probability"].values
  # construct preds map
  print("constructing predictions map")
  preds_map = defaultdict(int)
  for _, row in test_ra.iterrows():
    preds_map[("RA", row["cell_type_label"], row["name"])] = row["preds"]
  for _, row in test_ibd.iterrows():
    preds_map[("IBD", row["cell_type_label"], row["name"])] = row["preds"]
  test["preds"] = test.apply(lambda x: preds_map[("RA", x["cell_type_label"], x["name"])] if x["disease"] == "RA" else preds_map[("IBD", x["cell_type_label"], x["name"])], axis=1)
  # Save the DataFrame as a CSV file to blob store
  filename = f'pinnacle_GATv2_seed{i}_camera.csv'
  save_to_azure_blob(test, filename)
  print("calling evaluation group")
  res = group.evaluate(test, seed=i)
  out.append(res)
  print("done and appender. proceeding to next seed.")

out

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
Collecting torch-scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
  Preparing metadata (setup.py) ... [?25l- done
[?25hCollecting torch-sparse
  Downloading torch_sparse-0.6.18.tar.gz (209 kB)
  Preparing metadata (setup.py) ... [?25l- done
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: torch-scatter, torch-sparse
  Building wheel for torch-scatter (setup.py) ... [?25l- \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - \ | / - done
[?25h  Created wheel for torch-scatter: filename=torch_scatter-2.1.2-cp310-cp310-linux_x86_64.whl size=3986552 sha256=248deec39b8fea68b61fa512027d69f485d2230c13fa2d61eb38c8cc3fe2671d
  Stored in directory: /home/azure

Found local copy...
Loading...
Found local copy...
Loading...
Found local copy...
Found local copy...
100%|██████████| 5.20M/5.20M [00:00<00:00, 24.0MiB/s]
Extracting zip file...
100%|██████████| 5.20M/5.20M [00:00<00:00, 24.2MiB/s]
Extracting zip file...
Done!
100%|██████████| 5.20M/5.20M [00:00<00:00, 16.3MiB/s]
Extracting zip file...
Done!
Downloading...
100%|██████████| 5.20M/5.20M [00:00<00:00, 20.4MiB/s]
Extracting zip file...
Done!
Downloading...
100%|██████████| 2.57M/2.57M [00:00<00:00, 10.1MiB/s]
Loading...
Downloading...
100%|██████████| 185k/185k [00:00<00:00, 4.15MiB/s]
Loading...
Downloading...
100%|██████████| 202M/202M [00:08<00:00, 23.1MiB/s] 
Downloading...
100%|██████████| 20.5M/20.5M [00:01<00:00, 20.5MiB/s]
Downloading...
100%|██████████| 5.26M/5.26M [00:00<00:00, 13.9MiB/s]
Extracting zip file...
Done!
100%|██████████| 5.26M/5.26M [00:00<00:00, 23.2MiB/s]
Extracting zip file...
100%|██████████| 5.26M/5.26M [00:00<00:00, 25.8MiB/s]
Extracting zip file...
Downloadin

[{'RA': 0.33, 'IBD': 0.38000000000000006},
 {'RA': 0.39, 'IBD': 0.33},
 {'RA': 0.35000000000000003, 'IBD': 0.3500000000000001},
 {'RA': 0.43999999999999995, 'IBD': 0.36000000000000004},
 {'RA': 0.36000000000000004, 'IBD': 0.3000000000000001},
 {'RA': 0.7483333333333333, 'IBD': 0.33000000000000007},
 {'RA': 0.3200000000000001, 'IBD': 0.2900000000000001},
 {'RA': 0.38, 'IBD': 0.32},
 {'RA': 0.37, 'IBD': 0.8191666666666666},
 {'RA': 0.3200000000000001, 'IBD': 0.3000000000000001}]