In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import py3Dmol
from rdkit.Chem.rdchem import GetPeriodicTable
from sklearn.model_selection import train_test_split

PTABLE = GetPeriodicTable()
np.set_printoptions(formatter={'all':lambda x: str(x)})

def xyzfile(axyz):
    file = f"{axyz.shape[0]}\n\n"
    for a, p in zip(axyz[:, 0], axyz[:, 1:]):
        x, y, z = p.tolist()
        file += f"{PTABLE.GetElementSymbol(int(a))} {x:f} {y:f} {z:f}\n"
    return file

def show(axyz):
    view = py3Dmol.view(width=800, height=400)
    view.addModel(xyzfile(axyz), "xyz")
    view.setStyle({'stick': {}, 'sphere': {'scale': 0.25}})
    view.zoomTo()
    view.show()

In [2]:
metadata = np.load("../raw/qm9/processed/metadata.npy")
coords = np.load("../raw/qm9/processed/coords.npy")

# Unbind coordinates
start_indices = metadata[:, 0]
coords = np.split(coords, start_indices[1:])

In [3]:
split_ratio = [0.8, 0.1, 0.1]
D = np.arange(metadata.shape[0])
seed = 100

splits = {"train": None, "val": None, "test": None}
val_test_ratio = split_ratio[1] / (split_ratio[1] + split_ratio[2])
splits["train"], D = train_test_split(D, train_size=split_ratio[0], random_state=seed)
splits["val"], splits["test"] = train_test_split(D, train_size=val_test_ratio, random_state=(seed + 1))

In [4]:
for key in splits:
    assert (np.sort(np.load(f"qm9_{key}.npy")) == np.sort(splits[key])).all()

In [5]:
import sys
sys.path.append('../..')
from src.datamodule import make_molecule

moments = []

from tqdm import tqdm
for info, axyz in tqdm(zip(metadata, coords)):
    M = make_molecule(info, axyz)
    moments.append(M.moments.numpy()[0])
moments = np.array(moments)

133471it [01:07, 1982.57it/s]


In [6]:
min_moments = moments.min(axis=1)
min_moments.shape

(133471,)

In [7]:
(min_moments < 1e-5).sum()

3109

In [8]:
i = np.argsort(min_moments)[:3200][-1]
show(coords[i])
min_moments[i]

1.2488858e-05

In [9]:
small_moments = (min_moments < 1e-5).nonzero()[0]
np.save("qm9_small_moments.npy", small_moments)

In [10]:
num_atoms = np.array([axyz.shape[0] for axyz in coords])
small = (num_atoms < 5).nonzero()[0]
np.save("qm9_too_small.npy", small)