In [None]:
# Import necessary libraries
from IPython.display import display, Image as IPImage
import os
from typing import List, Optional
import numpy as np
import pickle
from pathlib import Path
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
from sb3_contrib.common.maskable.utils import get_action_masks
from utils import DATA_DIR
from utils.rl.env import FlexibleRoleDraftEnv, action_mask_fn
from utils.rl.self_play import ModelPool, SelfPlayWithPoolWrapper
from utils.rl.visualizer import integrate_with_env
from utils.match_prediction import get_best_device, PREPARED_DATA_DIR

import warnings

# sb3_contrib is not updated to latest api, this is the message we are ignoring:
# WARN: env.get_action_mask to get variables from other wrappers is deprecated and will be removed in v1.0
warnings.filterwarnings("ignore", message=".*env.get_action_mask.*")

device = get_best_device()


# Create a simplified model pool for visualization
class VisualizationModelPool(ModelPool):
    def __init__(self, model_path: str):
        super().__init__(save_dir="")  # We don't need save_dir for visualization
        self.model = MaskablePPO.load(
            model_path, device=device
        )  # need to specify device to avoid error because of numeric constraint
        # TODO: this might cause problems if inference is done on a different device

    def sample_opponent(self):
        """Always return the same model for visualization"""
        return self.model


# TODO: refactor to commong code
def get_latest_patches(n_patches: int = 5) -> List[int]:
    """
    Load patch mapping and return the n latest numerical patches.

    Args:
        n_patches: Number of latest patches to return

    Returns:
        List of numerical patch values, sorted from newest to oldest
    """
    patch_mapping_path = Path(PREPARED_DATA_DIR) / "patch_mapping.pkl"
    with open(patch_mapping_path, "rb") as f:
        patch_data = pickle.load(f)

    # Get unique raw patch numbers
    raw_patches = sorted(set(patch_data["mapping"].keys()))

    # Return the n latest patches (highest numbers)
    return raw_patches[-n_patches:]


patches = get_latest_patches()


def visualize_self_play(
    model_path: str = f"{DATA_DIR}/self_play_models/final_model",
    num_games: int = 1,
    team1_custom_role_matrix: Optional[np.ndarray] = None,
    team2_custom_role_matrix: Optional[np.ndarray] = None,
    team1_use_fallback: bool = False,
    team2_use_fallback: bool = False,
    numeric_patch: Optional[int] = None,
):
    """
    Visualize self-play matches between two teams, optionally with custom role matrices.
    Plays num_games matches with team1 as blue and team2 as red, then num_games matches with sides swapped.

    Parameters:
        model_path: Path to the trained model.
        num_games: Number of games to play per side.
        team1_custom_role_matrix: Optional custom role matrix for team1.
        team2_custom_role_matrix: Optional custom role matrix for team2.
        team1_use_fallback: If True, use fallback role matrix for team1.
        team2_use_fallback: If True, use fallback role matrix for team2.
        numeric_patch: Optional patch number to use.
    """
    # Load the model and create the pool
    model_pool = VisualizationModelPool(model_path)

    side_configurations = [
        # First num_games games, team1 as blue, team2 as red
        {
            "agent_side": "blue",
            "blue_custom_role_matrix": team1_custom_role_matrix,
            "blue_use_fallback": team1_use_fallback,
            "red_custom_role_matrix": team2_custom_role_matrix,
            "red_use_fallback": team2_use_fallback,
            "label": "Team1 as Blue, Team2 as Red"
        },
        # Second num_games games, team2 as blue, team1 as red
        {
            "agent_side": "blue",
            "blue_custom_role_matrix": team2_custom_role_matrix,
            "blue_use_fallback": team2_use_fallback,
            "red_custom_role_matrix": team1_custom_role_matrix,
            "red_use_fallback": team1_use_fallback,
            "label": "Team2 as Blue, Team1 as Red"
        },
    ]

    for config in side_configurations:
        print(f"\n{'='*50}\n{config['label']}\n{'='*50}")
        for game in range(num_games):
            print(f"\nGame {game + 1}:")
            # Create and wrap the environment each time
            env = integrate_with_env(FlexibleRoleDraftEnv)(patches=patches)

            # TODO: first obs is without patch info! need to fix that
            obs, _ = env.reset()

            if numeric_patch is not None:
                env.patch = numeric_patch

            # Set custom role matrices
            if config["blue_custom_role_matrix"] is not None:
                if config["blue_use_fallback"]:
                    env.state.blue_role_matrix = env.state.blue_fallback_role_matrix.copy()
                else:
                    env.state.blue_role_matrix = config["blue_custom_role_matrix"].copy()
            if config["red_custom_role_matrix"] is not None:
                if config["red_use_fallback"]:
                    env.state.red_role_matrix = env.state.red_fallback_role_matrix.copy()
                else:
                    env.state.red_role_matrix = config["red_custom_role_matrix"].copy()

            env = SelfPlayWithPoolWrapper(
                env, model_pool, agent_side=config["agent_side"]
            )
            env = ActionMasker(env, action_mask_fn)

            done = False
            truncated = False

            while not done and not truncated:
                # Get the action mask
                action_masks = get_action_masks(env)
                # Use the action_masks when predicting the action
                action, _states = model_pool.model.predict(
                    obs, action_masks=action_masks, deterministic=True
                )

                # Step the environment
                obs, reward, done, truncated, info = env.step(int(action))

            print(f"Episode reward ({config['agent_side']} winrate):", reward)
            # Get the final render
            image_data = env.render()
            if image_data is not None:
                display(IPImage(data=image_data))
            else:
                print("Draft is not complete or visualization is not available.")

In [None]:
# View multiple games
visualize_self_play(num_games=1, numeric_patch=14 * 50 + 22)

In [None]:
# Import necessary libraries
from IPython.display import display, Image as IPImage, clear_output
import os
from typing import List, Optional, Set, Dict, Tuple
import pickle
from pathlib import Path
from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
from utils import DATA_DIR
from utils.rl.env import FlexibleRoleDraftEnv, action_mask_fn
from utils.rl.self_play import ModelPool
from utils.rl.visualizer import integrate_with_env
from utils.match_prediction import get_best_device, PREPARED_DATA_DIR
from utils.rl.champions import Champion
from difflib import get_close_matches
import numpy as np
import sys
import time
import warnings
import torch as th

# Suppress deprecated warnings
warnings.filterwarnings("ignore", message=".*env.get_action_mask.*")

device = get_best_device()


def get_latest_patches(n_patches: int = 5) -> List[int]:
    """
    Load patch mapping and return the n latest numerical patches.

    Args:
        n_patches: Number of latest patches to return

    Returns:
        List of numerical patch values, sorted from newest to oldest
    """
    patch_mapping_path = Path(PREPARED_DATA_DIR) / "patch_mapping.pkl"
    with open(patch_mapping_path, "rb") as f:
        patch_data = pickle.load(f)

    # Get unique raw patch numbers
    raw_patches = sorted(set(patch_data["mapping"].keys()))

    # Return the n latest patches (highest numbers)
    return raw_patches[-n_patches:]


patches = get_latest_patches()


# Create a simplified model pool for visualization
class VisualizationModelPool(ModelPool):
    def __init__(self, model_path: str):
        super().__init__(save_dir="")  # We don't need save_dir for visualization
        self.model = MaskablePPO.load(
            model_path, device=device
        )  # Load the trained model

    def sample_opponent(self):
        """Always return the same model for visualization"""
        return self.model


def get_action_suggestions(
    model: MaskablePPO,
    obs: Dict[str, np.ndarray],
    action_mask: np.ndarray,
    n_suggestions: int = 5,
    picks_to_watch: Optional[List[Champion]] = None,
) -> Tuple[List[Tuple[int, float]], List[Tuple[int, float]]]:
    """
    Get top n suggested actions and probabilities for specific champions from the model.

    Args:
        model: The trained RL model
        obs: Current observation
        action_mask: Action mask indicating valid actions
        n_suggestions: Number of top suggestions to return
        picks_to_watch: Optional list of Champion objects to track probabilities for

    Returns:
        Tuple containing:
        - List of (action_id, probability) for top n suggestions
        - List of (action_id, probability) for watched picks that are valid moves
    """
    # Handle both dictionary and tuple observations
    if isinstance(obs, tuple):
        obs = obs[0]

    # Get action probabilities from model's policy
    obs_tensor = {
        k: th.from_numpy(v).unsqueeze(0).to(model.device) for k, v in obs.items()
    }
    with th.no_grad():
        distribution = model.policy.get_distribution(obs_tensor)
        action_probs = distribution.distribution.probs[0].cpu().numpy()

    # Mask invalid actions
    action_probs = action_probs * action_mask

    # Normalize probabilities
    action_probs /= action_probs.sum() + 1e-8

    # Get top n valid actions
    top_indices = np.argsort(action_probs)[-n_suggestions:][::-1]
    top_suggestions = [(int(idx), float(action_probs[idx])) for idx in top_indices]

    # Get probabilities for watched picks if they're valid moves
    watched_suggestions = []
    if picks_to_watch:
        for champion in picks_to_watch:
            if action_mask[champion.id]:  # Only include if it's a valid move
                watched_suggestions.append(
                    (champion.id, float(action_probs[champion.id]))
                )
        # Sort watched picks by probability
        watched_suggestions.sort(key=lambda x: x[1], reverse=True)

    return top_suggestions, watched_suggestions


def create_role_matrix_from_pools(
    champion_pools: Dict[str, List[Champion]]
) -> np.ndarray:
    """
    Create a role matrix from specified champion pools.

    Args:
        champion_pools: Dictionary mapping role names to lists of Champions.

    Returns:
        Numpy array representing the role matrix.
    """
    num_champions = (
        max(champ.id for champ in Champion) + 2
    )  # +2 I think because of unknown champions, TODO: load it
    num_roles = 5
    role_indices = {"TOP": 0, "JUNGLE": 1, "MID": 2, "BOT": 3, "UTILITY": 4}
    role_matrix = np.zeros((num_champions, num_roles), dtype=np.int8)
    for role_name, champions in champion_pools.items():
        role_idx = role_indices[role_name.upper()]
        for champ in champions:
            role_matrix[champ.id, role_idx] = 1
    return role_matrix


def human_vs_human_with_suggestions(
    model_path: str = f"{DATA_DIR}/self_play_models/final_model",
    blue_custom_role_matrix: Optional[np.ndarray] = None,
    red_custom_role_matrix: Optional[np.ndarray] = None,
    blue_use_fallback: bool = False,
    red_use_fallback: bool = False,
    numeric_patch: Optional[int] = None,
    picks_to_watch: Optional[List[Champion]] = None,
):
    """
    Allow a human to select all moves from both sides, with the model suggesting moves.

    Args:
        model_path: Path to the trained model.
        blue_custom_role_matrix: Custom role matrix for blue team.
        red_custom_role_matrix: Custom role matrix for red team.
        blue_use_fallback: If True, use fallback (all champions) for blue team.
        red_use_fallback: If True, use fallback (all champions) for red team.
        picks_to_watch: Optional list of Champion objects to track probabilities for
    """
    # Load the model
    model_pool = VisualizationModelPool(model_path)

    # Create and wrap the environment
    env = integrate_with_env(FlexibleRoleDraftEnv)(patches=patches)
    env = ActionMasker(env, action_mask_fn)
    obs, _ = env.reset()  # reset now to create the state

    if numeric_patch is not None:
        env.patch = numeric_patch

    # Override role matrices if provided
    if blue_custom_role_matrix is not None:
        env.state.blue_role_matrix = blue_custom_role_matrix.copy()
    if red_custom_role_matrix is not None:
        env.state.red_role_matrix = red_custom_role_matrix.copy()

    # Use fallback matrices if specified
    if blue_use_fallback:
        env.state.blue_role_matrix = env.state.blue_fallback_role_matrix.copy()
    if red_use_fallback:
        env.state.red_role_matrix = env.state.red_fallback_role_matrix.copy()

    # Create lookup dictionaries for champion names
    id_to_name = {champion.id: champion.display_name for champion in Champion}
    name_to_id = {champion.display_name.lower(): champion.id for champion in Champion}
    # Add common abbreviations
    name_to_id.update(
        {
            "tf": name_to_id["twisted fate"],
            "mf": name_to_id["miss fortune"],
            "asol": name_to_id["aurelion sol"],
            "j4": name_to_id["jarvan iv"],
            "tk": name_to_id["tahm kench"],
            # Add more abbreviations as needed
        }
    )

    roles = ["TOP", "JUNGLE", "MID", "BOT", "UTILITY"]

    # ANSI color codes
    BLUE = "\033[94m"
    RED = "\033[91m"
    GREEN = "\033[92m"
    RESET = "\033[0m"
    BOLD = "\033[1m"
    BAN = "\033[95m"
    PICK = "\033[96m"

    # Initialize variables to keep track of the draft
    blue_bans: List[int] = []
    red_bans: List[int] = []
    draft_history: List[dict] = []
    done = False

    while not done:
        action_mask = env.get_action_mask()
        valid_actions = np.where(action_mask == 1)[0]
        action_info = env.draft_order[env.current_step]
        current_team = action_info["team"]
        phase = action_info["phase"]

        # Fetch model suggestions
        suggestions, watched_picks = get_action_suggestions(
            model_pool.model,
            obs,
            action_mask,
            n_suggestions=5,
            picks_to_watch=picks_to_watch,
        )
        suggestions = [
            (id_to_name.get(action, "Unknown"), prob) for action, prob in suggestions
        ]
        watched_picks = [
            (id_to_name.get(action, "Unknown"), prob) for action, prob in watched_picks
        ]

        clear_output(wait=True)
        # Display model suggestions
        print("\nModel Suggestions:")
        print("-" * 100)
        for name, prob in suggestions:
            print(f"{name:<20} - Probability: {prob:.2%}")

        # Display watched picks if any
        if watched_picks:
            print("\nWatched Picks:")
            print("-" * 100)
            for name, prob in watched_picks:
                print(f"{name:<20} - Probability: {prob:.2%}")

        # Display current state
        print(
            f"{BOLD}Draft Phase: {phase}, Team: {'BLUE' if current_team == 0 else 'RED'}{RESET}"
        )
        print("-" * 100)

        # Print bans
        print("\nBans:")
        print("-" * 100)
        print(f"{'Order':<10}{BLUE}BLUE{RESET:<20}{RED}RED{RESET:<20}")
        print("-" * 50)
        max_bans = max(len(blue_bans), len(red_bans))
        for i in range(max_bans):
            blue_ban = (
                id_to_name.get(blue_bans[i], "---") if i < len(blue_bans) else "---"
            )
            red_ban = id_to_name.get(red_bans[i], "---") if i < len(red_bans) else "---"
            print(f"{BAN}Ban {i+1}{RESET:<6}{blue_ban:<20}{red_ban:<20}")

        # Print picks
        print("\nPicks:")
        print("-" * 100)
        print(f"{'Order':<10}{BLUE}BLUE{RESET:<20}{RED}RED{RESET:<20}")
        print("-" * 50)
        blue_picks = env.state.blue_picks
        red_picks = env.state.red_picks
        max_picks = max(len(blue_picks), len(red_picks))
        for i in range(max_picks):
            blue_pick = (
                id_to_name.get(blue_picks[i], "---") if i < len(blue_picks) else "---"
            )
            red_pick = (
                id_to_name.get(red_picks[i], "---") if i < len(red_picks) else "---"
            )
            print(f"{PICK}Pick {i+1}{RESET:<6}{blue_pick:<20}{red_pick:<20}")

        # Print team compositions
        print("\nTeam Compositions:")
        print("-" * 100)
        print(f"{'Role':<10}{BLUE}BLUE{RESET:<20}{RED}RED{RESET:<20}")
        print("-" * 50)
        for role in roles:
            blue_champ = "---"
            red_champ = "---"
            # Find the champion assigned to this role
            for champ_id, assigned_role in env.state.blue_roles.items():
                if assigned_role == role:
                    blue_champ = id_to_name.get(champ_id, "???")
                    break
            for champ_id, assigned_role in env.state.red_roles.items():
                if assigned_role == role:
                    red_champ = id_to_name.get(champ_id, "???")
                    break
            print(f"{role:<10}{blue_champ:<20}{red_champ:<20}")

        # Prompt for human input
        while True:
            try:
                if phase == 0:
                    prompt = "\nEnter champion to BAN: "
                elif phase == 1:
                    prompt = "\nEnter champion to PICK: "
                elif phase == 2:
                    # Get the role to assign
                    role_index = action_info["role_index"]
                    current_role = roles[role_index]
                    prompt = f"\nAssign champion to {current_role}: "
                else:
                    prompt = "\nEnter action: "

                search = input(prompt)
                if not search:
                    print("Input cannot be empty. Please try again.")
                    continue

                # Process human input to get champion ID
                search_term = search.lower().strip()
                chosen_id = None

                # Direct match with name or abbreviation
                if search_term in name_to_id:
                    chosen_id = name_to_id[search_term]
                else:
                    # Try to find close matches
                    all_names = list(name_to_id.keys())
                    matches = get_close_matches(search_term, all_names, n=5, cutoff=0.6)
                    if not matches:
                        print(
                            "No champions found matching that name. Please try again."
                        )
                        continue
                    # Show matches
                    print("\nDid you mean:")
                    for idx, name in enumerate(matches):
                        print(f"{idx + 1}. {name.title()}")
                    choice = input("Enter number (or press Enter to search again): ")
                    if not choice:
                        continue
                    try:
                        choice_idx = int(choice) - 1
                        if choice_idx < 0 or choice_idx >= len(matches):
                            raise IndexError
                        chosen_id = name_to_id[matches[choice_idx]]
                    except (ValueError, IndexError):
                        print("Invalid choice. Please try again.")
                        continue

                if chosen_id not in valid_actions:
                    print(
                        f"{id_to_name.get(chosen_id, 'Unknown')} is not available. Please choose again."
                    )
                    continue

                break
            except Exception as e:
                print(f"Error: {e}")
                print("Please try again.")

        # Keep track of bans
        if phase == 0:
            if current_team == 0:
                blue_bans.append(chosen_id)
            else:
                red_bans.append(chosen_id)

        # Store draft history
        draft_history.append(
            {
                "phase": phase,
                "team": current_team,
                "chosen_action": chosen_id,
                "suggestions": suggestions,
            }
        )

        obs, reward, done, truncated, info = env.step(chosen_id)

    # After the draft is complete
    clear_output(wait=True)
    print("\nFinal Draft:")
    print("-" * 100)

    # Print team compositions
    print("\nTeam Compositions:")
    print("-" * 100)
    print(f"{'Role':<10}{BLUE}BLUE{RESET:<20}{RED}RED{RESET:<20}")
    print("-" * 50)

    for role in roles:
        blue_champ = "---"
        red_champ = "---"
        # Find the champion assigned to this role
        for champ_id, assigned_role in env.state.blue_roles.items():
            if assigned_role == role:
                blue_champ = id_to_name.get(champ_id, "???")
                break
        for champ_id, assigned_role in env.state.red_roles.items():
            if assigned_role == role:
                red_champ = id_to_name.get(champ_id, "???")
                break

        print(f"{role:<10}{blue_champ:<20}{red_champ:<20}")

    print("\nDraft Complete!")
    print("Final winrate prediction (blue side winrate):", reward)

    # Get the final render
    image_data = env.render()
    if image_data is not None:
        display(IPImage(data=image_data))
    else:
        print("Visualization not available.")

In [None]:
# Example usage
# Prepare custom champion pools for roles
from utils.rl.champions import Champion

guerric_jungle_pool = [
    Champion.WARWICK,
    Champion.HECARIM,
    Champion.SKARNER,
    Champion.UDYR,
    Champion.REK_SAI,
    Champion.GRAVES,
]

filip_utility_pool = [
    Champion.LEONA,
    Champion.BRAUM,
    Champion.BARD,
    Champion.NAMI,
    Champion.RAKAN,
    Champion.JANNA,
    Champion.LULU,
    Champion.MAOKAI,
    Champion.BLITZCRANK,
]

# Actual players:
filip_jungle_pool = [
    Champion.IVERN,
    Champion.SEJUANI,
    Champion.ZYRA,
    Champion.ZAC,
    Champion.BRIAR,
    Champion.POPPY,
    Champion.RAMMUS,
    Champion.AMUMU,
    # Champion.FIDDLESTICKS,
    # Champion.TALIYAH,
]

geoffroy_bot_pool = [
    Champion.VEIGAR,
    Champion.ZIGGS,
    Champion.TRISTANA,
    Champion.JINX,
    Champion.VAYNE,
    Champion.EZREAL,
    Champion.ASHE,
    Champion.KOG_MAW,
    Champion.XAYAH,
    Champion.HWEI,
    Champion.SERAPHINE,
    Champion.VARUS,
    Champion.SWAIN,
]

cyprien_mid_pool = [
    Champion.AHRI,
    Champion.AKALI,
    Champion.ORIANNA,
    Champion.XERATH,
    Champion.RYZE,
    Champion.AZIR,
    Champion.VIKTOR,
    Champion.LEBLANC,
    Champion.JAYCE,
    Champion.YONE,
]

arthur_top_pool = [
    Champion.K_SANTE,
    Champion.POPPY,
    Champion.JAX,
    Champion.YONE,
]

mathias_utility_pool = [
    Champion.BARD,
    Champion.ZAC,
    Champion.PYKE,
    Champion.XERATH,
    Champion.PANTHEON,
    Champion.BRAND,
    Champion.NEEKO,
    Champion.SYLAS,
    Champion.SHACO,
    Champion.SWAIN,
    Champion.TARIC,
    Champion.POPPY,
]

# Custom champion pools for my clash team
clash_champion_pools = {
    "TOP": arthur_top_pool,
    "JUNGLE": filip_jungle_pool,
    "MID": cyprien_mid_pool,
    "BOT": geoffroy_bot_pool,
    "UTILITY": mathias_utility_pool,
}

# Create custom role matrices
clash_role_matrix = create_role_matrix_from_pools(clash_champion_pools)

major_patch = 14
minor_patch = 22

In [None]:

# Run the human_vs_human_with_suggestions function with custom role matrices
human_vs_human_with_suggestions(
    blue_custom_role_matrix=None,
    red_custom_role_matrix=clash_role_matrix,
    blue_use_fallback=True,  # Set to True to use all champions for blue team
    red_use_fallback=False,  # Set to True to use all champions for red team
    numeric_patch=major_patch * 50 + minor_patch,
)


In [None]:
major_patch = 14
minor_patch = 22

visualize_self_play(
    num_games=5,
    team1_custom_role_matrix=clash_role_matrix,
    team1_use_fallback=False,  # Set to True to use all champions for blue team
    team2_use_fallback=True,  # Set to True to use all champions for red team
    numeric_patch=major_patch * 50 + minor_patch,
)

In [None]:
human_vs_human_with_suggestions(
    picks_to_watch=filip_utility_pool,
    numeric_patch=major_patch * 50 + minor_patch,
)