In [4]:
!pip install tensorflow
!pip install spektral

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple/
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple/
Collecting spektral
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/26/63/d756a1490d360d1cb398bb606eb282054e702e213ea983d3f76e99b83334/spektral-1.0.8-py3-none-any.whl (123 kB)
     |████████████████████████████████| 123 kB 1.6 MB/s            
Collecting numpy<1.20
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/66/d7/3b133b17e185f14137bc8afe7a41daf1f31556900f10238312a5ae9c7345/numpy-1.19.5-cp38-cp38-manylinux2010_x86_64.whl (14.9 MB)
     |████████████████████████████████| 14.9 MB 1.9 MB/s            
Collecting lxml
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ae/64/f002c327e99a3f3e8ec5c84590f665c071a39f85547987bdbab585fda1df/lxml-4.7.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (6.9 MB)
     |████████████████████████████████| 6.9 MB 2.9 MB/s             
Collecting tensorflow>=2.1.0
  Downloading

In [9]:
import os
import numpy as np

from joblib import Parallel, delayed
from tensorflow.keras.utils import get_file
from tqdm import tqdm

from spektral.data import Dataset, Graph
from spektral.utils import label_to_one_hot, sparse
from spektral.utils.io import load_csv, load_sdf

ATOM_TYPES = [1, 6, 7, 8, 9]
BOND_TYPES = [1, 2, 3, 4]

class QM9(Dataset):

    ## This class is modified from Spektral to read sdf file in local enviornment

    url = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/gdb9.tar.gz"

    def __init__(self, amount = None, n_jobs = 1, **kwargs):
        self.amount = amount
        self.n_jobs = n_jobs
        super().__init__(**kwargs)

    def download(self):
        get_file(
            "qm9.tar.gz",
            self.url,
            extract = True,
            cache_dir = self.path,
            cache_subdir = self.path,
        )
        os.remove(os.path.join(self.path, "qm9.tar.gz"))

    def read(self):
        print("loading QM9 dataset.")
        #load  dir and file_name
        sdf_file = os.path.join("/home/liyang/Documents/csprogram/graph-mol-infer-main/sdf_files", "gdb9_15794_eli.sdf")
        data = load_sdf(sdf_file, amount = self.amount)

        def read_mol(mol):
            x = np.array([atom_to_feature(atom) for atom in mol["atoms"]])
            a, e = mol_to_adj(mol)
            return x, a ,e

        data = Parallel(n_jobs=self.n_jobs)(
            delayed(read_mol)(mol) for mol in tqdm(data, ncols=80)
        )
        x_list, a_list, e_list = list(zip(*data))

        #load labels
        labels_file = os.path.join(self.path, "gdb9.sdf.csv")
        labels = load_csv(labels_file)
        labels = labels.set_index("mol_id").values
        if self.amount is not None:
            labels = labels[:self.amount]

        return [
            Graph(x=x, a=a, e=e, y=y)
            for x, a, e, y in zip(x_list, a_list, e_list, labels)
        ]


def atom_to_feature(atom):
    atomic_num = label_to_one_hot(atom["atomic_num"], ATOM_TYPES)
    coords = atom["coords"]
    charge = atom["charge"]
    iso = atom["iso"]

    return np.concatenate((atomic_num, coords, [charge, iso]), -1)

def mol_to_adj(mol):
    row, col, edge_features = [], [], []
    for bond in mol["bonds"]:
        start, end = bond["start_atom"], bond["end_atom"]
        row += [start, end]
        col += [end, start]
        edge_features += [bond["type"]] * 2

    a, e = sparse.edge_index_to_matrix(
        edge_index=np.array((row, col)).T,
        edge_weight=np.ones_like(row),
        edge_features=label_to_one_hot(edge_features, BOND_TYPES),
    )

    return a, e

In [11]:
dataset = QM9()
print(dataset)
print(dataset[0])

loading QM9 dataset.
Reading SDF


100%|████████████████████████████████████| 15794/15794 [00:22<00:00, 701.69it/s]


QM9(n_graphs=15794)
Graph(n_nodes=14, n_node_features=10, n_edge_features=4, n_labels=19)
