Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add alexandria dataset to matsciml toolkit https://github.com/IntelLabs/matsciml/discussions/107 #132

Merged
merged 27 commits into from
Mar 13, 2024

Conversation

JonathanSchmidt1
Copy link
Contributor

This pull request adds the AlexandriaRequest class to download the Alexandria database to an lmdb dataset and a AlexandriaDataset class to use it for ML within the matsciml toolkit.
The AlexandriaDataset class mostly follows the Materials project example.
I also added most of the same tests.
At the moment there are not test for the AlexandriaRequest class.

@laserkelvin laserkelvin self-requested a review February 20, 2024 16:19
@laserkelvin laserkelvin added enhancement New feature or request data Issues related to data loading, pipelining, etc. labels Feb 20, 2024
Copy link
Collaborator

@laserkelvin laserkelvin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @JonathanSchmidt1 for the great PR!!

For the most part things look good and functional to me; I did leave a few style remarks, and wanted to see if you could add a PyG test as well. If you don't normally use that pipeline, don't worry about it and we can add it later.

matsciml/datasets/alexandria/README.md Show resolved Hide resolved
matsciml/datasets/alexandria/dataset.py Outdated Show resolved Hide resolved
matsciml/datasets/alexandria/dataset.py Outdated Show resolved Hide resolved
matsciml/datasets/alexandria/dataset.py Outdated Show resolved Hide resolved
matsciml/datasets/alexandria/dataset.py Outdated Show resolved Hide resolved


@registry.register_dataset("AlexandriaDataset")
class AlexandriaDataset(PointCloudDataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More a comment for discussion with @melo-gonzo but the fact that @JonathanSchmidt1 had to copy over a bunch of functions here means we could do better on the abstraction: what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parse_structure/parse_symmetry functions from the materials project class seems quite general. Maybe it makes sense to put them in some utils file instead and import them in the respective datasets.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, a lot of this can be refactored into some sort of utils file - it would make all the datasets much cleaner. I can draft up an issue for this.

matsciml/datasets/alexandria/tests/test_alexandria.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@melo-gonzo melo-gonzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! It looks pretty good overall, I've left a few initial comments. In addition to what was noted, can you add some details to the DATASETS.md as well? It's a bit redundant but for consistency it will be nice to have some details in there as well.

examples/model_demos/mpnn_alexandria_dgl.py Outdated Show resolved Hide resolved
matsciml/datasets/alexandria/README.md Show resolved Hide resolved
matsciml/datasets/alexandria/api.py Outdated Show resolved Hide resolved
matsciml/datasets/alexandria/dataset.py Outdated Show resolved Hide resolved
@laserkelvin
Copy link
Collaborator

LOL the timing on our reviews

…each dataset. Also removed 1 atom structures from devset due to dgl errors and added max_atom parameter for download to avoid oom errors during training
@JonathanSchmidt1
Copy link
Contributor Author

Thank you for both of the reviews. I will go through them today and tomorrow.
I noticed one issue that the way periodic boundary conditions are dealt with was updated over the last weeks when my graphtransform test failed.
How is the output of a periodic graph supposed to look after the transformation with [PeriodicPropertiesTransform(20.), PointCloudToGraphTransform('pyg', cutoff_dist=20.0)] now?
i.e. Are the edges supposed to show the edges including boundary conditions already? Or is this calculated somewhere else later based on the shifts and offsets (if yes where?).

btw maybe it's an issue for the materials project dataset also but for my dataset the check for precomputed src and dst nodes lead to there only being self loops in the end

g.edge_index
tensor([[0, 1, 2],
        [0, 1, 2]])

instead of

tensor([[1, 2, 2],
        [0, 0, 1]])

with the check commented out

@JonathanSchmidt1
Copy link
Contributor Author

JonathanSchmidt1 commented Feb 21, 2024

ok the first question solved itself. Now I am just wondering about the edge case if you do not use the PeriodicPropertiesTransform
but in the materials project or Alexandria dataset you already have src/dst nodes set from the PointCloudDataset choose_dst_nodes method. However these do not seem to be the correct edges to me e.g. with a cutoff of 40 I would expect connections between the different atoms (this is related to the second point in the previous question):

>>> dset.__getitem__(0)['graph'].edge_index
tensor([[0, 1, 2],
        [0, 1, 2]])
>>> dset.__getitem__(0)['graph'].pos
tensor([[ 0.0000,  0.0000,  0.0000],
        [22.0659,  4.5494,  2.7011],
        [ 6.3406,  1.3073,  0.7761]])

but maybe I am missunderstanding the edge indices

Test do not run without PeriodicPropertiesTransform,
Added units to README, changed num_atoms to 100,  fixed __init__ that was changed by precommit hooks,
@laserkelvin
Copy link
Collaborator

laserkelvin commented Feb 26, 2024

but in the materials project or Alexandria dataset you already have src/dst nodes set from the PointCloudDataset choose_dst_nodes method. However these do not seem to be the correct edges to me e.g. with a cutoff of 40 I would expect connections between the different atoms (this is related to the second point in the previous question):

To be clear (I don't know if this is a misunderstanding or not so forgive me) the function choose_dst_nodes is probably poorly named, because point clouds don't have edges. That function is used to construct featurizations for point clouds but are not related to graph edge construction, so the cut-off radii do not come into play at all. That's why I think edge_index don't make sense

If you want to understand the graph edge construction, there are basically two implemented paths now: the cutoff_radius value is applied either by an explicit dense matrix distance calculation (if there is no cell attribute in the data sample, i.e. assume it's basically like a molecular graph) or if there is cell, we rely on pymatgen to compute the connectivity.

Sorry about the late reply - it was a crazy week last week. Hopefully my response answers your questions, and please let us know if and when the code is ready for re-review?

@JonathanSchmidt1
Copy link
Contributor Author

JonathanSchmidt1 commented Feb 26, 2024

but in the materials project or Alexandria dataset you already have src/dst nodes set from the PointCloudDataset choose_dst_nodes method. However these do not seem to be the correct edges to me e.g. with a cutoff of 40 I would expect connections between the different atoms (this is related to the second point in the previous question):

To be clear (I don't know if this is a misunderstanding or not so forgive me) the function choose_dst_nodes is probably poorly named, because point clouds don't have edges. That function is used to construct featurizations for point clouds but are not related to graph edge construction, so the cut-off radii do not come into play at all. That's why I think edge_index don't make sense

If you want to understand the graph edge construction, there are basically two implemented paths now: the cutoff_radius value is applied either by an explicit dense matrix distance calculation (if there is no cell attribute in the data sample, i.e. assume it's basically like a molecular graph) or if there is cell, we rely on pymatgen to compute the connectivity.

Thank you very much for the explanation.
Maybe I just have a different bug in my code but right now I think the following is happening.
This block already sets the src/dst nodes in
the last line (both in alexandria and materials project).

        src_nodes, dst_nodes = chosen_nodes["src_nodes"], chosen_nodes["dst_nodes"]
        atom_numbers = torch.LongTensor(structure.atomic_numbers)
        # uses one-hot encoding featurization
        pc_features = point_cloud_featurization(
            atom_numbers[src_nodes],
            atom_numbers[dst_nodes],
            100,
        )
        # keep atomic numbers for graph featurization
        return_dict["atomic_numbers"] = atom_numbers
        return_dict["pc_features"] = pc_features
        return_dict["sizes"] = system_size
        return_dict.update(**chosen_nodes)

After the update the PointCloudToGraphTransform does not do anything if return_dict.update(**chosen_nodes) already exists

        def _convert_dgl(self, data: DataDict) -> None:
            atom_numbers = self.get_atom_types(data)
            coords = data["pos"]
            atom_numbers, coords = self._apply_mask(atom_numbers, coords, data)
            num_nodes = len(atom_numbers)
            # use pre-computed edges with periodic boundary conditions
            if all([f"{key}_nodes" in data for key in ["src", "dst"]]):

returning "wrong edges" that were actually part of the point_cloud_featurization.
If I use the PeriodicPropertiesTransform this is ok as the src/dst nodes get overwritten anyway and than the precomputed check works as intended.

If you run:

from matsciml.datasets import MaterialsProjectDataset
from matsciml.datasets.transforms import PointCloudToGraphTransform, PeriodicPropertiesTransform
dset = MaterialsProjectDataset(MaterialsProjectDataset.__devset__, [PointCloudToGraphTransform('dgl', cutoff_dist=40.0)])
dset.__getitem__(0)['graph']

and add some print statement for the precomputed check you will find that
it already uses precomupted edges.

Sorry about the late reply - it was a crazy week last week. Hopefully my response answers your questions, and please let us know if and when the code is ready for re-review?

Sorry I will have to find another few hours to finish it. maybe this weekend

@JonathanSchmidt1
Copy link
Contributor Author

@laserkelvin One question concerning the test that I did not completely understand is still open. The rest should be ready for rereview.

@laserkelvin
Copy link
Collaborator

I'm out of office for the rest of the week and I'll review it once I get back.

Copy link
Collaborator

@laserkelvin laserkelvin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think just some minor comments, and otherwise everything looks good and we can merge afterwards!

examples/datasets/alexandria/single_task_devset.py Outdated Show resolved Hide resolved
examples/datasets/alexandria/single_task_devset.py Outdated Show resolved Hide resolved
examples/datasets/alexandria/single_task_egnn.py Outdated Show resolved Hide resolved
matsciml/datasets/alexandria/__init__.py Outdated Show resolved Hide resolved
self.get_data_dict(entry)
for entry in tqdm(data["entries"], desc=f"Processing file {index}")
]
if self.devset:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to move this up so you don't run all of the requests for devsets? That way you don't use too much server load if you don't need to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can move it up one step so we run self.get_data_dict only for the 100 entries. Right now we only have the database split up into the 100k structure jsons so one has to download at least one of them (and it also will only be one of them for the devset so that should not be an issue for the server).

     def devset(cls, lmdb_target_dir: str) -> AlexandriaRequest:
        return cls([0], lmdb_target_dir, dataset="3D", devset=True)

matsciml/datasets/alexandria/dataset.py Show resolved Hide resolved
@JonathanSchmidt1
Copy link
Contributor Author

Thanks for the corrections :)

@JonathanSchmidt1
Copy link
Contributor Author

I implemented the suggestions.
I also added an extra warning if you want to have look below:
This is needed because the 2D/1D structures only have a vacuum of 15 Angstrom so having a larger cutoff radius can result in wrong neighborlists.

    def __init__(
        self,
        lmdb_root_path: str | Path,
        transforms: list[Callable[..., Any]] | None = None,
        full_pairwise: bool = True,
    ) -> None:
        super().__init__(lmdb_root_path, transforms, full_pairwise)
        if self.transforms:
            for transform in self.transforms:
                if (
                    (
                        hasattr(transform, "cutoff_radius")
                        and transform.cutoff_radius > 15.0
                    )
                    or (
                        hasattr(transform, "cutoff_dist")
                        and transform.cutoff_dist > 15.0
                    )
                    or (
                        hasattr(transform, "adaptive_cutoff")
                        and transform.adaptive_cutoff > 15.0
                    )
                ):
                    warnings.warn(
                        f"Transform {transform} has a cutoff radius > 15.0 this will lead to wrong neighborlists for the two and one-dimensional datasets."
                    )

Copy link
Collaborator

@melo-gonzo melo-gonzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more minor updates from my end, and then will be good to approve!

examples/datasets/alexandria/single_task_base.py Outdated Show resolved Hide resolved
examples/datasets/alexandria/single_task_egnn.py Outdated Show resolved Hide resolved
examples/datasets/alexandria/single_task_gala.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@laserkelvin laserkelvin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just highlighted the unused import flagged by pre-commit, but other than that I'm giving my approval for merging.

@laserkelvin
Copy link
Collaborator

I think we're good to go - I looked into the failing tests and nothing is relevant to this PR.

Thank you @JonathanSchmidt1 for all your hard work on this PR!

@laserkelvin laserkelvin merged commit d1d7e87 into IntelLabs:main Mar 13, 2024
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
data Issues related to data loading, pipelining, etc. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants