|
11 | 11 | from utils.plot import show_or_save_fig, log_scale
|
12 | 12 | import os
|
13 | 13 | import sys
|
| 14 | +import orbax.checkpoint as ocp |
14 | 15 |
|
15 | 16 | parser = ArgumentParser()
|
16 | 17 | parser.add_argument('--save_dir', type=str, default=None, help="Specify a path where the data will be stored.")
|
|
47 | 48 | parser.add_argument('--BS', type=int, default=512, help="Batch size used for training.")
|
48 | 49 | parser.add_argument('--lr', type=float, default=1e-4, help="Learning rate")
|
49 | 50 | parser.add_argument('--force_clip', type=float, default=1e8, help="Clipping value for the force")
|
| 51 | +parser.add_argument('--load', type=bool, default=False, const=True, nargs='?', |
| 52 | + help="Continue training and load the model from the save_dir.") |
50 | 53 |
|
51 | 54 | parser.add_argument('--seed', type=int, default=1, help="The seed that will be used for initialization")
|
52 | 55 |
|
|
55 | 58 | parser.add_argument('--dt', type=float, required=True)
|
56 | 59 |
|
57 | 60 | # plotting
|
58 |
| -parser.add_argument('--log_plots', type=bool, default=False, const=True, nargs='?', help="Save plots in log scale where possible") |
| 61 | +parser.add_argument('--log_plots', type=bool, default=False, const=True, nargs='?', |
| 62 | + help="Save plots in log scale where possible") |
59 | 63 |
|
60 | 64 |
|
61 | 65 | def main():
|
@@ -102,13 +106,36 @@ def main():
|
102 | 106 | state_q = train_state.TrainState.create(apply_fn=setup.model_q.apply, params=params_q, tx=optimizer_q)
|
103 | 107 | loss_fn = setup.construct_loss(state_q, args.gamma, args.BS)
|
104 | 108 |
|
| 109 | + ckpt = {'model': state_q, 'losses': []} |
| 110 | + orbax_checkpointer = ocp.PyTreeCheckpointer() |
| 111 | + options = ocp.CheckpointManagerOptions( |
| 112 | + save_interval_steps=1_000, |
| 113 | + max_to_keep=3, |
| 114 | + create=True, |
| 115 | + cleanup_tmp_directories=True, |
| 116 | + save_on_steps=[args.epochs] |
| 117 | + ) |
| 118 | + checkpoint_manager = ocp.CheckpointManager(os.path.abspath(args.save_dir), orbax_checkpointer, options) |
| 119 | + |
| 120 | + if args.load: |
| 121 | + if checkpoint_manager.latest_step() is None: |
| 122 | + print("Warning: No checkpoint found.") |
| 123 | + else: |
| 124 | + print('Loading checkpoint:', checkpoint_manager.latest_step()) |
| 125 | + |
| 126 | + state_restored = checkpoint_manager.restore(checkpoint_manager.latest_step()) |
| 127 | + # The model needs to be casted to a trainstate object |
| 128 | + state_restored['model'] = checkpoint_manager.restore(checkpoint_manager.latest_step(), items=ckpt)['model'] |
| 129 | + ckpt = state_restored |
| 130 | + |
105 | 131 | key, train_key = jax.random.split(key)
|
106 |
| - state_q, loss_plot = train(state_q, loss_fn, args.epochs, train_key) |
107 |
| - print("Number of potential evaluations", args.BS * args.epochs) |
| 132 | + ckpt = train(ckpt, loss_fn, args.epochs, train_key, checkpoint_manager) |
| 133 | + state_q = ckpt['model'] |
| 134 | + print("Total number of potential evaluations", args.BS * args.epochs) |
108 | 135 |
|
109 |
| - if jnp.isnan(jnp.array(loss_plot)).any(): |
| 136 | + if jnp.isnan(jnp.array(ckpt['losses'])).any(): |
110 | 137 | print("Warning: Loss contains NaNs")
|
111 |
| - plt.plot(loss_plot) |
| 138 | + plt.plot(ckpt['losses']) |
112 | 139 | log_scale(args.log_plots, False, True)
|
113 | 140 | show_or_save_fig(args.save_dir, 'loss_plot.pdf')
|
114 | 141 |
|
|
0 commit comments