In [None]:
import collections
import time
import typing as tt
from dataclasses import dataclass
import os

import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Categorical
import torch.optim as optim
from tqdm import tqdm

from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.envs.unity_parallel_env import UnityParallelEnv
from mlagents_envs.envs.custom_side_channel import CustomDataChannel, StringSideChannel
import wandb

# --- Config & Hyperparameters ---
config_dict = {
    # --- Run Identification ---
    "RUN_NAME": "d4x21vnd_final_optimized", 

    # --- PPO & GAE ---
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "PPO_EPSILON": 0.2,
    "PPO_EPOCHS": 6,
    "PPO_BATCH_SIZE": 1024,
    
    "VALUE_LOSS_COEF": 0.5,
    
    "ENTROPY_ANNEAL": True,
    "ENTROPY_COEF_START": 0.02,   
    "ENTROPY_COEF_END": 0.005,    

    "LEARNING_RATE": 1e-4,
    "MIN_LEARNING_RATE": 1e-5,   
    "LR_SCHEDULE_ANNEAL": True,
    
    "TOTAL_TIMESTEPS": 50_000_000, 

    # --- Environment & Sampling ---
    "SKIP_N": 4,  
    "STACK_N": 4,
    "MAX_EPISODE_STEPS": 2000,
    
    "NUM_ENVS": 16, 
    "PPO_STEPS_PER_ENV": 1024,
    "BASE_PORT": 5004,
    "PORT_STRIDE": 10,
    "EXECUTABLE_PATH": r"C:\Users\tan04\Downloads\BuildTraining\BuildTraining\dp.exe",

    # --- Model Management & Evaluation ---
    "SAVE_EVERY_FRAMES": 100_000,
    "MODEL_SAVE_PATH": "saved_models_ppo_dual_resnet_final",

    # --- Gating / Balancing Mechanism (Hysteresis) ---
    "SCORE_GATING_THRESHOLD": 2.0, 
    "SCORE_RESUME_THRESHOLD": -1, 
    "GATING_WARMUP_EPISODES": 100, 
}


class CheckpointManager:
    """Handles saving and loading of the entire training state for resuming."""
    def __init__(self, base_model_path: str, run_name: str):
        self.checkpoint_dir = os.path.join(base_model_path, run_name)
        self.checkpoint_file = os.path.join(self.checkpoint_dir, "checkpoint.pth")
        if not os.path.exists(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
            print(f"Created checkpoint directory: {self.checkpoint_dir}")

    def save_checkpoint(self, net_p0: nn.Module, net_p1: nn.Module, 
                        optimizer_p0: optim.Optimizer, optimizer_p1: optim.Optimizer,
                        total_frames: int, update: int, wandb_run_id: str,
                        p0_active: bool, p1_active: bool):
        """Saves the complete state needed to resume training, including gating status."""
        state = {
            'net_p0_state_dict': net_p0.state_dict(),
            'net_p1_state_dict': net_p1.state_dict(),
            'optimizer_p0_state_dict': optimizer_p0.state_dict(),
            'optimizer_p1_state_dict': optimizer_p1.state_dict(),
            'total_frames': total_frames,
            'update': update,
            'wandb_run_id': wandb_run_id,
            'p0_active': p0_active, 
            'p1_active': p1_active 
        }
        print(f"Saving checkpoint to {self.checkpoint_file} at frame {total_frames}...")
        torch.save(state, self.checkpoint_file)
        print("Checkpoint saved.")

    def load_checkpoint(self, device: torch.device) -> tt.Optional[dict]:
        """Loads a checkpoint if it exists."""
        if os.path.exists(self.checkpoint_file):
            print(f"Loading checkpoint from {self.checkpoint_file}...")
            checkpoint = torch.load(self.checkpoint_file, map_location=device)
            print("Checkpoint loaded.")
            return checkpoint
        else:
            print("No checkpoint found. Starting a new training run.")
            return None

# --- Neural Network Classes ---
class ResidualBlock(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        out += residual
        return self.relu(out)

class ActorCriticResNet(nn.Module):
    def __init__(self, input_shape: tt.Tuple[int, int, int], n_actions_per_head: tt.List[int]):
        super().__init__()
        c, h, w = input_shape
        
        self.conv = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(),
            ResidualBlock(64),
            ResidualBlock(64),
            nn.Flatten()
        )
        
        with torch.no_grad():
            conv_out_size = self.conv(torch.zeros(1, *input_shape)).size(-1)
            
        self.fc_base = nn.Sequential(nn.Linear(conv_out_size, 512), nn.ReLU())
        
        self.action_head_0 = nn.Linear(512, n_actions_per_head[0])
        self.action_head_1 = nn.Linear(512, n_actions_per_head[1])
        self.action_head_2 = nn.Linear(512, n_actions_per_head[2])
        self.value_head = nn.Linear(512, 1)

    def forward(self, x: torch.Tensor) -> tt.Tuple[tt.Tuple[torch.Tensor, ...], torch.Tensor]:
        base_features = self.fc_base(self.conv(x.float() / 255.0))
        logits = (
            self.action_head_0(base_features), 
            self.action_head_1(base_features), 
            self.action_head_2(base_features)
        )
        value = self.value_head(base_features)
        return logits, value


# --- Environment Wrappers ---
class UnityEnvWrapper:
    def __init__(self, env: UnityParallelEnv, max_episode_steps: int):
        self.env = env
        self.max_episode_steps = max_episode_steps
        self.agent_ids = self.env.agents
        self.p0_id, self.p1_id = self.agent_ids[0], self.agent_ids[1]
        self._step_count = 0

    @staticmethod
    def _preprocess_obs(img: np.ndarray) -> np.ndarray:
        # Grayscale and Crop
        gray = 0.299 * img[0] + 0.587 * img[1] + 0.114 * img[2]
        return (gray[np.newaxis, 24:-6, 18:-18] * 255).astype(np.uint8)

    def reset(self) -> np.ndarray:
        self._step_count = 0
        observations = self.env.reset()
        img = observations[self.p0_id]['observation'][0]
        return self._preprocess_obs(img)

    def step(self, actions: dict) -> tt.Tuple[np.ndarray, dict, bool, dict]:
        observations, rewards, dones, infos = self.env.step(actions)
        img = observations[self.p0_id]['observation'][0]
        processed_frame = self._preprocess_obs(img)
        step_rewards = {self.p0_id: float(rewards.get(self.p0_id, 0.0)), self.p1_id: float(rewards.get(self.p1_id, 0.0))}
        done_env = bool(dones.get('__all__', False))
        self._step_count += 1
        truncated = self._step_count >= self.max_episode_steps
        done = done_env or truncated
        return processed_frame, step_rewards, done, {}

    def close(self):
        self.env.close()

class Multi_Env_With_SkipN_and_Stack:
    def __init__(self, config):
        self.config = config
        self.envs: tt.List[UnityEnvWrapper] = []
        self._frame_stacks: tt.List[collections.deque] = []
        
        print(f"Starting {self.config.NUM_ENVS} environments...")
        for i in range(self.config.NUM_ENVS):
            port = self.config.BASE_PORT + i * self.config.PORT_STRIDE
            channel = CustomDataChannel()
            string_channel = StringSideChannel()
            channel.send_data(serve=212, p1=0, p2=0) 
            unity_env = UnityEnvironment(file_name=self.config.EXECUTABLE_PATH, base_port=port, side_channels=[string_channel, channel])
            wrapped_env = UnityEnvWrapper(UnityParallelEnv(unity_env), self.config.MAX_EPISODE_STEPS)
            self.envs.append(wrapped_env)
            self._frame_stacks.append(collections.deque(maxlen=self.config.STACK_N))
            print(f"  - Env {i} started on port {port}.")

        first_env = self.envs[0]
        self.p0_id, self.p1_id = first_env.p0_id, first_env.p1_id
        
        self.states = self.reset()
        
    def _get_env_state(self, frame_stack: collections.deque) -> np.ndarray:
        return np.concatenate(list(frame_stack), axis=0)

    def reset(self) -> tt.List[np.ndarray]:
        for i, env in enumerate(self.envs):
            first_frame = env.reset()
            for _ in range(self.config.STACK_N):
                self._frame_stacks[i].append(first_frame)
        return [self._get_env_state(fs) for fs in self._frame_stacks]

    @staticmethod
    def _decode_action(action_index: int) -> np.ndarray:
        a0 = action_index % 3
        a1 = (action_index // 3) % 3
        a2 = (action_index // 9) % 3
        return np.array([a0, a1, a2], dtype=np.int32)

    def step(self, action_idxs_p0: tt.List[int], action_idxs_p1: tt.List[int]) -> tt.Tuple[tt.List[np.ndarray], dict, tt.List[bool]]:
        num_envs = len(self.envs)
        cum_p0_rewards = [0.0] * num_envs
        cum_p1_rewards = [0.0] * num_envs
        dones = [False] * num_envs
        last_frames = [None] * num_envs

        decoded_p0_actions = [self._decode_action(idx) for idx in action_idxs_p0]
        decoded_p1_actions = [self._decode_action(idx) for idx in action_idxs_p1]

        for _ in range(self.config.SKIP_N):
            for i, env in enumerate(self.envs):
                if not dones[i]:
                    env_actions = {self.p0_id: decoded_p0_actions[i], self.p1_id: decoded_p1_actions[i]}
                    new_frame, rewards, done, _ = env.step(env_actions)
                    
                    cum_p0_rewards[i] += rewards[self.p0_id] - rewards[self.p1_id]
                    cum_p1_rewards[i] += rewards[self.p1_id] - rewards[self.p0_id]
                    
                    last_frames[i] = new_frame
                    if done:
                        dones[i] = True
        
        for i in range(num_envs):
            if dones[i]:
                first_frame = self.envs[i].reset()
                self._frame_stacks[i].clear()
                for _ in range(self.config.STACK_N):
                    self._frame_stacks[i].append(first_frame)
            else:
                self._frame_stacks[i].append(last_frames[i])
        
        new_states = [self._get_env_state(fs) for fs in self._frame_stacks]
        reward_dict = {self.p0_id: cum_p0_rewards, self.p1_id: cum_p1_rewards}
        
        return new_states, reward_dict, dones

    def close(self):
        for env in self.envs: env.close()

# --- PPO Implementation ---
@dataclass
class TrajectoryItem:
    state: np.ndarray
    action: int
    reward: float
    done: bool
    log_prob: float
    value: float

class Collector:
    def __init__(self, envs: Multi_Env_With_SkipN_and_Stack, config):
        self.envs = envs
        self.config = config
        self.current_states = self.envs.states

    def collect_trajectories(self, net_p0: ActorCriticResNet, net_p1: ActorCriticResNet, device: torch.device) -> tt.Tuple[dict, dict, np.ndarray]:
        trajectories_p0 = {i: [] for i in range(self.config.NUM_ENVS)}
        trajectories_p1 = {i: [] for i in range(self.config.NUM_ENVS)}
        
        for _ in range(self.config.PPO_STEPS_PER_ENV):
            states_v = torch.from_numpy(np.array(self.current_states)).to(device)
            
            with torch.no_grad():
                logits_p0, values_p0_v = net_p0(states_v)
                dists_p0 = [Categorical(logits=l) for l in logits_p0]
                actions_p0 = [d.sample() for d in dists_p0]
                log_probs_p0 = sum(d.log_prob(a) for d, a in zip(dists_p0, actions_p0))
                action_idxs_p0 = actions_p0[0] + 3 * actions_p0[1] + 9 * actions_p0[2]

                logits_p1, values_p1_v = net_p1(states_v)
                dists_p1 = [Categorical(logits=l) for l in logits_p1]
                actions_p1 = [d.sample() for d in dists_p1]
                log_probs_p1 = sum(d.log_prob(a) for d, a in zip(dists_p1, actions_p1))
                action_idxs_p1 = actions_p1[0] + 3 * actions_p1[1] + 9 * actions_p1[2]

            new_states, rewards_dict, dones = self.envs.step(
                action_idxs_p0.cpu().numpy(), 
                action_idxs_p1.cpu().numpy()
            )
            
            for i in range(self.config.NUM_ENVS):
                trajectories_p0[i].append(TrajectoryItem(
                    state=self.current_states[i], action=action_idxs_p0[i].item(),
                    reward=rewards_dict[self.envs.p0_id][i], done=dones[i],
                    log_prob=log_probs_p0[i].item(), value=values_p0_v[i].item()
                ))
                trajectories_p1[i].append(TrajectoryItem(
                    state=self.current_states[i], action=action_idxs_p1[i].item(),
                    reward=rewards_dict[self.envs.p1_id][i], done=dones[i],
                    log_prob=log_probs_p1[i].item(), value=values_p1_v[i].item()
                ))
            
            self.current_states = new_states
        
        last_states = np.array(self.current_states)
        return trajectories_p0, trajectories_p1, last_states

class PPOTrainer:
    def __init__(self, net: ActorCriticResNet, optimizer: optim.Optimizer, device: torch.device, agent_id: str, config):
        self.net = net
        self.optimizer = optimizer
        self.device = device
        self.agent_id = agent_id
        self.config = config

    def _calculate_advantages_and_returns(self, trajectories_by_env: tt.Dict[int, tt.List[TrajectoryItem]], last_states: np.ndarray):
        all_advantages, all_returns = [], []
        last_states_v = torch.from_numpy(last_states).to(self.device)
        with torch.no_grad():
            _, last_vals_v = self.net(last_states_v)
        
        for i in range(len(trajectories_by_env)):
            advantages, returns = [], []
            gae_advantage = 0.0
            next_value = last_vals_v[i].item()
            if trajectories_by_env[i][-1].done:
                next_value = 0.0
            
            for traj in reversed(trajectories_by_env[i]):          
                delta = traj.reward + self.config.GAMMA * next_value * (1 - int(traj.done)) - traj.value
                gae_advantage = delta + self.config.GAMMA * self.config.GAE_LAMBDA * gae_advantage * (1 - int(traj.done))
                returns.append(gae_advantage + traj.value)
                advantages.append(gae_advantage)
                next_value = traj.value
            
            all_advantages.extend(list(reversed(advantages)))
            all_returns.extend(list(reversed(returns)))

        return all_advantages, all_returns

    def train(self, trajectories_by_env: tt.Dict[int, tt.List[TrajectoryItem]], last_states: np.ndarray, current_entropy_coef: float):
        trajectories = [item for sublist in trajectories_by_env.values() for item in sublist]
        advantages, returns = self._calculate_advantages_and_returns(trajectories_by_env, last_states)
        
        states = np.array([t.state for t in trajectories])
        actions = np.array([t.action for t in trajectories])
        log_probs_old = np.array([t.log_prob for t in trajectories], dtype=np.float32)
        advantages = np.array(advantages, dtype=np.float32)
        returns = np.array(returns, dtype=np.float32)
        values_old = np.array([t.value for t in trajectories], dtype=np.float32)

        var_y = np.var(returns)
        explained_var = 1 - np.var(returns - values_old) / (var_y + 1e-8) if var_y > 1e-8 else 0
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        for _ in range(self.config.PPO_EPOCHS):
            indices = np.random.permutation(len(trajectories))
            for start in range(0, len(trajectories), self.config.PPO_BATCH_SIZE):
                end = start + self.config.PPO_BATCH_SIZE
                batch_indices = indices[start:end]                  
                
                batch_states = torch.from_numpy(states[batch_indices]).to(self.device)
                batch_actions_v = torch.from_numpy(actions[batch_indices]).to(self.device)
                batch_log_probs_old_v = torch.from_numpy(log_probs_old[batch_indices]).to(self.device)
                batch_advantages_v = torch.from_numpy(advantages[batch_indices]).to(self.device)
                batch_returns_v = torch.from_numpy(returns[batch_indices]).to(self.device)
                
                logits, values_v = self.net(batch_states)
                values_v = values_v.squeeze(-1)

                dists = [Categorical(logits=l) for l in logits]
                a0 = batch_actions_v % 3
                a1 = (batch_actions_v // 3) % 3
                a2 = (batch_actions_v // 9) % 3
                log_probs_new = dists[0].log_prob(a0) + dists[1].log_prob(a1) + dists[2].log_prob(a2)
                
                ratio = torch.exp(log_probs_new - batch_log_probs_old_v)
                surr1 = ratio * batch_advantages_v
                surr2 = torch.clamp(ratio, 1.0 - self.config.PPO_EPSILON, 1.0 + self.config.PPO_EPSILON) * batch_advantages_v
                policy_loss = -torch.min(surr1, surr2).mean()

                value_loss = nn.functional.mse_loss(values_v, batch_returns_v)
                entropy = sum(d.entropy().mean() for d in dists)
                
                total_loss = (policy_loss + self.config.VALUE_LOSS_COEF * value_loss - current_entropy_coef * entropy)
                
                self.optimizer.zero_grad()
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.net.parameters(), max_norm=0.5)
                self.optimizer.step()

        with torch.no_grad():
            log_ratio = log_probs_new - batch_log_probs_old_v
            approx_kl = torch.mean((torch.exp(log_ratio) - 1) - log_ratio).item()
            clip_fraction = torch.mean((torch.abs(ratio - 1.0) > self.config.PPO_EPSILON).float()).item()
        
        wandb.log({
            f"train/{self.agent_id}_policy_loss": policy_loss.item(),
            f"train/{self.agent_id}_value_loss": value_loss.item(),
            f"train/{self.agent_id}_entropy": entropy.item(),
            f"train/{self.agent_id}_explained_variance": explained_var,
            f"train/{self.agent_id}_approx_kl": approx_kl,
            f"train/{self.agent_id}_clip_fraction": clip_fraction,
        }, commit=False)

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 1. Initialization
    RUN_NAME = config_dict.get("RUN_NAME", wandb.util.generate_id())
    MODEL_SAVE_PATH = config_dict.get("MODEL_SAVE_PATH", "saved_models_ppo_dual_resnet_final")

    checkpoint_manager = CheckpointManager(MODEL_SAVE_PATH, RUN_NAME)
    checkpoint = checkpoint_manager.load_checkpoint(device)

    # 3. State variables
    total_frames = 0
    start_update = 1
    last_save_frame = 0
    wandb_run_id = None 
    update_for_save = 1 
    
    # Gating States (Persist across updates)
    p0_training_active = True
    p1_training_active = True

    # 4. Resume Logic
    if checkpoint:
        try:
            wandb_run_id = checkpoint['wandb_run_id']
            total_frames = checkpoint['total_frames']
            start_update = checkpoint['update'] + 1
            update_for_save = start_update 
            last_save_frame = total_frames
            
            # Load gating states if available, else default to True
            p0_training_active = checkpoint.get('p0_active', True)
            p1_training_active = checkpoint.get('p1_active', True)
            
            print(f"Found checkpoint for W&B run ID: {wandb_run_id}. Resuming from frame {total_frames}.")
            print(f"Restored gating states: P0 Active={p0_training_active}, P1 Active={p1_training_active}")
        except KeyError as e:
            print(f"Checkpoint file is corrupt or missing key: {e}. Starting a new run.")
            checkpoint = None 
            wandb_run_id = None
    
    # 5. Initialize W&B
    if wandb_run_id:
        run = wandb.init(
            project="PPO-Skip-N-Step-Parallel-Resumable",
            config=config_dict,
            id=wandb_run_id,
            name=RUN_NAME, 
            resume="must"
        )
    else:
        print("Starting a new training run.")
        run = wandb.init(
            project="PPO-Skip-N-Step-Parallel-Resumable",
            config=config_dict,
            name=RUN_NAME, 
        )
    
    wandb_run_id = run.id
    config = run.config 
    print(f"W&B Run ID: {wandb_run_id}")

    # 7. Initialize envs, models, optimizers
    envs = Multi_Env_With_SkipN_and_Stack(config)
    actual_input_shape = envs.states[0].shape
    print(f"Detected observation shape from environment: {actual_input_shape}")
    
    n_actions_per_head = [3, 3, 3]

    net_p0 = ActorCriticResNet(actual_input_shape, n_actions_per_head).to(device)
    net_p1 = ActorCriticResNet(actual_input_shape, n_actions_per_head).to(device)
    
    optimizer_p0 = optim.Adam(net_p0.parameters(), lr=config.LEARNING_RATE, eps=1e-5)
    optimizer_p1 = optim.Adam(net_p1.parameters(), lr=config.LEARNING_RATE, eps=1e-5)

    # 8. Load model/optimizer states
    if checkpoint:
        print("Loading model and optimizer states from checkpoint...")
        net_p0.load_state_dict(checkpoint['net_p0_state_dict'])
        net_p1.load_state_dict(checkpoint['net_p1_state_dict'])
        optimizer_p0.load_state_dict(checkpoint['optimizer_p0_state_dict'])
        optimizer_p1.load_state_dict(checkpoint['optimizer_p1_state_dict'])
        print("States loaded successfully.")
    
    collector = Collector(envs, config)
    trainer_p0 = PPOTrainer(net_p0, optimizer_p0, device, agent_id="p0", config=config)
    trainer_p1 = PPOTrainer(net_p1, optimizer_p1, device, agent_id="p1", config=config)
    
    ep_info_buffer_p0 = collections.deque(maxlen=100)
    ep_info_buffer_p1 = collections.deque(maxlen=100)
    ep_len_buffer = collections.deque(maxlen=100)
    
    current_ep_rewards_p0 = np.zeros(config.NUM_ENVS, dtype=np.float32)
    current_ep_rewards_p1 = np.zeros(config.NUM_ENVS, dtype=np.float32)
    current_ep_lengths = np.zeros(config.NUM_ENVS, dtype=np.int32)
    
    # --- Main Training Loop ---
    start_time = time.time()
    num_updates = config.TOTAL_TIMESTEPS // (config.NUM_ENVS * config.PPO_STEPS_PER_ENV)
    
    try:
        for update in tqdm(range(start_update, num_updates + 1), desc="PPO Updates", initial=start_update, total=num_updates):
            update_for_save = update 
            
            # --- Schedule: Learning Rate & Entropy ---
            frac = 1.0 - (update - 1.0) / num_updates

            if config.LR_SCHEDULE_ANNEAL:
                min_lr = config.MIN_LEARNING_RATE
                max_lr = config.LEARNING_RATE
                new_lr = min_lr + (max_lr - min_lr) * frac
                optimizer_p0.param_groups[0]["lr"] = new_lr
                optimizer_p1.param_groups[0]["lr"] = new_lr

            if config_dict.get("ENTROPY_ANNEAL", False):
                start_e = config_dict["ENTROPY_COEF_START"]
                end_e = config_dict["ENTROPY_COEF_END"]
                current_entropy_coef = end_e + (start_e - end_e) * frac
            else:
                current_entropy_coef = 0.01 # Default fallback

            # -----------------------------------------

            trajectories_p0_by_env, trajectories_p1_by_env, last_states = collector.collect_trajectories(net_p0, net_p1, device)
            
            frames_in_batch = sum(len(traj_list) for traj_list in trajectories_p0_by_env.values()) * config.SKIP_N
            total_frames += frames_in_batch
            
            for i in range(config.NUM_ENVS):
                for t_item_p0, t_item_p1 in zip(trajectories_p0_by_env[i], trajectories_p1_by_env[i]):
                    current_ep_rewards_p0[i] += t_item_p0.reward
                    current_ep_rewards_p1[i] += t_item_p1.reward
                    current_ep_lengths[i] += 1
                    
                    if t_item_p0.done:
                        ep_info_buffer_p0.append(current_ep_rewards_p0[i])
                        ep_info_buffer_p1.append(current_ep_rewards_p1[i])
                        ep_len_buffer.append(current_ep_lengths[i])
                        
                        current_ep_rewards_p0[i] = 0
                        current_ep_rewards_p1[i] = 0
                        current_ep_lengths[i] = 0

            # --- Gating Mechanism Start (Hysteresis Logic) ---
            PAUSE_THRESHOLD = config.SCORE_GATING_THRESHOLD
            RESUME_THRESHOLD = config.SCORE_RESUME_THRESHOLD

            mean_score_p0 = 0
            mean_score_p1 = 0
            score_diff = 0.0

            if len(ep_info_buffer_p0) >= config.GATING_WARMUP_EPISODES:
                mean_score_p0 = np.mean(ep_info_buffer_p0)
                mean_score_p1 = np.mean(ep_info_buffer_p1)
                score_diff = mean_score_p0 - mean_score_p1
                
                # --- P0 Control Logic ---
                if p0_training_active:
                    if score_diff > PAUSE_THRESHOLD:
                        p0_training_active = False
                else:
                    if score_diff < RESUME_THRESHOLD:
                        p0_training_active = True

                # --- P1 Control Logic (Symmetric) ---
                if p1_training_active:
                    if score_diff < -PAUSE_THRESHOLD:
                        p1_training_active = False
                else:
                    if score_diff > -RESUME_THRESHOLD:
                        p1_training_active = True
            
            should_train_p0 = p0_training_active
            should_train_p1 = p1_training_active
            # --- Gating Mechanism End ---

            if should_train_p0:
                trainer_p0.train(trajectories_p0_by_env, last_states, current_entropy_coef)
            
            if should_train_p1:
                trainer_p1.train(trajectories_p1_by_env, last_states, current_entropy_coef)
            
            fps = frames_in_batch / (time.time() - start_time)
            start_time = time.time()

            log_data = {
                "charts/total_frames": total_frames,
                "charts/fps": fps,
                "charts/learning_rate": optimizer_p0.param_groups[0]["lr"],
                "charts/entropy_coef": current_entropy_coef,
                "charts/update_step": update, 
                "gating/score_diff": score_diff,
                "gating/train_p0_active": int(should_train_p0), 
                "gating/train_p1_active": int(should_train_p1), 
            }
            
            if len(ep_len_buffer) > 0:
                log_data["episodes/mean_reward_p0"] = np.mean(ep_info_buffer_p0)
                log_data["episodes/mean_reward_p1"] = np.mean(ep_info_buffer_p1)
                log_data["episodes/mean_length"] = np.mean(ep_len_buffer)

            wandb.log(log_data, commit=True)

            if total_frames - last_save_frame >= config.SAVE_EVERY_FRAMES:
                checkpoint_manager.save_checkpoint(
                    net_p0, net_p1, optimizer_p0, optimizer_p1,
                    total_frames, update, wandb_run_id,
                    p0_training_active, p1_training_active
                )
                last_save_frame = total_frames

        print("Training finished.")

    except KeyboardInterrupt:
        print("\nTraining interrupted by user.")
    finally:
        print("Saving final checkpoint and closing all environments...")
        checkpoint_manager.save_checkpoint(
            net_p0, net_p1, optimizer_p0, optimizer_p1,
            total_frames, update_for_save, wandb_run_id,
            p0_training_active, p1_training_active
        )
        envs.close()
        wandb.finish()
        print(f"Cleanup complete. Final checkpoint saved.")