In [1]:
import tqdm

from atomic_datasets import QM9Dataset, tmQMDataset, GEOMDrugsDataset

In [2]:
dataset = QM9Dataset(
    root_dir="data/qm9",
    check_with_rdkit=False,
    remove_uncharacterized_molecules=True,
    max_num_molecules=100,
)

dataset = tmQMDataset(
    root_dir="data/tmqm",
)

dataset = GEOMDrugsDataset(
    root_dir="data/geom_drugs",
)

In [None]:
for graph in dataset:
    print(graph)

Here, we see how to use `atomic_datasets` with PyTorch Geometric:

In [None]:
from typing import Optional, Callable

import torch.utils.data
import torch_geometric.data


class QM9(torch.utils.data.Dataset):
    """QM9 dataset in PyTorch Geometric format."""

    def __init__(self, root_dir, transform: Optional[Callable] = None):
        super(QM9, self).__init__()
        self.dataset = QM9Dataset(root_dir, check_with_rdkit=False)
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        sample = torch_geometric.data.Data(
            pos=torch.as_tensor(sample['nodes']['positions']),
            species=torch.as_tensor(sample['nodes']['species']),
        )
        if self.transform:
            sample = self.transform(sample)
        return sample
