In [1]:
import numpy as np
import matplotlib.pyplot as plt
import copy
import torch
import os
import time
import json
import logging

from tqdm import tqdm
import gymnasium as gym

# Import your TD3 components
from TD3.memory import ReplayBuffer
from TD3.td3 import TD3

# Import the custom Hockey environment
import hockey.hockey_env as h_env
from importlib import reload

# Reload the hockey environment to ensure the latest version is used
reload(h_env)

# ============================================
# Environment Wrapper
# ============================================
class CustomHockeyEnv:
    def __init__(self, mode=h_env.Mode.NORMAL, render_mode=None):
        """
        Wraps the HockeyEnv to match the Gym API more closely.
        """
        self.env = h_env.HockeyEnv(mode=mode)
        self.render_mode = render_mode

    def reset(self):
        obs, info = self.env.reset()
        return obs, info

    def step(self, action):
        """
        The custom step function returns (obs, r, d, t, info).
        We convert to the standard Gym format: (obs, reward, done, info).
        """
        obs, r, d, t, info = self.env.step(action)
        done = d or t
        return obs, r, done, info

    def close(self):
        self.env.close()

    def seed(self, seed):
        self.env.set_seed(seed)
        self.env.action_space.seed(seed)

    @property
    def observation_space(self):
        return self.env.observation_space

    @property
    def action_space(self):
        return self.env.action_space

    def get_info_agent_two(self):
        return self.env.get_info_agent_two()


# ============================================
# Trained Opponent Class
# ============================================
class TrainedOpponent:
    def __init__(self, agent):
        """
        Opponent that uses a trained TD3 agent to select actions.
        """
        self.agent = agent

    def act(self, observation):
        """
        Selects an action based on the current observation using the trained agent.
        """
        observation = torch.FloatTensor(observation).to(self.agent.device)
        with torch.no_grad():
            action = self.agent.actor(observation).cpu().numpy()
        return action


# ============================================
# Opponent Definitions
# ============================================
def get_opponent(opponent_type, env, trained_agent=None):
    """
    Returns an opponent based on the specified type ('weak', 'strong', 'trained').

    Args:
        opponent_type (str): Type of opponent ('weak', 'strong', 'trained').
        env (CustomHockeyEnv): The environment instance.
        trained_agent (TD3 or None): The trained agent to use as an opponent if opponent_type is 'trained'.

    Returns:
        An opponent instance compatible with the environment.
    """
    if opponent_type == "weak":
        return h_env.BasicOpponent(weak=True)
    elif opponent_type == "strong":
        return h_env.BasicOpponent(weak=False)
    elif opponent_type == "trained":
        if trained_agent is None:
            raise ValueError("trained_agent must be provided for 'trained' opponent type.")
        return TrainedOpponent(trained_agent)
    else:
        raise ValueError(f"Unknown opponent type: {opponent_type}")


# ============================================
# Extended Evaluation Function
# ============================================
def eval_policy_extended(
    policy, 
    eval_episodes=100, 
    seed=42,
    mode=h_env.Mode.NORMAL,
    opponent_type=None,
    trained_agent=None  # Existing parameter
):
    """
    A single evaluation function that can handle:
      - Normal mode with strong/weak/trained opponent
      - Shooting mode (no opponent)
      - Defense mode (no opponent)

    If `opponent_type` is 'strong', 'weak', or 'trained', we create that opponent.
    If `opponent_type` is None, we do not create any built-in opponent.

    Returns a dictionary of evaluation stats:
        {
            "avg_reward": float,
            "win": int,
            "loss": int,
            "draw": int,
            "win_rate": float
        }
    """
    # Create a new environment for evaluation to avoid state carry-over
    eval_env = CustomHockeyEnv(mode=mode)
    eval_env.seed(seed)

    # Set networks to eval mode
    policy.actor.eval()
    policy.critic.eval()

    # Initialize the opponent if needed
    if opponent_type in ["strong", "weak", "trained"]:
        opponent = get_opponent(opponent_type, eval_env, trained_agent=trained_agent)
    else:
        opponent = None  # e.g., for shooting or defense

    total_rewards = []
    results = {'win': 0, 'loss': 0, 'draw': 0}

    with torch.no_grad():
        for _ in range(eval_episodes):
            state, info = eval_env.reset()
            done = False
            episode_reward = 0

            while not done:
                # Agent action
                agent_action = policy.act(np.array(state), add_noise=False)

                if opponent is not None:
                    # Opponent action
                    opponent_obs = eval_env.env.obs_agent_two()
                    opponent_action = opponent.act(opponent_obs)
                else:
                    # No built-in opponent, so second agent does nothing
                    opponent_action = np.array([0, 0, 0, 0], dtype=np.float32)

                # Combine actions
                full_action = np.hstack([agent_action, opponent_action])

                next_state, reward, done, info = eval_env.step(full_action)
                episode_reward += reward
                state = next_state

            total_rewards.append(episode_reward)

            # Determine final outcome via `env._get_info()`
            final_info = eval_env.env._get_info()
            winner = final_info.get('winner', 0)  # 1 => agent1, -1 => agent1 loses, 0 => draw

            if winner == 1:
                results['win'] += 1
            elif winner == -1:
                results['loss'] += 1
            else:
                results['draw'] += 1

    avg_reward = np.mean(total_rewards)
    total_games = results['win'] + results['loss'] + results['draw']
    win_rate = (results['win'] / total_games) if total_games > 0 else 0.0

    # Restore networks to train mode
    policy.actor.train()
    policy.critic.train()

    eval_stats = {
        "avg_reward": float(avg_reward),
        "win": results['win'],
        "loss": results['loss'],
        "draw": results['draw'],
        "win_rate": float(win_rate)
    }

    return eval_stats


# ============================================
# Plotting Functions
# ============================================
def plot_losses(loss_data, mode, save_path):
    episodes = range(1, len(loss_data["critic_loss"]) + 1)

    plt.figure(figsize=(14, 6))

    # Plot Critic Loss
    plt.subplot(1, 2, 1)
    plt.plot(episodes, loss_data["critic_loss"], label='Critic Loss', color='blue')
    plt.xlabel('Episode')
    plt.ylabel('Loss')
    plt.title(f'Critic Loss over Episodes for {mode}')
    plt.legend()
    plt.grid(True)

    # Plot Actor Loss
    plt.subplot(1, 2, 2)
    plt.plot(episodes, loss_data["actor_loss"], label='Actor Loss', color='orange')
    plt.xlabel('Episode')
    plt.ylabel('Loss')
    plt.title(f'Actor Loss over Episodes for {mode}')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(os.path.join(save_path, f"losses_{mode}.png"))
    plt.close()

def plot_overall_statistics(overall_stats, opponent_type, save_path):
    """
    Plots overall statistics for an opponent type.

    Args:
        overall_stats (dict): Statistics including win_rate, loss_rate, etc.
        opponent_type (str): Type of opponent.
        save_path (str): Directory to save the plots.
    """
    os.makedirs(save_path, exist_ok=True)

    categories = ['Win Rate', 'Loss Rate', 'Draw Rate']
    values = [overall_stats['Agent1']['win_rate'], 
              overall_stats['Agent1']['loss_rate'], 
              overall_stats['Agent1']['draw_rate']]
    colors = ['green', 'red', 'blue']

    fig, ax = plt.subplots(figsize=(8, 6))
    ax.bar(categories, values, color=colors)
    ax.set_ylim(0, 1)
    ax.set_ylabel('Rate')
    ax.set_title(f'Overall Statistics vs {opponent_type.capitalize()} Opponent')

    for i, v in enumerate(values):
        ax.text(i, v + 0.02, f"{v:.2f}", ha='center')

    plt.tight_layout()
    plt.savefig(os.path.join(save_path, f"overall_statistics_{opponent_type}.png"))
    plt.close()
    
def plot_rewards(rewards, mode, save_path, window=20):
    plt.figure(figsize=(12,6))
    plt.plot(rewards, label='Episode Reward', alpha=0.3)
    if len(rewards) >= window:
        moving_avg = np.convolve(rewards, np.ones(window)/window, mode='valid')
        plt.plot(range(window-1, len(rewards)), moving_avg, label=f'{window}-Episode Moving Average', color='red')
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.title(f'TD3 on {mode}: Episode Rewards with Moving Average')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(save_path, f"rewards_{mode}.png"))
    plt.close()


def plot_multi_winrate_curves(winrate_dict, mode, save_path):
    """
    Plots multiple win-rate curves in the same figure.
    `winrate_dict` should be a dictionary like:
       {
         "strong": [...],
         "weak": [...],
         "trained_1": [...],
         "trained_2": [...],
         "shooting": [...],
         "defense": [...]
       }
    Each list is the time series of the agent's win rate in that scenario.
    """
    plt.figure(figsize=(12,6))

    # We'll plot each scenario if it exists
    for scenario, wr_list in winrate_dict.items():
        if len(wr_list) > 0:
            plt.plot(wr_list, label=f'Win Rate vs {scenario.capitalize()}')

    plt.xlabel('Evaluation Index (every eval_freq episodes)')
    plt.ylabel('Win Rate')
    plt.title(f'{mode} - Win Rates Over Training (All Scenarios)')
    plt.ylim([0, 1])
    plt.grid(True)
    plt.legend()
    plt.savefig(os.path.join(save_path, f"winrates_{mode}.png"))
    plt.close()


def plot_noise_comparison(agent, training_config, env_name, save_path):
    """
    Plots Pink/OU noise vs Gaussian noise for comparison, if the agent uses pink or OU noise.
    """
    if training_config["expl_noise_type"].lower() not in ["pink", "ou"]:
        print(f"No noise comparison plot available for noise type: {training_config['expl_noise_type']}")
        return

    max_steps = training_config.get("max_episode_steps", 600)
    action_dim = agent.action_dim

    # Generate noise sequence
    noise_sequence = []
    for _ in range(max_steps):
        if training_config["expl_noise_type"].lower() == "pink":
            noise = agent.pink_noise.get_noise() * training_config["expl_noise"]
        elif training_config["expl_noise_type"].lower() == "ou":
            noise = agent.ou_noise.sample() * training_config["expl_noise"]
        noise_sequence.append(noise)
    noise_sequence = np.array(noise_sequence)

    # Generate Gaussian noise for comparison
    gaussian_noise_sequence = np.random.normal(0, training_config["expl_noise"], size=(max_steps, action_dim))

    # Plot for each action dimension
    for dim in range(action_dim):
        plt.figure(figsize=(12, 4))
        plt.plot(noise_sequence[:, dim], label=f'{training_config["expl_noise_type"].capitalize()} Noise')
        plt.plot(gaussian_noise_sequence[:, dim], label='Gaussian Noise', alpha=0.7)
        plt.title(f'Noise Comparison for Action Dimension {dim} in {env_name}')
        plt.xlabel('Step')
        plt.ylabel('Noise Value')
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(save_path, f"noise_comparison_dim_{dim}_{env_name}.png"))
        plt.close()


# ============================================
# Logging Setup
# ============================================
def setup_evaluation_logging(results_dir, seed, opponent_type):
    log_file = os.path.join(results_dir, f"evaluation_log_seed_{seed}_{opponent_type}.log")
    # To prevent adding multiple handlers if this function is called multiple times
    if not logging.getLogger().hasHandlers():
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s | %(levelname)s | %(message)s',
            handlers=[
                logging.FileHandler(log_file),
                logging.StreamHandler()  # Also log to console
            ]
        )
    else:
        # Add a new handler for each opponent_type
        logger = logging.getLogger()
        handler = logging.FileHandler(log_file)
        formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
    logging.info(f"Evaluation started for opponent_type: {opponent_type}, seed: {seed}")


def setup_logging(results_dir, seed, mode):
    log_file = os.path.join(results_dir, f"training_log_seed_{seed}.log")
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s | %(levelname)s | %(message)s',
        handlers=[
            logging.FileHandler(log_file)
        ]
    )
    logging.info(f"Training started with mode: {mode}, seed: {seed}")


# ============================================
# Save Training Information
# ============================================
def save_training_info(config, mode, mixed_cycle, opponent_agent_paths, results_dir):
    # Convert mixed_cycle to a serializable format (convert Mode enums to strings)
    if mixed_cycle is not None:
        mixed_cycle_serializable = []
        for opponent, mode_enum in mixed_cycle:
            mode_str = mode_enum.name if mode_enum else None
            mixed_cycle_serializable.append((opponent, mode_str))
    else:
        mixed_cycle_serializable = None

    training_info = {
        "training_config": config,
        "training_mode": mode,
        "mixed_cycle": mixed_cycle_serializable,
        "opponent_agent_paths": opponent_agent_paths
    }
    with open(os.path.join(results_dir, "training_info.json"), "w") as f:
        json.dump(training_info, f, indent=4)


# ============================================
# Training Loop
# ============================================
def get_trained_opponent_index(episode, num_trained_opponents):
    """
    Determines which trained opponent to use based on the episode number.
    Cycles through the available trained opponents.
    """
    if num_trained_opponents == 0:
        return None
    return (episode - 1) % num_trained_opponents


def main(
    mode="vs_strong",
    episodes=700,
    seed=42,
    save_model=True,
    training_config=None,
    load_agent_path=None,
    opponent_agent_paths=None  # Changed from single path to list of paths
):
    """
    Main function to train TD3 on the Custom Hockey Environment in various modes:
      - "vs_strong": Always train against the strong built-in opponent.
      - "vs_weak": Always train against the weak built-in opponent.
      - "shooting": Train in shooting mode (no built-in opponent needed).
      - "defense": Train in defense mode (no built-in opponent needed).
      - "mixed": Cycle between (strong, weak, trained1, trained2, ..., shooting, defense) each episode.

    Args:
        mode (str): Training mode ("vs_strong", "vs_weak", "shooting", "defense", "mixed").
        episodes (int): Number of training episodes.
        seed (int): Random seed for reproducibility.
        save_model (bool): Whether to save the trained models.
        training_config (dict): Dictionary containing training configurations.
        load_agent_path (str or None): If provided, loads an existing agent from this path before training.
        opponent_agent_paths (list of str or None): List of paths to trained agents to be used as opponents in mixed mode.

    Returns:
        dict: A dictionary of final evaluation results.
    """
    # Default training configuration if none provided
    if training_config is None:
        training_config = {
            "discount": 0.99,
            "tau": 0.005,
            "policy_noise": 0.2,
            "noise_clip": 0.5,
            "policy_freq": 2,
            "max_episodes": episodes,       # total training episodes
            "start_timesteps": 1000,        # number of initial random steps
            "eval_freq": 50,                # how often to evaluate (in episodes)
            "batch_size": 256,              # batch size for training
            "expl_noise_type": "pink",      # type of exploration noise: "gaussian", "pink", or "ou"
            "expl_noise": 0.1,              # exploration noise scale
            "pink_noise_params": {          # pink noise specific params
                "exponent": 1.0,
                "fmin": 0.0
            },
            "ou_noise_params": {            # OU noise specific params
                "mu": 0.0,
                "theta": 0.15,
                "sigma": 0.2
            },
            "use_layer_norm": True,         # toggle LayerNorm
            "ln_eps": 1e-5,                
            "save_model": save_model,
            "save_model_freq": 100,         # how often to save
            "use_rnd": True,                # toggle RND
            "rnd_weight": 1.0,
            "rnd_lr": 1e-4,
            "rnd_hidden_dim": 128,
            "max_episode_steps": 600
        }

    # Decide environment mode based on user choice
    # If user chooses "shooting" or "defense", we set env_mode accordingly. 
    # If user chooses "mixed", we will override the env mode each episode.
    if mode == "shooting":
        env_mode = h_env.Mode.TRAIN_SHOOTING
    elif mode == "defense":
        env_mode = h_env.Mode.TRAIN_DEFENSE
    else:
        # For "vs_strong", "vs_weak", or "mixed", use NORMAL
        env_mode = h_env.Mode.NORMAL

    # Create environment
    env = CustomHockeyEnv(mode=env_mode)

    # Set seeds
    env.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Environment dimensions
    state_dim = env.observation_space.shape[0]
    action_dim = 4  # We only control the first 4 actions for the first agent
    max_action = float(env.action_space.high[0])

    # Directories
    base_results_dir = "./results_hockey"
    base_models_dir = "./models_hockey"
    os.makedirs(base_results_dir, exist_ok=True)
    os.makedirs(base_models_dir, exist_ok=True)

    file_name = f"TD3_Hockey_{mode}_seed_{seed}"
    results_dir = os.path.join(base_results_dir, mode, f"seed_{seed}")
    models_dir = os.path.join(base_models_dir, mode, f"seed_{seed}")
    os.makedirs(results_dir, exist_ok=True)
    if save_model:
        os.makedirs(models_dir, exist_ok=True)

    # Setup logging
    setup_logging(results_dir, seed, mode)

    logging.info(f"Training Mode: {mode}")
    logging.info(f"Environment Mode: {env_mode}")
    logging.info(f"State Dim: {state_dim}, Action Dim: {action_dim}, Max Action: {max_action}")

    # Save training info
    if mode == "mixed":
        mixed_cycle = [
            ("strong", h_env.Mode.NORMAL),
            ("strong", h_env.Mode.NORMAL),
            ("strong", h_env.Mode.NORMAL),
            ("weak", h_env.Mode.NORMAL),
            ("trained", h_env.Mode.NORMAL),
            (None,     h_env.Mode.TRAIN_SHOOTING),
            (None,     h_env.Mode.TRAIN_DEFENSE)
        ]
    else:
        mixed_cycle = None

    save_training_info(training_config, mode, mixed_cycle, opponent_agent_paths, results_dir)

    # Initialize the TD3 agent
    agent = TD3(
        state_dim=state_dim,
        action_dim=action_dim,
        max_action=max_action,
        training_config=training_config,
    )

    # If we have a path to load from, do so
    if load_agent_path is not None:
        logging.info(f"Loading agent from: {load_agent_path}")
        agent.load(load_agent_path)

    # Load the opponent agents if provided and mode is mixed
    opponent_agents = []
    if mode == "mixed" and opponent_agent_paths is not None:
        for idx, path in enumerate(opponent_agent_paths):
            logging.info(f"Loading opponent agent {idx + 1} from: {path}")
            trained_agent = TD3(
                state_dim=state_dim,
                action_dim=action_dim,
                max_action=max_action,
                training_config=training_config,
            )
            trained_agent.load(path)
            trained_agent.critic.eval()
            trained_agent.actor.eval()
            trained_agent.actor_target.eval()
            trained_agent.critic_target.eval()
            opponent_agents.append(trained_agent)

    # Replay buffer
    replay_buffer = ReplayBuffer(state_dim, action_dim, max_size=int(1e6))

    # Tracking
    total_timesteps = 0
    evaluation_results = []   # store training rewards (per-episode)
    loss_results = {
        "critic_loss": [],
        "actor_loss": []
    }

    # We track multiple win-rate curves for each scenario:
    # 1. vs Strong
    # 2. vs Weak
    # 3. vs Trained1, vs Trained2, etc.
    # 4. Shooting
    # 5. Defense
    evaluation_winrates = {
        "strong": [],
        "weak": [],
        # Dynamically add "trained_1", "trained_2", etc.
        "shooting": [],
        "defense": []
    }

    # Opponent function for the "mixed" mode
    # We'll cycle through different (opponent, env_mode)
    if mode == "mixed":
        mixed_cycle = [
            ("strong", h_env.Mode.NORMAL),
            ("strong", h_env.Mode.NORMAL),
            ("strong", h_env.Mode.NORMAL),
            ("weak", h_env.Mode.NORMAL),
            ("trained", h_env.Mode.NORMAL),
            (None,     h_env.Mode.TRAIN_SHOOTING),
            (None,     h_env.Mode.TRAIN_DEFENSE)
        ]
    else:
        mixed_cycle = None  # Not used for other modes

    pbar = tqdm(total=training_config["max_episodes"], desc=f"Training {mode} Seed {seed}")
    for episode in range(1, training_config["max_episodes"] + 1):
        # Decide opponent and environment mode for this episode
        if mode == "vs_strong":
            opponent_type = "strong"
            current_env_mode = env_mode
            selected_trained_agent = None
        elif mode == "vs_weak":
            opponent_type = "weak"
            current_env_mode = env_mode
            selected_trained_agent = None
        elif mode == "shooting":
            opponent_type = None
            current_env_mode = h_env.Mode.TRAIN_SHOOTING
            selected_trained_agent = None
        elif mode == "defense":
            opponent_type = None
            current_env_mode = h_env.Mode.TRAIN_DEFENSE
            selected_trained_agent = None
        elif mode == "mixed":
            # Cycle through different (opponent, env_mode)
            idx = (episode - 1) % len(mixed_cycle)
            cycle_opponent_type, cycle_env_mode = mixed_cycle[idx]
            
            if cycle_opponent_type == "trained" and len(opponent_agents) > 0:
                # Select which trained opponent to use
                trained_idx = get_trained_opponent_index(episode, len(opponent_agents))
                opponent_type = "trained"
                current_env_mode = cycle_env_mode
                selected_trained_agent = opponent_agents[trained_idx]
            else:
                opponent_type = cycle_opponent_type
                current_env_mode = cycle_env_mode
                selected_trained_agent = None

        # If needed, reinitialize environment for changed mode
        if mode == "mixed" and env.env.mode != current_env_mode:
            env.close()
            env = CustomHockeyEnv(mode=current_env_mode)
            env.seed(seed + episode)  # or some offset
            logging.info(f"Switched environment mode to: {current_env_mode}")

        # Create opponent if applicable
        if opponent_type in ["strong", "weak", "trained"]:
            opponent = get_opponent(opponent_type, env, trained_agent=selected_trained_agent)
        else:
            opponent = None  # no built-in opponent for shooting/defense

        state, info = env.reset()
        done = False
        episode_reward = 0
        episode_timesteps = 0

        # Reset noise if needed (OU or Pink)
        if (training_config["expl_noise_type"].lower() == "ou") and hasattr(agent, 'ou_noise'):
            agent.ou_noise.reset()
        if (training_config["expl_noise_type"].lower() == "pink") and hasattr(agent, 'pink_noise'):
            agent.pink_noise.reset()

        # Reset loss accumulators
        cumulative_critic_loss = 0.0
        cumulative_actor_loss = 0.0
        loss_steps = 0
        actor_loss_steps = 0

        while not done:
            episode_timesteps += 1
            total_timesteps += 1

            # Select action
            if total_timesteps < training_config["start_timesteps"]:
                action = env.env.action_space.sample()[:action_dim]
            else:
                action = agent.act(np.array(state), add_noise=True)

            # Opponent action if applicable
            if opponent is not None:
                opp_obs = env.env.obs_agent_two()
                opp_action = opponent.act(opp_obs)
                full_action = np.hstack([action, opp_action])
            else:
                # Shooting/Defense modes do not require an explicit opponent
                full_action = np.hstack([action, [0,0,0,0]])

            next_state, reward, done, info = env.step(full_action)
            done_bool = float(done) if episode_timesteps < training_config["max_episode_steps"] else 0

            # Store in replay buffer
            replay_buffer.add(state, action, next_state, reward, done_bool)

            state = next_state
            episode_reward += reward

            # Train the agent (after enough steps)
            if total_timesteps >= training_config["start_timesteps"]:
                critic_loss, actor_loss = agent.train(replay_buffer, training_config["batch_size"])
                cumulative_critic_loss += critic_loss
                loss_steps += 1
                if actor_loss is not None:
                    cumulative_actor_loss += actor_loss
                    actor_loss_steps += 1

        # Track losses
        avg_critic_loss = cumulative_critic_loss / loss_steps if loss_steps > 0 else 0
        avg_actor_loss = cumulative_actor_loss / actor_loss_steps if actor_loss_steps > 0 else 0
        loss_results["critic_loss"].append(avg_critic_loss)
        loss_results["actor_loss"].append(avg_actor_loss)

        # Track rewards
        evaluation_results.append(episode_reward)

        # Log per-episode information
        logging.info(f"Episode {episode} | Reward: {episode_reward:.2f} | Critic Loss: {avg_critic_loss:.4f} | Actor Loss: {avg_actor_loss:.4f}")

        # Evaluate at intervals
        if episode % training_config["eval_freq"] == 0:
            logging.info(f"===== Evaluation at Episode {episode} =====")
            # Evaluate vs STRONG
            stats_strong = eval_policy_extended(
                policy=agent, 
                eval_episodes=100,
                seed=seed + 10, 
                mode=h_env.Mode.NORMAL,
                opponent_type="strong"
            )
            evaluation_winrates["strong"].append(stats_strong["win_rate"])
            logging.info(f"  vs Strong  => WinRate: {stats_strong['win_rate']:.2f}")

            # Evaluate vs WEAK
            stats_weak = eval_policy_extended(
                policy=agent, 
                eval_episodes=100,
                seed=seed + 20, 
                mode=h_env.Mode.NORMAL,
                opponent_type="weak"
            )
            evaluation_winrates["weak"].append(stats_weak["win_rate"])
            logging.info(f"  vs Weak    => WinRate: {stats_weak['win_rate']:.2f}")

            # Evaluate each TRAINED opponent
            if mode == "mixed" and len(opponent_agents) > 0:
                for idx, trained_agent in enumerate(opponent_agents):
                    stats_trained = eval_policy_extended(
                        policy=agent, 
                        eval_episodes=100,
                        seed=seed + 25 + idx,  # Different seed for each opponent
                        mode=h_env.Mode.NORMAL,
                        opponent_type="trained",
                        trained_agent=trained_agent
                    )
                    key = f"trained_{idx + 1}"
                    evaluation_winrates[key] = evaluation_winrates.get(key, [])
                    evaluation_winrates[key].append(stats_trained["win_rate"])
                    logging.info(f"  vs Trained_{idx + 1} => WinRate: {stats_trained['win_rate']:.2f}")

            # Evaluate SHOOTING mode
            stats_shooting = eval_policy_extended(
                policy=agent, 
                eval_episodes=100,
                seed=seed + 30, 
                mode=h_env.Mode.TRAIN_SHOOTING,
                opponent_type=None
            )
            evaluation_winrates["shooting"].append(stats_shooting["win_rate"])
            logging.info(f"  Shooting   => WinRate: {stats_shooting['win_rate']:.2f}")

            # Evaluate DEFENSE mode
            stats_defense = eval_policy_extended(
                policy=agent, 
                eval_episodes=100,
                seed=seed + 40, 
                mode=h_env.Mode.TRAIN_DEFENSE,
                opponent_type=None
            )
            evaluation_winrates["defense"].append(stats_defense["win_rate"])
            logging.info(f"  Defense    => WinRate: {stats_defense['win_rate']:.2f}")

            # Save intermediate results
            np.save(os.path.join(results_dir, f"{file_name}_evaluations.npy"), evaluation_results)
            np.save(os.path.join(results_dir, f"{file_name}_winrates.npy"), evaluation_winrates)

            # Optionally save model
            if save_model and (episode % training_config["save_model_freq"] == 0):
                agent.save(os.path.join(models_dir, f"{file_name}_episode_{episode}.pth"))
                logging.info(f"Saved model at episode {episode}")

        pbar.update(1)

    pbar.close()

    # Optionally save final model
    if save_model:
        agent.save(os.path.join(models_dir, f"{file_name}_final.pth"))
        logging.info(f"Saved final model to {models_dir}")

    # ==========================
    # Post-Training Analysis
    # ==========================
    evaluation_results = np.array(evaluation_results)

    # Plot the loss curves
    plot_losses(loss_results, mode, results_dir)

    # Plot the episode rewards
    plot_rewards(evaluation_results, mode, results_dir, window=20)

    # Plot the stored multi-scenario win-rate curves (across training)
    if len(evaluation_winrates["strong"]) > 0:
        plot_multi_winrate_curves(evaluation_winrates, mode, results_dir)

    # Plot the final noise comparison (if OU or Pink)
    plot_noise_comparison(agent, training_config, mode, results_dir)

    # -- Final Extended Evaluation (again, for all scenarios) --
    logging.info("\n===== Final Extended Evaluation (100 episodes each scenario) =====")
    final_stats_strong = eval_policy_extended(agent, mode=h_env.Mode.NORMAL, opponent_type="strong", seed=seed+100)
    final_stats_weak = eval_policy_extended(agent, mode=h_env.Mode.NORMAL, opponent_type="weak", seed=seed+200)
    if mode == "mixed" and len(opponent_agents) > 0:
        for idx, trained_agent in enumerate(opponent_agents):
            final_stats_trained = eval_policy_extended(
                agent, 
                mode=h_env.Mode.NORMAL, 
                opponent_type="trained", 
                seed=seed+250+idx, 
                trained_agent=trained_agent
            )
            logging.info(f"  vs Trained_{idx + 1} => WinRate: {final_stats_trained['win_rate']:.2f}")
    final_stats_shooting = eval_policy_extended(agent, mode=h_env.Mode.TRAIN_SHOOTING, opponent_type=None, seed=seed+300)
    final_stats_defense = eval_policy_extended(agent, mode=h_env.Mode.TRAIN_DEFENSE, opponent_type=None, seed=seed+400)

    logging.info(f"  vs Strong  => WinRate: {final_stats_strong['win_rate']:.2f}")
    logging.info(f"  vs Weak    => WinRate: {final_stats_weak['win_rate']:.2f}")
    logging.info(f"  Shooting   => WinRate: {final_stats_shooting['win_rate']:.2f}")
    logging.info(f"  Defense    => WinRate: {final_stats_defense['win_rate']:.2f}")

    # Store final results in the dictionary
    final_evaluation_results = {
        "final_eval_strong": final_stats_strong,
        "final_eval_weak": final_stats_weak,
        "final_eval_shooting": final_stats_shooting,
        "final_eval_defense": final_stats_defense,
    }

    if mode == "mixed" and len(opponent_agents) > 0:
        for idx, trained_agent in enumerate(opponent_agents):
            # Ensure the final_stats_trained is correctly stored
            key = f"final_eval_trained_{idx + 1}"
            final_evaluation_results[key] = eval_policy_extended(
                agent, 
                mode=h_env.Mode.NORMAL, 
                opponent_type="trained", 
                seed=seed+250+idx, 
                trained_agent=trained_agent
            )

    final_evaluation_results.update({
        "loss_results": loss_results,
        "training_rewards": evaluation_results.tolist(),
        "winrates_vs_scenarios": evaluation_winrates
    })

    # Plot the stored multi-scenario win-rate curves (across training)
    if len(evaluation_winrates["strong"]) > 0:
        plot_multi_winrate_curves(evaluation_winrates, mode, results_dir)

    # Save final evaluation dictionary
    with open(os.path.join(results_dir, f"{file_name}_final_evaluations.json"), "w") as f:
        json.dump(final_evaluation_results, f, indent=4)

    logging.info("Training completed successfully.")
    return final_evaluation_results


# ===============================
# Example Usage
# ===============================
if __name__ == "__main__":
    mode = "mixed"  # options: "vs_weak", "vs_strong", "shooting", "defense", "mixed"
    episodes = 70000
    seed = 44
    save_model = True

    custom_training_config = {
        "discount": 0.99,
        "tau": 0.005,
        "policy_noise": 0.2,
        "noise_clip": 0.5,
        "policy_freq": 2,
        "max_episodes": episodes,
        "start_timesteps": 1000,
        "eval_freq": 200,      # Evaluate every 200 episodes
        "batch_size": 256,
        "expl_noise_type": "pink",
        "expl_noise": 0.1,
        "pink_noise_params": {"exponent": 1.0, "fmin": 0.0},
        "ou_noise_params": {"mu": 0.0, "theta": 0.15, "sigma": 0.2},
        "use_layer_norm": True,
        "ln_eps": 1e-5,
        "save_model": save_model,
        "save_model_freq": 10000,
        "use_rnd": True,
        "rnd_weight": 1.0,
        "rnd_lr": 1e-4,
        "rnd_hidden_dim": 128,
        "max_episode_steps": 600
    }

    # Paths to the trained opponent agents
    trained_opponent_paths = [
        "models_hockey/vs_strong/seed_42/TD3_Hockey_vs_strong_seed_42_final.pth",
        "models_hockey/mixed/seed_43/TD3_Hockey_mixed_seed_43_final.pth",
        # Add more paths as needed
    ]

    final_results = main(
        mode=mode,
        episodes=episodes,
        seed=seed,
        save_model=save_model,
        training_config=custom_training_config,
        load_agent_path=None,  # or "models_hockey/vs_strong/seed_42/TD3_Hockey_vs_strong_seed_42_final.pth" to resume
        opponent_agent_paths=trained_opponent_paths  # Provide the list of paths to the trained agents
    )

  logger.warn(f"Overriding environment {new_spec.id} already in registry.")
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")
  self.critic.load_state_dict(torch.load(filename + "_critic.pth"))
  self.critic_optimizer.load_state_dict(torch.load(filename + "_critic_optimizer.pth"))
  self.actor.load_state_dict(torch.load(filename + "_actor.pth"))
  self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer.pth"))
  self.rnd.target_network.load_state_dict(torch.load(filename + "_rnd_target.pth"))
  self.rnd.predictor_network.load_state_dict(torch.load(filename + "_rnd_predictor.pth"))
  self.rnd.optimizer.load_state_dict(torch.load(filename + "_rnd_optimizer.pth"))
Training mixed Seed 44:   0%|          | 12/70000 [00:02<4:43:55,  4.11it/s]

KeyboardInterrupt: 

In [2]:
def evaluate_trained_agent(
    agent_path, 
    opponent_types,  # List of opponent types to evaluate against
    trained_opponent_paths=None,  # Dict mapping opponent_type to list of trained opponent paths
    num_games=100, 
    env_mode=h_env.Mode.NORMAL,
    seed=42
):
    """
    Evaluates the trained TD3 agent against specified opponents independently, ensuring side consistency.

    Args:
        agent_path (str): Path to the trained agent model file.
        opponent_types (list of str): List of opponent types to evaluate against ("weak", "strong", "shooting", "defense", "trained").
        trained_opponent_paths (dict of str: list of str or None): 
            - Keys are opponent types (e.g., "trained").
            - Values are lists of paths to trained agent models for those types.
            - For non-trained opponents, keys can be omitted or set to None.
        num_games (int): Number of games to play for each evaluation per role.
        env_mode (h_env.Mode): The environment mode to use for evaluation.
        seed (int): Random seed for reproducibility.

    Returns:
        dict: A nested dictionary containing statistics for each opponent type and side.
    """
    # Initialize the environment
    env = CustomHockeyEnv(mode=env_mode)
    env.seed(seed)  # Parameterize seed if needed

    # Set up directories for saving evaluation results
    base_results_dir = "./evaluation_results_hockey"
    os.makedirs(base_results_dir, exist_ok=True)

    # Load the trained agent
    state_dim = env.observation_space.shape[0]
    action_dim = 4  # Assuming action_dim is 4 as per training
    max_action = float(env.action_space.high[0])

    # Define a basic training configuration for initializing the agent
    custom_training_config = {
        "discount": 0.99,
        "tau": 0.005,
        "policy_noise": 0.2,
        "noise_clip": 0.5,
        "policy_freq": 2,
        "max_episodes": 70000,  # Not used in evaluation
        "start_timesteps": 1000,  # Not used in evaluation
        "eval_freq": 200,  # Not used in evaluation
        "batch_size": 256,
        "expl_noise_type": "pink",
        "expl_noise": 0.1,
        "pink_noise_params": {
            "exponent": 1.0,
            "fmin": 0.0
        },
        "ou_noise_params": {
            "mu": 0.0,
            "theta": 0.15,
            "sigma": 0.2
        },
        "use_layer_norm": True,
        "ln_eps": 1e-5,
        "save_model": False,  # Not saving during evaluation
        "save_model_freq": 100,
        "use_rnd": True,
        "rnd_weight": 1.0,
        "rnd_lr": 1e-4,
        "rnd_hidden_dim": 128,
        "max_episode_steps": 600
    }

    agent = TD3(
        state_dim=state_dim,
        action_dim=action_dim,
        max_action=max_action,
        training_config=custom_training_config
    )

    agent.load(agent_path)
    logging.info(f"Loaded trained agent from {agent_path}")

    # Set networks to evaluation mode
    agent.actor.eval()
    agent.critic.eval()

    # Load trained opponents if any
    trained_opponents = {}
    if trained_opponent_paths:
        for opp_type, paths in trained_opponent_paths.items():
            if opp_type not in opponent_types:
                continue  # Only load trained opponents for specified types
            trained_opponents[opp_type] = []
            for path in paths:
                trained_agent = TD3(
                    state_dim=state_dim,
                    action_dim=action_dim,
                    max_action=max_action,
                    training_config=custom_training_config
                )
                trained_agent.load(path)
                trained_agent.actor.eval()
                trained_agent.critic.eval()
                trained_opponents[opp_type].append(trained_agent)
                logging.info(f"Loaded trained opponent agent from {path} for type '{opp_type}'")
                
    print(trained_opponents)
    # Initialize a dictionary to hold all statistics
    all_stats = {}

    # Iterate over each opponent type
    for opponent_type in opponent_types:
        logging.info(f"\n===== Evaluating against {opponent_type} opponent =====")

        # Setup logging for this opponent type
        results_dir = os.path.join(base_results_dir, opponent_type)
        os.makedirs(results_dir, exist_ok=True)
        setup_evaluation_logging(results_dir, seed, opponent_type)

        # Initialize statistics containers
        stats_agent1 = {'win': 0, 'loss': 0, 'draw': 0}

        # Determine if the opponent type requires arbitrary actions
        requires_arbitrary_actions = opponent_type in ["shooting", "defense"]

        # Determine opponents based on type
        if opponent_type == "trained":
            # Use the list of trained opponents for this type
            if opponent_type in trained_opponents and trained_opponents[opponent_type]:
                opponents = trained_opponents[opponent_type]
            else:
                logging.warning(f"No trained opponents provided for type '{opponent_type}'. Skipping evaluation.")
                continue
        elif requires_arbitrary_actions:
            # No opponent agent; actions will be arbitrary
            opponents = [None]  # Placeholder
        else:
            # For built-in opponents, instantiate once
            opponent = get_opponent(opponent_type, env)
            opponents = [opponent]

        # Evaluation Loop for Agent1 (Left Side)
        for idx, current_opponent in enumerate(opponents):
            # For 'shooting' and 'defense' modes, do not switch sides
            is_switchable = False  # Always evaluate as Agent1

            # If 'trained', evaluate separately for each trained opponent
            if opponent_type == "trained":
                opponent_identifier = f"{opponent_type}_{idx + 1}"
            else:
                opponent_identifier = opponent_type

            logging.info(f"\n--- Evaluating against {opponent_identifier} ---")

            # Set the environment mode accordingly
            if opponent_type == "shooting":
                current_env_mode = h_env.Mode.TRAIN_SHOOTING
            elif opponent_type == "defense":
                current_env_mode = h_env.Mode.TRAIN_DEFENSE
            else:
                current_env_mode = h_env.Mode.NORMAL

            # Reinitialize the environment with the appropriate mode
            env.close()
            env = CustomHockeyEnv(mode=current_env_mode)
            env.seed(seed + idx)  # Different seed per opponent to ensure diversity

            # Log the current environment mode
            logging.info(f"Environment mode set to: {current_env_mode}")

            # Initialize evaluation statistics for this opponent
            stats_agent1_opponent = {'win': 0, 'loss': 0, 'draw': 0}

            # Evaluate as Agent1 (Left Side)
            logging.info(f"Evaluating as Agent1 (Left Side) against {opponent_identifier}")
            for game in tqdm(range(1, num_games + 1), desc=f"{opponent_identifier} - Agent1"):
                state, info = env.reset()
                done = False

                episode_reward1 = 0

                while not done:
                    # Agent1 selects action
                    agent_action = agent.act(np.array(state), add_noise=False)

                    # Opponent selects action
                    if requires_arbitrary_actions:
                        opponent_action = get_arbitrary_action(opponent_type)
                    else:
                        opponent_action = current_opponent.act(env.env.obs_agent_two())

                    # Combine actions
                    full_action = np.hstack([agent_action, opponent_action])

                    # Step environment
                    next_state, reward, done, info = env.step(full_action)
                    episode_reward1 += reward

                    # Prepare for next step
                    state = next_state

                # Determine the outcome
                winner = info.get('winner', 0)  # 1 => agent1 wins, -1 => agent1 loses, 0 => draw

                if winner == 1:
                    stats_agent1_opponent['win'] += 1
                elif winner == -1:
                    stats_agent1_opponent['loss'] += 1
                else:
                    stats_agent1_opponent['draw'] += 1

                # Log per-game results (optional)
                logging.debug(f"Game {game} as Agent1: Reward1={episode_reward1:.2f}, Winner={winner}")

            # Calculate statistics for Agent1
            win_rate_agent1 = stats_agent1_opponent['win'] / num_games
            loss_rate_agent1 = stats_agent1_opponent['loss'] / num_games
            draw_rate_agent1 = stats_agent1_opponent['draw'] / num_games

            stats = {}
            stats["Agent1"] = {
                "total_games": num_games,
                "wins": stats_agent1_opponent['win'],
                "losses": stats_agent1_opponent['loss'],
                "draws": stats_agent1_opponent['draw'],
                "win_rate": win_rate_agent1,
                "loss_rate": loss_rate_agent1,
                "draw_rate": draw_rate_agent1,
                "mean_reward": None,  # Placeholder if you decide to track rewards
                "std_reward": None
            }

            # Aggregate statistics
            all_stats[opponent_identifier] = stats

            # Log the statistics
            logging.info("\n===== Evaluation Statistics =====")
            for side, side_stats in stats.items():
                logging.info(f"\n{side} Statistics:")
                for key, value in side_stats.items():
                    if value is not None:
                        logging.info(f"  {key}: {value:.4f}")
                    else:
                        logging.info(f"  {key}: N/A")

            # Plot overall statistics
            plot_overall_statistics(
                overall_stats=stats,
                opponent_type=opponent_identifier,
                save_path=results_dir
            )

            # Save detailed evaluation results
            evaluation_results = {
                "Agent1": {
                    "wins": stats_agent1_opponent['win'],
                    "losses": stats_agent1_opponent['loss'],
                    "draws": stats_agent1_opponent['draw']
                },
                "Statistics": stats
            }

            with open(os.path.join(results_dir, f"evaluation_results_{opponent_identifier}.json"), "w") as f:
                json.dump(evaluation_results, f, indent=4)

    # After evaluating all opponent types, reset the environment to NORMAL mode
    env.close()
    env = CustomHockeyEnv(mode=h_env.Mode.NORMAL)
    env.seed(seed)
    logging.info("Environment reset to NORMAL mode.")

    # Save all statistics to a master JSON file
    master_stats_path = os.path.join(base_results_dir, "master_evaluation_stats.json")
    with open(master_stats_path, "w") as f:
        json.dump(all_stats, f, indent=4)

    logging.info("\n===== All Evaluations Completed =====")
    return all_stats


def get_arbitrary_action(opponent_type):
    """
    Returns an arbitrary action based on the opponent type.
    For 'shooting', returns a fixed shooting action.
    For 'defense', returns a fixed defensive action.
    """
    if opponent_type == "shooting":
        return np.array([1, 0, 0, 1])  # Example shooting action
    elif opponent_type == "defense":
        return np.array([0.1, 0, 0, 1])  # Example defensive action
    else:
        # Default arbitrary action
        return np.zeros(4)

In [6]:
agent_path = "models_hockey/mixed/seed_44/TD3_Hockey_mixed_seed_44_final.pth"
opponent_types = ["strong", "weak", "shooting", "defense", "trained"]
num_games = 1000  # Number of evaluation games
env_mode = h_env.Mode.NORMAL  # Or other modes like TRAIN_SHOOTING, TRAIN_DEFENSE
trained_opponent_agent_paths = {
    "trained": [
        "models_hockey/vs_strong/seed_42/TD3_Hockey_vs_strong_seed_42_final.pth",
        "models_hockey/mixed/seed_43/TD3_Hockey_mixed_seed_43_final.pth",
        "models_hockey/mixed/seed_44/TD3_Hockey_mixed_seed_44_final.pth"
        # Add more paths as needed
    ]
}
seed = 44

evaluation_stats = evaluate_trained_agent(
        agent_path=agent_path,
        opponent_types=opponent_types,
        trained_opponent_paths=trained_opponent_agent_paths,  # Provide the dict of trained opponent paths
        num_games=num_games,
        env_mode=env_mode,
        seed=seed
    )

# Print the evaluation statistics
print("\n===== Final Evaluation Statistics =====")
for opponent, stats in evaluation_stats.items():
    print(f"\nOpponent Identifier: {opponent}")
    for side, side_stats in stats.items():
        print(f"  {side}:")
        for key, value in side_stats.items():
            if value is not None:
                print(f"    {key}: {value:.4f}")
            else:
                print(f"    {key}: N/A")

{'trained': [<TD3.td3.TD3 object at 0x7f60249308d0>, <TD3.td3.TD3 object at 0x7f6055715710>, <TD3.td3.TD3 object at 0x7f60244034d0>]}


strong - Agent1: 100%|██████████| 1000/1000 [00:22<00:00, 43.76it/s]
weak - Agent1: 100%|██████████| 1000/1000 [00:21<00:00, 46.24it/s]
shooting - Agent1: 100%|██████████| 1000/1000 [00:11<00:00, 90.72it/s]
defense - Agent1: 100%|██████████| 1000/1000 [00:09<00:00, 102.79it/s]
trained_1 - Agent1: 100%|██████████| 1000/1000 [00:33<00:00, 30.14it/s]
trained_2 - Agent1: 100%|██████████| 1000/1000 [00:23<00:00, 43.42it/s]
trained_3 - Agent1: 100%|██████████| 1000/1000 [00:48<00:00, 20.67it/s]



===== Final Evaluation Statistics =====

Opponent Identifier: strong
  Agent1:
    total_games: 1000.0000
    wins: 977.0000
    losses: 20.0000
    draws: 3.0000
    win_rate: 0.9770
    loss_rate: 0.0200
    draw_rate: 0.0030
    mean_reward: N/A
    std_reward: N/A

Opponent Identifier: weak
  Agent1:
    total_games: 1000.0000
    wins: 991.0000
    losses: 6.0000
    draws: 3.0000
    win_rate: 0.9910
    loss_rate: 0.0060
    draw_rate: 0.0030
    mean_reward: N/A
    std_reward: N/A

Opponent Identifier: shooting
  Agent1:
    total_games: 1000.0000
    wins: 974.0000
    losses: 11.0000
    draws: 15.0000
    win_rate: 0.9740
    loss_rate: 0.0110
    draw_rate: 0.0150
    mean_reward: N/A
    std_reward: N/A

Opponent Identifier: defense
  Agent1:
    total_games: 1000.0000
    wins: 966.0000
    losses: 21.0000
    draws: 13.0000
    win_rate: 0.9660
    loss_rate: 0.0210
    draw_rate: 0.0130
    mean_reward: N/A
    std_reward: N/A

Opponent Identifier: trained_1
  Agent1: