In [86]:
import numpy as np
import pandas as pd
import torch
import torch_geometric.transforms as T
import os

from collections import defaultdict
from torch_geometric.data import HeteroData
import torch.nn as nn
import torch.optim as optim
from torch_geometric.nn import MessagePassing, GCNConv, HeteroConv, GraphConv, Linear
import torch.nn.functional as F
from torch import Tensor
import matplotlib.pyplot as plt

In [87]:
from sentence_transformers import SentenceTransformer

In [88]:
import pickle

In [89]:
from neo4j import GraphDatabase

URI = "neo4j://localhost:7687"
AUTH = ("test", "666666")

with GraphDatabase.driver(URI, auth=AUTH) as driver:

    driver.verify_connectivity()


def fetch_data(query, parameters=None, **kwargs):
    with driver.session(database="mtg-wilcox") as session:
        result = session.run(query, parameters, **kwargs)
        return pd.DataFrame([r.values() for r in result], columns=result.keys())

  driver.verify_connectivity()


In [91]:
class SequenceEncoder(object):
    # The 'SequenceEncoder' encodes raw column strings into embeddings using a sentence transformer.
    def __init__(self, model_name="all-MiniLM-L6-v2", device=None):
        self.device = device
        self.model = SentenceTransformer(model_name, device=device)

    @torch.no_grad()
    def __call__(self, df):
        x = self.model.encode(
            df.values,
            show_progress_bar=True,
            convert_to_tensor=True,
            device=self.device,
        )
        return x.cpu()


class GenresEncoder(object):
    # The 'GenreEncoder' splits the raw column strings by 'sep' and converts
    # individual elements to categorical labels.
    def __init__(self, sep="|"):
        self.sep = sep

    def __call__(self, df):
        genres = set(g for col in df.values for g in col.split(self.sep))
        mapping = {genre: i for i, genre in enumerate(genres)}

        x = torch.zeros(len(df), len(mapping))
        for i, col in enumerate(df.values):
            for genre in col.split(self.sep):
                x[i, mapping[genre]] = 1
        return x


class IdentityEncoder(object):
    # The 'IdentityEncoder' takes the raw column values and converts them to
    # PyTorch tensors.
    def __init__(self, dtype=None, is_list=False):
        self.dtype = dtype
        self.is_list = is_list

    def __call__(self, df):
        if self.is_list:
            return torch.stack([torch.tensor(el) for el in df.values])
        return torch.from_numpy(df.values).to(self.dtype)

In [92]:
def load_node(
    cypher, index_col, encoders=None, category_col=None, parameters=None, **kwargs
):
    # Execute the cypher query and retrieve data from Neo4j
    df = fetch_data(cypher, parameters, **kwargs)
    df.set_index(index_col, inplace=True)
    # Define node mapping
    mapping = {index: i for i, index in enumerate(df.index.unique())}
    # Define node features
    x = None
    if encoders is not None:
        xs = [encoder(df[col]) for col, encoder in encoders.items()]
        x = torch.cat(xs, dim=-1)

    y = None
    if category_col is not None:
        # Get unique categories and map to numerical labels
        categories = df[category_col].unique()
        category_to_idx = {cat: idx for idx, cat in enumerate(sorted(categories))}

        # Map category column to numerical labels
        y = df[category_col].map(category_to_idx).values
        y = torch.tensor(y, dtype=torch.long)  # length: n_nodes
        return x, y, mapping, category_to_idx
    return x, mapping

In [93]:
def load_edge(
    cypher,
    src_index_col,
    src_mapping,
    dst_index_col,
    dst_mapping,
    encoders=None,
    parameters=None,
    **kwargs
):
    # Execute the cypher query and retrieve data from Neo4j
    df = fetch_data(cypher, parameters, **kwargs)
    # Define edge index
    src = [src_mapping[index] for index in df[src_index_col]]
    dst = [dst_mapping[index] for index in df[dst_index_col]]
    edge_index = torch.tensor([src, dst])

    # Define edge features
    edge_attr = None
    if encoders is not None:
        edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()]
        edge_attr = torch.cat(edge_attrs, dim=-1)

    return edge_index, edge_attr

In [94]:
ct_query = """
MATCH (ct:CellType)
RETURN ct.id AS cell_type_name_species, ct.broad_taxo_cs AS broad_taxo_cs
"""

ct_x, ct_y, ct_mapping, y_mapping = load_node(
    ct_query, index_col="cell_type_name_species", category_col="broad_taxo_cs"
)

# ct_x has no node features
# ct_mapping is just a dictionary mapping cell_type_name to index numbers

In [95]:
ct_x

In [96]:
ct_y

tensor([0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3,
        0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3,
        0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3,
        0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3,
        0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3])

In [97]:
ct_mapping

{'Astro_H.sapiens': 0,
 'Oligo_H.sapiens': 1,
 'VLMC_H.sapiens': 2,
 'Micro-PVM_H.sapiens': 3,
 'OPC_H.sapiens': 4,
 'Endo_H.sapiens': 5,
 'L5-6 NP_H.sapiens': 6,
 'L6 CT_H.sapiens': 7,
 'L6b_H.sapiens': 8,
 'L5 ET_H.sapiens': 9,
 'Pax6_H.sapiens': 10,
 'Vip_H.sapiens': 11,
 'Sncg_H.sapiens': 12,
 'Lamp5_Lhx6_H.sapiens': 13,
 'Lamp5_H.sapiens': 14,
 'Sst Chodl_H.sapiens': 15,
 'Pvalb_H.sapiens': 16,
 'Sst_H.sapiens': 17,
 'Chandelier_H.sapiens': 18,
 'L5 IT_H.sapiens': 19,
 'L4 IT_H.sapiens': 20,
 'L2-3 IT_H.sapiens': 21,
 'L6 IT Car3_H.sapiens': 22,
 'L6 IT_H.sapiens': 23,
 'Astro_M.mulatta': 24,
 'Oligo_M.mulatta': 25,
 'VLMC_M.mulatta': 26,
 'Micro-PVM_M.mulatta': 27,
 'OPC_M.mulatta': 28,
 'Endo_M.mulatta': 29,
 'L5-6 NP_M.mulatta': 30,
 'L6 CT_M.mulatta': 31,
 'L6b_M.mulatta': 32,
 'L5 ET_M.mulatta': 33,
 'Pax6_M.mulatta': 34,
 'Vip_M.mulatta': 35,
 'Sncg_M.mulatta': 36,
 'Lamp5_Lhx6_M.mulatta': 37,
 'Lamp5_M.mulatta': 38,
 'Sst Chodl_M.mulatta': 39,
 'Pvalb_M.mulatta': 40,
 'Sst_

In [None]:
with open(f"all_sp_heterodata/mtg_all_sp_wilcox_data_with_og.pkl", "wb") as f:
    pickle.dump(ct_mapping, f)

In [99]:
120 / 5

24.0

In [100]:
y_mapping

{'Astro': 0,
 'CGE-derived': 1,
 'Endo': 2,
 'IT': 3,
 'MGE-derived': 4,
 'Micro-PVM': 5,
 'Non-IT': 6,
 'OPC': 7,
 'Oligo': 8,
 'VLMC': 9}

In [101]:
# now we give gene nodes encoded features based on GOs, KEGGs, etc.

In [15]:
gene_query = """
MATCH (gene:Gene)
RETURN gene.id as gene_id, gene.gos as gos, gene.pfams as pfams, gene.description as description, gene.kegg_pathway as kegg_pathways
"""
gene_x, gene_mapping = load_node(
    gene_query,
    index_col="gene_id",
    encoders={
        "description": SequenceEncoder(),
        "gos": GenresEncoder(sep=","),
        "pfams": GenresEncoder(sep=","),
        "kegg_pathways": GenresEncoder(sep=","),
    },
)

Batches:   0%|          | 0/2893 [00:00<?, ?it/s]

In [16]:
gene_x.shape

torch.Size([92549, 29954])

In [17]:
gene_mapping

{'ARRB2_P.troglodytes': 0,
 'CRY2_P.troglodytes': 1,
 'ARRB1_P.troglodytes': 2,
 'CRY1_P.troglodytes': 3,
 'IGF1R_P.troglodytes': 4,
 'CAMK2D_P.troglodytes': 5,
 'MAPK8_P.troglodytes': 6,
 'HSP90AA1_P.troglodytes': 7,
 'FYN_P.troglodytes': 8,
 'AGO2_P.troglodytes': 9,
 'PRKCH_P.troglodytes': 10,
 'AP2A1_P.troglodytes': 11,
 'AP2S1_P.troglodytes': 12,
 'MAPK9_P.troglodytes': 13,
 'TGFBR1_P.troglodytes': 14,
 'SPAST_P.troglodytes': 15,
 'RAB7A_P.troglodytes': 16,
 'PAFAH1B1_P.troglodytes': 17,
 'SIRT4_P.troglodytes': 18,
 'SNX1_P.troglodytes': 19,
 'TGFBR2_P.troglodytes': 20,
 'RPA2_P.troglodytes': 21,
 'TMEM30A_P.troglodytes': 22,
 'EIF3J_P.troglodytes': 23,
 'HGS_P.troglodytes': 24,
 'PSMB5_P.troglodytes': 25,
 'TRIM5_P.troglodytes': 26,
 'ATG5_P.troglodytes': 27,
 'EIF3H_P.troglodytes': 28,
 'EIF3M_P.troglodytes': 29,
 'EIF3K_P.troglodytes': 30,
 'CACNA1D_P.troglodytes': 31,
 'RAB11B_P.troglodytes': 32,
 'FEN1_P.troglodytes': 33,
 'SNX5_P.troglodytes': 34,
 'POFUT1_P.troglodytes': 35,

In [18]:
len(gene_mapping)

92549

In [None]:
marker_query = """
MATCH (g:Gene)-[r:GeneWilcoxMarkerInCellType]->(ct:CellType)
WHERE r.avg_log2fc >= 4
RETURN g.id as gene_id, ct.id as cell_type_name_species, r.avg_log2fc as avg_log2fc
"""

edge_index, edge_weights = load_edge(
    marker_query,
    src_index_col="gene_id",
    src_mapping=gene_mapping,  # the two index mappings were used for this
    dst_index_col="cell_type_name_species",
    dst_mapping=ct_mapping,
    encoders={
        "avg_log2fc": IdentityEncoder(dtype=torch.float32)
    },  # remember to set the correct dtype for identity encoding the edge weight
)

In [None]:
edge_weights

tensor([4.0142, 4.0404, 4.0512,  ..., 6.7956, 7.1142, 7.1893])

In [None]:
edge_index

tensor([[85348, 85347, 85346,  ..., 23867, 75808, 74932],
        [    0,     0,     0,  ...,   119,   119,   119]])

In [None]:
edge_index.shape

torch.Size([2, 17520])

In [None]:
marker_query_2 = """
MATCH (g:Gene)-[r:GeneWilcoxMarkerInCellType]->(ct:CellType)
WHERE r.avg_log2fc >= 4 
RETURN g.id as gene_id, ct.id as cell_type_name_species, r.avg_log2fc as avg_log2fc
"""

edge_index_2, edge_weights_2 = load_edge(
    marker_query_2,
    src_index_col="cell_type_name_species",
    src_mapping=ct_mapping,  # the two index mappings were used for this
    dst_index_col="gene_id",
    dst_mapping=gene_mapping,
    encoders={
        "avg_log2fc": IdentityEncoder(dtype=torch.float32)
    },  # remember to set the correct dtype for identity encoding the edge weight
)

In [None]:
edge_index_2
# should be two rows-switched edge_index

tensor([[    0,     0,     0,  ...,   119,   119,   119],
        [85348, 85347, 85346,  ..., 23867, 75808, 74932]])

In [None]:
edge_weights_2

# should be the same as edge_weights 1

tensor([4.0142, 4.0404, 4.0512,  ..., 6.7956, 7.1142, 7.1893])

In [19]:
og_query = """
MATCH (og:OrthologousGroup)
WHERE og.eggnog_dataset_name = 'Mammalia'
RETURN og.orthologous_group_id as og_id
"""

og_x, og_mapping = load_node(og_query, index_col="og_id")

In [20]:
og_mapping

{'3J40A': 0,
 '3J7G6': 1,
 '3J605': 2,
 '3J1XG': 3,
 '3JEDX': 4,
 '3J2JM': 5,
 '3J5E0': 6,
 '3JAF8': 7,
 '3J90S': 8,
 '3JDI6': 9,
 '3J6VY': 10,
 '3J2VQ': 11,
 '3JDG2': 12,
 '3J4QZ': 13,
 '3J9JY': 14,
 '3J88D': 15,
 '3JEHT': 16,
 '3J704': 17,
 '3J2G1': 18,
 '3JF4R': 19,
 '3JFKB': 20,
 '3JBWT': 21,
 '3JCP3': 22,
 '3JE5B': 23,
 '3J916': 24,
 '3J1Q8': 25,
 '3JBVQ': 26,
 '3J4RQ': 27,
 '3JANP': 28,
 '3JAIF': 29,
 '3J9W3': 30,
 '3JE2P': 31,
 '3JBA6': 32,
 '3J66F': 33,
 '3JB0Q': 34,
 '3J7NT': 35,
 '3J842': 36,
 '3JC84': 37,
 '3J7JC': 38,
 '3J33G': 39,
 '3JEAP': 40,
 '3J2AP': 41,
 '3J3YN': 42,
 '3JDRG': 43,
 '3J2QU': 44,
 '3JBPA': 45,
 '3J72P': 46,
 '3JF16': 47,
 '3J5EU': 48,
 '3J9YB': 49,
 '3J4J3': 50,
 '3J6GF': 51,
 '3JCIK': 52,
 '3JHDM': 53,
 '3JACD': 54,
 '3J5DX': 55,
 '3J2E7': 56,
 '3J6NR': 57,
 '3J5A2': 58,
 '3J5HX': 59,
 '3JAG2': 60,
 '3JFCU': 61,
 '3JEBB': 62,
 '3J3CJ': 63,
 '3J5PB': 64,
 '3J8S2': 65,
 '3JB9G': 66,
 '3J6V6': 67,
 '3J2MI': 68,
 '3J915': 69,
 '3J436': 70,
 '3J5TK': 71,
 '

In [21]:
gene_og_query = """
MATCH (g:Gene)-[r:GeneInOrthologousGroup]->(og:OrthologousGroup)
WHERE og.eggnog_dataset_name = 'Mammalia' 
RETURN g.id as gene_id, og.orthologous_group_id as og_id
"""

edge_index_og, edge_weights_og = load_edge(
    gene_og_query,
    src_index_col="gene_id",
    src_mapping=gene_mapping,  # the two index mappings were used for this
    dst_index_col="og_id",
    dst_mapping=og_mapping,
)

In [22]:
edge_index_og

tensor([[43834, 42081,  8725,  ..., 59087, 59089, 59113],
        [    0,     0,     0,  ..., 14945, 14946, 14947]])

In [23]:
edge_weights_og

In [24]:
edge_index_og.shape

torch.Size([2, 49548])

In [25]:
gene_og_query = """
MATCH (g:Gene)-[r:GeneInOrthologousGroup]->(og:OrthologousGroup)
WHERE og.eggnog_dataset_name = 'Mammalia'
RETURN g.id as gene_id, og.orthologous_group_id as og_id
"""

edge_index_og_2, edge_weights_og_2 = load_edge(
    gene_og_query,
    src_index_col="og_id",
    src_mapping=og_mapping,  # the two index mappings were used for this
    dst_index_col="gene_id",
    dst_mapping=gene_mapping,
)
# again add reverse edge

In [26]:
edge_index_og_2

tensor([[    0,     0,     0,  ..., 14945, 14946, 14947],
        [43834, 42081,  8725,  ..., 59087, 59089, 59113]])

In [27]:
data = HeteroData()

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [6]:
#  The edge_label tensor holds the ground truth labels that you want the model to predict for specific edges.
#  Used for edge prediction tasks

In [7]:
# Add user node features for message passing:
data["cell_type"].x = torch.eye(len(ct_mapping), device=device)
data["cell_type"].y = ct_y
# Add movie node features
data["gene"].x = gene_x

data["orthologous_group"].x = torch.eye(len(og_mapping), device=device)
data

NameError: name 'ct_mapping' is not defined

In [8]:
# Add ratings between users and movies
data["gene", "is_wilcox_marker_of", "cell_type"].edge_index = edge_index
data["gene", "is_wilcox_marker_of", "cell_type"].edge_weights = edge_weights

data["cell_type", "rev_is_wilcox_marker_of", "gene"].edge_index = edge_index_2
data["cell_type", "rev_is_wilcox_marker_of", "gene"].edge_weights = edge_weights_2

data["gene", "is_in", "orthologous_group"].edge_index = edge_index_og

data["orthologous_group", "rev_is_in", "gene"].edge_index = edge_index_og_2

data.to(device, non_blocking=True)

NameError: name 'edge_index' is not defined

In [None]:
# I also need a reverse edge from cell type to genes
# just for the HGT model

In [None]:
data["cell_type"].x  # is just a diagonal matrix - no features yet

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])

In [38]:
data.node_types

['cell_type', 'gene', 'orthologous_group']

In [None]:
torch.save(data, "mtg_all_sp_wilcox_data_with_og.pt")

In [None]:
data["cell_type"].y.shape

torch.Size([120])

In [41]:
ct_mapping

{'Astro_H.sapiens': 0,
 'Oligo_H.sapiens': 1,
 'VLMC_H.sapiens': 2,
 'Micro-PVM_H.sapiens': 3,
 'OPC_H.sapiens': 4,
 'Endo_H.sapiens': 5,
 'L5-6 NP_H.sapiens': 6,
 'L6 CT_H.sapiens': 7,
 'L6b_H.sapiens': 8,
 'L5 ET_H.sapiens': 9,
 'Pax6_H.sapiens': 10,
 'Vip_H.sapiens': 11,
 'Sncg_H.sapiens': 12,
 'Lamp5_Lhx6_H.sapiens': 13,
 'Lamp5_H.sapiens': 14,
 'Sst Chodl_H.sapiens': 15,
 'Pvalb_H.sapiens': 16,
 'Sst_H.sapiens': 17,
 'Chandelier_H.sapiens': 18,
 'L5 IT_H.sapiens': 19,
 'L4 IT_H.sapiens': 20,
 'L2-3 IT_H.sapiens': 21,
 'L6 IT Car3_H.sapiens': 22,
 'L6 IT_H.sapiens': 23,
 'Astro_M.mulatta': 24,
 'Oligo_M.mulatta': 25,
 'VLMC_M.mulatta': 26,
 'Micro-PVM_M.mulatta': 27,
 'OPC_M.mulatta': 28,
 'Endo_M.mulatta': 29,
 'L5-6 NP_M.mulatta': 30,
 'L6 CT_M.mulatta': 31,
 'L6b_M.mulatta': 32,
 'L5 ET_M.mulatta': 33,
 'Pax6_M.mulatta': 34,
 'Vip_M.mulatta': 35,
 'Sncg_M.mulatta': 36,
 'Lamp5_Lhx6_M.mulatta': 37,
 'Lamp5_M.mulatta': 38,
 'Sst Chodl_M.mulatta': 39,
 'Pvalb_M.mulatta': 40,
 'Sst_

In [None]:
split = {
    "train_idx": np.array(range(0, 96, 1)),
    "test_idx": np.array(range(96, 120, 1)),
}

In [43]:
np.array(range(0, 96, 1))

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
       68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
       85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])

In [44]:
np.array(range(96, 120, 1))

array([ 96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108,
       109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119])

In [None]:
# Now since i want to train on human and test on pt, i need to create the train val test split
for name in ["train", "test"]:
    idx = split[f"{name}_idx"]
    idx = torch.from_numpy(idx).to(torch.long)
    mask = torch.zeros(data["cell_type"].num_nodes, dtype=torch.bool)
    mask[idx] = True
    data["cell_type"][f"{name}_mask"] = mask

    # this train test split can happen on cluster before training so only need to save the data once

In [None]:
data["cell_type"]

{'x': tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]]), 'y': tensor([0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3,
        0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3,
        0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3,
        0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3,
        0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3]), 'train_mask': tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
      

In [47]:
ct_mapping

{'Astro_H.sapiens': 0,
 'Oligo_H.sapiens': 1,
 'VLMC_H.sapiens': 2,
 'Micro-PVM_H.sapiens': 3,
 'OPC_H.sapiens': 4,
 'Endo_H.sapiens': 5,
 'L5-6 NP_H.sapiens': 6,
 'L6 CT_H.sapiens': 7,
 'L6b_H.sapiens': 8,
 'L5 ET_H.sapiens': 9,
 'Pax6_H.sapiens': 10,
 'Vip_H.sapiens': 11,
 'Sncg_H.sapiens': 12,
 'Lamp5_Lhx6_H.sapiens': 13,
 'Lamp5_H.sapiens': 14,
 'Sst Chodl_H.sapiens': 15,
 'Pvalb_H.sapiens': 16,
 'Sst_H.sapiens': 17,
 'Chandelier_H.sapiens': 18,
 'L5 IT_H.sapiens': 19,
 'L4 IT_H.sapiens': 20,
 'L2-3 IT_H.sapiens': 21,
 'L6 IT Car3_H.sapiens': 22,
 'L6 IT_H.sapiens': 23,
 'Astro_M.mulatta': 24,
 'Oligo_M.mulatta': 25,
 'VLMC_M.mulatta': 26,
 'Micro-PVM_M.mulatta': 27,
 'OPC_M.mulatta': 28,
 'Endo_M.mulatta': 29,
 'L5-6 NP_M.mulatta': 30,
 'L6 CT_M.mulatta': 31,
 'L6b_M.mulatta': 32,
 'L5 ET_M.mulatta': 33,
 'Pax6_M.mulatta': 34,
 'Vip_M.mulatta': 35,
 'Sncg_M.mulatta': 36,
 'Lamp5_Lhx6_M.mulatta': 37,
 'Lamp5_M.mulatta': 38,
 'Sst Chodl_M.mulatta': 39,
 'Pvalb_M.mulatta': 40,
 'Sst_

In [None]:
torch.save(data, "mtg_all_sp_gg_test_wilcox_data_with_og.pt")

In [49]:
data

HeteroData(
  cell_type={
    x=[120, 120],
    y=[120],
    train_mask=[120],
    test_mask=[120],
  },
  gene={ x=[92549, 29954] },
  orthologous_group={ x=[14948, 14948] },
  (gene, is_wilcox_marker_of, cell_type)={
    edge_index=[2, 17520],
    edge_weights=[17520],
  },
  (cell_type, rev_is_wilcox_marker_of, gene)={
    edge_index=[2, 17520],
    edge_weights=[17520],
  },
  (gene, is_in, orthologous_group)={ edge_index=[2, 49548] },
  (orthologous_group, rev_is_in, gene)={ edge_index=[2, 49548] }
)

## From here I start to create the dataset that uses the actual ct_name as classes, to see how well the predictions can be


In [None]:
ct_name_query = """
MATCH (ct:CellType)
RETURN ct.id AS cell_type_name_species, ct.cell_type_name AS cell_type_name
"""

ct_name_x, ct_name_y, ct_name_mapping, y_name_mapping = load_node(
    ct_name_query, index_col="cell_type_name_species", category_col="cell_type_name"
)

In [59]:
ct_name_mapping

{'Astro_H.sapiens': 0,
 'Oligo_H.sapiens': 1,
 'VLMC_H.sapiens': 2,
 'Micro-PVM_H.sapiens': 3,
 'OPC_H.sapiens': 4,
 'Endo_H.sapiens': 5,
 'L5-6 NP_H.sapiens': 6,
 'L6 CT_H.sapiens': 7,
 'L6b_H.sapiens': 8,
 'L5 ET_H.sapiens': 9,
 'Pax6_H.sapiens': 10,
 'Vip_H.sapiens': 11,
 'Sncg_H.sapiens': 12,
 'Lamp5_Lhx6_H.sapiens': 13,
 'Lamp5_H.sapiens': 14,
 'Sst Chodl_H.sapiens': 15,
 'Pvalb_H.sapiens': 16,
 'Sst_H.sapiens': 17,
 'Chandelier_H.sapiens': 18,
 'L5 IT_H.sapiens': 19,
 'L4 IT_H.sapiens': 20,
 'L2-3 IT_H.sapiens': 21,
 'L6 IT Car3_H.sapiens': 22,
 'L6 IT_H.sapiens': 23,
 'Astro_M.mulatta': 24,
 'Oligo_M.mulatta': 25,
 'VLMC_M.mulatta': 26,
 'Micro-PVM_M.mulatta': 27,
 'OPC_M.mulatta': 28,
 'Endo_M.mulatta': 29,
 'L5-6 NP_M.mulatta': 30,
 'L6 CT_M.mulatta': 31,
 'L6b_M.mulatta': 32,
 'L5 ET_M.mulatta': 33,
 'Pax6_M.mulatta': 34,
 'Vip_M.mulatta': 35,
 'Sncg_M.mulatta': 36,
 'Lamp5_Lhx6_M.mulatta': 37,
 'Lamp5_M.mulatta': 38,
 'Sst Chodl_M.mulatta': 39,
 'Pvalb_M.mulatta': 40,
 'Sst_

In [None]:
data_name = HeteroData()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
#  The edge_label tensor holds the ground truth labels that you want the model to predict for specific edges.
#  Used for edge prediction tasks
# Add user node features for message passing:
data_name["cell_type"].x = torch.eye(len(ct_name_mapping), device=device)
data_name["cell_type"].y = ct_name_y
# Add movie node features
data_name["gene"].x = gene_x

data_name["orthologous_group"].x = torch.eye(len(og_mapping), device=device)
data_name

cpu


HeteroData(
  cell_type={
    x=[120, 120],
    y=[120],
  },
  gene={ x=[92549, 29954] },
  orthologous_group={ x=[14948, 14948] }
)

In [None]:
# Add ratings between users and movies
data_name["gene", "is_wilcox_marker_of", "cell_type"].edge_index = edge_index
data_name["gene", "is_wilcox_marker_of", "cell_type"].edge_weights = edge_weights

data_name["cell_type", "rev_is_wilcox_marker_of", "gene"].edge_index = edge_index_2
data_name["cell_type", "rev_is_wilcox_marker_of", "gene"].edge_weights = edge_weights_2

data_name["gene", "is_in", "orthologous_group"].edge_index = edge_index_og

data_name["orthologous_group", "rev_is_in", "gene"].edge_index = edge_index_og_2

data_name.to(device, non_blocking=True)
# I also need a reverse edge from cell type to genes
# just for the HGT model
data_name["cell_type"].x  # is just a diagonal matrix - no features yet

torch.save(data_name, "mtg_all_sp_wilcox_data_with_og_ct_name.pt")

In [None]:
data_name["cell_type"].y

tensor([ 0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23, 19, 13, 12, 21, 18, 20,
         1,  6,  4,  3, 10,  9,  0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23,
        19, 13, 12, 21, 18, 20,  1,  6,  4,  3, 10,  9,  0, 16, 22, 14, 15,  2,
         7,  8, 11,  5, 17, 23, 19, 13, 12, 21, 18, 20,  1,  6,  4,  3, 10,  9,
         0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23, 19, 13, 12, 21, 18, 20,
         1,  6,  4,  3, 10,  9,  0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23,
        19, 13, 12, 21, 18, 20,  1,  6,  4,  3, 10,  9])

In [None]:
split = {
    "train_idx": np.array(range(0, 96, 1)),
    "test_idx": np.array(range(96, 120, 1)),
}
np.array(range(0, 96, 1))
np.array(range(96, 120, 1))
# Now since i want to train on human and test on pt, i need to create the train val test split
for name in ["train", "test"]:
    idx = split[f"{name}_idx"]
    idx = torch.from_numpy(idx).to(torch.long)
    mask = torch.zeros(data["cell_type"].num_nodes, dtype=torch.bool)
    mask[idx] = True
    data_name["cell_type"][f"{name}_mask"] = mask
data_name["cell_type"]
ct_mapping
torch.save(data_name, "mtg_all_sp_gg_test_wilcox_data_with_og_ct_name.pt")

In [None]:
data_name["cell_type"]

{'x': tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]]), 'y': tensor([ 0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23, 19, 13, 12, 21, 18, 20,
         1,  6,  4,  3, 10,  9,  0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23,
        19, 13, 12, 21, 18, 20,  1,  6,  4,  3, 10,  9,  0, 16, 22, 14, 15,  2,
         7,  8, 11,  5, 17, 23, 19, 13, 12, 21, 18, 20,  1,  6,  4,  3, 10,  9,
         0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23, 19, 13, 12, 21, 18, 20,
         1,  6,  4,  3, 10,  9,  0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23,
        19, 13, 12, 21, 18, 20,  1,  6,  4,  3, 10,  9]), 'train_mask': tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  Tru

## From here I am making the graph using reduced number of species


In [102]:
species_lists = {
    "no_hs": ["P.troglodytes", "M.mulatta", "C.jacchus", "G.gorilla"],
    "no_pt": ["H.sapiens", "M.mulatta", "C.jacchus", "G.gorilla"],
    "no_mm": ["H.sapiens", "P.troglodytes", "C.jacchus", "G.gorilla"],
    "no_cj": ["H.sapiens", "P.troglodytes", "M.mulatta", "G.gorilla"],
    "no_gg": ["H.sapiens", "P.troglodytes", "M.mulatta", "C.jacchus"],
}

In [None]:
for case, species_list in species_lists.items():
    print(case)
    print(species_list)

    ct_sub_query = """
    MATCH (ct:CellType)
    WHERE ct.species_of_origin IN $species_list
    RETURN ct.id AS cell_type_name_species, ct.cell_type_name AS cell_type_name
    """

    ct_sub_x, ct_sub_y, ct_sub_mapping, y_sub_mapping = load_node(
        ct_sub_query,
        parameters={"species_list": species_list},
        index_col="cell_type_name_species",
        category_col="cell_type_name",
    )
    marker_sub_query = """
    MATCH (g:Gene)-[r:GeneWilcoxMarkerInCellType]->(ct:CellType)
    WHERE r.avg_log2fc >= 4 AND ct.species_of_origin IN $species_list
    RETURN g.id as gene_id, ct.id as cell_type_name_species, r.avg_log2fc as avg_log2fc
    """

    edge_index_sub, edge_weights_sub = load_edge(
        marker_sub_query,
        src_index_col="gene_id",
        src_mapping=gene_mapping,  # the two index mappings were used for this
        dst_index_col="cell_type_name_species",
        dst_mapping=ct_sub_mapping,
        encoders={
            "avg_log2fc": IdentityEncoder(dtype=torch.float32)
        },  # remember to set the correct dtype for identity encoding the edge weight
        parameters={"species_list": species_list},
    )

    marker_sub_query_2 = """
    MATCH (g:Gene)-[r:GeneWilcoxMarkerInCellType]->(ct:CellType)
    WHERE r.avg_log2fc >= 4 AND ct.species_of_origin IN $species_list
    RETURN g.id as gene_id, ct.id as cell_type_name_species, r.avg_log2fc as avg_log2fc
    """

    edge_index_sub_2, edge_weights_sub_2 = load_edge(
        marker_sub_query_2,
        src_index_col="cell_type_name_species",
        src_mapping=ct_sub_mapping,  # the two index mappings were used for this
        dst_index_col="gene_id",
        dst_mapping=gene_mapping,
        encoders={
            "avg_log2fc": IdentityEncoder(dtype=torch.float32)
        },  # remember to set the correct dtype for identity encoding the edge weight
        parameters={"species_list": species_list},
    )
    ct_sub_mapping
    data_sub = HeteroData()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    #  The edge_label tensor holds the ground truth labels that you want the model to predict for specific edges.
    #  Used for edge prediction tasks
    # Add user node features for message passing:
    data_sub["cell_type"].x = torch.eye(len(ct_sub_mapping), device=device)
    data_sub["cell_type"].y = ct_sub_y

    data_sub["gene"].x = gene_x

    data_sub["orthologous_group"].x = torch.eye(len(og_mapping), device=device)
    data_sub
    case
    # Add ratings between users and movies
    data_sub["gene", "is_wilcox_marker_of", "cell_type"].edge_index = edge_index_sub
    data_sub["gene", "is_wilcox_marker_of", "cell_type"].edge_weights = edge_weights_sub

    data_sub["cell_type", "rev_is_wilcox_marker_of", "gene"].edge_index = (
        edge_index_sub_2
    )
    data_sub["cell_type", "rev_is_wilcox_marker_of", "gene"].edge_weights = (
        edge_weights_sub_2
    )

    data_sub["gene", "is_in", "orthologous_group"].edge_index = edge_index_og

    data_sub["orthologous_group", "rev_is_in", "gene"].edge_index = edge_index_og_2

    data_sub.to(device, non_blocking=True)
    # I also need a reverse edge from cell type to genes
    # just for the HGT model
    data_sub["cell_type"].x  # is just a diagonal matrix - no features yet
    print(f"start writing {case} pytorch data, unsplit")
    torch.save(data_sub, f"mtg_all_sp_wilcox_data_with_og_ct_name_{case}.pt")
    print(f"finish writing {case} pytorch data, unsplit")

    with open(
        f"mtg_all_sp_wilcox_data_with_og_ct_name_{case}_ct_mapping.pkl", "wb"
    ) as f:
        pickle.dump(ct_sub_mapping, f)

    with open(
        f"mtg_all_sp_wilcox_data_with_og_ct_name_{case}_gene_mapping.pkl", "wb"
    ) as f:
        pickle.dump(gene_mapping, f)

    with open(
        f"mtg_all_sp_wilcox_data_with_og_ct_name_{case}_og_mapping.pkl", "wb"
    ) as f:
        pickle.dump(og_mapping, f)

no_hs
['P.troglodytes', 'M.mulatta', 'C.jacchus', 'G.gorilla']
cpu
start writing no_hs pytorch data, unsplit
finish writing no_hs pytorch data, unsplit
no_pt
['H.sapiens', 'M.mulatta', 'C.jacchus', 'G.gorilla']
cpu
start writing no_pt pytorch data, unsplit
finish writing no_pt pytorch data, unsplit
no_mm
['H.sapiens', 'P.troglodytes', 'C.jacchus', 'G.gorilla']
cpu
start writing no_mm pytorch data, unsplit
finish writing no_mm pytorch data, unsplit
no_cj
['H.sapiens', 'P.troglodytes', 'M.mulatta', 'G.gorilla']
cpu
start writing no_cj pytorch data, unsplit
finish writing no_cj pytorch data, unsplit
no_gg
['H.sapiens', 'P.troglodytes', 'M.mulatta', 'C.jacchus']
cpu
start writing no_gg pytorch data, unsplit
finish writing no_gg pytorch data, unsplit


In [43]:
data_sub["cell_type"].y

tensor([ 0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23, 19, 13, 12, 21, 18, 20,
         1,  6,  4,  3, 10,  9,  0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23,
        19, 13, 12, 21, 18, 20,  1,  6,  4,  3, 10,  9,  0, 16, 22, 14, 15,  2,
         7,  8, 11,  5, 17, 23, 19, 13, 12, 21, 18, 20,  1,  6,  4,  3, 10,  9,
         0, 16, 22, 14, 15,  2,  7,  8, 11,  5, 17, 23, 19, 13, 12, 21, 18, 20,
         1,  6,  4,  3, 10,  9])

In [44]:
data_sub

HeteroData(
  cell_type={
    x=[96, 96],
    y=[96],
  },
  gene={ x=[92549, 29954] },
  orthologous_group={ x=[14948, 14948] },
  (gene, is_wilcox_marker_of, cell_type)={
    edge_index=[2, 13487],
    edge_weights=[13487],
  },
  (cell_type, rev_is_wilcox_marker_of, gene)={
    edge_index=[2, 13487],
    edge_weights=[13487],
  },
  (gene, is_in, orthologous_group)={ edge_index=[2, 49548] },
  (orthologous_group, rev_is_in, gene)={ edge_index=[2, 49548] }
)

In [None]:
species_origin = dict(
    pd.Series(list(ct_sub_mapping.keys())).replace(".*_", "", regex=True)
)

In [72]:
all_species = set(species_origin.values())

In [None]:
for species_test_now in all_species:
    print(species_test_now)

P.troglodytes
H.sapiens
C.jacchus
M.mulatta


In [82]:
split = {
    "train_idx": np.array(
        [k for k in species_origin.keys() if species_origin[k] == species_test_now]
    ),
    "test_idx": np.array(
        [k for k in species_origin.keys() if species_origin[k] != species_test_now]
    ),
}

In [84]:
split

{'train_idx': array([24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
        41, 42, 43, 44, 45, 46, 47]),
 'test_idx': array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,
        58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74,
        75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91,
        92, 93, 94, 95])}

## This is for 4 species but to predict cell type family


In [103]:
for case, species_list in species_lists.items():
    print(case)
    print(species_list)

    ct_sub_query = """
    MATCH (ct:CellType)
    WHERE ct.species_of_origin IN $species_list
    RETURN ct.id AS cell_type_name_species,  ct.broad_taxo_cs AS broad_taxo_cs
    """

    ct_sub_x, ct_sub_y, ct_sub_mapping, y_sub_mapping = load_node(
        ct_sub_query,
        parameters={"species_list": species_list},
        index_col="cell_type_name_species",
        category_col="broad_taxo_cs",
    )

    marker_sub_query = """
    MATCH (g:Gene)-[r:GeneWilcoxMarkerInCellType]->(ct:CellType)
    WHERE r.avg_log2fc >= 4 AND ct.species_of_origin IN $species_list
    RETURN g.id as gene_id, ct.id as cell_type_name_species, r.avg_log2fc as avg_log2fc
    """

    edge_index_sub, edge_weights_sub = load_edge(
        marker_sub_query,
        src_index_col="gene_id",
        src_mapping=gene_mapping,  # the two index mappings were used for this
        dst_index_col="cell_type_name_species",
        dst_mapping=ct_sub_mapping,
        encoders={
            "avg_log2fc": IdentityEncoder(dtype=torch.float32)
        },  # remember to set the correct dtype for identity encoding the edge weight
        parameters={"species_list": species_list},
    )

    marker_sub_query_2 = """
    MATCH (g:Gene)-[r:GeneWilcoxMarkerInCellType]->(ct:CellType)
    WHERE r.avg_log2fc >= 4 AND ct.species_of_origin IN $species_list
    RETURN g.id as gene_id, ct.id as cell_type_name_species, r.avg_log2fc as avg_log2fc
    """

    edge_index_sub_2, edge_weights_sub_2 = load_edge(
        marker_sub_query_2,
        src_index_col="cell_type_name_species",
        src_mapping=ct_sub_mapping,  # the two index mappings were used for this
        dst_index_col="gene_id",
        dst_mapping=gene_mapping,
        encoders={
            "avg_log2fc": IdentityEncoder(dtype=torch.float32)
        },  # remember to set the correct dtype for identity encoding the edge weight
        parameters={"species_list": species_list},
    )
    ct_sub_mapping
    data_sub = HeteroData()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    #  The edge_label tensor holds the ground truth labels that you want the model to predict for specific edges.
    #  Used for edge prediction tasks
    # Add user node features for message passing:
    data_sub["cell_type"].x = torch.eye(len(ct_sub_mapping), device=device)
    data_sub["cell_type"].y = ct_sub_y

    print(ct_sub_y, flush=True)

    data_sub["gene"].x = gene_x

    data_sub["orthologous_group"].x = torch.eye(len(og_mapping), device=device)
    data_sub
    case
    # Add ratings between users and movies
    data_sub["gene", "is_wilcox_marker_of", "cell_type"].edge_index = edge_index_sub
    data_sub["gene", "is_wilcox_marker_of", "cell_type"].edge_weights = edge_weights_sub

    data_sub["cell_type", "rev_is_wilcox_marker_of", "gene"].edge_index = (
        edge_index_sub_2
    )
    data_sub["cell_type", "rev_is_wilcox_marker_of", "gene"].edge_weights = (
        edge_weights_sub_2
    )

    data_sub["gene", "is_in", "orthologous_group"].edge_index = edge_index_og

    data_sub["orthologous_group", "rev_is_in", "gene"].edge_index = edge_index_og_2

    data_sub.to(device, non_blocking=True)
    # I also need a reverse edge from cell type to genes
    # just for the HGT model
    data_sub["cell_type"].x  # is just a diagonal matrix - no features yet
    print(f"start writing {case} pytorch data, unsplit")
    torch.save(data_sub, f"mtg_all_sp_wilcox_data_with_og_{case}.pt")
    print(f"finish writing {case} pytorch data, unsplit")

    with open(f"mtg_all_sp_wilcox_data_with_og_{case}_ct_mapping.pkl", "wb") as f:
        pickle.dump(ct_sub_mapping, f)

    with open(f"mtg_all_sp_wilcox_data_with_og_{case}_gene_mapping.pkl", "wb") as f:
        pickle.dump(gene_mapping, f)

    with open(f"mtg_all_sp_wilcox_data_with_og_{case}_og_mapping.pkl", "wb") as f:
        pickle.dump(og_mapping, f)

no_hs
['P.troglodytes', 'M.mulatta', 'C.jacchus', 'G.gorilla']
cpu
tensor([0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3,
        0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3,
        0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3,
        0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3])
start writing no_hs pytorch data, unsplit
finish writing no_hs pytorch data, unsplit
no_pt
['H.sapiens', 'M.mulatta', 'C.jacchus', 'G.gorilla']
cpu
tensor([0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3,
        0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3,
        0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3,
        0, 8, 9, 5, 7, 2, 6, 6, 6, 6, 1, 1, 1, 1, 1, 4, 4, 4, 4, 3, 3, 3, 3, 3])
start writing no_pt pytorch data, unsplit
finish writing no_pt pytorch data, unsplit
no_mm
['H.sapiens', 'P.troglodytes', 'C.jacchus', 'G.goril

## From here I try to predict different species
