Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename src/dst Nodes For Point Cloud Featurization #228

Merged
merged 2 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions matsciml/datasets/alexandria/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def _parse_structure(
cell = torch.from_numpy(structure.lattice.matrix.copy()).float()
return_dict["cell"] = cell
chosen_nodes = self.choose_dst_nodes(system_size, self.full_pairwise)
src_nodes, dst_nodes = chosen_nodes["src_nodes"], chosen_nodes["dst_nodes"]
src_nodes, dst_nodes = (
chosen_nodes["pc_src_nodes"],
chosen_nodes["pc_dst_nodes"],
)
atom_numbers = torch.LongTensor(structure.atomic_numbers)
# uses one-hot encoding featurization
pc_features = point_cloud_featurization(
Expand Down Expand Up @@ -261,5 +264,5 @@ def collate_fn(batch: list[DataDict]) -> BatchDict:
return concatenate_keys(
batch,
pad_keys=["pc_features"],
unpacked_keys=["sizes", "src_nodes", "dst_nodes"],
unpacked_keys=["sizes", "pc_src_nodes", "pc_dst_nodes"],
)
4 changes: 2 additions & 2 deletions matsciml/datasets/alexandria/tests/test_alexandria.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_pairwise_pointcloud():
assert all(
[
key in sample
for key in ["pos", "pc_features", "sizes", "src_nodes", "dst_nodes"]
for key in ["pos", "pc_features", "sizes", "pc_src_nodes", "pc_dst_nodes"]
],
)
# for a pairwise point cloud sizes should be equal
Expand All @@ -93,7 +93,7 @@ def test_sampled_pointcloud():
assert all(
[
key in sample
for key in ["pos", "pc_features", "sizes", "src_nodes", "dst_nodes"]
for key in ["pos", "pc_features", "sizes", "pc_src_nodes", "pc_dst_nodes"]
],
)
# for a pairwise point cloud sizes should be equal
Expand Down
10 changes: 4 additions & 6 deletions matsciml/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import abstractmethod
from pathlib import Path
from random import sample
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable

import torch
from torch.utils.data import DataLoader, Dataset
Expand Down Expand Up @@ -253,13 +253,11 @@ def representation(self, value: str) -> None:
self._representation = value

@property
def pad_keys(self) -> list[str]:
...
def pad_keys(self) -> list[str]: ...

@pad_keys.setter
@abstractmethod
def pad_keys(self, keys: list[str]) -> None:
...
def pad_keys(self, keys: list[str]) -> None: ...

@classmethod
def from_devset(cls, transforms: list[Callable] | None = None, **kwargs):
Expand Down Expand Up @@ -355,4 +353,4 @@ def choose_dst_nodes(size: int, full_pairwise: bool) -> dict[str, torch.Tensor]:
dst_indices = torch.randperm(size)[:num_neighbors].sort().values
else:
dst_indices = src_indices
return {"src_nodes": src_indices, "dst_nodes": dst_indices}
return {"pc_src_nodes": src_indices, "pc_dst_nodes": dst_indices}
7 changes: 5 additions & 2 deletions matsciml/datasets/carolina_db/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def collate_fn(batch: list[DataDict]) -> BatchDict:
return concatenate_keys(
batch,
pad_keys=["pc_features"],
unpacked_keys=["sizes", "src_nodes", "dst_nodes"],
unpacked_keys=["sizes", "pc_src_nodes", "pc_dst_nodes"],
)

def raw_sample(self, idx):
Expand Down Expand Up @@ -74,7 +74,10 @@ def data_from_key(
return_dict["pos"] = coords
system_size = coords.size(0)
node_choices = self.choose_dst_nodes(system_size, self.full_pairwise)
src_nodes, dst_nodes = node_choices["src_nodes"], node_choices["dst_nodes"]
src_nodes, dst_nodes = (
node_choices["pc_src_nodes"],
node_choices["pc_dst_nodes"],
)
atom_numbers = torch.LongTensor(data["atomic_numbers"])
# uses one-hot encoding featurization
pc_features = point_cloud_featurization(
Expand Down
7 changes: 5 additions & 2 deletions matsciml/datasets/colabfit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def data_from_key(
assert coords.size(-1) == 3
system_size = coords.size(0)
node_choices = self.choose_dst_nodes(system_size, self.full_pairwise)
src_nodes, dst_nodes = node_choices["src_nodes"], node_choices["dst_nodes"]
src_nodes, dst_nodes = (
node_choices["pc_src_nodes"],
node_choices["pc_dst_nodes"],
)
# typecast atomic numbers
atom_numbers = torch.LongTensor(data["atomic_numbers"])
# uses one-hot encoding featurization
Expand Down Expand Up @@ -102,5 +105,5 @@ def collate_fn(batch: list[DataDict]) -> BatchDict:
return concatenate_keys(
batch,
pad_keys=["pc_features"],
unpacked_keys=["sizes", "src_nodes", "dst_nodes"],
unpacked_keys=["sizes", "pc_src_nodes", "pc_dst_nodes"],
)
4 changes: 2 additions & 2 deletions matsciml/datasets/colabfit/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_pairwise_pointcloud():
assert all(
[
key in sample
for key in ["pos", "pc_features", "sizes", "src_nodes", "dst_nodes"]
for key in ["pos", "pc_features", "sizes", "pc_src_nodes", "pc_dst_nodes"]
],
)
# for a pairwise point cloud sizes should be equal
Expand All @@ -90,7 +90,7 @@ def test_sampled_pointcloud():
assert all(
[
key in sample
for key in ["pos", "pc_features", "sizes", "src_nodes", "dst_nodes"]
for key in ["pos", "pc_features", "sizes", "pc_src_nodes", "pc_dst_nodes"]
],
)
# for a non-pairwise point cloud sizes should not be equal
Expand Down
10 changes: 6 additions & 4 deletions matsciml/datasets/lips/dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any

import numpy as np
import torch

from matsciml.common.registry import registry
Expand Down Expand Up @@ -54,7 +53,7 @@ def collate_fn(batch: list[DataDict]) -> BatchDict:
return concatenate_keys(
batch,
pad_keys=["pc_features"],
unpacked_keys=["sizes", "src_nodes", "dst_nodes"],
unpacked_keys=["sizes", "pc_src_nodes", "pc_dst_nodes"],
)

def data_from_key(
Expand Down Expand Up @@ -86,7 +85,10 @@ def data_from_key(
coords = data["pos"]
system_size = coords.size(0)
node_choices = self.choose_dst_nodes(system_size, self.full_pairwise)
src_nodes, dst_nodes = node_choices["src_nodes"], node_choices["dst_nodes"]
src_nodes, dst_nodes = (
node_choices["pc_src_nodes"],
node_choices["pc_dst_nodes"],
)
atom_numbers = torch.LongTensor(data["atomic_numbers"])
# uses one-hot encoding featurization
pc_features = point_cloud_featurization(
Expand Down
8 changes: 4 additions & 4 deletions matsciml/datasets/lips/tests/test_lips_reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def test_pairwise_pointcloud():
"pos",
"pc_features",
"sizes",
"src_nodes",
"dst_nodes",
"pc_src_nodes",
"pc_dst_nodes",
"force",
]
],
Expand All @@ -40,8 +40,8 @@ def test_sampled_pointcloud():
"pos",
"pc_features",
"sizes",
"src_nodes",
"dst_nodes",
"pc_src_nodes",
"pc_dst_nodes",
"force",
]
],
Expand Down
7 changes: 5 additions & 2 deletions matsciml/datasets/materials_project/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ def _parse_structure(
system_size = len(coords)
return_dict["pos"] = coords
chosen_nodes = self.choose_dst_nodes(system_size, self.full_pairwise)
src_nodes, dst_nodes = chosen_nodes["src_nodes"], chosen_nodes["dst_nodes"]
src_nodes, dst_nodes = (
chosen_nodes["pc_src_nodes"],
chosen_nodes["pc_dst_nodes"],
)
atom_numbers = torch.LongTensor(structure.atomic_numbers)
# uses one-hot encoding featurization
pc_features = point_cloud_featurization(
Expand Down Expand Up @@ -340,7 +343,7 @@ def collate_fn(batch: list[DataDict]) -> BatchDict:
return concatenate_keys(
batch,
pad_keys=["pc_features", "force", "stress"],
unpacked_keys=["sizes", "src_nodes", "dst_nodes"],
unpacked_keys=["sizes", "pc_src_nodes", "pc_dst_nodes"],
)


Expand Down
4 changes: 2 additions & 2 deletions matsciml/datasets/materials_project/tests/test_mp_reps.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_pairwise_pointcloud():
assert all(
[
key in sample
for key in ["pos", "pc_features", "sizes", "src_nodes", "dst_nodes"]
for key in ["pos", "pc_features", "sizes", "pc_src_nodes", "pc_dst_nodes"]
],
)
# for a pairwise point cloud sizes should be equal
Expand All @@ -26,7 +26,7 @@ def test_sampled_pointcloud():
assert all(
[
key in sample
for key in ["pos", "pc_features", "sizes", "src_nodes", "dst_nodes"]
for key in ["pos", "pc_features", "sizes", "pc_src_nodes", "pc_dst_nodes"]
],
)
# for a pairwise point cloud sizes should be equal
Expand Down
7 changes: 5 additions & 2 deletions matsciml/datasets/nomad/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def collate_fn(batch: list[DataDict]) -> BatchDict:
return concatenate_keys(
batch,
pad_keys=["pc_features"],
unpacked_keys=["sizes", "src_nodes", "dst_nodes"],
unpacked_keys=["sizes", "pc_src_nodes", "pc_dst_nodes"],
)

def raw_sample(self, idx):
Expand Down Expand Up @@ -112,7 +112,10 @@ def _parse_data(self, data: dict[str, Any], return_dict: dict[str, Any]) -> dict
system_size = len(cart_coords)
return_dict["pos"] = cart_coords
chosen_nodes = self.choose_dst_nodes(system_size, self.full_pairwise)
src_nodes, dst_nodes = chosen_nodes["src_nodes"], chosen_nodes["dst_nodes"]
src_nodes, dst_nodes = (
chosen_nodes["pc_src_nodes"],
chosen_nodes["pc_dst_nodes"],
)

atomic_numbers = torch.LongTensor(
[
Expand Down
11 changes: 8 additions & 3 deletions matsciml/datasets/ocp_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pathlib import Path
from typing import Callable

import dgl
import torch

from matsciml.common.registry import registry
Expand Down Expand Up @@ -44,7 +43,10 @@ def data_from_key(
data = super().data_from_key(lmdb_index, subindex)
system_size = data["pos"].size(0)
node_choices = self.choose_dst_nodes(system_size, self.full_pairwise)
src_nodes, dst_nodes = node_choices["src_nodes"], node_choices["dst_nodes"]
src_nodes, dst_nodes = (
node_choices["pc_src_nodes"],
node_choices["pc_dst_nodes"],
)
atom_numbers = data["atomic_numbers"].to(torch.int)
# uses one-hot encoding featurization
pc_features = point_cloud_featurization(
Expand Down Expand Up @@ -95,7 +97,10 @@ def data_from_key(
data = super().data_from_key(lmdb_index, subindex)
system_size = data["pos"].size(0)
node_choices = self.choose_dst_nodes(system_size, self.full_pairwise)
src_nodes, dst_nodes = node_choices["src_nodes"], node_choices["dst_nodes"]
src_nodes, dst_nodes = (
node_choices["pc_src_nodes"],
node_choices["pc_dst_nodes"],
)
atom_numbers = data["atomic_numbers"].to(torch.int)
# uses one-hot encoding featurization
pc_features = point_cloud_featurization(
Expand Down
7 changes: 5 additions & 2 deletions matsciml/datasets/oqmd/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def collate_fn(batch: list[DataDict]) -> BatchDict:
return concatenate_keys(
batch,
pad_keys=["pc_features"],
unpacked_keys=["sizes", "src_nodes", "dst_nodes"],
unpacked_keys=["sizes", "pc_src_nodes", "pc_dst_nodes"],
)

def raw_sample(self, idx):
Expand Down Expand Up @@ -79,7 +79,10 @@ def data_from_key(
cell = data["cell"]
system_size = coords.size(0)
node_choices = self.choose_dst_nodes(system_size, self.full_pairwise)
src_nodes, dst_nodes = node_choices["src_nodes"], node_choices["dst_nodes"]
src_nodes, dst_nodes = (
node_choices["pc_src_nodes"],
node_choices["pc_dst_nodes"],
)
atom_numbers = torch.LongTensor(data["atomic_numbers"])
# uses one-hot encoding featurization
pc_features = point_cloud_featurization(
Expand Down
4 changes: 2 additions & 2 deletions matsciml/datasets/symmetry/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def data_from_key(
# get number of particles in the original system
sample["sizes"] = len(sample["pos"])
# these are meant to correspond to indices
sample["src_nodes"] = torch.arange(sample["num_centers"])
sample["dst_nodes"] = torch.arange(sample["num_centers"])
sample["pc_src_nodes"] = torch.arange(sample["num_centers"])
sample["pc_dst_nodes"] = torch.arange(sample["num_centers"])
# clean up keys
for key in ["label"]:
del sample[key]
Expand Down
Loading