In [104]:
import numpy as np
import pandas as pd
import sys
from tqdm.notebook import tqdm
import networkx as nx
sys.path.append('..')

import torch
from torch.functional import F
import torch.nn as nn

from torch_geometric.data import Data, DataLoader, Dataset
from torch_geometric.utils import from_networkx, to_networkx, degree
from torch_geometric.nn import GATConv, GCNConv, global_add_pool, PNAConv, BatchNorm, CGConv, global_max_pool
from torch_geometric.utils.metric import accuracy, precision, f1_score
import torch_geometric.transforms as T

from models.graph_transformer.euclidean_graph_transformer import GraphTransformerEncoder

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.metrics import roc_auc_score, accuracy_score, average_precision_score, precision_score

from utils.data_gen import load_prot_embs, to_categorical, wcsv2graph, SNLDataset

dev = torch.device('cuda:0')

In [105]:
prot_embs, global_dict = load_prot_embs(512, norm=False)

In [106]:
labelled_ugraphs = pd.read_csv('../snac_data/graph_classification_all.csv')
weighted_df = pd.read_csv('../snac_data/file_info_weighted.csv')

val_set_1 = pd.read_csv('../snac_data/splits/val_set_1.csv')
val_set_2 = pd.read_csv('../snac_data/splits/val_set_2.csv')
val_set_3 = pd.read_csv('../snac_data/splits/val_set_3.csv')
val_set_4 = pd.read_csv('../snac_data/splits/val_set_4.csv')
test_set = pd.read_csv('../snac_data/splits/test_set.csv')

In [107]:
wsample_path = weighted_df.files_weighted.to_numpy()[4]
data = wcsv2graph(wsample_path, global_dict, [0,0,1])

In [108]:
usm = pd.DataFrame(labelled_ugraphs.groupby('sig_id').moa_v1.unique()).reset_index()
usm_corr = np.array([np.array(i) for i in usm.moa_v1.to_numpy()]).reshape(-1)
usm['moa_v1'] = usm_corr

X_df = pd.merge(weighted_df, usm, on='sig_id')
val_df =  pd.merge(X_df, val_set_4, on='sig_id')
test_df = pd.merge(X_df, test_set, on='sig_id')

for sig in tqdm(test_set.sig_id):
    X_df = X_df[X_df['sig_id'] != sig]

HBox(children=(FloatProgress(value=0.0, max=1031.0), HTML(value='')))




In [109]:
X_train, y_train = X_df.files_weighted.to_numpy(), X_df.moa_v1.to_numpy()
X_val, y_val = val_df.files_weighted.to_numpy(), val_df.moa_v1_x.to_numpy()
X_test, y_test = test_df.files_weighted.to_numpy(), test_df.moa_v1_x.to_numpy()

In [110]:
le = OneHotEncoder()
y = np.concatenate([y_train, y_val, y_test])
le = le.fit(y.reshape(-1,1))
y_train = le.transform(y_train.reshape(len(y_train),-1)).toarray()
y_val = le.transform(y_val.reshape(len(y_val),-1)).toarray()
y_test = le.transform(y_test.reshape(len(y_test), -1)).toarray()

In [111]:
train_data = SNLDataset(X_train, y_train, global_dict)
val_data = SNLDataset(X_val, y_val, global_dict)
test_data = SNLDataset(X_test, y_test, global_dict)

train_loader = DataLoader(train_data, batch_size=1, num_workers=12, shuffle=True)
val_loader = DataLoader(val_data, batch_size=1, num_workers=12, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1, num_workers=12)

# Eval

In [18]:
preds = np.array(preds)
p_reshaped = preds.reshape(4, 1031, 255)

In [19]:
pred_mean = torch.mean(torch.tensor(p_reshaped), dim=0)

In [20]:
pred_labels = torch.argmax(pred_mean, dim=-1)
true_labels = np.argmax(y_test, axis=-1)

In [21]:
test_set['true'] = true_labels
test_set['predicted'] = pred_labels

In [22]:
per_drug_acc(test_set)

0.375

In [23]:
per_sig_acc(test_set)

0.1794871794871795