In [1]:
import pandas as pd
import numpy as np
from rdkit import Chem
import dgl
import torch
import networkx as nx
import matplotlib.pyplot as plt

In [2]:
data = pd.read_csv('Synergy_ALMANAC_int.csv')
data = data[['drug_row','drug_col','DepMap_ID','synergy_loewe']]
data['drug_pair'] = data['drug_row'].astype(str).str.cat(data['drug_col'].astype(str), sep=',')
data

Unnamed: 0,drug_row,drug_col,DepMap_ID,synergy_loewe,drug_pair
0,67,66,19,-10.659082,6766
1,32,92,19,-8.796808,3292
2,52,3,19,-5.172253,523
3,29,19,19,2.140359,2919
4,66,5,19,-5.825567,665
...,...,...,...,...,...
236185,1,3,24,-38.087402,13
236186,33,3,24,-24.851390,333
236187,42,3,24,-3.214656,423
236188,72,3,24,-41.788745,723


In [3]:
def atom_features(atom):
    """extract atomic features"""
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
                                          ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As',
                                           'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se',
                                           'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr',
                                           'Pt', 'Hg', 'Pb', 'Unknown']) +
                    one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
                    one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
                    one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
                    [atom.GetIsAromatic()])


def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))


def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def smile_to_graph(smile):
    """set max atom number equals to 100"""
    mol = Chem.MolFromSmiles(smile)
    num_nodes = mol.GetNumAtoms()
    edges_src = []
    edges_dst = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edges_src.append(i)
        edges_dst.append(j)
        
    
    g = dgl.graph((edges_src, edges_dst), num_nodes=num_nodes)
    g = dgl.to_bidirected(g)
    features = np.zeros([num_nodes, 78])
    for i, atom in enumerate(mol.GetAtoms()):
        if atom.GetAtomicNum == 0:
            return None
        feature = atom_features(atom)
        features[i, :] = feature / sum(feature)
    g.ndata['feat'] = torch.from_numpy(features).double()
    return g

In [4]:
Smiles = pd.read_table("Drug_SMILE.txt",sep="\t",names=['smile','drug'])
drug_map = np.load('./Preprocessed/reg/Drug_map.npy', allow_pickle=True)
drug_map = drug_map.item()
drug = []
for key in drug_map:
    graph = smile_to_graph(Smiles.loc[Smiles['drug']==key,'smile'].values[0])
    drug.append(graph)

In [5]:
Pair_map = dict((v, i) for i,v in enumerate(data['drug_pair'].unique()))
pair_graph = []
for key in Pair_map:
    drugs_num = key.split(',')
    drugA_num = int(drugs_num[0])
    drugB_num = int(drugs_num[1])
    drugA_graph = drug[drugA_num]
    drugB_graph = drug[drugB_num]
    
    drugA_centrality = nx.betweenness_centrality(drugA_graph.to_networkx())
    drugB_centrality = nx.betweenness_centrality(drugB_graph.to_networkx())
    drugA_center = sorted(drugA_centrality.items(),key=lambda x:x[1],reverse=True)[0][0]
    drugB_center = sorted(drugB_centrality.items(),key=lambda x:x[1],reverse=True)[0][0]
    
    srcA, dstA = drugA_graph.edges()
    srcB, dstB = drugB_graph.edges()
    srcB = srcB + len(drugA_graph.nodes())
    dstB = dstB + len(drugA_graph.nodes())
    drugB_center = drugB_center + len(drugA_graph.nodes())
    num_nodes = len(drugA_graph.nodes()) + len(drugB_graph.nodes())
    A_feature = drugA_graph.ndata['feat']
    B_feature = drugB_graph.ndata['feat']
    AB_feats = np.row_stack((A_feature, B_feature))
    
    pair_g = dgl.graph((srcA, dstA), num_nodes=num_nodes)
    pair_g.add_edges(srcB, dstB)
    pair_g.add_edges(drugA_center, drugB_center)
    pair_g.add_edges(drugB_center, drugA_center)
    pair_g.ndata['feat'] = torch.from_numpy(AB_feats)
    pair_graph.append(pair_g)

In [6]:
data = data.replace({'drug_pair':Pair_map})
data.to_csv('./Preprocessed/reg/data_to_split.csv')

In [7]:
np.save('./Preprocessed/reg/Pair_map.npy', Pair_map)
np.save('./Preprocessed/reg/Pair_graph.npy', pair_graph)
np.save('./Preprocessed/reg/Drug_graph.npy', drug)