# Prepare Input Data
The first step of building a model with TensorFlow is to get the data in an easy-to-use format.
For TensorFlow, easy-to-use means that the data is stored in matrix formats and in a format that can be rapidly read from disk.
Here, we show how to do that.

In [1]:
%matplotlib inline
from typing import List, Tuple
from matplotlib import pyplot as plt
from mpnn.data import make_tfrecord
from sklearn.model_selection import train_test_split
from rdkit import Chem
from tqdm import tqdm
import tensorflow as tf
import pandas as pd
import numpy as np
import json

## Get the Data
It is stored on a [GitHub page](https://github.com/globus-labs/g4mp2-atomization-energy) from a previous project

In [2]:
data = pd.read_json('../datasets/qm9.json.gz', lines=True)
print(f'Loaded {len(data)} training entries')

Loaded 25000 training entries


Convert the SMILES to RDKit molecules. 

Make sure to add the Hydrogens in

In [3]:
%%time
data['mol'] = data['smiles_0'].apply(Chem.MolFromSmiles).apply(Chem.AddHs)

CPU times: user 1.14 s, sys: 38.8 ms, total: 1.18 s
Wall time: 1.18 s


## Convert the Molecule Records to Dictionaries of Arrays
While RDKit molecules are convenient, TensorFlow works with numeric _tensors_. The next few cells show how to convert an RDKit molecule to a format.

Our first step is to prepare to convert types of atoms and bonds to numeric values. We do that by finding all types of atoms and Take a look at [./mpnn/data.py](./mpnn/data.py) to get a better idea of what this function does

In [4]:
def make_type_lookup_tables(mols: List[Chem.Mol]) -> Tuple[List[int], List[str]]:
    """Create lists of observed atom and bond types

    Args:
        mols: List of molecules used for our training set
    Returns:
        - List of atom types (elements)
        - List of bond types (elements)
    """

    # Initialize the lists
    atom_types = set()
    bond_types = set()

    # Get all types observed in these graphs
    for mol in mols:
        atom_types.update([x.GetAtomicNum() for x in mol.GetAtoms()])
        bond_types.update([x.GetBondType() for x in mol.GetBonds()])

    # Return as sorted lists
    return sorted(atom_types), sorted(bond_types)

In [5]:
atom_types, bond_types = make_type_lookup_tables(data['mol'])
print(f'Found {len(atom_types)} types of atoms: {atom_types}')
print(f'Found {len(bond_types)} types of bonds: {bond_types}')

Found 5 types of atoms: [1, 6, 7, 8, 9]
Found 4 types of bonds: [rdkit.Chem.rdchem.BondType.SINGLE, rdkit.Chem.rdchem.BondType.DOUBLE, rdkit.Chem.rdchem.BondType.TRIPLE, rdkit.Chem.rdchem.BondType.AROMATIC]


The next step is to convert the molecules into dictionaries. We need to store the type of each atom in a molecule, the types of bonds, and which bonds connect which other atoms.

In [6]:
def convert_mol_to_dict(mol: Chem.Mol, atom_types: List[int], bond_types: List[str]) -> dict:
    """Convert RDKit representation of a molecule to an MPNN-ready dict
    
    Args:
        mol: Molecule to be converted
        atom_types: Lookup table of observed atom types
        bond_types: Lookup table of observed bond types
    Returns:
        (dict) Molecule as a dict
    """

    # Get the atom types, look them up in the atom_type list
    atom_type = [a.GetAtomicNum() for a in mol.GetAtoms()]
    atom_type_id = list(map(atom_types.index, atom_type))

    # Get the bond types and which atoms these connect
    connectivity = []
    edge_type = []
    for bond in mol.GetBonds():
        # Get information about the bond
        a = bond.GetBeginAtomIdx()
        b = bond.GetEndAtomIdx()
        b_type = bond.GetBondType()
        
        # Store how they are connected
        connectivity.append([a, b])
        connectivity.append([b, a])
        edge_type.append(b_type)
        edge_type.append(b_type)
    edge_type_id = list(map(bond_types.index, edge_type))

    # Sort connectivity array by the first column
    #  This is needed for the MPNN code to efficiently group messages for
    #  each atom when performing the message passing step
    connectivity = np.array(connectivity)
    if connectivity.size > 0:
        # Skip a special case of a molecule w/o bonds
        inds = np.lexsort((connectivity[:, 1], connectivity[:, 0]))
        connectivity = connectivity[inds, :]

        # Tensorflow's "segment_sum" will cause problems if the last atom
        #  is not bonded because it returns an array
        if connectivity.max() != len(atom_type) - 1:
            smiles = convert_nx_to_smiles(graph)
            raise ValueError(f"Problem with unconnected atoms for {smiles}")
    else:
        connectivity = np.zeros((0, 2))

    return {
        'n_atom': len(atom_type),
        'n_bond': len(edge_type),
        'atom': atom_type_id,
        'bond': edge_type_id,
        'connectivity': connectivity
    }

Let's show Methane and as example.

In [7]:
convert_mol_to_dict(Chem.AddHs(Chem.MolFromSmiles('C')), atom_types, bond_types)

{'n_atom': 5,
 'n_bond': 8,
 'atom': [1, 0, 0, 0, 0],
 'bond': [0, 0, 0, 0, 0, 0, 0, 0],
 'connectivity': array([[0, 1],
        [0, 2],
        [0, 3],
        [0, 4],
        [1, 0],
        [2, 0],
        [3, 0],
        [4, 0]])}

Checking this off:
- Methane has 5 atoms
- There are 4 bonds (8 when you count both forward and backward)
- The first atom is a Carbon (type 1 in our lookup table)
- The remaining atoms are Hydrogen (type 0 in our lookup table)
- All bonds are single bonds (type 0 in our lookup table)
- All bonds either start or end in the carbon atom (atom number 0)

It looks like it is working correctly, so let's run on the whole dataset

In [8]:
data['dict'] = data['mol'].apply(lambda x: convert_mol_to_dict(x, atom_types, bond_types))

## Save the Data as TFRecords
Tensorflow has a preferred data format, [`TFRecord`](https://www.tensorflow.org/tutorials/load_data/tfrecord), which stores data in a binary format that is fast to read from disk. The details of it are a little more advanced for this tutorial but the short version is that we must convert data to this binary format then save it into a special archive format.

The `make_tfrecord` function takes one of these dicionaries and stores it in binary format. You make notice some familiar words in this binary format, such as `n_bond`, but most of it is in a format that is not for humans.

In [9]:
make_tfrecord(data['dict'].iloc[0])

b'\n\xf4\x01\n\x0f\n\x06n_bond\x12\x05\x1a\x03\n\x01.\n\x0f\n\x06n_atom\x12\x05\x1a\x03\n\x01\x16\np\n\x0cconnectivity\x12`\x1a^\n\\\x00\x01\x00\t\x00\n\x00\x0b\x01\x00\x01\x02\x01\x05\x02\x01\x02\x03\x02\x06\x03\x02\x03\x04\x03\x0c\x03\r\x04\x03\x04\x05\x04\x0e\x04\x0f\x05\x01\x05\x04\x05\x10\x05\x11\x06\x02\x06\x07\x06\x08\x06\x12\x07\x06\x07\x08\x07\x13\x07\x14\x08\x06\x08\x07\x08\x15\t\x00\n\x00\x0b\x00\x0c\x03\r\x03\x0e\x04\x0f\x04\x10\x05\x11\x05\x12\x06\x13\x07\x14\x07\x15\x08\n"\n\x04atom\x12\x1a\x1a\x18\n\x16\x01\x01\x01\x01\x01\x01\x01\x01\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\n:\n\x04bond\x122\x1a0\n.\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'

Now, let's split our data into a separate training (used to learn parameters of our neural network), validation (used to assess when our model is done training) and test set (used to assess the model's performance after training). 

In [10]:
train_data, test_data = train_test_split(data, shuffle=True, train_size=0.9)

In [11]:
train_data, valid_data = train_test_split(train_data, train_size=0.9)

Save the data in TFDataset format in "protobuf" files

In [12]:
for name, dataset in zip(['train', 'valid', 'test'], [train_data, valid_data, test_data]):
    # Open the file in which to store the data
    with tf.io.TFRecordWriter(f'datasets/{name}_data.proto') as writer:
        # Loop over each entry in the dataset
        for _, entry in tqdm(dataset.iterrows(), desc=name):
            # Store some output values in the dictionary as well
            record = entry['dict']
            for o in ['u0_atom', 'bandgap']:
                record[o] = entry[o]
            writer.write(make_tfrecord(record))

train: 20250it [00:02, 8253.40it/s]
valid: 2250it [00:00, 7943.84it/s]
test: 2500it [00:00, 8139.57it/s]


Great! We are now ready to train an MPNN. Note we have a test set that is 10% of our full dataset (2500 entries) and the training set is 90% of the remaining 90% of the full data (20250 entries).