In [None]:
from IPython.display import display, Image as IPImage
from sb3_contrib import MaskablePPO
from stable_baselines3.common.vec_env import DummyVecEnv
from sb3_contrib.common.wrappers import ActionMasker
from sb3_contrib.common.maskable.utils import get_action_masks

from utils import DATA_DIR
from utils.rl.env import LoLDraftEnv, SelfPlayWrapper, action_mask_fn, FixedRoleDraftEnv
from utils.rl.visualizer import integrate_with_env

# Load the trained model
model = MaskablePPO.load(f"{DATA_DIR}/lol_draft_ppo")

# Create and wrap the environment
env = integrate_with_env(FixedRoleDraftEnv)()
env = SelfPlayWrapper(env)
env = ActionMasker(env, action_mask_fn)
env = DummyVecEnv([lambda: env])

In [None]:

# Reset the environment
obs = env.reset()
done = False

while not done:
    # Get the action mask
    action_masks = get_action_masks(env)

    # Use the action_masks when predicting the action
    action, _states = model.predict(obs, action_masks=action_masks, deterministic=True)

    # Step the environment
    obs, reward, done, info = env.step(action)

    if done[0]:  # DummyVecEnv returns a list of done flags
        print("Episode reward(blue side winrate):", reward)
        # Get the final render
        image_data = env.envs[0].render()

        if image_data is not None:
            display(IPImage(data=image_data))
        else:
            print("Draft is not complete or visualization is not available.")

In [None]:
from IPython.display import display, Image as IPImage, clear_output
import numpy as np
from typing import List, Optional
from difflib import get_close_matches
from utils.rl.champions import Champion


class HumanPlayWrapper(gym.Wrapper):
    def __init__(self, env, role_names: Optional[List[str]] = None):
        super().__init__(env)
        self.role_names = role_names or ["TOP", "JUNGLE", "MID", "BOT", "UTILITY"]

        # Create lookup dictionaries for champion names
        self.id_to_name = {champion.id: champion.display_name for champion in Champion}
        self.name_to_id = {
            champion.display_name.lower(): champion.id for champion in Champion
        }
        # Add common abbreviations
        self.name_to_id.update(
            {
                "tf": self.name_to_id["twisted fate"],
                "mf": self.name_to_id["miss fortune"],
                "asol": self.name_to_id["aurelion sol"],
                "j4": self.name_to_id["jarvan iv"],
                "tk": self.name_to_id["tahm kench"],
            }
        )

        # Track bans
        self.blue_bans = []
        self.red_bans = []

    def reset(self, **kwargs):
        """Reset the environment and display initial state"""
        self.blue_bans = []
        self.red_bans = []
        obs = self.env.reset(**kwargs)
        self.display_state()
        return obs

    def display_state(self):
        """Display current draft state and available actions"""
        clear_output(wait=True)
        action_info = self.env.get_action_info()
        phase = "BAN" if action_info["phase"] == 0 else "PICK"
        team = "RED" if action_info["team"] == 1 else "BLUE"

        print(f"\nCurrent Draft State:")
        print(f"Phase: {phase}")
        print(f"Team to move: {team}")

        # Display bans
        if self.blue_bans:
            print("\nBlue Team Bans:")
            print(", ".join(self.id_to_name[ban_id] for ban_id in self.blue_bans))

        if self.red_bans:
            print("\nRed Team Bans:")
            print(", ".join(self.id_to_name[ban_id] for ban_id in self.red_bans))

        # Display current picks if any
        if np.any(self.env.blue_picks):
            print("\nBlue Team Picks:")
            for i, role in enumerate(self.env.roles):
                if self.env.blue_roles_picked[i]:
                    champ_id = np.argmax(self.env.blue_picks[i])
                    champ_name = self.id_to_name.get(champ_id, f"Champion {champ_id}")
                    print(f"{role}: {champ_name}")

        if np.any(self.env.red_picks):
            print("\nRed Team Picks:")
            for i, role in enumerate(self.env.roles):
                if self.env.red_roles_picked[i]:
                    champ_id = np.argmax(self.env.red_picks[i])
                    champ_name = self.id_to_name.get(champ_id, f"Champion {champ_id}")
                    print(f"{role}: {champ_name}")

        # If it's pick phase, show available roles
        if phase == "PICK":
            roles_picked = (
                self.env.red_roles_picked
                if action_info["team"] == 1
                else self.env.blue_roles_picked
            )
            available_roles = [
                role for i, role in enumerate(self.role_names) if roles_picked[i] == 0
            ]
            print(f"\nAvailable roles: {', '.join(available_roles)}")

        # Show valid actions
        action_mask = self.env.get_action_mask()
        valid_actions = np.where(action_mask == 1)[0]
        valid_champs = [
            champ_id for champ_id in valid_actions if champ_id in self.id_to_name
        ]

        if phase == "BAN":
            print("\nAvailable champions to ban:")
        else:
            print("\nAvailable champions to pick:")

        # Sort by name for easier reading
        valid_champs.sort(key=lambda x: self.id_to_name[x])
        for champ_id in valid_champs:
            print(f"- {self.id_to_name[champ_id]}")


    def step(self, action):
        """Modified step function that handles both AI and human turns"""
        action_info = self.env.get_action_info()
        phase = action_info["phase"]
        current_team = action_info["team"]

        if current_team == 0:  # AI's turn (blue side)
            result = self.env.step(action)
            # Track ban if in ban phase
            if phase == 0:
                # Handle both scalar and array actions
                ban_id = action[0] if isinstance(action, (list, np.ndarray)) else action
                self.blue_bans.append(ban_id)
            self.display_state()
            return result
        else:  # Human's turn (red side)
            self.display_state()

            # Get valid actions
            action_mask = self.env.get_action_mask()
            valid_actions = set(np.where(action_mask == 1)[0])

            # Get human input
            while True:
                try:
                    if phase == 0:
                        search = input("\nEnter champion name to ban: ")
                    else:
                        search = input("\nEnter champion name to pick: ")

                    matches = self.find_champion(search)

                    if not matches:
                        print("No champions found matching that name. Please try again.")
                        continue

                    if len(matches) > 1:
                        print("\nDid you mean:")
                        valid_matches = [m for m in matches if m[0] in valid_actions]
                        for i, (champ_id, champ_name) in enumerate(valid_matches, 1):
                            print(f"{i}. {champ_name}")

                        if not valid_matches:
                            print("None of the matches are available. Please try again.")
                            continue

                        choice = input("Enter number (or press Enter to search again): ")
                        if not choice:
                            continue

                        try:
                            chosen_id = valid_matches[int(choice) - 1][0]
                        except (ValueError, IndexError):
                            print("Invalid choice. Please try again.")
                            continue
                    else:
                        chosen_id = matches[0][0]

                    if chosen_id in valid_actions:
                        # Track ban if in ban phase
                        if phase == 0:
                            self.red_bans.append(chosen_id)
                        return self.env.step(chosen_id)
                    else:
                        print(
                            f"{self.id_to_name[chosen_id]} is not available. Please choose from the valid champions listed above."
                        )

                except Exception as e:
                    print(f"Error: {e}")
                    print("Please try again.")

    def find_champion(self, search_term: str) -> List[tuple]:
        """Find champions matching the search term"""
        search_term = search_term.lower().strip()

        # Direct match with name or abbreviation
        if search_term in self.name_to_id:
            champ_id = self.name_to_id[search_term]
            return [(champ_id, self.id_to_name[champ_id])]

        # Try to find close matches
        all_names = list(self.name_to_id.keys())
        matches = get_close_matches(search_term, all_names, n=3, cutoff=0.6)

        return [
            (self.name_to_id[name], self.id_to_name[self.name_to_id[name]])
            for name in matches
        ]


# Function to display the current state of the draft
def display_draft_state(env):
    """Display current draft state including picks and bans"""
    action_info = env.get_action_info()
    phase = "BAN" if action_info["phase"] == 0 else "PICK"
    team = "RED" if action_info["team"] == 1 else "BLUE"

    print(f"\nCurrent Draft State:")
    print(f"Phase: {phase}")
    print(f"Team to move: {team}")

    # Display current picks if any
    if np.any(env.env.blue_picks):
        print("\nBlue Team Picks:")
        for i, role in enumerate(env.env.roles):
            if env.env.blue_roles_picked[i]:
                champ = np.argmax(env.env.blue_picks[i])
                print(f"{role}: Champion {champ}")

    if np.any(env.env.red_picks):
        print("\nRed Team Picks:")
        for i, role in enumerate(env.env.roles):
            if env.env.red_roles_picked[i]:
                champ = np.argmax(env.env.red_picks[i])
                print(f"{role}: Champion {champ}")

In [None]:
from IPython.display import display, Image as IPImage, clear_output
from sb3_contrib import MaskablePPO
from stable_baselines3.common.vec_env import DummyVecEnv
from sb3_contrib.common.wrappers import ActionMasker
from sb3_contrib.common.maskable.utils import get_action_masks

# Load the trained model
model = MaskablePPO.load(f"{DATA_DIR}/lol_draft_ppo")

# Create and wrap the environment
env = integrate_with_env(FixedRoleDraftEnv)()
env = HumanPlayWrapper(env)
env = ActionMasker(env, action_mask_fn)
env = DummyVecEnv([lambda: env])

# Reset the environment
obs = env.reset()
done = False

while not done:
    # Get the action mask
    action_masks = get_action_masks(env)
    
    # Get current action info
    action_info = env.envs[0].env.env.get_action_info()  # Need extra .env due to wrappers
    
    if action_info["team"] == 0:  # AI's turn
        # Use the model to predict action
        action, _states = model.predict(obs, action_masks=action_masks, deterministic=True)
    else:  # Human's turn
        action = [0]  # Dummy action, will be ignored by wrapper
    
    # Take step in environment
    obs, reward, done, info = env.step(action)
    
    if done[0]:  # DummyVecEnv returns a list of done flags
        print("\nDraft Complete!")
        print("Episode reward (blue side winrate):", reward[0])
        
        # Get the final render
        image_data = env.envs[0].render()
        if image_data is not None:
            display(IPImage(data=image_data))
        else:
            print("Visualization not available.")
