Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run_baseline_parallel.py does not seem to restart from checkpoint? #56

Open
SomeBelgianDude opened this issue Oct 16, 2023 · 4 comments

Comments

@SomeBelgianDude
Copy link

might be line 50?
in run_baseline_parallel.py it is defined differently as
file_name = 'session_e41c9eff/poke_38207488_steps' #'session_e41c9eff/poke_250871808_steps'
in run_baseline.py it is defined differently as
file_name = 'poke_' #'best_12-7/poke_12_b'

I have no folder 'session_e41c9eff', so this seems to be misconfigured.

Can this be fixed?

@PWhiddy
Copy link
Owner

PWhiddy commented Oct 17, 2023

Ah, this was left over from some old experiments. If the folder isn't found it will start a new training session. If you point it to an existing one, it will restarting training from that checkpoint. This could definitely be made clearer in the code 🤔

@Muktiharba090909
Copy link

Ah, this was left over from some old experiments. If the folder isn't found it will start a new training session. If you point it to an existing one, it will restarting training from that checkpoint. This could definitely be made clearer in the code 🤔

@techmore
Copy link

I was able to get it to restore the latest session but I'm getting an error. Maybe someone sees what I'm doing wrong. I added the code needed at the bottom.

session_4da05e87_main_good/poke_439746560_steps

loading checkpoint
/Users/seandolbec/miniconda3/lib/python3.11/site-packages/stable_baselines3/common/save_util.py:166: UserWarning: Could not deserialize object lr_schedule. Consider using custom_objects argument to replace this object.
Exception: code expected at least 16 arguments, got 15
warnings.warn(
/Users/techmore/miniconda3/lib/python3.11/site-packages/stable_baselines3/common/save_util.py:166: UserWarning: Could not deserialize object clip_range. Consider using custom_objects argument to replace this object.
Exception: code expected at least 16 arguments, got 15
warnings.warn(
/Users/techmore/miniconda3/lib/python3.11/site-packages/stable_baselines3/common/vec_env/patch_gym.py:95: UserWarning: You loaded a model that was trained using OpenAI Gym. We strongly recommend transitioning to Gymnasium by saving that model again.
warnings.warn(
Wrapping the env in a VecTransposeImage.
/Users/techmore/miniconda3/lib/python3.11/site-packages/stable_baselines3/common/base_class.py:752: UserWarning: You are probably loading a model saved with SB3 < 1.7.0, we deactivated exact_match so you can save the model again to avoid issues in the future (see DLR-RM/stable-baselines3#1233 for more info). Original error: Error(s) in loading state_dict for ActorCriticCnnPolicy:
Missing key(s) in state_dict: "pi_features_extractor.cnn.0.weight", "pi_features_extractor.cnn.0.bias", "pi_features_extractor.cnn.2.weight", "pi_features_extractor.cnn.2.bias", "pi_features_extractor.cnn.4.weight", "pi_features_extractor.cnn.4.bias", "pi_features_extractor.linear.0.weight", "pi_features_extractor.linear.0.bias", "vf_features_extractor.cnn.0.weight", "vf_features_extractor.cnn.0.bias", "vf_features_extractor.cnn.2.weight", "vf_features_extractor.cnn.2.bias", "vf_features_extractor.cnn.4.weight", "vf_features_extractor.cnn.4.bias", "vf_features_extractor.linear.0.weight", "vf_features_extractor.linear.0.bias".
Note: the model should still work fine, this only a warning.
warnings.warn(
step: 490 event: 0.00 level: 4.00 heal: 0.08 op_lvl: 0.00 dead: -0.00 badge: 0.00 explore: 1.01 sum: 5.08healed: 0.5135135135135135
step: 531 event: 0.00 level: 2.00 heal: 0.41 op_lvl: 0.00 dead: -0.00 badge: 0.00 explore: 0.59 sum: 3.00healed: 0.7692307692307692
step: 2047 event: 0.00 level: 5.00 heal: 1.64 op_lvl: 0.00 dead: -0.10 badge: 0.00 explore: 2.16 sum: 8.70------------------------------
| time/ | |
| fps | 96 |
| iterations | 1 |
| time_elapsed | 211 |
| total_timesteps | 20480 |

import glob
import re

def find_latest_session_and_poke():
all_folders = os.listdir()
session_folders = [folder for folder in all_folders if re.match(r'session_[0-9a-fA-F]{8}', folder)]
session_folders.sort()

for session_folder in session_folders:
    poke_files = glob.glob(f"{session_folder}/poke_*_steps.zip")
    poke_files.sort()
    if poke_files:
        latest_poke_file = poke_files[-1]
        # Remove '.zip' from the filename
        latest_poke_file = latest_poke_file[:-4]
        return session_folder, latest_poke_file

return None, None

if name == 'main':
...
session_folder, latest_poke_file = find_latest_session_and_poke()
print('\n' + latest_poke_file)
if latest_poke_file:
print('\nloading checkpoint')
model = PPO.load(latest_poke_file, env=env)

@techmore
Copy link

I just noticed you can read "saves_to_record.txt" and get the last value instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants