<a href="https://colab.research.google.com/github/HarshaSatyavardhan/ot-cigin/blob/main/ot_on_cigin.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/devalab/CIGIN.git

Cloning into 'CIGIN'...
remote: Enumerating objects: 98, done.[K
remote: Counting objects: 100% (98/98), done.[K
remote: Compressing objects: 100% (76/76), done.[K
remote: Total 98 (delta 46), reused 52 (delta 18), pack-reused 0[K
Unpacking objects: 100% (98/98), 4.15 MiB | 7.88 MiB/s, done.


In [2]:
!pip install rdkit-pypi
!pip install dgl-cu100
!pip install dgl-cu113 dglgo -f https://data.dgl.ai/wheels/repo.html
!pip install mlflow
!pip install torch==1.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting rdkit-pypi
  Downloading rdkit_pypi-2022.9.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m33.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: rdkit-pypi
Successfully installed rdkit-pypi-2022.9.5
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
[31mERROR: Could not find a version that satisfies the requirement dgl-cu100 (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for dgl-cu100[0m[31m
[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.dgl.ai/wheels/repo.html
Collecting dgl-cu113
  Downloading https://data.dgl.ai/wheels/dgl_cu113-0.9.1.post1-cp39-cp39-manylinux1_x86_64.whl (239.2 MB)
[2K     

In [3]:
!pip install pandas

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [4]:
!rm -rf /content/sample_data
!mv /content/CIGIN/* ./

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from dgl import DGLGraph, heterograph
from dgl.nn.pytorch import Set2Set, NNConv, GATConv
from rdkit import Chem, RDLogger,rdBase
from rdkit.Chem import rdMolDescriptors as rdDesc
import numpy as np
import warnings
import pandas as pd
# rdkit imports
from rdkit import RDLogger
from rdkit import rdBase
from rdkit import Chem

# torch imports
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch

#dgl imports
import dgl


lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
rdBase.DisableLog('rdApp.error')
warnings.filterwarnings("ignore")

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


In [6]:
class GatherModel(nn.Module):
    def __init__(self,
                 node_input_dim=42,
                 edge_input_dim=10,
                 node_hidden_dim=42,
                 edge_hidden_dim=128,
                 num_step_message_passing=6,
                 gather="mpnn",
                 n_heads=3):
        super(GatherModel, self).__init__()
        self.num_step_message_passing = num_step_message_passing
        self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
        self.gather = gather
        self.set2set = Set2Set(node_hidden_dim, 2, 1) 
        if self.gather == "mpnn":
        	self.message_layer = nn.Linear(2 * node_hidden_dim, node_hidden_dim)
	        edge_network = nn.Sequential(
	            nn.Linear(edge_input_dim, edge_hidden_dim), nn.ReLU(),
	            nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim))
	        self.conv = NNConv(in_feats=node_hidden_dim,
	                           out_feats=node_hidden_dim,
	                           edge_func=edge_network,
	                           aggregator_type='sum',
                               residual=True
                                )
        	self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)
        elif self.gather == "gat":
        	self.n_heads = n_heads  
        	self.gat =  GATConv(node_hidden_dim,node_hidden_dim,self.n_heads)

    def forward(self, g, n_feat, e_feat):

        init = n_feat.clone()
        out = F.relu(self.lin0(n_feat))
        if self.gather == "mpnn":
            h = out.unsqueeze(0)                           
            for i in range(self.num_step_message_passing):
                m = torch.relu(self.conv(g, out, e_feat))
                out = self.message_layer(torch.cat([m, out],dim=1))
            return out + init



class CIGINModel(nn.Module):
    
    def __init__(self,
                 node_input_dim=42,
                 edge_input_dim=10,
                 node_hidden_dim=42,
                 edge_hidden_dim=42,
                 num_step_message_passing=8,
                 interaction='dot',
                 gather='mpnn'):
        super(CIGINModel, self).__init__()
        
        self.node_input_dim = node_input_dim
        self.node_hidden_dim =  node_hidden_dim
        self.edge_input_dim = edge_input_dim
        self.edge_hidden_dim = edge_hidden_dim
        self.num_step_message_passing = num_step_message_passing
        self.gather = gather
        self.interaction = interaction

        self.solute_gather = GatherModel(self.node_input_dim,self.edge_input_dim,
                              self.node_hidden_dim,self.edge_input_dim,
                              self.num_step_message_passing, 
                              self.gather, 3)
        self.solvent_gather = GatherModel(self.node_input_dim,self.edge_input_dim,
                              self.node_hidden_dim,self.edge_input_dim,
                              self.num_step_message_passing, 
                              self.gather, 3)

        self.fc1 = nn.Linear(8*self.node_hidden_dim,256)
        self.fc2 = nn.Linear(256,128)
        self.fc3 = nn.Linear(128,1)
        
        self.imap = nn.Linear(80,1)
        self.num_step_set2set=2
        self.num_layer_set2set=1
        self.set2set_solute = Set2Set(2*node_hidden_dim, self.num_step_set2set, self.num_layer_set2set)
        self.set2set_solvent = Set2Set(2*node_hidden_dim, self.num_step_set2set, self.num_layer_set2set)

    # def solu_feature(self):
    #   solute_H = torch.ones(self.solute_features.shape[0], self.solute_features.shape[1]) / self.solute_features.shape[1]
    #   return solute_H

    # def solv_features(self):
    #   solvent_H = torch.ones(self.solvent_features.shape[0], self.solvent_features.shape[1]) / self.solvent_features.shape[1]
    #   return solvent_H
        
    def forward(self, data):

        solute = data[0]
        solvent = data[1]

        solute_features = self.solute_gather(solute, solute.ndata['x'].float(), solute.edata['w'].float())
        solvent_features = self.solvent_gather(solvent, solvent.ndata['x'].float(), solvent.edata['w'].float())


        if 'dot' not in self.interaction:
            X1 = solute_features.unsqueeze(0)
            Y1= solvent_features.unsqueeze(1)
            X2 = X1.repeat(solvent_features.shape[0],1,1)
            Y2 = Y1.repeat(1,solute_features.shape[0],1)
            Z = torch.cat([X2,Y2],-1)

            if self.interaction == 'general':
                interaction_map = self.imap(Z).squeeze(2)
            if self.interaction == 'tanh-general':
                interaction_map = torch.tanh(self.imap(Z)).squeeze(2)

            ret_interaction_map = torch.clone(interaction_map)

        elif 'dot' in self.interaction :
            interaction_map = torch.mm(solute_features, solvent_features.t())
            if 'scaled' in self.interaction:
                interaction_map = interaction_map/(np.sqrt(self.node_hidden_dim))

            ret_interaction_map = torch.clone(interaction_map)
            interaction_map = torch.tanh(interaction_map)
        
        solvent_prime = torch.mm(interaction_map.t(), solute_features)
        solute_prime = torch.mm(interaction_map, solvent_features)

        solute_features = torch.cat((solute_features, solute_prime), dim=1)
        solvent_features = torch.cat((solvent_features, solvent_prime), dim=1)
        
        
        solute_features = self.set2set_solute(solute, solute_features)
        solvent_features = self.set2set_solvent(solvent, solvent_features)

        final_features = torch.cat((solute_features,solvent_features),1)
        predictions = torch.relu(self.fc1(final_features))
        predictions = torch.relu(self.fc2(predictions))
        predictions =  self.fc3(predictions)
        return predictions, ret_interaction_map




def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception("input {0} not in allowable set{1}:".format(
            x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]        
    return list(map(lambda s: x == s, allowable_set))

def get_atom_features(atom, stereo, features, explicit_H=False):

    """
    Method that computes atom level features from rdkit atom object
    :param atom: rdkit atom object
    :return: atom features, 1d numpy array
    """
    # todo: take list  of all possible atoms
    possible_atoms = ['C','N','O','S','F','P','Cl','Br','I','Si']
    atom_features  = one_of_k_encoding_unk(atom.GetSymbol(),possible_atoms)
    atom_features += one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1])
    atom_features += one_of_k_encoding_unk(atom.GetNumRadicalElectrons(), [0, 1])
    atom_features += one_of_k_encoding(atom.GetDegree(),[0, 1, 2, 3, 4, 5, 6]) 
    atom_features += one_of_k_encoding_unk(atom.GetFormalCharge(), [-1, 0, 1])
    atom_features += one_of_k_encoding_unk(atom.GetHybridization(), [
                Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D])
    atom_features += [int(i) for i in list("{0:06b}".format(features))]

    #todo: add aromacity,acceptor,donor and chirality
    if not explicit_H:
        atom_features += one_of_k_encoding_unk(atom.GetTotalNumHs(),[0, 1, 2, 3, 4])

    try:
        atom_features += one_of_k_encoding_unk(stereo,['R', 'S']) 
        atom_features += [atom.HasProp('_ChiralityPossible')]
    except Exception as e:
        
        atom_features +=  [False, False
                          ] + [atom.HasProp('_ChiralityPossible')]
        
    return np.array(atom_features)

def get_bond_features(bond):
    
    """
    Method that computes bond level features from rdkit bond object
    :param bond: rdkit bond object
    :return: bond features, 1d numpy array
    """
    
    bond_type = bond.GetBondType()
    bond_feats = [
      bond_type == Chem.rdchem.BondType.SINGLE, bond_type == Chem.rdchem.BondType.DOUBLE,
      bond_type == Chem.rdchem.BondType.TRIPLE, bond_type == Chem.rdchem.BondType.AROMATIC,
      bond.GetIsConjugated(),
      bond.IsInRing()
    ]
    bond_feats += one_of_k_encoding_unk(str(bond.GetStereo()),["STEREONONE", "STEREOANY", "STEREOZ", "STEREOE"])

    return np.array(bond_feats)

def get_graph_from_smile(molecule):
    
    """
    Method that constructs a molecular graph with nodes being the atoms
    and bonds being the edges.
    :param molecule: SMILE sequence
    :return: DGL graph object, Node features and Edge features
    """

    G = DGLGraph()
    molecule = Chem.MolFromSmiles(molecule)
    features = rdDesc.GetFeatureInvariants(molecule)
    
    stereo = Chem.FindMolChiralCenters(molecule)
    chiral_centers = [0]* molecule.GetNumAtoms()
    for i in stereo:
        chiral_centers[i[0]] = i[1]
        
    G.add_nodes(molecule.GetNumAtoms())
    node_features = []
    edge_features = []
    for i in range(molecule.GetNumAtoms()):

        atom_i = molecule.GetAtomWithIdx(i)
        atom_i_features =  get_atom_features(atom_i,chiral_centers[i],features[i])
        node_features.append(atom_i_features)
        
        for j in range(molecule.GetNumAtoms()):
            bond_ij = molecule.GetBondBetweenAtoms(i, j)
            if bond_ij is not None:
                G.add_edge(i,j)
                bond_features_ij = get_bond_features(bond_ij)
                edge_features.append(bond_features_ij)
                
    G.ndata['x'] = torch.FloatTensor(node_features)
    G.edata['w'] = torch.FloatTensor(edge_features)
    return G



def get_len_matrix(len_list):
    len_list = np.array(len_list)
    max_nodes = np.sum(len_list)
    curr_sum = 0
    len_matrix = []
    for l in len_list:
        curr = np.zeros(max_nodes)
        curr[curr_sum:curr_sum+l] = 1
        len_matrix.append(curr)
        curr_sum += l
    return np.array(len_matrix)
    
class Dataclass(Dataset):
    def __init__(self,dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # solute_file = 'mol_files/'+self.dataset.loc[idx]['FileHandle'] +'.mol'
        # solute = Chem.MolFromMolFile(solute_file) 
        # solute=Chem.MolToSmiles(solute)
        solute = self.dataset.loc[idx]['SoluteSMILES']
        mol = Chem.MolFromSmiles(solute)
        mol = Chem.AddHs(mol)
        solute = Chem.MolToSmiles(mol)
        solute_graph = get_graph_from_smile(solute)
        
        solvent = self.dataset.loc[idx]['SolventSMILES']
        mol = Chem.MolFromSmiles(solvent)
        mol = Chem.AddHs(mol)
        solvent = Chem.MolToSmiles(mol)
        
        solvent_graph = get_graph_from_smile(solvent)
        ddi_value = self.dataset.loc[idx]['DeltaGsolv']
        return [solute_graph, solvent_graph, ddi_value]

model= CIGINModel().to(device)
model.eval()

CIGINModel(
  (solute_gather): GatherModel(
    (lin0): Linear(in_features=42, out_features=42, bias=True)
    (set2set): Set2Set(
      n_iters=2
      (lstm): LSTM(84, 42)
    )
    (message_layer): Linear(in_features=84, out_features=42, bias=True)
    (conv): NNConv(
      (edge_func): Sequential(
        (0): Linear(in_features=10, out_features=10, bias=True)
        (1): ReLU()
        (2): Linear(in_features=10, out_features=1764, bias=True)
      )
      (res_fc): Identity()
    )
    (gru): GRU(42, 42)
  )
  (solvent_gather): GatherModel(
    (lin0): Linear(in_features=42, out_features=42, bias=True)
    (set2set): Set2Set(
      n_iters=2
      (lstm): LSTM(84, 42)
    )
    (message_layer): Linear(in_features=84, out_features=42, bias=True)
    (conv): NNConv(
      (edge_func): Sequential(
        (0): Linear(in_features=10, out_features=10, bias=True)
        (1): ReLU()
        (2): Linear(in_features=10, out_features=1764, bias=True)
      )
      (res_fc): Identity()
  

In [7]:
def collate(samples):
    solute_graphs, solvent_graphs, labels = map(list, zip(*samples))
    solute_graphs = dgl.batch(solute_graphs)
    solvent_graphs = dgl.batch(solvent_graphs)
    solute_len_matrix = get_len_matrix(solute_graphs.batch_num_nodes())
    solvent_len_matrix = get_len_matrix(solvent_graphs.batch_num_nodes())
    return solute_graphs, solvent_graphs, solute_len_matrix, solvent_len_matrix, labels

In [8]:
project_name = 'cigin'
interaction = 'dot'
max_epochs = 10
batch_size = 128

In [10]:
train_df = pd.read_csv('/content/CIGIN_V2/data/train.csv', sep=";")
valid_df = pd.read_csv('/content/CIGIN_V2/data/valid.csv', sep=";")

train_dataset = Dataclass(train_df)
valid_dataset = Dataclass(valid_df)

train_loader = DataLoader(train_dataset, collate_fn=collate, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, collate_fn=collate, batch_size=128)

In [11]:
model = CIGINModel(interaction=interaction)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, patience=5, mode='min', verbose=True)

In [12]:
from tqdm import tqdm
import torch
import numpy as np

loss_fn = torch.nn.MSELoss()
mae_loss_fn = torch.nn.L1Loss()

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

def get_metrics(model, data_loader):
    valid_outputs = []
    valid_labels = []
    valid_loss = []
    valid_mae_loss = []
    for solute_graphs, solvent_graphs, solute_lens, solvent_lens, labels in tqdm(data_loader):
        outputs, i_map = model(
            [solute_graphs.to(device), solvent_graphs.to(device), torch.tensor(solute_lens).to(device),
             torch.tensor(solvent_lens).to(device)])
        loss = loss_fn(outputs, torch.tensor(labels).to(device).float())
        mae_loss = mae_loss_fn(outputs, torch.tensor(labels).to(device).float())
        valid_outputs += outputs.cpu().detach().numpy().tolist()
        valid_loss.append(loss.cpu().detach().numpy())
        valid_mae_loss.append(mae_loss.cpu().detach().numpy())
        valid_labels += labels

    loss = np.mean(np.array(valid_loss).flatten())
    mae_loss = np.mean(np.array(valid_mae_loss).flatten())
    return loss, mae_loss


def train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name):
    best_val_loss = 100
    for epoch in range(max_epochs):
        model.train()
        running_loss = []
        tq_loader = tqdm(train_loader)
        o = {}
        for samples in tq_loader:
            optimizer.zero_grad()
            outputs, interaction_map = model(
                [samples[0].to(device), samples[1].to(device), torch.tensor(samples[2]).to(device),
                 torch.tensor(samples[3]).to(device)])
            l1_norm = torch.norm(interaction_map, p=2) * 1e-4
            loss = loss_fn(outputs, torch.tensor(samples[4]).to(device).float()) + l1_norm
            loss.backward()
            optimizer.step()
            loss = loss - l1_norm
            running_loss.append(loss.cpu().detach())
            tq_loader.set_description(
                "Epoch: " + str(epoch + 1) + "  Training loss: " + str(np.mean(np.array(running_loss))))
        model.eval()
        val_loss, mae_loss = get_metrics(model, valid_loader)
        scheduler.step(val_loss)
        print("Epoch: " + str(epoch + 1) + "  train_loss " + str(np.mean(np.array(running_loss))) + " Val_loss " + str(
            val_loss) + " MAE Val_loss " + str(mae_loss))
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "/content/runs/best_model.tar")

In [14]:
train(max_epochs, model, optimizer, scheduler, train_loader, valid_loader, project_name)

Epoch: 1  Training loss: 1.6648848: 100%|██████████| 1/1 [00:00<00:00,  7.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.96it/s]


Epoch: 1  train_loss 1.6648848 Val_loss 1.2317414 MAE Val_loss 0.95316017


Epoch: 2  Training loss: 1.2317415: 100%|██████████| 1/1 [00:00<00:00,  7.86it/s]
100%|██████████| 1/1 [00:00<00:00, 12.53it/s]


Epoch: 2  train_loss 1.2317415 Val_loss 1.6051244 MAE Val_loss 1.1040303


Epoch: 3  Training loss: 1.6051244: 100%|██████████| 1/1 [00:00<00:00,  7.49it/s]
100%|██████████| 1/1 [00:00<00:00, 12.47it/s]


Epoch: 3  train_loss 1.6051244 Val_loss 1.4120902 MAE Val_loss 1.0456204


Epoch: 4  Training loss: 1.4120902: 100%|██████████| 1/1 [00:00<00:00,  7.80it/s]
100%|██████████| 1/1 [00:00<00:00, 13.84it/s]


Epoch: 4  train_loss 1.4120902 Val_loss 1.1998557 MAE Val_loss 0.9066779


Epoch: 5  Training loss: 1.1998554: 100%|██████████| 1/1 [00:00<00:00,  7.66it/s]
100%|██████████| 1/1 [00:00<00:00, 12.75it/s]


Epoch: 5  train_loss 1.1998554 Val_loss 1.3612971 MAE Val_loss 1.0300183


Epoch: 6  Training loss: 1.3612971: 100%|██████████| 1/1 [00:00<00:00,  7.78it/s]
100%|██████████| 1/1 [00:00<00:00, 12.69it/s]


Epoch: 6  train_loss 1.3612971 Val_loss 1.3032413 MAE Val_loss 1.0039868


Epoch: 7  Training loss: 1.3032413: 100%|██████████| 1/1 [00:00<00:00,  8.03it/s]
100%|██████████| 1/1 [00:00<00:00, 12.72it/s]


Epoch: 7  train_loss 1.3032413 Val_loss 1.2118188 MAE Val_loss 0.9358868


Epoch: 8  Training loss: 1.2118188: 100%|██████████| 1/1 [00:00<00:00,  8.02it/s]
100%|██████████| 1/1 [00:00<00:00, 12.96it/s]


Epoch: 8  train_loss 1.2118188 Val_loss 1.2039841 MAE Val_loss 0.9187584


Epoch: 9  Training loss: 1.203984: 100%|██████████| 1/1 [00:00<00:00,  7.35it/s]
100%|██████████| 1/1 [00:00<00:00, 13.07it/s]


Epoch: 9  train_loss 1.203984 Val_loss 1.2364683 MAE Val_loss 0.9572525


Epoch: 10  Training loss: 1.2364682: 100%|██████████| 1/1 [00:00<00:00,  7.96it/s]
100%|██████████| 1/1 [00:00<00:00, 11.14it/s]

Epoch    11: reducing learning rate of group 0 to 1.0000e-04.
Epoch: 10  train_loss 1.2364682 Val_loss 1.2452238 MAE Val_loss 0.96423787



