In [1]:
import numpy as np
import pandas as pd
import sys
from tqdm.notebook import tqdm
import networkx as nx
sys.path.append('..')

import torch
from torch.functional import F
import torch.nn as nn

from torch_geometric.data import Data, DataLoader, Dataset
from torch_geometric.utils import from_networkx, to_networkx, degree, dense_to_sparse
from torch_geometric.nn import GATConv, GCNConv, global_add_pool, PNAConv, BatchNorm, CGConv, global_max_pool
from torch_geometric.utils.metric import accuracy, precision, f1_score
import torch_geometric.transforms as T

from models.graph_transformer.euclidean_graph_transformer import GraphTransformerEncoder

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.metrics import roc_auc_score, accuracy_score, average_precision_score, precision_score

from utils.data_gen import load_prot_embs, to_categorical, wcsv2graph, SNLDataset
from utils.chem import rdkit2graph, tensorise_smiles

dev = torch.device('cuda:0')

OSError: /home/rootlocus/anaconda3/envs/torch/lib/python3.7/site-packages/torch_sparse/_version.so: undefined symbol: _ZN3c105ErrorC1ENS_14SourceLocationERKSs

# Protein Loading and Preprocessing

In [2]:
prot_embs, global_dict = load_prot_embs(1024, norm=False)

In [3]:
labelled_ugraphs = pd.read_csv('../snac_data/graph_classification_all.csv')
weighted_df = pd.read_csv('../snac_data/file_info_weighted.csv')

val_set_1 = pd.read_csv('../snac_data/splits/val_set_1.csv')
val_set_2 = pd.read_csv('../snac_data/splits/val_set_2.csv')
val_set_3 = pd.read_csv('../snac_data/splits/val_set_3.csv')
val_set_4 = pd.read_csv('../snac_data/splits/val_set_4.csv')
test_set = pd.read_csv('../snac_data/splits/test_set.csv')

In [4]:
wsample_path = weighted_df.files_weighted.to_numpy()[200]
data = wcsv2graph(wsample_path, global_dict, [0,0,1])

In [70]:
usm = pd.DataFrame(labelled_ugraphs.groupby('sig_id').moa_v1.unique()).reset_index()
usm_corr = np.array([np.array(i) for i in usm.moa_v1.to_numpy()]).reshape(-1)
usm['moa_v1'] = usm_corr

X_df = pd.merge(weighted_df, usm, on='sig_id')
val_df =  pd.merge(X_df, val_set_1, on='sig_id')
test_df = pd.merge(X_df, test_set, on='sig_id')

for sig in tqdm(test_set.sig_id):
    X_df = X_df[X_df['sig_id'] != sig]

HBox(children=(FloatProgress(value=0.0, max=1031.0), HTML(value='')))




In [83]:
X_train, y_train = X_df.files_weighted.to_numpy(), X_df.moa_v1.to_numpy()
X_val, y_val = val_df.files_weighted.to_numpy(), val_df.moa_v1_x.to_numpy()
X_test, y_test = test_df.files_weighted.to_numpy(), test_df.moa_v1_x.to_numpy()

X_train_drugs = X_df.rdkit.to_numpy()
X_val_drugs = val_df.rdkit.to_numpy()
X_test_drugs = val_df.rdkit.to_numpy()

le = OneHotEncoder()
y = np.concatenate([y_train, y_val, y_test])
le = le.fit(y.reshape(-1,1))
y_train = le.transform(y_train.reshape(len(y_train),-1)).toarray()
y_val = le.transform(y_val.reshape(len(y_val),-1)).toarray()
y_test = le.transform(y_test.reshape(len(y_test), -1)).toarray()

train_data = SNLDataset(X_train, y_train, global_dict)
val_data = SNLDataset(X_val, y_val, global_dict)
test_data = SNLDataset(X_test, y_test, global_dict)

train_loader = DataLoader(train_data, batch_size=1, num_workers=12, shuffle=True)
val_loader = DataLoader(val_data, batch_size=1, num_workers=12, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1, num_workers=12)

# Drugs Loading and Preprocessing

In [10]:
try:
    from rdkit import Chem
except ImportError:
    Chem = None

x_map = {
    'atomic_num':
    list(range(0, 119)),
    'chirality': [
        'CHI_UNSPECIFIED',
        'CHI_TETRAHEDRAL_CW',
        'CHI_TETRAHEDRAL_CCW',
        'CHI_OTHER',
    ],
    'degree':
    list(range(0, 11)),
    'formal_charge':
    list(range(-5, 7)),
    'num_hs':
    list(range(0, 9)),
    'num_radical_electrons':
    list(range(0, 5)),
    'hybridization': [
        'UNSPECIFIED',
        'S',
        'SP',
        'SP2',
        'SP3',
        'SP3D',
        'SP3D2',
        'OTHER',
    ],
    'is_aromatic': [False, True],
    'is_in_ring': [False, True],
}

e_map = {
    'bond_type': [
        'misc',
        'SINGLE',
        'DOUBLE',
        'TRIPLE',
        'AROMATIC',
    ],
    'stereo': [
        'STEREONONE',
        'STEREOZ',
        'STEREOE',
        'STEREOCIS',
        'STEREOTRANS',
        'STEREOANY',
    ],
    'is_conjugated': [False, True],
}

In [14]:
mol_sample = weighted_df.rdkit[200]
mol_sample

'Brc1c(NC2=NCCN2)ccc2nccnc12'

In [96]:
def deg_distr(data):
    deg = torch.zeros(5, dtype=torch.long)
    d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
    deg += torch.bincount(d, minlength=deg.numel())
    return deg

def smiles2graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        pass
    
    xs = []
    for atom in mol.GetAtoms():
        x = []
        x.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
        x.append(x_map['chirality'].index(str(atom.GetChiralTag())))
        x.append(x_map['degree'].index(atom.GetTotalDegree()))
        x.append(x_map['formal_charge'].index(atom.GetFormalCharge()))
        x.append(x_map['num_hs'].index(atom.GetTotalNumHs()))
        x.append(x_map['num_radical_electrons'].index(
            atom.GetNumRadicalElectrons()))
        x.append(x_map['hybridization'].index(
            str(atom.GetHybridization())))
        x.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
        x.append(x_map['is_in_ring'].index(atom.IsInRing()))
        xs.append(x)

    x = torch.tensor(xs, dtype=torch.long).view(-1, 9)

    edge_indices, edge_attrs = [], []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()

        e = []
        bt = torch.eye(5)[e_map['bond_type'].index(str(bond.GetBondType()))].tolist()
        e = bt
        e.append(e_map['stereo'].index(str(bond.GetStereo())))
        e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))

        edge_indices += [[i, j], [j, i]]
        edge_attrs += [e, e]

    edge_index = torch.tensor(edge_indices)
    edge_index = edge_index.t().to(torch.long).view(2, -1)
    edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1,7)
    
    if edge_index.numel() > 0:
            perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
            edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]
    
    return Data(x=x, edge_attr=edge_attr, edge_index=edge_index)

In [97]:
class NetNDrugs(Dataset):
    def __init__(self, fnames, rdkit,  y, global_dict):
        super(NetNDrugs, self).__init__()
        self.fnames = fnames
        self.gd = global_dict
        self.y = y
        self.rdkit = rdkit
        
    def len(self):
        return len(self.fnames)
        
    def get(self, idx):
        graph = wcsv2graph(self.fnames[idx], self.gd, self.y[idx])
        drug = smiles2graph(self.rdkit[idx])
        return graph, drug

In [101]:
train_data = NetNDrugs(X_train, X_train_drugs, y_train, global_dict)
val_data = NetNDrugs(X_val, X_val_drugs, y_val, global_dict)
test_data = NetNDrugs(X_test, X_test_drugs, y_test, global_dict)

In [104]:
test_data

NetNDrugs(1031)