In [1]:
import pandas as pd
import numpy as np
import pickle
import os
from collections import OrderedDict
from glob import glob
from tqdm import tqdm
from file_handling import *

In [2]:
dataset_dir = '/home/matthew/Programming/Kaggle/MolecularProperties/dataset'

In [3]:
atom_indicies = {
    'H' : 0,
    'C' : 1,
    'O' : 2,
    'N' : 3,
    'P' : 4,
    'S' : 5,
    'CL' : 6,
    'BR' : 6,
    'F' : 6,
    'I' : 6
}

bond_indicies = {
    '1' : 0,
    '2' : 1,
    '3' : 2,
    'am' : 3,
    'ar' : 4
}

In [4]:
def constructMultiGraph(mol):
    ligand_path = os.path.join(dataset_dir,'structures/%s.mol2' % mol)
    atom_df,bond_df = read_mol2(ligand_path)

    lig_coords = atom_df[['x','y','z']].values

    g = OrderedDict({}) # Bond information
    h = OrderedDict({}) # Atom information

    for idx,row in atom_df.iterrows(): # Build graph of ligand atoms
        atom_id = row.atom_id
        atom_type = str(row.atom_type.split('.')[0]).upper()
        atom_idx = atom_indicies[atom_type]
        fp = np.zeros((7),dtype=np.int32) # Atom type one-hot
        fp[atom_idx] = 1
        h[atom_id] = list(fp)
        connected = list(bond_df[bond_df.atom_A == atom_id].atom_B.values)
        connected.extend(bond_df[bond_df.atom_B == atom_id].atom_A.values)
        for bond in connected:
            bfp = np.zeros((5),dtype=np.int32) # Bond type one-hot
            bfp[0] = 1 # Covalent bond
            bfp = list(bfp)
            if atom_id not in g:
                g[atom_id] = []
            g[atom_id].append((bfp,bond))

    return g,h

In [5]:
def saveMultiGraph(mol,g,h):
    fname = '%s_g.pkl' % mol
    f_path = os.path.join(dataset_dir,'multigraphs',fname)
    with open(f_path, 'wb') as f:
        pickle.dump(g,f)
    fname = '%s_h.pkl' % mol
    f_path = os.path.join(dataset_dir,'multigraphs',fname)
    with open(f_path, 'wb') as f:
        pickle.dump(h,f)

In [6]:
train_df = pd.read_csv(os.path.join(dataset_dir,'train.csv'))
    for mol in tqdm(pd.unique(train_df.molecule_name)):
        try:
            g,h = constructMultiGraph(mol)
            saveMultiGraph(mol,g,h)
        except:
            print('MultiGraph Error')

100%|██████████| 85003/85003 [48:47<00:00, 29.04it/s]


In [7]:
test_df = pd.read_csv(os.path.join(dataset_dir,'test.csv'))
for mol in tqdm(pd.unique(test_df.molecule_name)):
    try:
        g,h = constructMultiGraph(mol)
        saveMultiGraph(mol,g,h)
    except:
        print('MultiGraph Error')

100%|██████████| 45772/45772 [26:17<00:00, 29.01it/s]
