In [None]:
from IPython.display import display, Image as IPImage
from sb3_contrib import MaskablePPO
from stable_baselines3.common.vec_env import DummyVecEnv
from sb3_contrib.common.wrappers import ActionMasker
from sb3_contrib.common.maskable.utils import get_action_masks

from utils import DATA_DIR
from utils.rl.env import LoLDraftEnv, SelfPlayWrapper, action_mask_fn, FixedRoleDraftEnv
from utils.rl.visualizer import integrate_with_env

# Load the trained model
model = MaskablePPO.load(f"{DATA_DIR}/lol_draft_ppo")

# Create and wrap the environment
env = integrate_with_env(FixedRoleDraftEnv)()
env = SelfPlayWrapper(env)
env = ActionMasker(env, action_mask_fn)
env = DummyVecEnv([lambda: env])

In [None]:

# Reset the environment
obs = env.reset()
done = False

while not done:
    # Get the action mask
    action_masks = get_action_masks(env)

    # Use the action_masks when predicting the action
    action, _states = model.predict(obs, action_masks=action_masks, deterministic=True)

    # Step the environment
    obs, reward, done, info = env.step(action)

    if done[0]:  # DummyVecEnv returns a list of done flags
        print("Episode reward(blue side winrate):", reward)
        # Get the final render
        image_data = env.envs[0].render()

        if image_data is not None:
            display(IPImage(data=image_data))
        else:
            print("Draft is not complete or visualization is not available.")