In [36]:
import torch
import torch.nn as nn
from rdkit import Chem
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect
from rdkit.DataStructs.cDataStructs import ConvertToNumpyArray
import numpy as np
import deepchem as dc
import torch_geometric

atorvastatin_smiles = 'O=C(O)C[C@H](O)C[C@H](O)CCn2c(c(c(c2c1ccc(F)cc1)c3ccccc3)C(=O)Nc4ccccc4)'
atorvastatin = Chem.MolFromSmiles(atorvastatin_smiles)

fing_print = GetMorganFingerprintAsBitVect(atorvastatin, radius = 2, nBits = 2048)

fp_array = np.zeros((1, ))

ConvertToNumpyArray(fing_print, fp_array)
print(fp_array)
print(fp_array.shape)
print(sum(fp_array))

[0. 1. 0. ... 0. 0. 0.]
(2048,)
54.0


In [37]:
def get_atom_features(mol):
    ato_number = []
    num_h = []
    
    for atom in mol.GetAtoms(): 
        ato_number.append(atom.GetAtomicNum())
        num_h.append(atom.GetTotalNumHs(includeNeighbors = True))
        
    return torch.tensor([ato_number, num_h]).t()

In [124]:
def get_edge_index(mol): 
    r, c = [], []
    
    for bond in mol.GetBonds():
        st, end = bond.GetBeginAtomIdx(),  bond.GetEndAtomIdx()
        r += [st, end]
        c += [end, st]
        
    return torch.tensor([r, c], dtype = torch.long)

In [125]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

def prepare_dataloader(mol_list, batch_size=3):
    data_list = []

    for i, mol in enumerate(mol_list):
        
        x = get_atom_features(mol)
        edge_index = get_edge_index(mol)

        data = torch_geometric.data.data.Data(edge_index=edge_index, x=x)
        data_list.append(data)

    return DataLoader(data_list, batch_size=batch_size, shuffle=False), data_list

In [126]:
smiles_list = ['Cc1cc(c(C)n1c2ccc(F)cc2)S(=O)(=O)NCC(=O)N',
'CN(CC(=O)N)S(=O)(=O)c1c(C)n(c(C)c1S(=O)(=O)N(C)CC(=O)N)c2ccc(F)cc2',
'Fc1ccc(cc1)n2cc(COC(=O)CBr)nn2',
'CCOC(=O)COCc1cn(nn1)c2ccc(F)cc2',
'COC(=O)COCc1cn(nn1)c2ccc(F)cc2',
'Fc1ccc(cc1)n2cc(COCC(=O)OCc3cn(nn3)c4ccc(F)cc4)nn2']

mol_list = [Chem.MolFromSmiles(smi) for smi in smiles_list]
dloader, dList = prepare_dataloader(mol_list)
print(dList)
print(dloader)

[Data(x=[22, 2], edge_index=[2, 46]), Data(x=[32, 2], edge_index=[2, 66]), Data(x=[18, 2], edge_index=[2, 38]), Data(x=[20, 2], edge_index=[2, 42]), Data(x=[19, 2], edge_index=[2, 40]), Data(x=[31, 2], edge_index=[2, 68])]
<torch_geometric.loader.dataloader.DataLoader object at 0x000001A8F9302208>


In [127]:
for batch in dloader:
    break
    
print(batch)

DataBatch(x=[72, 2], edge_index=[2, 150], batch=[72], ptr=[4])


In [128]:
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter_add
from torch_geometric.utils import add_self_loops, degree

class NeuralLoop(MessagePassing):
    def __init__(self, atom_features, fp_size):
        super(NeuralLoop, self).__init__(aggr='add')
        self.H = nn.Linear(atom_features, atom_features)
        self.W = nn.Linear(atom_features, fp_size)
        
    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
    
    def message(self, x_j, edge_index, size):
        return x_j 
    
    def update(self, v):
        v = v.type('torch.FloatTensor')
        updated_atom_features = self.H(v).sigmoid()
        updated_fingerprint = self.W(updated_atom_features).softmax(dim=-1)
        
        return updated_atom_features, updated_fingerprint # shape [N, atom_features]
    
class NeuralFP(nn.Module):
    def __init__(self, atom_features=52, fp_size=2048):
        super(NeuralFP, self).__init__()
        
        self.atom_features = 52
        self.fp_size = 2048
        
        self.loop1 = NeuralLoop(atom_features=atom_features, fp_size=fp_size)
        self.loop2 = NeuralLoop(atom_features=atom_features, fp_size=fp_size)
        self.loops = nn.ModuleList([self.loop1, self.loop2])
        
    def forward(self, data):
        fingerprint = torch.zeros((data.batch.shape[0], self.fp_size), dtype=torch.float)
        
        out = data.x
        for idx, loop in enumerate(self.loops):
            updated_atom_features, updated_fingerprint = loop(out, data.edge_index)
            out = updated_atom_features
            fingerprint += updated_fingerprint
            
        return scatter_add(fingerprint, data.batch, dim=0)

In [129]:
neural_fp = NeuralFP(atom_features = 2, fp_size = 2048)
fps = neural_fp(batch)
print(fps.shape)

torch.Size([3, 2048])


In [192]:
_, (train, valid, test), _ = dc.molnet.load_bace_regression(featurizer = 'Raw')
bs = 4
print(train.X)
train_loader, _ = prepare_dataloader(train.X, batch_size = bs)
valid_loader, _ = prepare_dataloader(valid.X, batch_size = bs)
test_loader, _ = prepare_dataloader(test.X, batch_size = bs)

train_labels_loader = torch.utils.data.DataLoader(train.y, batch_size = bs)
valid_labels_loader = torch.utils.data.DataLoader(valid.y, batch_size = bs)
test_labels_loader = torch.utils.data.DataLoader(test.y, batch_size = bs)



[<rdkit.Chem.rdchem.Mol object at 0x000001A8F92F72D0>
 <rdkit.Chem.rdchem.Mol object at 0x000001A8F92F7298>
 <rdkit.Chem.rdchem.Mol object at 0x000001A8F92F7260> ...
 <rdkit.Chem.rdchem.Mol object at 0x000001A8F92AA378>
 <rdkit.Chem.rdchem.Mol object at 0x000001A8F92AA458>
 <rdkit.Chem.rdchem.Mol object at 0x000001A8F92AA538>]


In [193]:
print(len(valid_labels_loader))

38


In [194]:
import torch.nn.functional as F

class MLP_regressor(nn.Module): 
    def __init__(self, atom_features = 2, fp_size = 2048, hidden_size = 100): 
        super(MLP_regressor, self).__init__()
        self.neural_fp = neural_fp
        self.lin1 = nn.Linear(fp_size, hidden_size)
        self.lin2 = nn.Linear(hidden_size, 1)
        self.dropout = nn.Dropout(p=0.3)
        
    def forward(self, batch): 
        fp = self.neural_fp(batch)
        hidden = F.relu(self.dropout(self.lin1(fp)))
        out = F.relu(self.lin2(hidden))
        return out

In [195]:
device = 'cuda'

In [196]:
def train_step(batch, labels, reg): 
    out = reg(batch)
    loss = F.mse_loss(out, labels.to(torch.float), reduction = 'mean')
    loss.backward()
    return loss

In [197]:
def valid_step(batch, labels, reg): 
    out = reg(batch)
    loss = F.mse_loss(out, labels.to(torch.float), reduction = 'mean')
    # loss.backward()
    return loss

In [198]:
def train_fn(train_loader, train_labels_loader, reg, opt): 
    reg.train()
    total_loss = 0
    for idx, (batch, labels) in enumerate(zip(train_loader, train_labels_loader)):
        loss = train_step(batch, labels, reg)
        total_loss += loss.item()
        
    torch.nn.utils.clip_grad_norm_(reg.parameters(), 1)
    opt.step()
    opt.zero_grad()
    return total_loss/len(train_loader)

In [199]:
def valid_fn(valid_loader, valid_labels_loader, reg): 
    reg.train()
    total_loss = 0
    with torch.no_grad():
        for idx, (batch, labels) in enumerate(zip(valid_loader, valid_labels_loader)):
            loss = valid_step(batch, labels, reg)
            total_loss += loss.item()
    
    total_loss /= len(valid_loader)
    return total_loss

In [200]:
reg = MLP_regressor(atom_features = 2, fp_size = 2048, hidden_size = 100)
optimizer = torch.optim.SGD(reg.parameters(), lr = 0.001, weight_decay = 0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 100)

total_epochs = 1000

history = dict()
tr_l, va_l = list(), list()
for epoch in range(1, total_epochs+1): 
    train_loss = train_fn(train_loader, train_labels_loader, reg, opt= optimizer)
    valid_loss = valid_fn(valid_loader, valid_labels_loader, reg) 
    
    scheduler.step(valid_loss)
    
    if epoch % 10 == 0:
        print(f'Epoch:{epoch}   train_loss: {train_loss}   valid_loss: {valid_loss}')
        
    tr_l.append(train_loss)
    va_l.append(valid_loss)


Epoch:10   train_loss: 1.0028532254709417   valid_loss: 0.314499016242304
Epoch:20   train_loss: 1.000763944407046   valid_loss: 0.32039252947135155
Epoch:30   train_loss: 0.9989624869765363   valid_loss: 0.32534536729049013
Epoch:40   train_loss: 0.9987429508688881   valid_loss: 0.3307444569233523
Epoch:50   train_loss: 0.9973115894940402   valid_loss: 0.33397750943664795
Epoch:60   train_loss: 0.9958178968377804   valid_loss: 0.3380616707325221
Epoch:70   train_loss: 0.9951895318876202   valid_loss: 0.34095786681064055
Epoch:80   train_loss: 0.9938792334362349   valid_loss: 0.34409069778676465
Epoch:90   train_loss: 0.9929891601320134   valid_loss: 0.3444661681181682
Epoch:100   train_loss: 0.9915923232208053   valid_loss: 0.3461741656122429
Epoch:110   train_loss: 0.99160590018845   valid_loss: 0.3450323944497014
Epoch:120   train_loss: 0.9915192742069734   valid_loss: 0.3455415155999378
Epoch:130   train_loss: 0.9920001698180683   valid_loss: 0.3446881283579339
Epoch:140   train_lo

In [205]:
# summarize history for accuracy
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import style
style.use('fivethirtyeight')
sns.set(style='whitegrid',color_codes=True)

# summarize history for loss
plt.figure(figsize = (20,4))
plt.plot(tr_l)
plt.plot(va_l)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.savefig('./loss_1.png')