In [None]:
# Import necessary libraries
from IPython.display import display, Image as IPImage
from sb3_contrib import MaskablePPO
from stable_baselines3.common.vec_env import DummyVecEnv
from sb3_contrib.common.wrappers import ActionMasker
from sb3_contrib.common.maskable.utils import get_action_masks
from utils import DATA_DIR
from utils.rl.env import FixedRoleDraftEnv, action_mask_fn
from utils.rl.self_play import ModelPool, SelfPlayWithPoolWrapper
from utils.rl.visualizer import integrate_with_env


# Create a simplified model pool for visualization
class VisualizationModelPool(ModelPool):
    def __init__(self, model_path: str):
        super().__init__(save_dir="")  # We don't need save_dir for visualization
        self.model = MaskablePPO.load(model_path)

    def sample_opponent(self):
        """Always return the same model for visualization"""
        return self.model


def visualize_self_play(
    model_path: str = f"{DATA_DIR}/self_play_models/final_model", num_games: int = 1
):
    # Load the model and create the pool
    model_pool = VisualizationModelPool(model_path)

    # Create and wrap the environment
    env = integrate_with_env(FixedRoleDraftEnv)()
    env = SelfPlayWithPoolWrapper(
        env, model_pool, agent_side="blue"
    )  # Force blue side for consistency
    env = ActionMasker(env, action_mask_fn)
    env = DummyVecEnv([lambda: env])

    for game in range(num_games):
        print(f"\nGame {game + 1}:")
        obs = env.reset()
        done = False
        truncated = False

        while not done and not truncated:
            # Get the action mask
            action_masks = get_action_masks(env)
            # Use the action_masks when predicting the action
            action, _states = model_pool.model.predict(
                obs, action_masks=action_masks, deterministic=True
            )

            # Step the environment
            obs, reward, done, info = env.step(action)
            done = done[0]  # DummyVecEnv returns a list

            if done:
                print("Episode reward (blue side winrate):", reward[0])
                # Get the final render
                image_data = env.envs[0].render()
                if image_data is not None:
                    display(IPImage(data=image_data))
                else:
                    print("Draft is not complete or visualization is not available.")

In [None]:
# View multiple games
visualize_self_play(num_games=5)

In [None]:
from IPython.display import display, Image as IPImage, clear_output
from sb3_contrib import MaskablePPO
from stable_baselines3.common.vec_env import DummyVecEnv
from sb3_contrib.common.wrappers import ActionMasker
from utils import DATA_DIR
from utils.rl.env import FixedRoleDraftEnv, action_mask_fn
from utils.rl.self_play import ModelPool

import numpy as np
from typing import List, Optional, Set
from difflib import get_close_matches
from utils.rl.champions import Champion
from gymnasium import Wrapper
import sys
import time
import torch as th


class HumanPlayWrapper(Wrapper):
    """Wrapper that allows human-model interaction from either side"""

    def __init__(
        self,
        env,
        human_side: int,  # 0 for blue, 1 for red
        model,
        role_names: Optional[List[str]] = None,
    ):
        super().__init__(env)
        self.human_side = human_side
        self.model = model
        self.role_names = role_names or ["TOP", "JUNGLE", "MID", "BOT", "UTILITY"]

        # Create lookup dictionaries for champion names
        self.id_to_name = {champion.id: champion.display_name for champion in Champion}
        self.name_to_id = {
            champion.display_name.lower(): champion.id for champion in Champion
        }
        # Add common abbreviations
        self.name_to_id.update(
            {
                "tf": self.name_to_id["twisted fate"],
                "mf": self.name_to_id["miss fortune"],
                "asol": self.name_to_id["aurelion sol"],
                "j4": self.name_to_id["jarvan iv"],
                "tk": self.name_to_id["tahm kench"],
            }
        )

        # Track bans
        self.blue_bans: List[int] = []
        self.red_bans: List[int] = []

        # Track draft history
        self.draft_history: List[dict] = []

    def reset(self, **kwargs):
        """Reset the environment and display initial state"""
        self.blue_bans = []
        self.red_bans = []
        obs = self.env.reset(**kwargs)
        self.display_state()
        return obs

    def _get_model_suggestions(
        self, obs: dict, action_mask: np.ndarray, n_suggestions: int = 5
    ) -> List[tuple]:
        """Get top n suggested actions from the model with their probabilities"""
        # Get action probabilities from model's policy
        obs_tensor = {
            k: (
                th.from_numpy(v).unsqueeze(0)
                if len(v.shape) == 1
                else th.from_numpy(v).unsqueeze(0)
            )
            for k, v in obs.items()
        }
        with th.no_grad():
            action_probs = (
                self.model.policy.get_distribution(obs_tensor)
                .distribution.probs[0]
                .cpu()
                .numpy()
            )

        # Mask invalid actions
        action_probs = action_probs * action_mask

        # Get top n valid actions
        top_indices = np.argsort(action_probs)[-n_suggestions:][::-1]

        return [(idx, action_probs[idx]) for idx in top_indices]

    def display_state(self):
        """Display current draft state and available actions"""
        clear_output(wait=True)
        action_info = self.env.get_current_draft_step()
        phase = "BAN" if action_info["phase"] == 0 else "PICK"
        team = "RED" if action_info["team"] == 1 else "BLUE"

        print(f"\nCurrent Draft State:")
        print(f"Phase: {phase}")
        print(f"Team to move: {team}")

        # Get and display current model suggestions
        obs = self._get_obs()
        action_mask = self.env.get_action_mask()
        suggestions = self._get_model_suggestions(obs, action_mask)

        print("\nModel's top 5 suggestions for current state:")
        for champ_id, prob in suggestions:
            if champ_id in self.id_to_name:
                print(f"{self.id_to_name[champ_id]}: {prob:.3f}")

        # Display bans
        if self.blue_bans:
            print("\nBlue Team Bans:")
            print(", ".join(self.id_to_name[ban_id] for ban_id in self.blue_bans))

        if self.red_bans:
            print("\nRed Team Bans:")
            print(", ".join(self.id_to_name[ban_id] for ban_id in self.red_bans))

        # Display current picks if any
        if np.any(self.env.blue_picks):
            print("\nBlue Team Picks:")
            for i, role in enumerate(self.env.roles):
                if self.env.blue_roles_picked[i]:
                    champ_id = np.argmax(self.env.blue_picks[i])
                    champ_name = self.id_to_name.get(champ_id, f"Champion {champ_id}")
                    print(f"{role}: {champ_name}")

        if np.any(self.env.red_picks):
            print("\nRed Team Picks:")
            for i, role in enumerate(self.env.roles):
                if self.env.red_roles_picked[i]:
                    champ_id = np.argmax(self.env.red_picks[i])
                    champ_name = self.id_to_name.get(champ_id, f"Champion {champ_id}")
                    print(f"{role}: {champ_name}")

        # Display move history
        if self.draft_history:
            print("\nMove History:")
            for move in self.draft_history:
                phase_str = "BAN" if move["phase"] == 0 else "PICK"
                team_str = "RED" if move["team"] == 1 else "BLUE"
                chosen_champ = self.id_to_name[move["chosen_action"]]
                print(f"{team_str} {phase_str}: {chosen_champ}")
                print("Top suggestions were:")
                for champ_id, prob in move["suggestions"]:
                    print(f"  {self.id_to_name[champ_id]}: {prob:.3f}")
                print()

        sys.stdout.flush()
        time.sleep(0.1)  # Small delay to ensure display completes

    def step(self, base_action):
        """Handle both model and human turns"""
        action_info = self.env.get_current_draft_step()
        phase = action_info["phase"]
        current_team = action_info["team"]

        # Get current state info before action
        obs = self._get_obs()
        action_mask = self.env.get_action_mask()
        suggestions = self._get_model_suggestions(obs, action_mask)

        # Get action based on current turn
        if current_team == self.human_side:
            action = self._get_human_action(phase)
        else:
            action = self._get_model_action(phase)

        # Store draft history
        chosen_action = action[0] if isinstance(action, (list, np.ndarray)) else action

        self.draft_history.append(
            {
                "phase": phase,
                "team": current_team,
                "chosen_action": chosen_action,
                "suggestions": suggestions,
            }
        )

        # Track bans
        if phase == 0:
            if current_team == 0:  # Blue side
                self.blue_bans.append(chosen_action)
            else:  # Red side
                self.red_bans.append(chosen_action)

        result = self.env.step(chosen_action)
        return result

    def _get_human_action(self, phase: int) -> int:
        """Get action from human input"""
        self.display_state()
        sys.stdout.flush()  # ensure state is displayed before input
        action_mask = self.env.get_action_mask()
        valid_actions = set(np.where(action_mask == 1)[0])

        while True:
            try:
                prompt = (
                    "\nEnter champion name to ban: "
                    if phase == 0
                    else "\nEnter champion name to pick: "
                )
                search = input(prompt)
                chosen_id = self._process_human_input(search, valid_actions)

                if chosen_id is not None:
                    return chosen_id

            except Exception as e:
                print(f"Error: {e}")
                print("Please try again.")

    def _get_model_action(self, phase: int) -> np.ndarray:
        """Get action from model prediction"""
        obs = self._get_obs()
        action_mask = self.env.get_action_mask()

        action, _states = self.model.predict(
            obs,
            action_masks=np.array([action_mask]),
            deterministic=True,
        )
        return action

    def _get_obs(self):
        """Get observation in the format expected by the model"""
        return {
            "available_champions": self.env.available_champions.copy(),
            "blue_picks": self.env.blue_picks.copy(),
            "red_picks": self.env.red_picks.copy(),
            "blue_ordered_picks": self.env.blue_ordered_picks.copy(),
            "red_ordered_picks": self.env.red_ordered_picks.copy(),
            "blue_roles_picked": self.env.blue_roles_picked.copy(),
            "red_roles_picked": self.env.red_roles_picked.copy(),
            "phase": np.array(
                [self.env.get_current_draft_step()["phase"]], dtype=np.int8
            ),
            "turn": np.array(
                [self.env.get_current_draft_step()["team"]], dtype=np.int8
            ),
            "action_mask": self.env.get_action_mask(),
        }

    def _process_human_input(
        self, search: str, valid_actions: Set[int]
    ) -> Optional[int]:
        """Process human input and return chosen champion ID if valid"""
        matches = self.find_champion(search)

        if not matches:
            print("No champions found matching that name. Please try again.")
            return None

        if len(matches) > 1:
            valid_matches = [m for m in matches if m[0] in valid_actions]

            if not valid_matches:
                print("None of the matches are available. Please try again.")
                return None

            print("\nDid you mean:")
            for i, (champ_id, champ_name) in enumerate(valid_matches, 1):
                print(f"{i}. {champ_name}")

            choice = input("Enter number (or press Enter to search again): ")
            if not choice:
                return None

            try:
                chosen_id = valid_matches[int(choice) - 1][0]
            except (ValueError, IndexError):
                print("Invalid choice. Please try again.")
                return None
        else:
            chosen_id = matches[0][0]

        if chosen_id in valid_actions:
            return chosen_id

        print(
            f"{self.id_to_name[chosen_id]} is not available. Please choose from the valid champions listed above."
        )
        return None

    def find_champion(self, search_term: str) -> List[tuple]:
        """Find champions matching the search term"""
        search_term = search_term.lower().strip()

        # Direct match with name or abbreviation
        if search_term in self.name_to_id:
            champ_id = self.name_to_id[search_term]
            return [(champ_id, self.id_to_name[champ_id])]

        # Try to find close matches
        all_names = list(self.name_to_id.keys())
        matches = get_close_matches(search_term, all_names, n=3, cutoff=0.6)

        return [
            (self.name_to_id[name], self.id_to_name[self.name_to_id[name]])
            for name in matches
        ]


def play_vs_model(
    human_side: str = "blue",
    model_path: str = f"{DATA_DIR}/self_play_models/final_model",
):
    """
    Play against the trained model.

    Args:
        human_side: "blue" or "red"
        model_path: path to the trained model
    """
    # Convert side string to integer
    human_side_int = 0 if human_side.lower() == "blue" else 1

    # Load the model
    model = MaskablePPO.load(model_path)

    # Create and wrap the environment
    env = integrate_with_env(FixedRoleDraftEnv)()
    env = HumanPlayWrapper(env, human_side_int, model)
    env = ActionMasker(env, action_mask_fn)
    env = DummyVecEnv([lambda: env])

    # Reset the environment
    obs = env.reset()
    done = False

    while not done:
        # Take step in environment
        action = [0]  # Dummy action, will be handled by wrapper
        obs, reward, done, info = env.step(action)
        done = done[0]  # DummyVecEnv returns a list

    print("\nDraft Complete!")
    if human_side_int == 0:
        print("Final winrate prediction (your side):", reward[0])
    else:
        print("Final winrate prediction (your side):", 1 - reward[0])

    # Get the final render
    image_data = env.envs[0].render()
    if image_data is not None:
        display(IPImage(data=image_data))
    else:
        print("Visualization not available.")

In [None]:
play_vs_model("blue")