In [1]:
import numpy as np
import pickle
import torch

In [2]:
train = np.load('qm9/train.npz')

In [3]:
with open('qm9_second_half_smiles.pickle', 'rb') as f:
    smiles = pickle.load(f)

In [4]:
len(smiles)

48841

In [5]:
dataset = {}
with np.load('qm9/train.npz') as f:
    for key, val in f.items():
        dataset[key] = torch.from_numpy(val)

In [6]:
dataset.keys()

dict_keys(['num_atoms', 'charges', 'positions', 'index', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 'omega1', 'zpve_thermo', 'U0_thermo', 'U_thermo', 'H_thermo', 'G_thermo', 'Cv_thermo'])

In [7]:
np.random.seed(0)
fixed_perm = np.random.permutation(len(train[train.files[0]]))

In [8]:
sliced_perm = fixed_perm[len(train[train.files[0]]) // 2:]

In [9]:
for key in dataset.keys():
    dataset[key] = dataset[key][sliced_perm]

In [10]:
valid_idx = [i for i, smi in smiles]
valid_smiles = [smi for i, smi in smiles]
valid_idx = torch.tensor(valid_idx, dtype=torch.long)

In [11]:
valid = {}
for key, tensor in dataset.items():
    valid[key] = tensor[valid_idx]

In [12]:
valid['smiles'] = valid_smiles

In [13]:
torch.save(valid, 'qm9_second_half_smiles_dataset.pt')

In [14]:
rag_dataset = torch.load('qm9_second_half_smiles_dataset.pt')

In [15]:
rag_dataset

{'num_atoms': tensor([19, 17, 13,  ..., 23, 25, 19]),
 'charges': tensor([[6, 6, 6,  ..., 0, 0, 0],
         [8, 6, 6,  ..., 0, 0, 0],
         [8, 6, 6,  ..., 0, 0, 0],
         ...,
         [6, 6, 6,  ..., 0, 0, 0],
         [6, 6, 6,  ..., 0, 0, 0],
         [8, 6, 6,  ..., 0, 0, 0]]),
 'positions': tensor([[[ 0.1442,  1.4871,  0.1681],
          [ 0.1413, -0.0395,  0.0671],
          [ 0.5729, -0.5419, -1.3249],
          ...,
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],
 
         [[ 0.4061,  1.1421, -0.4790],
          [ 0.0523, -0.1354, -0.0311],
          [ 0.1556, -0.4097,  1.5266],
          ...,
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]],
 
         [[ 0.0377, -0.0351, -0.2158],
          [-0.0515,  1.1177,  0.0610],
          [-1.2947,  1.9717,  0.0678],
          ...,
          [ 0.0000,  0.0000,  0.0000],
          [ 0.0000,  

In [16]:
rag_dataset.keys()

dict_keys(['num_atoms', 'charges', 'positions', 'index', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 'omega1', 'zpve_thermo', 'U0_thermo', 'U_thermo', 'H_thermo', 'G_thermo', 'Cv_thermo', 'smiles'])

In [17]:
rag_dataset['smiles']

['[H]c1c2nnn([H])c2c(N([H])[H])n1[H]',
 '[H]C(=O)C(=NC([H])([H])C#N)N([H])[H]',
 '[H]C([H])([H])C([H])([H])C([H])([H])C1(C([H])([H])[H])OC([H])([H])C1([H])C([H])([H])[H]',
 '[H]C([H])([H])OC1(C([H])([H])[H])C([H])([H])C1(C([H])([H])[H])C([H])([H])C([H])([H])[H]',
 '[H]C1([H])OC([H])([H])C([H])(C([H])([H])[H])C([H])([H])C([H])(C([H])([H])[H])O1',
 '[H]C1([H])OC(=O)N(C([H])([H])C([H])([H])[H])C1=O',
 '[H]N(C(=O)C1([H])C([H])([H])C1([H])[H])C([H])([H])[H]',
 '[H]OC1([H])C([H])([H])C12OC([H])([H])C2([H])[H]',
 '[H]c1noc2nnnn12',
 '[H]N(C(=O)C([H])(OC([H])([H])C([H])([H])[H])C([H])([H])[H])C([H])([H])[H]',
 '[H]C1([H])OC2([H])C([H])([H])C3([H])C([H])([H])C3([H])C12[H]',
 '[H]OC1([H])C2([H])OC3([H])C([H])([H])C([H])([H])C1([H])C32[H]',
 '[H]OC1(C([H])([H])[H])C2([H])C([H])(C([H])([H])[H])C3([H])C2([H])C31[H]',
 '[H]C([H])([H])C1([H])C(=O)C2([H])C3([H])OC1([H])C32[H]',
 '[H]C([H])([H])C([H])([H])c1noc(C#N)n1',
 '[H]C(=O)C1(C([H])([H])C#N)N([H])C1([H])C([H])([H])[H]',
 '[H]N=c1oc(F)c([H])c([H]

In [18]:
print(rag_dataset['homo'].shape, rag_dataset['lumo'].shape, rag_dataset['alpha'].shape, rag_dataset['gap'].shape, rag_dataset['mu'].shape, rag_dataset['Cv'].shape)

torch.Size([48841]) torch.Size([48841]) torch.Size([48841]) torch.Size([48841]) torch.Size([48841]) torch.Size([48841])
