In [1]:
import numpy as np
import pandas as pd

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import SGConv

from typing import List


# Data

In [2]:
from graph import download_graph_edge_list, create_graph

In [8]:
# download graph data
download_graph_edge_list(target_dir = '../data')

In [3]:
# create graph
edge_index, edge_weight, gene_list, gene2idx = create_graph('../data/go_essential_all/go_essential_all.csv', topn=10, nrows=10)
print(edge_index)
print(edge_weight)
print(gene_list)
print(gene2idx)

tensor([[7, 7, 7, 7, 7, 7, 7, 7, 7, 7],
        [7, 3, 9, 0, 6, 8, 2, 5, 4, 1]])
tensor([1.0000, 0.1304, 0.1333, 0.2500, 0.1250, 0.1739, 0.1250, 0.2222, 0.1176,
        0.1200])
['ACTR1B', 'AGL', 'ADAMTSL4', 'ACLY', 'AGA', 'AEBP1', 'ADAMTS20', 'A1BG', 'ADAMTS3', 'ACTR10']
{'ACTR1B': 0, 'AGL': 1, 'ADAMTSL4': 2, 'ACLY': 3, 'AGA': 4, 'AEBP1': 5, 'ADAMTS20': 6, 'A1BG': 7, 'ADAMTS3': 8, 'ACTR10': 9}


# Model

In [8]:
from model import GNN

gnn_model = GNN(
    genes = gene_list, 
    seq_len = 3, 
    d_hid = 2, 
    edge_index = edge_index, 
    edge_weight = edge_weight, 
    gene2idx = gene2idx, 
    n_gnn_layers = 3,
   )

In [17]:
src = torch.randn((10, 3))
print(src[gene2idx['ADAMTS3'], :])

with torch.no_grad():
    res = gnn_model(src, pert_gene='ADAMTS3')
print(res)

tensor([-1.5749,  2.1456,  0.9442])
tensor([[ 0.3612,  0.0961, -0.1318]])
