Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SWA checkpoints and swa virial/stress weights #35

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,21 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--virials_weight", help="weight of virials loss", type=float, default=1.0
)
parser.add_argument(
"--swa_virials_weight",
help="weight of virials loss after starting swa",
type=float,
default=10.0,
)
parser.add_argument(
"--stress_weight", help="weight of virials loss", type=float, default=1.0
)
parser.add_argument(
"--swa_stress_weight",
help="weight of stress loss after starting swa",
type=float,
default=10.0,
)
parser.add_argument(
"--dipole_weight", help="weight of dipoles loss", type=float, default=1.0
)
Expand Down
74 changes: 58 additions & 16 deletions mace/tools/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,32 @@ class CheckpointPathInfo:
path: str
tag: str
epochs: int
swa: bool


class CheckpointIO:
def __init__(self, directory: str, tag: str, keep: bool = False) -> None:
def __init__(
self, directory: str, tag: str, keep: bool = False, swa_start: int = None
) -> None:
self.directory = directory
self.tag = tag
self.keep = keep
self.old_path: Optional[str] = None
self.swa_start = swa_start

self._epochs_string = "_epoch-"
self._filename_extension = "pt"

def _get_checkpoint_filename(self, epochs: int) -> str:
def _get_checkpoint_filename(self, epochs: int, swa_start=None) -> str:
if swa_start is not None and epochs > swa_start:
return (
self.tag
+ self._epochs_string
+ str(epochs)
+ "_swa"
+ "."
+ self._filename_extension
)
return (
self.tag
+ self._epochs_string
Expand All @@ -81,17 +94,29 @@ def _parse_checkpoint_path(self, path: str) -> Optional[CheckpointPathInfo]:
regex = re.compile(
rf"^(?P<tag>.+){self._epochs_string}(?P<epochs>\d+)\.{self._filename_extension}$"
)
regex2 = re.compile(
rf"^(?P<tag>.+){self._epochs_string}(?P<epochs>\d+)_swa\.{self._filename_extension}$"
)
# regex2 = re.compile(
# rf"^(?P<tag>.+)_epoch-(?P<epochs>\d+)_swa\.pt$"
# )
match = regex.match(filename)
match2 = regex2.match(filename)
swa = False
if not match:
return None
if not match2:
return None
match = match2
swa = True

return CheckpointPathInfo(
path=path,
tag=match.group("tag"),
epochs=int(match.group("epochs")),
swa=swa,
)

def _get_latest_checkpoint_path(self) -> Optional[str]:
def _get_latest_checkpoint_path(self, swa) -> Optional[str]:
all_file_paths = self._list_file_paths()
checkpoint_info_list = [
self._parse_checkpoint_path(path) for path in all_file_paths
Expand All @@ -105,28 +130,42 @@ def _get_latest_checkpoint_path(self) -> Optional[str]:
f"Cannot find checkpoint with tag '{self.tag}' in '{self.directory}'"
)
return None

latest_checkpoint_info = max(
selected_checkpoint_info_list, key=lambda info: info.epochs
)
selected_checkpoint_info_list_swa = []
selected_checkpoint_info_list_no_swa = []

for ckp in selected_checkpoint_info_list:
if ckp.swa:
selected_checkpoint_info_list_swa.append(ckp)
else:
selected_checkpoint_info_list_no_swa.append(ckp)
if swa:
latest_checkpoint_info = max(
selected_checkpoint_info_list_swa, key=lambda info: info.epochs
)
else:
latest_checkpoint_info = max(
selected_checkpoint_info_list_no_swa, key=lambda info: info.epochs
)
return latest_checkpoint_info.path

def save(self, checkpoint: Checkpoint, epochs: int) -> None:
if not self.keep and self.old_path:
def save(
self, checkpoint: Checkpoint, epochs: int, keep_last: bool = False
) -> None:
if not self.keep and self.old_path and not keep_last:
logging.debug(f"Deleting old checkpoint file: {self.old_path}")
os.remove(self.old_path)

filename = self._get_checkpoint_filename(epochs)
filename = self._get_checkpoint_filename(epochs, self.swa_start)
path = os.path.join(self.directory, filename)
logging.debug(f"Saving checkpoint: {path}")
os.makedirs(self.directory, exist_ok=True)
torch.save(obj=checkpoint, f=path)
self.old_path = path

def load_latest(
self, device: Optional[torch.device] = None
self, swa: Optional[bool] = False, device: Optional[torch.device] = None
) -> Optional[Tuple[Checkpoint, int]]:
path = self._get_latest_checkpoint_path()
path = self._get_latest_checkpoint_path(swa=swa)
if path is None:
return None

Expand All @@ -152,17 +191,20 @@ def __init__(self, *args, **kwargs) -> None:
self.io = CheckpointIO(*args, **kwargs)
self.builder = CheckpointBuilder()

def save(self, state: CheckpointState, epochs: int) -> None:
def save(
self, state: CheckpointState, epochs: int, keep_last: bool = False
) -> None:
checkpoint = self.builder.create_checkpoint(state)
self.io.save(checkpoint, epochs)
self.io.save(checkpoint, epochs, keep_last)

def load_latest(
self,
state: CheckpointState,
swa: Optional[bool] = False,
device: Optional[torch.device] = None,
strict=False,
) -> Optional[int]:
result = self.io.load_latest(device=device)
result = self.io.load_latest(swa=swa, device=device)
if result is None:
return None

Expand Down
43 changes: 29 additions & 14 deletions mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,28 @@ def train(
lowest_loss = np.inf
patience_counter = 0
swa_start = True
keep_last = False

if max_grad_norm is not None:
logging.info(f"Using gradient clipping with tolerance={max_grad_norm:.3f}")
logging.info("Started training")
for epoch in range(start_epoch, max_num_epochs):
epoch = start_epoch
while epoch < max_num_epochs:
# LR scheduler and SWA update
if swa is None or epoch < swa.start:
if epoch > start_epoch:
lr_scheduler.step(valid_loss) # Can break if exponential LR, TODO fix that!
else:
if swa_start:
logging.info("Changing loss based on SWA")
lowest_loss = np.inf
swa_start = False
keep_last = True
loss_fn = swa.loss_fn
swa.model.update_parameters(model)
if epoch > start_epoch:
swa.scheduler.step()

# Train
for batch in train_loader:
_, opt_metrics = take_step(
Expand Down Expand Up @@ -160,7 +177,12 @@ def train(
)
if valid_loss >= lowest_loss:
patience_counter += 1
if patience_counter >= patience:
if patience_counter >= patience and epoch < swa.start:
logging.info(
f"Stopping optimization after {patience_counter} epochs without improvement and starting swa"
)
epoch = swa.start
elif patience_counter >= patience:
logging.info(
f"Stopping optimization after {patience_counter} epochs without improvement"
)
Expand All @@ -173,24 +195,17 @@ def train(
checkpoint_handler.save(
state=CheckpointState(model, optimizer, lr_scheduler),
epochs=epoch,
keep_last=keep_last,
)
keep_last = False
else:
checkpoint_handler.save(
state=CheckpointState(model, optimizer, lr_scheduler),
epochs=epoch,
keep_last=keep_last,
)

# LR scheduler and SWA update
if swa is None or epoch < swa.start:
lr_scheduler.step(valid_loss) # Can break if exponential LR, TODO fix that!
else:
if swa_start:
logging.info("Changing loss based on SWA")
swa_start = False
loss_fn = swa.loss_fn
swa.model.update_parameters(model)
swa.scheduler.step()

keep_last = False
epoch += 1
logging.info("Training complete")


Expand Down
110 changes: 73 additions & 37 deletions scripts/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,27 +368,29 @@ def main() -> None:
else:
raise RuntimeError(f"Unknown scheduler: '{args.scheduler}'")

checkpoint_handler = tools.CheckpointHandler(
directory=args.checkpoints_dir, tag=tag, keep=args.keep_checkpoints
)

start_epoch = 0
if args.restart_latest:
opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler), device=device
)
if opt_start_epoch is not None:
start_epoch = opt_start_epoch

swa: Optional[tools.SWAContainer] = None
swas = [False]
if args.swa:
assert dipole_only is False, "swa for dipole fitting not implemented"
swas.append(True)
if args.start_swa is None:
args.start_swa = (
args.max_num_epochs // 4 * 3
) # if not set start swa at 75% of training
if args.loss == "forces_only":
logging.info("Can not select swa with forces only loss.")
elif args.loss == "virials":
loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss(
energy_weight=args.swa_energy_weight,
forces_weight=args.swa_forces_weight,
virials_weight=args.swa_virials_weight,
)
elif args.loss == "stress":
loss_fn_energy = modules.WeightedEnergyForcesStressLoss(
energy_weight=args.swa_energy_weight,
forces_weight=args.swa_forces_weight,
stress_weight=args.swa_stress_weight,
)
elif args.loss == "energy_forces_dipole":
loss_fn_energy = modules.WeightedEnergyForcesDipoleLoss(
args.swa_energy_weight,
Expand Down Expand Up @@ -418,6 +420,30 @@ def main() -> None:
loss_fn=loss_fn_energy,
)

checkpoint_handler = tools.CheckpointHandler(
directory=args.checkpoints_dir,
tag=tag,
keep=args.keep_checkpoints,
swa_start=args.start_swa,
)

start_epoch = 0
if args.restart_latest:
try:
opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=True,
device=device,
)
except:
opt_start_epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=False,
device=device,
)
if opt_start_epoch is not None:
start_epoch = opt_start_epoch

ema: Optional[ExponentialMovingAverage] = None
if args.ema:
ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay)
Expand Down Expand Up @@ -447,11 +473,6 @@ def main() -> None:
log_errors=args.error_table,
)

epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler), device=device
)
logging.info(f"Loaded model from epoch {epoch}")

# Evaluation on test datasets
logging.info("Computing metrics for training, validation, and test sets")

Expand All @@ -460,28 +481,43 @@ def main() -> None:
("valid", collections.valid),
] + collections.tests

table = create_error_table(
table_type=args.error_table,
all_collections=all_collections,
z_table=z_table,
r_max=args.r_max,
valid_batch_size=args.valid_batch_size,
model=model,
loss_fn=loss_fn,
output_args=output_args,
device=device,
)

logging.info("\n" + str(table))
for swa_eval in swas:
epoch = checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=swa_eval,
device=device,
)
model.to(device)
logging.info(f"Loaded model from epoch {epoch}")

table = create_error_table(
table_type=args.error_table,
all_collections=all_collections,
z_table=z_table,
r_max=args.r_max,
valid_batch_size=args.valid_batch_size,
model=model,
loss_fn=loss_fn,
output_args=output_args,
device=device,
)

# Save entire model
model_path = Path(args.checkpoints_dir) / (tag + ".model")
logging.info(f"Saving model to {model_path}")
if args.save_cpu:
model = model.to("cpu")
torch.save(model, model_path)
logging.info("\n" + str(table))

torch.save(model, Path(args.model_dir) / (args.name + ".model"))
# Save entire model
if swa_eval:
model_path = Path(args.checkpoints_dir) / (tag + "_swa.model")
else:
model_path = Path(args.checkpoints_dir) / (tag + ".model")
logging.info(f"Saving model to {model_path}")
if args.save_cpu:
model = model.to("cpu")
torch.save(model, model_path)

if swa_eval:
torch.save(model, Path(args.model_dir) / (args.name + "_swa.model"))
else:
torch.save(model, Path(args.model_dir) / (args.name + ".model"))

logging.info("Done")

Expand Down
3 changes: 1 addition & 2 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import sys
from pathlib import Path

import pytest

import ase.io
import numpy as np
import pytest
from ase.atoms import Atoms
from ase.calculators.test import gradient_test
from ase.constraints import ExpCellFilter
Expand Down