Skip to content

Commit 7715c75

Browse files
committed
Add support for checkpoints
1 parent 064c60a commit 7715c75

File tree

2 files changed

+51
-13
lines changed

2 files changed

+51
-13
lines changed

main.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from utils.plot import show_or_save_fig, log_scale
1212
import os
1313
import sys
14+
import orbax.checkpoint as ocp
1415

1516
parser = ArgumentParser()
1617
parser.add_argument('--save_dir', type=str, default=None, help="Specify a path where the data will be stored.")
@@ -47,6 +48,8 @@
4748
parser.add_argument('--BS', type=int, default=512, help="Batch size used for training.")
4849
parser.add_argument('--lr', type=float, default=1e-4, help="Learning rate")
4950
parser.add_argument('--force_clip', type=float, default=1e8, help="Clipping value for the force")
51+
parser.add_argument('--load', type=bool, default=False, const=True, nargs='?',
52+
help="Continue training and load the model from the save_dir.")
5053

5154
parser.add_argument('--seed', type=int, default=1, help="The seed that will be used for initialization")
5255

@@ -55,7 +58,8 @@
5558
parser.add_argument('--dt', type=float, required=True)
5659

5760
# plotting
58-
parser.add_argument('--log_plots', type=bool, default=False, const=True, nargs='?', help="Save plots in log scale where possible")
61+
parser.add_argument('--log_plots', type=bool, default=False, const=True, nargs='?',
62+
help="Save plots in log scale where possible")
5963

6064

6165
def main():
@@ -102,13 +106,36 @@ def main():
102106
state_q = train_state.TrainState.create(apply_fn=setup.model_q.apply, params=params_q, tx=optimizer_q)
103107
loss_fn = setup.construct_loss(state_q, args.gamma, args.BS)
104108

109+
ckpt = {'model': state_q, 'losses': []}
110+
orbax_checkpointer = ocp.PyTreeCheckpointer()
111+
options = ocp.CheckpointManagerOptions(
112+
save_interval_steps=1_000,
113+
max_to_keep=3,
114+
create=True,
115+
cleanup_tmp_directories=True,
116+
save_on_steps=[args.epochs]
117+
)
118+
checkpoint_manager = ocp.CheckpointManager(os.path.abspath(args.save_dir), orbax_checkpointer, options)
119+
120+
if args.load:
121+
if checkpoint_manager.latest_step() is None:
122+
print("Warning: No checkpoint found.")
123+
else:
124+
print('Loading checkpoint:', checkpoint_manager.latest_step())
125+
126+
state_restored = checkpoint_manager.restore(checkpoint_manager.latest_step())
127+
# The model needs to be casted to a trainstate object
128+
state_restored['model'] = checkpoint_manager.restore(checkpoint_manager.latest_step(), items=ckpt)['model']
129+
ckpt = state_restored
130+
105131
key, train_key = jax.random.split(key)
106-
state_q, loss_plot = train(state_q, loss_fn, args.epochs, train_key)
107-
print("Number of potential evaluations", args.BS * args.epochs)
132+
ckpt = train(ckpt, loss_fn, args.epochs, train_key, checkpoint_manager)
133+
state_q = ckpt['model']
134+
print("Total number of potential evaluations", args.BS * args.epochs)
108135

109-
if jnp.isnan(jnp.array(loss_plot)).any():
136+
if jnp.isnan(jnp.array(ckpt['losses'])).any():
110137
print("Warning: Loss contains NaNs")
111-
plt.plot(loss_plot)
138+
plt.plot(ckpt['losses'])
112139
log_scale(args.log_plots, False, True)
113140
show_or_save_fig(args.save_dir, 'loss_plot.pdf')
114141

training/train.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,35 @@
1-
from typing import Callable, Tuple
1+
from typing import Callable, Dict, Any
2+
from flax.training import orbax_utils
23
from flax.training.train_state import TrainState
34
import jax
45
from jax.typing import ArrayLike
6+
from orbax.checkpoint import CheckpointManager
57
from tqdm import trange
68

79

8-
def train(state_q: TrainState, loss_fn: Callable, epochs: int, key: ArrayLike) -> Tuple[TrainState, list[float]]:
10+
def train(ckpt: Any, loss_fn: Callable, epochs: int, key: ArrayLike,
11+
checkpoint_manager: CheckpointManager) -> Dict:
12+
if ckpt['model'].step >= epochs:
13+
return ckpt
14+
915
@jax.jit
1016
def train_step(_state_q: TrainState, _key: ArrayLike) -> (TrainState, float):
1117
grad_fn = jax.value_and_grad(loss_fn, argnums=0)
1218
loss, grads = grad_fn(_state_q.params, _key)
1319
_state_q = _state_q.apply_gradients(grads=grads)
1420
return _state_q, loss
1521

16-
loss_plot = []
17-
with trange(epochs) as pbar:
18-
for _ in pbar:
22+
with trange(ckpt['model'].step, epochs) as pbar:
23+
for i in pbar:
1924
key, loc_key = jax.random.split(key)
20-
state_q, loss = train_step(state_q, loc_key)
25+
ckpt['model'], loss = train_step(ckpt['model'], loc_key)
2126
pbar.set_postfix(loss=loss)
22-
loss_plot.append(loss)
27+
ckpt['losses'].append(loss.item())
28+
29+
if checkpoint_manager.should_save(i + 1):
30+
save_args = orbax_utils.save_args_from_target(ckpt)
31+
checkpoint_manager.save(i + 1, ckpt, save_kwargs={'save_args': save_args})
32+
33+
checkpoint_manager.wait_until_finished()
2334

24-
return state_q, loss_plot
35+
return ckpt

0 commit comments

Comments
 (0)