In [16]:
import numpy as np
import scipy.spatial
import ase.db
import torch


class AseDbDataset(torch.utils.data.Dataset):
    
    def __init__(self, asedb_path, transformer, **kwargs):
        super().__init__(**kwargs)
        self.asedb_path = asedb_path
        self.asedb_connection = ase.db.connect(asedb_path)
        self.transformer = transformer if transformer else lambda x: x

    def __len__(self):
        return len(self.asedb_connection)

    def __getitem__(self, key):
        # Note that ASE databases are 1-indexed
        try:
            return self.transformer(self.asedb_connection[key + 1])
        except KeyError:
            raise IndexError("index out of range")

            
class TransformRowToGraphXyz:
    """Transform ASE DB row to graph while keeping the xyz positions of the vertices"""

    def __init__(self, cutoff=5.0, target_property="U0"):
        self.cutoff = cutoff
        self.target_property = target_property

    def __call__(self, row):
        atoms = row.toatoms()
        
        # Get edges (does not implement periodic boundary conditions)
        edges = self.get_edges_simple(atoms)

        # Extract target property
        if hasattr(row, self.target_property):
            t = getattr(row, self.target_property)
        elif hasattr(row, "data") and self.target_property in row.data:
            t = row.data[self.target_property]
        else:
            t = np.nan

        default_type = torch.get_default_dtype()
        graph_data = {
            "nodes": torch.tensor(atoms.get_atomic_numbers()),
            "nodes_xyz": torch.tensor(atoms.get_positions(), dtype=default_type),
            "num_nodes": torch.tensor(len(atoms.get_atomic_numbers())),
            "edges": torch.tensor(edges),
            "num_edges": torch.tensor(edges.shape[0]),
            "target": t
        }
        return graph_data

    def get_edges_simple(self, atoms):
        # Compute distance matrix
        pos = atoms.get_positions()
        dist_mat = scipy.spatial.distance_matrix(pos, pos)
        # Build array with edges and edge features (distances)
        valid_indices_bool = dist_mat < self.cutoff
        np.fill_diagonal(valid_indices_bool, False)  # Remove self-loops
        edges = np.argwhere(valid_indices_bool)  # num_edges x 2
        return edges


dataset_path = "qm9.db"
dataset = AseDbDataset(dataset_path, TransformRowToGraphXyz())
example = dataset[0]
print(example)

DatabaseError: ignored

In [11]:
pip install ase

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
