## 1. Introduction - Overview of problem and approach

## 2. Disease and Background - Visualizing the Target

## 3. Dataset - Finding relevant compounds with activities

### General Dataset Preparation

In [None]:
!conda install -c rdkit rdkit -y
!git clone https://github.com/tmacdou4/2019-nCov.git

In [None]:
import os
os.listdir("/kaggle/working/2019-nCov")

In [None]:
! grep AID /kaggle/working/2019-nCov/Data/SARS_C3_Assays.txt > /kaggle/working/2019-nCov/Data/SARS_C3_Assays_AID_only.csv

In [None]:
! sed -i 's/[^0-9]//g' /kaggle/working/2019-nCov/Data/SARS_C3_Assays_AID_only.csv

In [None]:
#Imports
import rdkit
from rdkit.Chem import AllChem as Chem
from rdkit.DataStructs import cDataStructs
import numpy as np
import pandas as pd
from rdkit.Chem.Draw import IPythonConsole
import matplotlib.pyplot as plt
import os
import time
import pickle
import csv
from rdkit.Chem import QED
import random
import json
from sklearn.preprocessing import StandardScaler

In [None]:
def get_assays(assay_path, assay_pickle_path):
    with open(str(assay_path)) as f:
        r = csv.reader(f)
        AIDs = list(r)
    assays = []
    for i, AID in zip(range(len(AIDs)), AIDs):
        #This needs to be changed to 
        #os.system('curl https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/%s/sdf -o cmp.sdf' %CID)
        #if you run it on a mac
        os.system(f'wget https://pubchem.ncbi.nlm.nih.gov/rest/pug/assay/aid/{str(AID[0])}/csv -O Data/assay.csv')
        if os.stat(f'/kaggle/working/2019-nCov/Data/assay.csv').st_size != 0:
            assays.append(pd.read_csv(f'/kaggle/working/2019-nCov/Data/assay.csv'))

    pickle.dump(assays, open(str(assay_pickle_path), "wb"))

def get_mols_for_assays(assays_no_mol_path, assays_with_mol_path):
    assays = pickle.load(open(str(assays_no_mol_path), "rb"))
    for assay in assays:
        if len(assay) != 1:
            cids = list(assay[['PUBCHEM_CID']].values.astype("int32").squeeze())
            nan_counter = 0
            for i in range(len(cids)):
                if cids[i] < 0:
                    nan_counter += 1
                else:
                    break
            cids = cids[nan_counter:]
            mols = []
            for CID in cids:
                #os.system('curl https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/%s/sdf -o cmp.sdf' %CID)
                os.system('wget https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/%s/sdf -O cmp.sdf' %CID)
                if os.stat(f'/kaggle/working/2019-nCov/Data/cmp.sdf').st_size != 0:
                    mols.append(Chem.SDMolSupplier("/kaggle/working/2019-nCov/Data/cmp.sdf")[0])
                else:
                    mols.append(None)

            for i in range(nan_counter):
                mols.insert(0,None)

            assay.insert(3, "Mol Object", mols)

    pickle.dump(assays, open(str(assays_with_mol_path), "wb"))

In [None]:
get_assays("/kaggle/working/2019-nCov/Data/SARS_C3_Assays_AID_only.csv", 
           "/kaggle/working/2019-nCov/Data/sars/sars_assays_no_mol.pkl")

In [None]:
get_mols_for_assays("/kaggle/working/2019-nCov/Data/sars/sars_assays_no_mol.pkl",
                    "/kaggle/working/2019-nCov/Data/sars/sars_assays.pkl")

In [None]:
#This goes an HTTP get on EVERY compound and takes a WHILE. Might be better to just use the pickled datasets
# get_assays_no_mol("/kaggle/working/2019-nCov/Data/MERS_Protease_Assays_AID_only.csv",
#                   "/kaggle/working/2019-nCov/Data/mers/mers_assays_no_mol.pkl")
# get_mols_for_assay("/kaggle/working/2019-nCov/Data/mers/mers_assays_no_mol.pkl", 
#                    "/kaggle/working/2019-nCov/Data/mers/mers_assays.pkl")
# get_assays_no_mol("/kaggle/working/2019-nCov/Data/NS3_Protease_Assays_AID_only.csv", 
#                   "/kaggle/working/2019-nCov/Data/ns3/ns3_assays_no_mol.pkl")
# get_mols_for_assay("/kaggle/working/2019-nCov/Data/ns3/ns3_assays_no_mol.pkl",
#                    "/kaggle/working/2019-nCov/Data/ns3/ns3_assays.pkl")
# get_assays_no_mol("/kaggle/working/2019-nCov/Data/HIV_Protease_Assays_AID_only.csv",
#                   "/kaggle/working/2019-nCov/Data/hiv/hiv_assays_no_mol.pkl")
# get_mols_for_assay("/kaggle/working/2019-nCov/Data/hiv/hiv_assays_no_mol.pkl", 
#                    "/kaggle/working/2019-nCov/Data/hiv/hiv_assays.pkl")

In [None]:
#This datastructure is a dictionary of lists of dataframe. 
assays = {}
assays["sars"] = pickle.load(open("/kaggle/working/2019-nCov/Data/sars/sars_assays.pkl", "rb"))
assays["mers"] = pickle.load(open("/kaggle/working/2019-nCov/Data/mers/mers_assays.pkl", "rb"))
assays["ns3"] = pickle.load(open("/kaggle/working/2019-nCov/Data/ns3/ns3_assays.pkl", "rb"))
assays["hiv"] = pickle.load(open("/kaggle/working/2019-nCov/Data/hiv/hiv_assays.pkl", "rb"))

It is worth mentioning here the different kinds of Bioactivities that an assay can report. Depending on what was relevant to the scientists involved in the study, various values can be used. Possibly most importantly for generating this dataset though is to not confuse the different kinds of activities. We will focus on IC50, which is the concentration of the compound at which 50% inhibition is observed. The value is normal reported as a "Micromolar concentration". The lower the value, the better the compound is at inhibiting the protein. It is important to not be tempted to use the "Activity" reported in some assays, which is normally a % and corresponds to how much that compound inhibits the protein at a given concentration. We're sticking with IC50 because this value is very information rich and actually many "Activity" experiments go into producing 1 IC50 value. Also they are more easily comparable, as we don't need to standardize concentration across the assays.

For this report we will focus on the "PubChem Standard Value" which is normally a standardized value using some metric (we will further narrow to only the metrics we want)

In [None]:
# #This removes all the assays that do not have a column called "PubChem Standard Value"
# for a in ["sars", "mers", "ns3", "hiv"]:
#     print("Length of",str(a),"before removing")
#     print(len(assays[a]))
#     assays[a] = np.array(assays[a])
#     bad_list = []
#     good_list = []
#     for i in range(len(assays[a])):
#         ic50_cols = [col for col in assays[a][i].columns if 'PubChem Standard Value' in col]
#         if not ic50_cols:
#             bad_list.append(i)
#         else:
#             good_list.append(int(i))

#     bad_list = np.array(bad_list)
#     good_list = np.array(good_list, dtype='int32')

#     assays[a] = assays[a][good_list]
#     print("Length of",str(a),"after removing")
#     print(len(assays[a]))

In [None]:
# #Remove unnesessary columns
# for a in ["sars", "mers", "ns3", "hiv"]:
#     for i in range(len(assays[a])):
#         assays[a][i] = assays[a][i][["Mol Object", "PubChem Standard Value", "Standard Type"]]

In [None]:
# #Look at what different kind of metrics were used
# for a in ["sars", "mers", "ns3", "hiv"]:
#     for i in range(len(assays[a])):
#         print(assays[a][i][["Standard Type"]].values[-1])

In [None]:
# #concatenate all of the dataframe in the dictionary into a single list.
# #We lose the notion that they were once for different targets
# all_dfs = []
# for a in ["sars", "mers", "ns3", "hiv"]:
#     for i in range(len(assays[a])):
#         if assays[a][i][["Standard Type"]].values[-1][0] in {"IC50", "Ki", "Kd", "IC90"}:
#             all_dfs.append(assays[a][i])

In [None]:
# #Remove header info and concatenate them
# for i in range(len(all_dfs)):
#     all_dfs[i] = all_dfs[i].iloc[4:]
# final_df = pd.concat(all_dfs)

In [None]:
# #Take only the compounds with activites below 0.1 (all will be relatively active)
# final_df['PubChem Standard Value'] = final_df['PubChem Standard Value'].astype(float)
# df = final_df[final_df["PubChem Standard Value"] < 0.1]

# pickle.dump(df, open("/kaggle/working/2019-nCov/Data/final_df.pkl", "wb"))

### Method Specific-preparation

Now moving on to preparing the dataset for use in the predictive model as well as the generative model

In [None]:
df = pickle.load(open("/kaggle/working/2019-nCov/Data/final_df.pkl", "rb"))

In [None]:
df.insert(3, 'smiles', [Chem.MolToSmiles(x) for x in df[['Mol Object']].values[:,0]], True)
df.insert(4, 'qed', [QED.qed(x) for x in df[['Mol Object']].values[:,0]], True)

In [None]:
salt_indexes = []
for i in range(len(df)):
    if "." in df[["smiles"]].values[i][0]:
        salt_indexes.append(i)

In [None]:
df = df.reset_index()
df = df.drop(df.index[salt_indexes])

#This is the format for the generative model, the cgvae
#smiles string and QED values, with validation id's defined in a separate json file
df[["smiles", "qed"]].to_csv("/kaggle/working/2019-nCov/Data/250k_rndm_zinc_drugs_clean_3.csv", index=False)
new_valid_idx = random.sample(range(len(df)), int(len(df)*0.1))
new_valid_idx.sort()
with open("/kaggle/working/2019-nCov/Data/valid_idx_zinc.json", 'w') as f:
    json.dump(new_valid_idx, f)

In [None]:
#The values are very small, so it's more effective to work in log-space
df.insert(2, 'log_std', [-np.log10(x) for x in df[['PubChem Standard Value']].values[:,0]], True)
#plt.hist(df[['log_std']].values[:,0], bins=25)

In [None]:
#Scale the values to have 0 mean and unit variance
scaler = StandardScaler()
scaler.fit(df[['log_std']].values[:,0].reshape(-1, 1))
df.insert(2, 'log_std_scaled', scaler.transform(df[['log_std']].values[:,0].reshape(-1, 1)), True)

In [None]:
#Format the data for the predictive model
data = df[["smiles", "log_std_scaled"]].values
np.random.shuffle(data)
train, valid, test = np.split(data, [int(.8*data.shape[0]), int(.9*data.shape[0])])
train = np.insert(train, 0, [None]*train.shape[0], 1)
valid = np.insert(valid, 0, [None]*valid.shape[0], 1)
test = np.insert(test, 0, [None]*test.shape[0], 1)

In [None]:
#Save the data for the predicitive model
pd.DataFrame(train).to_csv("/kaggle/working/2019-nCov/Data/protease_train.csv.gz",
                           index=False, compression='gzip', sep='\t')
pd.DataFrame(valid).to_csv("/kaggle/working/2019-nCov/Data/protease_valid.csv.gz",
                           index=False, compression='gzip', sep='\t')
pd.DataFrame(test).to_csv("/kaggle/working/2019-nCov/Data/protease_test.csv.gz",
                          index=False, compression='gzip', sep='\t')

## 4. Bioactivity Prediction - Edge Memory Neural Network

This approach is detailed in the paper: "Building Attention and Edge Convolution Neural Networks for Bioactivity and Physical-Chemical Property Prediction" available here: https://chemrxiv.org/articles/Building_Attention_and_Edge_Convolution_Neural_Networks_for_Bioactivity_and_Physical-Chemical_Property_Prediction/9873599 and with code available here: https://github.com/edvardlindelof/graph-neural-networks-for-drug-discovery. Other than moving it all to this notebook, the code is only lightly modified coming from that repo

They describe efforts to make several new network architectures to better learn from chemical graph data. Their implementation is written using pytorch and i'm got it working with the requirements described in the install.sh script.

Outline of this section: First we train the model using the dataset obtained in section 3. Then we can use the trained model to look at commercially available libraries of molecules where it would be impossible to do docking studies on each compound. In this way, the machine learning model serves to "thin the herd" of potential compounds so that we can identify candidates for docking studies.

### Training the model

In [None]:
##################  aggregation.py
import torch
from torch import nn


class AggregationMPNN(nn.Module):

    def __init__(self, node_features, edge_features, message_size, message_passes, out_features):
        super(AggregationMPNN, self).__init__()
        self.node_features = node_features
        self.edge_features = edge_features
        self.message_size = message_size
        self.message_passes = message_passes
        self.out_features = out_features

    # nodes (total number of nodes in batch, number of features)
    # node_neighbours (total number of nodes in batch, max node degree, number of features)
    # node_neighbours (total number of nodes in batch, max node degree, number of edge features)
    # mask (total number of nodes in batch, max node degree) elements are 1 if corresponding neighbour exist
    def aggregate_message(self, nodes, node_neighbours, edges, mask):
        raise NotImplementedError

    # inputs are "batches" of shape (maximum number of nodes in batch, number of features)
    def update(self, nodes, messages):
        raise NotImplementedError

    # inputs are "batches" of same shape as the nodes passed to update
    # node_mask is same shape as inputs and is 1 if elements corresponding exists, otherwise 0
    def readout(self, hidden_nodes, input_nodes, node_mask):
        raise NotImplementedError

    def forward(self, adjacency, nodes, edges):
        edge_batch_batch_indices, edge_batch_node_indices, edge_batch_neighbour_indices = adjacency.nonzero().unbind(-1)

        node_batch_batch_indices, node_batch_node_indices = adjacency.sum(-1).nonzero().unbind(-1)
        node_batch_adj = adjacency[node_batch_batch_indices, node_batch_node_indices, :]

        node_batch_size = node_batch_batch_indices.shape[0]
        node_degrees = node_batch_adj.sum(-1).long()
        max_node_degree = node_degrees.max()
        node_batch_node_neighbours = torch.zeros(node_batch_size, max_node_degree, self.node_features)
        node_batch_edges = torch.zeros(node_batch_size, max_node_degree, self.edge_features)

        node_batch_neighbour_neighbour_indices = torch.cat([torch.arange(i) for i in node_degrees])

        edge_batch_node_batch_indices = torch.cat(
            [i * torch.ones(degree) for i, degree in enumerate(node_degrees)]
        ).long()

        node_batch_node_neighbour_mask = torch.zeros(node_batch_size, max_node_degree)

        if next(self.parameters()).is_cuda:
            node_batch_node_neighbours = node_batch_node_neighbours.cuda()
            node_batch_edges = node_batch_edges.cuda()
            node_batch_neighbour_neighbour_indices = node_batch_neighbour_neighbour_indices.cuda()
            edge_batch_node_batch_indices = edge_batch_node_batch_indices.cuda()
            node_batch_node_neighbour_mask = node_batch_node_neighbour_mask.cuda()

        node_batch_node_neighbour_mask[edge_batch_node_batch_indices, node_batch_neighbour_neighbour_indices] = 1

        node_batch_edges[edge_batch_node_batch_indices, node_batch_neighbour_neighbour_indices, :] = \
            edges[edge_batch_batch_indices, edge_batch_node_indices, edge_batch_neighbour_indices, :]

        hidden_nodes = nodes.clone()

        for i in range(self.message_passes):
            node_batch_nodes = hidden_nodes[node_batch_batch_indices, node_batch_node_indices, :]
            node_batch_node_neighbours[edge_batch_node_batch_indices, node_batch_neighbour_neighbour_indices, :] = \
                hidden_nodes[edge_batch_batch_indices, edge_batch_neighbour_indices, :]

            messages = self.aggregate_message(
                node_batch_nodes, node_batch_node_neighbours.clone(), node_batch_edges, node_batch_node_neighbour_mask
            )
            hidden_nodes[node_batch_batch_indices, node_batch_node_indices, :] = self.update(node_batch_nodes, messages)

        node_mask = (adjacency.sum(-1) != 0)  # .unsqueeze(-1).expand_as(nodes)
        output = self.readout(hidden_nodes, nodes, node_mask)
        return output


#############  aggregation_mpnn_implementation.py


class AttentionENNS2V(AggregationMPNN):

    def __init__(self, node_features, edge_features, message_size, message_passes, out_features,
                 enn_depth=3, enn_hidden_dim=200, enn_dropout_p=0,
                 att_depth=3, att_hidden_dim=200, att_dropout_p=0,
                 s2v_lstm_computations=12, s2v_memory_size=50,
                 out_depth=1, out_hidden_dim=200, out_dropout_p=0):
        super(AttentionENNS2V, self).__init__(
            node_features, edge_features, message_size, message_passes, out_features
        )
        self.enn = FeedForwardNetwork(
            edge_features, [enn_hidden_dim] * enn_depth, node_features * message_size, dropout_p=enn_dropout_p
        )
        self.att_enn = FeedForwardNetwork(
            node_features + edge_features, [att_hidden_dim] * att_depth, message_size, dropout_p=att_dropout_p
        )
        self.gru = nn.GRUCell(input_size=message_size, hidden_size=node_features, bias=False)
        self.s2v = Set2Vec(node_features, s2v_lstm_computations, s2v_memory_size)
        self.out_nn = FeedForwardNetwork(
            s2v_memory_size * 2, [out_hidden_dim] * out_depth, out_features, dropout_p=out_dropout_p, bias=False
        )

    def aggregate_message(self, nodes, node_neighbours, edges, mask):
        BIG_NEGATIVE = -1e6
        max_node_degree = node_neighbours.shape[1]

        enn_output = self.enn(edges)
        matrices = enn_output.view(-1, max_node_degree, self.message_size, self.node_features)
        message_terms = torch.matmul(matrices, node_neighbours.unsqueeze(-1)).squeeze()

        att_enn_output = self.att_enn(torch.cat([edges, node_neighbours], dim=2))
        energies = att_enn_output.view(-1, max_node_degree, self.message_size)
        energy_mask = (1 - mask).float() * BIG_NEGATIVE
        weights = torch.softmax(energies + energy_mask.unsqueeze(-1), dim=1)

        return (weights * message_terms).sum(1)

    def update(self, nodes, messages):
        return self.gru(messages, nodes)

    def readout(self, hidden_nodes, input_nodes, node_mask):
        graph_embeddings = self.s2v(hidden_nodes, input_nodes, node_mask)
        return self.out_nn(graph_embeddings)


class AttentionGGNN(AggregationMPNN):

    def __init__(self, node_features, edge_features, message_size, message_passes, out_features,
                 msg_depth=4, msg_hidden_dim=200, msg_dropout_p=0.0,
                 att_depth=3, att_hidden_dim=200, att_dropout_p=0,
                 gather_width=100,
                 gather_att_depth=3, gather_att_hidden_dim=100, gather_att_dropout_p=0.0,
                 gather_emb_depth=3, gather_emb_hidden_dim=100, gather_emb_dropout_p=0.0,
                 out_depth=2, out_hidden_dim=100, out_dropout_p=0.0, out_layer_shrinkage=1.0):
        super(AttentionGGNN, self).__init__(node_features, edge_features, message_size, message_passes, out_features)

        self.msg_nns = nn.ModuleList()
        self.att_nns = nn.ModuleList()
        for _ in range(edge_features):
            self.msg_nns.append(
                FeedForwardNetwork(node_features, [msg_hidden_dim] * msg_depth, message_size, dropout_p=msg_dropout_p,
                                   bias=False)
            )
            self.att_nns.append(
                FeedForwardNetwork(node_features, [att_hidden_dim] * att_depth, message_size, dropout_p=att_dropout_p,
                                   bias=False)
            )
        self.gru = nn.GRUCell(input_size=message_size, hidden_size=node_features, bias=False)
        self.gather = GraphGather(
            node_features, gather_width,
            gather_att_depth, gather_att_hidden_dim, gather_att_dropout_p,
            gather_emb_depth, gather_emb_hidden_dim, gather_emb_dropout_p
        )
        out_layer_sizes = [  # example: depth 5, dim 50, shrinkage 0.5 => out_layer_sizes [50, 42, 35, 30, 25]
            round(out_hidden_dim * (out_layer_shrinkage ** (i / (out_depth - 1 + 1e-9)))) for i in range(out_depth)
        ]
        self.out_nn = FeedForwardNetwork(gather_width, out_layer_sizes, out_features, dropout_p=out_dropout_p)

    def aggregate_message(self, nodes, node_neighbours, edges, node_neighbour_mask):
        energy_mask = (node_neighbour_mask == 0).float() * 1e6
        # xxs_masked_per_edge contains (batch_size, max_n_neighbours, message_size)-shape tensors, that has 0s in all rows except
        # the ones corresponding to the edge type indicated by the list index
        # intuitive way of writing this involves a torch.stack along batch dimension and is immensely slow
        embeddings_masked_per_edge = [
            edges[:, :, i].unsqueeze(-1) * self.msg_nns[i](node_neighbours) for i in range(self.edge_features)
        ]
        embedding = sum(embeddings_masked_per_edge)
        energies_masked_per_edge = [
            edges[:, :, i].unsqueeze(-1) * self.att_nns[i](node_neighbours) for i in range(self.edge_features)
        ]
        energies = sum(energies_masked_per_edge) - energy_mask.unsqueeze(-1)
        attention = torch.softmax(energies, dim=1)
        return torch.sum(attention * embedding, dim=1)

    def update(self, nodes, messages):
        return self.gru(messages, nodes)

    def readout(self, hidden_nodes, input_nodes, node_mask):
        graph_embeddings = self.gather(hidden_nodes, input_nodes, node_mask)
        return self.out_nn(graph_embeddings)


################  emn.py

class EMN(nn.Module):

    def __init__(self, edge_features, edge_embedding_size, message_passes, out_features):
        super(EMN, self).__init__()
        self.edge_features = edge_features
        self.edge_embedding_size = edge_embedding_size
        self.message_passes = message_passes
        self.out_features = out_features

    def preprocess_edges(self, nodes, node_neighbours, edges):
        raise NotImplementedError

    # (total number of edges in batch, edge_features) and (total number of edges in batch, max_node_degree, edge_features)
    def propagate_edges(self, edges, ingoing_edge_memories, ingoing_edges_mask):
        raise NotImplementedError

    def readout(self, hidden_nodes, input_nodes, node_mask):
        raise NotImplementedError

    # adjacency (N, n_nodes, n_nodes); edges (N, n_nodes, n_nodes, edge_features)
    def forward(self, adjacency, nodes, edges):
        # indices for finding edges in batch
        edges_b_idx, edges_n_idx, edges_nhb_idx = adjacency.nonzero().unbind(-1)

        n_edges = edges_n_idx.shape[0]
        adj_of_edge_batch_indices = adjacency.clone().long()
        r = torch.arange(n_edges) + 1  # +1 to distinguish the index 0 from 'empty' elements, subtracted few lines down
        if next(self.parameters()).is_cuda:
            r = r.cuda()
        adj_of_edge_batch_indices[edges_b_idx, edges_n_idx, edges_nhb_idx] = r

        ingoing_edges_eb_idx = (torch.cat([
            row[row.nonzero()] for row in adj_of_edge_batch_indices[edges_b_idx, edges_nhb_idx, :]
        ]) - 1).squeeze()

        edge_degrees = adjacency[edges_b_idx, edges_nhb_idx, :].sum(-1).long()
        ingoing_edges_igeb_idx = torch.cat([i * torch.ones(d) for i, d in enumerate(edge_degrees)]).long()
        ingoing_edges_ige_idx = torch.cat([torch.arange(i) for i in edge_degrees]).long()

        batch_size = adjacency.shape[0]
        n_nodes = adjacency.shape[1]
        max_node_degree = adjacency.sum(-1).max().int()
        edge_memories = torch.zeros(n_edges, self.edge_embedding_size)
        ingoing_edge_memories = torch.zeros(n_edges, max_node_degree, self.edge_embedding_size)
        ingoing_edges_mask = torch.zeros(n_edges, max_node_degree)
        if next(self.parameters()).is_cuda:
            edge_memories = edge_memories.cuda()
            ingoing_edge_memories = ingoing_edge_memories.cuda()
            ingoing_edges_mask = ingoing_edges_mask.cuda()

        edge_batch_nodes = nodes[edges_b_idx, edges_n_idx, :]
        edge_batch_neighbours = nodes[edges_b_idx, edges_nhb_idx, :]
        edge_batch_edges = edges[edges_b_idx, edges_n_idx, edges_nhb_idx, :]
        edge_batch_edges = self.preprocess_edges(edge_batch_nodes, edge_batch_neighbours, edge_batch_edges)

        # remove h_ji:s influence on h_ij
        ingoing_edges_nhb_idx = edges_nhb_idx[ingoing_edges_eb_idx]
        ingoing_edges_receiving_edge_n_idx = edges_n_idx[ingoing_edges_igeb_idx]
        not_same_idx = (ingoing_edges_receiving_edge_n_idx != ingoing_edges_nhb_idx).nonzero()
        ingoing_edges_eb_idx = ingoing_edges_eb_idx[not_same_idx].squeeze()
        ingoing_edges_ige_idx = ingoing_edges_ige_idx[not_same_idx].squeeze()
        ingoing_edges_igeb_idx = ingoing_edges_igeb_idx[not_same_idx].squeeze()

        ingoing_edges_mask[ingoing_edges_igeb_idx, ingoing_edges_ige_idx] = 1

        for i in range(self.message_passes):
            ingoing_edge_memories[ingoing_edges_igeb_idx, ingoing_edges_ige_idx, :] = \
                edge_memories[ingoing_edges_eb_idx, :]
            edge_memories = self.propagate_edges(edge_batch_edges, ingoing_edge_memories.clone(), ingoing_edges_mask)

        node_mask = (adjacency.sum(-1) != 0)

        node_sets = torch.zeros(batch_size, n_nodes, max_node_degree, self.edge_embedding_size)
        if next(self.parameters()).is_cuda:
            node_sets = node_sets.cuda()

        edge_batch_edge_memory_indices = torch.cat(
            [torch.arange(row.sum()) for row in adjacency.view(-1, n_nodes)]
        ).long()

        node_sets[edges_b_idx, edges_n_idx, edge_batch_edge_memory_indices, :] = edge_memories
        graph_sets = node_sets.sum(2)
        output = self.readout(graph_sets, graph_sets, node_mask)

        return output


#####################  emn_implementation.py

class EMNImplementation(EMN):

    def __init__(self, node_features, edge_features, message_passes, out_features,
                 edge_embedding_size,
                 edge_emb_depth=3, edge_emb_hidden_dim=150, edge_emb_dropout_p=0.0,
                 att_depth=3, att_hidden_dim=80, att_dropout_p=0.0,
                 msg_depth=3, msg_hidden_dim=80, msg_dropout_p=0.0,
                 gather_width=100,
                 gather_att_depth=3, gather_att_hidden_dim=100, gather_att_dropout_p=0.0,
                 gather_emb_depth=3, gather_emb_hidden_dim=100, gather_emb_dropout_p=0.0,
                 out_depth=2, out_hidden_dim=100, out_dropout_p=0, out_layer_shrinkage=1.0):
        super(EMNImplementation, self).__init__(
            edge_features, edge_embedding_size, message_passes, out_features
        )
        self.embedding_nn = FeedForwardNetwork(
            node_features * 2 + edge_features, [edge_emb_hidden_dim] * edge_emb_depth, edge_embedding_size,
            dropout_p=edge_emb_dropout_p
        )

        self.emb_msg_nn = FeedForwardNetwork(
            edge_embedding_size, [msg_hidden_dim] * msg_depth, edge_embedding_size, dropout_p=msg_dropout_p
        )
        self.att_msg_nn = FeedForwardNetwork(
            edge_embedding_size, [att_hidden_dim] * att_depth, edge_embedding_size, dropout_p=att_dropout_p
        )

        # self.extra_gru_layer = nn.Linear(edge_embedding_size, edge_embedding_size, bias=False)
        self.gru = nn.GRUCell(edge_embedding_size, edge_embedding_size, bias=False)
        self.gather = GraphGather(
            edge_embedding_size, gather_width,
            gather_att_depth, gather_att_hidden_dim, gather_att_dropout_p,
            gather_emb_depth, gather_emb_hidden_dim, gather_emb_dropout_p
        )
        out_layer_sizes = [  # example: depth 5, dim 50, shrinkage 0.5 => out_layer_sizes [50, 42, 35, 30, 25]
            round(out_hidden_dim * (out_layer_shrinkage ** (i / (out_depth - 1 + 1e-9)))) for i in range(out_depth)
        ]
        self.out_nn = FeedForwardNetwork(gather_width, out_layer_sizes, out_features, dropout_p=out_dropout_p)

    def preprocess_edges(self, nodes, node_neighbours, edges):
        cat = torch.cat([nodes, node_neighbours, edges], dim=1)
        return torch.tanh(self.embedding_nn(cat))

    def propagate_edges(self, edges, ingoing_edge_memories, ingoing_edges_mask):
        BIG_NEGATIVE = -1e6
        energy_mask = ((1 - ingoing_edges_mask).float() * BIG_NEGATIVE).unsqueeze(-1)

        cat = torch.cat([edges.unsqueeze(1), ingoing_edge_memories], dim=1)
        embeddings = self.emb_msg_nn(cat)

        edge_energy = self.att_msg_nn(edges)
        ing_memory_energies = self.att_msg_nn(ingoing_edge_memories) + energy_mask
        energies = torch.cat([edge_energy.unsqueeze(1), ing_memory_energies], dim=1)
        attention = torch.softmax(energies, dim=1)

        # set aggregation of the set of the given edge feature and ingoing edge memories
        message = (attention * embeddings).sum(dim=1)
        return self.gru(message)  # returning hidden state but it is also set internally I think.. hm

    def readout(self, hidden_nodes, input_nodes, node_mask):
        graph_embeddings = self.gather(hidden_nodes, input_nodes, node_mask)
        return self.out_nn(graph_embeddings)


################  graphh_features.py

# this is deepchems source file with unused stuff removed, for faster import
import numpy as np
from rdkit import Chem


def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception("input {0} not in allowable set{1}:".format(
            x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))


def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))


def get_intervals(l):
    """For list of lists, gets the cumulative products of the lengths"""
    intervals = len(l) * [0]
    # Initalize with 1
    intervals[0] = 1
    for k in range(1, len(l)):
        intervals[k] = (len(l[k]) + 1) * intervals[k - 1]

    return intervals


def safe_index(l, e):
    """Gets the index of e in l, providing an index of len(l) if not found"""
    try:
        return l.index(e)
    except:
        return len(l)


possible_atom_list = [
    'C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Mg', 'Na', 'Br', 'Fe', 'Ca', 'Cu',
    'Mc', 'Pd', 'Pb', 'K', 'I', 'Al', 'Ni', 'Mn'
]
possible_numH_list = [0, 1, 2, 3, 4]
possible_valence_list = [0, 1, 2, 3, 4, 5, 6]
possible_formal_charge_list = [-3, -2, -1, 0, 1, 2, 3]
possible_hybridization_list = [
    Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
    Chem.rdchem.HybridizationType.SP3D2
]
possible_number_radical_e_list = [0, 1, 2]
possible_chirality_list = ['R', 'S']

reference_lists = [
    possible_atom_list, possible_numH_list, possible_valence_list,
    possible_formal_charge_list, possible_number_radical_e_list,
    possible_hybridization_list, possible_chirality_list
]

intervals = get_intervals(reference_lists)


def get_feature_list(atom):
    features = 6 * [0]
    features[0] = safe_index(possible_atom_list, atom.GetSymbol())
    features[1] = safe_index(possible_numH_list, atom.GetTotalNumHs())
    features[2] = safe_index(possible_valence_list, atom.GetImplicitValence())
    features[3] = safe_index(possible_formal_charge_list, atom.GetFormalCharge())
    features[4] = safe_index(possible_number_radical_e_list,
                             atom.GetNumRadicalElectrons())
    features[5] = safe_index(possible_hybridization_list, atom.GetHybridization())
    return features


def features_to_id(features, intervals):
    """Convert list of features into index using spacings provided in intervals"""
    id = 0
    for k in range(len(intervals)):
        id += features[k] * intervals[k]

    # Allow 0 index to correspond to null molecule 1
    id = id + 1
    return id


def atom_to_id(atom):
    """Return a unique id corresponding to the atom type"""
    features = get_feature_list(atom)
    return features_to_id(features, intervals)


def atom_features(atom,
                  bool_id_feat=False,
                  explicit_H=False,
                  use_chirality=False):
    if bool_id_feat:
        return np.array([atom_to_id(atom)])
    else:
        results = one_of_k_encoding_unk(
            atom.GetSymbol(),
            [
                'C',
                'N',
                'O',
                'S',
                'F',
                'Si',
                'P',
                'Cl',
                'Br',
                'Mg',
                'Na',
                'Ca',
                'Fe',
                'As',
                'Al',
                'I',
                'B',
                'V',
                'K',
                'Tl',
                'Yb',
                'Sb',
                'Sn',
                'Ag',
                'Pd',
                'Co',
                'Se',
                'Ti',
                'Zn',
                'H',  # H?
                'Li',
                'Ge',
                'Cu',
                'Au',
                'Ni',
                'Cd',
                'In',
                'Mn',
                'Zr',
                'Cr',
                'Pt',
                'Hg',
                'Pb',
                'Unknown'
            ]) + one_of_k_encoding(atom.GetDegree(),
                                   [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + \
                  one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) + \
                  [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \
                  one_of_k_encoding_unk(atom.GetHybridization(), [
                      Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                      Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.
                                        SP3D, Chem.rdchem.HybridizationType.SP3D2
                  ]) + [atom.GetIsAromatic()]
        # In case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs`
        if not explicit_H:
            results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(),
                                                      [0, 1, 2, 3, 4])
        if use_chirality:
            try:
                results = results + one_of_k_encoding_unk(
                    atom.GetProp('_CIPCode'),
                    ['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
            except:
                results = results + [False, False
                                     ] + [atom.HasProp('_ChiralityPossible')]

        return np.array(results)


####################    modules.py
import math


class GraphGather(nn.Module):
    r"""The GGNN readout function
    """

    def __init__(self, node_features, out_features,
                 att_depth=2, att_hidden_dim=100, att_dropout_p=0.0,
                 emb_depth=2, emb_hidden_dim=100, emb_dropout_p=0.0):
        super(GraphGather, self).__init__()

        # denoted i and j in GGNN, MPNN and PotentialNet papers
        self.att_nn = FeedForwardNetwork(
            node_features * 2, [att_hidden_dim] * att_depth, out_features, dropout_p=att_dropout_p, bias=False
        )
        self.emb_nn = FeedForwardNetwork(
            node_features, [emb_hidden_dim] * emb_depth, out_features, dropout_p=emb_dropout_p, bias=False
        )

    def forward(self, hidden_nodes, input_nodes, node_mask):
        cat = torch.cat([hidden_nodes, input_nodes], dim=2)
        energy_mask = (node_mask == 0).float() * 1e6
        energies = self.att_nn(cat) - energy_mask.unsqueeze(-1)
        attention = torch.sigmoid(energies)
        # attention = torch.softmax(energies, dim=1)
        embedding = self.emb_nn(hidden_nodes)
        return torch.sum(attention * embedding, dim=1)


class Set2Vec(nn.Module):
    r"""The readout function of MPNN paper's best network
    """

    # used to set attention terms to 0 when passing energies to softmax
    # tf code uses same trick
    BIG_NEGATIVE = -1e6

    def __init__(self, node_features, lstm_computations, memory_size):
        super(Set2Vec, self).__init__()

        self.lstm_computations = lstm_computations
        self.memory_size = memory_size

        self.embedding_matrix = nn.Linear(node_features * 2, self.memory_size, bias=False)
        self.lstm = nn.LSTMCell(self.memory_size, self.memory_size, bias=False)

    def forward(self, hidden_output_nodes, input_nodes, node_mask):
        batch_size = input_nodes.shape[0]
        energy_mask = (1 - node_mask).float() * self.BIG_NEGATIVE

        lstm_input = torch.zeros(batch_size, self.memory_size)

        cat = torch.cat([hidden_output_nodes, input_nodes], dim=2)
        memory = self.embedding_matrix(cat)

        hidden_state = torch.zeros(batch_size, self.memory_size)
        cell_state = torch.zeros(batch_size, self.memory_size)

        if next(self.parameters()).is_cuda:
            lstm_input = lstm_input.cuda()
            hidden_state = hidden_state.cuda()
            cell_state = cell_state.cuda()

        for i in range(self.lstm_computations):
            query, cell_state = self.lstm(lstm_input, (hidden_state, cell_state))
            # dot product query x memory
            energies = (query.view(batch_size, 1, self.memory_size) * memory).sum(dim=-1)
            attention = torch.softmax(energies + energy_mask, dim=1)
            read = (attention.unsqueeze(-1) * memory).sum(dim=1)

            hidden_state = query
            lstm_input = read

        cat = torch.cat([query, read], dim=1)
        return cat


class FeedForwardNetwork(nn.Module):
    r"""Convenience class to create network composed of linear layers with an activation function
    applied between them

    Args:
        in_features: size of each input sample
        hidden_layer_sizes: list of hidden layer sizes
        out_features: size of each output sample
        activation: 'SELU' or 'ReLU'
        bias: If set to False, the layers will not learn an additive bias.
            Default: ``False``
    """

    def __init__(self, in_features, hidden_layer_sizes, out_features, activation='SELU', bias=False, dropout_p=0.0):
        super(FeedForwardNetwork, self).__init__()

        if activation == 'SELU':
            Activation = nn.SELU
            Dropout = nn.AlphaDropout
            init_constant = 1.0
        elif activation == 'ReLU':
            Activation = nn.ReLU
            Dropout = nn.Dropout
            init_constant = 2.0

        layer_sizes = [in_features] + hidden_layer_sizes + [out_features]

        layers = []
        for i in range(len(layer_sizes) - 2):
            layers.append(Dropout(dropout_p))
            layers.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1], bias))
            layers.append(Activation())
        layers.append(Dropout(dropout_p))
        layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1], bias))

        self.seq = nn.Sequential(*layers)

        for i in range(1, len(layers), 3):
            # initialization recommended in SELU paper
            nn.init.normal_(layers[i].weight, std=math.sqrt(init_constant / layers[i].weight.size(1)))

    def forward(self, input):
        return self.seq(input)

    # I'm probably *supposed to* override extra_repr but then self.seq (unreadable) will be printed too
    def __repr__(self):
        ffnn = type(self).__name__
        in_features = self.seq[1].in_features
        hidden_layer_sizes = [linear.out_features for linear in self.seq[1:-1:3]]
        out_features = self.seq[-1].out_features
        if len(self.seq) > 2:
            activation = str(self.seq[2])
        else:
            activation = 'None'
        bias = self.seq[1].bias is not None
        dropout_p = self.seq[0].p
        return '{}(in_features={}, hidden_layer_sizes={}, out_features={}, activation={}, bias={}, dropout_p={})'.format(
            ffnn, in_features, hidden_layer_sizes, out_features, activation, bias, dropout_p
        )


#############  molgraph_dataset.py

import gzip
import rdkit
from rdkit import Chem
from rdkit.Chem.rdchem import BondType
from torch.utils import data

from collections import defaultdict


class MolGraphDataset(data.Dataset):
    r"""For datasets consisting of SMILES strings and target values.

    Expects a csv file formatted as:
    comment,smiles,targetName1,targetName2
    Some Comment,CN=C=O,0,1
    ,CC(=O)NCCC1=CNc2c1cc(OC)cc2,1,1

    Args:
        path
        prediction: set to True if dataset contains no target values
    """

    def __init__(self, path, prediction=False):
        with gzip.open(path, 'r') as file:
            self.header_cols = file.readline().decode('utf-8')[:-2].split('\t')
        n_cols = len(self.header_cols)

        self.target_names = self.header_cols[2:]
        self.comments = np.genfromtxt(path, delimiter='\t', skip_header=1, usecols=[0], dtype=np.str, comments=None)
        # comments=None because default is "#", that some smiles contain
        self.smiles = np.genfromtxt(path, delimiter='\t', skip_header=1, usecols=[1], dtype=np.str, comments=None)
        if prediction:
            self.targets = np.empty((len(self.smiles), n_cols - 2))  # may be used to figure out number of targets etc
        else:
            self.targets = np.genfromtxt(path, delimiter='\t', skip_header=1, usecols=range(2, n_cols),
                                         comments=None).reshape(-1, n_cols - 2)

    def __getitem__(self, index):
        adjacency, nodes, edges = smile_to_graph(self.smiles[index])
        targets = self.targets[index, :]
        return (adjacency, nodes, edges), targets

    def __len__(self):
        return len(self.smiles)


rdLogger = rdkit.RDLogger.logger()
rdLogger.setLevel(rdkit.RDLogger.ERROR)


def smile_to_graph(smile):
    molecule = Chem.MolFromSmiles(smile)
    n_atoms = molecule.GetNumAtoms()
    atoms = [molecule.GetAtomWithIdx(i) for i in range(n_atoms)]

    adjacency = Chem.rdmolops.GetAdjacencyMatrix(molecule)
    node_features = np.array([atom_features(atom) for atom in atoms])

    n_edge_features = 4
    edge_features = np.zeros([n_atoms, n_atoms, n_edge_features])
    for bond in molecule.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bond_type = BONDTYPE_TO_INT[bond.GetBondType()]
        edge_features[i, j, bond_type] = 1
        edge_features[j, i, bond_type] = 1

    return adjacency, node_features, edge_features


# rdkit GetBondType() result -> int
BONDTYPE_TO_INT = defaultdict(
    lambda: 0,
    {
        BondType.SINGLE: 0,
        BondType.DOUBLE: 1,
        BondType.TRIPLE: 2,
        BondType.AROMATIC: 3
    }
)


class MolGraphDatasetSubset(MolGraphDataset):
    r"""Takes a subset of MolGraphDataset.

    The "Subset" class of pytorch does not allow column selection
    """

    def __init__(self, path, indices=None, columns=None):
        super(MolGraphDatasetSubset, self).__init__(path)
        if indices:
            self.smiles = self.smiles[indices]
            self.targets = self.targets[indices]
        if columns:
            self.target_names = [self.target_names[col] for col in columns]
            self.targets = self.targets[:, columns]


# data is list of ((g,h,e), [targets])
# to be passable to DataLoader it needs to have this signature,
# where the outer tuple is that which is returned by Dataset's __getitem__
def molgraph_collate_fn(data):
    n_samples = len(data)
    (adjacency_0, node_features_0, edge_features_0), targets_0 = data[0]
    n_nodes_largest_graph = max(map(lambda sample: sample[0][0].shape[0], data))
    n_node_features = node_features_0.shape[1]
    n_edge_features = edge_features_0.shape[2]
    n_targets = len(targets_0)

    adjacency_tensor = torch.zeros(n_samples, n_nodes_largest_graph, n_nodes_largest_graph)
    node_tensor = torch.zeros(n_samples, n_nodes_largest_graph, n_node_features)
    edge_tensor = torch.zeros(n_samples, n_nodes_largest_graph, n_nodes_largest_graph, n_edge_features)
    target_tensor = torch.zeros(n_samples, n_targets)

    for i in range(n_samples):
        (adjacency, node_features, edge_features), target = data[i]
        n_nodes = adjacency.shape[0]

        adjacency_tensor[i, :n_nodes, :n_nodes] = torch.Tensor(adjacency)
        node_tensor[i, :n_nodes, :] = torch.Tensor(node_features)
        edge_tensor[i, :n_nodes, :n_nodes, :] = torch.Tensor(edge_features)

        target_tensor[i] = torch.Tensor(target)

    return adjacency_tensor, node_tensor, edge_tensor, target_tensor


#################  summation_mpnn.py

class SummationMPNN(nn.Module):
    r"""Abstract MPNN class, ExampleMPNN demonstrates how to extend it

    Args:
        node_features (int)
        edge_features (int)
        message_size (int)
        message_passes (int)
        out_features (int)
    """

    def __init__(self, node_features, edge_features, message_size, message_passes, out_features):
        super(SummationMPNN, self).__init__()
        self.node_features = node_features
        self.edge_features = edge_features
        self.message_size = message_size
        self.message_passes = message_passes
        self.out_features = out_features

    # inputs are "batches" of shape (total number of edges in batch, number of features)
    def message_terms(self, nodes, node_neighbours, edges):
        raise NotImplementedError

    # inputs are "batches" of shape (maximum number of nodes in batch, number of features)
    def update(self, nodes, messages):
        raise NotImplementedError

    # inputs are "batches" of same shape as the nodes passed to update
    # node_mask is same shape as inputs and is 1 if elements corresponding exists, otherwise 0
    def readout(self, hidden_nodes, input_nodes, node_mask):
        raise NotImplementedError

    def forward(self, adjacency, nodes, edges):
        edge_batch_batch_indices, edge_batch_node_indices, edge_batch_neighbour_indices = adjacency.nonzero().unbind(-1)
        node_batch_batch_indices, node_batch_node_indices = adjacency.sum(-1).nonzero().unbind(-1)

        same_batch = node_batch_batch_indices.view(-1, 1) == edge_batch_batch_indices
        same_node = node_batch_node_indices.view(-1, 1) == edge_batch_node_indices
        # element_ij = 1 if edge_batch_edges[j] is connected with node_batch_nodes[i], else 0
        message_summation_matrix = (same_batch * same_node).float()

        edge_batch_edges = edges[edge_batch_batch_indices, edge_batch_node_indices, edge_batch_neighbour_indices, :]
        hidden_nodes = nodes.clone()
        node_batch_nodes = hidden_nodes[node_batch_batch_indices, node_batch_node_indices, :]

        for i in range(self.message_passes):
            edge_batch_nodes = hidden_nodes[edge_batch_batch_indices, edge_batch_node_indices, :]
            edge_batch_neighbours = hidden_nodes[edge_batch_batch_indices, edge_batch_neighbour_indices, :]

            message_terms = self.message_terms(edge_batch_nodes, edge_batch_neighbours, edge_batch_edges)
            # the summation in eq. 1 of the NMPQC paper happens here
            messages = torch.matmul(message_summation_matrix, message_terms)
            node_batch_nodes = self.update(node_batch_nodes, messages)

            hidden_nodes[node_batch_batch_indices, node_batch_node_indices, :] = node_batch_nodes

        node_mask = (adjacency.sum(-1) != 0)  # .unsqueeze(-1).expand_as(nodes)
        output = self.readout(hidden_nodes, nodes, node_mask)
        return output


##############  summation_mpnn_implementaion.py


class ENNS2V(SummationMPNN):

    def __init__(self, node_features, edge_features, message_size, message_passes, out_features,
                 enn_depth=4, enn_hidden_dim=200, enn_dropout_p=0,
                 s2v_lstm_computations=12, s2v_memory_size=50,
                 out_depth=1, out_hidden_dim=200, out_dropout_p=0):
        super(ENNS2V, self).__init__(node_features, edge_features, message_size, message_passes, out_features)

        self.enn = FeedForwardNetwork(
            edge_features, [enn_hidden_dim] * enn_depth, node_features * message_size, dropout_p=enn_dropout_p
        )
        self.gru = nn.GRUCell(input_size=message_size, hidden_size=node_features, bias=False)
        self.s2v = Set2Vec(node_features, s2v_lstm_computations, s2v_memory_size)
        self.out_nn = FeedForwardNetwork(
            s2v_memory_size * 2, [out_hidden_dim] * out_depth, out_features, dropout_p=out_dropout_p, bias=False
        )

    def message_terms(self, nodes, node_neighbours, edges):
        enn_output = self.enn(edges)
        matrices = enn_output.view(-1, self.message_size, self.node_features)
        msg_terms = torch.matmul(matrices, node_neighbours.unsqueeze(-1)).squeeze(-1)
        return msg_terms

    def update(self, nodes, messages):
        return self.gru(messages, nodes)

    def readout(self, hidden_nodes, input_nodes, node_mask):
        graph_embeddings = self.s2v(hidden_nodes, input_nodes, node_mask)
        return self.out_nn(graph_embeddings)


class GGNN(SummationMPNN):

    def __init__(self, node_features, edge_features, message_size, message_passes, out_features,
                 msg_depth=4, msg_hidden_dim=200, msg_dropout_p=0.0,
                 gather_width=100,
                 gather_att_depth=3, gather_att_hidden_dim=100, gather_att_dropout_p=0.0,
                 gather_emb_depth=3, gather_emb_hidden_dim=100, gather_emb_dropout_p=0.0,
                 out_depth=2, out_hidden_dim=100, out_dropout_p=0.0, out_layer_shrinkage=1.0):
        super(GGNN, self).__init__(node_features, edge_features, message_size, message_passes, out_features)

        self.msg_nns = nn.ModuleList()
        for _ in range(edge_features):
            self.msg_nns.append(
                FeedForwardNetwork(node_features, [msg_hidden_dim] * msg_depth, message_size, dropout_p=msg_dropout_p,
                                   bias=False)
            )
        self.gru = nn.GRUCell(input_size=message_size, hidden_size=node_features, bias=False)
        self.gather = GraphGather(
            node_features, gather_width,
            gather_att_depth, gather_att_hidden_dim, gather_att_dropout_p,
            gather_emb_depth, gather_emb_hidden_dim, gather_emb_dropout_p
        )
        out_layer_sizes = [  # example: depth 5, dim 50, shrinkage 0.5 => out_layer_sizes [50, 42, 35, 30, 25]
            round(out_hidden_dim * (out_layer_shrinkage ** (i / (out_depth - 1 + 1e-9)))) for i in range(out_depth)
        ]
        self.out_nn = FeedForwardNetwork(gather_width, out_layer_sizes, out_features, dropout_p=out_dropout_p)

    def message_terms(self, nodes, node_neighbours, edges):
        # terms_masked_per_edge contains (edge_batch_size, message_size)-shape tensors, that has 0s in all rows except
        # the ones corresponding to the edge type indicated by the list index
        # intuitive way of writing this involves a torch.stack along batch dimension and is immensely slow
        edges_v = edges.view(-1, self.edge_features, 1)
        node_neighbours_v = edges_v * node_neighbours.view(-1, 1, self.node_features)
        terms_masked_per_edge = [
            edges_v[:, i, :] * self.msg_nns[i](node_neighbours_v[:, i, :]) for i in range(self.edge_features)
        ]
        return sum(terms_masked_per_edge)

    def update(self, nodes, messages):
        return self.gru(messages, nodes)

    def readout(self, hidden_nodes, input_nodes, node_mask):
        graph_embeddings = self.gather(hidden_nodes, input_nodes, node_mask)
        return self.out_nn(graph_embeddings)


###############  gnn_test_case.py

import unittest

# padding invariance is important because it indicates whether node vectors full of 0s (corresponding to
# non-existant nodes) affects the output
#
# node order invariance is important because it shold in principle not matter
#
# add a newly implemented MPNN by extending MPNNTestCase and overrding it's member net, see bottom of file

# some graphs generated that 1) are small enough so that the tensors are readable and
# 2) have different size adjacency matrices, to assure molgraph_collate_fn:s padding is being used


COMPOUNDS = ['OC(=O)[C@@H]1CCN1', 'BrC1=NC=CC=C1', 'ClCCOC=C']
BATCH_SIZE = len(COMPOUNDS)
DUMMY_ADJ, DUMMY_NODES, DUMMY_EDGES, DUMMY_TARGET = \
    molgraph_collate_fn(list(map(lambda smile: (smile_to_graph(smile), [1]), COMPOUNDS)))
NODE_FEATURES = 5
DUMMY_NODES = DUMMY_NODES[:, :, :NODE_FEATURES]  # dropping all node features except a few
EDGE_FEATURES = DUMMY_EDGES.shape[3]
OUT_FEATURES = DUMMY_TARGET.shape[1]


class DummyMPNN(SummationMPNN):

    def __init__(self, node_features, edge_features, message_size, message_passes, out_features):
        super(DummyMPNN, self).__init__(node_features, edge_features, message_size, message_passes, out_features)

        # breaks padding invariance, unless node_mask is used properly
        # self.readout_layer = nn.Linear(NODE_FEATURES, 1, bias=True)
        self.readout_layer = nn.Linear(NODE_FEATURES, 1, bias=False)

    def message_terms(self, nodes, node_neighbours, edges):
        message_terms = nodes + node_neighbours
        return message_terms

    def update(self, nodes, messages):
        return messages

    def readout(self, hidden_nodes, input_nodes, node_mask):
        output = self.readout_layer(hidden_nodes).sum(dim=1)
        return output


class GNNTestCase(unittest.TestCase):
    NODE_FEATURES = NODE_FEATURES
    EDGE_FEATURES = EDGE_FEATURES
    OUT_FEATURES = OUT_FEATURES

    # keep number of weights down to make this run fast
    MESSAGE_SIZE = 5
    MESSAGE_PASSES = 2

    net = DummyMPNN(NODE_FEATURES, EDGE_FEATURES, MESSAGE_SIZE, MESSAGE_PASSES, OUT_FEATURES)

    @classmethod
    def setUpClass(self):
        optimizer = optim.Adam(self.net.parameters(), lr=0.0005)
        criterion = nn.MSELoss()
        self.net.train()
        for i in range(10):
            self.net.zero_grad()
            output = self.net(DUMMY_ADJ, DUMMY_NODES, DUMMY_EDGES)
            loss = criterion(output, DUMMY_TARGET)
            loss.backward()
            optimizer.step()

    def test_padding_invariance(self):
        padded_dim_size = DUMMY_ADJ.shape[1] + 5
        padded_adj = torch.zeros(BATCH_SIZE, padded_dim_size, padded_dim_size)
        padded_adj[:, :DUMMY_ADJ.shape[1], :DUMMY_ADJ.shape[2]] = DUMMY_ADJ
        padded_nodes = torch.zeros(BATCH_SIZE, padded_dim_size, NODE_FEATURES)
        padded_nodes[:, :DUMMY_NODES.shape[1], :] = DUMMY_NODES
        padded_edges = torch.zeros(BATCH_SIZE, padded_dim_size, padded_dim_size, EDGE_FEATURES)
        padded_edges[:, :DUMMY_EDGES.shape[1], :DUMMY_EDGES.shape[2], :] = DUMMY_EDGES

        with torch.no_grad():
            self.net.eval()
            normal_output = self.net(DUMMY_ADJ, DUMMY_NODES, DUMMY_EDGES)
            extra_padding_output = self.net(padded_adj, padded_nodes, padded_edges)
            # consider outputs equal if difference is smaller than 0.001%
            # this is not always exact for whatever numerical reason
            self.assertTrue(np.allclose(normal_output, extra_padding_output, rtol=1e-5))

    def test_sample_order_invariance(self):
        permutation = [1, 2, 0]
        shuffled_adj = DUMMY_ADJ[permutation, :, :]
        shuffled_nodes = DUMMY_NODES[permutation, :, :]
        shuffled_edges = DUMMY_EDGES[permutation, :, :, :]

        with torch.no_grad():
            self.net.eval()
            output = self.net(DUMMY_ADJ, DUMMY_NODES, DUMMY_EDGES)
            shuffling_after_prop_output = output[permutation]
            shuffling_before_prop_output = self.net(shuffled_adj, shuffled_nodes, shuffled_edges)
            # consider outputs equal if difference is smaller than 0.001%
            # this is not always exact for whatever numerical reason
            self.assertTrue(np.allclose(shuffling_after_prop_output, shuffling_before_prop_output, rtol=1e-5))

    def test_node_order_invariance(self):
        shuffled_adj = torch.zeros_like(DUMMY_ADJ)
        shuffled_nodes = torch.zeros_like(DUMMY_NODES)
        shuffled_edges = torch.zeros_like(DUMMY_EDGES)
        for i in range(BATCH_SIZE):
            n_real_nodes = (DUMMY_ADJ[i, :, :].sum(dim=1) != 0).sum().item()
            perm = np.random.permutation(n_real_nodes).reshape(1, -1)
            perm_t = perm.transpose()
            shuffled_adj[i, :n_real_nodes, :n_real_nodes] = DUMMY_ADJ[i, perm, perm_t]
            shuffled_nodes[i, :n_real_nodes] = DUMMY_NODES[i, perm, :]
            shuffled_edges[i, :n_real_nodes, :n_real_nodes, :] = DUMMY_EDGES[i, perm, perm_t, :]

        with torch.no_grad():
            self.net.eval()
            normal_output = self.net(DUMMY_ADJ, DUMMY_NODES, DUMMY_EDGES)
            shuffling_output = self.net(shuffled_adj, shuffled_nodes, shuffled_edges)
            # consider outputs equal if difference is smaller than 0.001%
            # this is not always exact for whatever numerical reason
            self.assertTrue(np.allclose(normal_output, shuffling_output, rtol=1e-5))


############### test_examples.py

import unittest
from torch import optim, nn

# this script checks that the MPNN implementations can be initialized and trained for a few iterations
# without crashing, and that they fulfill the principles of invariance to node order, padding size
# and shuffled input order
#
# padding invariance is important because it indicates whether node vectors full of 0s (corresponding to
# non-existant nodes) affects the output
#
# node order invariance is important because it shold in principle not matter
#
# add a newly implemented MPNN by extending MPNNTestCase and overrding it's member net, see bottom of file

# some graphs generated that 1) are small enough so that the tensors are readable and
# 2) have different size adjacency matrices, to assure molgraph_collate_fn:s padding is being used


COMPOUNDS = ['OC(=O)[C@@H]1CCN1', 'BrC1=NC=CC=C1', 'ClCCOC=C']
BATCH_SIZE = len(COMPOUNDS)
DUMMY_ADJ, DUMMY_NODES, DUMMY_EDGES, DUMMY_TARGET = \
    molgraph_collate_fn(list(map(lambda smile: (smile_to_graph(smile), [1]), COMPOUNDS)))
NODE_FEATURES = 5
DUMMY_NODES = DUMMY_NODES[:, :, :NODE_FEATURES]  # dropping all node features except a few
EDGE_FEATURES = DUMMY_EDGES.shape[3]
OUT_FEATURES = DUMMY_TARGET.shape[1]


class DummyMPNN(SummationMPNN):

    def __init__(self, node_features, edge_features, message_size, message_passes, out_features):
        super(DummyMPNN, self).__init__(node_features, edge_features, message_size, message_passes, out_features)

        # breaks padding invariance, unless node_mask is used properly
        # self.readout_layer = nn.Linear(NODE_FEATURES, 1, bias=True)
        self.readout_layer = nn.Linear(NODE_FEATURES, 1, bias=False)

    def message_terms(self, nodes, node_neighbours, edges):
        message_terms = nodes + node_neighbours
        return message_terms

    def update(self, nodes, messages):
        return messages

    def readout(self, hidden_nodes, input_nodes, node_mask):
        output = self.readout_layer(hidden_nodes).sum(dim=1)
        return output


class GNNTestCase(unittest.TestCase):
    NODE_FEATURES = NODE_FEATURES
    EDGE_FEATURES = EDGE_FEATURES
    OUT_FEATURES = OUT_FEATURES

    # keep number of weights down to make this run fast
    MESSAGE_SIZE = 5
    MESSAGE_PASSES = 2

    net = DummyMPNN(NODE_FEATURES, EDGE_FEATURES, MESSAGE_SIZE, MESSAGE_PASSES, OUT_FEATURES)

    @classmethod
    def setUpClass(self):
        optimizer = optim.Adam(self.net.parameters(), lr=0.0005)
        criterion = nn.MSELoss()
        self.net.train()
        for i in range(10):
            self.net.zero_grad()
            output = self.net(DUMMY_ADJ, DUMMY_NODES, DUMMY_EDGES)
            loss = criterion(output, DUMMY_TARGET)
            loss.backward()
            optimizer.step()

    def test_padding_invariance(self):
        padded_dim_size = DUMMY_ADJ.shape[1] + 5
        padded_adj = torch.zeros(BATCH_SIZE, padded_dim_size, padded_dim_size)
        padded_adj[:, :DUMMY_ADJ.shape[1], :DUMMY_ADJ.shape[2]] = DUMMY_ADJ
        padded_nodes = torch.zeros(BATCH_SIZE, padded_dim_size, NODE_FEATURES)
        padded_nodes[:, :DUMMY_NODES.shape[1], :] = DUMMY_NODES
        padded_edges = torch.zeros(BATCH_SIZE, padded_dim_size, padded_dim_size, EDGE_FEATURES)
        padded_edges[:, :DUMMY_EDGES.shape[1], :DUMMY_EDGES.shape[2], :] = DUMMY_EDGES

        with torch.no_grad():
            self.net.eval()
            normal_output = self.net(DUMMY_ADJ, DUMMY_NODES, DUMMY_EDGES)
            extra_padding_output = self.net(padded_adj, padded_nodes, padded_edges)
            # consider outputs equal if difference is smaller than 0.001%
            # this is not always exact for whatever numerical reason
            self.assertTrue(np.allclose(normal_output, extra_padding_output, rtol=1e-5))

    def test_sample_order_invariance(self):
        permutation = [1, 2, 0]
        shuffled_adj = DUMMY_ADJ[permutation, :, :]
        shuffled_nodes = DUMMY_NODES[permutation, :, :]
        shuffled_edges = DUMMY_EDGES[permutation, :, :, :]

        with torch.no_grad():
            self.net.eval()
            output = self.net(DUMMY_ADJ, DUMMY_NODES, DUMMY_EDGES)
            shuffling_after_prop_output = output[permutation]
            shuffling_before_prop_output = self.net(shuffled_adj, shuffled_nodes, shuffled_edges)
            # consider outputs equal if difference is smaller than 0.001%
            # this is not always exact for whatever numerical reason
            self.assertTrue(np.allclose(shuffling_after_prop_output, shuffling_before_prop_output, rtol=1e-5))

    def test_node_order_invariance(self):
        shuffled_adj = torch.zeros_like(DUMMY_ADJ)
        shuffled_nodes = torch.zeros_like(DUMMY_NODES)
        shuffled_edges = torch.zeros_like(DUMMY_EDGES)
        for i in range(BATCH_SIZE):
            n_real_nodes = (DUMMY_ADJ[i, :, :].sum(dim=1) != 0).sum().item()
            perm = np.random.permutation(n_real_nodes).reshape(1, -1)
            perm_t = perm.transpose()
            shuffled_adj[i, :n_real_nodes, :n_real_nodes] = DUMMY_ADJ[i, perm, perm_t]
            shuffled_nodes[i, :n_real_nodes] = DUMMY_NODES[i, perm, :]
            shuffled_edges[i, :n_real_nodes, :n_real_nodes, :] = DUMMY_EDGES[i, perm, perm_t, :]

        with torch.no_grad():
            self.net.eval()
            normal_output = self.net(DUMMY_ADJ, DUMMY_NODES, DUMMY_EDGES)
            shuffling_output = self.net(shuffled_adj, shuffled_nodes, shuffled_edges)
            # consider outputs equal if difference is smaller than 0.001%
            # this is not always exact for whatever numerical reason
            self.assertTrue(np.allclose(normal_output, shuffling_output, rtol=1e-5))


#################example.py


class ExampleAttentionMPNN(AggregationMPNN):

    def __init__(self, node_features, edge_features, out_features, message_passes=3):
        super(ExampleAttentionMPNN, self).__init__(node_features, edge_features, node_features, message_passes,
                                                   out_features)

        self.message_att_weight = nn.Linear(node_features, 1)
        self.message_emb_weight = nn.Linear(node_features, node_features)
        self.out_weight = nn.Linear(node_features, out_features)

    def aggregate_message(self, nodes, node_neighbours, edges, mask):
        neighbourhood = torch.cat([nodes.unsqueeze(1), node_neighbours], dim=1)

        neighbourhood_mask = torch.cat([torch.ones((mask.shape[0], 1)), mask], dim=1)
        energy_mask = (neighbourhood_mask == 0).float() * 1e6

        energies = self.message_att_weight(neighbourhood) - energy_mask.unsqueeze(-1)
        attention = torch.softmax(energies, dim=1)
        embedding = self.message_emb_weight(neighbourhood)
        messages = torch.sum(attention * embedding, dim=1)
        return messages

    def update(self, nodes, messages):
        hidden_nodes = torch.selu(messages)
        return hidden_nodes

    def readout(self, hidden_nodes, input_nodes, node_mask):
        graph_embedding = torch.sum(hidden_nodes, dim=1)
        output = self.out_weight(graph_embedding)
        return output


# if __name__ == '__main__':
#     print('loading data')
#     train_dataset = MolGraphDataset('toydata/piece-of-esol.csv.gz')
#     train_dataloader = DataLoader(train_dataset, batch_size=50, shuffle=True, collate_fn=molgraph_collate_fn)
#
#     print('instantiating ExampleAttentionMPNN')
#     # 75 and 4 corresponds to MolGraphDataset, 1 corresponds to ESOL
#     net = ExampleAttentionMPNN(node_features=75, edge_features=4, out_features=1)
#     optimizer = optim.Adam(net.parameters(), lr=2e-5)
#     criterion = nn.MSELoss()
#
#     print('starting training')
#     for epoch in range(10):
#         for i_batch, batch in enumerate(train_dataloader):
#             adjacency, nodes, edges, target = batch
#             optimizer.zero_grad()
#             output = net(adjacency, nodes, edges)
#             loss = criterion(output, target)
#             loss.backward()
#             torch.nn.utils.clip_grad_value_(net.parameters(), 5.0)
#             optimizer.step()
#
#         print('epoch: {}, training MSE: {}'.format(epoch + 1, loss))


################  test_example.py


NODE_FEATURES = GNNTestCase.NODE_FEATURES
EDGE_FEATURES = GNNTestCase.EDGE_FEATURES
MESSAGE_SIZE = GNNTestCase.MESSAGE_SIZE
MESSAGE_PASSES = GNNTestCase.MESSAGE_PASSES
OUT_FEATURES = GNNTestCase.OUT_FEATURES


class ExampleMPNNTestCase(GNNTestCase):
    net = ExampleAttentionMPNN(
        NODE_FEATURES, EDGE_FEATURES, OUT_FEATURES
    )


###############  test_implementation.py


NODE_FEATURES = GNNTestCase.NODE_FEATURES
EDGE_FEATURES = GNNTestCase.EDGE_FEATURES
MESSAGE_SIZE = GNNTestCase.MESSAGE_SIZE
MESSAGE_PASSES = GNNTestCase.MESSAGE_PASSES
OUT_FEATURES = GNNTestCase.OUT_FEATURES


class ENNS2VTestCase(GNNTestCase):
    net = ENNS2V(
        NODE_FEATURES, EDGE_FEATURES, MESSAGE_SIZE, MESSAGE_PASSES, OUT_FEATURES,
        enn_hidden_dim=4, out_hidden_dim=6, s2v_memory_size=5
    )


class GGNNTestCase(GNNTestCase):
    net = GGNN(
        NODE_FEATURES, EDGE_FEATURES, MESSAGE_SIZE, MESSAGE_PASSES, OUT_FEATURES,
        msg_hidden_dim=4, gather_width=7, gather_att_hidden_dim=9, gather_emb_hidden_dim=7, out_hidden_dim=3
    )


class AttentionENNS2VTestCase(GNNTestCase):
    net = AttentionENNS2V(
        NODE_FEATURES, EDGE_FEATURES, MESSAGE_SIZE, 10, OUT_FEATURES,
        enn_hidden_dim=4, out_hidden_dim=6, s2v_memory_size=5
    )


class AttentionGGNNTestCase(GNNTestCase):
    net = AttentionGGNN(
        NODE_FEATURES, EDGE_FEATURES, MESSAGE_SIZE, MESSAGE_PASSES, OUT_FEATURES,
        att_hidden_dim=9,
        msg_hidden_dim=4, gather_width=7, gather_att_hidden_dim=9, gather_emb_hidden_dim=7, out_hidden_dim=3
    )


class EMNImplementationTestCase(GNNTestCase):
    EDGE_EMBEDDING_SIZE = 7
    net = EMNImplementation(
        edge_features=EDGE_FEATURES, edge_embedding_size=EDGE_EMBEDDING_SIZE, message_passes=MESSAGE_PASSES,
        out_features=OUT_FEATURES, node_features=NODE_FEATURES,
        msg_hidden_dim=4, gather_width=7, gather_att_hidden_dim=9, gather_emb_hidden_dim=7, out_hidden_dim=3
    )


#################  losses.py


class MaskedMultiTaskCrossEntropy(nn.Module):

    def forward(self, input, target):
        scores = torch.sigmoid(input)
        target_active = (target == 1).float()  # from -1/1 to 0/1
        loss_terms = -(target_active * torch.log(scores) + (1 - target_active) * torch.log(1 - scores))
        missing_values_mask = (target != 0).float()
        return (loss_terms * missing_values_mask).sum() / missing_values_mask.sum()


LOSS_FUNCTIONS = {
    'MaskedMultiTaskCrossEntropy': MaskedMultiTaskCrossEntropy(),
    'MSE': nn.MSELoss()
}

#####################  predict.py

import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--cuda', action='store_true', default=False, help='Enables CUDA training')

parser.add_argument('--modelpath', type=str, help='Path to saved model', required=True)
parser.add_argument('--datapath', type=str, default='toydata/piece-of-tox21-test.csv.gz', help='Testing dataset path')
parser.add_argument('--score', type=str, choices=['roc-auc', 'pr-auc', 'MSE', 'RMSE'], required=True)

# if __name__ == '__main__':
#     global args
#     args = parser.parse_args()
#
#     with torch.no_grad():
#         net = torch.load(args.modelpath)
#         if args.cuda:
#             net = net.cuda()
#         else:
#             net = net.cpu()
#         net.eval()
#
#         dataset = MolGraphDataset(args.datapath, prediction=True)
#         dataloader = DataLoader(dataset, batch_size=50, collate_fn=molgraph_collate_fn)
#
#         batch_outputs = []
#         for i_batch, batch in enumerate(dataloader):
#             if args.cuda:
#                 batch = [tensor.cuda() for tensor in batch]
#             adjacency, nodes, edges, target = batch
#             batch_output = net(adjacency, nodes, edges)
#             if args.score == 'roc-auc' or args.score == 'pr-auc':
#                 batch_output = torch.sigmoid(batch_output)
#             batch_outputs.append(batch_output)
#
#         output = torch.cat(batch_outputs).cpu().numpy()
#
#         print('\t'.join([str(col) for col in dataset.header_cols]))
#         for i in range(len(output)):
#             comment = dataset.comments[i]
#             row_str = '\t'.join([str(x) for x in output[i]])
#             print('{}, {}'.format(comment, row_str))


#################  predict.py


MODEL_CONSTRUCTOR_DICTS = {
    'ENNS2V': {
        'constructor': ENNS2V,
        'hyperparameters': {
            'message-passes': {'type': int, 'default': 5},
            'message-size': {'type': int, 'default': 50},
            'enn-depth': {'type': int, 'default': 3},
            'enn-hidden-dim': {'type': int, 'default': 100},
            'enn-dropout-p': {'type': float, 'default': 0.0},
            's2v-lstm-computations': {'type': int, 'default': 7},
            's2v-memory-size': {'type': int, 'default': 50},
            'out-depth': {'type': int, 'default': 2},
            'out-hidden-dim': {'type': int, 'default': 300},
            'out-dropout-p': {'type': float, 'default': 0.0}
        }
    },
    'GGNN': {
        'constructor': GGNN,
        'hyperparameters': {  # the below, batch size 50, learn rate 1.176e-5 and 1200 epochs is good for ESOL
            'message-passes': {'type': int, 'default': 1},
            'message-size': {'type': int, 'default': 25},
            'msg-depth': {'type': int, 'default': 2},
            'msg-hidden-dim': {'type': int, 'default': 50},
            'msg-dropout-p': {'type': float, 'default': 0.0},
            'gather-width': {'type': int, 'default': 45},
            'gather-att-depth': {'type': int, 'default': 2},
            'gather-att-hidden-dim': {'type': int, 'default': 26},
            'gather-att-dropout-p': {'type': float, 'default': 0.0},
            'gather-emb-depth': {'type': int, 'default': 2},
            'gather-emb-hidden-dim': {'type': int, 'default': 26},
            'gather-emb-dropout-p': {'type': float, 'default': 0.0},
            'out-depth': {'type': int, 'default': 2},
            'out-hidden-dim': {'type': int, 'default': 450},
            'out-dropout-p': {'type': float, 'default': 0.00463},
            'out-layer-shrinkage': {'type': float, 'default': 0.5028}
        }
    },
    'AttentionGGNN': {  # the below, batch size 50, learn rate 1.560e-5 and 600 epochs is good for BBBP
        'constructor': AttentionGGNN,
        'hyperparameters': {
            'message-passes': {'type': int, 'default': 8},
            'message-size': {'type': int, 'default': 25},
            'msg-depth': {'type': int, 'default': 2},
            'msg-hidden-dim': {'type': int, 'default': 50},
            'msg-dropout-p': {'type': float, 'default': 0.0},
            'att-depth': {'type': int, 'default': 2},
            'att-hidden-dim': {'type': int, 'default': 50},
            'att-dropout-p': {'type': float, 'default': 0.0},
            'gather-width': {'type': int, 'default': 45},
            'gather-att-depth': {'type': int, 'default': 2},
            'gather-att-hidden-dim': {'type': int, 'default': 45},
            'gather-att-dropout-p': {'type': float, 'default': 0.0},
            'gather-emb-depth': {'type': int, 'default': 2},
            'gather-emb-hidden-dim': {'type': int, 'default': 26},
            'gather-emb-dropout-p': {'type': float, 'default': 0.0},
            'out-depth': {'type': int, 'default': 2},
            'out-hidden-dim': {'type': int, 'default': 560},
            'out-dropout-p': {'type': float, 'default': 0.1},
            'out-layer-shrinkage': {'type': float, 'default': 0.6}
        }
    },
    'EMN': {  # the below, batch size 50, learn rate 1e-4 and 1000 epochs is good for SIDER
        'constructor': EMNImplementation,
        'hyperparameters': {
            'message-passes': {'type': int, 'default': 8},
            'edge-embedding-size': {'type': int, 'default': 50},
            'edge-emb-depth': {'type': int, 'default': 2},
            'edge-emb-hidden-dim': {'type': int, 'default': 105},
            'edge-emb-dropout-p': {'type': float, 'default': 0.0},
            'att-depth': {'type': int, 'default': 2},
            'att-hidden-dim': {'type': int, 'default': 85},
            'att-dropout-p': {'type': float, 'default': 0.0},
            'msg-depth': {'type': int, 'default': 2},
            'msg-hidden-dim': {'type': int, 'default': 150},
            'msg-dropout-p': {'type': float, 'default': 0.0},
            'gather-width': {'type': int, 'default': 45},
            'gather-att-depth': {'type': int, 'default': 2},
            'gather-att-hidden-dim': {'type': int, 'default': 45},
            'gather-att-dropout-p': {'type': float, 'default': 0.0},
            'gather-emb-depth': {'type': int, 'default': 2},
            'gather-emb-hidden-dim': {'type': int, 'default': 45},
            'gather-emb-dropout-p': {'type': float, 'default': 0.0},
            'out-depth': {'type': int, 'default': 2},
            'out-hidden-dim': {'type': int, 'default': 450},
            'out-dropout-p': {'type': float, 'default': 0.1},
            'out-layer-shrinkage': {'type': float, 'default': 0.6}
        }
    }
}

# common_args_parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, add_help=False)
# 
# common_args_parser.add_argument('--cuda', action='store_true', default=False, help='Enables CUDA training')
# 
# common_args_parser.add_argument('--train-set', type=str, default='toydata/piece-of-tox21-train.csv.gz',
#                                 help='Training dataset path')
# common_args_parser.add_argument('--valid-set', type=str, default='toydata/piece-of-tox21-valid.csv.gz',
#                                 help='Validation dataset path')
# common_args_parser.add_argument('--test-set', type=str, default='toydata/piece-of-tox21-test.csv.gz',
#                                 help='Testing dataset path')
# common_args_parser.add_argument('--loss', type=str, default='MaskedMultiTaskCrossEntropy',
#                                 choices=[k for k, v in LOSS_FUNCTIONS.items()])
# common_args_parser.add_argument('--score', type=str, default='roc-auc', help='roc-auc or MSE')
# 
# common_args_parser.add_argument('--epochs', type=int, default=500, help='Number of training epochs')
# common_args_parser.add_argument('--batch-size', type=int, default=50, help='Number of graphs in a mini-batch')
# common_args_parser.add_argument('--learn-rate', type=float, default=1e-5)
# 
# common_args_parser.add_argument('--savemodel', action='store_true', default=False,
#                                 help='Saves model with highest validation score')
# common_args_parser.add_argument('--logging', type=str, default='less', choices=[k for k, v in LOG_FUNCTIONS.items()])
# 
# main_parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# subparsers = main_parser.add_subparsers(help=', '.join([k for k, v in MODEL_CONSTRUCTOR_DICTS.items()]), dest='model')
# subparsers.required = True
# 
# model_parsers = {}
# for model_name, constructor_dict in MODEL_CONSTRUCTOR_DICTS.items():
#     subparser = subparsers.add_parser(model_name, parents=[common_args_parser])
#     for hp_name, hp_kwargs in constructor_dict['hyperparameters'].items():
#         subparser.add_argument('--' + hp_name, **hp_kwargs, help=model_name + ' hyperparameter')
#     model_parsers[model_name] = subparser
# 
# 
# def main():
#     global args
#     args = main_parser.parse_args()
#     args_dict = vars(args)
#     # dictionary of hyperparameters that are specific to the chosen model
#     model_hp_kwargs = {
#         name.replace('-', '_'): args_dict[name.replace('-', '_')]  # argparse converts to "_" implicitly
#         for name, v in MODEL_CONSTRUCTOR_DICTS[args.model]['hyperparameters'].items()
#     }
# 
#     train_dataset = MolGraphDataset(args.train_set)
#     train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
#                                   collate_fn=molgraph_collate_fn)
#     validation_dataset = MolGraphDataset(args.valid_set)
#     validation_dataloader = DataLoader(validation_dataset, batch_size=args.batch_size, collate_fn=molgraph_collate_fn)
#     test_dataset = MolGraphDataset(args.test_set)
#     test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, collate_fn=molgraph_collate_fn)
# 
#     ((sample_adjacency, sample_nodes, sample_edges), sample_target) = train_dataset[0]
#     net = MODEL_CONSTRUCTOR_DICTS[args.model]['constructor'](
#         node_features=len(sample_nodes[0]), edge_features=len(sample_edges[0, 0]), out_features=len(sample_target),
#         **model_hp_kwargs
#     )
#     if args.cuda:
#         net = net.cuda()
# 
#     optimizer = optim.Adam(net.parameters(), lr=args.learn_rate)
#     criterion = LOSS_FUNCTIONS[args.loss]
# 
#     for epoch in range(args.epochs):
#         net.train()
#         for i_batch, batch in enumerate(train_dataloader):
# 
#             if args.cuda:
#                 batch = [tensor.cuda() for tensor in batch]
#             adjacency, nodes, edges, target = batch
# 
#             optimizer.zero_grad()
#             output = net(adjacency, nodes, edges)
#             loss = criterion(output, target)
#             loss.backward()
#             torch.nn.utils.clip_grad_value_(net.parameters(), 5.0)
#             optimizer.step()
# 
#         with torch.no_grad():
#             net.eval()
#             LOG_FUNCTIONS[args.logging](
#                 net, train_dataloader, validation_dataloader, test_dataloader, criterion, epoch, args
#             )


#################  train_logging.py

from sklearn.metrics import roc_auc_score, average_precision_score

import datetime

OUTPUT_DIR = 'output/'
TENSORBOARDX_OUTPUT_DIR = 'tbxoutput/'
SAVEDMODELS_DIR = 'savedmodels/'
# time of importing this file, including microseconds because slurm may start queued jobs very close in time
DATETIME_STR = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')


class Globals:  # container for all objects getting passed between log calls
    evaluate_called = False


g = Globals()

TRAIN_SUBSET_SIZE = 500
SUBSET_LOADER_BATCH_SIZE = 50


def subset_loader(dataloader, subset_size, seed=0):
    np.random.seed(seed)
    random_indices = np.random.choice(len(dataloader.dataset), subset_size)
    np.random.seed()  # "reset" seed
    subset = data.Subset(dataloader.dataset, random_indices)
    return data.DataLoader(subset, batch_size=SUBSET_LOADER_BATCH_SIZE, collate_fn=dataloader.collate_fn)


def compute_roc_auc(output, target):
    def roc_auc_of_column(scores_column, targets_column):
        relevant_indices = targets_column.nonzero()
        relevant_targets = targets_column[relevant_indices]
        relevant_scores = scores_column[relevant_indices]
        relevant_targets_np = relevant_targets.cpu().numpy()
        relevant_targets_np = relevant_targets_np == 1  # -1s/1s => Falses/Trues
        try:
            score = roc_auc_score(relevant_targets_np, relevant_scores.cpu().detach().numpy())
        except:
            score = np.nan
        return score

    scores = torch.sigmoid(output)
    roc_aucs = [
        roc_auc_of_column(scores[:, i], target[:, i])
        for i in range(target.shape[1])
    ]
    return roc_aucs


def compute_pr_auc(output, target):
    def pr_auc_of_column(scores_column, targets_column):
        relevant_indices = targets_column.nonzero()
        relevant_targets = targets_column[relevant_indices]
        relevant_scores = scores_column[relevant_indices]
        relevant_targets_np = relevant_targets.cpu().numpy()
        relevant_targets_np = relevant_targets_np == 1  # -1s/1s => Falses/Trues
        return average_precision_score(relevant_targets_np, relevant_scores.cpu().detach().numpy())

    scores = torch.sigmoid(output)
    pr_aucs = [
        pr_auc_of_column(scores[:, i], target[:, i])
        for i in range(target.shape[1])
    ]
    return pr_aucs


def compute_mse(output, target):
    nn_mse = torch.nn.MSELoss()
    mses = [
        nn_mse(output[:, i], target[:, i]).cpu().detach().numpy()
        for i in range(target.shape[1])
    ]
    return mses


def compute_rmse(output, target):
    mses = compute_mse(output, target)
    return np.sqrt(mses)


SCORE_FUNCTIONS = {
    'roc-auc': compute_roc_auc, 'pr-auc': compute_pr_auc, 'MSE': compute_mse, 'RMSE': compute_rmse
}


def feed_net(net, dataloader, criterion, cuda):
    batch_outputs = []
    batch_losses = []
    batch_targets = []
    for i_batch, batch in enumerate(dataloader):
        if cuda:
            batch = [tensor.cuda(non_blocking=True) for tensor in batch]
        adjacency, nodes, edges, target = batch
        output = net(adjacency, nodes, edges)
        loss = criterion(output, target)
        batch_outputs.append(output)
        batch_losses.append(loss.item())
        batch_targets.append(target)
    outputs = torch.cat(batch_outputs)
    loss = np.mean(batch_losses)
    targets = torch.cat(batch_targets)
    return outputs, loss, targets


def evaluate_net(net, train_dataloader, validation_dataloader, test_dataloader, criterion, args):
    global g
    if not g.evaluate_called:
        g.evaluate_called = True
        if args.score == 'roc-auc' or args.score == 'pr-auc':
            g.best_mean_train_score, g.best_mean_validation_score, g.best_mean_test_score = 0, 0, 0
        elif args.score == 'MSE' or args.score == 'RMSE':
            # just something large, this is arbitrary
            g.best_mean_train_score, g.best_mean_validation_score, g.best_mean_test_score = 10, 10, 10
        # g.train_subset_loader = subset_loader(train_dataloader, TRAIN_SUBSET_SIZE, seed=0)
        g.train_subset_loader = train_dataloader

    train_output, train_loss, train_target = feed_net(net, g.train_subset_loader, criterion, args.cuda)
    validation_output, validation_loss, validation_target = feed_net(net, validation_dataloader, criterion, args.cuda)
    test_output, test_loss, test_target = feed_net(net, test_dataloader, criterion, args.cuda)

    train_scores = SCORE_FUNCTIONS[args.score](train_output, train_target)
    train_mean_score = np.nanmean(train_scores)
    validation_scores = SCORE_FUNCTIONS[args.score](validation_output, validation_target)
    validation_mean_score = np.nanmean(validation_scores)
    test_scores = SCORE_FUNCTIONS[args.score](test_output, test_target)
    test_mean_score = np.nanmean(test_scores)

    if args.score == 'roc-auc' or args.score == 'pr-auc':
        new_best_model_found = validation_mean_score > g.best_mean_validation_score
    elif args.score == 'MSE' or args.score == 'RMSE':
        new_best_model_found = validation_mean_score < g.best_mean_validation_score

    if new_best_model_found:
        g.best_mean_train_score = train_mean_score
        g.best_mean_validation_score = validation_mean_score
        g.best_mean_test_score = test_mean_score

        if args.savemodel:
            path = SAVEDMODELS_DIR + type(net).__name__ + DATETIME_STR
            torch.save(net, path)

    target_names = train_dataloader.dataset.target_names
    return {  # if made deeper, tensorboardx writing breaks I think
        'loss': {'train': train_loss, 'test': test_loss},
        'mean {}'.format(args.score):
            {'train': train_mean_score, 'validation': validation_mean_score, 'test': test_mean_score},
        'train {}s'.format(args.score): {target_names[i]: train_scores[i] for i in range(len(target_names))},
        'test {}s'.format(args.score): {target_names[i]: test_scores[i] for i in range(len(target_names))},
        'best mean {}'.format(args.score):
            {'train': g.best_mean_train_score, 'validation': g.best_mean_validation_score,
             'test': g.best_mean_test_score}
    }


def get_run_info(net, args):
    return {
        'net': type(net).__name__,
        'args': ', '.join([str(k) + ': ' + str(v) for k, v in vars(args).items()]),
        'modules': {name: str(module) for name, module in net._modules.items()}
    }


def less_log(net, train_dataloader, validation_dataloader, test_dataloader, criterion, epoch, args):
    scalars = evaluate_net(net, train_dataloader, validation_dataloader, test_dataloader, criterion, args)
    mean_score_key = 'mean {}'.format(args.score)
    print('epoch {}, training mean {}: {}, validation mean {}: {}, testing mean {}: {}'.format(
        epoch + 1,
        args.score, scalars[mean_score_key]['train'],
        args.score, scalars[mean_score_key]['validation'],
        args.score, scalars[mean_score_key]['test'])
    )


def more_log(net, train_dataloader, validation_dataloader, test_dataloader, criterion, epoch, args):
    mean_score_key = 'mean {}'.format(args.score)
    best_mean_score_key = 'best {}'.format(mean_score_key)
    global g
    if not g.evaluate_called:
        run_info = get_run_info(net, args)
        print('net: ' + run_info['net'])
        print('args: {' + run_info['args'] + '}')
        print('****** MODULES: ******')
        for name, description in run_info['modules'].items():
            print(name + ': ' + description)
        print('**********************')
        print('score metric: {}'.format(args.score))
        print('columns:')
        print(
            'epochs, ' + \
            'mean training score, mean validation score, mean testing score, ' + \
            'best-model-so-far mean training score, best-model-so-far mean validation score, best-model-so-far mean testing score'
        )

    scalars = evaluate_net(net, train_dataloader, validation_dataloader, test_dataloader, criterion, args)
    print(
        '%d, %f, %f, %f, %f, %f, %f' % (
            epoch + 1,
            scalars[mean_score_key]['train'], scalars[mean_score_key]['validation'], scalars[mean_score_key]['test'],
            scalars[best_mean_score_key]['train'], scalars[best_mean_score_key]['validation'],
            scalars[best_mean_score_key]['test']
        )
    )


# to open tensorboard training summaries, live or static:
# 1) do some training to generate them in tbxoutput/
# 2) install tensorflow (in a separate environment is fine)
# 3) run tensorboard --port 6011 --logdir tbxoutput/ and open localhost:6011 in a browser
def tensorboardx_log(net, train_dataloader, validation_dataloader, test_dataloader, criterion, epoch, args):
    global g
    if not g.evaluate_called:
        from tensorboardX import SummaryWriter

        run_info = get_run_info(net, args)

        class_str = run_info['net']
        output_subdir = TENSORBOARDX_OUTPUT_DIR + class_str + ' ' + DATETIME_STR
        g.writer = SummaryWriter(output_subdir)

        g.writer.add_text('args', run_info['args'])
        for k, v in run_info['modules'].items():
            g.writer.add_text(k, v)
    else:
        # writer = SummaryWriter(output_subdir) # tensorboardx bug causes this to crash on epoch 40 or so
        g.writer.file_writer.reopen()  # workaround

    scalars = evaluate_net(net, train_dataloader, validation_dataloader, test_dataloader, criterion, args)

    for k, v in scalars.items():
        g.writer.add_scalars(k, v, epoch)

    # writer.close() # tensorboardx bug causes this to crash on epoch 40 or so
    g.writer.file_writer.close()  # workaround

    print('epoch %d, training loss: %f, validation loss: %f' %
          (epoch + 1, scalars['loss']['train'], scalars['loss']['validation']))


LOG_FUNCTIONS = {
    'less': less_log, 'more': more_log, 'tensorboardx': tensorboardx_log
}


In [None]:
#Imports
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

import datetime

In [None]:
train_dataset = MolGraphDataset('/kaggle/working/2019-nCov/Data/protease_train.csv.gz')
train_dataloader = DataLoader(train_dataset, batch_size=50, shuffle=True, collate_fn=molgraph_collate_fn)
validation_dataset = MolGraphDataset('/kaggle/working/2019-nCov/Data/protease_valid.csv.gz')
validation_dataloader = DataLoader(validation_dataset, batch_size=50, collate_fn=molgraph_collate_fn)
test_dataset = MolGraphDataset('/kaggle/working/2019-nCov/Data/protease_test.csv.gz')
test_dataloader = DataLoader(test_dataset, batch_size=50, collate_fn=molgraph_collate_fn)

In [None]:
((sample_adjacency, sample_nodes, sample_edges), sample_target) = train_dataset[0]

net = EMNImplementation(node_features=len(sample_nodes[0]), 
                                                edge_features=len(sample_edges[0, 0]), 
                                                out_features=len(sample_target), 
                                                message_passes=8, edge_embedding_size=50, 
                                                edge_emb_depth=2, edge_emb_hidden_dim=150, 
                                                edge_emb_dropout_p=0.0, att_depth=2, att_hidden_dim=85, 
                                                att_dropout_p=0.0, msg_depth=2, msg_hidden_dim=150, 
                                                msg_dropout_p=0.0, gather_width=45, gather_att_depth=2, 
                                                gather_att_hidden_dim=45, gather_att_dropout_p=0.0, 
                                                gather_emb_depth=2, gather_emb_hidden_dim=45, 
                                                gather_emb_dropout_p=0.0, out_depth=2, out_hidden_dim=450, 
                                                out_dropout_p=0.1, out_layer_shrinkage=0.6)
                                                
if True:
    net = net.cuda()

optimizer = optim.Adam(net.parameters(), lr=1e-4)
criterion = nn.MSELoss()

In [None]:
# import os
# os.mkdir("/kaggle/working/hello")
# os.listdir("/kaggle/working")

In [None]:
SAVEDMODELS_DIR = "/kaggle/working/savedmodels/"
os.mkdir(SAVEDMODELS_DIR)
def evaluate_net(net, train_dataloader, validation_dataloader, test_dataloader, criterion):
    global evaluate_called
    global DATETIME_STR
    global best_mean_train_score
    global best_mean_validation_score
    global best_mean_test_score
    global train_subset_loader
    
    if not evaluate_called:
        evaluate_called = True
        best_mean_train_score, best_mean_validation_score, best_mean_test_score = 10, 10, 10
        train_subset_loader = train_dataloader

    train_output, train_loss, train_target = feed_net(net, train_subset_loader, criterion, True)
    validation_output, validation_loss, validation_target = feed_net(net, validation_dataloader, criterion, True)
    test_output, test_loss, test_target = feed_net(net, test_dataloader, criterion, True)

    train_scores = compute_mse(train_output, train_target)
    train_mean_score = np.nanmean(train_scores)
    validation_scores = compute_mse(validation_output, validation_target)
    validation_mean_score = np.nanmean(validation_scores)
    test_scores = compute_mse(test_output, test_target)
    test_mean_score = np.nanmean(test_scores)

    new_best_model_found = validation_mean_score < best_mean_validation_score

    if new_best_model_found:
        best_mean_train_score = train_mean_score
        best_mean_validation_score = validation_mean_score
        best_mean_test_score = test_mean_score

        path = SAVEDMODELS_DIR + type(net).__name__ + DATETIME_STR
        torch.save(net, path)

    target_names = train_dataloader.dataset.target_names
    return {  # if made deeper, tensorboardx writing breaks I think
        'loss': {'train': train_loss, 'test': test_loss},
        'mean {}'.format("MSE"):
            {'train': train_mean_score, 'validation': validation_mean_score, 'test': test_mean_score},
        'train {}s'.format("MSE"): {target_names[i]: train_scores[i] for i in range(len(target_names))},
        'test {}s'.format("MSE"): {target_names[i]: test_scores[i] for i in range(len(target_names))},
        'best mean {}'.format("MSE"):
            {'train': best_mean_train_score, 'validation': best_mean_validation_score, 'test': best_mean_test_score}
    }

In [None]:
def less_log(net, train_dataloader, validation_dataloader, test_dataloader, criterion, epoch):
    scalars = evaluate_net(net, train_dataloader, validation_dataloader, test_dataloader, criterion)
    mean_score_key = 'mean {}'.format("MSE")
    print('epoch {}, training mean {}: {}, validation mean {}: {}, testing mean {}: {}'.format(
        epoch + 1,
        "MSE", scalars[mean_score_key]['train'],
        "MSE", scalars[mean_score_key]['validation'],
        "MSE", scalars[mean_score_key]['test'])
    )

In [None]:
evaluate_called = False
best_mean_train_score, best_mean_validation_score, best_mean_test_score = 10, 10, 10
train_subset_loader = None
DATETIME_STR = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')

for epoch in range(10):   # epoch in range(10)
    net.train()
    for i_batch, batch in enumerate(train_dataloader):

        if True:
            batch = [tensor.cuda() for tensor in batch]
        adjacency, nodes, edges, target = batch

        optimizer.zero_grad()
        output = net(adjacency, nodes, edges)
        loss = criterion(output, target)
        loss.backward()
        torch.nn.utils.clip_grad_value_(net.parameters(), 5.0)
        optimizer.step()

    with torch.no_grad():
        net.eval()
        less_log(net, train_dataloader, validation_dataloader, test_dataloader, criterion, epoch)

### Predictions

In [None]:
def predict(test_set):
    with torch.no_grad():
        #Change this path to predict using different trained models
        net = torch.load("/kaggle/working/2019-nCov/EMNN/savedmodels/EMNImplementation2020-03-01 17:49:07.529952")
        if True:
            net = net.cuda()
        else:
            net = net.cpu()
        net.eval()

        dataset = MolGraphDataset(test_set, prediction=True)
        dataloader = DataLoader(dataset, batch_size=50, collate_fn=molgraph_collate_fn)

        batch_outputs = []
        for i_batch, batch in enumerate(dataloader):
            if True:
                batch = [tensor.cuda() for tensor in batch]
            adjacency, nodes, edges, target = batch
            batch_output = net(adjacency, nodes, edges)
            batch_outputs.append(batch_output)

        output = torch.cat(batch_outputs).cpu().numpy()
        
        df = pd.read_csv(test_set)
        
        df.insert(1, 'pred_log_std_scaled', output, True)
        
        return df

In [None]:
predict("/kaggle/working/2019-nCov/Data/protease_test.csv.gz")

The N3 ligand from the crystal structure was extracted and its smiles string was searched on pubchem to find similar structures, which were saves and their activities can be predicted using the newly trained model

In [None]:
sim_compound = pd.read_csv("/kaggle/working/2019-nCov/Data/n3_similar_compounds.csv")
cids = list(sim_compound[['cid']].values.astype("int32").squeeze())

In [None]:
#re-using some code from above, the 
mols = []
for CID in cids:
    #os.system('curl https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/%s/sdf -o Data/cmp.sdf' %CID)
    os.system('wget https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/%s/sdf -O Data/cmp.sdf' %CID)
    if os.stat(f'/kaggle/working/2019-nCov/Data/cmp.sdf').st_size != 0:
        mols.append(Chem.SDMolSupplier("/kaggle/working/2019-nCov/Data/cmp.sdf")[0])
    else:
        mols.append(None)

In [None]:
sim_df = pd.DataFrame(mols)

In [None]:
sim_df.insert(0, 'smiles', [Chem.MolToSmiles(x) for x in sim_df[[0]].values[:,0]], True)

In [None]:
sim_df.insert(0, 'empty', [None]*len(sim_df), True)

In [None]:
sim_df[["empty","smiles"]].to_csv("/kaggle/working/2019-nCov/Data/n3_similarity_test.csv.gz", 
                                  index=False, compression='gzip', sep='\t')

In [None]:
predict("/kaggle/working/2019-nCov/Data/n3_similarity_test.csv.gz")

In [None]:
predictions = predict("/kaggle/working/2019-nCov/Data/n3_similarity_test.csv.gz")

In [None]:
predictions.sort_values("pred_log_std_scaled", ascending=False)

Because we have no way of verifying the accuracy of the predicted values for these compounds, we're going to instead just take the 10 highest predicted compounds and dock them using Autodock, By doing this we actually don't even need to re-scale the scaled standard activity value, we can just take the compounds knowing that why were predicted to be the best

In [None]:
#Keeping the best 10 compounds
best_predicted = predictions[["empty\tsmiles"]].values[:10,0]

In [None]:
#Removing a strangly appearing tab character from the front of each string
best_predicted = [best_predicted[i][1:] for i in range(len(best_predicted))]

In [None]:
pickle.dump(best_predicted, open("/kaggle/working/2019-nCov/Data/best_predicted_smiles.pkl", "wb"))

## 5. Molecular Generation I - Constrained Graph Variational Autoencoder

They describe a variational autoencoder trained directly on molecular graphs. More details on VAE's can be found here: (kigma paper). Essentially, a network learns a latent representation of the distribution of the training data that is normal with respect to each of the dimensions of the latent space. A new vector can be sampled from this latent distibution and then be constructed into a molecule,

Outline of this section: First we train the model using the dataset obtained in section 3. Using the built-in generation phase of the VAE, we generate new compounds and consider some of them as candidiates. The design of the method allows for optimization in the latent space, by doing gradient accent on a target value. This was not employed and random sampling generation was used.

In [None]:
!pip install docopt typing planarity
!git clone https://github.com/microsoft/constrained-graph-variational-autoencoder.git

In [None]:
################   utils.py

import planarity
import sascorer
# !/usr/bin/env/python
import tensorflow as tf
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Crippen
from rdkit.Chem import Draw
from rdkit.Chem import QED
from rdkit.Chem import rdmolops

SMALL_NUMBER = 1e-7
LARGE_NUMBER = 1e10

geometry_numbers = [3, 4, 5, 6]  # triangle, square, pentagen, hexagon

# bond mapping
bond_dict = {'SINGLE': 0, 'DOUBLE': 1, 'TRIPLE': 2, "AROMATIC": 3}
number_to_bond = {0: Chem.rdchem.BondType.SINGLE, 1: Chem.rdchem.BondType.DOUBLE,
                  2: Chem.rdchem.BondType.TRIPLE, 3: Chem.rdchem.BondType.AROMATIC}


def dataset_info(dataset):  # qm9, zinc, cep
    if dataset == 'qm9':
        return {'atom_types': ["H", "C", "N", "O", "F"],
                'maximum_valence': {0: 1, 1: 4, 2: 3, 3: 2, 4: 1},
                'number_to_atom': {0: "H", 1: "C", 2: "N", 3: "O", 4: "F"},
                'bucket_sizes': np.array(list(range(4, 28, 2)) + [29])
                }
    elif dataset == 'zinc':
        return {'atom_types': ['Br1(0)', 'C4(0)', 'Cl1(0)', 'F1(0)', 'H1(0)', 'I1(0)',
                               'N2(-1)', 'N3(0)', 'N4(1)', 'O1(-1)', 'O2(0)', 'S2(0)', 'S4(0)', 'S6(0)'],
                'maximum_valence': {0: 1, 1: 4, 2: 1, 3: 1, 4: 1, 5: 1, 6: 2, 7: 3, 8: 4, 9: 1, 10: 2, 11: 2, 12: 4,
                                    13: 6, 14: 3},
                'number_to_atom': {0: 'Br', 1: 'C', 2: 'Cl', 3: 'F', 4: 'H', 5: 'I', 6: 'N', 7: 'N', 8: 'N', 9: 'O',
                                   10: 'O', 11: 'S', 12: 'S', 13: 'S'},
                'bucket_sizes': np.array(
                    [28, 31, 33, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 55, 58, 84])
                }

    elif dataset == "cep":
        return {'atom_types': ["C", "S", "N", "O", "Se", "Si"],
                'maximum_valence': {0: 4, 1: 2, 2: 3, 3: 2, 4: 2, 5: 4},
                'number_to_atom': {0: "C", 1: "S", 2: "N", 3: "O", 4: "Se", 5: "Si"},
                'bucket_sizes': np.array([25, 28, 29, 30, 32, 33, 34, 35, 36, 37, 38, 39, 43, 46])
                }
    else:
        print("the datasets in use are qm9|zinc|cep")
        exit(1)


# add one edge to adj matrix
def add_edge_mat(amat, src, dest, e, considering_edge_type=True):
    if considering_edge_type:
        amat[e, dest, src] = 1
        amat[e, src, dest] = 1
    else:
        amat[src, dest] = 1
        amat[dest, src] = 1


def graph_to_adj_mat(graph, max_n_vertices, num_edge_types, tie_fwd_bkwd=True, considering_edge_type=True):
    if considering_edge_type:
        amat = np.zeros((num_edge_types, max_n_vertices, max_n_vertices))
        for src, e, dest in graph:
            add_edge_mat(amat, src, dest, e)
    else:
        amat = np.zeros((max_n_vertices, max_n_vertices))
        for src, e, dest in graph:
            add_edge_mat(amat, src, dest, e, considering_edge_type=False)
    return amat


def check_edge_prob(dataset):
    with open('intermediate_results_%s' % dataset, 'rb') as f:
        adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels, mean, logvariance = pickle.load(
            f)
    for ep, epl in zip(edge_prob, edge_prob_label):
        print("prediction")
        print(ep)
        print("label")
        print(epl)


# check whether a graph is planar or not
def is_planar(location, adj_list, is_dense=False):
    if is_dense:
        new_adj_list = defaultdict(list)
        for x in range(len(adj_list)):
            for y in range(len(adj_list)):
                if adj_list[x][y] == 1:
                    new_adj_list[x].append((y, 1))
        adj_list = new_adj_list
    edges = []
    seen = set()
    for src, l in adj_list.items():
        for dst, e in l:
            if (dst, src) not in seen:
                edges.append((src, dst))
                seen.add((src, dst))
    edges += [location, (location[1], location[0])]
    return planarity.is_planar(edges)


def check_edge_type_prob(filter=None):
    with open('intermediate_results_%s' % dataset, 'rb') as f:
        adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels, mean, logvariance = pickle.load(
            f)
    for ep, epl in zip(edge_type_prob, edge_type_label):
        print("prediction")
        print(ep)
        print("label")
        print(epl)


def check_mean(dataset, filter=None):
    with open('intermediate_results_%s' % dataset, 'rb') as f:
        adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels, mean, logvariance = pickle.load(
            f)
    print(mean.tolist()[:40])


def check_variance(dataset, filter=None):
    with open('intermediate_results_%s' % dataset, 'rb') as f:
        adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels, mean, logvariance = pickle.load(
            f)
    print(np.exp(logvariance).tolist()[:40])


def check_node_prob(filter=None):
    print(dataset)
    with open('intermediate_results_%s' % dataset, 'rb') as f:
        adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels, mean, logvariance = pickle.load(
            f)
    print(node_symbol_prob[0])
    print(node_symbol[0])
    print(node_symbol_prob.shape)


def check_qed(filter=None):
    with open('intermediate_results_%s' % dataset, 'rb') as f:
        adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels, mean, logvariance = pickle.load(
            f)
    print(qed_prediction)
    print(qed_labels[0])
    print(np.mean(np.abs(qed_prediction - qed_labels[0])))


def onehot(idx, len):
    z = [0 for _ in range(len)]
    z[idx] = 1
    return z


def generate_empty_adj_matrix(maximum_vertice_num):
    return np.zeros((1, 3, maximum_vertice_num, maximum_vertice_num))


# standard normal with shape [a1, a2, a3]
def generate_std_normal(a1, a2, a3):
    return np.random.normal(0, 1, [a1, a2, a3])


def check_validity(dataset):
    with open('generated_smiles_%s' % dataset, 'rb') as f:
        all_smiles = set(pickle.load(f))
    count = 0
    for smiles in all_smiles:
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            count += 1
    return len(all_smiles), count


# Get length for each graph based on node masks
def get_graph_length(all_node_mask):
    all_lengths = []
    for graph in all_node_mask:
        if 0 in graph:
            length = np.argmin(graph)
        else:
            length = len(graph)
        all_lengths.append(length)
    return all_lengths


def make_dir(path):
    if not os.path.exists(path):
        os.mkdir(path)
        print('made directory %s' % path)


# sample node symbols based on node predictions
def sample_node_symbol(all_node_symbol_prob, all_lengths, dataset):
    all_node_symbol = []
    for graph_idx, graph_prob in enumerate(all_node_symbol_prob):
        node_symbol = []
        for node_idx in range(all_lengths[graph_idx]):
            symbol = np.random.choice(np.arange(len(dataset_info(dataset)['atom_types'])), p=graph_prob[node_idx])
            node_symbol.append(symbol)
        all_node_symbol.append(node_symbol)
    return all_node_symbol


def dump(file_name, content):
    with open(file_name, 'wb') as out_file:
        pickle.dump(content, out_file, pickle.HIGHEST_PROTOCOL)


def load(file_name):
    with open(file_name, 'rb') as f:
        return pickle.load(f)

    # generate a new feature on whether adding the edges will generate more than two overlapped edges for rings


def get_overlapped_edge_feature(edge_mask, color, new_mol):
    overlapped_edge_feature = []
    for node_in_focus, neighbor in edge_mask:
        if color[neighbor] == 1:
            # attempt to add the edge
            new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[0])
            # Check whether there are two cycles having more than two overlap edges
            try:
                ssr = Chem.GetSymmSSSR(new_mol)
            except:
                ssr = []
            overlap_flag = False
            for idx1 in range(len(ssr)):
                for idx2 in range(idx1 + 1, len(ssr)):
                    if len(set(ssr[idx1]) & set(ssr[idx2])) > 2:
                        overlap_flag = True
            # remove that edge
            new_mol.RemoveBond(int(node_in_focus), int(neighbor))
            if overlap_flag:
                overlapped_edge_feature.append((node_in_focus, neighbor))
    return overlapped_edge_feature


# adj_list [3, v, v] or defaultdict. bfs distance on a graph
def bfs_distance(start, adj_list, is_dense=False):
    distances = {}
    visited = set()
    queue = deque([(start, 0)])
    visited.add(start)
    while len(queue) != 0:
        current, d = queue.popleft()
        for neighbor, edge_type in adj_list[current]:
            if neighbor not in visited:
                distances[neighbor] = d + 1
                visited.add(neighbor)
                queue.append((neighbor, d + 1))
    return [(start, node, d) for node, d in distances.items()]


def get_initial_valence(node_symbol, dataset):
    return [dataset_info(dataset)['maximum_valence'][s] for s in node_symbol]


def add_atoms(new_mol, node_symbol, dataset):
    for number in node_symbol:
        if dataset == 'qm9' or dataset == 'cep':
            idx = new_mol.AddAtom(Chem.Atom(dataset_info(dataset)['number_to_atom'][number]))
        elif dataset == 'zinc':
            new_atom = Chem.Atom(dataset_info(dataset)['number_to_atom'][number])
            charge_num = int(dataset_info(dataset)['atom_types'][number].split('(')[1].strip(')'))
            new_atom.SetFormalCharge(charge_num)
            new_mol.AddAtom(new_atom)


def visualize_mol(path, new_mol):
    AllChem.Compute2DCoords(new_mol)
    print(path)
    Draw.MolToFile(new_mol, path)


def get_idx_of_largest_frag(frags):
    return np.argmax([len(frag) for frag in frags])


def remove_extra_nodes(new_mol):
    frags = Chem.rdmolops.GetMolFrags(new_mol)
    while len(frags) > 1:
        # Get the idx of the frag with largest length
        largest_idx = get_idx_of_largest_frag(frags)
        for idx in range(len(frags)):
            if idx != largest_idx:
                # Remove one atom that is not in the largest frag
                new_mol.RemoveAtom(frags[idx][0])
                break
        frags = Chem.rdmolops.GetMolFrags(new_mol)


def novelty_metric(dataset):
    with open('all_smiles_%s.pkl' % dataset, 'rb') as f:
        all_smiles = set(pickle.load(f))
    with open('generated_smiles_%s' % dataset, 'rb') as f:
        generated_all_smiles = set(pickle.load(f))
    total_new_molecules = 0
    for generated_smiles in generated_all_smiles:
        if generated_smiles not in all_smiles:
            total_new_molecules += 1

    return float(total_new_molecules) / len(generated_all_smiles)


def count_edge_type(dataset, generated=True):
    if generated:
        filename = 'generated_smiles_%s' % dataset
    else:
        filename = 'all_smiles_%s.pkl' % dataset
    with open(filename, 'rb') as f:
        all_smiles = set(pickle.load(f))

    counter = defaultdict(int)
    edge_type_per_molecule = []
    for smiles in all_smiles:
        nodes, edges = to_graph(smiles, dataset)
        edge_type_this_molecule = [0] * len(bond_dict)
        for edge in edges:
            edge_type = edge[1]
            edge_type_this_molecule[edge_type] += 1
            counter[edge_type] += 1
        edge_type_per_molecule.append(edge_type_this_molecule)
    total_sum = 0
    return len(all_smiles), counter, edge_type_per_molecule


def need_kekulize(mol):
    for bond in mol.GetBonds():
        if bond_dict[str(bond.GetBondType())] >= 3:
            return True
    return False


def check_planar(dataset):
    with open("generated_smiles_%s" % dataset, 'rb') as f:
        all_smiles = set(pickle.load(f))
    total_non_planar = 0
    for smiles in all_smiles:
        try:
            nodes, edges = to_graph(smiles, dataset)
        except:
            continue
        edges = [(src, dst) for src, e, dst in edges]
        if edges == []:
            continue

        if not planarity.is_planar(edges):
            total_non_planar += 1
    return len(all_smiles), total_non_planar


def count_atoms(dataset):
    with open("generated_smiles_%s" % dataset, 'rb') as f:
        all_smiles = set(pickle.load(f))
    counter = defaultdict(int)
    atom_count_per_molecule = []  # record the counts for each molecule
    for smiles in all_smiles:
        try:
            nodes, edges = to_graph(smiles, dataset)
        except:
            continue
        atom_count_this_molecule = [0] * len(dataset_info(dataset)['atom_types'])
        for node in nodes:
            atom_type = np.argmax(node)
            atom_count_this_molecule[atom_type] += 1
            counter[atom_type] += 1
        atom_count_per_molecule.append(atom_count_this_molecule)
    total_sum = 0

    return len(all_smiles), counter, atom_count_per_molecule


def to_graph(smiles, dataset):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return [], []
    # Kekulize it
    if need_kekulize(mol):
        rdmolops.Kekulize(mol)
        if mol is None:
            return None, None
    # remove stereo information, such as inward and outward edges
    Chem.RemoveStereochemistry(mol)

    edges = []
    nodes = []
    for bond in mol.GetBonds():
        edges.append((bond.GetBeginAtomIdx(), bond_dict[str(bond.GetBondType())], bond.GetEndAtomIdx()))
        assert bond_dict[str(bond.GetBondType())] != 3
    for atom in mol.GetAtoms():
        if dataset == 'qm9' or dataset == "cep":
            nodes.append(onehot(dataset_info(dataset)['atom_types'].index(atom.GetSymbol()),
                                len(dataset_info(dataset)['atom_types'])))
        elif dataset == 'zinc':  # transform using "<atom_symbol><valence>(<charge>)"  notation
            symbol = atom.GetSymbol()
            valence = atom.GetTotalValence()
            charge = atom.GetFormalCharge()
            atom_str = "%s%i(%i)" % (symbol, valence, charge)

            if atom_str not in dataset_info(dataset)['atom_types']:
                print('unrecognized atom type %s' % atom_str)
                return [], []

            nodes.append(
                onehot(dataset_info(dataset)['atom_types'].index(atom_str), len(dataset_info(dataset)['atom_types'])))

    return nodes, edges


def check_uniqueness(dataset):
    with open('generated_smiles_%s' % dataset, 'rb') as f:
        all_smiles = pickle.load(f)
    original_num = len(all_smiles)
    all_smiles = set(all_smiles)
    new_num = len(all_smiles)
    return new_num / original_num


def shape_count(dataset, remove_print=False, all_smiles=None):
    if all_smiles == None:
        with open('generated_smiles_%s' % dataset, 'rb') as f:
            all_smiles = set(pickle.load(f))

    geometry_counts = [0] * len(geometry_numbers)
    geometry_counts_per_molecule = []  # record the geometry counts for each molecule
    for smiles in all_smiles:
        nodes, edges = to_graph(smiles, dataset)
        if len(edges) <= 0:
            continue
        new_mol = Chem.MolFromSmiles(smiles)

        ssr = Chem.GetSymmSSSR(new_mol)
        counts_for_molecule = [0] * len(geometry_numbers)
        for idx in range(len(ssr)):
            ring_len = len(list(ssr[idx]))
            if ring_len in geometry_numbers:
                geometry_counts[geometry_numbers.index(ring_len)] += 1
                counts_for_molecule[geometry_numbers.index(ring_len)] += 1
        geometry_counts_per_molecule.append(counts_for_molecule)

    return len(all_smiles), geometry_counts, geometry_counts_per_molecule


def check_adjacent_sparse(adj_list, node, neighbor_in_doubt):
    for neighbor, edge_type in adj_list[node]:
        if neighbor == neighbor_in_doubt:
            return True, edge_type
    return False, None


def glorot_init(shape):
    initialization_range = np.sqrt(6.0 / (shape[-2] + shape[-1]))
    return np.random.uniform(low=-initialization_range, high=initialization_range, size=shape).astype(np.float32)


class ThreadedIterator:
    """An iterator object that computes its elements in a parallel thread to be ready to be consumed.
    The iterator should *not* return None"""

    def __init__(self, original_iterator, max_queue_size: int = 2):
        self.__queue = queue.Queue(maxsize=max_queue_size)
        self.__thread = threading.Thread(target=lambda: self.worker(original_iterator))
        self.__thread.start()

    def worker(self, original_iterator):
        for element in original_iterator:
            assert element is not None, 'By convention, iterator elements much not be None'
            self.__queue.put(element, block=True)
        self.__queue.put(None, block=True)

    def __iter__(self):
        next_element = self.__queue.get(block=True)
        while next_element is not None:
            yield next_element
            next_element = self.__queue.get(block=True)
        self.__thread.join()


# Implements multilayer perceptron
class MLP(object):
    def __init__(self, in_size, out_size, hid_sizes, dropout_keep_prob):
        self.in_size = in_size
        self.out_size = out_size
        self.hid_sizes = hid_sizes
        self.dropout_keep_prob = dropout_keep_prob
        self.params = self.make_network_params()

    def make_network_params(self):
        dims = [self.in_size] + self.hid_sizes + [self.out_size]
        weight_sizes = list(zip(dims[:-1], dims[1:]))
        weights = [tf.Variable(self.init_weights(s), name='MLP_W_layer%i' % i)
                   for (i, s) in enumerate(weight_sizes)]
        biases = [tf.Variable(np.zeros(s[-1]).astype(np.float32), name='MLP_b_layer%i' % i)
                  for (i, s) in enumerate(weight_sizes)]

        network_params = {
            "weights": weights,
            "biases": biases,
        }

        return network_params

    def init_weights(self, shape):
        return np.sqrt(6.0 / (shape[-2] + shape[-1])) * (2 * np.random.rand(*shape).astype(np.float32) - 1)

    def __call__(self, inputs):
        acts = inputs
        for W, b in zip(self.params["weights"], self.params["biases"]):
            hid = tf.matmul(acts, tf.nn.dropout(W, self.dropout_keep_prob)) + b
            acts = tf.nn.relu(hid)
        last_hidden = hid
        return last_hidden


class Graph():

    def __init__(self, V, g):
        self.V = V
        self.graph = g

    def addEdge(self, v, w):
        # Add w to v ist.
        self.graph[v].append(w)
        # Add v to w list.
        self.graph[w].append(v)

        # A recursive function that uses visited[] 

    # and parent to detect cycle in subgraph 
    # reachable from vertex v.
    def isCyclicUtil(self, v, visited, parent):

        # Mark current node as visited
        visited[v] = True

        # Recur for all the vertices adjacent 
        # for this vertex
        for i in self.graph[v]:
            # If an adjacent is not visited, 
            # then recur for that adjacent
            if visited[i] == False:
                if self.isCyclicUtil(i, visited, v) == True:
                    return True

            # If an adjacent is visited and not 
            # parent of current vertex, then there 
            # is a cycle.
            elif i != parent:
                return True

        return False

    # Returns true if the graph is a tree, 
    # else false.
    def isTree(self):
        # Mark all the vertices as not visited 
        # and not part of recursion stack
        visited = [False] * self.V

        # The call to isCyclicUtil serves multiple 
        # purposes. It returns true if graph reachable 
        # from vertex 0 is cyclcic. It also marks 
        # all vertices reachable from 0.
        if self.isCyclicUtil(0, visited, -1) == True:
            return False

        # If we find a vertex which is not reachable
        # from 0 (not marked by isCyclicUtil(), 
        # then we return false
        for i in range(self.V):
            if visited[i] == False:
                return False

        return True


# whether whether the graphs has no cycle or not 
def check_cyclic(dataset, generated=True):
    if generated:
        with open("generated_smiles_%s" % dataset, 'rb') as f:
            all_smiles = set(pickle.load(f))
    else:
        with open("all_smiles_%s.pkl" % dataset, 'rb') as f:
            all_smiles = set(pickle.load(f))

    tree_count = 0
    for smiles in all_smiles:
        nodes, edges = to_graph(smiles, dataset)
        edges = [(src, dst) for src, e, dst in edges]
        if edges == []:
            continue
        new_adj_list = defaultdict(list)

        for src, dst in edges:
            new_adj_list[src].append(dst)
            new_adj_list[dst].append(src)
        graph = Graph(len(nodes), new_adj_list)
        if graph.isTree():
            tree_count += 1
    return len(all_smiles), tree_count


def check_sascorer(dataset):
    with open('generated_smiles_%s' % dataset, 'rb') as f:
        all_smiles = set(pickle.load(f))
    sa_sum = 0
    total = 0
    sa_score_per_molecule = []
    for smiles in all_smiles:
        new_mol = Chem.MolFromSmiles(smiles)
        try:
            val = sascorer.calculateScore(new_mol)
        except:
            continue
        sa_sum += val
        sa_score_per_molecule.append(val)
        total += 1
    return sa_sum / total, sa_score_per_molecule


def check_logp(dataset):
    with open('generated_smiles_%s' % dataset, 'rb') as f:
        all_smiles = set(pickle.load(f))
    logp_sum = 0
    total = 0
    logp_score_per_molecule = []
    for smiles in all_smiles:
        new_mol = Chem.MolFromSmiles(smiles)
        try:
            val = Crippen.MolLogP(new_mol)
        except:
            continue
        logp_sum += val
        logp_score_per_molecule.append(val)
        total += 1
    return logp_sum / total, logp_score_per_molecule


def check_qed(dataset):
    with open('generated_smiles_%s' % dataset, 'rb') as f:
        all_smiles = set(pickle.load(f))
    qed_sum = 0
    total = 0
    qed_score_per_molecule = []
    for smiles in all_smiles:
        new_mol = Chem.MolFromSmiles(smiles)
        try:
            val = QED.qed(new_mol)
        except:
            continue
        qed_sum += val
        qed_score_per_molecule.append(val)
        total += 1
    return qed_sum / total, qed_score_per_molecule


def sssr_metric(dataset):
    with open('generated_smiles_%s' % dataset, 'rb') as f:
        all_smiles = set(pickle.load(f))
    overlapped_molecule = 0
    for smiles in all_smiles:
        new_mol = Chem.MolFromSmiles(smiles)
        ssr = Chem.GetSymmSSSR(new_mol)
        overlap_flag = False
        for idx1 in range(len(ssr)):
            for idx2 in range(idx1 + 1, len(ssr)):
                if len(set(ssr[idx1]) & set(ssr[idx2])) > 2:
                    overlap_flag = True
        if overlap_flag:
            overlapped_molecule += 1
    return overlapped_molecule / len(all_smiles)


# select the best based on shapes and probs
def select_best(all_mol):
    # sort by shape
    all_mol = sorted(all_mol)
    best_shape = all_mol[-1][0]
    all_mol = [(p, m) for s, p, m in all_mol if s == best_shape]
    # sort by probs
    all_mol = sorted(all_mol)
    return all_mol[-1][1]


# a series util function converting sparse matrix representation to dense 

def incre_adj_mat_to_dense(incre_adj_mat, num_edge_types, maximum_vertice_num):
    new_incre_adj_mat = []
    for sparse_incre_adj_mat in incre_adj_mat:
        dense_incre_adj_mat = np.zeros((num_edge_types, maximum_vertice_num, maximum_vertice_num))
        for current, adj_list in sparse_incre_adj_mat.items():
            for neighbor, edge_type in adj_list:
                dense_incre_adj_mat[edge_type][current][neighbor] = 1
        new_incre_adj_mat.append(dense_incre_adj_mat)
    return new_incre_adj_mat  # [number_iteration,num_edge_types,maximum_vertice_num, maximum_vertice_num]


def distance_to_others_dense(distance_to_others, maximum_vertice_num):
    new_all_distance = []
    for sparse_distances in distance_to_others:
        dense_distances = np.zeros((maximum_vertice_num), dtype=int)
        for x, y, d in sparse_distances:
            dense_distances[y] = d
        new_all_distance.append(dense_distances)
    return new_all_distance  # [number_iteration, maximum_vertice_num]


def overlapped_edge_features_to_dense(overlapped_edge_features, maximum_vertice_num):
    new_overlapped_edge_features = []
    for sparse_overlapped_edge_features in overlapped_edge_features:
        dense_overlapped_edge_features = np.zeros((maximum_vertice_num), dtype=int)
        for node_in_focus, neighbor in sparse_overlapped_edge_features:
            dense_overlapped_edge_features[neighbor] = 1
        new_overlapped_edge_features.append(dense_overlapped_edge_features)
    return new_overlapped_edge_features  # [number_iteration, maximum_vertice_num]


def node_sequence_to_dense(node_sequence, maximum_vertice_num):
    new_node_sequence = []
    for node in node_sequence:
        s = [0] * maximum_vertice_num
        s[node] = 1
        new_node_sequence.append(s)
    return new_node_sequence  # [number_iteration, maximum_vertice_num]


def edge_type_masks_to_dense(edge_type_masks, maximum_vertice_num, num_edge_types):
    new_edge_type_masks = []
    for mask_sparse in edge_type_masks:
        mask_dense = np.zeros([num_edge_types, maximum_vertice_num])
        for node_in_focus, neighbor, bond in mask_sparse:
            mask_dense[bond][neighbor] = 1
        new_edge_type_masks.append(mask_dense)
    return new_edge_type_masks  # [number_iteration, 3, maximum_vertice_num]


def edge_type_labels_to_dense(edge_type_labels, maximum_vertice_num, num_edge_types):
    new_edge_type_labels = []
    for labels_sparse in edge_type_labels:
        labels_dense = np.zeros([num_edge_types, maximum_vertice_num])
        for node_in_focus, neighbor, bond in labels_sparse:
            labels_dense[bond][neighbor] = 1 / float(len(labels_sparse))  # fix the probability bug here.
        new_edge_type_labels.append(labels_dense)
    return new_edge_type_labels  # [number_iteration, 3, maximum_vertice_num]


def edge_masks_to_dense(edge_masks, maximum_vertice_num):
    new_edge_masks = []
    for mask_sparse in edge_masks:
        mask_dense = [0] * maximum_vertice_num
        for node_in_focus, neighbor in mask_sparse:
            mask_dense[neighbor] = 1
        new_edge_masks.append(mask_dense)
    return new_edge_masks  # [number_iteration, maximum_vertice_num]


def edge_labels_to_dense(edge_labels, maximum_vertice_num):
    new_edge_labels = []
    for label_sparse in edge_labels:
        label_dense = [0] * maximum_vertice_num
        for node_in_focus, neighbor in label_sparse:
            label_dense[neighbor] = 1 / float(len(label_sparse))
        new_edge_labels.append(label_dense)
    return new_edge_labels  # [number_iteration, maximum_vertice_num]


# !/usr/bin/env/python
"""
Usage:
    get_qm9.py

Options:
    -h --help                Show this screen.
"""

import os
import sys

from GGNN_core import ChemModel

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))

dataset = 'qm9'


def get_validation_file_names(unzip_path):
    print('loading train/validation split')
    with open('valid_idx_qm9.json', 'r') as f:
        valid_idx = json.load(f)['valid_idxs']
    valid_files = [os.path.join(unzip_path, 'dsgdb9nsd_%s.xyz' % i) for i in valid_idx]
    return valid_files


def read_xyz(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()
        smiles = lines[-2].split('\t')[0]
        mu = QED.qed(Chem.MolFromSmiles(smiles))
    return {'smiles': smiles, 'QED': mu}


def train_valid_split(unzip_path):
    print('reading data...')
    raw_data = {'train': [], 'valid': []}  # save the train, valid dataset.
    all_files = glob.glob(os.path.join(unzip_path, '*.xyz'))
    valid_files = get_validation_file_names(unzip_path)

    file_count = 0
    for file_idx, file_path in enumerate(all_files):
        if file_path not in valid_files:
            raw_data['train'].append(read_xyz(file_path))
        else:
            raw_data['valid'].append(read_xyz(file_path))
        file_count += 1
        if file_count % 2000 == 0:
            print('finished reading: %d' % file_count, end='\r')
    return raw_data


def preprocess(raw_data, dataset):
    print('parsing smiles as graphs...')
    processed_data = {'train': [], 'valid': []}

    file_count = 0
    for section in ['train', 'valid']:
        all_smiles = []  # record all smiles in training dataset
        for i, (smiles, QED) in enumerate([(mol['smiles'], mol['QED'])
                                           for mol in raw_data[section]]):
            nodes, edges = to_graph(smiles, dataset)
            if len(edges) <= 0:
                continue
            processed_data[section].append({
                'targets': [[(QED)]],
                'graph': edges,
                'node_features': nodes,
                'smiles': smiles
            })
            all_smiles.append(smiles)
            if file_count % 2000 == 0:
                print('finished processing: %d' % file_count, end='\r')
            file_count += 1
        print('%s: 100 %%      ' % (section))
        # save the dataset
        with open('molecules_%s_%s.json' % (section, dataset), 'w') as f:
            json.dump(processed_data[section], f)
        # save all molecules in the training dataset
        if section == 'train':
            utils.dump('smiles_%s.pkl' % dataset, all_smiles)


# if __name__ == "__main__":
#     # download   
#     download_path = 'dsgdb9nsd.xyz.tar.bz2'
#     if not os.path.exists(download_path):
#         print('downloading data to %s ...' % download_path)
#         source = 'https://ndownloader.figshare.com/files/3195389'
#         os.system('wget -O %s %s' % (download_path, source))
#         print('finished downloading')
# 
#     # unzip
#     unzip_path = 'qm9_raw'
#     if not os.path.exists(unzip_path):
#         print('extracting data to %s ...' % unzip_path)
#         os.mkdir(unzip_path)
#         os.system('tar xvjf %s -C %s' % (download_path, unzip_path))
#         print('finished extracting')
# 
#     raw_data = train_valid_split(unzip_path)
#     preprocess(raw_data, dataset)

######  get_zinc.py

# !/usr/bin/env/python
"""
Usage:
    get_data.py --dataset zinc|qm9|cep

Options:
    -h --help                Show this screen.
    --dataset NAME           Dataset name: zinc, qm9, cep
"""

import sys, os

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
import glob
import csv

dataset = "zinc"


def train_valid_split(download_path):
    # load validation dataset
    with open("valid_idx_zinc.json", 'r') as f:
        valid_idx = json.load(f)

    print('reading data...')
    raw_data = {'train': [], 'valid': []}  # save the train, valid dataset.
    with open(download_path, 'r') as f:
        all_data = list(csv.DictReader(f))

    file_count = 0
    for i, data_item in enumerate(all_data):
        smiles = data_item['smiles'].strip()
        QED = float(data_item['qed'])
        if i not in valid_idx:
            raw_data['train'].append({'smiles': smiles, 'QED': QED})
        else:
            raw_data['valid'].append({'smiles': smiles, 'QED': QED})
        file_count += 1
        if file_count % 2000 == 0:
            print('finished reading: %d' % file_count, end='\r')
    return raw_data


# 
# if __name__ == "__main__":
#     download_path = '250k_rndm_zinc_drugs_clean_3.csv'
#     if not os.path.exists(download_path):
#         print('downloading data to %s ...' % download_path)
#         source = 'https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv'
#         os.system('wget -O %s %s' % (download_path, source))
#         print('finished downloading')
# 
#     raw_data = train_valid_split(download_path)
#     preprocess(raw_data, dataset)


#######################   CGVAE.py

# !/usr/bin/env/python
"""
Usage:
    CGVAE.py [options]

Options:
    -h --help                Show this screen
    --dataset NAME           Dataset name: zinc, qm9, cep
    --config-file FILE       Hyperparameter configuration file path (in JSON format)
    --config CONFIG          Hyperparameter configuration dictionary (in JSON format)
    --log_dir NAME           log dir name
    --data_dir NAME          data dir name
    --restore FILE           File to restore weights from.
    --freeze-graph-model     Freeze weights of graph model components
"""

'''
Comments provide the expected tensor shapes where helpful.

Key to symbols in comments:
---------------------------
[...]:  a tensor
; ; :   a list
b:      batch size
e:      number of edege types (3)
es:     maximum number of BFS transitions in this batch
v:      number of vertices per graph in this batch
h:      GNN hidden size
'''


class DenseGGNNChemModel(ChemModel):
    def __init__(self, args):
        super().__init__(args)

    @classmethod
    def default_params(cls):
        params = dict(super().default_params())
        params.update({
            'task_sample_ratios': {},
            'use_edge_bias': True,  # whether use edge bias in gnn

            'clamp_gradient_norm': 1.0,
            'out_layer_dropout_keep_prob': 1.0,

            'tie_fwd_bkwd': True,
            'task_ids': [0],  # id of property prediction

            'random_seed': 0,  # fixed for reproducibility 

            'batch_size': 8 if dataset == 'zinc' or dataset == 'cep' else 64,
            "qed_trade_off_lambda": 10,
            'prior_learning_rate': 0.05,
            'stop_criterion': 0.01,
            'num_epochs': 3 if dataset == 'zinc' or dataset == 'cep' else 10,
            'epoch_to_generate': 3 if dataset == 'zinc' or dataset == 'cep' else 10,
            'number_of_generation': 30000,
            'optimization_step': 0,
            'maximum_distance': 50,
            "use_argmax_generation": False,  # use random sampling or argmax during generation
            'residual_connection_on': True,  # whether residual connection is on
            'residual_connections': {  # For iteration i, specify list of layers whose output is added as an input
                2: [0],
                4: [0, 2],
                6: [0, 2, 4],
                8: [0, 2, 4, 6],
                10: [0, 2, 4, 6, 8],
                12: [0, 2, 4, 6, 8, 10],
                14: [0, 2, 4, 6, 8, 10, 12],
            },
            'num_timesteps': 12,  # gnn propagation step
            'hidden_size': 100,
            "kl_trade_off_lambda": 0.3,  # kl tradeoff
            'learning_rate': 0.001,
            'graph_state_dropout_keep_prob': 1,
            "compensate_num": 1,  # how many atoms to be added during generation

            'train_file': 'data/molecules_train_%s.json' % dataset,
            'valid_file': 'data/molecules_valid_%s.json' % dataset,

            'try_different_starting': True,
            "num_different_starting": 6,

            'generation': False,  # only for generation
            'use_graph': True,  # use gnn
            "label_one_hot": False,  # one hot label or not
            "multi_bfs_path": False,  # whether sample several BFS paths for each molecule
            "bfs_path_count": 30,
            "path_random_order": False,  # False: canonical order, True: random order
            "sample_transition": False,  # whether use transition sampling
            'edge_weight_dropout_keep_prob': 1,
            'check_overlap_edge': False,
            "truncate_distance": 10,
        })

        return params

    def prepare_specific_graph_model(self) -> None:
        h_dim = self.params['hidden_size']
        expanded_h_dim = self.params['hidden_size'] + self.params['hidden_size'] + 1  # 1 for focus bit
        self.placeholders['graph_state_keep_prob'] = tf.placeholder(tf.float32, None, name='graph_state_keep_prob')
        self.placeholders['edge_weight_dropout_keep_prob'] = tf.placeholder(tf.float32, None,
                                                                            name='edge_weight_dropout_keep_prob')
        self.placeholders['initial_node_representation'] = tf.placeholder(tf.float32,
                                                                          [None, None, self.params['hidden_size']],
                                                                          name='node_features')  # padded node symbols
        # mask out invalid node
        self.placeholders['node_mask'] = tf.placeholder(tf.float32, [None, None], name='node_mask')  # [b x v]
        self.placeholders['num_vertices'] = tf.placeholder(tf.int32, ())
        # adj for encoder
        self.placeholders['adjacency_matrix'] = tf.placeholder(tf.float32,
                                                               [None, self.num_edge_types, None, None],
                                                               name="adjacency_matrix")  # [b, e, v, v]
        # labels for node symbol prediction
        self.placeholders['node_symbols'] = tf.placeholder(tf.float32, [None, None, self.params[
            'num_symbols']])  # [b, v, edge_type]
        # node symbols used to enhance latent representations
        self.placeholders['latent_node_symbols'] = tf.placeholder(tf.float32,
                                                                  [None, None, self.params['hidden_size']],
                                                                  name='latent_node_symbol')  # [b, v, h]
        # mask out cross entropies in decoder
        self.placeholders['iteration_mask'] = tf.placeholder(tf.float32, [None, None])  # [b, es]
        # adj matrices used in decoder
        self.placeholders['incre_adj_mat'] = tf.placeholder(tf.float32, [None, None, self.num_edge_types, None, None],
                                                            name='incre_adj_mat')  # [b, es, e, v, v]
        # distance 
        self.placeholders['distance_to_others'] = tf.placeholder(tf.int32, [None, None, None],
                                                                 name='distance_to_others')  # [b, es,v]
        # maximum iteration number of this batch
        self.placeholders['max_iteration_num'] = tf.placeholder(tf.int32, [], name='max_iteration_num')  # number
        # node number in focus at each iteration step
        self.placeholders['node_sequence'] = tf.placeholder(tf.float32, [None, None, None],
                                                            name='node_sequence')  # [b, es, v]
        # mask out invalid edge types at each iteration step 
        self.placeholders['edge_type_masks'] = tf.placeholder(tf.float32, [None, None, self.num_edge_types, None],
                                                              name='edge_type_masks')  # [b, es, e, v]
        # ground truth edge type labels at each iteration step 
        self.placeholders['edge_type_labels'] = tf.placeholder(tf.float32, [None, None, self.num_edge_types, None],
                                                               name='edge_type_labels')  # [b, es, e, v]
        # mask out invalid edge at each iteration step 
        self.placeholders['edge_masks'] = tf.placeholder(tf.float32, [None, None, None],
                                                         name='edge_masks')  # [b, es, v]
        # ground truth edge labels at each iteration step 
        self.placeholders['edge_labels'] = tf.placeholder(tf.float32, [None, None, None],
                                                          name='edge_labels')  # [b, es, v]        
        # ground truth labels for whether it stops at each iteration step
        self.placeholders['local_stop'] = tf.placeholder(tf.float32, [None, None], name='local_stop')  # [b, es]
        # z_prior sampled from standard normal distribution
        self.placeholders['z_prior'] = tf.placeholder(tf.float32, [None, None, self.params['hidden_size']],
                                                      name='z_prior')  # the prior of z sampled from normal distribution
        # put in front of kl latent loss
        self.placeholders['kl_trade_off_lambda'] = tf.placeholder(tf.float32, [], name='kl_trade_off_lambda')  # number
        # overlapped edge features
        self.placeholders['overlapped_edge_features'] = tf.placeholder(tf.int32, [None, None, None],
                                                                       name='overlapped_edge_features')  # [b, es, v]

        # weights for encoder and decoder GNN. 
        if self.params["residual_connection_on"]:
            # weights for encoder and decoder GNN. Different weights for each iteration
            for scope in ['_encoder', '_decoder']:
                if scope == '_encoder':
                    new_h_dim = h_dim
                else:
                    new_h_dim = expanded_h_dim
                for iter_idx in range(self.params['num_timesteps']):
                    with tf.variable_scope("gru_scope" + scope + str(iter_idx), reuse=False):
                        self.weights['edge_weights' + scope + str(iter_idx)] = tf.Variable(
                            glorot_init([self.num_edge_types, new_h_dim, new_h_dim]))
                        if self.params['use_edge_bias']:
                            self.weights['edge_biases' + scope + str(iter_idx)] = tf.Variable(
                                np.zeros([self.num_edge_types, 1, new_h_dim]).astype(np.float32))

                        cell = tf.contrib.rnn.GRUCell(new_h_dim)
                        cell = tf.nn.rnn_cell.DropoutWrapper(cell,
                                                             state_keep_prob=self.placeholders['graph_state_keep_prob'])
                        self.weights['node_gru' + scope + str(iter_idx)] = cell
        else:
            for scope in ['_encoder', '_decoder']:
                if scope == '_encoder':
                    new_h_dim = h_dim
                else:
                    new_h_dim = expanded_h_dim
                self.weights['edge_weights' + scope] = tf.Variable(
                    glorot_init([self.num_edge_types, new_h_dim, new_h_dim]))
                if self.params['use_edge_bias']:
                    self.weights['edge_biases' + scope] = tf.Variable(
                        np.zeros([self.num_edge_types, 1, new_h_dim]).astype(np.float32))
                with tf.variable_scope("gru_scope" + scope):
                    cell = tf.contrib.rnn.GRUCell(new_h_dim)
                    cell = tf.nn.rnn_cell.DropoutWrapper(cell,
                                                         state_keep_prob=self.placeholders['graph_state_keep_prob'])
                    self.weights['node_gru' + scope] = cell

        # weights for calculating mean and variance
        self.weights['mean_weights'] = tf.Variable(glorot_init([h_dim, h_dim]))
        self.weights['mean_biases'] = tf.Variable(np.zeros([1, h_dim]).astype(np.float32))
        self.weights['variance_weights'] = tf.Variable(glorot_init([h_dim, h_dim]))
        self.weights['variance_biases'] = tf.Variable(np.zeros([1, h_dim]).astype(np.float32))

        # The weights for generating nodel symbol logits    
        self.weights['node_symbol_weights'] = tf.Variable(glorot_init([h_dim, self.params['num_symbols']]))
        self.weights['node_symbol_biases'] = tf.Variable(np.zeros([1, self.params['num_symbols']]).astype(np.float32))

        feature_dimension = 6 * expanded_h_dim
        # record the total number of features
        self.params["feature_dimension"] = 6
        # weights for generating edge type logits
        for i in range(self.num_edge_types):
            self.weights['edge_type_%d' % i] = tf.Variable(glorot_init([feature_dimension, feature_dimension]))
            self.weights['edge_type_biases_%d' % i] = tf.Variable(np.zeros([1, feature_dimension]).astype(np.float32))
            self.weights['edge_type_output_%d' % i] = tf.Variable(glorot_init([feature_dimension, 1]))
        # weights for generating edge logits
        self.weights['edge_iteration'] = tf.Variable(glorot_init([feature_dimension, feature_dimension]))
        self.weights['edge_iteration_biases'] = tf.Variable(np.zeros([1, feature_dimension]).astype(np.float32))
        self.weights['edge_iteration_output'] = tf.Variable(glorot_init([feature_dimension, 1]))
        # Weights for the stop node
        self.weights["stop_node"] = tf.Variable(glorot_init([1, expanded_h_dim]))
        # Weight for distance embedding
        self.weights['distance_embedding'] = tf.Variable(glorot_init([self.params['maximum_distance'], expanded_h_dim]))
        # Weight for overlapped edge feature
        self.weights["overlapped_edge_weight"] = tf.Variable(glorot_init([2, expanded_h_dim]))
        # weights for linear projection on qed prediction input
        self.weights['qed_weights'] = tf.Variable(glorot_init([h_dim, h_dim]))
        self.weights['qed_biases'] = tf.Variable(np.zeros([1, h_dim]).astype(np.float32))
        # use node embeddings
        self.weights["node_embedding"] = tf.Variable(glorot_init([self.params["num_symbols"], h_dim]))

        # graph state mask
        self.ops['graph_state_mask'] = tf.expand_dims(self.placeholders['node_mask'], 2)

    # transform one hot vector to dense embedding vectors
    def get_node_embedding_state(self, one_hot_state):
        node_nums = tf.argmax(one_hot_state, axis=2)
        return tf.nn.embedding_lookup(self.weights["node_embedding"], node_nums) * self.ops['graph_state_mask']

    def compute_final_node_representations_with_residual(self, h, adj, scope_name):  # scope_name: _encoder or _decoder
        # h: initial representation, adj: adjacency matrix, different GNN parameters for encoder and decoder
        v = self.placeholders['num_vertices']
        # _decoder uses a larger latent space because concat of symbol and latent representation
        if scope_name == "_decoder":
            h_dim = self.params['hidden_size'] + self.params['hidden_size'] + 1
        else:
            h_dim = self.params['hidden_size']
        h = tf.reshape(h, [-1, h_dim])  # [b*v, h]
        # record all hidden states at each iteration
        all_hidden_states = [h]
        for iter_idx in range(self.params['num_timesteps']):
            with tf.variable_scope("gru_scope" + scope_name + str(iter_idx), reuse=None) as g_scope:
                for edge_type in range(self.num_edge_types):
                    # the message passed from this vertice to other vertices
                    m = tf.matmul(h, self.weights['edge_weights' + scope_name + str(iter_idx)][edge_type])  # [b*v, h]
                    if self.params['use_edge_bias']:
                        m += self.weights['edge_biases' + scope_name + str(iter_idx)][edge_type]  # [b, v, h]
                    m = tf.reshape(m, [-1, v, h_dim])  # [b, v, h]
                    # collect the messages from other vertices to each vertice
                    if edge_type == 0:
                        acts = tf.matmul(adj[edge_type], m)
                    else:
                        acts += tf.matmul(adj[edge_type], m)
                # all messages collected for each node
                acts = tf.reshape(acts, [-1, h_dim])  # [b*v, h]
                # add residual connection here
                layer_residual_connections = self.params['residual_connections'].get(iter_idx)
                if layer_residual_connections is None:
                    layer_residual_states = []
                else:
                    layer_residual_states = [all_hidden_states[residual_layer_idx]
                                             for residual_layer_idx in layer_residual_connections]
                # concat current hidden states with residual states
                acts = tf.concat([acts] + layer_residual_states, axis=1)  # [b, (1+num residual connection)* h]

                # feed msg inputs and hidden states to GRU
                h = self.weights['node_gru' + scope_name + str(iter_idx)](acts, h)[1]  # [b*v, h]
                # record the new hidden states
                all_hidden_states.append(h)
        last_h = tf.reshape(all_hidden_states[-1], [-1, v, h_dim])
        return last_h

    def compute_final_node_representations_without_residual(self, h, adj, edge_weights, edge_biases, node_gru,
                                                            gru_scope_name):
        # h: initial representation, adj: adjacency matrix, different GNN parameters for encoder and decoder
        v = self.placeholders['num_vertices']
        if gru_scope_name == "gru_scope_decoder":
            h_dim = self.params['hidden_size'] + self.params['hidden_size']
        else:
            h_dim = self.params['hidden_size']
        h = tf.reshape(h, [-1, h_dim])

        with tf.variable_scope(gru_scope_name) as scope:
            for i in range(self.params['num_timesteps']):
                if i > 0:
                    tf.get_variable_scope().reuse_variables()
                for edge_type in range(self.num_edge_types):
                    m = tf.matmul(h, tf.nn.dropout(edge_weights[edge_type],
                                                   keep_prob=self.placeholders[
                                                       'edge_weight_dropout_keep_prob']))  # [b*v, h]
                    if self.params['use_edge_bias']:
                        m += edge_biases[edge_type]  # [b, v, h]
                    m = tf.reshape(m, [-1, v, h_dim])  # [b, v, h]
                    if edge_type == 0:
                        acts = tf.matmul(adj[edge_type], m)
                    else:
                        acts += tf.matmul(adj[edge_type], m)
                acts = tf.reshape(acts, [-1, h_dim])  # [b*v, h]
                h = node_gru(acts, h)[1]  # [b*v, h]
            last_h = tf.reshape(h, [-1, v, h_dim])
        return last_h

    def compute_mean_and_logvariance(self):
        h_dim = self.params['hidden_size']
        reshped_last_h = tf.reshape(self.ops['final_node_representations'], [-1, h_dim])
        mean = tf.matmul(reshped_last_h, self.weights['mean_weights']) + self.weights['mean_biases']
        logvariance = tf.matmul(reshped_last_h, self.weights['variance_weights']) + self.weights['variance_biases']
        return mean, logvariance

    def sample_with_mean_and_logvariance(self):
        v = self.placeholders['num_vertices']
        h_dim = self.params['hidden_size']
        # Sample from normal distribution
        z_prior = tf.reshape(self.placeholders['z_prior'], [-1, h_dim])
        # Train: sample from u, Sigma. Generation: sample from 0,1
        z_sampled = tf.cond(self.placeholders['is_generative'], lambda: z_prior,  # standard normal 
                            lambda: tf.add(self.ops['mean'], tf.multiply(tf.sqrt(tf.exp(self.ops['logvariance'])),
                                                                         z_prior)))  # non-standard normal
        # filter
        z_sampled = tf.reshape(z_sampled, [-1, v, h_dim]) * self.ops['graph_state_mask']
        return z_sampled

    def fully_connected(self, input, hidden_weight, hidden_bias, output_weight):
        output = tf.nn.relu(tf.matmul(input, hidden_weight) + hidden_bias)
        output = tf.matmul(output, output_weight)
        return output

    def generate_cross_entropy(self, idx, cross_entropy_losses, edge_predictions, edge_type_predictions):
        v = self.placeholders['num_vertices']
        h_dim = self.params['hidden_size']
        num_symbols = self.params['num_symbols']
        batch_size = tf.shape(self.placeholders['initial_node_representation'])[0]
        # Use latent representation as decoder GNN'input 
        filtered_z_sampled = self.ops["initial_repre_for_decoder"]  # [b, v, h+h]
        # data needed in this iteration
        incre_adj_mat = self.placeholders['incre_adj_mat'][:, idx, :, :, :]  # [b, e, v, v]
        distance_to_others = self.placeholders['distance_to_others'][:, idx, :]  # [b,v]
        overlapped_edge_features = self.placeholders['overlapped_edge_features'][:, idx, :]  # [b,v]
        node_sequence = self.placeholders['node_sequence'][:, idx, :]  # [b, v]
        node_sequence = tf.expand_dims(node_sequence, axis=2)  # [b,v,1]
        edge_type_masks = self.placeholders['edge_type_masks'][:, idx, :, :]  # [b, e, v]
        # make invalid locations to be very small before using softmax function
        edge_type_masks = edge_type_masks * LARGE_NUMBER - LARGE_NUMBER
        edge_type_labels = self.placeholders['edge_type_labels'][:, idx, :, :]  # [b, e, v]
        edge_masks = self.placeholders['edge_masks'][:, idx, :]  # [b, v]
        # make invalid locations to be very small before using softmax function
        edge_masks = edge_masks * LARGE_NUMBER - LARGE_NUMBER
        edge_labels = self.placeholders['edge_labels'][:, idx, :]  # [b, v]  
        local_stop = self.placeholders['local_stop'][:, idx]  # [b]        
        # concat the hidden states with the node in focus
        filtered_z_sampled = tf.concat([filtered_z_sampled, node_sequence], axis=2)  # [b, v, h + h + 1]
        # Decoder GNN
        if self.params["use_graph"]:
            if self.params["residual_connection_on"]:
                new_filtered_z_sampled = self.compute_final_node_representations_with_residual(filtered_z_sampled,
                                                                                               tf.transpose(
                                                                                                   incre_adj_mat,
                                                                                                   [1, 0, 2, 3]),
                                                                                               "_decoder")  # [b, v, h + h]
            else:
                new_filtered_z_sampled = self.compute_final_node_representations_without_residual(filtered_z_sampled,
                                                                                                  tf.transpose(
                                                                                                      incre_adj_mat,
                                                                                                      [1, 0, 2, 3]),
                                                                                                  self.weights[
                                                                                                      'edge_weights_decoder'],
                                                                                                  self.weights[
                                                                                                      'edge_biases_decoder'],
                                                                                                  self.weights[
                                                                                                      'node_gru_decoder'],
                                                                                                  "gru_scope_decoder")  # [b, v, h + h]
        else:
            new_filtered_z_sampled = filtered_z_sampled
        # Filter nonexist nodes
        new_filtered_z_sampled = new_filtered_z_sampled * self.ops['graph_state_mask']
        # Take out the node in focus
        node_in_focus = tf.reduce_sum(node_sequence * new_filtered_z_sampled, axis=1)  # [b, h + h]
        # edge pair representation
        edge_repr = tf.concat( \
            [tf.tile(tf.expand_dims(node_in_focus, 1), [1, v, 1]), new_filtered_z_sampled],
            axis=2)  # [b, v, 2*(h+h)]            
        # combine edge repre with local and global repr
        local_graph_repr_before_expansion = tf.reduce_sum(new_filtered_z_sampled, axis=1) / \
                                            tf.reduce_sum(self.placeholders['node_mask'], axis=1,
                                                          keep_dims=True)  # [b, h + h]
        local_graph_repr = tf.expand_dims(local_graph_repr_before_expansion, 1)
        local_graph_repr = tf.tile(local_graph_repr, [1, v, 1])  # [b, v, h+h]        
        global_graph_repr_before_expansion = tf.reduce_sum(filtered_z_sampled, axis=1) / \
                                             tf.reduce_sum(self.placeholders['node_mask'], axis=1, keep_dims=True)
        global_graph_repr = tf.expand_dims(global_graph_repr_before_expansion, 1)
        global_graph_repr = tf.tile(global_graph_repr, [1, v, 1])  # [b, v, h+h]
        # distance representation
        distance_repr = tf.nn.embedding_lookup(self.weights['distance_embedding'], distance_to_others)  # [b, v, h+h]
        # overlapped edge feature representation
        overlapped_edge_repr = tf.nn.embedding_lookup(self.weights['overlapped_edge_weight'],
                                                      overlapped_edge_features)  # [b, v, h+h]
        # concat and reshape.
        combined_edge_repr = tf.concat([edge_repr, local_graph_repr,
                                        global_graph_repr, distance_repr, overlapped_edge_repr], axis=2)

        combined_edge_repr = tf.reshape(combined_edge_repr,
                                        [-1, self.params["feature_dimension"] * (h_dim + h_dim + 1)])
        # Calculate edge logits
        edge_logits = self.fully_connected(combined_edge_repr, self.weights['edge_iteration'],
                                           self.weights['edge_iteration_biases'], self.weights['edge_iteration_output'])
        edge_logits = tf.reshape(edge_logits, [-1, v])  # [b, v]
        # filter invalid terms
        edge_logits = edge_logits + edge_masks
        # Calculate whether it will stop at this step
        # prepare the data
        expanded_stop_node = tf.tile(self.weights['stop_node'], [batch_size, 1])  # [b, h + h]
        distance_to_stop_node = tf.nn.embedding_lookup(self.weights['distance_embedding'],
                                                       tf.tile([0], [batch_size]))  # [b, h + h]
        overlap_edge_stop_node = tf.nn.embedding_lookup(self.weights['overlapped_edge_weight'],
                                                        tf.tile([0], [batch_size]))  # [b, h + h]

        combined_stop_node_repr = tf.concat([node_in_focus, expanded_stop_node, local_graph_repr_before_expansion,
                                             global_graph_repr_before_expansion, distance_to_stop_node,
                                             overlap_edge_stop_node], axis=1)  # [b, 6 * (h + h)]
        # logits for stop node                                    
        stop_logits = self.fully_connected(combined_stop_node_repr,
                                           self.weights['edge_iteration'], self.weights['edge_iteration_biases'],
                                           self.weights['edge_iteration_output'])  # [b, 1]
        edge_logits = tf.concat([edge_logits, stop_logits], axis=1)  # [b, v + 1]

        # Calculate edge type logits
        edge_type_logits = []
        for i in range(self.num_edge_types):
            edge_type_logit = self.fully_connected(combined_edge_repr,
                                                   self.weights['edge_type_%d' % i],
                                                   self.weights['edge_type_biases_%d' % i],
                                                   self.weights[
                                                       'edge_type_output_%d' % i])  # [b * v, 1]                        
            edge_type_logits.append(tf.reshape(edge_type_logit, [-1, 1, v]))  # [b, 1, v]

        edge_type_logits = tf.concat(edge_type_logits, axis=1)  # [b, e, v]
        # filter invalid items
        edge_type_logits = edge_type_logits + edge_type_masks  # [b, e, v]
        # softmax over edge type axis
        edge_type_probs = tf.nn.softmax(edge_type_logits, 1)  # [b, e, v]

        # edge labels
        edge_labels = tf.concat([edge_labels, tf.expand_dims(local_stop, 1)], axis=1)  # [b, v + 1]                
        # softmax for edge
        edge_loss = - tf.reduce_sum(tf.log(tf.nn.softmax(edge_logits) + SMALL_NUMBER) * edge_labels, axis=1)
        # softmax for edge type 
        edge_type_loss = - edge_type_labels * tf.log(edge_type_probs + SMALL_NUMBER)  # [b, e, v]
        edge_type_loss = tf.reduce_sum(edge_type_loss, axis=[1, 2])  # [b]
        # total loss
        iteration_loss = edge_loss + edge_type_loss
        cross_entropy_losses = cross_entropy_losses.write(idx, iteration_loss)
        edge_predictions = edge_predictions.write(idx, tf.nn.softmax(edge_logits))
        edge_type_predictions = edge_type_predictions.write(idx, edge_type_probs)
        return (idx + 1, cross_entropy_losses, edge_predictions, edge_type_predictions)

    def construct_logit_matrices(self):
        v = self.placeholders['num_vertices']
        batch_size = tf.shape(self.placeholders['initial_node_representation'])[0]
        h_dim = self.params['hidden_size']

        # Initial state: embedding
        latent_node_state = self.get_node_embedding_state(self.placeholders["latent_node_symbols"])
        # concat z_sampled with node symbols
        filtered_z_sampled = tf.concat([self.ops['z_sampled'],
                                        latent_node_state], axis=2)  # [b, v, h + h]
        self.ops["initial_repre_for_decoder"] = filtered_z_sampled
        # The tensor array used to collect the cross entropy losses at each step
        cross_entropy_losses = tf.TensorArray(dtype=tf.float32, size=self.placeholders['max_iteration_num'])
        edge_predictions = tf.TensorArray(dtype=tf.float32, size=self.placeholders['max_iteration_num'])
        edge_type_predictions = tf.TensorArray(dtype=tf.float32, size=self.placeholders['max_iteration_num'])
        idx_final, cross_entropy_losses_final, edge_predictions_final, edge_type_predictions_final = \
            tf.while_loop(
                lambda idx, cross_entropy_losses, edge_predictions, edge_type_predictions: idx < self.placeholders[
                    'max_iteration_num'],
                self.generate_cross_entropy,
                (tf.constant(0), cross_entropy_losses, edge_predictions, edge_type_predictions,))

        # record the predictions for generation
        self.ops['edge_predictions'] = edge_predictions_final.read(0)
        self.ops['edge_type_predictions'] = edge_type_predictions_final.read(0)

        # final cross entropy losses
        cross_entropy_losses_final = cross_entropy_losses_final.stack()
        self.ops['cross_entropy_losses'] = tf.transpose(cross_entropy_losses_final, [1, 0])  # [b, es]

        # Logits for node symbols
        self.ops['node_symbol_logits'] = tf.reshape(
            tf.matmul(tf.reshape(self.ops['z_sampled'], [-1, h_dim]), self.weights['node_symbol_weights']) +
            self.weights['node_symbol_biases'], [-1, v, self.params['num_symbols']])

    def construct_loss(self):
        v = self.placeholders['num_vertices']
        h_dim = self.params['hidden_size']
        kl_trade_off_lambda = self.placeholders['kl_trade_off_lambda']
        # Edge loss
        self.ops["edge_loss"] = tf.reduce_sum(self.ops['cross_entropy_losses'] * self.placeholders['iteration_mask'],
                                              axis=1)
        # KL loss 
        kl_loss = 1 + self.ops['logvariance'] - tf.square(self.ops['mean']) - tf.exp(self.ops['logvariance'])
        kl_loss = tf.reshape(kl_loss, [-1, v, h_dim]) * self.ops['graph_state_mask']
        self.ops['kl_loss'] = -0.5 * tf.reduce_sum(kl_loss, [1, 2])
        # Node symbol loss
        self.ops['node_symbol_prob'] = tf.nn.softmax(self.ops['node_symbol_logits'])
        self.ops['node_symbol_loss'] = -tf.reduce_sum(tf.log(self.ops['node_symbol_prob'] + SMALL_NUMBER) *
                                                      self.placeholders['node_symbols'], axis=[1, 2])
        # Add in the loss for calculating QED
        for (internal_id, task_id) in enumerate(self.params['task_ids']):
            with tf.variable_scope("out_layer_task%i" % task_id):
                with tf.variable_scope("regression_gate"):
                    self.weights['regression_gate_task%i' % task_id] = MLP(self.params['hidden_size'], 1, [],
                                                                           self.placeholders[
                                                                               'out_layer_dropout_keep_prob'])
                with tf.variable_scope("regression"):
                    self.weights['regression_transform_task%i' % task_id] = MLP(self.params['hidden_size'], 1, [],
                                                                                self.placeholders[
                                                                                    'out_layer_dropout_keep_prob'])
                normalized_z_sampled = tf.nn.l2_normalize(self.ops['z_sampled'], 2)
                self.ops['qed_computed_values'] = computed_values = self.gated_regression(normalized_z_sampled,
                                                                                          self.weights[
                                                                                              'regression_gate_task%i' % task_id],
                                                                                          self.weights[
                                                                                              'regression_transform_task%i' % task_id],
                                                                                          self.params["hidden_size"],
                                                                                          self.weights['qed_weights'],
                                                                                          self.weights['qed_biases'],
                                                                                          self.placeholders[
                                                                                              'num_vertices'],
                                                                                          self.placeholders[
                                                                                              'node_mask'])
                diff = computed_values - self.placeholders['target_values'][internal_id, :]  # [b]
                task_target_mask = self.placeholders['target_mask'][internal_id, :]
                task_target_num = tf.reduce_sum(task_target_mask) + SMALL_NUMBER
                diff = diff * task_target_mask  # Mask out unused values [b]
                self.ops['accuracy_task%i' % task_id] = tf.reduce_sum(tf.abs(diff)) / task_target_num
                task_loss = tf.reduce_sum(0.5 * tf.square(diff)) / task_target_num  # number
                # Normalise loss to account for fewer task-specific examples in batch:
                task_loss = task_loss * (1.0 / (self.params['task_sample_ratios'].get(task_id) or 1.0))
                self.ops['qed_loss'].append(task_loss)
                if task_id == 0:  # Assume it is the QED score
                    z_sampled_shape = tf.shape(self.ops['z_sampled'])
                    flattened_z_sampled = tf.reshape(self.ops['z_sampled'], [z_sampled_shape[0], -1])
                    self.ops['l2_loss'] = 0.01 * tf.reduce_sum(flattened_z_sampled * flattened_z_sampled, axis=1) / 2
                    # Calculate the derivative with respect to QED + l2 loss
                    self.ops['derivative_z_sampled'] = tf.gradients(self.ops['qed_computed_values'] -
                                                                    self.ops['l2_loss'], self.ops['z_sampled'])
        self.ops['total_qed_loss'] = tf.reduce_sum(self.ops['qed_loss'])  # number
        self.ops['mean_edge_loss'] = tf.reduce_mean(self.ops["edge_loss"])  # record the mean edge loss
        self.ops['mean_node_symbol_loss'] = tf.reduce_mean(self.ops["node_symbol_loss"])
        self.ops['mean_kl_loss'] = tf.reduce_mean(kl_trade_off_lambda * self.ops['kl_loss'])
        self.ops['mean_total_qed_loss'] = self.params["qed_trade_off_lambda"] * self.ops['total_qed_loss']
        return tf.reduce_mean(self.ops["edge_loss"] + self.ops['node_symbol_loss'] + \
                              kl_trade_off_lambda * self.ops['kl_loss']) \
               + self.params["qed_trade_off_lambda"] * self.ops['total_qed_loss']

    def gated_regression(self, last_h, regression_gate, regression_transform, hidden_size, projection_weight,
                         projection_bias, v, mask):
        # last_h: [b x v x h]
        last_h = tf.reshape(last_h, [-1, hidden_size])  # [b*v, h]    
        # linear projection on last_h
        last_h = tf.nn.relu(tf.matmul(last_h, projection_weight) + projection_bias)  # [b*v, h]  
        # same as last_h
        gate_input = last_h
        # linear projection and combine                                       
        gated_outputs = tf.nn.sigmoid(regression_gate(gate_input)) * tf.nn.tanh(
            regression_transform(last_h))  # [b*v, 1]
        gated_outputs = tf.reshape(gated_outputs, [-1, v])  # [b, v]
        masked_gated_outputs = gated_outputs * mask  # [b x v]
        output = tf.reduce_sum(masked_gated_outputs, axis=1)  # [b]
        output = tf.sigmoid(output)
        return output

    def calculate_incremental_results(self, raw_data, bucket_sizes, file_name):
        incremental_results = []
        # copy the raw_data if more than 1 BFS path is added
        new_raw_data = []
        for idx, d in enumerate(raw_data):
            # Use canonical order or random order here. canonical order starts from index 0. random order starts from random nodes
            if not self.params["path_random_order"]:
                # Use several different starting index if using multi BFS path
                if self.params["multi_bfs_path"]:
                    list_of_starting_idx = list(range(self.params["bfs_path_count"]))
                else:
                    list_of_starting_idx = [0]  # the index 0
            else:
                # get the node length for this molecule
                node_length = len(d["node_features"])
                if self.params["multi_bfs_path"]:
                    list_of_starting_idx = np.random.choice(node_length, self.params["bfs_path_count"],
                                                            replace=True)  # randomly choose several
                else:
                    list_of_starting_idx = [random.choice(list(range(node_length)))]  # randomly choose one
            for list_idx, starting_idx in enumerate(list_of_starting_idx):
                # choose a bucket
                chosen_bucket_idx = np.argmax(bucket_sizes > max([v for e in d['graph']
                                                                  for v in [e[0], e[2]]]))
                chosen_bucket_size = bucket_sizes[chosen_bucket_idx]

                # Calculate incremental results without master node
                nodes_no_master, edges_no_master = to_graph(d['smiles'], self.params["dataset"])
                incremental_adj_mat, distance_to_others, node_sequence, edge_type_masks, edge_type_labels, local_stop, edge_masks, edge_labels, overlapped_edge_features = \
                    construct_incremental_graph(dataset, edges_no_master, chosen_bucket_size,
                                                len(nodes_no_master), nodes_no_master, self.params,
                                                initial_idx=starting_idx)
                if self.params["sample_transition"] and list_idx > 0:
                    incremental_results[-1] = [x + y for x, y in
                                               zip(incremental_results[-1], [incremental_adj_mat, distance_to_others,
                                                                             node_sequence, edge_type_masks,
                                                                             edge_type_labels, local_stop, edge_masks,
                                                                             edge_labels, overlapped_edge_features])]
                else:
                    incremental_results.append([incremental_adj_mat, distance_to_others, node_sequence, edge_type_masks,
                                                edge_type_labels, local_stop, edge_masks, edge_labels,
                                                overlapped_edge_features])
                    # copy the raw_data here 
                    new_raw_data.append(d)
                if idx % 50 == 0:
                    print('finish calculating %d incremental matrices' % idx, end="\r")
        return incremental_results, new_raw_data

    # ----- Data preprocessing and chunking into minibatches:
    def process_raw_graphs(self, raw_data, is_training_data, file_name, bucket_sizes=None):
        if bucket_sizes is None:
            bucket_sizes = dataset_info(self.params["dataset"])["bucket_sizes"]
        incremental_results, raw_data = self.calculate_incremental_results(raw_data, bucket_sizes, file_name)
        bucketed = defaultdict(list)
        x_dim = len(raw_data[0]["node_features"][0])

        for d, (incremental_adj_mat, distance_to_others, node_sequence, edge_type_masks, edge_type_labels, local_stop,
                edge_masks, edge_labels, overlapped_edge_features) \
                in zip(raw_data, incremental_results):
            # choose a bucket
            chosen_bucket_idx = np.argmax(bucket_sizes > max([v for e in d['graph']
                                                              for v in [e[0], e[2]]]))
            chosen_bucket_size = bucket_sizes[chosen_bucket_idx]
            # total number of nodes in this data point
            n_active_nodes = len(d["node_features"])
            bucketed[chosen_bucket_idx].append({
                'adj_mat': graph_to_adj_mat(d['graph'], chosen_bucket_size, self.num_edge_types,
                                            self.params['tie_fwd_bkwd']),
                'incre_adj_mat': incremental_adj_mat,
                'distance_to_others': distance_to_others,
                'overlapped_edge_features': overlapped_edge_features,
                'node_sequence': node_sequence,
                'edge_type_masks': edge_type_masks,
                'edge_type_labels': edge_type_labels,
                'edge_masks': edge_masks,
                'edge_labels': edge_labels,
                'local_stop': local_stop,
                'number_iteration': len(local_stop),
                'init': d["node_features"] + [[0 for _ in range(x_dim)] for __ in
                                              range(chosen_bucket_size - n_active_nodes)],
                'labels': [d["targets"][task_id][0] for task_id in self.params['task_ids']],
                'mask': [1. for _ in range(n_active_nodes)] + [0. for _ in range(chosen_bucket_size - n_active_nodes)]
            })

        if is_training_data:
            for (bucket_idx, bucket) in bucketed.items():
                np.random.shuffle(bucket)
                for task_id in self.params['task_ids']:
                    task_sample_ratio = self.params['task_sample_ratios'].get(str(task_id))
                    if task_sample_ratio is not None:
                        ex_to_sample = int(len(bucket) * task_sample_ratio)
                        for ex_id in range(ex_to_sample, len(bucket)):
                            bucket[ex_id]['labels'][task_id] = None

        bucket_at_step = [[bucket_idx for _ in range(len(bucket_data) // self.params['batch_size'])]
                          for bucket_idx, bucket_data in bucketed.items()]
        bucket_at_step = [x for y in bucket_at_step for x in y]

        return (bucketed, bucket_sizes, bucket_at_step)

    def pad_annotations(self, annotations):
        return np.pad(annotations,
                      pad_width=[[0, 0], [0, 0], [0, self.params['hidden_size'] - self.params["num_symbols"]]],
                      mode='constant')

    def make_batch(self, elements, maximum_vertice_num):
        # get maximum number of iterations in this batch. used to control while_loop
        max_iteration_num = -1
        for d in elements:
            max_iteration_num = max(d['number_iteration'], max_iteration_num)
        batch_data = {'adj_mat': [], 'init': [], 'labels': [], 'edge_type_masks': [], 'edge_type_labels': [],
                      'edge_masks': [],
                      'edge_labels': [], 'node_mask': [], 'task_masks': [], 'node_sequence': [],
                      'iteration_mask': [], 'local_stop': [], 'incre_adj_mat': [], 'distance_to_others': [],
                      'max_iteration_num': max_iteration_num, 'overlapped_edge_features': []}
        for d in elements:
            # sparse to dense for saving memory           
            incre_adj_mat = incre_adj_mat_to_dense(d['incre_adj_mat'], self.num_edge_types, maximum_vertice_num)
            distance_to_others = distance_to_others_dense(d['distance_to_others'], maximum_vertice_num)
            overlapped_edge_features = overlapped_edge_features_to_dense(d['overlapped_edge_features'],
                                                                         maximum_vertice_num)
            node_sequence = node_sequence_to_dense(d['node_sequence'], maximum_vertice_num)
            edge_type_masks = edge_type_masks_to_dense(d['edge_type_masks'], maximum_vertice_num, self.num_edge_types)
            edge_type_labels = edge_type_labels_to_dense(d['edge_type_labels'], maximum_vertice_num,
                                                         self.num_edge_types)
            edge_masks = edge_masks_to_dense(d['edge_masks'], maximum_vertice_num)
            edge_labels = edge_labels_to_dense(d['edge_labels'], maximum_vertice_num)

            batch_data['adj_mat'].append(d['adj_mat'])
            batch_data['init'].append(d['init'])
            batch_data['node_mask'].append(d['mask'])

            batch_data['incre_adj_mat'].append(incre_adj_mat +
                                               [np.zeros(
                                                   (self.num_edge_types, maximum_vertice_num, maximum_vertice_num))
                                                   for _ in range(max_iteration_num - d['number_iteration'])])
            batch_data['distance_to_others'].append(distance_to_others +
                                                    [np.zeros((maximum_vertice_num))
                                                     for _ in range(max_iteration_num - d['number_iteration'])])
            batch_data['overlapped_edge_features'].append(overlapped_edge_features +
                                                          [np.zeros((maximum_vertice_num))
                                                           for _ in range(max_iteration_num - d['number_iteration'])])
            batch_data['node_sequence'].append(node_sequence +
                                               [np.zeros((maximum_vertice_num))
                                                for _ in range(max_iteration_num - d['number_iteration'])])
            batch_data['edge_type_masks'].append(edge_type_masks +
                                                 [np.zeros((self.num_edge_types, maximum_vertice_num))
                                                  for _ in range(max_iteration_num - d['number_iteration'])])
            batch_data['edge_masks'].append(edge_masks +
                                            [np.zeros((maximum_vertice_num))
                                             for _ in range(max_iteration_num - d['number_iteration'])])
            batch_data['edge_type_labels'].append(edge_type_labels +
                                                  [np.zeros((self.num_edge_types, maximum_vertice_num))
                                                   for _ in range(max_iteration_num - d['number_iteration'])])
            batch_data['edge_labels'].append(edge_labels +
                                             [np.zeros((maximum_vertice_num))
                                              for _ in range(max_iteration_num - d['number_iteration'])])
            batch_data['iteration_mask'].append([1 for _ in range(d['number_iteration'])] +
                                                [0 for _ in range(max_iteration_num - d['number_iteration'])])
            batch_data['local_stop'].append([int(s) for s in d["local_stop"]] +
                                            [0 for _ in range(max_iteration_num - d['number_iteration'])])

            target_task_values = []
            target_task_mask = []
            for target_val in d['labels']:
                if target_val is None:  # This is one of the examples we didn't sample...
                    target_task_values.append(0.)
                    target_task_mask.append(0.)
                else:
                    target_task_values.append(target_val)
                    target_task_mask.append(1.)
            batch_data['labels'].append(target_task_values)
            batch_data['task_masks'].append(target_task_mask)

        return batch_data

    def get_dynamic_feed_dict(self, elements, latent_node_symbol, incre_adj_mat, num_vertices,
                              distance_to_others, overlapped_edge_dense, node_sequence, edge_type_masks, edge_masks,
                              random_normal_states):
        if incre_adj_mat is None:
            incre_adj_mat = np.zeros((1, 1, self.num_edge_types, 1, 1))
            distance_to_others = np.zeros((1, 1, 1))
            overlapped_edge_dense = np.zeros((1, 1, 1))
            node_sequence = np.zeros((1, 1, 1))
            edge_type_masks = np.zeros((1, 1, self.num_edge_types, 1))
            edge_masks = np.zeros((1, 1, 1))
            latent_node_symbol = np.zeros((1, 1, self.params["num_symbols"]))
        return {
            self.placeholders['z_prior']: random_normal_states,  # [1, v, h]
            self.placeholders['incre_adj_mat']: incre_adj_mat,  # [1, 1, e, v, v]
            self.placeholders['num_vertices']: num_vertices,  # v

            self.placeholders['initial_node_representation']: \
                self.pad_annotations([elements['init']]),
            self.placeholders['node_symbols']: [elements['init']],
            self.placeholders['latent_node_symbols']: self.pad_annotations(latent_node_symbol),
            self.placeholders['adjacency_matrix']: [elements['adj_mat']],
            self.placeholders['node_mask']: [elements['mask']],

            self.placeholders['graph_state_keep_prob']: 1,
            self.placeholders['edge_weight_dropout_keep_prob']: 1,
            self.placeholders['iteration_mask']: [[1]],
            self.placeholders['is_generative']: True,
            self.placeholders['out_layer_dropout_keep_prob']: 1.0,
            self.placeholders['distance_to_others']: distance_to_others,  # [1, 1,v]
            self.placeholders['overlapped_edge_features']: overlapped_edge_dense,
            self.placeholders['max_iteration_num']: 1,
            self.placeholders['node_sequence']: node_sequence,  # [1, 1, v]
            self.placeholders['edge_type_masks']: edge_type_masks,  # [1, 1, e, v]
            self.placeholders['edge_masks']: edge_masks,  # [1, 1, v]
        }

    def get_node_symbol(self, batch_feed_dict):
        fetch_list = [self.ops['node_symbol_prob']]
        result = self.sess.run(fetch_list, feed_dict=batch_feed_dict)
        return result[0]

    def node_symbol_one_hot(self, sampled_node_symbol, real_n_vertices, max_n_vertices):
        one_hot_representations = []
        for idx in range(max_n_vertices):
            representation = [0] * self.params["num_symbols"]
            if idx < real_n_vertices:
                atom_type = sampled_node_symbol[idx]
                representation[atom_type] = 1
            one_hot_representations.append(representation)
        return one_hot_representations

    def search_and_generate_molecule(self, initial_idx, valences,
                                     sampled_node_symbol, real_n_vertices, random_normal_states,
                                     elements, max_n_vertices):
        # New molecule
        new_mol = Chem.MolFromSmiles('')
        new_mol = Chem.rdchem.RWMol(new_mol)
        # Add atoms
        add_atoms(new_mol, sampled_node_symbol, self.params["dataset"])
        # Breadth first search over the molecule
        queue = deque([initial_idx])
        # color 0: have not found 1: in the queue 2: searched already
        color = [0] * max_n_vertices
        color[initial_idx] = 1
        # Empty adj list at the beginning
        incre_adj_list = defaultdict(list)
        # record the log probabilities at each step
        total_log_prob = 0
        while len(queue) > 0:
            node_in_focus = queue.popleft()
            # iterate until the stop node is selected 
            while True:
                # Prepare data for one iteration based on the graph state
                edge_type_mask_sparse, edge_mask_sparse = generate_mask(valences, incre_adj_list, color,
                                                                        real_n_vertices, node_in_focus,
                                                                        self.params["check_overlap_edge"], new_mol)
                edge_type_mask = edge_type_masks_to_dense([edge_type_mask_sparse], max_n_vertices,
                                                          self.num_edge_types)  # [1, e, v]
                edge_mask = edge_masks_to_dense([edge_mask_sparse], max_n_vertices)  # [1, v]
                node_sequence = node_sequence_to_dense([node_in_focus], max_n_vertices)  # [1, v]
                distance_to_others_sparse = bfs_distance(node_in_focus, incre_adj_list)
                distance_to_others = distance_to_others_dense([distance_to_others_sparse], max_n_vertices)  # [1, v]
                overlapped_edge_sparse = get_overlapped_edge_feature(edge_mask_sparse, color, new_mol)

                overlapped_edge_dense = overlapped_edge_features_to_dense([overlapped_edge_sparse],
                                                                          max_n_vertices)  # [1, v]
                incre_adj_mat = incre_adj_mat_to_dense([incre_adj_list],
                                                       self.num_edge_types, max_n_vertices)  # [1, e, v, v]
                sampled_node_symbol_one_hot = self.node_symbol_one_hot(sampled_node_symbol, real_n_vertices,
                                                                       max_n_vertices)

                # get feed_dict
                feed_dict = self.get_dynamic_feed_dict(elements, [sampled_node_symbol_one_hot],
                                                       [incre_adj_mat], max_n_vertices, [distance_to_others],
                                                       [overlapped_edge_dense],
                                                       [node_sequence], [edge_type_mask], [edge_mask],
                                                       random_normal_states)

                # fetch nn predictions
                fetch_list = [self.ops['edge_predictions'], self.ops['edge_type_predictions']]
                edge_probs, edge_type_probs = self.sess.run(fetch_list, feed_dict=feed_dict)
                # select an edge
                if not self.params["use_argmax_generation"]:
                    neighbor = np.random.choice(np.arange(max_n_vertices + 1), p=edge_probs[0])
                else:
                    neighbor = np.argmax(edge_probs[0])
                # update log prob
                total_log_prob += np.log(edge_probs[0][neighbor] + SMALL_NUMBER)
                # stop it if stop node is picked
                if neighbor == max_n_vertices:
                    break
                    # or choose an edge type
                if not self.params["use_argmax_generation"]:
                    bond = np.random.choice(np.arange(self.num_edge_types), p=edge_type_probs[0, :, neighbor])
                else:
                    bond = np.argmax(edge_type_probs[0, :, neighbor])
                # update log prob
                total_log_prob += np.log(edge_type_probs[0, :, neighbor][bond] + SMALL_NUMBER)
                # update valences
                valences[node_in_focus] -= (bond + 1)
                valences[neighbor] -= (bond + 1)
                # add the bond
                new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[bond])
                # add the edge to increment adj list
                incre_adj_list[node_in_focus].append((neighbor, bond))
                incre_adj_list[neighbor].append((node_in_focus, bond))
                # Explore neighbor nodes
                if color[neighbor] == 0:
                    queue.append(neighbor)
                    color[neighbor] = 1
            color[node_in_focus] = 2  # explored
        # Remove unconnected node     
        remove_extra_nodes(new_mol)
        new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))
        return new_mol, total_log_prob

    def gradient_ascent(self, random_normal_states, derivative_z_sampled):
        return random_normal_states + self.params['prior_learning_rate'] * derivative_z_sampled

    # optimization in latent space. generate one molecule for each optimization step
    def optimization_over_prior(self, random_normal_states, num_vertices, generated_all_similes, elements, count):
        # record how many optimization steps are taken
        step = 0
        # generate a new molecule
        self.generate_graph_with_state(random_normal_states, num_vertices, generated_all_similes, elements, step, count)
        fetch_list = [self.ops['derivative_z_sampled'], self.ops['qed_computed_values'], self.ops['l2_loss']]
        for _ in range(self.params['optimization_step']):
            # get current qed and derivative
            batch_feed_dict = self.get_dynamic_feed_dict(elements, None, None, num_vertices, None,
                                                         None, None, None, None,
                                                         random_normal_states)
            derivative_z_sampled, qed_computed_values, l2_loss = self.sess.run(fetch_list, feed_dict=batch_feed_dict)
            # update the states
            random_normal_states = self.gradient_ascent(random_normal_states,
                                                        derivative_z_sampled[0])
            # generate a new molecule
            step += 1
            self.generate_graph_with_state(random_normal_states, num_vertices,
                                           generated_all_similes, elements, step, count)
        return random_normal_states

    def generate_graph_with_state(self, random_normal_states, num_vertices,
                                  generated_all_similes, elements, step, count):
        # Get back node symbol predictions
        # Prepare dict
        node_symbol_batch_feed_dict = self.get_dynamic_feed_dict(elements, None, None,
                                                                 num_vertices, None, None, None, None, None,
                                                                 random_normal_states)
        # Get predicted node probs
        predicted_node_symbol_prob = self.get_node_symbol(node_symbol_batch_feed_dict)
        # Node numbers for each graph
        real_length = get_graph_length([elements['mask']])[0]  # [valid_node_number] 
        # Sample node symbols
        sampled_node_symbol = sample_node_symbol(predicted_node_symbol_prob, [real_length], self.params["dataset"])[
            0]  # [v]        
        # Maximum valences for each node
        valences = get_initial_valence(sampled_node_symbol, self.params["dataset"])  # [v]
        # randomly pick the starting point or use zero 
        if not self.params["path_random_order"]:
            # Try different starting points
            if self.params["try_different_starting"]:
                # starting_point=list(range(self.params["num_different_starting"]))
                starting_point = random.sample(range(real_length),
                                               min(self.params["num_different_starting"], real_length))
            else:
                starting_point = [0]
        else:
            if self.params["try_different_starting"]:
                starting_point = random.sample(range(real_length),
                                               min(self.params["num_different_starting"], real_length))
            else:
                starting_point = [random.choice(list(range(real_length)))]  # randomly choose one
        # record all molecules from different starting points
        all_mol = []
        for idx in starting_point:
            # generate a new molecule
            new_mol, total_log_prob = self.search_and_generate_molecule(idx, np.copy(valences),
                                                                        sampled_node_symbol, real_length,
                                                                        random_normal_states, elements, num_vertices)
            # record the molecule with largest number of shapes
            if dataset == 'qm9' and new_mol is not None:
                all_mol.append((np.sum(shape_count(self.params["dataset"], True,
                                                   [Chem.MolToSmiles(new_mol)])[1]), total_log_prob, new_mol))
            # record the molecule with largest number of pentagon and hexagonal for zinc and cep
            elif dataset == 'zinc' and new_mol is not None:
                counts = shape_count(self.params["dataset"], True, [Chem.MolToSmiles(new_mol)])
                all_mol.append((0.5 * counts[1][2] + counts[1][3], total_log_prob, new_mol))
            elif dataset == 'cep' and new_mol is not None:
                all_mol.append((np.sum(shape_count(self.params["dataset"], True,
                                                   [Chem.MolToSmiles(new_mol)])[1][2:]), total_log_prob, new_mol))
        # select one out
        best_mol = select_best(all_mol)
        # nothing generated
        if best_mol is None:
            return
        # visualize it 
        make_dir('visualization_%s' % dataset)
        visualize_mol('visualization_%s/%d_%d.png' % (dataset, count, step), best_mol)
        # record the best molecule
        generated_all_similes.append(Chem.MolToSmiles(best_mol))
        dump('generated_smiles_%s' % (dataset), generated_all_similes)
        print("Real QED value")
        print(QED.qed(best_mol))
        if len(generated_all_similes) >= self.params['number_of_generation']:
            print("generation done")
            exit(0)

    def compensate_node_length(self, elements, bucket_size):
        maximum_length = bucket_size + self.params["compensate_num"]
        real_length = get_graph_length([elements['mask']])[0] + self.params["compensate_num"]
        elements['mask'] = [1] * real_length + [0] * (maximum_length - real_length)
        elements['init'] = np.zeros((maximum_length, self.params["num_symbols"]))
        elements['adj_mat'] = np.zeros((self.num_edge_types, maximum_length, maximum_length))
        return maximum_length

    def generate_new_graphs(self, data):
        # bucketed: data organized by bucket
        (bucketed, bucket_sizes, bucket_at_step) = data
        bucket_counters = defaultdict(int)
        # all generated similes
        generated_all_similes = []
        # counter
        count = 0
        # shuffle the lengths
        np.random.shuffle(bucket_at_step)
        for step in range(len(bucket_at_step)):
            bucket = bucket_at_step[step]  # bucket number
            # data index
            start_idx = bucket_counters[bucket] * self.params['batch_size']
            end_idx = (bucket_counters[bucket] + 1) * self.params['batch_size']
            # batch data
            elements_batch = bucketed[bucket][start_idx:end_idx]
            for elements in elements_batch:
                # compensate for the length during generation 
                # (this is a result that BFS may not make use of all candidate nodes during generation)
                maximum_length = self.compensate_node_length(elements, bucket_sizes[bucket])
                # initial state
                random_normal_states = generate_std_normal(1, maximum_length, \
                                                           self.params['hidden_size'])  # [1, v, h]                
                random_normal_states = self.optimization_over_prior(random_normal_states,
                                                                    maximum_length, generated_all_similes, elements,
                                                                    count)
                count += 1
            bucket_counters[bucket] += 1

    def make_minibatch_iterator(self, data, is_training: bool):
        (bucketed, bucket_sizes, bucket_at_step) = data
        if is_training:
            np.random.shuffle(bucket_at_step)
            for _, bucketed_data in bucketed.items():
                np.random.shuffle(bucketed_data)
        bucket_counters = defaultdict(int)
        dropout_keep_prob = self.params['graph_state_dropout_keep_prob'] if is_training else 1.
        edge_dropout_keep_prob = self.params['edge_weight_dropout_keep_prob'] if is_training else 1.
        for step in range(len(bucket_at_step)):
            bucket = bucket_at_step[step]
            start_idx = bucket_counters[bucket] * self.params['batch_size']
            end_idx = (bucket_counters[bucket] + 1) * self.params['batch_size']
            elements = bucketed[bucket][start_idx:end_idx]
            batch_data = self.make_batch(elements, bucket_sizes[bucket])

            num_graphs = len(batch_data['init'])
            initial_representations = batch_data['init']
            initial_representations = self.pad_annotations(initial_representations)
            batch_feed_dict = {
                self.placeholders['initial_node_representation']: initial_representations,
                self.placeholders['node_symbols']: batch_data['init'],
                self.placeholders['latent_node_symbols']: initial_representations,
                self.placeholders['target_values']: np.transpose(batch_data['labels'], axes=[1, 0]),
                self.placeholders['target_mask']: np.transpose(batch_data['task_masks'], axes=[1, 0]),
                self.placeholders['num_graphs']: num_graphs,
                self.placeholders['num_vertices']: bucket_sizes[bucket],
                self.placeholders['adjacency_matrix']: batch_data['adj_mat'],
                self.placeholders['node_mask']: batch_data['node_mask'],
                self.placeholders['graph_state_keep_prob']: dropout_keep_prob,
                self.placeholders['edge_weight_dropout_keep_prob']: edge_dropout_keep_prob,
                self.placeholders['iteration_mask']: batch_data['iteration_mask'],
                self.placeholders['incre_adj_mat']: batch_data['incre_adj_mat'],
                self.placeholders['distance_to_others']: batch_data['distance_to_others'],
                self.placeholders['node_sequence']: batch_data['node_sequence'],
                self.placeholders['edge_type_masks']: batch_data['edge_type_masks'],
                self.placeholders['edge_type_labels']: batch_data['edge_type_labels'],
                self.placeholders['edge_masks']: batch_data['edge_masks'],
                self.placeholders['edge_labels']: batch_data['edge_labels'],
                self.placeholders['local_stop']: batch_data['local_stop'],
                self.placeholders['max_iteration_num']: batch_data['max_iteration_num'],
                self.placeholders['kl_trade_off_lambda']: self.params['kl_trade_off_lambda'],
                self.placeholders['overlapped_edge_features']: batch_data['overlapped_edge_features']
            }
            bucket_counters[bucket] += 1
            yield batch_feed_dict


# if __name__ == "__main__":
#     args = docopt(__doc__)
#     dataset = args.get('--dataset')
#     try:
#         model = DenseGGNNChemModel(args)
#         evaluation = False
#         if evaluation:
#             model.example_evaluation()
#         else:
#             model.train()
#     except:
#         typ, value, tb = sys.exc_info()
#         traceback.print_exc()
#         pdb.post_mortem(tb)

#####################  data_augmentation.py

from utils import *
from copy import deepcopy


# Generate the mask based on the valences and adjacent matrix so far
# For a (node_in_focus, neighbor, edge_type) to be valid, neighbor's color < 2 and 
# there is no edge so far between node_in_focus and neighbor and it satisfy the valence constraint
# and node_in_focus != neighbor 
def generate_mask(valences, adj_mat, color, real_n_vertices, node_in_focus, check_overlap_edge, new_mol):
    edge_type_mask = []
    edge_mask = []
    for neighbor in range(real_n_vertices):
        if neighbor != node_in_focus and color[neighbor] < 2 and \
                not check_adjacent_sparse(adj_mat, node_in_focus, neighbor)[0]:
            min_valence = min(valences[node_in_focus], valences[neighbor], 3)
            # Check whether two cycles have more than two overlap edges here
            # the neighbor color = 1 and there are left valences and 
            # adding that edge will not cause overlap edges.
            if check_overlap_edge and min_valence > 0 and color[neighbor] == 1:
                # attempt to add the edge
                new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[0])
                # Check whether there are two cycles having more than two overlap edges
                ssr = Chem.GetSymmSSSR(new_mol)
                overlap_flag = False
                for idx1 in range(len(ssr)):
                    for idx2 in range(idx1 + 1, len(ssr)):
                        if len(set(ssr[idx1]) & set(ssr[idx2])) > 2:
                            overlap_flag = True
                # remove that edge
                new_mol.RemoveBond(int(node_in_focus), int(neighbor))
                if overlap_flag:
                    continue
            for v in range(min_valence):
                assert v < 3
                edge_type_mask.append((node_in_focus, neighbor, v))
            # there might be an edge between node in focus and neighbor
            if min_valence > 0:
                edge_mask.append((node_in_focus, neighbor))
    return edge_type_mask, edge_mask


# when a new edge is about to be added, we generate labels based on ground truth
# if an edge is in ground truth and has not been added to incremental adj yet, we label it as positive
def generate_label(ground_truth_graph, incremental_adj, node_in_focus, real_neighbor, real_n_vertices, params):
    edge_type_label = []
    edge_label = []
    for neighbor in range(real_n_vertices):
        adjacent, edge_type = check_adjacent_sparse(ground_truth_graph, node_in_focus, neighbor)
        incre_adjacent, incre_edge_type = check_adjacent_sparse(incremental_adj, node_in_focus, neighbor)
        if not params["label_one_hot"] and adjacent and not incre_adjacent:
            assert edge_type < 3
            edge_type_label.append((node_in_focus, neighbor, edge_type))
            edge_label.append((node_in_focus, neighbor))
        elif params["label_one_hot"] and adjacent and not incre_adjacent and neighbor == real_neighbor:
            edge_type_label.append((node_in_focus, neighbor, edge_type))
            edge_label.append((node_in_focus, neighbor))
    return edge_type_label, edge_label


# add a incremental adj with one new edge
def genereate_incremental_adj(last_adj, node_in_focus, neighbor, edge_type):
    # copy last incremental adj matrix
    new_adj = deepcopy(last_adj)
    # Add a new edge into it
    new_adj[node_in_focus].append((neighbor, edge_type))
    new_adj[neighbor].append((node_in_focus, edge_type))
    return new_adj


def update_one_step(overlapped_edge_features, distance_to_others, node_sequence, node_in_focus, neighbor, edge_type,
                    edge_type_masks, valences, incremental_adj_mat,
                    color, real_n_vertices, graph, edge_type_labels, local_stop, edge_masks, edge_labels,
                    local_stop_label, params,
                    check_overlap_edge, new_mol, up_to_date_adj_mat, keep_prob):
    # check whether to keep this transition or not
    if params["sample_transition"] and random.random() > keep_prob:
        return
    # record the current node in focus
    node_sequence.append(node_in_focus)
    # generate mask based on current situation
    edge_type_mask, edge_mask = generate_mask(valences, up_to_date_adj_mat,
                                              color, real_n_vertices, node_in_focus, check_overlap_edge, new_mol)
    edge_type_masks.append(edge_type_mask)
    edge_masks.append(edge_mask)
    if not local_stop_label:
        # generate the label based on ground truth graph
        edge_type_label, edge_label = generate_label(graph, up_to_date_adj_mat, node_in_focus, neighbor,
                                                     real_n_vertices, params)
        edge_type_labels.append(edge_type_label)
        edge_labels.append(edge_label)
    else:
        edge_type_labels.append([])
        edge_labels.append([])
    # update local stop 
    local_stop.append(local_stop_label)
    # Calculate distance using bfs from the current node to all other node
    distances = bfs_distance(node_in_focus, up_to_date_adj_mat)
    distances = [(start, node, params["truncate_distance"]) if d > params["truncate_distance"] else (start, node, d) for
                 start, node, d in distances]
    distance_to_others.append(distances)
    # Calculate the overlapped edge mask
    overlapped_edge_features.append(get_overlapped_edge_feature(edge_mask, color, new_mol))
    # update the incremental adj mat at this step
    incremental_adj_mat.append(deepcopy(up_to_date_adj_mat))


def construct_incremental_graph(dataset, edges, max_n_vertices, real_n_vertices, node_symbol, params, initial_idx=0):
    # avoid calculating this if it is just for generating new molecules for speeding up
    if params["generation"]:
        return [], [], [], [], [], [], [], [], []
    # avoid the initial index is larger than real_n_vertices:
    if initial_idx >= real_n_vertices:
        initial_idx = 0
    # Maximum valences for each node
    valences = get_initial_valence([np.argmax(symbol) for symbol in node_symbol], dataset)
    # Add backward edges
    edges_bw = [(dst, edge_type, src) for src, edge_type, dst in edges]
    edges = edges + edges_bw
    # Construct a graph object using the edges
    graph = defaultdict(list)
    for src, edge_type, dst in edges:
        graph[src].append((dst, edge_type))
    # Breadth first search over the molecule 
    # color 0: have not found 1: in the queue 2: searched already
    color = [0] * max_n_vertices
    color[initial_idx] = 1
    queue = deque([initial_idx])
    # create a adj matrix without any edges
    up_to_date_adj_mat = defaultdict(list)
    # record incremental adj mat
    incremental_adj_mat = []
    # record the distance to other nodes at the moment
    distance_to_others = []
    # soft constraint on overlapped edges
    overlapped_edge_features = []
    # the exploration order of the nodes
    node_sequence = []
    # edge type masks for nn predictions at each step
    edge_type_masks = []
    # edge type labels for nn predictions at each step
    edge_type_labels = []
    # edge masks for nn predictions at each step
    edge_masks = []
    # edge labels for nn predictions at each step
    edge_labels = []
    # local stop labels
    local_stop = []
    # record the incremental molecule
    new_mol = Chem.MolFromSmiles('')
    new_mol = Chem.rdchem.RWMol(new_mol)
    # Add atoms
    add_atoms(new_mol, sample_node_symbol([node_symbol], [len(node_symbol)], dataset)[0], dataset)
    # calculate keep probability
    sample_transition_count = real_n_vertices + len(edges) / 2
    keep_prob = float(sample_transition_count) / (
            (real_n_vertices + len(edges) / 2) * params["bfs_path_count"])  # to form a binomial distribution
    while len(queue) > 0:
        node_in_focus = queue.popleft()
        current_adj_list = graph[node_in_focus]
        # sort (canonical order) it or shuffle (random order) it 
        if not params["path_random_order"]:
            current_adj_list = sorted(current_adj_list)
        else:
            random.shuffle(current_adj_list)
        for neighbor, edge_type in current_adj_list:
            # Add this edge if the color of neighbor node is not 2
            if color[neighbor] < 2:
                update_one_step(overlapped_edge_features, distance_to_others, node_sequence, node_in_focus, neighbor,
                                edge_type,
                                edge_type_masks, valences, incremental_adj_mat, color, real_n_vertices, graph,
                                edge_type_labels, local_stop, edge_masks, edge_labels, False, params,
                                params["check_overlap_edge"], new_mol,
                                up_to_date_adj_mat, keep_prob)
                # Add the edge and obtain a new adj mat
                up_to_date_adj_mat = genereate_incremental_adj(
                    up_to_date_adj_mat, node_in_focus, neighbor, edge_type)
                # suppose the edge is selected and update valences after adding the 
                valences[node_in_focus] -= (edge_type + 1)
                valences[neighbor] -= (edge_type + 1)
                # update the incremental mol
                new_mol.AddBond(int(node_in_focus), int(neighbor), number_to_bond[edge_type])
            # Explore neighbor nodes
            if color[neighbor] == 0:
                queue.append(neighbor)
                color[neighbor] = 1
        # local stop here. We move on to another node for exploration or stop completely
        update_one_step(overlapped_edge_features, distance_to_others, node_sequence, node_in_focus, None, None,
                        edge_type_masks,
                        valences, incremental_adj_mat, color, real_n_vertices, graph,
                        edge_type_labels, local_stop, edge_masks, edge_labels, True, params,
                        params["check_overlap_edge"], new_mol, up_to_date_adj_mat, keep_prob)
        color[node_in_focus] = 2

    return incremental_adj_mat, distance_to_others, node_sequence, edge_type_masks, edge_type_labels, local_stop, edge_masks, edge_labels, overlapped_edge_features


##################  evaluate.py

# !/usr/bin/env/python
"""
Usage:
    evaluate.py --dataset zinc|qm9|cep

Options:
    -h --help                Show this screen.
    --dataset NAME           Dataset name: zinc, qm9, cep
"""

# if __name__ == '__main__':
#     args = docopt(__doc__)
#     dataset=args.get('--dataset')
#     logpscorer, logp_score_per_molecule=utils.check_logp(dataset)
#     qedscorer, qed_score_per_molecule=utils.check_qed(dataset)
#     novelty=utils.novelty_metric(dataset)
#     total, nonplanar=utils.check_planar(dataset)
#     total, atom_counter, atom_per_molecule =utils.count_atoms(dataset)
#     total, edge_type_counter, edge_type_per_molecule=utils.count_edge_type(dataset)
#     total, shape_count, shape_count_per_molecule=utils.shape_count(dataset)
#     total, tree_count=utils.check_cyclic(dataset)    
#     sascorer, sa_score_per_molecule=utils.check_sascorer(dataset)
#     total, validity=utils.check_validity(dataset)
# 
#     print("------------------------------------------")
#     print("Metrics")
#     print("------------------------------------------")
#     print("total molecule")
#     print(total)
#     print("------------------------------------------")
#     print("percentage of nonplanar:")
#     print(nonplanar/total)
#     print("------------------------------------------")
#     print("avg atom:")
#     for atom_type, c in atom_counter.items():
#         print(dataset_info(dataset)['atom_types'][atom_type])
#         print(c/total)
#     print("standard deviation")
#     print(np.std(atom_per_molecule, axis=0))
#     print("------------------------------------------")
#     print("avg edge_type:")
#     for edge_type, c in edge_type_counter.items():
#         print(edge_type+1)
#         print(c/total)
#     print("standard deviation")
#     print(np.std(edge_type_per_molecule, axis=0))
#     print("------------------------------------------")
#     print("avg shape:")
#     for shape, c in zip(utils.geometry_numbers, shape_count):
#         print(shape)
#         print(c/total)
#     print("standard deviation")
#     print(np.std(shape_count_per_molecule, axis=0))
#     print("------------------------------------------")
#     print("percentage of tree:")
#     print(tree_count/total)
#     print("------------------------------------------")
#     print("percentage of validity:")
#     print(validity/total)
#     print("------------------------------------------")
#     print("avg sa_score:")
#     print(sascorer)
#     print("standard deviation")
#     print(np.std(sa_score_per_molecule))
#     print("------------------------------------------")
#     print("avg logp_score:")
#     print(logpscorer)
#     print("standard deviation")
#     print(np.std(logp_score_per_molecule))
#     print("------------------------------------------")
#     print("percentage of novelty:")
#     print(novelty)
#     print("------------------------------------------")
#     print("avg qed_score:")
#     print(qedscorer)
#     print("standard deviation")
#     print(np.std(qed_score_per_molecule))
#     print("------------------------------------------")
#     print("uniqueness")
#     print(utils.check_uniqueness(dataset))
#     print("------------------------------------------")
#     print("percentage of SSSR")
#     print(utils.sssr_metric(dataset))


################  CGNN_core.py


# !/usr/bin/env/python

from typing import List, Any
import time
import json
import random
import utils
from utils import MLP, dataset_info, ThreadedIterator, SMALL_NUMBER, LARGE_NUMBER, graph_to_adj_mat


class ChemModel(object):
    @classmethod
    def default_params(cls):
        return {

        }

    def __init__(self, args):
        self.args = args

        # Collect argument things:
        data_dir = ''
        if '--data_dir' in args and args['--data_dir'] is not None:
            data_dir = args['--data_dir']
        self.data_dir = data_dir

        # Collect parameters:
        params = self.default_params()
        config_file = args.get('--config-file')
        if config_file is not None:
            with open(config_file, 'r') as f:
                params.update(json.load(f))
        config = args.get('--config')
        if config is not None:
            params.update(json.loads(config))
        self.params = params

        # Get which dataset in use
        self.params['dataset'] = dataset = args.get('--dataset')
        # Number of atom types of this dataset
        self.params['num_symbols'] = len(dataset_info(dataset)["atom_types"])

        self.run_id = "_".join([time.strftime("%Y-%m-%d-%H-%M-%S"), str(os.getpid())])
        log_dir = args.get('--log_dir') or '.'
        self.log_file = os.path.join(log_dir, "%s_log_%s.json" % (self.run_id, dataset))
        self.best_model_file = os.path.join(log_dir, "%s_model.pickle" % self.run_id)

        with open(os.path.join(log_dir, "%s_params_%s.json" % (self.run_id, dataset)), "w") as f:
            json.dump(params, f)
        print("Run %s starting with following parameters:\n%s" % (self.run_id, json.dumps(self.params)))
        random.seed(params['random_seed'])
        np.random.seed(params['random_seed'])

        # Load data:
        self.max_num_vertices = 0
        self.num_edge_types = 0
        self.annotation_size = 0
        self.train_data = self.load_data(params['train_file'], is_training_data=True)
        self.valid_data = self.load_data(params['valid_file'], is_training_data=False)

        # Build the actual model
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph, config=config)
        with self.graph.as_default():
            tf.set_random_seed(params['random_seed'])
            self.placeholders = {}
            self.weights = {}
            self.ops = {}
            self.make_model()
            self.make_train_step()

            # Restore/initialize variables:
            restore_file = args.get('--restore')
            if restore_file is not None:
                self.restore_model(restore_file)
            else:
                self.initialize_model()

    def load_data(self, file_name, is_training_data: bool):
        full_path = os.path.join(self.data_dir, file_name)

        print("Loading data from %s" % full_path)
        with open(full_path, 'r') as f:
            data = json.load(f)

        restrict = self.args.get("--restrict_data")
        if restrict is not None and restrict > 0:
            data = data[:restrict]

        # Get some common data out:
        num_fwd_edge_types = len(utils.bond_dict) - 1
        for g in data:
            self.max_num_vertices = max(self.max_num_vertices, max([v for e in g['graph'] for v in [e[0], e[2]]]))

        self.num_edge_types = max(self.num_edge_types, num_fwd_edge_types * (1 if self.params['tie_fwd_bkwd'] else 2))
        self.annotation_size = max(self.annotation_size, len(data[0]["node_features"][0]))

        return self.process_raw_graphs(data, is_training_data, file_name)

    @staticmethod
    def graph_string_to_array(graph_string: str) -> List[List[int]]:
        return [[int(v) for v in s.split(' ')]
                for s in graph_string.split('\n')]

    def process_raw_graphs(self, raw_data, is_training_data, file_name, bucket_sizes=None):
        raise Exception("Models have to implement process_raw_graphs!")

    def make_model(self):
        self.placeholders['target_values'] = tf.placeholder(tf.float32, [len(self.params['task_ids']), None],
                                                            name='target_values')
        self.placeholders['target_mask'] = tf.placeholder(tf.float32, [len(self.params['task_ids']), None],
                                                          name='target_mask')
        self.placeholders['num_graphs'] = tf.placeholder(tf.int64, [], name='num_graphs')
        self.placeholders['out_layer_dropout_keep_prob'] = tf.placeholder(tf.float32, [],
                                                                          name='out_layer_dropout_keep_prob')
        # whether this session is for generating new graphs or not
        self.placeholders['is_generative'] = tf.placeholder(tf.bool, [], name='is_generative')

        with tf.variable_scope("graph_model"):
            self.prepare_specific_graph_model()

            # Initial state: embedding
            initial_state = self.get_node_embedding_state(self.placeholders['initial_node_representation'])

            # This does the actual graph work:
            if self.params['use_graph']:
                if self.params["residual_connection_on"]:
                    self.ops['final_node_representations'] = self.compute_final_node_representations_with_residual(
                        initial_state,
                        tf.transpose(self.placeholders['adjacency_matrix'], [1, 0, 2, 3]),
                        "_encoder")
                else:
                    self.ops['final_node_representations'] = self.compute_final_node_representations_without_residual(
                        initial_state,
                        tf.transpose(self.placeholders['adjacency_matrix'], [1, 0, 2, 3]),
                        self.weights['edge_weights_encoder'],
                        self.weights['edge_biases_encoder'], self.weights['node_gru_encoder'], "gru_scope_encoder")
            else:
                self.ops['final_node_representations'] = initial_state

        # Calculate p(z|x)'s mean and log variance
        self.ops['mean'], self.ops['logvariance'] = self.compute_mean_and_logvariance()
        # Sample from a gaussian distribution according to the mean and log variance
        self.ops['z_sampled'] = self.sample_with_mean_and_logvariance()
        # Construct logit matrices for both edges and edge types
        self.construct_logit_matrices()

        # Obtain losses for edges and edge types
        self.ops['qed_loss'] = []
        self.ops['loss'] = self.construct_loss()

    def make_train_step(self):
        trainable_vars = self.sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        if self.args.get('--freeze-graph-model'):
            graph_vars = set(self.sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="graph_model"))
            filtered_vars = []
            for var in trainable_vars:
                if var not in graph_vars:
                    filtered_vars.append(var)
                else:
                    print("Freezing weights of variable %s." % var.name)
            trainable_vars = filtered_vars

        optimizer = tf.train.AdamOptimizer(self.params['learning_rate'])
        grads_and_vars = optimizer.compute_gradients(self.ops['loss'], var_list=trainable_vars)
        clipped_grads = []
        for grad, var in grads_and_vars:
            if grad is not None:
                clipped_grads.append((tf.clip_by_norm(grad, self.params['clamp_gradient_norm']), var))
            else:
                clipped_grads.append((grad, var))
        grads_for_display = []
        for grad, var in grads_and_vars:
            if grad is not None:
                grads_for_display.append((tf.clip_by_norm(grad, self.params['clamp_gradient_norm']), var))
        self.ops['grads'] = grads_for_display
        self.ops['train_step'] = optimizer.apply_gradients(clipped_grads)
        # Initialize newly-introduced variables:
        self.sess.run(tf.local_variables_initializer())

    def gated_regression(self, last_h, regression_gate, regression_transform):
        raise Exception("Models have to implement gated_regression!")

    def prepare_specific_graph_model(self) -> None:
        raise Exception("Models have to implement prepare_specific_graph_model!")

    def compute_mean_and_logvariance(self):
        raise Exception("Models have to implement compute_mean_and_logvariance!")

    def sample_with_mean_and_logvariance(self):
        raise Exception("Models have to implement sample_with_mean_and_logvariance!")

    def construct_logit_matrices(self):
        raise Exception("Models have to implement construct_logit_matrices!")

    def construct_loss(self):
        raise Exception("Models have to implement construct_loss!")

    def make_minibatch_iterator(self, data: Any, is_training: bool):
        raise Exception("Models have to implement make_minibatch_iterator!")

    """
    def save_intermediate_results(self, adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels, mean, logvariance):
        with open('intermediate_results_%s' % self.params["dataset"], 'wb') as out_file:
            pickle.dump([adjacency_matrix, edge_type_prob, edge_type_label, node_symbol_prob, node_symbol, edge_prob, edge_prob_label, qed_prediction, qed_labels, mean, logvariance], out_file, pickle.HIGHEST_PROTOCOL)
    """

    def save_probs(self, all_results):
        with open('epoch_prob_matices_%s' % self.params["dataset"], 'wb') as out_file:
            pickle.dump([all_results], out_file, pickle.HIGHEST_PROTOCOL)

    def run_epoch(self, epoch_name: str, epoch_num, data, is_training: bool):
        loss = 0
        start_time = time.time()
        processed_graphs = 0
        batch_iterator = ThreadedIterator(self.make_minibatch_iterator(data, is_training), max_queue_size=5)

        for step, batch_data in enumerate(batch_iterator):
            num_graphs = batch_data[self.placeholders['num_graphs']]
            processed_graphs += num_graphs
            batch_data[self.placeholders['is_generative']] = False
            # Randomly sample from normal distribution
            batch_data[self.placeholders['z_prior']] = utils.generate_std_normal( \
                self.params['batch_size'], batch_data[self.placeholders['num_vertices']], self.params['hidden_size'])
            if is_training:
                batch_data[self.placeholders['out_layer_dropout_keep_prob']] = self.params[
                    'out_layer_dropout_keep_prob']
                fetch_list = [self.ops['loss'], self.ops['train_step'],
                              self.ops["edge_loss"], self.ops['kl_loss'],
                              self.ops['node_symbol_prob'], self.placeholders['node_symbols'],
                              self.ops['qed_computed_values'], self.placeholders['target_values'],
                              self.ops['total_qed_loss'],
                              self.ops['mean'], self.ops['logvariance'],
                              self.ops['grads'], self.ops['mean_edge_loss'], self.ops['mean_node_symbol_loss'],
                              self.ops['mean_kl_loss'], self.ops['mean_total_qed_loss']]
            else:
                batch_data[self.placeholders['out_layer_dropout_keep_prob']] = 1.0
                fetch_list = [self.ops['mean_edge_loss'], self.ops['accuracy_task0']]
            result = self.sess.run(fetch_list, feed_dict=batch_data)

            """try:
                if is_training:
                    self.save_intermediate_results(batch_data[self.placeholders['adjacency_matrix']], 
                        result[11], result[12], result[4], result[5], result[9], result[10], result[6], result[7], result[13], result[14])
            except IndexError:
                pass"""

            batch_loss = result[0]
            loss += batch_loss * num_graphs

            print("Running %s, batch %i (has %i graphs). Loss so far: %.4f" % (epoch_name,
                                                                               step,
                                                                               num_graphs,
                                                                               loss / processed_graphs), end='\r')
        loss = loss / processed_graphs
        instance_per_sec = processed_graphs / (time.time() - start_time)
        return loss, instance_per_sec

    def generate_new_graphs(self, data):
        raise Exception("Models have to implement generate_new_graphs!")

    def train(self):
        log_to_save = []
        total_time_start = time.time()
        with self.graph.as_default():
            for epoch in range(1, self.params['num_epochs'] + 1):
                if not self.params['generation']:
                    print("== Epoch %i" % epoch)

                    train_loss, train_speed = self.run_epoch("epoch %i (training)" % epoch, epoch,
                                                             self.train_data, True)
                    print("\r\x1b[K Train: loss: %.5f| instances/sec: %.2f" % (train_loss, train_speed))

                    valid_loss, valid_speed = self.run_epoch("epoch %i (validation)" % epoch, epoch,
                                                             self.valid_data, False)

                    print("\r\x1b[K Valid: loss: %.5f | instances/sec: %.2f" % (valid_loss, valid_speed))

                    epoch_time = time.time() - total_time_start

                    log_entry = {
                        'epoch': epoch,
                        'time': epoch_time,
                        'train_results': (train_loss, train_speed),
                    }
                    log_to_save.append(log_entry)
                    with open(self.log_file, 'w') as f:
                        json.dump(log_to_save, f, indent=4)
                    self.save_model(str(epoch) + ("_%s.pickle" % (self.params["dataset"])))
                # Run epoches for graph generation
                if epoch >= self.params['epoch_to_generate']:
                    self.generate_new_graphs(self.train_data)

    def save_model(self, path: str) -> None:
        weights_to_save = {}
        for variable in self.sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
            assert variable.name not in weights_to_save
            weights_to_save[variable.name] = self.sess.run(variable)

        data_to_save = {
            "params": self.params,
            "weights": weights_to_save
        }

        with open(path, 'wb') as out_file:
            pickle.dump(data_to_save, out_file, pickle.HIGHEST_PROTOCOL)

    def initialize_model(self) -> None:
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        self.sess.run(init_op)

    def restore_model(self, path: str) -> None:
        print("Restoring weights from file %s." % path)
        with open(path, 'rb') as in_file:
            data_to_load = pickle.load(in_file)

        variables_to_initialize = []
        with tf.name_scope("restore"):
            restore_ops = []
            used_vars = set()
            for variable in self.sess.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
                used_vars.add(variable.name)
                if variable.name in data_to_load['weights']:
                    restore_ops.append(variable.assign(data_to_load['weights'][variable.name]))
                else:
                    print('Freshly initializing %s since no saved value was found.' % variable.name)
                    variables_to_initialize.append(variable)
            for var_name in data_to_load['weights']:
                if var_name not in used_vars:
                    print('Saved weights for %s not used by model.' % var_name)
            restore_ops.append(tf.variables_initializer(variables_to_initialize))
            self.sess.run(restore_ops)


In [None]:
#Imports
import tensorflow as tf
from typing import Sequence, Any
from docopt import docopt
from collections import defaultdict, deque
import sys, traceback
import pdb
# from CGVAE.CGVAE import DenseGGNNChemModel
# from CGVAE.GGNN_core import ChemModel
# import CGVAE.utils
# from CGVAE.utils import *
# from CGVAE.data_augmentation import *
from numpy import linalg as LA
from copy import deepcopy


In [None]:
#Prepare the dataset
def train_valid_split(download_path):
    # load validation dataset
    with open("/kaggle/working/2019-nCov/Data/valid_idx_zinc.json", 'r') as f:
        valid_idx = json.load(f)

    print('reading data...')
    raw_data = {'train': [], 'valid': []} # save the train, valid dataset.
    with open(download_path, 'r') as f:
        all_data = list(csv.DictReader(f))

    file_count=0
    for i, data_item in enumerate(all_data):
        smiles = data_item['smiles'].strip()
        QED = float(data_item['qed'])
        if i not in valid_idx:
            raw_data['train'].append({'smiles': smiles, 'QED': QED})
        else:
            raw_data['valid'].append({'smiles': smiles, 'QED': QED})
        file_count += 1
        if file_count % 2000 ==0:
            print('finished reading: %d' % file_count, end='\r')
    return raw_data

def preprocess(raw_data, dataset):
    print('parsing smiles as graphs...')
    processed_data = {'train': [], 'valid': []}
    
    file_count = 0
    for section in ['train', 'valid']:
        all_smiles = [] # record all smiles in training dataset
        for i,(smiles, QED) in enumerate([(mol['smiles'], mol['QED']) 
                                          for mol in raw_data[section]]):
            nodes, edges = to_graph(smiles, dataset)
            if len(edges) <= 0:
                continue
            processed_data[section].append({
                'targets': [[(QED)]],
                'graph': edges,
                'node_features': nodes,
                'smiles': smiles
            })
            all_smiles.append(smiles)
            if file_count % 2000 == 0:
                print('finished processing: %d' % file_count, end='\r')
            file_count += 1
        print('%s: 100 %%      ' % (section))
        # save the dataset
        with open('/kaggle/working/2019-nCov/Data/molecules_%s_%s.json' % (section, dataset), 'w') as f:
            json.dump(processed_data[section], f)
        # save all molecules in the training dataset
        if section == 'train':
            CGVAE.utils.dump('/kaggle/working/2019-nCov/Data/smiles_%s.pkl' % dataset, all_smiles)  

In [None]:
path = '/kaggle/working/2019-nCov/Data/250k_rndm_zinc_drugs_clean_3.csv'
raw_data = train_valid_split(path)
preprocess(raw_data, 'zinc')

In [None]:
# The various arguments for the CGVAE - see the implementation for explanations.
# Keep in mind I also trained this model for 10 epochs ahead of time and the weights are found in 
# CGVAE/10_zinc.pickle (remove from arguments blow if you want to train it yourself!) - you can
# more parameters and defaults are found in CGVAE.py
args = {'--config': None,
 '--config-file': None,
 '--data_dir': '/kaggle/working/2019-nCov/Data/',
 '--dataset': 'zinc',
 '--freeze-graph-model': False,
 '--help': False,
 '--log_dir': '/kaggle/working/2019-nCov/CGVAE/',
 '--restore': '/kaggle/working/2019-nCov/CGVAE/10_zinc.pickle'}

In [None]:
# Implementation here is quit straightforawrd due to some lucky implementation choices by the original authors
# I highly recommend checking out their implementation in the CGVAE/ directory, it's great stuff!
# Also, I only have a graphics card able to do a mini-batch size of ONE! so be sure to raise that if
# you've got a better card!
# Also, model hyperparameters can be found in 
model = DenseGGNNChemModel(args)
model.train()

## 6. Validation - Molecular Docking Studies Using Autodock Vina

### Re-docking the N3 Ligand as a baseline

First step in the validation of proposed structures is to re-dock the N3 ligand into the protease structure, to get a baseline for the energy score associated with their binding. It is important to note that the N3 ligand shown in the X-Ray structure is a covalent inhibitor, which means it actually reacts with the active site of the protein. This results in a much stronger bond between ligand and target than non-covalent inhibition. 

This means that the re-docked structure may not be the same as the x-ray, since it lacks the covalent bond to the protein. The following in a procedure for the docking of the N3 ligand into the protein receptor. The same procedure is used for all fo the alter dockings of the candidate molecules. The procedure requires 3 programs:

Pymol - https://pymol.org/2/ or https://github.com/schrodinger/pymol-open-source

Autodock Vina - http://vina.scripps.edu/download.html

Autodock Tools, found in MGL tools - http://mgltools.scripps.edu/downloads

Two excellent video tutorials are found here: https://youtu.be/-GVZP0X0Tg8 and https://youtu.be/blxSn3Lhdec

Docking procedure with Autodock Vina:

1. Open the structure of the protein and ligand complex (.cif crystallographic information file)

2. Select the ligand chain (in the bottom right, click "residues" so that it swtiches to "chains" to be able to select a chain)

3. Delete the ligand, and save the file as a .pdb

4. re-load the original file and this time select the protein and delete, saving only the ligand, also as a .pdb file

5. Open autodock tools, load the protein target molecule with File>Read Molecule (.pdb file)

6. Add hydrogens (Edit>Hydrogens>Add>Polar_only>Okay)

7. View Mol surface (mention binding site)

8. Select the 10 residues involved in the binding site: 
    THR26,LEU27,HIS41,MET49,ASN142,CYS145, HIS163
	GLU166, HIS172, GLN189
    
9. Go to Grid>Gridbox and show the gridbox, then manipulate it by changing the center and size so that it's completely enclosing the selected sidechains. Remember the coordinates. They should be: center: x=-11.963,y=15.683,z=69.193, spacing: 1A, points: x=20, y=24, z=22

9. Flexible Residues>choose molecule

10. Flexible Residues>Choose Torsions in selected residues> then accept the defaults. Should see various bonds on the 10 selected residues be different colours. THIS IS A VERY IMPORTANT STEP - NOT CONSIDERING FLEXIBILITY IN THE PROTEIN WILL AFFECT ACCURACY OF THE SCORING

11. Flexible Residues>Output>SaveRigid. Save the rigid part of the protein as a .pdbqt file

12. Flexible Residues>Output>Saveflexible. Save the flexible part of the protein as a .pdbqt file

13. Now delete or hide the receptor and load the ligand with Ligand>input>open>ligand.pdb>ok

14. add hydrogens to the ligand edit>hydrogens>add>polar_only

15. export this as a pdb file (with hydrogens now)

16. re-load the updated pdb file

17. Define the rotatable bonds Ligand>TorsionTree>ChooseTorsions>okay

18. Ligand>Output> Save as PDBQT

19. Close autodock tools

20. Create a configuation file for vina that matches the structure of conf.txt found in the Docking/ directory of this repo. Exhaustiveness is proportional to time and is how thoroughly the conformational space is searched.
21. Run vina using ./vina --config conf.txt

In [None]:
#Visualizing the N3 ligand
Chem.MolFromSmiles("CC(C)C[C@H](NC(=O)[C@@H](NC(=O)[C@H](C)NC(=O)c1cc(C)on1)C(C)C)C(=O)N[C@@H](C[C@@H]2CCNC2=O)\C=C/C(=O)OCc3ccccc3")

Following this procedure for the N3 ligand, we end up with a final lowest energy minimum of around -7.9kcal/mol. The exact value doesn't tell us much, because the specific parameters of the docking scoring function can vary, but this serves as a baseline for comparison of later candidates. The following is the lowest energy stucture. You can see that it is in fact very different from the X-ray structure due to the lack of the covalent bond to the protein, with the N3 ligand sort of "bending back" in this conformation

Now for docking the candidates. The same procedure as above was followed for each of the candidates, with the additional step below of loading the structures and saving them as PDB files, to be opened in AutoDockTools

### Preparing the high scoring and generated compounds for docking

In [None]:
best_predicted = pickle.load(open("/kaggle/working/2019-nCov/Data/best_predicted_smiles.pkl", "rb"))

In [None]:
best_predicted_mols = [Chem.MolFromSmiles(x) for x in best_predicted]

In [None]:
rdkit.Chem.Draw.MolsToGridImage(best_predicted_mols, molsPerRow=2, maxMols=100, subImgSize=(800, 800))

These compounds are visually very similar to the N3 ligand (see visualization above). Maybe searching for similar compounds was a poor move.

In [None]:
def write_to_pdb(m, name):
    m = Chem.AddHs(m)
    Chem.EmbedMolecule(m)
    w = Chem.rdmolfiles.PDBWriter(open("/kaggle/working/2019-nCov/Docking/"+ str(name) + ".pdb", "w"))
    w.write(m)

In [None]:
for i in range(len(best_predicted_mols)):
    write_to_pdb(best_predicted_mols[i], "bp_" + str(i+1))

In [None]:
generated = pickle.load(open("/kaggle/working/2019-nCov/Data/first_generated_smiles_zinc",
                             "rb"))
generated = [Chem.MolFromSmiles(x) for x in generated]
rdkit.Chem.Draw.MolsToGridImage(generated, molsPerRow=2, maxMols=100, subImgSize=(800, 800))

The generated compounds from this method are VERY strange for the most part, many with large strange rings. This is interesting because this papaer is a realtively early example of generating molecular graphs and in the past little while have used a penalty on large rings such as the ones seen in these compounds. However, all is not lost, because several of these compounds still have interesting structur and are small and comparable to the N3 ligand.

In these 50 generated compounds, the ones that appeared at least the most visually similar to the n3 ligand (mainly just the small ones, which there aren't many) are: indexes: [5,6,9,32,33,34,36,38,44]

In [None]:
for i in [5,6,9,32,33,34,36,38,43,44]:
    write_to_pdb(generated[i], str(i))

After being prepared, the ligands were docked using autodock vina and the script multi_dock.sh to automate the process of docking many compounds. 

In [None]:
#Preparation of the config files for the command line running of autodock vina
# this is only for the "best predicted" compounds - I already docked the other ones before
# I made this script but they could easily be changed.
names = ["bp_1", "bp_2", "bp_3", "bp_4", "bp_5","bp_6", "bp_7", "bp_8", "bp_9", "bp_10"]
for name in names:
    f = open("/kaggle/working/2019-nCov/Docking/conf_" + name + ".txt", "w+")
    f.write("receptor = /u/macdougt/Research/2019-nCov/Docking/6LU7_receptor_rigid.pdbqt\n")
    f.write("flex = /u/macdougt/Research/2019-nCov/Docking/6LU7_receptor_flex.pdbqt\n")
    f.write("ligand = /u/macdougt/Research/2019-nCov/Docking/" + name + ".pdbqt\n")

    f.write("out = /u/macdougt/Research/2019-nCov/Docking/out_" + name + ".pdbqt\n")
    f.write("log = /u/macdougt/Research/2019-nCov/Docking/log_" + name + ".txt\n")

    f.write("center_x = -11.963\n")
    f.write("center_y = 15.683\n")
    f.write("center_z = 69.193\n")

    f.write("size_x = 20\n")
    f.write("size_y = 24\n")
    f.write("size_z = 22\n")

    f.write("exhaustiveness = 80\n")

    f.write("cpu = 7\n")
    f.close()

### Scores of various compounds

### best predicted compounds

In [None]:
rdkit.Chem.Draw.MolsToGridImage(best_predicted_mols, molsPerRow=2, maxMols=100, subImgSize=(800, 800), legends=["-8.2", "-7.5", "-8.0", "-7.6", "-8.5", "-7.5", "-8.2", "-8.7", "-7.6", "-8.4"])

For the generated compounds

## docking score generated

In [None]:
docked_generated = [generated[i] for i in [5,6,9,32,33,34,36,38,43,44]]

In [None]:
rdkit.Chem.Draw.MolsToGridImage(docked_generated, molsPerRow=2, maxMols=100, subImgSize=(800, 800), legends=["-7.3", "-8.2", "-8.8", "-9.8", "-8.4", "-6.9", "-7.6", "-7.9", "-6.7", "-9.4"])

Chosing the best scoring compound from these two scemes, we get the following compound from the prediction method, with a score of -8.7kcal/mol

In [None]:
best_predicted_mols[7]

The best one from the generative mthod is shown below, with a score of -9.8 kcal/mol

In [None]:
generated[32]

### Visualizing the high scoring compounds in the active site

The following is the highest scoring predicted compound mentioned above.

The following is the highest scoring generated compound mentioned above.

## Discussion and Conclusion

A predictive deep learning model was trained on a self-generated set of protease inhibitors, and the the pubchem literature was searched for 183 compounds that are somilar to the n3 ligand. Predictions were made on these compounds and those with the 10 best predictive scores were docked to the ligand. The highest scoring compound is shown above and has a score of -8.7kcal/mol

A generative deep learning model was trained on a self-generated set of protease inhibitors, and 50 new compounds were sampled from the latent space of the model. The 10 most promising compounds were docked to the ligand. The highest scoring compound is shown above and has a score of -9.8kcal/mol

The best compounds from each method show signicant gains over the baseline score of -7.9kcal/mol for the n3 ligand.

I would say that the high scoring compound from the predictive model should be investigated first, because since it was predicted using a test test of compounds from pubchem, this means that it is a chemically feasible compound, which is very important, which means that it could be obtained or made quickly, to be used right away. The generated compound did have a higher binding score, but it's a generated compound that might be difficult to make, even for an experienced chemist, it's difficult to say. 