Skip to content

Commit

Permalink
allow specifying decay fraction in CosineAnnealing LR scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
BowenD-UCB committed Oct 12, 2023
1 parent a760ec8 commit dbb0305
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
1 change: 1 addition & 0 deletions chgnet/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def forward(
new_bond_feas = aggregate(
bond_update, bond_graph[:, 1], average=False, num_owner=len(bond_feas)
)

# New bond features
if self.use_mlp_out:
new_bond_feas = self.mlp_out(new_bond_feas)
Expand Down
13 changes: 10 additions & 3 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,23 @@ def __init__(
self.scheduler = ExponentialLR(self.optimizer, **scheduler_params)
self.scheduler_type = "exp"
elif scheduler in ["CosineAnnealingLR", "CosLR", "Cos", "cos"]:
scheduler_params = kwargs.pop("scheduler_params", {"decay_fraction": 1e-2})
decay_fraction = scheduler_params.pop("decay_fraction")
self.scheduler = CosineAnnealingLR(
self.optimizer,
T_max=10 * epochs, # Maximum number of iterations.
eta_min=1e-2 * learning_rate,
eta_min=decay_fraction * learning_rate,
)
self.scheduler_type = "cos"
elif scheduler in ["CosRestartLR"]:
scheduler_params = kwargs.pop("scheduler_params", {"T_0": 10, "T_mult": 2})
scheduler_params = kwargs.pop(
"scheduler_params", {"decay_fraction": 1e-2, "T_0": 10, "T_mult": 2}
)
decay_fraction = scheduler_params.pop("decay_fraction")
self.scheduler = CosineAnnealingWarmRestarts(
self.optimizer, eta_min=1e-2 * learning_rate, **scheduler_params
self.optimizer,
eta_min=decay_fraction * learning_rate,
**scheduler_params,
)
self.scheduler_type = "cosrestart"
else:
Expand Down

0 comments on commit dbb0305

Please sign in to comment.