# Creating, Saving and Loading Checkpoints for RL training loops
 This notebook is about checkpoints for reinforcement learning training loops. It is designed for using pytorch networks. This notebook is intended as help for creating loops creating, saving and loading checkpoints and does not include full code of networks or RL training loops.

In [None]:
import os
import torch
import pickle
from dotmap import DotMap
import pathlib

>First define parameters about the location to save checkpoints as well as the frequency they should be saved

In [None]:
# define name of folder to save checkpoints, checkpoint frequency and path to save checkpoints
EXPERIMENT_NAME = 'lunarlander-goestomars'
CHECKPOINT_FREQUENCY = 10
ROOT_DIR = 'C:/Users/name/Documents/folder'
BASE_CHECKPOINT_PATH = f"{ROOT_DIR}/checkpoints/{EXPERIMENT_NAME}/"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

>These checkpoints should include the hyperparameters, which are defined here in the following way

In [None]:
# Define some parameters that are used in training, and saved every checkpoint
class HyperParameters():
    scale_reward:         float = SCALE_REWARD
    min_reward:           float = MIN_REWARD
    batch_size:           int   = BATCH_SIZE
    discount:             float = DISCOUNT
    gae_lambda:           float = GAE_LAMBDA
    ppo_clip:             float = PPO_CLIP
    ppo_epochs:           int   = PPO_EPOCHS
    max_grad_norm:        float = MAX_GRAD_NORM
    entropy_factor:       float = ENTROPY_FACTOR
    actor_learning_rate:  float = ACTOR_LEARNING_RATE
    critic_learning_rate: float = CRITIC_LEARNING_RATE
    rollout_steps:        int = ROLLOUT_STEPS
    parallel_rollouts:    int = PARALLEL_ROLLOUTS
        
hp = HyperParameters(parallel_rollouts=32, rollout_steps=2000, batch_size=600)
batch_count = hp.parallel_rollouts * hp.rollout_steps / hp.recurrent_seq_len / hp.batch_size

>Next are the functions to save and load the checkpoints.

- save_checkpoint() creates the files that are going to be saved using dotmap. It saves the environment, iteration, hyperparameters and actor and critic networks including optimizers. More things can be added here such as termination conditions that are used to stop training early

- load_checkpoint() loads the created checkpoints to continue training from the latest checkpoint.

- load_trained_model() loads only actor and critic to use them for testing or evaluation.

- get_last_checkpoint_iteration() is used to continue counting from the latest checkpoint

In [None]:
def save_checkpoint(actor, critic, actor_optimizer, critic_optimizer, iteration, stop_conditions):
    print("Saving Checkpoint")
    #Save training checkpoint.
    checkpoint = DotMap()
    checkpoint.env = ENV
    checkpoint.iteration = iteration
    checkpoint.hp = hp
    CHECKPOINT_PATH = BASE_CHECKPOINT_PATH + f"{iteration}/"
    pathlib.Path(CHECKPOINT_PATH).mkdir(parents=True, exist_ok=True) 
    with open(CHECKPOINT_PATH + "parameters.pt", "wb") as f:
        pickle.dump(checkpoint, f)
    with open(CHECKPOINT_PATH + "actor_class.pt", "wb") as f:
        pickle.dump(Actor, f)
    with open(CHECKPOINT_PATH + "critic_class.pt", "wb") as f:
        pickle.dump(Critic, f)
    torch.save(actor.state_dict(), CHECKPOINT_PATH + "actor.pt")
    torch.save(critic.state_dict(), CHECKPOINT_PATH + "critic.pt")
    torch.save(actor_optimizer.state_dict(), CHECKPOINT_PATH + "actor_optimizer.pt")
    torch.save(critic_optimizer.state_dict(), CHECKPOINT_PATH + "critic_optimizer.pt")
    
def load_checkpoint(iteration):
    print("Loading Checkpoint")
    #Load from training checkpoint.
    CHECKPOINT_PATH = BASE_CHECKPOINT_PATH + f"{iteration}/"
    with open(CHECKPOINT_PATH + "parameters.pt", "rb") as f:
        checkpoint = pickle.load(f)
        
    assert ENV == checkpoint.env, "To resume training environment must match current settings."
    assert hp == checkpoint.hp, "To resume training hyperparameters must match current settings."

    actor_state_dict = torch.load(CHECKPOINT_PATH + "actor.pt", map_location=torch.device(DEVICE))
    critic_state_dict = torch.load(CHECKPOINT_PATH + "critic.pt", map_location=torch.device(DEVICE))
    actor_optimizer_state_dict = torch.load(CHECKPOINT_PATH + "actor_optimizer.pt", map_location=torch.device(DEVICE))
    critic_optimizer_state_dict = torch.load(CHECKPOINT_PATH + "critic_optimizer.pt", map_location=torch.device(DEVICE))
    
    return (actor_state_dict, critic_state_dict,
           actor_optimizer_state_dict, critic_optimizer_state_dict,
           checkpoint.stop_conditions)

def load_trained_model(iteration):
    #Load previously trained model based on the number of training iterations, to use for testing for example
    print("Loading Trained Model")
    obsv_dim, action_dim, continuous_action_space = get_env_space()
    # actor and critic are pytorch NNs
    actor = Actor(obsv_dim,
                  action_dim,
                  continuous_action_space=continuous_action_space,
                  trainable_std_dev=hp.trainable_std_dev,
                  init_log_std_dev=hp.init_log_std_dev)
    critic = Critic(obsv_dim)
    
    #Load from training checkpoint.
    CHECKPOINT_PATH = BASE_CHECKPOINT_PATH + f"{iteration}/"
    with open(CHECKPOINT_PATH + "parameters.pt", "rb") as f:
        checkpoint = pickle.load(f)
        
    actor_state_dict = torch.load(CHECKPOINT_PATH + "actor.pt", map_location=torch.device(DEVICE))
    critic_state_dict = torch.load(CHECKPOINT_PATH + "critic.pt", map_location=torch.device(DEVICE))
    
    actor.load_state_dict(actor_state_dict, strict=True) 
    critic.load_state_dict(critic_state_dict, strict=True)
    
    return actor, critic

def get_last_checkpoint_iteration():
    # needed to load from existing checkpoints and continue onwards
    # checks if checkpoint exists
    if os.path.isdir(BASE_CHECKPOINT_PATH):
        max_checkpoint_iteration = max([int(dirname) for dirname in os.listdir(BASE_CHECKPOINT_PATH)])
    else:
        max_checkpoint_iteration = 0
    return max_checkpoint_iteration


>Start() initializes the actor and critic networks and loads from checkpoints of they exist. 

In [None]:
def start():
    # initialization of NNs and checkpoints
    max_checkpoint_iteration = get_last_checkpoint_iteration()
    
    obsv_dim, action_dim, continuous_action_space = get_env_space()
    # actor and critic are pytorch NNs here
    actor = Actor(obsv_dim,
                  action_dim,
                  continuous_action_space=continuous_action_space,
                  trainable_std_dev=hp.trainable_std_dev,
                  init_log_std_dev=hp.init_log_std_dev)
    critic = Critic(obsv_dim)
        
    actor_optimizer = optim.AdamW(actor.parameters(), lr=hp.actor_learning_rate)
    critic_optimizer = optim.AdamW(critic.parameters(), lr=hp.critic_learning_rate)
    
        
    # check if checkpoint exists, then load it
    if max_checkpoint_iteration > 0:
        actor_state_dict, critic_state_dict, actor_optimizer_state_dict, critic_optimizer_state_dict, stop_conditions = load_checkpoint(max_checkpoint_iteration)
        
        actor.load_state_dict(actor_state_dict, strict=True) 
        critic.load_state_dict(critic_state_dict, strict=True)
        actor_optimizer.load_state_dict(actor_optimizer_state_dict)
        critic_optimizer.load_state_dict(critic_optimizer_state_dict)

        # move optimizers
        for state in actor_optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(DEVICE)

        for state in critic_optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(DEVICE)
                    
    return actor, critic, actor_optimizer, critic_optimizer, max_checkpoint_iteration

>In order to use these functions, save_checkpoint() should be called inside the training loop. CHECKPOINT_FREQUENCY makes sure that checkpoints are saved according to the specified frequency. Iteration should be included as a variable to provide names accordingly to the folders of checkpoints.

In [None]:
# add this to your training loop to make and save checkpoints    
if iteration % CHECKPOINT_FREQUENCY == 0: 
    save_checkpoint(actor,critic, actor_optimizer, critic_optimizer, iteration, stop_conditions)
iteration += 1

>After training, training can be conitued using the saved checkpoint or a trained pair of actor and critic can be loaded. The following code provides an example use of loading actor and critic from the multiple checkpoints.

In [None]:
# Load trained actor and critic
models = [x for x in range(0, 300, 10)]
for model in models:
    actor, critic = load_trained_model(model)