Skip to content

Commit

Permalink
v0.3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Oct 23, 2023
1 parent 74a6a70 commit 19de8ff
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions chgnet/model/composition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _assemble_graphs(self, graphs: list[CrystalGraph]):
assembled batch_graph that contains all information for model.
"""
composition_feas = []
for _graph_idx, graph in enumerate(graphs):
for graph in graphs:
composition_fea = torch.bincount(
graph.atomic_number - 1, minlength=self.max_num_elements
)
Expand Down Expand Up @@ -201,7 +201,7 @@ def initialize_from(self, dataset: str):
"""Initialize pre-fitted weights from a dataset."""
if dataset in ["MPtrj", "MPtrj_e"]:
self.initialize_from_MPtrj()
elif dataset in ["MPF"]:
elif dataset == "MPF":
self.initialize_from_MPF()
else:
raise NotImplementedError(f"{dataset=} not supported yet")
Expand Down
6 changes: 3 additions & 3 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(
eta_min=decay_fraction * learning_rate,
)
self.scheduler_type = "cos"
elif scheduler in ["CosRestartLR"]:
elif scheduler == "CosRestartLR":
scheduler_params = kwargs.pop(
"scheduler_params", {"decay_fraction": 1e-2, "T_0": 10, "T_mult": 2}
)
Expand Down Expand Up @@ -471,7 +471,7 @@ def get_best_model(self):
if self.best_model is None:
raise RuntimeError("the model needs to be trained first")
MAE = min(self.training_history["e"]["val"])
print(f"Best model has val {MAE = :.4}")
print(f"Best model has val {MAE =:.4}")
return self.best_model

@property
Expand Down Expand Up @@ -616,7 +616,7 @@ def __init__(
self.criterion = nn.MSELoss()
elif criterion in ["MAE", "mae", "l1"]:
self.criterion = nn.L1Loss()
elif criterion in ["Huber"]:
elif criterion == "Huber":
self.criterion = nn.HuberLoss(delta=delta)
else:
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "chgnet"
version = "0.3.0"
version = "0.3.1"
description = "Pretrained Universal Neural Network Potential for Charge-informed Atomistic Modeling"
authors = [{ name = "Bowen Deng", email = "bowendeng@berkeley.edu" }]
requires-python = ">=3.9"
Expand Down Expand Up @@ -45,7 +45,7 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] }

[tool.setuptools.package-data]
"chgnet" = ["*.json"]
"chgnet.pretrained" = ["**/*"]
"chgnet.pretrained" = ["*", "**/*"]

[tool.ruff]
target-version = "py39"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_predict_supercell() -> None:
np.repeat(pristine_prediction["f"], 4, axis=0), rel=1e-4, abs=1e-4
)

assert out["s"] == pytest.approx(pristine_prediction["s"], rel=1e-4, abs=1e-4)
assert out["s"] == pytest.approx(pristine_prediction["s"], rel=1e-3, abs=1e-3)

assert out["site_energies"] == pytest.approx(
np.repeat(pristine_prediction["site_energies"], 4), rel=1e-4, abs=1e-4
Expand Down

0 comments on commit 19de8ff

Please sign in to comment.