# GEARS Graphs Creation

### Packages

In [1]:
import torch
import numpy as np
import pandas as pd
import networkx as nx
from tqdm import tqdm
import scanpy as sc
import pickle
import sys, os
import requests
from torch_geometric.data import Data
from zipfile import ZipFile
import tarfile
import matplotlib.pyplot as plt

## Co-Expression

In [2]:
with open("/scratch/jeremy/data/graphs/raw/gene_to_id.pkl", 'rb') as handle:
    node_map = pickle.load(handle)

In [11]:
edge_list = pd.read_csv("/scratch/jeremy/data/graphs/mca_GEARS_kidney_lung_train/0.2_20_co_expression_network.csv")
co_expr_network = GeneSimNetwork(edge_list, node_map)
G_coexpress = co_expr_network.edge_index
G_coexpress_weight = co_expr_network.edge_weight

In [12]:
G_coexpress.shape

torch.Size([2, 46092])

In [13]:
G_coexpress_weight.shape

torch.Size([46092])

### GEARS Method

In [6]:
def np_pearson_cor(x, y):
    xv = x - x.mean(axis=0)
    yv = y - y.mean(axis=0)
    xvss = (xv * xv).sum(axis=0)
    yvss = (yv * yv).sum(axis=0)
    result = np.matmul(xv.transpose(), yv) / np.sqrt(np.outer(xvss, yvss))
    # bound the values to -1 to 1 in the event of precision issues
    return np.maximum(np.minimum(result, 1.0), -1.0)

def get_coexpression_network_from_train(adata, threshold, k, data_path,
                                        data_name):
    
    fname = os.path.join(os.path.join(data_path, data_name),
                         str(threshold) + '_' + str(k) + '_co_expression_network.csv')
    
    if os.path.exists(fname):
        return pd.read_csv(fname)
    else:
        gene_list = [f for f in adata.var.reset_index()["index"].values]
        idx2gene = dict(zip(range(len(gene_list)), gene_list)) 
        X_tr = adata.X
        gene_list = adata.var.reset_index()["index"].values

        X_tr = X_tr.toarray()
        out = np_pearson_cor(X_tr, X_tr)
        out[np.isnan(out)] = 0
        out = np.abs(out)
        
        out_sort_idx = np.argsort(out)[:, -(k + 1):]
        out_sort_val = np.sort(out)[:, -(k + 1):]

        df_g = []
        for i in range(out_sort_idx.shape[0]):
            target = idx2gene[i]
            for j in range(out_sort_idx.shape[1]):
                df_g.append((idx2gene[out_sort_idx[i, j]], target, out_sort_val[i, j]))

        df_g = [i for i in df_g if i[2] > threshold]
        df_co_expression = pd.DataFrame(df_g).rename(columns = {0: 'source',
                                                                1: 'target',
                                                                2: 'importance'})
        df_co_expression.to_csv(fname, index = False)
        return df_co_expression

### Pipeline

In [11]:
train_data_path = "/scratch/jeremy/data/train_Kidney_Lung_GEARS.h5ad"
threshold = 0.2
k = 20
data_path = "/scratch/jeremy/data/graphs/"
data_name = "mca_GEARS_kidney_lung_train"

In [12]:
# Load Data
train_data = sc.read_h5ad(train_data_path)

In [13]:
# Create Co-Expression Graph
co_express_df = get_coexpression_network_from_train(train_data,threshold,k,data_path,data_name)

### GEARS Graph Class

In [3]:
class GeneSimNetwork():
    def __init__(self, edge_list:pd.DataFrame, node_map:dict, gene_list=None):
        # Create Graph
        self.edge_list = edge_list
        self.G = nx.from_pandas_edgelist(self.edge_list, source='source',
                        target='target', edge_attr=['importance'],
                        create_using=nx.DiGraph())
        # Save gene list
        if gene_list == None:
            self.gene_list = sorted(list(set(self.G.nodes)))
        else:
            self.gene_list = gene_list
        for n in self.gene_list:
            if n not in self.G.nodes():
                self.G.add_node(n)
        # Convert data to tensor
        self.node_map = node_map
        edge_index_ = [(node_map[e[0]], node_map[e[1]]) for e in
                      self.G.edges]
        self.edge_index = torch.tensor(edge_index_, dtype=torch.long).T
        edge_attr = nx.get_edge_attributes(self.G, 'importance') 
        importance = np.array([edge_attr[e] for e in self.G.edges])
        self.edge_weight = torch.Tensor(importance)

In [45]:
with open("/scratch/jeremy/data/graphs/raw/gene_to_id.pkl", 'rb') as handle:
    node_map = pickle.load(handle)

In [46]:
sim_network = GeneSimNetwork(co_express_df, node_map)

## Gene-Ontology

### GEARS Methods

In [54]:
def dataverse_download(url, save_path):
    """dataverse download helper with progress bar
    Args:
        url (str): the url of the dataset
        path (str): the path to save the dataset
    """
    
    if os.path.exists(save_path):
        print('Found local copy...')
    else:
        print("Downloading...")
        response = requests.get(url, stream=True)
        total_size_in_bytes= int(response.headers.get('content-length', 0))
        block_size = 1024
        progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
        with open(save_path, 'wb') as file:
            for data in response.iter_content(block_size):
                progress_bar.update(len(data))
                file.write(data)
        progress_bar.close()

        
def zip_data_download_wrapper(url, save_path, data_path):

    if os.path.exists(save_path):
        print('Found local copy...')
    else:
        dataverse_download(url, save_path + '.zip')
        print('Extracting zip file...')
        with ZipFile((save_path + '.zip'), 'r') as zip:
            zip.extractall(path = data_path)
        print("Done!")  
        
def tar_data_download_wrapper(url, save_path, data_path):

    if os.path.exists(save_path):
        print('Found local copy...')
    else:
        dataverse_download(url, save_path + '.tar.gz')
        print('Extracting tar file...')
        print(save_path  + '.tar.gz')
        with tarfile.open(save_path  + '.tar.gz') as tar:
            tar.extractall(path= data_path)
        print("Done!")  

### Download GEARS GO

In [None]:
data_path = "/scratch/jeremy/data/graphs/raw/"
server_path = 'https://dataverse.harvard.edu/api/access/datafile/6934319'

In [None]:
tar_data_download_wrapper(server_path, os.path.join(data_path, 'go_essential_all'),data_path)

### Pipeline

In [142]:
# Input paths
mca_path = "/scratch/jeremy/data/mouse_cell_atlas_processed.h5ad"
GEARS_go_path = "/scratch/jeremy/data/graphs/raw/GEARS_basic_GO.csv"
gene2go_path = "/scratch/jeremy/data/graphs/raw/gene2go_all.pkl"
# Output paths
gene_to_id_path = "/scratch/jeremy/data/graphs/raw/gene_to_id.pkl"
id_to_gene_path = "/scratch/jeremy/data/graphs/raw/id_to_gene.pkl"
mca_go_graph_path = "/scratch/jeremy/data/graphs/raw/mca_go_graph.csv"

In [113]:
# Load MCA gene list
mca_data = sc.read_h5ad(mca_path)
mca_genes = set([g.upper() for g in sc_data.var_names])
print(f"There are {len(mca_genes)} genes in the MCA dataset.")

There are 22959 genes in the MCA dataset.


In [114]:
# Load Gene-Ontology graph used by default in GEARS
df_jaccard = pd.read_csv(GEARS_go_path)
df_out = df_jaccard.groupby('target').apply(lambda x: x.nlargest(k + 1,['importance'])).reset_index(drop = True)
GEARS_GO_genes = set(df_out.source.values)
print(f"There are {len(GEARS_GO_genes)} genes in the GEARS Graph Ontology.")

There are 9671 genes in the GEARS Graph Ontology.


In [115]:
# Load Gene - Ontology ID mapping
with open(gene2go_path, 'rb') as f:
    gene2go = pickle.load(f)
print(f"There are {len(gene2go)} genes in the Gene2GO mapping.")

There are 67832 genes in the Gene2GO mapping.


There are several ways we could go starting from here:
- Take the gene ontology and the gene2go mapping to create a broader graph (twice more cover) compared to the one from GEARS
- Take GEARS GO, but keep all genes. We just create nodes without any edges for the extra ones --> high bias, maybe not good
- Take GEARS GO, filter out genes from the MCA dataset that are not present in the graph to restrict analysis.

For now on, we will try as first step the last option.

In [136]:
# MCA restricted gene set
final_gene_set = mca_genes.intersection(GEARS_GO_genes)
print(f"There are {len(final_gene_set)} genes in the final gene set.")
# Create Gene Mappings
gene_to_id = dict([(g,i) for i,g in enumerate(sorted(list(final_gene_set)))])
id_to_gene = dict([(i,g) for i,g in enumerate(sorted(list(final_gene_set)))])
# Save Gene Mappings
with open(gene_to_id_path, 'wb') as handle:
    pickle.dump(gene_to_id, handle)
with open(id_to_gene_path, 'wb') as handle:
    pickle.dump(id_to_gene, handle)

There are 8179 genes in the final gene set.


In [143]:
# Saving MCA GO graph 
mca_go_graph = df_out.query("source in @final_gene_set and target in @final_gene_set")
mca_go_graph.to_csv(mca_go_graph_path)

## PyG Data Methods

In [55]:
Data?

---