In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Code taken from:
# https://www.blopig.com/blog/2022/02/how-to-turn-a-smiles-string-into-a-molecular-graph-for-pytorch-geometric/

# general tools
import numpy as np

# RDkit
from rdkit import Chem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix

import pandas as pd
# Pytorch and Pytorch Geometri
import torch
from torch_geometric.data import Data
from torch.utils.data import DataLoader
from tqdm import tqdm
import random

from torch.nn import Linear, MSELoss
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.nn import global_mean_pool

In [3]:
# Step 1: Atom Featurisation
def one_hot_encoding(x, permitted_list):
    """
    Maps input elements x which are not in the permitted list to the last element
    of the permitted list.
    """
    if x not in permitted_list:
        x = permitted_list[-1]
    return [int(x == s) for s in permitted_list]

In [4]:
def get_atom_features(atom, use_chirality = True, hydrogens_implicit = True):
    """
    Takes an RDKit atom object as input and gives a 1d-numpy array of atom features as output.
    """

    # define list of permitted atoms

    permitted_list_of_atoms =  ["C","N","O","S","F","Si","P","Cl","Br","Mg","Na","Ca","Fe","As","Al","I", "B","V","K",
                                "Tl","Yb","Sb","Sn","Ag","Pd","Co","Se","Ti","Zn", "Li","Ge","Cu","Au","Ni","Cd",
                                "In","Mn","Zr","Cr","Pt","Hg","Pb","Unknown"]

    if hydrogens_implicit is False:
        permitted_list_of_atoms = ["H"] + permitted_list_of_atoms

    # compute atom features

    atom_type_enc = one_hot_encoding(str(atom.GetSymbol()), permitted_list_of_atoms)

    n_heavy_neighbors_enc = one_hot_encoding(int(atom.GetDegree()), [0, 1, 2, 3, 4, "MoreThanFour"])

    formal_charge_enc = one_hot_encoding(int(atom.GetFormalCharge()), [-3, -2, -1, 0, 1, 2, 3, "Extreme"])

    hybridisation_type_enc = one_hot_encoding(
                str(atom.GetHybridization()),
                ["S", "SP", "SP2", "SP3", "SP3D", "SP3D2", "OTHER"]
                                              )

    is_in_a_ring_enc = [int(atom.IsInRing())]

    is_aromatic_enc = [int(atom.GetIsAromatic())]

    atomic_mass_scaled = [float((atom.GetMass() - 10.812)/116.092)]

    vdw_radius_scaled = [float((Chem.GetPeriodicTable().GetRvdw(atom.GetAtomicNum()) - 1.5)/0.6)]

    covalent_radius_scaled = [float((Chem.GetPeriodicTable().GetRcovalent(atom.GetAtomicNum()) - 0.64)/0.76)]

    atom_feature_vector = (
        atom_type_enc + n_heavy_neighbors_enc + formal_charge_enc +
        hybridisation_type_enc + is_in_a_ring_enc + is_aromatic_enc +
        atomic_mass_scaled + vdw_radius_scaled + covalent_radius_scaled
    )

    if use_chirality is True:
        chirality_type_enc = one_hot_encoding(
            str(atom.GetChiralTag()),
                                    ["CHI_UNSPECIFIED", "CHI_TETRAHEDRAL_CW", "CHI_TETRAHEDRAL_CCW", "CHI_OTHER"]
                                    )
        atom_feature_vector += chirality_type_enc

    if hydrogens_implicit is True:
        n_hydrogens_enc = one_hot_encoding(int(atom.GetTotalNumHs()), [0, 1, 2, 3, 4, "MoreThanFour"])
        atom_feature_vector += n_hydrogens_enc

    return np.array(atom_feature_vector)

In [5]:
# Step 2: bond Featurisation

def get_bond_features(bond,
                      use_stereochemistry = True):
    """
    Takes an RDKit bond object as input and gives a 1d-numpy array of bond features as output.
    """

    permitted_list_of_bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
                                    Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]

    bond_type_enc = one_hot_encoding(bond.GetBondType(), permitted_list_of_bond_types)

    bond_is_conj_enc = [int(bond.GetIsConjugated())]

    bond_is_in_ring_enc = [int(bond.IsInRing())]

    bond_feature_vector = bond_type_enc + bond_is_conj_enc + bond_is_in_ring_enc

    if use_stereochemistry is True:
        stereo_type_enc = one_hot_encoding(str(bond.GetStereo()), ["STEREOZ", "STEREOE", "STEREOANY", "STEREONONE"])
        bond_feature_vector += stereo_type_enc

    return np.array(bond_feature_vector)

In [6]:
def create_pytorch_geometric_graph_data_list_from_smiles_and_labels(x_smiles, y):
    """
    Inputs:

    x_smiles = [smiles_1, smiles_2, ....] ... a list of SMILES strings
    y = [y_1, y_2, ...] ... a list of numerial labels for the SMILES strings (such as associated pKi values)

    Outputs:

    data_list = [G_1, G_2, ...] ... a list of torch_geometric.data.Data objects which represent labeled molecular
    graphs that can readily be used for machine learning

    """

    data_list = []

    for (smiles, y_val) in zip(x_smiles, y):

        # convert SMILES to RDKit mol object
        mol = Chem.MolFromSmiles(smiles)

        # get feature dimensions
        n_nodes = mol.GetNumAtoms()
        n_edges = 2*mol.GetNumBonds()
        unrelated_smiles = "O=O"
        unrelated_mol = Chem.MolFromSmiles(unrelated_smiles)
        n_node_features = len(get_atom_features(unrelated_mol.GetAtomWithIdx(0)))
        n_edge_features = len(get_bond_features(unrelated_mol.GetBondBetweenAtoms(0,1)))

        # construct node feature matrix X of shape (n_nodes, n_node_features)
        X = np.zeros((n_nodes, n_node_features))

        for atom in mol.GetAtoms():
            X[atom.GetIdx(), :] = get_atom_features(atom)

        X = torch.tensor(X, dtype = torch.float)

        # construct edge index array E of shape (2, n_edges)
        (rows, cols) = np.nonzero(GetAdjacencyMatrix(mol))
        torch_rows = torch.from_numpy(rows.astype(np.int64)).to(torch.long)
        torch_cols = torch.from_numpy(cols.astype(np.int64)).to(torch.long)
        E = torch.stack([torch_rows, torch_cols], dim = 0)

        # construct edge feature array EF of shape (n_edges, n_edge_features)
        EF = np.zeros((n_edges, n_edge_features))

        for (k, (i,j)) in enumerate(zip(rows, cols)):

            EF[k] = get_bond_features(mol.GetBondBetweenAtoms(int(i),int(j)))

        EF = torch.tensor(EF, dtype = torch.float)

        # construct label tensor
        y_tensor = torch.tensor(np.array([y_val]), dtype = torch.float)

        # construct Pytorch Geometric data object and append to data list
        data_list.append(Data(x = X, edge_index = E, edge_attr = EF, y = y_tensor))

    return data_list

In [7]:
# read data
smiles_df = pd.read_csv("/Users/niklaskiermeyer/Desktop/Codespace/DruxAI/data/preprocessed/drug_smiles_data.csv",
                        index_col=0)

In [8]:
data_list = create_pytorch_geometric_graph_data_list_from_smiles_and_labels(
    random.sample(smiles_df["smiles"].to_list(), 10),
    random.sample(smiles_df["auc_per_drug"].to_list(), 10))

In [9]:
data_list

[Data(x=[27, 79], edge_index=[2, 58], edge_attr=[58, 10], y=[1]),
 Data(x=[23, 79], edge_index=[2, 46], edge_attr=[46, 10], y=[1]),
 Data(x=[16, 79], edge_index=[2, 30], edge_attr=[30, 10], y=[1]),
 Data(x=[31, 79], edge_index=[2, 64], edge_attr=[64, 10], y=[1]),
 Data(x=[14, 79], edge_index=[2, 26], edge_attr=[26, 10], y=[1]),
 Data(x=[56, 79], edge_index=[2, 122], edge_attr=[122, 10], y=[1]),
 Data(x=[19, 79], edge_index=[2, 40], edge_attr=[40, 10], y=[1]),
 Data(x=[23, 79], edge_index=[2, 48], edge_attr=[48, 10], y=[1]),
 Data(x=[27, 79], edge_index=[2, 58], edge_attr=[58, 10], y=[1]),
 Data(x=[40, 79], edge_index=[2, 88], edge_attr=[88, 10], y=[1])]

In [10]:
class SimpleGNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(79, 128)
        self.fc = Linear(128, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # 1st Graph Convolution
        x = self.conv1(x, edge_index)
        x = torch.relu(x)

        # Global Pooling (mean)
        x = global_mean_pool(x, batch)

        return self.fc(x)

In [11]:
class SimpleGAT(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GATConv(79, 128, heads=8, dropout=0.6)  # Adjust parameters as needed
        self.fc = Linear(128 * 8, 1)  # Adjust input size for fully connected layer

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        # 1st Graph Attention Layer
        x = self.conv1(x, edge_index)
        x = torch.relu(x)

        # Global Pooling (mean)
        x = global_mean_pool(x, batch)

        return self.fc(x)

In [12]:
from torch_geometric.data import Batch

def custom_collate(batch):
    return Batch.from_data_list(batch)

In [14]:
# Training data creation
data_list = create_pytorch_geometric_graph_data_list_from_smiles_and_labels(
    random.sample(smiles_df["smiles"].to_list(), 10000),
    random.sample(smiles_df["auc_per_drug"].to_list(), 10000)
    )

dataloader = DataLoader(dataset=data_list, batch_size=2, collate_fn=custom_collate)

# Model creation
gnn = SimpleGAT().to("cpu")

# Loss function
loss_function = MSELoss()

# Optimizer
optimiser = torch.optim.Adam(gnn.parameters(), lr=1e-3)

# Training loop
for epoch in tqdm(range(100)):
    gnn.train()
    print(f"Epoch: {epoch}")

    for _, batch in enumerate(dataloader):
        # Move data to device
        batch = batch.to("cpu")

        # Forward pass
        output = gnn(batch)

        # Compute loss
        loss = loss_function(output[:, 0], torch.tensor(batch.y, dtype=torch.float32).to("cpu"))

        # Backpropagation
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

In [56]:
data_list

[Data(x=[38, 79], edge_index=[2, 84], edge_attr=[84, 10], y=[1]),
 Data(x=[31, 79], edge_index=[2, 68], edge_attr=[68, 10], y=[1]),
 Data(x=[29, 79], edge_index=[2, 64], edge_attr=[64, 10], y=[1]),
 Data(x=[31, 79], edge_index=[2, 66], edge_attr=[66, 10], y=[1]),
 Data(x=[16, 79], edge_index=[2, 36], edge_attr=[36, 10], y=[1]),
 Data(x=[18, 79], edge_index=[2, 38], edge_attr=[38, 10], y=[1]),
 Data(x=[28, 79], edge_index=[2, 62], edge_attr=[62, 10], y=[1]),
 Data(x=[28, 79], edge_index=[2, 62], edge_attr=[62, 10], y=[1]),
 Data(x=[33, 79], edge_index=[2, 72], edge_attr=[72, 10], y=[1]),
 Data(x=[36, 79], edge_index=[2, 80], edge_attr=[80, 10], y=[1]),
 Data(x=[18, 79], edge_index=[2, 38], edge_attr=[38, 10], y=[1]),
 Data(x=[23, 79], edge_index=[2, 50], edge_attr=[50, 10], y=[1]),
 Data(x=[22, 79], edge_index=[2, 50], edge_attr=[50, 10], y=[1]),
 Data(x=[46, 79], edge_index=[2, 102], edge_attr=[102, 10], y=[1]),
 Data(x=[13, 79], edge_index=[2, 26], edge_attr=[26, 10], y=[1]),
 Data(x=

In [59]:
data_list[0]

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 1., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 1., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 1., 1., 0., 0., 0., 1.],
        [0., 0., 0., 1., 1., 1., 0., 0., 0., 1.],
        [0., 0., 0., 1., 1., 1., 0., 0., 0., 1.],
        [0., 0., 0., 1., 1., 1., 0., 0., 0., 1.],
        [0., 0., 0., 1., 1., 1., 0., 0., 0., 1.],
        [0., 0., 0., 1., 1., 1., 0., 0., 0., 1.],
        [0., 0., 0., 1., 1., 1., 0., 0., 0., 1.],
        [0., 0., 0., 1., 1., 1., 0., 0., 0., 1.],
        [0., 0., 0., 1., 1., 1., 0., 0., 0., 1.],
        [1., 0., 0., 0., 1., 1., 0., 0., 0., 1.],
        [0., 0., 0., 1., 1., 1., 0., 0., 0., 1.],
        [1., 0., 0., 0., 1., 1., 0., 0., 0., 1.],
        [0., 1., 0., 0., 1., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 1., 1., 0., 0., 0., 1.],
        [0., 1., 0., 0., 1., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 1., 1., 0., 0., 0., 1.],
