In [12]:
# !pip install torch_geometric

In [13]:
# from google.colab import drive
# drive.mount('/content/drive')

In [14]:
# !ls /content/drive/MyDrive/all_data

In [15]:
from sklearn.model_selection import StratifiedShuffleSplit
from rdkit import Chem
import pandas as pd
from collections import defaultdict
from tqdm import tqdm
import numpy as np
import os
import pickle
import torch
from torch_geometric.data import Data


In [16]:
# import pandas as pd
# df= pd.read_csv("/content/drive/My Drive/all_data/twosides.csv")

In [17]:
# df

In [18]:
# !pip install PyTDC


In [19]:
import pandas as pd

df_drugbank = pd.read_csv("./datasets/drugbank.csv")
df_drugbank  # Display first few rows


Unnamed: 0,Drug1_ID,Drug1,Drug2_ID,Drug2,Y
0,DB04571,CC1=CC2=CC3=C(OC(=O)C=C3C)C(C)=C2O1,DB00460,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...,1
1,DB00855,NCC(=O)CCC(O)=O,DB00460,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...,1
2,DB09536,O=[Ti]=O,DB00460,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...,1
3,DB01600,CC(C(O)=O)C1=CC=C(S1)C(=O)C1=CC=CC=C1,DB00460,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...,1
4,DB09000,CC(CN(C)C)CN1C2=CC=CC=C2SC2=C1C=C(C=C2)C#N,DB00460,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...,1
...,...,...,...,...,...
191803,DB00437,OC1=NC=NC2=C1C=NN2,DB00492,CCC(=O)O[C@@H](O[P@](=O)(CCCCC1=CC=CC=C1)CC(=O...,86
191804,DB00437,OC1=NC=NC2=C1C=NN2,DB09477,[H][C@@](C)(N[C@@]([H])(CCC1=CC=CC=C1)C(O)=O)C...,86
191805,DB00437,OC1=NC=NC2=C1C=NN2,DB00790,[H][C@]12C[C@H](N(C(=O)[C@H](C)N[C@@H](CCC)C(=...,86
191806,DB00415,[H][C@]12SC(C)(C)[C@@H](N1C(=O)[C@H]2NC(=O)[C@...,DB00437,OC1=NC=NC2=C1C=NN2,86


In [20]:
df_drugbank.isnull().sum()

Drug1_ID    0
Drug1       0
Drug2_ID    0
Drug2       0
Y           0
dtype: int64

In [21]:
from operator import index
import torch
from collections import defaultdict
from sklearn.model_selection import StratifiedShuffleSplit
from rdkit import Chem
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle
import os

In [22]:
def one_of_k_encoding(k, possible_values):
    if k not in possible_values:
        raise ValueError(f"{k} is not a valid value in {possible_values}")
    return [k == e for e in possible_values]


def one_of_k_encoding_unk(x, allowable_set):
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

In [23]:
def save_data(data, filename, dirname="data/preprocessed", dataset="drugbank"):
    save_path = os.path.join(dirname, dataset)

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    filepath = os.path.join(save_path, filename)

    with open(filepath, 'wb') as f:
        pickle.dump(data, f)

    print(f'\nData saved as {filepath}!')


In [None]:
def atom_features(atom, atom_symbols, explicit_H=True, use_chirality=False):

    results = one_of_k_encoding_unk(atom.GetSymbol(), atom_symbols + ['Unknown']) + \
            one_of_k_encoding(atom.GetDegree(),[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + \
            one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) + \
                [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \
            one_of_k_encoding_unk(atom.GetHybridization(), [
                Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.
                                    SP3D, Chem.rdchem.HybridizationType.SP3D2
                ]) + [atom.GetIsAromatic()]
    # In case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs`
    if explicit_H:
        results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(),
                                                [0, 1, 2, 3, 4])
    # if use_chirality:
    #     try:
    #         results = results + one_of_k_encoding_unk(
    #         atom.GetProp('_CIPCode'),
    #         ['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
    #     except:
    #         results = results + [False, False
    #                         ] + [atom.HasProp('_ChiralityPossible')]

    results = np.array(results).astype(np.float32)

    return torch.from_numpy(results)


def edge_features(bond):
    bond_type = bond.GetBondType()
    return torch.tensor([
        bond_type == Chem.rdchem.BondType.SINGLE,
        bond_type == Chem.rdchem.BondType.DOUBLE,
        bond_type == Chem.rdchem.BondType.TRIPLE,
        bond_type == Chem.rdchem.BondType.AROMATIC,
        bond.GetIsConjugated(),
        bond.IsInRing()]).long()

In [25]:
def generate_drug_data(mol_graph, atom_symbols):

    edge_list = torch.LongTensor([(b.GetBeginAtomIdx(), b.GetEndAtomIdx(), *edge_features(b)) for b in mol_graph.GetBonds()])
    edge_list, edge_feats = (edge_list[:, :2], edge_list[:, 2:].float()) if len(edge_list) else (torch.LongTensor([]), torch.FloatTensor([]))
    edge_list = torch.cat([edge_list, edge_list[:, [1, 0]]], dim=0) if len(edge_list) else edge_list
    edge_feats = torch.cat([edge_feats]*2, dim=0) if len(edge_feats) else edge_feats

    features = [(atom.GetIdx(), atom_features(atom, atom_symbols)) for atom in mol_graph.GetAtoms()]
    features.sort()
    _, features = zip(*features)
    features = torch.stack(features)

    line_graph_edge_index = torch.LongTensor([])
    if edge_list.nelement() != 0:
        conn = (edge_list[:, 1].unsqueeze(1) == edge_list[:, 0].unsqueeze(0)) & (edge_list[:, 0].unsqueeze(1) != edge_list[:, 1].unsqueeze(0))
        line_graph_edge_index = conn.nonzero(as_tuple=False).T

    new_edge_index = edge_list.T

    return features, new_edge_index, edge_feats, line_graph_edge_index


def load_drug_mol_data(df_drugbank):

    data = df_drugbank
    drug_id_mol_tup = []
    symbols = list()
    drug_smile_dict = {}

    for id1, smiles1, id2, smiles2, relation in zip(data['Drug1_ID'], data['Drug1'], data['Drug2_ID'], data['Drug2'], data['Y']):
        drug_smile_dict[id1] = smiles1
        drug_smile_dict[id2] = smiles2

    for id, smiles in drug_smile_dict.items():
        mol =  Chem.MolFromSmiles(smiles.strip())
        if mol is not None:
            drug_id_mol_tup.append((id, mol))
            symbols.extend(atom.GetSymbol() for atom in mol.GetAtoms())

    symbols = list(set(symbols))
    drug_data = {id: generate_drug_data(mol, symbols) for id, mol in tqdm(drug_id_mol_tup, desc='Processing drugs')}
    save_data(drug_data, 'drug_data.pkl', dirname="data/preprocessed", dataset="drugbank")
    return drug_data


In [26]:
def generate_pair_triplets(df_drugbank, neg_ent =1, seed=42, dirname="data/preprocessed", dataset="drugbank"):
    pos_triplets = []
    drug_ids = []

    with open(f'{dirname}/{dataset.lower()}/drug_data.pkl', 'rb') as f:
        drug_ids = list(pickle.load(f).keys())
        data = df_drugbank
    for id1, id2, relation in zip(data['Drug1_ID'], data['Drug2_ID'],  data['Y']):
        if ((id1 not in drug_ids) or (id2 not in drug_ids)): continue
        # Drugbank dataset is 1-based index, need to substract by 1
        if dataset in ('drugbank', ):
            relation -= 1
        pos_triplets.append([id1, id2, relation])

    if len(pos_triplets) == 0:
        raise ValueError('All tuples are invalid.')

    pos_triplets = np.array(pos_triplets)
    data_statistics = load_data_statistics(pos_triplets)
    drug_ids = np.array(drug_ids)

    random_state = np.random.RandomState(seed)

    neg_samples = []
    for pos_item in tqdm(pos_triplets, desc='Generating Negative sample'):
        temp_neg = []
        h, t, r = pos_item[:3]

        if dataset == 'drugbank':
            neg_heads, neg_tails = _normal_batch(h, t, r, neg_ent, data_statistics, drug_ids, random_state)
            temp_neg = [str(neg_h) + '$h' for neg_h in neg_heads] + \
                        [str(neg_t) + '$t' for neg_t in neg_tails]
        else:
            existing_drug_ids = np.asarray(list(set(
                np.concatenate([data_statistics["ALL_TRUE_T_WITH_HR"][(h, r)], data_statistics["ALL_TRUE_H_WITH_TR"][(h, r)]], axis=0)
                )))
            temp_neg = _corrupt_ent(existing_drug_ids, neg_ent, drug_ids, random_state)

        neg_samples.append('_'.join(map(str, temp_neg[:neg_ent])))

    df = pd.DataFrame({'Drug1_ID': pos_triplets[:, 0],
                        'Drug2_ID': pos_triplets[:, 1],
                        'Y': pos_triplets[:, 2],
                        'Neg samples': neg_samples})
    filename = f'{dirname}/{dataset}/pair_pos_neg_triplets.csv'
    df.to_csv(filename, index=False)
    print(f'\nData saved as {filename}!')
    save_data(data_statistics, 'data_statistics.pkl', dirname="data/preprocessed", dataset="drugbank")


def load_data_statistics(all_tuples):

    print('Loading data statistics ...')
    statistics = dict()
    statistics["ALL_TRUE_H_WITH_TR"] = defaultdict(list)
    statistics["ALL_TRUE_T_WITH_HR"] = defaultdict(list)
    statistics["FREQ_REL"] = defaultdict(int)
    statistics["ALL_H_WITH_R"] = defaultdict(dict)
    statistics["ALL_T_WITH_R"] = defaultdict(dict)
    statistics["ALL_TAIL_PER_HEAD"] = {}
    statistics["ALL_HEAD_PER_TAIL"] = {}

    for h, t, r in tqdm(all_tuples, desc='Getting data statistics'):
        statistics["ALL_TRUE_H_WITH_TR"][(t, r)].append(h)
        statistics["ALL_TRUE_T_WITH_HR"][(h, r)].append(t)
        statistics["FREQ_REL"][r] += 1.0
        statistics["ALL_H_WITH_R"][r][h] = 1
        statistics["ALL_T_WITH_R"][r][t] = 1

    for t, r in statistics["ALL_TRUE_H_WITH_TR"]:
        statistics["ALL_TRUE_H_WITH_TR"][(t, r)] = np.array(list(set(statistics["ALL_TRUE_H_WITH_TR"][(t, r)])))
    for h, r in statistics["ALL_TRUE_T_WITH_HR"]:
        statistics["ALL_TRUE_T_WITH_HR"][(h, r)] = np.array(list(set(statistics["ALL_TRUE_T_WITH_HR"][(h, r)])))

    for r in statistics["FREQ_REL"]:
        statistics["ALL_H_WITH_R"][r] = np.array(list(statistics["ALL_H_WITH_R"][r].keys()))
        statistics["ALL_T_WITH_R"][r] = np.array(list(statistics["ALL_T_WITH_R"][r].keys()))
        statistics["ALL_HEAD_PER_TAIL"][r] = statistics["FREQ_REL"][r] / len(statistics["ALL_T_WITH_R"][r])
        statistics["ALL_TAIL_PER_HEAD"][r] = statistics["FREQ_REL"][r] / len(statistics["ALL_H_WITH_R"][r])

    print('getting data statistics done!')

    return statistics


def _corrupt_ent(positive_existing_ents, max_num, drug_ids, random_state):
    corrupted_ents = []
    while len(corrupted_ents) < max_num:
        candidates = random_state.choice(drug_ids, (max_num - len(corrupted_ents)) * 2, replace=False)
        invalid_drug_ids = np.concatenate([positive_existing_ents, corrupted_ents], axis=0)
        mask = np.isin(candidates, invalid_drug_ids, assume_unique=True, invert=True)
        corrupted_ents.extend(candidates[mask])

    corrupted_ents = np.array(corrupted_ents)[:max_num]
    return corrupted_ents


def _normal_batch( h, t, r, neg_size, data_statistics, drug_ids, random_state):
    neg_size_h = 0
    neg_size_t = 0
    prob = data_statistics["ALL_TAIL_PER_HEAD"][r] / (data_statistics["ALL_TAIL_PER_HEAD"][r] +
                                                            data_statistics["ALL_HEAD_PER_TAIL"][r])
    # prob = 2
    for i in range(neg_size):
        if random_state.random() < prob:
            neg_size_h += 1
        else:
            neg_size_t +=1

    return (_corrupt_ent(data_statistics["ALL_TRUE_H_WITH_TR"][t, r], neg_size_h, drug_ids, random_state),
            _corrupt_ent(data_statistics["ALL_TRUE_T_WITH_HR"][h, r], neg_size_t, drug_ids, random_state))


In [27]:
def split_data(class_name, seed, test_ratio, n_folds, dirname="data/preprocessed", dataset="drugbank"):
    filename = f'{dirname}/{dataset}/pair_pos_neg_triplets.csv'
    df = pd.read_csv(filename)
    seed = seed
    class_name = class_name
    test_size_ratio = test_ratio
    n_folds = n_folds
    save_to_filename = os.path.splitext(filename)[0]
    cv_split = StratifiedShuffleSplit(n_splits=n_folds, test_size=test_size_ratio, random_state=seed)
    for fold_i, (train_index, test_index) in enumerate(cv_split.split(X=df, y=df[class_name])):
        print(f'Fold {fold_i} generated!')
        train_df = df.iloc[train_index]
        test_df = df.iloc[test_index]
        train_df.to_csv(f'{save_to_filename}_train_fold{fold_i}.csv', index=False)
        print(f'{save_to_filename}_train_fold{fold_i}.csv', 'saved!')
        test_df.to_csv(f'{save_to_filename}_test_fold{fold_i}.csv', index=False)
        print(f'{save_to_filename}_test_fold{fold_i}.csv', 'saved!')

In [28]:
drug_data = load_drug_mol_data(df_drugbank)

[10:31:46] SMILES Parse Error: syntax error while parsing: OC1=CC=CC(=C1)C-1=C2\CCC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC(O)=CC=C1)C1=CC(O)=CC=C1)\C1=CC(O)=CC=C1
[10:31:46] SMILES Parse Error: check for mistakes around position 76:
[10:31:46] C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC(O)=CC=C1)C
[10:31:46] ~~~~~~~~~~~~~~~~~~~~^
[10:31:46] SMILES Parse Error: Failed parsing SMILES 'OC1=CC=CC(=C1)C-1=C2\CCC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC(O)=CC=C1)C1=CC(O)=CC=C1)\C1=CC(O)=CC=C1' for input: 'OC1=CC=CC(=C1)C-1=C2\CCC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC(O)=CC=C1)C1=CC(O)=CC=C1)\C1=CC(O)=CC=C1'
  new_edge_index = edge_list.T
Processing drugs: 100%|██████████| 1705/1705 [00:03<00:00, 567.56it/s]



Data saved as data/preprocessed\drugbank\drug_data.pkl!


In [29]:
generate_pair_triplets(df_drugbank, neg_ent=1, seed=42)

Loading data statistics ...


Getting data statistics: 100%|██████████| 191798/191798 [00:00<00:00, 439621.09it/s]


getting data statistics done!


Generating Negative sample: 100%|██████████| 191798/191798 [00:17<00:00, 11137.84it/s]



Data saved as data/preprocessed/drugbank/pair_pos_neg_triplets.csv!

Data saved as data/preprocessed\drugbank\data_statistics.pkl!


In [30]:
split_data('Y', seed=42, test_ratio=0.2, n_folds=3)

Fold 0 generated!
data/preprocessed/drugbank/pair_pos_neg_triplets_train_fold0.csv saved!
data/preprocessed/drugbank/pair_pos_neg_triplets_test_fold0.csv saved!
Fold 1 generated!
data/preprocessed/drugbank/pair_pos_neg_triplets_train_fold1.csv saved!
data/preprocessed/drugbank/pair_pos_neg_triplets_test_fold1.csv saved!
Fold 2 generated!
data/preprocessed/drugbank/pair_pos_neg_triplets_train_fold2.csv saved!
data/preprocessed/drugbank/pair_pos_neg_triplets_test_fold2.csv saved!


# getting datasets

In [31]:
import math
import torch
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Batch, Data
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
import pickle

In [32]:
NUM_FEATURES = None
NUM_EDGE_FEATURES = None
bipartite_edge_dict = dict()
drug_num_node_indices = dict()

In [33]:
import pickle
def total_num_rel():
    """Returns the total number of relations for DrugBank dataset."""
    return 86

def split_train_valid(data, fold, val_ratio=0.2):
    """Splits the dataset into training and validation sets."""
    cv_split = StratifiedShuffleSplit(n_splits=2, test_size=val_ratio, random_state=fold)
    pos_triplets, neg_samples = data
    train_index, val_index = next(iter(cv_split.split(X=pos_triplets, y=pos_triplets[:, 2])))

    train_tup = (pos_triplets[train_index], neg_samples[train_index])
    val_tup = (pos_triplets[val_index], neg_samples[val_index])

    return train_tup, val_tup

def load_ddi_data_fold(fold, batch_size=32, data_size_ratio=1.0, valid_ratio=0.2, dirname="data/preprocessed"):
    """Loads DrugBank dataset and prepares PyTorch DataLoaders."""
    global NUM_FEATURES, NUM_EDGE_FEATURES, drug_num_node_indices

    dataset_name = "drugbank"
    print(f'Loading {dataset_name}...')

    # Load processed drug data
    drug_data_file = f'{dirname}/{dataset_name}/drug_data.pkl'
    print('\nLoading processed drug data...')
    with open(drug_data_file, 'rb') as f:
        all_drug_data = pickle.load(f)

    # Extract feature dimensions
    NUM_FEATURES, _, NUM_EDGE_FEATURES = next(iter(all_drug_data.values()))[:3]
    NUM_FEATURES, NUM_EDGE_FEATURES = NUM_FEATURES.shape[1], NUM_EDGE_FEATURES.shape[1]

    # Convert data to CustomData format
    all_drug_data = {
        drug_id: CustomData(x=data[0], edge_index=data[1], edge_feats=data[2], line_graph_edge_index=data[3])
        for drug_id, data in all_drug_data.items()
    }

    # Speed up training with precomputed indices
    drug_num_node_indices = {
        drug_id: torch.zeros(data.x.size(0)).long() for drug_id, data in all_drug_data.items()
    }

    # Load train/validation/test splits
    train_tup = load_split(f'train_fold{fold}', dirname)
    train_tup, val_tup = split_train_valid(train_tup, fold)
    test_tup = load_split(f'test_fold{fold}', dirname)

    print(f'{train_tup[1].shape[1]} negative samples on fold {fold}')

    # Create dataset objects
    train_data = DrugDataset(train_tup, all_drug_data, seed=fold)
    val_data = DrugDataset(val_tup, all_drug_data, seed=fold)
    test_data = DrugDataset(test_tup, all_drug_data, seed=fold)

    print(f"\nTraining on {len(train_data)} samples, validating on {len(val_data)}, and testing on {len(test_data)} samples.")

    # Create DataLoaders
    train_loader = DrugDataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DrugDataLoader(val_data, batch_size=batch_size)
    test_loader = DrugDataLoader(test_data, batch_size=batch_size)

    return train_loader, val_loader, test_loader, NUM_FEATURES, NUM_EDGE_FEATURES

def load_split(split_name, dirname="data/preprocessed"):
    """Loads dataset splits for DrugBank."""
    filename = f'{dirname}/drugbank/pair_pos_neg_triplets_{split_name}.csv'
    print(f'\nLoading {filename}...')

    df = pd.read_csv(filename)
    pos_triplets = [(d1, d2, r) for d1, d2, r in zip(df['Drug1_ID'], df['Drug2_ID'], df['Y'])]
    neg_samples = [[str(e) for e in neg_s.split('_')] for neg_s in df['Neg samples']]

    return np.array(pos_triplets), np.array(neg_samples)

# ===========================
# Dataset Classes
# ===========================

class DrugDataset(Dataset):
    """Custom PyTorch Dataset for DrugBank."""

    def __init__(self, pos_neg_tuples, all_drug_data, ratio=1.0, seed=0):
        self.pair_triplets = []
        self.ratio = ratio
        self.drug_ids = list(all_drug_data.keys())
        self.all_drug_data = all_drug_data
        self.rng = np.random.RandomState(seed)

        for pos_item, neg_list in zip(*pos_neg_tuples):
            if (pos_item[0] in self.drug_ids) and (pos_item[1] in self.drug_ids):
                self.pair_triplets.append((pos_item, neg_list))

        if ratio != 1.0:
            self.rng.shuffle(self.pair_triplets)
            limit = math.ceil(len(self.pair_triplets) * ratio)
            self.pair_triplets = self.pair_triplets[:limit]

    def collate_fn(self, batch):
        old_id_to_new_batch_id = {}
        batch_drug_feats = []
        self.node_ind_seqs = []
        self.node_i_ind_seqs_for_pair = []
        self.node_j_ind_seqs_for_pair = []

        combo_indices_pos = []
        combo_indices_neg = []
        already_in_combo = {}
        rels = []
        batch_unique_pairs= []

        for ind, (pos_item, neg_list) in enumerate(batch):
            h, t, r = pos_item[:3]
            idx_h, h_num_nodes = self._get_new_batch_id_and_num_nodes(h, old_id_to_new_batch_id, batch_drug_feats)
            idx_t, t_num_nodes = self._get_new_batch_id_and_num_nodes(t, old_id_to_new_batch_id, batch_drug_feats)
            combo_idx = self._get_combo_index((idx_h, idx_t), (h, t), already_in_combo, batch_unique_pairs, (h_num_nodes, t_num_nodes))
            combo_indices_pos.append(combo_idx)

            rels.append(int(r))

            for neg_s in neg_list:
                s = neg_s.split('$')
                neg_idx, neg_num_nodes = self._get_new_batch_id_and_num_nodes(s[0], old_id_to_new_batch_id, batch_drug_feats)
                if ('h' == s[1].lower()):
                        combo_idx = self._get_combo_index((neg_idx, idx_t), (s[0], t), already_in_combo, batch_unique_pairs, (neg_num_nodes, t_num_nodes))
                else:
                    combo_idx = self._get_combo_index((idx_h, neg_idx), (h, s[0]), already_in_combo, batch_unique_pairs, (h_num_nodes, neg_num_nodes))

                combo_indices_neg.append(combo_idx)

        batch_drug_data = Batch.from_data_list(batch_drug_feats, follow_batch=['edge_index'])
        batch_drug_pair_indices = torch.LongTensor(combo_indices_pos + combo_indices_neg)
        batch_unique_drug_pair = Batch.from_data_list(batch_unique_pairs, follow_batch=['edge_index'])
        node_j_for_pairs = torch.cat(self.node_j_ind_seqs_for_pair)
        node_i_for_pairs = torch.cat(self.node_i_ind_seqs_for_pair)
        rels = torch.LongTensor(rels)

        return batch_drug_data, batch_unique_drug_pair, rels, batch_drug_pair_indices, node_j_for_pairs, node_i_for_pairs

    def _get_new_batch_id_and_num_nodes(self, old_id, old_id_to_new_batch_id, batch_drug_feats):
        new_id = old_id_to_new_batch_id.get(old_id, -1)
        num_nodes = self.all_drug_data[old_id].x.size(0)
        if new_id == - 1:
            new_id = len(old_id_to_new_batch_id)
            old_id_to_new_batch_id[old_id] = new_id
            batch_drug_feats.append(self.all_drug_data[old_id])
            start = (self.node_ind_seqs[-1][-1] + 1) if len(self.node_ind_seqs) else 0
            self.node_ind_seqs.append(torch.arange(num_nodes) + start)

        return new_id, num_nodes

    def _get_combo_index(self, combo, old_combo, already_in_combo, unique_pairs, num_nodes):
        idx = already_in_combo.get(combo, -1)
        if idx == -1:
            idx = len(already_in_combo)
            already_in_combo[combo] = idx
            pair_edge_index = bipartite_edge_dict.get(old_combo)
            if pair_edge_index is None:
                index_j = torch.arange(num_nodes[0]).repeat_interleave(num_nodes[1])
                index_i = torch.arange(num_nodes[1]).repeat(num_nodes[0])
                pair_edge_index = torch.stack([index_j, index_i])
                bipartite_edge_dict[old_combo] = pair_edge_index

            j_num_indices, i_num_indices = drug_num_node_indices[old_combo[0]], drug_num_node_indices[old_combo[1]]
            unique_pairs.append(PairData(j_num_indices, i_num_indices, pair_edge_index))
            self.node_j_ind_seqs_for_pair.append(self.node_ind_seqs[combo[0]])
            self.node_i_ind_seqs_for_pair.append(self.node_ind_seqs[combo[1]])

        return idx

    def __len__(self):
        return len(self.pair_triplets)

    def __getitem__(self, index):
        return self.pair_triplets[index]

class DrugDataLoader(DataLoader):
    """Custom DataLoader for DrugBank."""
    def __init__(self, data, **kwargs):
        super().__init__(data, collate_fn=data.collate_fn, **kwargs)

class PairData(Data):

    def __init__(self, j_indices, i_indices, pair_edge_index):
        super().__init__()
        self.i_indices = i_indices
        self.j_indices = j_indices
        self.edge_index = pair_edge_index
        self.num_nodes = None

    def __inc__(self, key, value, *args, **kwargs):
    # In case of "TypeError: __inc__() takes 3 positional arguments but 4 were given"
    # Replace with "def __inc__(self, key, value, *args, **kwargs)"
        if key == 'edge_index':
            return torch.tensor([[self.j_indices.shape[0]], [self.i_indices.shape[0]]])
        if key in ('i_indices', 'j_indices'):
            return 1
        return super().__inc__(self, key, value, args, kwargs)
            # In case of "TypeError: __inc__() takes 3 positional arguments but 4 were given"
            # Replace with "return super().__inc__(self, key, value, args, kwargs)"



class CustomData(Data):
    def __inc__(self, key, value, *args, **kwargs):  # Accepts additional arguments
        if key == 'line_graph_edge_index':
            return self.edge_index.size(1) if self.edge_index.nelement() != 0 else 0
        return super().__inc__(key, value, *args, **kwargs)  # Pass extra args



In [34]:
train_loader, val_loader, test_loader, num_features, num_edge_features = load_ddi_data_fold(
    fold=0, batch_size=32, data_size_ratio=1.0
)


Loading drugbank...

Loading processed drug data...

Loading data/preprocessed/drugbank/pair_pos_neg_triplets_train_fold0.csv...

Loading data/preprocessed/drugbank/pair_pos_neg_triplets_test_fold0.csv...
1 negative samples on fold 0

Training on 122750 samples, validating on 30688, and testing on 38360 samples.


# Model

In [35]:
# import torch
# !pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
# !pip install git+https://github.com/pyg-team/pytorch_geometric.git

In [36]:
# Import necessary libraries
import torch
from torch import nn
from torch_geometric.nn.inits import glorot
from torch_geometric.utils import degree
from torch_scatter import scatter

# Define CustomDropout
class CustomDropout(nn.Module):
    def __init__(self, p):
        super().__init__()
        self.dropout = (lambda x: x) if p == 0 else nn.Dropout(p)

    def forward(self, input):
        return self.dropout(input)

# Define GmpnnBlock
class GmpnnBlock(nn.Module):
    def __init__(self, edge_feats, n_feats, n_iter, dropout):
        super().__init__()
        self.n_feats = n_feats
        self.n_iter = n_iter
        self.dropout = dropout
        self.snd_n_feats = n_feats * 2

        self.w_i = nn.Parameter(torch.Tensor(self.n_feats, self.n_feats))
        self.w_j = nn.Parameter(torch.Tensor(self.n_feats, self.n_feats))
        self.a = nn.Parameter(torch.Tensor(1, self.n_feats))
        self.bias = nn.Parameter(torch.zeros(self.n_feats))

        self.edge_emb = nn.Sequential(nn.Linear(edge_feats, self.n_feats))

        self.lin1 = nn.Sequential(nn.BatchNorm1d(n_feats), nn.Linear(n_feats, self.snd_n_feats))
        self.lin2 = nn.Sequential(nn.BatchNorm1d(self.snd_n_feats), CustomDropout(self.dropout), nn.PReLU(), nn.Linear(self.snd_n_feats, self.snd_n_feats))
        self.lin3 = nn.Sequential(nn.BatchNorm1d(self.snd_n_feats), CustomDropout(self.dropout), nn.PReLU(), nn.Linear(self.snd_n_feats, self.snd_n_feats))
        self.lin4 = nn.Sequential(nn.BatchNorm1d(self.snd_n_feats), CustomDropout(self.dropout), nn.PReLU(), nn.Linear(self.snd_n_feats, self.snd_n_feats))

        glorot(self.w_i)
        glorot(self.w_j)
        glorot(self.a)

        self.sml_mlp = nn.Sequential(nn.PReLU(), nn.Linear(self.n_feats, self.n_feats))

    def forward(self, data):
        edge_index = data.edge_index
        edge_feats = data.edge_feats
        edge_feats = self.edge_emb(edge_feats)
        deg = degree(edge_index[1], data.x.size(0), dtype=data.x.dtype)

        alpha_i = (data.x @ self.w_i)
        alpha_j = (data.x @ self.w_j)
        alpha = alpha_i[edge_index[1]] + alpha_j[edge_index[0]] + self.bias
        alpha = self.sml_mlp(alpha)

        alpha = (alpha * edge_feats).sum(-1)
        alpha = alpha / (deg[edge_index[0]])
        edge_weights = torch.sigmoid(alpha)

        edge_attr = data.x[edge_index[0]] * edge_weights.unsqueeze(-1)

        out = edge_attr
        for _ in range(self.n_iter):
            out = scatter(out[data.line_graph_edge_index[0]], data.line_graph_edge_index[1], dim_size=edge_attr.size(0), dim=0, reduce='add')
            out = edge_attr + (out * edge_weights.unsqueeze(-1))

        x = data.x + scatter(out, edge_index[1], dim_size=data.x.size(0), dim=0, reduce='add')
        x = self.mlp(x)

        return x

    def mlp(self, x):
        x = self.lin1(x)
        x = (self.lin3(self.lin2(x)) + x) / 2
        x = (self.lin4(x) + x) / 2
        return x

# Define GmpnnCSNetDrugBank
class GmpnnCSNetDrugBank(nn.Module):
    def __init__(self, in_feats, edge_feats, hid_feats, rel_total, n_iter, dropout=0):
        super().__init__()
        self.in_feats = in_feats
        self.hid_feats = hid_feats
        self.rel_total = rel_total
        self.n_iter = n_iter
        self.dropout = dropout
        self.snd_hid_feats = hid_feats * 2

        self.mlp = nn.Sequential(
            nn.Linear(in_feats, hid_feats),
            CustomDropout(self.dropout),
            nn.PReLU(),
            nn.Linear(hid_feats, hid_feats),
            nn.BatchNorm1d(hid_feats),
            CustomDropout(self.dropout),
            nn.PReLU(),
            nn.Linear(hid_feats, hid_feats),
            nn.BatchNorm1d(hid_feats),
            CustomDropout(self.dropout),
        )

        self.propagation_layer = GmpnnBlock(edge_feats, self.hid_feats, self.n_iter, dropout)

        self.i_pro = nn.Parameter(torch.zeros(self.snd_hid_feats, self.hid_feats))
        self.j_pro = nn.Parameter(torch.zeros(self.snd_hid_feats, self.hid_feats))
        self.bias = nn.Parameter(torch.zeros(self.hid_feats))

        self.rel_embs = nn.Embedding(self.rel_total, self.hid_feats)

        glorot(self.i_pro)
        glorot(self.j_pro)

    def forward(self, batch):
        drug_data, unique_drug_pair, rels, drug_pair_indices, node_j_for_pairs, node_i_for_pairs = batch
        drug_data.x = self.mlp(drug_data.x)

        new_feats = self.propagation_layer(drug_data)
        drug_data.x = new_feats
        x_j = drug_data.x[node_j_for_pairs]
        x_i = drug_data.x[node_i_for_pairs]

        pair_repr = ((x_i[unique_drug_pair.edge_index[1]] @ self.i_pro) * (x_j[unique_drug_pair.edge_index[0]] @ self.j_pro))
        pair_repr = scatter(pair_repr, unique_drug_pair.edge_index_batch, reduce='add', dim=0)[drug_pair_indices]

        p_scores, n_scores = self.compute_score(pair_repr, rels)
        return p_scores, n_scores

    def compute_score(self, pair_repr, rels):
        batch_size = len(rels)
        neg_n = (len(pair_repr) - batch_size) // batch_size
        rels = torch.cat([rels, torch.repeat_interleave(rels, neg_n, dim=0)], dim=0)
        rels = self.rel_embs(rels)
        scores = (pair_repr * rels).sum(-1)
        p_scores, n_scores = scores[:batch_size].unsqueeze(-1), scores[batch_size:].view(batch_size, -1, 1)
        return p_scores, n_scores


In [37]:
model = GmpnnCSNetDrugBank(in_feats=128, edge_feats=32, hid_feats=64, rel_total=86, n_iter=2)
print(model)


GmpnnCSNetDrugBank(
  (mlp): Sequential(
    (0): Linear(in_features=128, out_features=64, bias=True)
    (1): CustomDropout()
    (2): PReLU(num_parameters=1)
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): CustomDropout()
    (6): PReLU(num_parameters=1)
    (7): Linear(in_features=64, out_features=64, bias=True)
    (8): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): CustomDropout()
  )
  (propagation_layer): GmpnnBlock(
    (edge_emb): Sequential(
      (0): Linear(in_features=32, out_features=64, bias=True)
    )
    (lin1): Sequential(
      (0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Linear(in_features=64, out_features=128, bias=True)
    )
    (lin2): Sequential(
      (0): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): CustomDropout

# Train on Fold

In [38]:
import torch
from torch import nn
import torch.nn.functional as F



class SigmoidLoss(nn.Module):

    def forward(self, p_scores, n_scores):
        p_loss = - F.logsigmoid(p_scores).mean()
        n_loss = - F.logsigmoid(-n_scores).mean()

        return (p_loss + n_loss) / 2, p_loss, n_loss

In [39]:
from operator import le
from sklearn import metrics
from collections import defaultdict
import json
import numpy as np


def do_compute_metrics(probas_pred, target):
    pred = (probas_pred >= 0.5).astype(int)
    acc = metrics.accuracy_score(target, pred)
    auroc = metrics.roc_auc_score(target, probas_pred)
    f1_score = metrics.f1_score(target, pred)
    precision = metrics.precision_score(target, pred)
    recall = metrics.recall_score(target, pred)
    p, r, t = metrics.precision_recall_curve(target, probas_pred)
    int_ap = metrics.auc(r, p)
    ap= metrics.average_precision_score(target, probas_pred)

    return acc, auroc, f1_score, precision, recall, int_ap, ap

In [None]:
from datetime import datetime
import numpy as np
import torch
from torch import optim
import time
from tqdm import tqdm



dataset_name = 'drugbank'
fold_i = 0
dropout = 0.2
n_iter = 3
TOTAL_NUM_RELS = total_num_rel()
batch_size = 512
data_size_ratio = 1
device = 'cuda' if torch.cuda.is_available() else 'cpu'
hid_feats = 64
rel_total = TOTAL_NUM_RELS
lr = 1e-3
weight_decay = 5e-4
n_epochs = 20
kge_feats = 64

def do_compute(model, batch, device):

        batch = [t.to(device) for t in batch]
        p_score, n_score = model(batch)
        assert p_score.ndim == 2
        assert n_score.ndim == 3
        probas_pred = np.concatenate([torch.sigmoid(p_score.detach()).cpu().mean(dim=-1), torch.sigmoid(n_score.detach()).mean(dim=-1).view(-1).cpu()])
        ground_truth = np.concatenate([np.ones(p_score.shape[0]), np.zeros(n_score.shape[:2]).reshape(-1)])

        return p_score, n_score, probas_pred, ground_truth


def run_batch(model, optimizer, data_loader, epoch_i, desc, loss_fn, device):
        total_loss = 0
        loss_pos = 0
        loss_neg = 0
        probas_pred = []
        ground_truth = []

        for batch in tqdm(data_loader, desc= f'{desc} Epoch {epoch_i}'):
            p_score, n_score, batch_probas_pred, batch_ground_truth = do_compute(model, batch, device)

            probas_pred.append(batch_probas_pred)
            ground_truth.append(batch_ground_truth)

            loss, loss_p, loss_n = loss_fn(p_score, n_score)
            if model.training:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            total_loss += loss.item()
            loss_pos += loss_p.item()
            loss_neg += loss_n.item()
        total_loss /= len(data_loader)
        loss_pos /= len(data_loader)
        loss_neg /= len(data_loader)

        probas_pred = np.concatenate(probas_pred)
        ground_truth = np.concatenate(ground_truth)

        return total_loss, do_compute_metrics(probas_pred, ground_truth)


def print_metrics(loss, acc, auroc, f1_score, precision, recall, int_ap, ap):
    print(f'loss: {loss:.4f}, acc: {acc:.4f}, roc: {auroc:.4f}, f1: {f1_score:.4f}, ', end='')
    print(f'p: {precision:.4f}, r: {recall:.4f}, int-ap: {int_ap:.4f}, ap: {ap:.4f}')

    return f1_score


def train(model, train_data_loader, val_data_loader, test_data_loader, loss_fn, optimizer, n_epochs, device, scheduler):
    for epoch_i in range(1, n_epochs+1):
        start = time.time()
        model.train()
        
        ## Training
        train_loss, train_metrics = run_batch(model, optimizer, train_data_loader, epoch_i, 'train', loss_fn, device)
        if scheduler:
            scheduler.step()

        model.eval()
        with torch.no_grad():
            ## Validation
            if val_data_loader:
                val_loss, val_metrics = run_batch(model, optimizer, val_data_loader, epoch_i, 'val', loss_fn, device)
            
            ## Test Set Evaluation
            if test_data_loader:
                test_loss, test_metrics = run_batch(model, optimizer, test_data_loader, epoch_i, 'test', loss_fn, device)

        print(f'\n#### Epoch time {time.time() - start:.4f}s')
        print_metrics(train_loss, *train_metrics)

        if val_data_loader:
            print('#### Validation')
            print_metrics(val_loss, *val_metrics)

        if test_data_loader:
            print('#### Test')
            print_metrics(test_loss, *test_metrics)



train_data_loader, val_data_loader, test_data_loader, NUM_FEATURES, NUM_EDGE_FEATURES = \
    load_ddi_data_fold(
    fold=0, batch_size=32, data_size_ratio=1.0)

GmpnnNet = GmpnnCSNetDrugBank if dataset_name == 'drugbank' else GmpnnCSNetDrugBank

model = GmpnnNet(NUM_FEATURES, NUM_EDGE_FEATURES, hid_feats, rel_total, n_iter, dropout)
loss_fn = SigmoidLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))

time_stamp = f'{datetime.now()}'.replace(':', '_')


model.to(device=device)
print(f'Training on {device}.')
print(f'Starting fold_{fold_i} at', datetime.now())
train(model, train_data_loader, val_data_loader, test_data_loader, loss_fn, optimizer, n_epochs, device, scheduler)

Loading drugbank...

Loading processed drug data...

Loading data/preprocessed/drugbank/pair_pos_neg_triplets_train_fold0.csv...

Loading data/preprocessed/drugbank/pair_pos_neg_triplets_test_fold0.csv...
1 negative samples on fold 0

Training on 122750 samples, validating on 30688, and testing on 38360 samples.
Training on cuda.
Starting fold_0 at 2025-02-28 16:11:09.772483


train Epoch 1: 100%|██████████| 3836/3836 [04:31<00:00, 14.10it/s]  
val Epoch 1: 100%|██████████| 959/959 [00:58<00:00, 16.44it/s]
test Epoch 1: 100%|██████████| 1199/1199 [01:16<00:00, 15.61it/s]



#### Epoch time 407.5823s
loss: 2.1496, acc: 0.5063, roc: 0.5074, f1: 0.5079, p: 0.5062, r: 0.5095, int-ap: 0.5113, ap: 0.5061
#### Validation
loss: 0.7009, acc: 0.5347, roc: 0.5523, f1: 0.5294, p: 0.5354, r: 0.5236, int-ap: 0.5401, ap: 0.5402
#### Test
loss: 0.6994, acc: 0.5306, roc: 0.5498, f1: 0.5242, p: 0.5314, r: 0.5172, int-ap: 0.5379, ap: 0.5379


train Epoch 2: 100%|██████████| 3836/3836 [03:56<00:00, 16.22it/s]
val Epoch 2: 100%|██████████| 959/959 [00:58<00:00, 16.38it/s]
test Epoch 2: 100%|██████████| 1199/1199 [01:11<00:00, 16.71it/s]



#### Epoch time 367.2161s
loss: 0.7241, acc: 0.5359, roc: 0.5499, f1: 0.5490, p: 0.5339, r: 0.5649, int-ap: 0.5404, ap: 0.5404
#### Validation
loss: 0.6769, acc: 0.5928, roc: 0.6327, f1: 0.6551, p: 0.5682, r: 0.7735, int-ap: 0.6053, ap: 0.6054
#### Test
loss: 0.6748, acc: 0.5927, roc: 0.6334, f1: 0.6550, p: 0.5681, r: 0.7734, int-ap: 0.6079, ap: 0.6079


train Epoch 3: 100%|██████████| 3836/3836 [04:04<00:00, 15.69it/s] 
val Epoch 3: 100%|██████████| 959/959 [00:58<00:00, 16.49it/s]
test Epoch 3: 100%|██████████| 1199/1199 [01:13<00:00, 16.28it/s]



#### Epoch time 376.7321s
loss: 0.6628, acc: 0.6035, roc: 0.6437, f1: 0.6172, p: 0.5967, r: 0.6391, int-ap: 0.6247, ap: 0.6247
#### Validation
loss: 0.6530, acc: 0.6247, roc: 0.6749, f1: 0.6792, p: 0.5931, r: 0.7947, int-ap: 0.6416, ap: 0.6416
#### Test
loss: 0.6523, acc: 0.6244, roc: 0.6743, f1: 0.6786, p: 0.5930, r: 0.7932, int-ap: 0.6425, ap: 0.6425


train Epoch 4: 100%|██████████| 3836/3836 [03:56<00:00, 16.25it/s]
val Epoch 4: 100%|██████████| 959/959 [01:08<00:00, 13.91it/s]
test Epoch 4: 100%|██████████| 1199/1199 [00:49<00:00, 24.04it/s]



#### Epoch time 355.3553s
loss: 0.6390, acc: 0.6340, roc: 0.6871, f1: 0.6434, p: 0.6274, r: 0.6602, int-ap: 0.6663, ap: 0.6663
#### Validation
loss: 0.6262, acc: 0.6519, roc: 0.7096, f1: 0.6768, p: 0.6316, r: 0.7289, int-ap: 0.6870, ap: 0.6870
#### Test
loss: 0.6254, acc: 0.6512, roc: 0.7092, f1: 0.6745, p: 0.6322, r: 0.7229, int-ap: 0.6892, ap: 0.6892


train Epoch 5: 100%|██████████| 3836/3836 [04:53<00:00, 13.09it/s]
val Epoch 5: 100%|██████████| 959/959 [01:03<00:00, 15.07it/s]
test Epoch 5: 100%|██████████| 1199/1199 [02:15<00:00,  8.85it/s]



#### Epoch time 492.7232s
loss: 0.6182, acc: 0.6561, roc: 0.7168, f1: 0.6656, p: 0.6477, r: 0.6845, int-ap: 0.6974, ap: 0.6974
#### Validation
loss: 0.6146, acc: 0.6621, roc: 0.7227, f1: 0.6980, p: 0.6310, r: 0.7810, int-ap: 0.7025, ap: 0.7025
#### Test
loss: 0.6149, acc: 0.6621, roc: 0.7232, f1: 0.6969, p: 0.6319, r: 0.7768, int-ap: 0.7032, ap: 0.7032


train Epoch 6: 100%|██████████| 3836/3836 [16:16<00:00,  3.93it/s]   
val Epoch 6: 100%|██████████| 959/959 [00:47<00:00, 20.40it/s]
test Epoch 6:  79%|███████▉  | 953/1199 [05:04<00:42,  5.81it/s]