In [1]:
import numpy as np 
import pandas as pd
import re
import os
import sys
sys.path.append('..')

import torch
from torch_geometric.data import DataLoader, Dataset
from torch_geometric.nn import Set2Set
from torch_geometric.utils import from_networkx

import networkx as nx
from tqdm.notebook import tqdm

from models.graph_transformer.euclidean_graph_transformer import GraphTransformerEncoder, MultipleOptimizer
from models.deep_graph_infomax.infomax import SNInfomax
from utils.data_gen import load_prot_embs, to_categorical

In [2]:
prot_embs, global_dict = load_prot_embs(512, False)

In [3]:
def ucsv2graph_infomax(fname, global_dict, sig_one_hot=None, y=None):
    """
    Unweighted Graph Creator
    """
    sample = pd.read_csv(fname)

    G = nx.from_pandas_edgelist(sample,
                                source='node1',
                                target='node2',
                                edge_attr=['sign'],
                                create_using=nx.DiGraph())

    n1a1 = sample[['node1', 'activity1']]
    n1a1.columns = ['node', 'act']

    n2a2 = sample[['node2', 'activity2']]
    n2a2.columns = ['node', 'act']
    na = pd.concat([n1a1, n2a2])
    na = na.drop_duplicates('node')
    na = na.set_index('node')
    na['acts'] = na[['act']].apply(lambda x: np.hstack(x), axis=1)
    na = na.drop(['act'], axis=1)['acts'].to_dict()

    nx.set_node_attributes(G, global_dict, 'global_idx')
    nx.set_node_attributes(G, na, 'acts')

    data = from_networkx(G)

    data.acts[data.acts < 0] = 0
    data.acts = to_categorical(data.acts, 2).reshape(-1, 2).long()

    data.sign[data.sign < 0] = 0
    data.sign = to_categorical(data.sign, 2).reshape(-1, 2).float()

    data.weight = torch.ones(data.num_edges)

    if sig_one_hot is not None:
        data.sig = torch.tensor(sig_one_hot)
        data.node_sig = data.sig.view(1, -1).repeat(data.num_nodes, 1)

    if y is not None:
        data.y = torch.tensor(y)
    # data.seq_mat = torch.add(aps, aps_r)

    return data

In [4]:
class SNDatasetInfomax(Dataset):
    def __init__(self, fnames, global_dict, sig_id_one_hot=None):
        super(SNDatasetInfomax, self).__init__()
        self.fnames = fnames
        self.gd = global_dict
        self.sig_id = sig_id_one_hot

    def len(self):
        return len(self.fnames)

    def get(self, idx):
        if self.sig_id is not None:
            data = ucsv2graph_infomax(self.fnames[idx], self.gd,
                                      self.sig_id[idx])
        else:
            data = ucsv2graph_infomax(self.fnames[idx], self.gd)
        return data

In [5]:
uc_path = 'data/use_cases/use_case_graphs.csv'
use_cases = pd.read_csv(uc_path).drop('Unnamed: 0', axis=1)
use_cases = use_cases.drop(index=[5,40,47])
path_list = use_cases.path.to_numpy()
path_list.shape

(1986,)

In [6]:
test_data = SNDatasetInfomax(path_list, global_dict)
loader = DataLoader(test_data, batch_size=1, num_workers = 12)

In [7]:
data = ucsv2graph_infomax(path_list[0], global_dict)

In [416]:
dev = torch.device('cuda')

In [417]:
summarizer = Set2Set(512, 3)
encoder = GraphTransformerEncoder(n_layers=1,
                                  n_heads=4,
                                  n_hid=512,
                                  pretrained_weights=prot_embs).to(dev)

In [418]:
model = SNInfomax(hidden_channels=512,
                  encoder=encoder,
                  summary=summarizer,
                  semi=False).to(dev)

In [419]:
model.load_state_dict(torch.load('embeddings/deep_graph_infomax/unsupervised/DGI_JSD_512_seq_uniform_un.pt'))

<All keys matched successfully>

In [427]:
@torch.no_grad()
def emb_csv(model, loader):
    counter = 0
    
    model.eval()
    embeddings = np.zeros((len(path_list), 1024))
    
    idx=0
    for graph in tqdm(loader):
        bs = graph.batch.max() + 1
        try:
            s = model.encoder(graph.to(dev))
            embeddings[idx:idx+bs] = model.summary(s, graph.batch).squeeze().cpu().numpy()
        except:
            pass
        idx = idx + bs
        
    return embeddings

In [428]:
embs = emb_csv(model, loader)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1986.0), HTML(value='')))




In [434]:
upl_f = lambda x : re.sub('../snac_data/use_case_graphs/','', x)
upl_f_csv = lambda x : re.sub('.csv','', x)
upl_f_emb = lambda x : re.sub('/graph','_emb',x)

upl = [*map(upl_f, path_list)]
upl = [*map(upl_f_csv, upl)]
upl = [*map(upl_f_emb, upl)]
upl = np.array(upl)

In [438]:
cols = ['emb',*map(str, *[range(1024)])]

In [439]:
df1 = pd.DataFrame(upl)
df2 = pd.DataFrame(embs)

In [440]:
df = pd.concat([df1, df2], axis=1)
df.columns = cols

In [448]:
df = df[df['0'] != 0]

In [449]:
df.to_csv('embeddings/deep_graph_infomax/unsupervised/DGI_JSD_512_seqveq_testcases.csv')