In [3]:
import numpy as np
import pandas as pd
import pickle
import math
import sys
sys.path.append('..')

import torch
import torch.nn as nn
from torch.functional import F

import torch_geometric
from torch_geometric.data import Data, DataLoader, Dataset
from torch_geometric.utils import add_self_loops, degree, to_dense_adj,remove_self_loops, to_networkx
from torch_geometric.nn import Set2Set
from torch_sparse import SparseTensor

import networkx as nx
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer, LabelEncoder, OneHotEncoder
from sklearn.metrics import accuracy_score, roc_auc_score
from tqdm.notebook import tqdm

from models.graph_transformer.euclidean_graph_transformer import GraphTransformerEncoder
from models.deep_graph_infomax.infomax import SNInfomax
from models.graph_transformer.autoencoder_base import DeepSNEM, LinearDecoder, FermiDiracDecoder
from utils.data_gen import load_prot_embs, wcsv2graph, load_prot_embs_go, SNDatasetInfomax, ucsv2graph_infomax

import re
import gc

torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
torch.cuda.is_available()

True

# Load a sample graph negative_sampling

In [4]:
unique_prots = 'data/prot_embeddings/new_features/proteins.csv'
unique_df = pd.read_csv(unique_prots)
global_dict = {}

for idx, prot in enumerate(unique_df.proteins.to_numpy()):
    global_dict[prot] = idx

In [5]:
unweighted_fnames = 'data/graph_info_df/samples_all.csv'
u_fnames_all = pd.read_csv(unweighted_fnames).drop('Unnamed: 0', axis=1)
u_fnames_all.columns = ['files_combined']
u_path_list = u_fnames_all.files_combined.to_numpy()
usample = u_path_list[80]

unweighted_total = '../snac_data/file_info.csv'
u_total = pd.read_csv(unweighted_total)

moa_fnames = '../snac_data/graph_classification_all.csv'
moa_fnames = pd.read_csv(moa_fnames)

mapping = '../snac_data/sig_mapping.csv'
mapping = pd.read_csv(mapping)

# Create the appropriate data loader

In [6]:
u_fnames = pd.read_csv('data/graph_info_df/full_dataset.csv')
oh = OneHotEncoder()
u_path_list = u_fnames.files_combined.to_numpy()
labels = u_fnames.sigs_g.to_numpy().reshape(-1,1)
labels = oh.fit_transform(labels).toarray()

samples_all = 'data/graph_info_df/samples_all.csv'
samples_all = pd.read_csv(samples_all)
u_path_list = samples_all.path_list.values

train_data = SNDatasetInfomax(u_path_list, global_dict)

u_loader = DataLoader(train_data, batch_size=16, num_workers=12, shuffle=False)

# Test the Graph Transformer

In [8]:
dev = torch.device('cuda')

SIZE = 512
EMB_DIM = 512

prot_embs = load_prot_embs_go(512, norm=False)
summarizer = Set2Set(512, 3)
enc = GraphTransformerEncoder(n_layers=4, n_heads=4, n_hid=512, 
                            pretrained_weights=prot_embs[0], summarizer=None).to(dev)

model = SNInfomax(hidden_channels=512, encoder=enc,
                                     summary=Set2Set(512, 3), semi=False).to(dev)

In [9]:
model.load_state_dict(torch.load('embeddings/deep_graph_infomax/unsupervised/DGI_JSD_512_seqveq_uniform_un.pt'))

<All keys matched successfully>

In [10]:
@torch.no_grad()
def emb_csv(model, loader):
    model.eval()
    embeddings = np.zeros((len(u_path_list), 1024))
    
    idx=0
    for graph in tqdm(loader):
        bs = graph.batch.max() + 1
        s = model.encoder(graph.to(dev))
        embeddings[idx:idx+bs] = model.summary(s, graph.batch).squeeze().cpu().numpy()
        idx = idx + bs
        
    return embeddings

In [11]:
embs = emb_csv(model, u_loader)

HBox(children=(FloatProgress(value=0.0, max=4324.0), HTML(value='')))




In [12]:
upl_f = lambda x : re.sub('graphs_combined/','', x)
upl_f_csv = lambda x : re.sub('.csv','', x)
upl_f_emb = lambda x : re.sub('/graph','_emb',x)

upl = [*map(upl_f, u_path_list)]
upl = [*map(upl_f_csv, upl)]
upl = [*map(upl_f_emb, upl)]
upl = np.array(upl)

In [13]:
cols = ['emb',*map(str, *[range(1024)])]

In [14]:
df1 = pd.DataFrame(upl)
df2 = pd.DataFrame(embs)

In [15]:
df = pd.concat([df1, df2], axis=1)
df.columns = cols

In [None]:
df.to_csv('embeddings/deep_graph_infomax/unsupervised/DGI_JSD_512_seqveq_uniform_un_l4.csv')