In [2]:
import ll
import rich

ll.pretty()

In [3]:
import datasets
import torch
from torch_geometric.data import Data

dataset = datasets.load_dataset("nimashoghi/mptrj", split="val")
dataset.set_format("torch")


def to_pyg(data_dict):
    return Data(
        pos=data_dict["positions"],
        cell=data_dict["cell"].unsqueeze(0),
        natoms=data_dict["num_atoms"],
        tags=torch.zeros_like(data_dict["numbers"], dtype=torch.long),
    )


rich.print(to_pyg(dataset[0]))

In [4]:
import torch

from jmppeft.utils.radius_graph import radius_graph_pbc

data = to_pyg(dataset[0])
edge_index, cell_offsets, num_neighbors = radius_graph_pbc(
    radius=6.0,
    max_num_neighbors_threshold=30,
    pos=data.pos,
    cell=data.cell,
    n_atoms=data.natoms,
)

rich.print(
    {
        "edge_index": edge_index,
        "cell_offsets": cell_offsets,
        "num_neighbors": num_neighbors,
    }
)

In [19]:
import torch

from jmppeft.utils.radius_graph import radius_graph_pbc
from torch_geometric.data import Batch

data = Batch.from_data_list([to_pyg(dataset[0]), to_pyg(dataset[0])])
data.edge_index, data.edge_cell_offsets, data.num_neighbors = radius_graph_pbc(
    radius=6.0,
    max_num_neighbors_threshold=30,
    pos=data.pos,
    cell=data.cell,
    n_atoms=data.natoms,
)

rich.print(
    data,
    {
        "edge_index": data.edge_index,
        "cell_offsets": data.edge_cell_offsets,
        "num_neighbors": data.num_neighbors,
    },
)

size_so_far = 0
for bsz_i in range(data.batch.max().item() + 1):
    node_mask = data.batch == bsz_i
    edge_mask = node_mask[data.edge_index[0]]
    assert (edge_mask == node_mask[data.edge_index[1]]).all()
    rich.print(
        {
            "edge_index": data.edge_index[:, edge_mask] - size_so_far,
            "cell_offsets": data.edge_cell_offsets[edge_mask],
            "num_neighbors": data.num_neighbors,
        },
    )

    size_so_far += node_mask.sum().item()