# Predict EE (Enantionmeric Excesses)

*Zhongying Ru  zilla_ru@zju.edu.cn  Apr.29th*


### Input

SMILES: Simplified Molecular Input Line Entry System

- **SM1**
- **SM2**
- SM_metal (optional)
- SM_ligand
- SM_solvent (optional)

### Pipeline

1. SMILES string -- RDkit-> molecule graph  --GNN-> vector representation

2. concatenate the vector representations of the 5 input variables

3. fully-connected layers or CNNs (*todo*)


### Output

- predicted EE

### Loss Fuction

#### predict in an end-to-end manner
Measure the sum of difference from the predicted EE to the true EE of **training samples**

### Train/Valid/Test set division

randomly divide as 8:1:1 (approximately)
or
use the predefined train/test files.

In [5]:
# prepare packages

from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import dgl
from rdkit import Chem
from torch.utils.data import DataLoader
from dgllife.utils import mol_to_complete_graph, mol_to_bigraph
from dgllife.utils import atom_type_one_hot
from dgllife.utils import atom_degree_one_hot
from dgllife.utils import atom_formal_charge
from dgllife.utils import atom_num_radical_electrons
from dgllife.utils import atom_hybridization_one_hot
from dgllife.utils import atom_total_num_H_one_hot
from dgllife.utils import CanonicalAtomFeaturizer
from dgllife.utils import CanonicalBondFeaturizer
from dgllife.utils import ConcatFeaturizer
from dgllife.utils import BaseAtomFeaturizer
from dgllife.utils import BaseBondFeaturizer

In [6]:
# prepare data
import csv
from itertools import islice
train_file = 'data/train_no_metal.csv'
test_file = 'data/test_no_metal.csv'
mol_id_dict_filename = 'data/n0_metal_sm_to_id.npy'

## Preprocess: map molecule SMILES to id

In [7]:
# load train & test files
# format (SM1, SM2, metal, ligand, solvent, ee)

sm_to_id_dict = dict()
sm_cnt, sm1_cnt, sm2_cnt, met_cnt, lig_cnt, sol_cnt= 0, 0, 0, 0, 0, 0

def build_mol_dict(fname):
    global sm_to_id_dict, sm_cnt, sm1_cnt, sm2_cnt, met_cnt, lig_cnt, sol_cnt
    id_id_ee_list = list()
    with open(fname, 'r') as f:
        reader = csv.reader(f)
        
        for row in islice(reader, 1, None):  # if the csv has a header , skip the 1st row
            sm1, sm2, met, lig, sol, ee = row
            if sm1 not in sm_to_id_dict:
                sm_to_id_dict[sm1] = sm_cnt
                sm_cnt+=1
                sm1_cnt+=1
            if sm2 not in sm_to_id_dict:
                sm_to_id_dict[sm2] = sm_cnt
                sm_cnt+=1
                sm2_cnt+=1
            if met != '' and met not in sm_to_id_dict:
                sm_to_id_dict[met] = sm_cnt
                sm_cnt+=1
                met_cnt+=1
            if lig not in sm_to_id_dict:
                sm_to_id_dict[lig] = sm_cnt
                sm_cnt+=1
                lig_cnt+=1
            if sol not in sm_to_id_dict:
                sm_to_id_dict[sol] = sm_cnt
                sm_cnt+=1
                sol_cnt+=1
            if met == '':
                met = -1
            else:
                met = sm_to_id_dict[met]
            id_id_ee_list.append((sm_to_id_dict[sm1], sm_to_id_dict[sm2], met, sm_to_id_dict[lig], sm_to_id_dict[sol], float(ee)))
        np.save(fname[:-4]+'_id.npy', id_id_ee_list)
            
build_mol_dict(train_file)
build_mol_dict(test_file)
print(f'sm_cnt = {sm_cnt}, sm1_cnt = {sm1_cnt}, sm2_cnt = {sm2_cnt}, \
metal_cnt = {met_cnt}, ligand_cnt = {lig_cnt}, solvent_cnt = {sol_cnt}')

# save `sm_to_id_dict` to file
np.save(mol_id_dict_filename, sm_to_id_dict)
# dict = np.load(mol_id_dict_filename).item()

sm_cnt = 407, sm1_cnt = 184, sm2_cnt = 43, metal_cnt = 0, ligand_cnt = 144, solvent_cnt = 36


## Model Settings

In [8]:
use_metal_emb = False
use_ligand_emb = True
use_solvent_emb = False


## Load Train set

In [2]:
# input: SMILES strings of molecules
# intermediate output: molecule representation
import dgllife.model.gnn.attentivefp as AFP
AFP

Using backend: pytorch


<module 'dgllife.model.gnn.attentivefp' from '/Users/zilla/py/anaconda3/envs/dgl_lifesci/lib/python3.6/site-packages/dgllife/model/gnn/attentivefp.py'>

## Define an AttentiveFP model

In [None]:
class AFP_EE_Predictor(nn.Module):
    """
    an end-to-end model based on AttentiveFP for regression

    AttentiveFP is introduced in
    `Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph
    Attention Mechanism. <https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__

    AttentiveFP Parameters
    ----------
    node_feat_size : int
        Size for the input node features.
    edge_feat_size : int
        Size for the input edge features.
    num_layers : int
        Number of GNN layers. Default to 2.
    num_timesteps : int
        Times of updating the graph representations with GRU. Default to 2.
    graph_feat_size : int
        Size for the learned graph representations. Default to 200.
        
    
    Task-oriented Predictor Parameters
    ----------
    pred_input_size : int
        depend on graph_feat_size and the concat manner
    n_tasks : int
        Number of tasks, which is also the output size. Default to 1.
    dropout : float
        Probability for performing the dropout. Default to 0.
    """
    class EE_Predictor(nn.Module):
        def __init__(self,
                     input_size,
                     n_task=1,
                     dropout=0.):
            self.in_size = input_size
            self.layer = nn.Sequential(
                nn.Dropout(dropout),
                nn.Linear(input_size, n_tasks)
            )
        def forward(self, g_feats, samples): # , sm_id_dict
            # graph feat -> sample feat
            sample_feats = np.zeros(shape=(len(samples), self.in_size)) #  dtype=float
            i = 0
            for sm1_id, sm2_id, met_id, lig_id, sol_id in samples:
                sample_feat = np.concatenate((g_feats[sm1_id], 
                                              g_feats[sm2_id], 
                                              g_feats[met_id], 
                                              g_feats[lig_id], 
                                              g_feats[sol_id]),
                                             axis=0) if met_id!=-1 else np.concatenate((g_feats[sm1_id],
                                                                                        g_feats[sm2_id],
                                                                                        g_feats[lig_id], 
                                                                                        g_feats[sol_id]),
                                                                                       axis=0)
                sample_feats[i] = sample_feat
                i+=1
                
            return self.layer(sample_feats)
        
    def __init__(self,
                 node_feat_size,
                 edge_feat_size,
                 num_layers=2,
                 num_timesteps=2,
                 graph_feat_size=200,
                 pred_input_size=800, # 4*200 - no metal, 5*200 - metal
                 n_tasks=1,
                 dropout=0.):
        super(AttentiveFPPredictor, self).__init__()
        
        self.gnn = AttentiveFPGNN(node_feat_size=node_feat_size,
                                  edge_feat_size=edge_feat_size,
                                  num_layers=num_layers,
                                  graph_feat_size=graph_feat_size,
                                  dropout=dropout)
        self.readout = AttentiveFPReadout(feat_size=graph_feat_size,
                                          num_timesteps=num_timesteps,
                                          dropout=dropout)
        self.predict = EE_Predictor(pred_input_size)

    def forward(self, g, samples, node_feats, edge_feats, get_node_weight=False):
        """Graph-level regression/soft classification.
        
        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs.
            Here, we treat all molecule graphs as a batch.
            
        node_feats : float32 tensor of shape (V, node_feat_size)
            Input node features. V for the number of nodes.
            
        edge_feats : float32 tensor of shape (E, edge_feat_size)
            Input edge features. E for the number of edges.
            
        get_node_weight : bool
            Whether to get the weights of atoms during readout. Default to False.

        Returns
        -------
        float32 tensor of shape (G, n_tasks)
            Prediction for the graphs in the batch. G for the number of graphs.
            
        node_weights : list of float32 tensor of shape (V, 1), optional
            This is returned when ``get_node_weight`` is ``True``.
            The list has a length ``num_timesteps`` and ``node_weights[i]``
            gives the node weights in the i-th update.
        """
        node_feats = self.gnn(g, node_feats, edge_feats)
        if get_node_weight:
            g_feats, node_weights = self.readout(g, node_feats, get_node_weight)
            return self.predict(g_feats, samples), node_weights
        else:
            g_feats = self.readout(g, node_feats, get_node_weight)
            return self.predict(g_feats, samples)


## Train and save the model

## Visualize the training process

## Load Test set

## Test and visualize