# Making 2D and 3D training sets
The goal here is to take a copy of the QM9 dataset and save two versions: one with the DFT geometries and another with 2D geometries created with RDKit.

In [1]:
%matplotlib inline
from schnetpack.data import AtomsData
from matplotlib import pyplot as plt
from rdkit.Chem import AllChem
from rdkit import Chem
from ase.db import connect
from ase.io.xyz import read_xyz
from typing import List
from pathlib import Path
from io import StringIO
from tqdm import tqdm
import pandas as pd
import numpy as np

Configuration

In [2]:
qm9_path = '../../../JCESR/g4mp2-atomization-energy/data/output/g4mp2_data.json.gz'

## Load the QM9 dataset
I'm using a copy I prepared [in another project](https://github.com/globus-labs/g4mp2-atomization-energy)

In [3]:
qm9 = pd.read_json(qm9_path, lines=True)

## Save a 3D version
[SchNetPack](https://schnetpack.readthedocs.io/en/stable/) requires data to be saved in a [special `AtomsData` format](https://schnetpack.readthedocs.io/en/stable/tutorials/tutorial_01_preparing_data.html#Preparing-your-own-data).
That database requires the molecular geometries to be stored as ASE atoms objects, 
and the properties to be provided as a dictionary with at least 1D arrays. 

For convenience, we'll make a function that takes our Pandas dataframe and saves it into an Atoms database


In [4]:
def to_spk_db(path: str, data: pd.DataFrame, xyz_col: str = 'xyz', property_cols: List[str] = ('u0', 'bandgap')):
    """Convert a pandas dataframe to a SchNetPack database
    
    Args:
        path: Path to the ase database
        data: Dataframe containing geometries and properties
        xyz_col: Column holding the XYZ-format molecular geometry
        proprety_cols: Names of the columns holding the properties
    """
    
    # Make sure the database doesn't exist already
    path = Path(path)
    if path.is_file():
        path.unlink()  # Deletes the file if it exists
    
    # Convert the XYZ
    atoms = data['xyz'].apply(lambda x: next(read_xyz(StringIO(x), slice(None))))
    
    # Convert the properties to the format needed for SchNetPack (SPK) -
    #  single dictionary per molecule and the property as a 1D array
    properties = [{col: np.atleast_1d(row[col]) for col in property_cols} for _, row in data.iterrows()]
    
    # Add them to a database
    db = AtomsData(str(path), available_properties=property_cols)
    db.add_systems(atoms, properties)
    return db

In [5]:
%%time
to_spk_db('3d_qm9.db', qm9)

Wall time: 6min 13s


<schnetpack.data.atoms.AtomsData at 0x2380c200188>

## Save a 2D Version
RDKit can generate 2D coordinates for molecules. To use them with SchNetPack, we'll first make a function that returns the molecules as an XYZ format that we can use with the "save to SPK" function above.

In [6]:
def generate_2d_xyz(smiles: str) -> str:
    """Generate an XYZ file with 2D coordinates of a molecule
    
    Args:
        smiles: SMILES string of a molecule
    Returns:
        An XYZ file
    """
    
    # Parse the molecule
    mol = Chem.MolFromSmiles(smiles)
    mol = Chem.AddHs(mol)
    
    # Generate the 2D coordinates of the molecule
    AllChem.Compute2DCoords(mol)
        
    # Save geometry as 3D coordinates
    xyz = f"{mol.GetNumAtoms()}\n"
    xyz += smiles + "\n"
    conf = mol.GetConformer()
    for i, a in enumerate(mol.GetAtoms()):
        s = a.GetSymbol()
        c = conf.GetAtomPosition(i)
        xyz += f"{s} {c[0]} {c[1]} {c[2]}\n"
    return xyz
print(generate_2d_xyz('c1ccccc1'))

12
c1ccccc1
C 1.5000000000000004 1.4802973661668753e-16 0.0
C 0.7499999999999993 -1.2990381056766584 0.0
C -0.7500000000000006 -1.2990381056766578 0.0
C -1.5 3.317267564887905e-16 0.0
C -0.7499999999999996 1.2990381056766584 0.0
C 0.7500000000000006 1.2990381056766582 0.0
H 3.0 3.700743415417188e-16 0.0
H 1.4999999999999996 -2.598076211353318 0.0
H -1.5000000000000007 -2.5980762113533156 0.0
H -3.0 5.921189464667501e-16 0.0
H -1.4999999999999998 2.598076211353316 0.0
H 1.5000000000000007 2.598076211353316 0.0



Generate them for all molecules in a dataset

In [7]:
qm9['xyz_2d'] = [generate_2d_xyz(x) for x in tqdm(qm9['smiles_0'])]

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 130258/130258 [02:02<00:00, 1060.58it/s]


Save the 2D data to disk

In [8]:
%%time
to_spk_db('2d_qm9.db', qm9, xyz_col='xyz_2d')

Wall time: 4min 10s


<schnetpack.data.atoms.AtomsData at 0x2380c1f73c8>