In [2]:
import numpy as np 
import pandas as pd
import re
import os
import sys
sys.path.append('..')

import torch
from torch_geometric.data import DataLoader, Dataset
from torch_geometric.nn import Set2Set
from torch_geometric.utils import from_networkx

import networkx as nx
from tqdm.notebook import tqdm

from models.graph_transformer.euclidean_graph_transformer import GraphTransformerEncoder, MultipleOptimizer
from models.deep_graph_infomax.infomax import SNInfomax
from utils.data_gen import load_prot_embs, to_categorical

In [3]:
prot_embs, global_dict = load_prot_embs(512, False)

In [4]:
def ucsv2graph_infomax(fname, global_dict, sig_one_hot=None, y=None):
    """
    Unweighted Graph Creator
    """
    sample = pd.read_csv(fname)

    G = nx.from_pandas_edgelist(sample,
                                source='node1',
                                target='node2',
                                edge_attr=['sign'],
                                create_using=nx.DiGraph())

    n1a1 = sample[['node1', 'activity1']]
    n1a1.columns = ['node', 'act']

    n2a2 = sample[['node2', 'activity2']]
    n2a2.columns = ['node', 'act']
    na = pd.concat([n1a1, n2a2])
    na = na.drop_duplicates('node')
    na = na.set_index('node')
    na['acts'] = na[['act']].apply(lambda x: np.hstack(x), axis=1)
    na = na.drop(['act'], axis=1)['acts'].to_dict()

    nx.set_node_attributes(G, global_dict, 'global_idx')
    nx.set_node_attributes(G, na, 'acts')

    data = from_networkx(G)

    data.acts[data.acts < 0] = 0
    data.acts = to_categorical(data.acts, 2).reshape(-1, 2).long()

    data.sign[data.sign < 0] = 0
    data.sign = to_categorical(data.sign, 2).reshape(-1, 2).float()

    data.weight = torch.ones(data.num_edges)

    if sig_one_hot is not None:
        data.sig = torch.tensor(sig_one_hot)
        data.node_sig = data.sig.view(1, -1).repeat(data.num_nodes, 1)

    if y is not None:
        data.y = torch.tensor(y)
    # data.seq_mat = torch.add(aps, aps_r)

    return data

In [5]:
class SNDatasetInfomax(Dataset):
    def __init__(self, fnames, global_dict, sig_id_one_hot=None):
        super(SNDatasetInfomax, self).__init__()
        self.fnames = fnames
        self.gd = global_dict
        self.sig_id = sig_id_one_hot

    def len(self):
        return len(self.fnames)

    def get(self, idx):
        if self.sig_id is not None:
            data = ucsv2graph_infomax(self.fnames[idx], self.gd,
                                      self.sig_id[idx])
        else:
            data = ucsv2graph_infomax(self.fnames[idx], self.gd)
        return data

In [9]:
uc_path = 'data/use_cases/use_case_graphs.csv'
use_cases = pd.read_csv(uc_path)
use_cases = use_cases.drop(index=[5,40,47])
path_list = use_cases.files_combined.to_numpy()

In [10]:
test_data = SNDatasetInfomax(path_list, global_dict)
loader = DataLoader(test_data, batch_size=1, num_workers = 12)

In [11]:
data = ucsv2graph_infomax(path_list[0], global_dict)

In [12]:
dev = torch.device('cuda')

In [13]:
summarizer = Set2Set(512, 3)
encoder = GraphTransformerEncoder(n_layers=1,
                                  n_heads=4,
                                  n_hid=512,
                                  pretrained_weights=prot_embs).to(dev)

In [15]:
model = SNInfomax(hidden_channels=512,
                  encoder=encoder,
                  summary=summarizer,
                  semi=False).to(dev)

In [16]:
model.load_state_dict(torch.load('embeddings/deep_graph_infomax/unsupervised/DGI_JSD_512_seq_uniform_un.pt'))

<All keys matched successfully>

In [17]:
@torch.no_grad()
def emb_csv(model, loader):
    counter = 0
    
    model.eval()
    embeddings = np.zeros((len(path_list), 1024))
    
    idx=0
    for graph in tqdm(loader):
        bs = graph.batch.max() + 1
        try:
            s = model.encoder(graph.to(dev))
            embeddings[idx:idx+bs] = model.summary(s, graph.batch).squeeze().cpu().numpy()
        except:
            pass
        idx = idx + bs
        
    return embeddings

In [18]:
embs = emb_csv(model, loader)

  0%|          | 0/1986 [00:00<?, ?it/s]

In [19]:
upl_f = lambda x : re.sub('../snac_data/use_case_graphs/','', x)
upl_f_csv = lambda x : re.sub('.csv','', x)
upl_f_emb = lambda x : re.sub('/graph','_emb',x)

upl = [*map(upl_f, path_list)]
upl = [*map(upl_f_csv, upl)]
upl = [*map(upl_f_emb, upl)]
upl = np.array(upl)

In [20]:
cols = ['emb',*map(str, *[range(1024)])]

In [21]:
df1 = pd.DataFrame(upl)
df2 = pd.DataFrame(embs)

In [22]:
df = pd.concat([df1, df2], axis=1)
df.columns = cols

In [23]:
df = df[df['0'] != 0]

In [449]:
df.to_csv('embeddings/deep_graph_infomax/unsupervised/DGI_JSD_512_seqveq_testcases.csv')

In [24]:
df

Unnamed: 0,emb,0,1,2,3,4,5,6,7,8,...,1014,1015,1016,1017,1018,1019,1020,1021,1022,1023
0,GSE96649_BELINOSTAT_emb_44,0.007429,0.003274,0.067972,0.026227,-0.012666,-0.004330,-0.002837,-0.022832,-0.000316,...,0.222896,0.212524,-0.217339,0.092297,0.141280,0.649693,-0.145343,0.096458,0.426202,0.309243
1,GSE96649_BELINOSTAT_emb_14,0.008200,0.003626,0.063266,0.026176,-0.012074,-0.004717,-0.005060,-0.024497,-0.002120,...,0.237001,0.207553,-0.208679,0.096283,0.130156,0.679872,-0.150913,0.078107,0.444935,0.342346
2,GSE96649_BELINOSTAT_emb_96,0.022013,0.005497,0.036963,0.031281,-0.002337,0.006641,0.001852,-0.043626,-0.004464,...,0.082040,0.190898,-0.026745,0.188483,0.071921,0.523984,-0.101918,0.099437,0.546916,0.303597
3,GSE96649_BELINOSTAT_emb_15,0.013621,0.008024,0.059604,0.040085,0.006622,0.014660,-0.002493,-0.025604,0.001317,...,0.180352,0.245094,-0.173990,0.162177,0.178941,0.559412,-0.103725,0.055944,0.360395,0.292528
4,GSE96649_BELINOSTAT_emb_18,0.016261,0.008828,0.059656,0.036430,0.005563,0.015797,-0.002767,-0.019779,0.001351,...,0.209292,0.258217,-0.185290,0.160515,0.192327,0.550498,-0.084916,0.045485,0.372187,0.296094
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1981,GSE19638_DOXORUBICIN_emb_8,-0.025240,0.000181,-0.010573,0.073867,0.013747,-0.014697,0.003689,0.007094,-0.006415,...,0.385207,0.683621,0.417258,-0.030003,-0.019724,-0.072981,0.030870,0.222163,0.114944,0.226626
1982,GSE19638_DOXORUBICIN_emb_85,-0.013367,0.002868,0.012296,0.088381,0.015524,-0.010955,0.003487,0.009042,0.001549,...,0.380582,0.662486,0.263034,0.017619,-0.121847,-0.144883,0.011316,0.277096,0.103885,0.204205
1983,GSE19638_DOXORUBICIN_emb_11,-0.025290,0.000288,-0.010169,0.075082,0.014108,-0.015093,0.003341,0.007447,-0.006045,...,0.382313,0.685066,0.415945,-0.028740,-0.018843,-0.071200,0.027523,0.222210,0.115262,0.223777
1984,GSE19638_DOXORUBICIN_emb_86,-0.026671,0.000122,-0.002587,0.061708,0.016876,-0.021163,0.005231,0.007961,-0.003718,...,0.478525,0.549034,0.322026,-0.012508,-0.029314,-0.075459,0.009253,0.251418,0.058822,0.187436
