In [3]:
import numpy as np
import pandas as pd

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import SGConv
import networkx as nx

from typing import List
import requests, gzip, shutil, tarfile
from pathlib import Path

# Data

In [314]:
def download_graph_edge_list(target_dir):
    
    path = Path(target_dir)
    
    url = 'https://dataverse.harvard.edu/api/access/datafile/6934319'
    r = requests.get(url)

    with open(path.joinpath("go.tar.gz"), "wb") as f:
        f.write(r.content)
    f.close()

    f = tarfile.open(path.joinpath('go.tar.gz'))
    f.extractall(path) 
    f.close()
    

def create_graph(filepath, topn=10, nrows=None):
    
    go_ = pd.read_csv(filepath, nrows=nrows)

    go = go_.groupby('target').apply(lambda x: x.nlargest(topn + 1, ['importance'])).reset_index(drop = True)

    gene_list = list(set(go.source.tolist() + go.target.tolist()))
    gene2idx = {g:i for i,g in enumerate(gene_list)}

    G = nx.from_pandas_edgelist(go, source='source',target='target', edge_attr=['importance'], create_using=nx.DiGraph())

    edge_index_ = [(gene2idx[e[0]], gene2idx[e[1]]) for e in G.edges]
    edge_index = torch.tensor(edge_index_, dtype=torch.long).T

    edge_attr = nx.get_edge_attributes(G, 'importance') 
    edge_weight = torch.Tensor(np.array([edge_attr[e] for e in G.edges]))

    return edge_index, edge_weight, gene_list, gene2idx


In [None]:
# download graph data
download_graph_edge_list(target_dir = '../data')

In [569]:
# create graph
edge_index, edge_weight, gene_list, gene2idx = create_graph('../data/go_essential_all/go_essential_all.csv', topn=10, nrows=10)
print(edge_index)
print(edge_weight)
print(gene_list)
print(gene2idx)

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 9, 0, 8, 2, 3, 4, 6, 7, 5]])
tensor([1.0000, 0.1304, 0.1333, 0.2500, 0.1250, 0.1739, 0.1250, 0.2222, 0.1176,
        0.1200])
['ACTR10', 'A1BG', 'ADAMTS20', 'ADAMTS3', 'ADAMTSL4', 'AGL', 'AEBP1', 'AGA', 'ACTR1B', 'ACLY']
{'ACTR10': 0, 'A1BG': 1, 'ADAMTS20': 2, 'ADAMTS3': 3, 'ADAMTSL4': 4, 'AGL': 5, 'AEBP1': 6, 'AGA': 7, 'ACTR1B': 8, 'ACLY': 9}


# Model

In [448]:
class MLP(nn.Module):

    def __init__(
        self, 
        sizes: List[int] = None, 
        batch_norm: bool = True
    ):

        super().__init__()
        self.sizes = sizes
        self.batch_norm = batch_norm
        
        layers = []
        for i in range(len(self.sizes) - 1):
            
            in_size = self.sizes[i]
            out_size = self.sizes[i + 1]
            
            layers.extend([
                nn.Linear(in_size, out_size),
                nn.BatchNorm1d(out_size) if self.batch_norm else None,
                nn.ReLU()
            ])

        self.layers = nn.Sequential(*[l for l in layers if l is not None])
        
    def forward(
        self, 
        x: Tensor
    ):
        return self.layers(x)

In [565]:
class GNN(nn.Module):

    def __init__(
        self, 
        genes: int,
        seq_len: int, # n_cell (this is a fixed value for all genes)
        d_hid: int,
        edge_index: Tensor,
        edge_weight: Tensor,
        gene2idx: dict,
        n_gnn_layers: int,
        device: str = 'cpu',
    ):

        super().__init__()   
        self.genes = genes
        self.n_genes = len(genes)
        self.seq_len = seq_len
        self.d_hid = d_hid if d_hid < 8*(seq_len//32) else 8*(seq_len//32)
        self.d_emb = seq_len + 1
        self.edge_index = edge_index
        self.edge_weight = edge_weight
        self.gene2idx = gene2idx
        self.device = device

        ### GNN layers
        self.n_gnn_layers = n_gnn_layers
        self.edge_index = edge_index
        self.edge_weight = edge_weight.to(self.device)
        self.gnn_layers = torch.nn.ModuleList()
        for i in range(1, self.n_gnn_layers+1):
            self.gnn_layers.append(SGConv(self.d_emb, self.d_emb, 1))
            
        ### attention layer
        self.attn = nn.MultiheadAttention(self.d_emb, 1, dropout=0.1, bias=True)

        # decoder layer
        self.decoder = MLP([self.d_emb, d_hid*2, d_hid*2, d_hid*4], batch_norm=False)
        self.output = nn.Linear(d_hid*4, seq_len)
        
        def get_cell_emb(self, emb):
            self.attn(emb)
        

    def forward(
        self, 
        src: Tensor,
        pert_gene: str, # TODO: batch
    ):
       
        pert_idx = self.gene2idx[pert_gene]
        ## pertubation embedding
        pert_emb = torch.zeros((self.n_genes, 1)) # n_gene, 1
        pert_emb[pert_idx, :] = 1 

        ## node feature
        emb = torch.concat((src, pert_emb), dim=1) # n_gene, seq_len+1

        ## augment global perturbation embedding with GNN
        for i, gnn in enumerate(self.gnn_layers):
            emb = gnn(emb, self.edge_index, self.edge_weight)
            if i < self.n_gnn_layers - 1:
                emb = emb.relu()

        ## attention
        attn_output, attn_weights = self.attn(emb, emb, emb)        
        cell_emb = torch.matmul(attn_weights[pert_idx, :], emb).reshape(1, -1)
        
        ## output
        return self.output(self.decoder(cell_emb))


In [566]:
model = GNN(
    genes = gene_list, 
    seq_len = 3, 
    d_hid = 2, 
    edge_index = edge_index, 
    edge_weight = edge_weight, 
    gene2idx = gene2idx, 
    n_gnn_layers = 3,
   )

model(src, pert_gene='ADAMTS3')

tensor([[-0.1156, -0.0120,  0.3801]], grad_fn=<AddmmBackward0>)

In [564]:
src[gene2idx['ADAMTS3'], :]

tensor([ 0.7878, -1.3893, -0.4456])