In [3]:
import torch, sys, json, os, pickle
from torchdrug import data, datasets, core, models, tasks, utils
import numpy as np
import rdkit.Chem.AllChem as Chem

In [31]:
# Some utilities 

# Load dataset
def load_from_pickle(pkl_file, name = "QM9"):
    print(f"Loading {name} dataset...")
    with open(pkl_file , "rb") as f:
        dset = pickle.load(f)
    print(f"Loaded {name} dataset.")
    return dset
# Clean dataset
def clean_dataset(dataset):
    errors = 0
    good_mols = []
    for i, mol in enumerate(dataset.data):
        if mol.stereo_atoms is not None and mol.stereo_atoms.size()[1] == 0:
            errors += 1
            mol.stereo_atoms = torch.zeros([int(mol.num_edge), 2], dtype=torch.int64)
        elif int(mol.num_edge) == 0:
            if mol.stereo_atoms is not None:
                errors += 1
                mol.stereo_atoms = torch.tensor([])
        good_mols.append(mol)
    if errors > 0:
        print("Errors found, revising dataset.")
        dataset.data = good_mols
    else:
        print("No errors found!")
    return dataset
# Molecule Graph from SMILES
def mol_graph_from_smiles(SMILES, mol_feature = None):
    mol = data.Molecule.from_smiles(SMILES, mol_feature = mol_feature)
    return mol
# Save dataset
def save_to_pickle(pkl_file, dset, name = "QM9", size = 20):
    print(f"Saving {name} dataset...")
    if size == "full":
        dset = dset
    else:
        dset = dset[:size]
    with open(pkl_file , "wb") as f:
        pickle.dump(dset, f)
    print(f"Saved {name} dataset to {pkl_file}.")
# update edge
def edge_distance(molecule):
    test_mol = molecule
    # add dummy dimesnion to edges to replace with distances
    n_edges = molecule.num_edge
    updated_edges = torch.zeros(size = (test_mol.edge_list.size()))
    test_mol.edge_feature = torch.hstack((test_mol.edge_feature,  torch.zeros(size = (test_mol.edge_feature.size()[0], 1))))
    # iterate edges
    bond_distances = []
    for i, edge in enumerate(test_mol.edge_feature):
        nodes = test_mol.edge_list[i][:-1].tolist()
        n1_i, n2_i = nodes[0], nodes[1]
        n1_pos, n2_pos = test_mol.node_position[n1_i], test_mol.node_position[n2_i]
        distance = np.linalg.norm(n1_pos - n2_pos)
        bond_distances.append(distance)
    bond_distances = np.array(bond_distances).reshape(-1, 1)
    # reciprocal of distance
    test_mol.edge_feature[:, 18] = 1 / torch.tensor(bond_distances)[:, 0]

    return test_mol

# molecule
SMILES = "CCO"
rdkit_mol = Chem.MolFromSmiles(SMILES)
Chem.EmbedMolecule(rdkit_mol)
node_position = torch.tensor(rdkit_mol.GetConformer().GetPositions())
# molecular graph
test_mol = data.Molecule.from_molecule(rdkit_mol)
# edge rep
out_mol = edge_distance(test_mol)

[18:22:00] Molecule does not have explicit Hs. Consider calling AddHs()


In [32]:
qm9 = load_from_pickle("QM9.pkl", "QM9")
dataset = qm9
updated_data = []
dataset.data = [edge_distance(mol) for mol in dataset.data]

Loading QM9 dataset...
Loaded QM9 dataset.


In [None]:
# Split dataset
lengths = [int(0.1 * len(dataset)), int(0.1 * len(dataset))]
lengths += [len(dataset) - sum(lengths)]
train_set, valid_set, test_set = torch.utils.data.random_split(dataset, lengths)

# Arguments
hidden_dim = 256
lr = 1e-3
batch_size = 128
epochs = 5
gpus = [0]
if torch.cuda.is_available():
     gpus = [0]
else:
     gpus = None
     
# Define model
model = models.GCN(input_dim = dataset.node_feature_dim,
                   hidden_dims = [256, 128, 64],
                   edge_input_dim = dataset.edge_feature_dim)

# Define task
task = tasks.PropertyPrediction(model, task=dataset.tasks)

# Optimizer
optimizer = torch.optim.Adam(task.parameters(), lr=lr)

# Solver
solver = core.Engine(task,
                     train_set,
                     valid_set,
                     test_set,
                     optimizer,
                     gpus = gpus,
                     batch_size = batch_size)

# Train model
solver.train(num_epoch=epochs)

In [None]:
bonds = [rdkit_mol.GetBondWithIdx(i) for i in range(rdkit_mol.GetNumBonds())]
print(bonds)
for bond in bonds:
    type = str(bond.GetBondType())
    stereo = bond.GetStereo()
    if stereo:
        _atoms = [a for a in bond.GetStereoAtoms()]
    else:
        _atoms = [0, 0]
    print(_atoms)

In [25]:
d = {}
d["b_type"] = Chem.rdchem.BondType.SINGLE
e_t = [int(d['b_type'] == x) for x in [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
                                        Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]]