Install Dependencies

In [None]:
%pip install gymnasium[box2d]

In [None]:
%pip install swig

In [None]:
%pip install stable-baselines3[extra]

In [None]:
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

Import Dependencies

In [None]:
import os
import gymnasium
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3 import DQN


Constants

In [None]:
ENV_NAME = 'LunarLander-v2'
ENV_RENDER_MODE = 'rgb_array'
VEC_ENV_RENDER_MODE = 'human'

CALLBACK_LOG_DIR = '\logs\\lunar-lander'
CALLBACK_CHECKPOINT_DIR = '\models\\lunar-lander'
CALLBACK_CHECK_FREQ = 10000
CALLBACK_ON_TRAINING_MODEL = 'lunar_lander_v2_training_model'
CALLBACK_ON_TRAINING_END = 'lunar_lander_v2_training_end_'

VERBOSE = 1

ALGORITHM_BEST_MODEL_NAME = 'lunar_lander_v2_best_model.zip'
ALGORITHM_POLICY = 'MlpPolicy'
ALGORITHM_DEVICE = 'cuda'
ALGORITHM_TOTAL_TIMESTEPS = 1000000

ALGORITHM_RENDER_MODEL = False
ALGORITHM_RENDER_EPISODES = 10
ALGORITHM_RENDER_STEPS = 10000
ALGORITHM_PREDICT_DETERMINISTIC = False

ALGORITHM_NEW_MODEL = True
ALGORITHM_LOAD_MODEL = False


MESSAGE_RENDERING_MODEL = 'RENDERING MODEL'
MESSAGE_LOADING_MODEL = 'LOADING MODEL'
MESSAGE_TRAINING_NEW_MODEL = 'NEW MODEL'


Creating the Env

In [None]:
env = gymnasium.make(ENV_NAME, render_mode=ENV_RENDER_MODE)

Creating Callback

In [None]:
class TrainAndLoggingCallback(BaseCallback):
    def __init__(self, check_freq, save_path, verbose=VERBOSE):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)
    
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, CALLBACK_ON_TRAINING_MODEL)
            self.model.save(model_path)

        return True

    def _on_training_end(self):
        model_path = os.path.join(self.save_path, f'{CALLBACK_ON_TRAINING_END}_{ALGORITHM_TOTAL_TIMESTEPS}')
        self.model.save(model_path)

callback = TrainAndLoggingCallback(check_freq=CALLBACK_CHECK_FREQ, save_path=CALLBACK_CHECKPOINT_DIR)

Loading, Training and Rendering

In [None]:
model = None

if ALGORITHM_RENDER_MODEL is True:
    print(MESSAGE_RENDERING_MODEL)
    model = DQN.load(ALGORITHM_BEST_MODEL_NAME, env=env)
    vec_env = model.get_env()

    for ep in range(ALGORITHM_RENDER_EPISODES):
        obs = vec_env.reset()
        for step in range(ALGORITHM_RENDER_STEPS):
            action, _ = model.predict(obs, deterministic=ALGORITHM_PREDICT_DETERMINISTIC)
            obs, rewards, dones, info = vec_env.step(action)
            vec_env.render(VEC_ENV_RENDER_MODE)
    
    env.close()

elif ALGORITHM_LOAD_MODEL is True:
    print(MESSAGE_LOADING_MODEL)
    model = DQN.load(ALGORITHM_BEST_MODEL_NAME, env=env, device=ALGORITHM_DEVICE)
    model.learn(total_timesteps=ALGORITHM_TOTAL_TIMESTEPS, callback=callback) 

elif ALGORITHM_NEW_MODEL is True:
    print(MESSAGE_TRAINING_NEW_MODEL)
    model = DQN(ALGORITHM_POLICY, env, tensorboard_log = CALLBACK_LOG_DIR, verbose=VERBOSE, device=ALGORITHM_DEVICE)
    model.learn(total_timesteps=ALGORITHM_TOTAL_TIMESTEPS, callback=callback) 