In [1]:
!pip install rdkit
!pip install torch_geometric



In [2]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

2.6.0+cu124
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [3]:
import torch
print("CUDA Available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
    print("CUDA Version:", torch.version.cuda)
    print("Torch Version:", torch.__version__)


CUDA Available: True
GPU: Tesla T4
CUDA Version: 12.4
Torch Version: 2.6.0+cu124


In [4]:
from operator import index
import torch
from collections import defaultdict
from sklearn.model_selection import StratifiedShuffleSplit
from rdkit import Chem
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle
import os

In [5]:
df_drugbank = pd.read_csv("drugbank.csv")
df_drugbank

Unnamed: 0,Drug1_ID,Drug1,Drug2_ID,Drug2,Y
0,DB04571,CC1=CC2=CC3=C(OC(=O)C=C3C)C(C)=C2O1,DB00460,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...,1
1,DB00855,NCC(=O)CCC(O)=O,DB00460,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...,1
2,DB09536,O=[Ti]=O,DB00460,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...,1
3,DB01600,CC(C(O)=O)C1=CC=C(S1)C(=O)C1=CC=CC=C1,DB00460,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...,1
4,DB09000,CC(CN(C)C)CN1C2=CC=CC=C2SC2=C1C=C(C=C2)C#N,DB00460,COC(=O)CCC1=C2NC(\C=C3/N=C(/C=C4\N\C(=C/C5=N/C...,1
...,...,...,...,...,...
191803,DB00437,OC1=NC=NC2=C1C=NN2,DB00492,CCC(=O)O[C@@H](O[P@](=O)(CCCCC1=CC=CC=C1)CC(=O...,86
191804,DB00437,OC1=NC=NC2=C1C=NN2,DB09477,[H][C@@](C)(N[C@@]([H])(CCC1=CC=CC=C1)C(O)=O)C...,86
191805,DB00437,OC1=NC=NC2=C1C=NN2,DB00790,[H][C@]12C[C@H](N(C(=O)[C@H](C)N[C@@H](CCC)C(=...,86
191806,DB00415,[H][C@]12SC(C)(C)[C@@H](N1C(=O)[C@H]2NC(=O)[C@...,DB00437,OC1=NC=NC2=C1C=NN2,86


## Helper Functions

In [6]:
def one_of_k_encoding(k, all_values):
  if k not in all_values:
    raise ValueError(f"{k} is not a valid value in all values: {all_values}")
  return [k==e for e in all_values]

def one_of_k_encoding_unk(k, all_values):
  if k not in all_values:
    k = all_values[-1]
  return list(map(lambda s:k == s, all_values))

In [7]:
def save_data(data, filename, dirname="data/preprocessed", dataset="drugbank"):
  save_path = os.path.join(dirname, dataset)
  if not os.path.exists(save_path):
    os.makedirs(save_path)
  file_path = os.path.join(save_path, filename)

  with open(file_path, 'wb') as f:
    pickle.dump(data, f)

  print(f"Data saved in {file_path}")

## Encode Chemical Properties

In [8]:
def atom_features(atom, atom_symbols, explicit_H=True, use_chirality=False):
  features = []
  features += one_of_k_encoding_unk(atom.GetSymbol(), atom_symbols + ['Unknown'])
  features += one_of_k_encoding(atom.GetDegree(), list(range(11)))
  features += one_of_k_encoding_unk(atom.GetImplicitValence(), list(range(7)))
  features += [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()]

  features += one_of_k_encoding_unk(atom.GetHybridization(), [
      Chem.rdchem.HybridizationType.SP,
      Chem.rdchem.HybridizationType.SP2,
      Chem.rdchem.HybridizationType.SP3,
      Chem.rdchem.HybridizationType.SP3D,
      Chem.rdchem.HybridizationType.SP3D2
  ])

  features += [atom.GetIsAromatic()]

  if explicit_H:
    features += one_of_k_encoding_unk(atom.GetTotalNumHs(), list(range(5)))

  features = np.array(features, dtype=np.float32)
  # print(f"Atom features: {features}")
  return torch.from_numpy(features)


def edge_features(bond):
  bond_type = bond.GetBondType()
  return torch.tensor([
      bond_type == Chem.rdchem.BondType.SINGLE,
      bond_type == Chem.rdchem.BondType.DOUBLE,
      bond_type == Chem.rdchem.BondType.TRIPLE,
      bond_type == Chem.rdchem.BondType.AROMATIC,
      bond.GetIsConjugated(),
      bond.IsInRing()
  ]).long()

In [9]:
def generate_drug_data(molecule, atom_symbols):
    edge_tuples = [
        (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), *edge_features(bond))
        for bond in molecule.GetBonds()
    ]
    edge_tensor = torch.LongTensor(edge_tuples)

    if edge_tensor.numel() > 0:
        edge_index = edge_tensor[:, :2]
        edge_features_tensor = edge_tensor[:, 2:].float()
        edge_index = torch.cat([edge_index, edge_index[:, [1, 0]]], dim=0)
        edge_features_tensor = torch.cat([edge_features_tensor] * 2, dim=0)
    else:
        edge_index = torch.LongTensor([])
        edge_features_tensor = torch.FloatTensor([])

    atom_features_list = [
        (atom.GetIdx(), atom_features(atom, atom_symbols))
        for atom in molecule.GetAtoms()
    ]
    atom_features_list.sort()
    _, atom_features_sorted = zip(*atom_features_list)
    atom_features_tensor = torch.stack(atom_features_sorted)

    line_graph_edge_index = torch.LongTensor([])
    if edge_index.numel() != 0:
        src, dst = edge_index[:, 0], edge_index[:, 1]
        conn_matrix = (dst.unsqueeze(1) == src.unsqueeze(0)) & (src.unsqueeze(1) != dst.unsqueeze(0))
        line_graph_edge_index = conn_matrix.nonzero(as_tuple=False).T

    return atom_features_tensor, edge_index.T, edge_features_tensor, line_graph_edge_index


def load_drug_mol_data(df_drugbank):
    drug_smiles_dict = {}

    for id1, smiles1, id2, smiles2, _ in zip(
        df_drugbank['Drug1_ID'], df_drugbank['Drug1'],
        df_drugbank['Drug2_ID'], df_drugbank['Drug2'],
        df_drugbank['Y']
    ):
        drug_smiles_dict[id1] = smiles1
        drug_smiles_dict[id2] = smiles2

    drug_mol_pairs = []
    atom_symbols = []
    for drug_id, smiles in drug_smiles_dict.items():
        molecule = Chem.MolFromSmiles(smiles.strip())
        if molecule:
            drug_mol_pairs.append((drug_id, molecule))
            atom_symbols.extend(atom.GetSymbol() for atom in molecule.GetAtoms())

    atom_symbols = list(set(atom_symbols))

    drug_data = {
        id: generate_drug_data(mol, atom_symbols)
        for id, mol in tqdm(drug_mol_pairs, desc='Processing drugs')
    }
    # print(drug_data)
    save_data(drug_data, 'drug_data.pkl', dirname="data/preprocessed", dataset="drugbank")
    return drug_data


In [18]:
mol = Chem.MolFromSmiles('CCO')
atom_symbols = ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I']
atom = mol.GetAtomWithIdx(0)
print(atom)
atom_feat = atom_features(atom, atom_symbols)
print("Atom Feature Vector Shape:", atom_feat.shape)

bond = mol.GetBondWithIdx(0)
bond_feat = edge_features(bond)
print("Bond Feature Vector:", bond_feat)

<rdkit.Chem.rdchem.Atom object at 0x7e2dc4df3f40>
Atom Feature Vector Shape: torch.Size([41])
Bond Feature Vector: tensor([1, 0, 0, 0, 0, 0])


In [10]:
def load_data_statistics(all_tuples):
  print("Loading Data Statics")
  stats = dict()
  stats["ALL_TRUE_H_WITH_TR"] = defaultdict(list)
  stats["ALL_TRUE_T_WITH_HR"] = defaultdict(list)
  stats["FREQ_REL"] = defaultdict(int)
  stats["ALL_H_WITH_R"] = defaultdict(dict)
  stats["ALL_T_WITH_R"] = defaultdict(dict)
  stats["ALL_TAIL_PER_HEAD"] = {}
  stats["ALL_HEAD_PER_TAIL"] = {}

  for head, tail, rel in tqdm(all_tuples, desc="Processing Data Stats"):
    stats["ALL_TRUE_H_WITH_TR"][(tail, rel)].append(head)
    stats["ALL_TRUE_T_WITH_HR"][(head, rel)].append(tail)
    stats["FREQ_REL"][rel] += 1
    stats["ALL_H_WITH_R"][rel][head] = 1
    stats["ALL_T_WITH_R"][rel][tail] = 1

  for key in stats['ALL_TRUE_H_WITH_TR']:
    stats["ALL_TRUE_H_WITH_TR"][key] = np.unique(stats["ALL_TRUE_H_WITH_TR"][key])
  for key in stats['ALL_TRUE_T_WITH_HR']:
    stats["ALL_TRUE_T_WITH_HR"][key] = np.unique(stats["ALL_TRUE_T_WITH_HR"][key])

  for rel in stats["FREQ_REL"]:
    head_set = stats["ALL_H_WITH_R"][rel]
    tail_set = stats['ALL_T_WITH_R'][rel]
    stats["ALL_H_WITH_R"][rel] = np.array(list(head_set.keys()))
    stats["ALL_T_WITH_R"][rel] = np.array(list(tail_set.keys()))
    stats["ALL_HEAD_PER_TAIL"][rel] = stats["FREQ_REL"][rel]/ len(stats["ALL_T_WITH_R"][rel])
    stats["ALL_TAIL_PER_HEAD"][rel] = stats["FREQ_REL"][rel] / len(stats["ALL_H_WITH_R"][rel])
  print("Collected data stats")
  print(stats)
  return stats


In [11]:
def _corrupt_ent(existing_positives, max_num, drug_ids, random_state):
  corrupted = []
  while len(corrupted) < max_num:
    candidates = random_state.choice(drug_ids, (max_num - len(corrupted))*2, replace=False)
    invalid = np.concatenate([existing_positives, corrupted], axis=0)
    mask = np.isin(candidates, invalid, assume_unique=True, invert=True)
    corrupted.extend(candidates[mask])

  corrupted = np.array(corrupted)[:max_num]
  return corrupted

In [12]:
def _normal_batch(head, tail, rel, neg_size, stats, drug_ids, random_state):
  prob = stats["ALL_TAIL_PER_HEAD"][rel] / (
      stats["ALL_TAIL_PER_HEAD"][rel] + stats["ALL_HEAD_PER_TAIL"][rel]
  )
  num_head_neg, num_tail_neg = 0,0

  for _ in range(neg_size):
    if random_state.random() < prob:
      num_head_neg += 1
    else:
      num_tail_neg +=1

  return (_corrupt_ent(stats["ALL_TRUE_H_WITH_TR"][tail, rel], num_head_neg, drug_ids, random_state),
          _corrupt_ent(stats["ALL_TRUE_T_WITH_HR"][head, rel], num_tail_neg, drug_ids, random_state))

In [13]:
def generate_pair_triplets(df_drugbank, neg_ent =1, seed=42, dirname="data/preprocessed", dataset="drugbank"):
  with open(f'{dirname}/{dataset}/drug_data.pkl', 'rb') as f:
    drug_data = pickle.load(f)
    drug_ids = list(drug_data.keys())

  pos_triplets = []

  for id1, id2, relation in zip(df_drugbank['Drug1_ID'], df_drugbank['Drug2_ID'], df_drugbank['Y']):
    if id1 not in drug_ids or id2 not in drug_ids:
      continue
    if dataset=='drugbank':
      relation -=1
    pos_triplets.append([id1, id2, relation])

  if not pos_triplets:
    raise ValueError("NO TRIPLETS FOUND; VALUES ARE WRONG")

  pos_triplets = np.array(pos_triplets)


  data_stats = load_data_statistics(pos_triplets)
  drug_ids = np.array(drug_ids)
  random_state = np.random.RandomState(seed)


  neg_triplets = []
  for head, tail, rel in tqdm(pos_triplets, desc="Generating Negative Triplets"):
    if dataset == 'drugbank':
      neg_heads, neg_tails = _normal_batch(head, tail, rel, neg_ent, data_stats, drug_ids, random_state)
      temp_neg = [f"{neg_h}$h" for neg_h in neg_heads] +\
                 [f"{neg_t}$t" for neg_t in neg_tails]
    else:
      existing_drug_ids = np.unique(np.concatenate([
          data_stats["ALL_TRUE_T_WITH_HR"][(head, rel)],
          data_stats["ALL_TRUE_H_WITH_TR"][(tail, rel)]
      ]))
      temp_neg = _corrupt_ent(existing_drug_ids, neg_ent, drug_ids, random_state)
    neg_triplets.append("_".join(map(str, temp_neg[:neg_ent])))

  df = pd.DataFrame({
      'Drug1_ID': pos_triplets[:, 0],
      'Drug2_ID': pos_triplets[:, 1],
      'Y': pos_triplets[:, 2],
      'Neg Samples': neg_triplets
      })

  filename = f'{dirname}/{dataset}/pair_pos_neg_triplets.csv'
  df.to_csv(filename, index=False)
  print(f"\nData saved as {filename}!")

  save_data(data_stats, 'data_statistics.pkl', dirname="data/preprocessed", dataset="drugbank")

## Split Dataset

In [14]:
def split_data(class_col, seed, test_ratio, n_folds, dirname="data/preprocessed", dataset="drugbank"):
  filename = os.path.join(dirname, dataset, "pair_pos_neg_triplets.csv")
  output_prefix = os.path.splitext(filename)[0]

  df = pd.read_csv(filename)

  splitter = StratifiedShuffleSplit(n_splits=n_folds, test_size=test_ratio, random_state=seed)
  for fold_idx, (train_idx, test_idx) in enumerate(splitter.split(X=df, y=df[class_col])):
    print(f"Generating fold {fold_idx}")
    train_df = df.iloc[train_idx]
    test_df = df.iloc[test_idx]

    train_file = f"{output_prefix}_train_fold{fold_idx}.csv"
    test_file = f"{output_prefix}_test_fold{fold_idx}.csv"
    train_df.to_csv(train_file, index=False)
    test_df.to_csv(test_file, index=False)

    print(f"Train data saved as {train_file}")
    print(f"Test data saved as {test_file}")

In [15]:
drug_data = load_drug_mol_data(df_drugbank)

[20:05:57] SMILES Parse Error: syntax error while parsing: OC1=CC=CC(=C1)C-1=C2\CCC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC(O)=CC=C1)C1=CC(O)=CC=C1)\C1=CC(O)=CC=C1
[20:05:57] SMILES Parse Error: check for mistakes around position 76:
[20:05:57] C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC(O)=CC=C1)C
[20:05:57] ~~~~~~~~~~~~~~~~~~~~^
[20:05:57] SMILES Parse Error: Failed parsing SMILES 'OC1=CC=CC(=C1)C-1=C2\CCC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC(O)=CC=C1)C1=CC(O)=CC=C1)\C1=CC(O)=CC=C1' for input: 'OC1=CC=CC(=C1)C-1=C2\CCC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC(O)=CC=C1)C1=CC(O)=CC=C1)\C1=CC(O)=CC=C1'
  return atom_features_tensor, edge_index.T, edge_features_tensor, line_graph_edge_index
Processing drugs: 100%|██████████| 1705/1705 [00:03<00:00, 497.52it/s]


Data saved in data/preprocessed/drugbank/drug_data.pkl


In [16]:
generate_pair_triplets(df_drugbank, neg_ent=1, seed=42)

Loading Data Statics


Processing Data Stats: 100%|██████████| 191798/191798 [00:00<00:00, 340280.33it/s]


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
       'DB00681', 'DB00735', 'DB00738', 'DB00793', 'DB00826', 'DB00834',
       'DB00857', 'DB00877', 'DB00936', 'DB01007', 'DB01026', 'DB01034',
       'DB01072', 'DB01091', 'DB01099', 'DB01110', 'DB01127', 'DB01141',
       'DB01153', 'DB01157', 'DB01167', 'DB01188', 'DB01243', 'DB01254',
       'DB01263', 'DB01422', 'DB02513', 'DB02703', 'DB03793', 'DB04794',
       'DB06290', 'DB06717', 'DB06820', 'DB08820', 'DB08933', 'DB08943',
       'DB08958', 'DB09031', 'DB09040', 'DB09041', 'DB09048', 'DB09063',
       'DB09073', 'DB09118', 'DB09330'], dtype='<U7'), (np.str_('DB00508'), np.str_('72')): array(['DB00205', 'DB00250', 'DB00254', 'DB00358', 'DB00440', 'DB00468',
       'DB00608', 'DB00613', 'DB00664', 'DB00908', 'DB01087', 'DB01103',
       'DB01117', 'DB01131', 'DB01218', 'DB01299', 'DB01611', 'DB06697',
       'DB06708', 'DB09274'], dtype='<U7'), (np.str_('DB09143'), np.str_('72')): array(['DB00091', 'DB00176', 'DB

Generating Negative Triplets: 100%|██████████| 191798/191798 [00:22<00:00, 8662.67it/s] 



Data saved as data/preprocessed/drugbank/pair_pos_neg_triplets.csv!
Data saved in data/preprocessed/drugbank/data_statistics.pkl


In [17]:
split_data('Y', seed=42, test_ratio=0.2, n_folds=3)

Generating fold 0
Train data saved as data/preprocessed/drugbank/pair_pos_neg_triplets_train_fold0.csv
Test data saved as data/preprocessed/drugbank/pair_pos_neg_triplets_test_fold0.csv
Generating fold 1
Train data saved as data/preprocessed/drugbank/pair_pos_neg_triplets_train_fold1.csv
Test data saved as data/preprocessed/drugbank/pair_pos_neg_triplets_test_fold1.csv
Generating fold 2
Train data saved as data/preprocessed/drugbank/pair_pos_neg_triplets_train_fold2.csv
Test data saved as data/preprocessed/drugbank/pair_pos_neg_triplets_test_fold2.csv


# geting datasets

In [18]:
import torch
from torch_geometric.data import Batch, Data
from sklearn.model_selection import StratifiedShuffleSplit
from torch.utils.data import Dataset, DataLoader
import pickle
import pandas as pd
import numpy as np
import math

In [19]:
NUM_FEATURES = None
NUM_EDGE_FEATURES = None
biparticle_edge_dict = dict()
drug_num_node_indices = dict()

In [20]:
def total_num_rel():
  return 86

def split_train_valid(data, fold, val_ratio=0.2):
  cv_split = StratifiedShuffleSplit(n_splits=2, test_size=val_ratio, random_state=fold)
  pos_triplets, neg_samples = data
  train_idx, val_idx = next(iter(cv_split.split(X=pos_triplets, y=pos_triplets[:, 2])))
  return (pos_triplets[train_idx], neg_samples[train_idx]), (pos_triplets[val_idx], neg_samples[val_idx])

In [21]:
def load_split(split_name, dirname="data/preprocessed"):
  path = os.path.join(dirname, "drugbank", f"pair_pos_neg_triplets_{split_name}.csv")
  df = pd.read_csv(path)
  pos_triplets = [(d1, d2, r) for d1, d2, r in zip(df['Drug1_ID'], df['Drug2_ID'], df['Y'])]
  neg_samples = [[str(s) for s in neg.split('_')] for neg in df['Neg Samples']]
  return np.array(pos_triplets), np.array(neg_samples)

In [22]:
def load_ddi_data_fold(fold, batch_size=32, data_size_ratio=1.0, valid_ratio=0.2, dirname="data/preprocessed"):
  global NUM_FEATURES, NUM_EDGE_FEATURES, drug_num_node_indices

  dataset_name = "drugbank"
  drug_data_path = os.path.join(dirname, dataset_name, "drug_data.pkl")
  with open(drug_data_path, 'rb') as f:
    raw_drug_data = pickle.load(f)

  sample = next(iter(raw_drug_data.values()))
  NUM_FEATURES, NUM_EDGE_FEATURES = sample[0].shape[1], sample[2].shape[1]

  all_drug_data = {
      drug_id: CustomData(x=data[0], edge_index=data[1], edge_feats=data[2], line_graph_edge_index=data[3])
      for drug_id, data in raw_drug_data.items()
  }

  drug_num_node_indices = {
      drug_id: torch.zeros(data.x.size(0)).long() for drug_id, data in all_drug_data.items()
  }

  train_split = load_split(f"train_fold{fold}", dirname)
  test_split = load_split(f"test_fold{fold}", dirname)
  train_split, val_split = split_train_valid(train_split, fold, valid_ratio)

  train_data = DrugDataset(train_split, all_drug_data, seed=fold, ratio=data_size_ratio)
  val_data = DrugDataset(val_split, all_drug_data, seed=fold)
  test_data = DrugDataset(test_split, all_drug_data, seed=fold)

  print(f"\nFold {fold} - Train: {len(train_data)}, Val: {len(val_data)}, Test: {len(test_data)}")

  return (
        DrugDataLoader(train_data, batch_size=batch_size, shuffle=True),
        DrugDataLoader(val_data, batch_size=batch_size),
        DrugDataLoader(test_data, batch_size=batch_size),
        NUM_FEATURES,
        NUM_EDGE_FEATURES
    )

In [39]:
class DrugDataset(Dataset):
  def __init__(self, pos_neg_pairs, all_drug_data, ratio=1.0, seed=0):
    self.rng = np.random.RandomState(seed)
    self.all_drug_data = all_drug_data
    self.drug_ids = list(all_drug_data.keys())

    pos_triplets, neg_samples = pos_neg_pairs
    self.pair_triplets = [
        (pos, neg) for pos, neg in zip(pos_triplets, neg_samples)
        if pos[0] in self.drug_ids and pos[1] in self.drug_ids
    ]

    if ratio <1.0:
      self.rng.shuffle(self.pair_triplets)
      limit = math.ceil(len(self.pair_triplets)*ratio)
      self.pair_triplets = self.pair_triplets[:limit]

  def __len__(self):
    return len(self.pair_triplets)
  def __getitem__(self, index):
    return self.pair_triplets[index]

  def collate_fn(self, batch):
    node_map = {}
    self.node_seqs = []
    unique_pairs = []
    combo_seen = {}

    pos_indices, neg_indices, labels = [], [], []
    drug_feats = []

    self.node_i_seqs, self.node_j_seqs = [], []

    for (h, t, r), negs in batch:
      h_id, t_id = str(h), str(t)
      h_idx, h_len = self._add_node(h_id, node_map, drug_feats, self.node_seqs)
      t_idx, t_len = self._add_node(t_id, node_map, drug_feats, self.node_seqs)

      pos_combo = self._add_pair((h_idx, t_idx), (h_id, t_id), (h_len, t_len), combo_seen, unique_pairs, self.node_i_seqs, self.node_j_seqs)
      pos_indices.append(pos_combo)
      labels.append(int(r))

      for neg in negs:
        neg_id, role = neg.split("$")
        neg_idx, neg_len = self._add_node(neg_id, node_map, drug_feats, self.node_seqs)

        if role.lower() == "h":
          pair = self._add_pair((neg_idx, t_idx), (neg_id, t_id), (neg_len, t_len), combo_seen, unique_pairs, self.node_i_seqs, self.node_j_seqs)
        else:
          pair = self._add_pair((h_idx, neg_idx), (h_id, neg_id), (h_len, neg_len), combo_seen, unique_pairs, self.node_i_seqs, self.node_j_seqs)

        neg_indices.append(pair)

    return (
      Batch.from_data_list(drug_feats, follow_batch=['edge_index']),
      Batch.from_data_list(unique_pairs, follow_batch=['edge_index']),
      torch.LongTensor(labels),
      torch.LongTensor(pos_indices + neg_indices),
      torch.cat(self.node_j_seqs),
      torch.cat(self.node_i_seqs),
    )

  def _add_node(self, drug_id, node_map, feat_list, seqs):
    if drug_id not in node_map:
      idx = len(node_map)
      node_map[drug_id] = idx
      feat_list.append(self.all_drug_data[drug_id])
      offset = seqs[-1][-1] + 1 if seqs else 0
      seqs.append(torch.arange(self.all_drug_data[drug_id].x.size(0)) + offset)
    return node_map[drug_id], self.all_drug_data[drug_id].x.size(0)

  def _add_pair(self, idx_pair, id_pair, size_pair, seen, pair_list, i_seqs, j_seqs):
    if idx_pair not in seen:
      idx = len(seen)
      seen[idx_pair] = idx

      edge_idx = biparticle_edge_dict.get(id_pair)
      if edge_idx is None:
        i = torch.arange(size_pair[1]).repeat(size_pair[0])
        j = torch.arange(size_pair[0]).repeat_interleave(size_pair[1])
        edge_idx = torch.stack([j, i])
        biparticle_edge_dict[id_pair] = edge_idx

      pair_list.append(PairData(drug_num_node_indices[id_pair[0]], drug_num_node_indices[id_pair[1]], edge_idx))
      self.node_j_seqs.append(self.node_seqs[idx_pair[0]])
      self.node_i_seqs.append(self.node_seqs[idx_pair[1]])

    return seen[idx_pair]

In [24]:
class DrugDataLoader(DataLoader):
  def __init__(self, dataset, **kwargs):
    super().__init__(dataset, collate_fn=dataset.collate_fn, **kwargs)

In [25]:
class PairData(Data):
  def __init__(self, j_indices, i_indices, edge_index):
    super().__init__()
    self.j_indices = j_indices
    self.i_indices = i_indices
    self.edge_index = edge_index

  def __inc__(self, key, value, *args, **kwargs):
    if key == "edge_index":
      return torch.tensor([[self.j_indices.size(0)], [self.i_indices.size(0)]])
    if key in {"i_indices", "j_indices"}:
      return 1
    return super().__inc__(key, value, *args, **kwargs)

In [26]:
class CustomData(Data):
  def __inc__(self, key, value, *args, **kwargs):
    if key == "line_graph_edge_index":
      return self.edge_index.size(1) if self.edge_index.nelement() else 0
    return super().__inc__(key, value, *args, **kwargs)

In [27]:
train_loader, val_loader, test_loader, num_features, num_edge_features = load_ddi_data_fold(
    fold=0, batch_size=32, data_size_ratio=1.0
)


Fold 0 - Train: 122750, Val: 30688, Test: 38360


# Model

In [28]:
import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.utils import degree
from torch_scatter import scatter

class DropoutIfNeeded(nn.Module):
    def __init__(self, p=0.0):
        super().__init__()
        self.dropout = nn.Dropout(p) if p > 0 else nn.Identity()

    def forward(self, x):
        return self.dropout(x)

class MultiHeadAttention(nn.Module):
    def __init__(self, in_dim, heads):
        super().__init__()
        self.heads = heads
        self.in_dim = in_dim
        self.dim_per_head = in_dim // heads
        assert in_dim % heads == 0, "in_dim must be divisible by number of heads"

        self.query = nn.Linear(in_dim, in_dim, bias=False)
        self.key = nn.Linear(in_dim, in_dim, bias=False)
        self.scale = self.dim_per_head ** -0.5

    def forward(self, x_i, x_j):
        Q = self.query(x_i).view(-1, self.heads, self.dim_per_head)
        K = self.key(x_j).view(-1, self.heads, self.dim_per_head)
        attn_scores = (Q * K).sum(-1) * self.scale
        attn_weights = torch.sigmoid(attn_scores)  # for interpretability
        return attn_weights.mean(dim=1), attn_weights  # Return mean and per-head attention

class GatedMessagePassingLayer(nn.Module):
    def __init__(self, node_dim, edge_dim, message_steps, dropout=0.0, heads=4):
        super().__init__()
        self.node_dim = node_dim
        self.message_steps = message_steps
        self.dropout = DropoutIfNeeded(dropout)
        self.attn = MultiHeadAttention(node_dim, heads)
        self.edge_embedding = nn.Linear(edge_dim, node_dim)
        self.lstm = nn.LSTMCell(node_dim, node_dim)  # LSTM for updating

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        edge_attr = self.edge_embedding(data.edge_feats)

        h, c = x, torch.zeros_like(x)
        row, col = edge_index

        for _ in range(self.message_steps):
            attn, _ = self.attn(x[row], x[col])
            edge_weight = attn.unsqueeze(-1) * edge_attr

            msg = x[row] * edge_weight
            src, dst = data.line_graph_edge_index
            msg_agg = scatter(msg[src], dst, dim=0, dim_size=msg.size(0), reduce='add')
            msg = msg + msg_agg

            update = scatter(msg, col, dim=0, dim_size=x.size(0), reduce='add')
            h, c = self.lstm(update, (h, c))  # LSTM update step
            x = self.dropout(h)

        return x

class GmpnnCSNetDrugBank(nn.Module):
    def __init__(self, node_feature_dim, edge_feature_dim, hidden_dim, rel_types, message_steps, dropout=0.0, heads=4):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.expanded_dim = hidden_dim * 2

        self.node_encoder = nn.Sequential(
            nn.Linear(node_feature_dim, hidden_dim),
            DropoutIfNeeded(dropout),
            nn.PReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            DropoutIfNeeded(dropout)
        )

        self.message_passing = GatedMessagePassingLayer(
            node_dim=hidden_dim,
            edge_dim=edge_feature_dim,
            message_steps=message_steps,
            dropout=dropout,
            heads=heads
        )

        self.drug_i_projection = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        self.drug_j_projection = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))

        self.interaction_bias = nn.Parameter(torch.zeros(hidden_dim))

        self.relation_embeddings = nn.Embedding(rel_types, hidden_dim)
        nn.init.xavier_uniform_(self.drug_i_projection)
        nn.init.xavier_uniform_(self.drug_j_projection)


    def forward(self, batch, return_repr=False):
        drug_data, unique_drug_pair, rels, drug_pair_indices, node_j_for_pairs, node_i_for_pairs = batch


        drug_data.x = self.node_encoder(drug_data.x)
        drug_data.x = self.message_passing(drug_data)
        x_j = drug_data.x[node_j_for_pairs]
        x_i = drug_data.x[node_i_for_pairs]
        i_proj = x_i[unique_drug_pair.edge_index[1]] @ self.drug_i_projection
        j_proj = x_j[unique_drug_pair.edge_index[0]] @ self.drug_j_projection

        pair_repr = i_proj * j_proj

        pair_repr = scatter(pair_repr, unique_drug_pair.edge_index_batch, reduce='add', dim=0)[drug_pair_indices]

        if return_repr:
            return pair_repr  # for contrastive learning use case

        p_scores, n_scores = self.compute_interaction_scores(pair_repr, rels)
        return p_scores, n_scores

    def compute_interaction_scores(self, pair_repr, rels):
        batch_size = len(rels)
        neg_samples_per_pos = (len(pair_repr) - batch_size) // batch_size

        all_rels = torch.cat([
            rels,
            torch.repeat_interleave(rels, neg_samples_per_pos, dim=0)
        ], dim=0)

        rel_embeddings = self.relation_embeddings(all_rels)
        scores = (pair_repr * rel_embeddings).sum(dim=-1)

        pos_scores = scores[:batch_size].unsqueeze(-1)
        neg_scores = scores[batch_size:].view(batch_size, -1, 1)
        return pos_scores, neg_scores


In [29]:
model = GmpnnCSNetDrugBank(
    node_feature_dim=128,
    edge_feature_dim=32,
    hidden_dim=64,
    rel_types=86,
    message_steps=2
)


## Train on Fold

In [30]:
from torch import nn
import torch.nn.functional as f

In [47]:
class SigmoidLoss(nn.Module):
    def forward(self, p_scores, n_scores):
        p_scores = p_scores.view(-1)
        n_scores = n_scores.view(-1)

        scores = torch.cat([p_scores, n_scores], dim=0)
        labels = torch.cat([
            torch.ones_like(p_scores),
            torch.zeros_like(n_scores)
        ], dim=0)

        loss = F.binary_cross_entropy_with_logits(scores, labels)

        p_loss = F.binary_cross_entropy_with_logits(p_scores, torch.ones_like(p_scores))
        n_loss = F.binary_cross_entropy_with_logits(n_scores, torch.zeros_like(n_scores))

        return loss, p_loss, n_loss


In [32]:
from operator import le
from sklearn import metrics
from collections import defaultdict
import json
import numpy as np


def do_compute_metrics(probas_pred, target):
    pred = (probas_pred >= 0.5).astype(int)
    acc = metrics.accuracy_score(target, pred)
    auroc = metrics.roc_auc_score(target, probas_pred)
    f1_score = metrics.f1_score(target, pred)
    precision = metrics.precision_score(target, pred)
    recall = metrics.recall_score(target, pred)
    p, r, t = metrics.precision_recall_curve(target, probas_pred)
    int_ap = metrics.auc(r, p)
    ap= metrics.average_precision_score(target, probas_pred)

    return acc, auroc, f1_score, precision, recall, int_ap, ap

In [46]:
# @title
from datetime import datetime
import numpy as np
import torch
from torch import optim
import time
from tqdm import tqdm



dataset_name = 'drugbank'
fold_i = 0
dropout = 0.2
n_iter = 3
TOTAL_NUM_RELS = total_num_rel()
batch_size = 512
data_size_ratio = 1
device = 'cuda' if torch.cuda.is_available() else 'cpu'
hid_feats = 64
rel_total = TOTAL_NUM_RELS
lr = 1e-3
weight_decay = 5e-4
n_epochs = 2
kge_feats = 64

def do_compute(model, batch, device):

        batch = [t.to(device) for t in batch]
        p_score, n_score = model(batch)
        assert p_score.ndim == 2
        assert n_score.ndim == 3
        probas_pred = np.concatenate([torch.sigmoid(p_score.detach()).cpu().mean(dim=-1), torch.sigmoid(n_score.detach()).mean(dim=-1).view(-1).cpu()])
        ground_truth = np.concatenate([np.ones(p_score.shape[0]), np.zeros(n_score.shape[:2]).reshape(-1)])

        return p_score, n_score, probas_pred, ground_truth


def run_batch(model, optimizer, data_loader, epoch_i, desc, loss_fn, device):
        total_loss = 0
        loss_pos = 0
        loss_neg = 0
        probas_pred = []
        ground_truth = []

        for batch in tqdm(data_loader, desc= f'{desc} Epoch {epoch_i}'):
            p_score, n_score, batch_probas_pred, batch_ground_truth = do_compute(model, batch, device)

            probas_pred.append(batch_probas_pred)
            ground_truth.append(batch_ground_truth)

            loss, loss_p, loss_n = loss_fn(p_score, n_score)
            if model.training:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            total_loss += loss.item()
            loss_pos += loss_p.item()
            loss_neg += loss_n.item()
        total_loss /= len(data_loader)
        loss_pos /= len(data_loader)
        loss_neg /= len(data_loader)

        probas_pred = np.concatenate(probas_pred)
        ground_truth = np.concatenate(ground_truth)

        return total_loss, do_compute_metrics(probas_pred, ground_truth)


def print_metrics(loss, acc, auroc, f1_score, precision, recall, int_ap, ap):
    print(f'loss: {loss:.4f}, acc: {acc:.4f}, roc: {auroc:.4f}, f1: {f1_score:.4f}, ', end='')
    print(f'p: {precision:.4f}, r: {recall:.4f}, int-ap: {int_ap:.4f}, ap: {ap:.4f}')

    return f1_score


def train(model, train_data_loader, val_data_loader, test_data_loader, loss_fn, optimizer, n_epochs, device, scheduler):
    for epoch_i in range(1, n_epochs+1):
        start = time.time()
        model.train()

        ## Training
        train_loss, train_metrics = run_batch(model, optimizer, train_data_loader, epoch_i, 'train', loss_fn, device)
        if scheduler:
            scheduler.step()

        model.eval()
        with torch.no_grad():
            ## Validation
            if val_data_loader:
                val_loss, val_metrics = run_batch(model, optimizer, val_data_loader, epoch_i, 'val', loss_fn, device)

            ## Test Set Evaluation
            if test_data_loader:
                test_loss, test_metrics = run_batch(model, optimizer, test_data_loader, epoch_i, 'test', loss_fn, device)

        print(f'\n#### Epoch time {time.time() - start:.4f}s')
        print_metrics(train_loss, *train_metrics)

        if val_data_loader:
            print('#### Validation')
            print_metrics(val_loss, *val_metrics)

        if test_data_loader:
            print('#### Test')
            print_metrics(test_loss, *test_metrics)



train_data_loader, val_data_loader, test_data_loader, NUM_FEATURES, NUM_EDGE_FEATURES = \
    load_ddi_data_fold(
    fold=0, batch_size=32, data_size_ratio=1.0)

GmpnnNet = GmpnnCSNetDrugBank if dataset_name == 'drugbank' else GmpnnCSNetDrugBank

model = GmpnnNet(NUM_FEATURES, NUM_EDGE_FEATURES, hid_feats, rel_total, n_iter, dropout)
loss_fn = SigmoidLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))

time_stamp = f'{datetime.now()}'.replace(':', '_')


model.to(device=device)
print(f'Training on {device}.')
print(f'Starting fold_{fold_i} at', datetime.now())
train(model, train_data_loader, val_data_loader, test_data_loader, loss_fn, optimizer, n_epochs, device, scheduler)


Fold 0 - Train: 122750, Val: 30688, Test: 38360
Training on cuda.
Starting fold_0 at 2025-04-16 20:37:58.036537


train Epoch 1: 100%|██████████| 3836/3836 [01:58<00:00, 32.31it/s]
val Epoch 1: 100%|██████████| 959/959 [00:21<00:00, 45.08it/s]
test Epoch 1: 100%|██████████| 1199/1199 [00:26<00:00, 46.08it/s]



#### Epoch time 166.3346s
loss: 0.6654, acc: 0.6049, roc: 0.6483, f1: 0.6163, p: 0.5991, r: 0.6344, int-ap: 0.6295, ap: 0.6295
#### Validation
loss: 0.6275, acc: 0.6629, roc: 0.7195, f1: 0.6979, p: 0.6322, r: 0.7788, int-ap: 0.6942, ap: 0.6942
#### Test
loss: 0.6257, acc: 0.6617, roc: 0.7206, f1: 0.6972, p: 0.6310, r: 0.7788, int-ap: 0.6967, ap: 0.6967


train Epoch 2: 100%|██████████| 3836/3836 [01:56<00:00, 32.84it/s]
val Epoch 2: 100%|██████████| 959/959 [00:20<00:00, 47.09it/s]
test Epoch 2: 100%|██████████| 1199/1199 [00:26<00:00, 45.89it/s]


#### Epoch time 163.6497s
loss: 0.6074, acc: 0.6712, roc: 0.7333, f1: 0.6818, p: 0.6606, r: 0.7043, int-ap: 0.7122, ap: 0.7122
#### Validation
loss: 0.5877, acc: 0.6891, roc: 0.7598, f1: 0.6792, p: 0.7015, r: 0.6582, int-ap: 0.7349, ap: 0.7349
#### Test
loss: 0.5863, acc: 0.6887, roc: 0.7610, f1: 0.6788, p: 0.7011, r: 0.6579, int-ap: 0.7344, ap: 0.7344



