In [62]:
import os
import sys
import torch
import pickle
import matplotlib.pyplot as plt
import numpy as np
import math
import random
import networkx as nx
import glob
import pandas as pd
import scipy 
import sklearn
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch_geometric.utils import convert
from torch_geometric.data import InMemoryDataset, download_url, Data
from torch_geometric.loader import DataLoader
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool, global_max_pool

In [14]:
# Load in label by name
def get_label(file, labels):
    pair_1 = file.split('/')[-1]
    pair_1, pair_2 = pair_1.split("and")
    pair_1 = pair_1.replace(".gpickle", "")
    pair_2 = pair_2.replace(".gpickle", "")
    l = int(labels.loc[(labels.protein_1 == pair_1) & (labels.protein_2 == pair_2)].label)
    return file, l

def read_graphs(file_set):
    g_list = []
    for i, file in enumerate(file_set):
        G = nx.read_gpickle(file)
        g_list.append(G)
    return g_list
    
def format_graphs(graphs, label=1):
    graph_list = []
    # Convert into pytorch geoetric dataset: Positive
    for i, x in enumerate(tqdm(graphs)):
        F = nx.convert_node_labels_to_integers(x)
        for (n1, n2, d) in F.edges(data=True):
            d.clear()
        data = convert.from_networkx(F, group_edge_attrs=None)
        data.y = torch.FloatTensor(np.array([label]))
        graph_list.append(data)
    return graph_list

def binary_acc(y_pred, y_test):
    probas = torch.sigmoid(y_pred)
    y_pred_tag = torch.round(torch.sigmoid(y_pred))
    correct_results_sum = (y_pred_tag == y_test).sum().float()
    acc = correct_results_sum/y_test.shape[0]
    acc = torch.round(acc * 100)
    return acc, y_pred_tag, probas

In [6]:
# Import the data
graph_dir_path = '/mnt/mnemo5/sum02dean/sl_projects/GCN/GCN-STRING/src/scripts/graph_data'
labels_dir_path = '/mnt/mnemo5/sum02dean/sl_projects/GCN/GCN-STRING/src/scripts/graph_labels'

graph_files = glob.glob(os.path.join(graph_dir_path, '*'))
graph_labels = glob.glob(os.path.join(labels_dir_path, '*'))
graph_labels = pd.read_csv(graph_labels[0])

# Create positive and negative sets
positives = []
negatives = []

for i, file in enumerate(graph_files):
    obs, label = get_label(file, graph_labels)
    
    if label == 1:
        positives.append(obs)
    else:
        negatives.append(obs)


In [7]:
# Balance the number of negatives with number of positives
negatives = np.random.choice(negatives, size=len(positives), replace=False)

In [11]:
# Read in the positives
pos_graphs = read_graphs(positives)
neg_graphs = read_graphs(negatives)


In [15]:
# Format graphs
positive_graphs = format_graphs(pos_graphs, label=1)
negative_graphs = format_graphs(neg_graphs, label=0)

100%|██████████| 5456/5456 [06:01<00:00, 15.10it/s]  
100%|██████████| 5456/5456 [06:27<00:00, 14.09it/s] 


In [50]:
df_pos = pd.DataFrame(positive_graphs[0].x)
for i in tqdm(range(1,len(positive_graphs))):
    df_2 = pd.DataFrame(positive_graphs[i].x)
    df_pos = pd.concat([df_pos, df_2], axis=0, ignore_index=True)

100%|██████████| 5455/5455 [02:49<00:00, 32.10it/s] 


In [53]:
df_neg = pd.DataFrame(negative_graphs[0].x)
for i in tqdm(range(1,len(negative_graphs))):
    df_2 = pd.DataFrame(negative_graphs[i].x)
    df_neg = pd.concat([df_neg, df_2], axis=0, ignore_index=True)

100%|██████████| 5455/5455 [03:29<00:00, 26.10it/s] 


In [85]:
df_pos['label'] = [1] * np.shape(df)[0]
df_neg['label'] = [0] * np.shape(df_neg)[0]


In [102]:
from sklearn.decomposition import PCA
c = pd.concat([df_pos, df_neg])
c_labels = c['label']

In [104]:
pca = PCA(n_components=2)
pcs = pca.fit_transform(c.iloc[:, :-1].values)
pcs = pd.DataFrame(pcs, columns=['PC1', 'PC2'])


KeyboardInterrupt: 

In [90]:
pcs['label'] = c_labels.values
pcs

Unnamed: 0,PC1,PC2,label
0,-104.978883,112.363586,1
1,-105.926855,24.809373,1
2,64.161726,50.862657,1
3,55.633954,11.575439,1
4,61.976700,55.284570,1
...,...,...,...
7014690,72.080888,29.896771,0
7014691,64.026486,17.281433,0
7014692,42.766979,8.665316,0
7014693,31.319823,3.628742,0


In [None]:
import seaborn as sns
plt.figure(figsize=(16,10))
sns.scatterplot(
    x="PC1", y="PC2",
    hue="label",
    data=pcs,
    legend="full",
    alpha=0.3
)

plt.show()