In [1]:
import numpy as np
import pandas as pd
import torch
import torch_geometric.transforms as T
import os

from collections import defaultdict
from torch_geometric.data import HeteroData
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import MessagePassing, GCNConv, HeteroConv, GraphConv, Linear
import torch.nn.functional as F
from torch import Tensor
import matplotlib.pyplot as plt

In [2]:
from sentence_transformers import SentenceTransformer

In [3]:
import pickle

In [4]:
from neo4j import GraphDatabase

URI = "neo4j://localhost:7687"
AUTH = ("test", "666666")

with GraphDatabase.driver(URI, auth=AUTH) as driver:

    driver.verify_connectivity()


def fetch_data(query, parameters=None, **kwargs):
    with driver.session(database="mtg-wilcox") as session:
        result = session.run(query, parameters, **kwargs)
        return pd.DataFrame([r.values() for r in result], columns=result.keys())

  driver.verify_connectivity()


In [5]:
class SequenceEncoder(object):
    # The 'SequenceEncoder' encodes raw column strings into embeddings using a sentence transformer.
    def __init__(self, model_name="all-MiniLM-L6-v2", device=None):
        self.device = device
        self.model = SentenceTransformer(model_name, device=device)

    @torch.no_grad()
    def __call__(self, df):
        x = self.model.encode(
            df.values,
            show_progress_bar=True,
            convert_to_tensor=True,
            device=self.device,
        )
        return x.cpu()


class GenresEncoder(object):
    # The 'GenreEncoder' splits the raw column strings by 'sep' and converts
    # individual elements to categorical labels.
    def __init__(self, sep="|"):
        self.sep = sep

    def __call__(self, df):
        genres = set(g for col in df.values for g in col.split(self.sep))
        mapping = {genre: i for i, genre in enumerate(genres)}

        x = torch.zeros(len(df), len(mapping))
        for i, col in enumerate(df.values):
            for genre in col.split(self.sep):
                x[i, mapping[genre]] = 1
        return x


class IdentityEncoder(object):
    # The 'IdentityEncoder' takes the raw column values and converts them to
    # PyTorch tensors.
    def __init__(self, dtype=None, is_list=False):
        self.dtype = dtype
        self.is_list = is_list

    def __call__(self, df):
        if self.is_list:
            return torch.stack([torch.tensor(el) for el in df.values])
        return torch.from_numpy(df.values).to(self.dtype)

In [6]:
def load_node(
    cypher, index_col, encoders=None, category_col=None, parameters=None, **kwargs
):
    # Execute the cypher query and retrieve data from Neo4j
    df = fetch_data(cypher, parameters, **kwargs)
    df.set_index(index_col, inplace=True)
    # Define node mapping
    mapping = {index: i for i, index in enumerate(df.index.unique())}
    # Define node features
    x = None
    if encoders is not None:
        xs = [encoder(df[col]) for col, encoder in encoders.items()]
        x = torch.cat(xs, dim=-1)

    y = None
    if category_col is not None:
        # Get unique categories and map to numerical labels
        categories = df[category_col].unique()
        category_to_idx = {cat: idx for idx, cat in enumerate(sorted(categories))}

        # Map category column to numerical labels
        y = df[category_col].map(category_to_idx).values
        y = torch.tensor(y, dtype=torch.long)  # length: n_nodes
        return x, y, mapping, category_to_idx
    return x, mapping

In [7]:
def load_edge(
    cypher,
    src_index_col,
    src_mapping,
    dst_index_col,
    dst_mapping,
    encoders=None,
    parameters=None,
    **kwargs
):
    # Execute the cypher query and retrieve data from Neo4j
    df = fetch_data(cypher, parameters, **kwargs)
    # Define edge index
    src = [src_mapping[index] for index in df[src_index_col]]
    dst = [dst_mapping[index] for index in df[dst_index_col]]
    edge_index = torch.tensor([src, dst])

    # Define edge features
    edge_attr = None
    if encoders is not None:
        edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()]
        edge_attr = torch.cat(edge_attrs, dim=-1)

    return edge_index, edge_attr

In [None]:
# Get all cell types with their exact cell type (subclass)
ct_query = """
MATCH (ct:CellType)
RETURN ct.id AS cell_type_name_species, ct.cell_type_name AS cell_type_name
"""

ct_x, ct_y, ct_mapping, y_mapping = load_node(
    ct_query, index_col="cell_type_name_species", category_col="cell_type_name"
)

# ct_x has no node features
# ct_mapping is just a dictionary mapping cell_type_name to index numbers

In [9]:
ct_x

In [10]:
ct_y

tensor([ 0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23, 19, 13, 12, 21, 18, 20,
         1,  6,  4,  3, 10,  9,  0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23,
        19, 13, 12, 21, 18, 20,  1,  6,  4,  3, 10,  9,  0, 16, 22, 14, 15,  2,
         7,  8, 11,  5, 17, 23, 19, 13, 12, 21, 18, 20,  1,  6,  4,  3, 10,  9,
         0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23, 19, 13, 12, 21, 18, 20,
         1,  6,  4,  3, 10,  9,  0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23,
        19, 13, 12, 21, 18, 20,  1,  6,  4,  3, 10,  9])

In [11]:
ct_mapping

{'Astro_H.sapiens': 0,
 'Oligo_H.sapiens': 1,
 'VLMC_H.sapiens': 2,
 'Micro-PVM_H.sapiens': 3,
 'OPC_H.sapiens': 4,
 'Endo_H.sapiens': 5,
 'L5-6 NP_H.sapiens': 6,
 'L6 CT_H.sapiens': 7,
 'L6b_H.sapiens': 8,
 'L5 ET_H.sapiens': 9,
 'Pax6_H.sapiens': 10,
 'Vip_H.sapiens': 11,
 'Sncg_H.sapiens': 12,
 'Lamp5_Lhx6_H.sapiens': 13,
 'Lamp5_H.sapiens': 14,
 'Sst Chodl_H.sapiens': 15,
 'Pvalb_H.sapiens': 16,
 'Sst_H.sapiens': 17,
 'Chandelier_H.sapiens': 18,
 'L5 IT_H.sapiens': 19,
 'L4 IT_H.sapiens': 20,
 'L2-3 IT_H.sapiens': 21,
 'L6 IT Car3_H.sapiens': 22,
 'L6 IT_H.sapiens': 23,
 'Astro_M.mulatta': 24,
 'Oligo_M.mulatta': 25,
 'VLMC_M.mulatta': 26,
 'Micro-PVM_M.mulatta': 27,
 'OPC_M.mulatta': 28,
 'Endo_M.mulatta': 29,
 'L5-6 NP_M.mulatta': 30,
 'L6 CT_M.mulatta': 31,
 'L6b_M.mulatta': 32,
 'L5 ET_M.mulatta': 33,
 'Pax6_M.mulatta': 34,
 'Vip_M.mulatta': 35,
 'Sncg_M.mulatta': 36,
 'Lamp5_Lhx6_M.mulatta': 37,
 'Lamp5_M.mulatta': 38,
 'Sst Chodl_M.mulatta': 39,
 'Pvalb_M.mulatta': 40,
 'Sst_

In [None]:
with open(
    f"all_sp_heterodata_only_gene_cell_edges/mtg_all_sp_wilcox_data_ct_mapping.pkl",
    "wb",
) as f:
    pickle.dump(ct_mapping, f)

In [13]:
120 / 5

24.0

In [14]:
y_mapping

{'Astro': 0,
 'Chandelier': 1,
 'Endo': 2,
 'L2-3 IT': 3,
 'L4 IT': 4,
 'L5 ET': 5,
 'L5 IT': 6,
 'L5-6 NP': 7,
 'L6 CT': 8,
 'L6 IT': 9,
 'L6 IT Car3': 10,
 'L6b': 11,
 'Lamp5': 12,
 'Lamp5_Lhx6': 13,
 'Micro-PVM': 14,
 'OPC': 15,
 'Oligo': 16,
 'Pax6': 17,
 'Pvalb': 18,
 'Sncg': 19,
 'Sst': 20,
 'Sst Chodl': 21,
 'VLMC': 22,
 'Vip': 23}

In [None]:
with open(
    f"all_sp_heterodata_only_gene_cell_edges/mtg_all_sp_wilcox_data_ct_y_mapping.pkl",
    "wb",
) as f:
    pickle.dump(y_mapping, f)

In [None]:
# Only get gene nodes, add ESM2 embeddings later because they are in the cluster
#
gene_query = """
MATCH (gene:Gene)
RETURN gene.id as gene_id
"""
gene_x, gene_mapping = load_node(
    gene_query,
    index_col="gene_id",
)

In [17]:
gene_x

In [18]:
gene_mapping

{'ARRB2_P.troglodytes': 0,
 'CRY2_P.troglodytes': 1,
 'ARRB1_P.troglodytes': 2,
 'CRY1_P.troglodytes': 3,
 'IGF1R_P.troglodytes': 4,
 'CAMK2D_P.troglodytes': 5,
 'MAPK8_P.troglodytes': 6,
 'HSP90AA1_P.troglodytes': 7,
 'FYN_P.troglodytes': 8,
 'AGO2_P.troglodytes': 9,
 'PRKCH_P.troglodytes': 10,
 'AP2A1_P.troglodytes': 11,
 'AP2S1_P.troglodytes': 12,
 'MAPK9_P.troglodytes': 13,
 'TGFBR1_P.troglodytes': 14,
 'SPAST_P.troglodytes': 15,
 'RAB7A_P.troglodytes': 16,
 'PAFAH1B1_P.troglodytes': 17,
 'SIRT4_P.troglodytes': 18,
 'SNX1_P.troglodytes': 19,
 'TGFBR2_P.troglodytes': 20,
 'RPA2_P.troglodytes': 21,
 'TMEM30A_P.troglodytes': 22,
 'EIF3J_P.troglodytes': 23,
 'HGS_P.troglodytes': 24,
 'PSMB5_P.troglodytes': 25,
 'TRIM5_P.troglodytes': 26,
 'ATG5_P.troglodytes': 27,
 'EIF3H_P.troglodytes': 28,
 'EIF3M_P.troglodytes': 29,
 'EIF3K_P.troglodytes': 30,
 'CACNA1D_P.troglodytes': 31,
 'RAB11B_P.troglodytes': 32,
 'FEN1_P.troglodytes': 33,
 'SNX5_P.troglodytes': 34,
 'POFUT1_P.troglodytes': 35,

In [19]:
len(gene_mapping)

92549

In [None]:
with open(
    f"all_sp_heterodata_only_gene_cell_edges/mtg_all_sp_wilcox_data_gene_mapping.pkl",
    "wb",
) as f:
    pickle.dump(gene_mapping, f)

In [26]:
marker_query = """
MATCH (g:Gene)-[r:GeneWilcoxMarkerInCellType]->(ct:CellType)
WHERE r.avg_log2fc >= 3
RETURN g.id as gene_id, ct.id as cell_type_name_species, r.avg_log2fc as avg_log2fc
"""

edge_index, edge_weights = load_edge(
    marker_query,
    src_index_col="gene_id",
    src_mapping=gene_mapping,  # the two index mappings were used for this
    dst_index_col="cell_type_name_species",
    dst_mapping=ct_mapping,
    encoders={
        "avg_log2fc": IdentityEncoder(dtype=torch.float32)
    },  # remember to set the correct dtype for identity encoding the edge weight
)

In [27]:
edge_weights

tensor([3.0007, 3.0012, 3.0025,  ..., 6.7956, 7.1142, 7.1893])

In [28]:
edge_index

tensor([[85449, 85448, 33421,  ..., 23867, 75808, 74932],
        [    0,     0,     0,  ...,   119,   119,   119]])

In [29]:
edge_index.shape

torch.Size([2, 32540])

In [30]:
marker_query_2 = """
MATCH (g:Gene)-[r:GeneWilcoxMarkerInCellType]->(ct:CellType)
WHERE r.avg_log2fc >= 3
RETURN g.id as gene_id, ct.id as cell_type_name_species, r.avg_log2fc as avg_log2fc
"""

edge_index_2, edge_weights_2 = load_edge(
    marker_query_2,
    src_index_col="cell_type_name_species",
    src_mapping=ct_mapping,  # the two index mappings were used for this
    dst_index_col="gene_id",
    dst_mapping=gene_mapping,
    encoders={
        "avg_log2fc": IdentityEncoder(dtype=torch.float32)
    },  # remember to set the correct dtype for identity encoding the edge weight
)

In [31]:
edge_index_2
# should be two rows-switched edge_index

tensor([[    0,     0,     0,  ...,   119,   119,   119],
        [85449, 85448, 33421,  ..., 23867, 75808, 74932]])

In [32]:
edge_weights_2

# should be the same as edge_weights 1

tensor([3.0007, 3.0012, 3.0025,  ..., 6.7956, 7.1142, 7.1893])

In [33]:
data = HeteroData()

In [34]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [6]:
#  The edge_label tensor holds the ground truth labels that you want the model to predict for specific edges.
#  Used for edge prediction tasks

In [36]:
# Add user node features for message passing:
data["cell_type"].x = torch.eye(len(ct_mapping), device=device)
data["cell_type"].y = ct_y
# Add movie node features
data["gene"].x = torch.eye(len(gene_mapping), device=device)
data

HeteroData(
  cell_type={
    x=[120, 120],
    y=[120],
  },
  gene={ x=[92549, 92549] }
)

In [None]:
# Add ratings between users and movies
data["gene", "is_wilcox_marker_of", "cell_type"].edge_index = edge_index
data["gene", "is_wilcox_marker_of", "cell_type"].edge_weights = edge_weights

data["cell_type", "rev_is_wilcox_marker_of", "gene"].edge_index = edge_index_2
data["cell_type", "rev_is_wilcox_marker_of", "gene"].edge_weights = edge_weights_2

data

HeteroData(
  cell_type={
    x=[120, 120],
    y=[120],
  },
  gene={ x=[92549, 92549] },
  (gene, is_wilcox_marker_of, cell_type)={
    edge_index=[2, 32540],
    edge_weights=[32540],
  },
  (cell_type, rev_is_wilcox_marker_of, gene)={
    edge_index=[2, 32540],
    edge_weights=[32540],
  }
)

In [None]:
data.to(device, non_blocking=True)

HeteroData(
  cell_type={
    x=[120, 120],
    y=[120],
  },
  gene={ x=[92549, 92549] },
  (gene, is_wilcox_marker_of, cell_type)={
    edge_index=[2, 32540],
    edge_weights=[32540],
  },
  (cell_type, rev_is_wilcox_marker_of, gene)={
    edge_index=[2, 32540],
    edge_weights=[32540],
  }
)

In [None]:
torch.save(
    data,
    "all_sp_heterodata_only_gene_cell_edges/mtg_all_sp_wilcox_heterodata_only_gene_cell_edges.pt",
)