# 0. 导入

In [1]:
import matplotlib
matplotlib.use('TkAgg') # 或 'Qt5Agg' 等
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.policies import BaseFeaturesExtractor
from PIL import Image
import matplotlib.pyplot as plt
from typing import Tuple, List, Dict, Optional, Any, Union
import cv2
import random
from collections import defaultdict
from scipy.special import softmax # For softmax function
import os


def load_images_to_dict(base_dir: str) -> Dict[str, List[np.ndarray]]:
    image_datasets = {}
    # 定义支持的图片扩展名
    valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.gif'}
    
    # 遍历基础目录下的所有子文件夹
    for folder_name in os.listdir(base_dir):
        folder_path = os.path.join(base_dir, folder_name)
        
        # 确保是文件夹
        if not os.path.isdir(folder_path):
            continue
        
        image_list = []
        
        # 遍历文件夹中的文件
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            
            # 检查文件扩展名
            ext = os.path.splitext(filename)[1].lower()
            if ext not in valid_extensions:
                continue  # 跳过非图片文件
            
            try:
                # 用PIL打开图片并转换为RGB（避免Alpha通道问题）
                with Image.open(file_path) as img:
                    img = img.convert('RGB')  # 统一转换为RGB三通道
                    image_list.append(np.array(img))
            except Exception as e:
                print(f"Error loading {file_path}: {str(e)}")
                continue
        
        # 将文件夹名作为key，图片数组列表作为value
        image_datasets[folder_name] = image_list
    
    return image_datasets

class CustomImageDataset(Dataset):
    def __init__(self, image_list: List[np.ndarray], transform: Any = None):
        """
        Args:
            image_list (List[np.ndarray]): List of images as NumPy arrays.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.images = image_list
        self.transform = transform

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, idx: int) -> torch.Tensor:
        image_np = self.images[idx]

        # Ensure image is in uint8 format for PIL conversion if necessary,
        # or if ToTensor is directly applied to numpy array.
        if image_np.dtype != np.uint8:
            # This can happen if images were, e.g., float32 from some processing.
            # Assuming they were originally 0-255 range if they are not uint8.
            # If they are already 0-1 float, this needs adjustment.
            # For typical image data, astype(np.uint8) is safe.
            image_np = image_np.astype(np.uint8)

        # transforms.ToTensor() can handle HWC uint8 NumPy arrays directly.
        # It will convert to CHW FloatTensor and scale to [0.0, 1.0].
        # If you prefer to use PIL Image explicitly:
        # image_pil = Image.fromarray(image_np)
        # if self.transform:
        #     image_tensor = self.transform(image_pil)
        # else: # Fallback if no transform, though ToTensor is crucial
        #     image_tensor = transforms.ToTensor()(image_pil)

        # Direct application to NumPy array (preferred if array is HWC)
        if self.transform:
            image_tensor = self.transform(image_np)
        else:
            # Default to ToTensor if no specific transform is given,
            # as it's essential for PyTorch models.
            image_tensor = transforms.ToTensor()(image_np)

        return image_tensor


# 1. Survival Game

In [None]:

# ========================
# 环境配置
# ========================
MAP_SIZE = 50
PREDATOR_COUNT = 3
MAX_FOOD = 250
CLUSTER_RADIUS = 3
PIXEL_TYPES = {
    0: 'environment',
    1: 'predator',
    2: 'food',
    3: 'agent'
}
REWARDS = {
    'predator': -50,  # 调整为文档要求的-50
    'food': +10,      # 调整为文档要求的+10
    'environment': -1
}
NUM_VIEWS = 4  # Front, Left, Right, Back
INITIAL_SCORE = 100  # 添加初始分数

# ========================
# 环境实现
# ========================
class SurvivalGameEnv(gym.Env):
    metadata = {'render_modes': ['human'], 'render_fps': 4}

    def __init__(self, image_datasets: Dict[str, List[np.ndarray]]):
        super(SurvivalGameEnv, self).__init__()

        self.image_datasets = image_datasets
        # This transform is for external use if needed, model handles its own
        self.vis_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        # These will be converted to probabilities for Escape, Eat, Wander inside step()
        # stable_baselines3 requires finite bounds for Box action space.
        # These outputs will be treated as logits for 3 intents.
        # A large range like [-100, 100] should be sufficient.
        self.action_space = spaces.Box(low=-100.0, high=100.0, shape=(3,), dtype=np.float32)

        self.observation_space = spaces.Box(
            low=0, high=255,
            shape=(NUM_VIEWS, 100, 100, 3),
            dtype=np.uint8
        )
        self.current_map_image = None  # For rendering
        self.fig = None
        self.ax = None
        self.current_map_image_artist = None # 将存储imshow返回的artist对象

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple[np.ndarray, Dict[str, Any]]:
        super().reset(seed=seed)
        np.random.seed(seed)
        random.seed(seed)
        
        # 初始化地图
        self.map = np.zeros((MAP_SIZE, MAP_SIZE), dtype=np.uint8)
        self.predators = []
        self.foods = {}
        
        # 放置捕食者
        for _ in range(PREDATOR_COUNT):
            pos = self._random_edge_position()
            while self.map[pos] != 0:  # 确保位置为空
                pos = self._random_edge_position()
            self.map[pos] = 1
            self.predators.append({
                'pos': pos,
                'img_idx': np.random.randint(len(self.image_datasets['predator']))
            })
        
        # 集群式生成食物
        num_clusters = max(1, MAX_FOOD // 10)  # 每10个食物一个集群
        cluster_centers = [self._random_position() for _ in range(num_clusters)]
        
        for center in cluster_centers:
            # 使用高斯分布生成集群
            cluster_size = min(MAX_FOOD - len(self.foods), MAX_FOOD // num_clusters)
            for _ in range(cluster_size):
                # 在集群半径内生成随机偏移
                dx = int(np.random.normal(0, CLUSTER_RADIUS / 2))
                dy = int(np.random.normal(0, CLUSTER_RADIUS / 2))
                pos = (
                    np.clip(center[0] + dx, 0, MAP_SIZE - 1),
                    np.clip(center[1] + dy, 0, MAP_SIZE - 1)
                )
                
                if self._is_valid_position(pos) and self.map[pos] == 0:
                    self.map[pos] = 2
                    self.foods[pos] = {
                        'img_idx': np.random.randint(len(self.image_datasets['food'])),
                        'consumed': False
                    }
                    if len(self.foods) >= MAX_FOOD:
                        break
            if len(self.foods) >= MAX_FOOD:
                break

        # 放置代理
        self.agent_pos = self._random_position()
        while self.map[self.agent_pos] != 0:
            self.agent_pos = self._random_position()
        
        # 初始化代理方向和分数
        self.agent_direction = 0  # 0:上, 1:左, 2:右, 3:下
        self.score = INITIAL_SCORE  # 添加分数系统
        
        self.prev_agent_pos = self.agent_pos
        self.steps = 0
        self.terminated = False
        self.truncated = False

        return self._get_observation(), {}

    def step(self, raw_actions: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
        if not (isinstance(raw_actions, np.ndarray) and raw_actions.shape == (3,)):
            raise ValueError(
                f"Invalid action type or shape passed to env.step(). "
                f"Expected 3-element np.ndarray (logits/preferences), got {type(raw_actions)} with value {raw_actions}"
            )
        # Convert raw_actions (logits/preferences) to probabilities for the 3 high-level intents
        # using softmax.
        # raw_actions can be thought of as the direct output of the policy network's last layer
        # for the 3 high-level intents before any sampling.
        intent_probs = softmax(raw_actions).astype(np.float32)

        # 1. 计算移动概率分布 (using the new intent_probs)
        move_probs_dist = self._calculate_movement_distribution(intent_probs) # This method already expects probabilities

        # 2. 选择移动方向
        chosen_direction_idx = np.random.choice(NUM_VIEWS, p=move_probs_dist)

        # 3. 移动代理并更新方向
        self.prev_agent_pos = self.agent_pos
        self.agent_pos = self._move_in_direction(self.agent_pos, chosen_direction_idx)
        self.agent_direction = chosen_direction_idx

        # 4. 处理交互和奖励
        reward = 0
        collided_entity_type = self.map[self.agent_pos]

        if collided_entity_type == 1:
            reward += REWARDS['predator']
            #self.terminated = True
        elif collided_entity_type == 2:
            if self.agent_pos in self.foods:
                reward += REWARDS['food']
                del self.foods[self.agent_pos]
                self.map[self.agent_pos] = 0
        elif collided_entity_type == 0:
            reward += REWARDS['environment']
        else:
            reward += REWARDS['environment']

        self.score += reward
        if self.score < 1:
            self.terminated = True

        self._move_predators()

        if not self.terminated:
            for pred in self.predators:
                if pred['pos'] == self.agent_pos:
                    # If caught by predator moving into agent's square
                    # Ensure this penalty is applied correctly, potentially overwriting previous reward for this step
                    # Or add a large negative value. Let's assume it sets the reward for this step to predator penalty.
                    current_step_reward_value = REWARDS['predator']
                    self.score += (current_step_reward_value - reward) # Adjust score based on the new reward for this step
                    reward = current_step_reward_value

                    if self.score < 1:
                        self.score = 0
                        self.terminated = True
                    break
        
        self.steps += 1
        if self.steps >= 1000:
            self.truncated = True

        return self._get_observation(), float(reward), self.terminated, self.truncated, {}

    def _get_direction_vectors(self) -> List[Tuple[int, int]]:
        """获取方向向量：前、左、右、后（相对于当前方向）"""
        # 绝对方向：0:上, 1:左, 2:右, 3:下
        # 根据代理当前方向计算相对方向
        if self.agent_direction == 0:  # 面向"上"
            return [(-1, 0), (0, -1), (0, 1), (1, 0)]  # 前、左、右、后
        elif self.agent_direction == 1:  # 面向"左"
            return [(0, -1), (1, 0), (-1, 0), (0, 1)]  # 前、左、右、后
        elif self.agent_direction == 2:  # 面向"右"
            return [(0, 1), (-1, 0), (1, 0), (0, -1)]  # 前、左、右、后
        else:  # 面向"下"
            return [(1, 0), (0, 1), (0, -1), (-1, 0)]  # 前、左、右、后

    def _get_observation(self) -> np.ndarray:
        """获取代理四个方向的观察图像（基于当前方向）"""
        observations = []
        direction_vectors = self._get_direction_vectors()
        
        # Determine a default image shape from the observation space if a category has no images
        # obs_space shape is (NUM_VIEWS, H, W, C)
        # Individual image shape is (H, W, C)
        h_obs, w_obs, c_obs = self.observation_space.shape[1], self.observation_space.shape[2], self.observation_space.shape[3]
        default_image_for_category = np.zeros((h_obs, w_obs, c_obs), dtype=np.uint8)
        
        for dr, dc in direction_vectors:
            target_pos = (
                (self.agent_pos[0] + dr) % MAP_SIZE,
                (self.agent_pos[1] + dc) % MAP_SIZE
            )
            entity_type = self.map[target_pos]
            
            img_array = None # Initialize to ensure it gets assigned
            
            if entity_type == 1:  # 捕食者
                predator_images = self.image_datasets.get('predator', []) # Get list, default to empty
                specific_pred_img = None
                
                # Try to find the specific predator for this position to use its assigned image
                if predator_images: # Only proceed if there are predator images
                    for p in self.predators:
                        if p['pos'] == target_pos:
                            img_idx = p['img_idx']
                            if 0 <= img_idx < len(predator_images):
                                specific_pred_img = predator_images[img_idx]
                            else:
                                # Log if a specific predator has an invalid img_idx
                                print(f"Warning: Predator at {target_pos} has invalid img_idx {img_idx}. Using random predator image.")
                            break # Found the predator entry in self.predators
                
                if specific_pred_img is not None:
                    img_array = specific_pred_img
                elif predator_images: # Fallback to a random predator image if specific one not found/invalid
                    img_array = random.choice(predator_images)
                else: # No predator images available at all
                    img_array = default_image_for_category
                    # Optional: print("Warning: 'predator' image dataset is empty. Using default image.")

            elif entity_type == 2:  # 食物
                food_images = self.image_datasets.get('food', [])
                food_details = self.foods.get(target_pos)

                if food_details and food_images: # If details for this food exist and there are food images
                    img_idx = food_details.get('img_idx', -1) # Use -1 to indicate invalid if key missing
                    if 0 <= img_idx < len(food_images):
                        img_array = food_images[img_idx]
                    else: # img_idx is out of bounds or was missing, pick a random one
                        # print(f"Warning: Food at {target_pos} has invalid/missing img_idx. Using random food image.")
                        img_array = random.choice(food_images)
                elif food_images: # Food on map, but no details in self.foods (consistency issue) or no food_details. Pick random.
                    # print(f"Warning: Food details not found for {target_pos} or img_idx issue. Using random food image.")
                    img_array = random.choice(food_images)
                else: # No food images available
                    img_array = default_image_for_category
                    # Optional: print("Warning: 'food' image dataset is empty. Using default image.")
            
            else:  # 环境 (entity_type == 0 or any other type)
                environment_images = self.image_datasets.get('environment', [])
                if environment_images:
                    img_array = random.choice(environment_images)
                else: # No environment images available
                    img_array = default_image_for_category
                    # Optional: print("Warning: 'environment' image dataset is empty. Using default image.")
            
            # This check should ideally not be needed if logic above is exhaustive
            if img_array is None:
                print(f"Critical Warning: img_array ended up as None for {target_pos}, entity_type {entity_type}. Using default.")
                img_array = default_image_for_category

            observations.append(img_array)
        
        return np.array(observations, dtype=np.uint8)

    def _calculate_movement_distribution(self, action_probs: np.ndarray) -> np.ndarray:
        """计算移动概率分布（基于相对方向）"""
        # 获取周围实体
        surrounding_entities = []
        direction_vectors = self._get_direction_vectors()
        
        for dr, dc in direction_vectors:
            pos = (
                (self.agent_pos[0] + dr) % MAP_SIZE,
                (self.agent_pos[1] + dc) % MAP_SIZE
            )
            surrounding_entities.append(self.map[pos])
        
        move_dist = np.zeros(NUM_VIEWS)
        
        # Escape逻辑：远离感知到的捕食者（向相反方向移动）
        predator_sensed = False
        for i, entity in enumerate(surrounding_entities):
            if entity == 1:  # 捕食者
                # 相反方向是索引3-i（前<->后，左<->右）
                opposite_idx = 3 - i
                move_dist[opposite_idx] += action_probs[0]
                predator_sensed = True
        
        if not predator_sensed:
            move_dist += action_probs[0] / NUM_VIEWS
        
        # Eat逻辑：朝向感知到的食物
        food_sensed = False
        for i, entity in enumerate(surrounding_entities):
            if entity == 2:  # 食物
                move_dist[i] += action_probs[1]
                food_sensed = True
        
        if not food_sensed:
            move_dist += action_probs[1] / NUM_VIEWS
        
        # Wander逻辑：随机移动
        move_dist += action_probs[2] / NUM_VIEWS
        
        # 归一化
        total = np.sum(move_dist)
        if total > 0:
            move_dist /= total
        else:
            move_dist = np.ones(NUM_VIEWS) / NUM_VIEWS
        
        return move_dist

    def _move_in_direction(self, current_pos: Tuple[int, int], direction_idx: int) -> Tuple[int, int]:
        """在指定方向上移动位置（基于相对方向）"""
        dr, dc = self._get_direction_vectors()[direction_idx]
        new_r = (current_pos[0] + dr) % MAP_SIZE
        new_c = (current_pos[1] + dc) % MAP_SIZE
        return new_r, new_c

    def _move_predators(self):
        """改进的捕食者移动逻辑，解决移动冲突且禁止移动到食物位置"""
        # 清除所有捕食者位置标记（临时设为环境）
        for pred in self.predators:
            self.map[pred['pos']] = 0
        
        moved_positions = set()  # 记录所有计划移动到的位置
        new_predators = []       # 存储移动后的捕食者
        move_vectors = [(-1, 0), (1, 0), (0, -1), (0, 1)]  # 上、下、左、右
        
        # 随机顺序处理捕食者，确保公平性
        random.shuffle(self.predators)
        
        for pred in self.predators:
            current_pos = pred['pos']
            moved = False
            
            # 尝试所有可能的移动方向（随机顺序）
            random.shuffle(move_vectors)
            
            for dr, dc in move_vectors:
                new_pos = (current_pos[0] + dr, current_pos[1] + dc)
                
                # 检查是否移出边界
                if not (0 <= new_pos[0] < MAP_SIZE and 0 <= new_pos[1] < MAP_SIZE):
                    # 在边缘生成新捕食者
                    edge_pos = self._random_edge_position()
                    
                    # 确保新位置有效且未被占用
                    if self.map[edge_pos] == 0 and edge_pos not in moved_positions:
                        new_pred = {
                            'pos': edge_pos,
                            'img_idx': np.random.randint(len(self.image_datasets['predator']))
                        }
                        self.map[edge_pos] = 1
                        moved_positions.add(edge_pos)
                        new_predators.append(new_pred)
                        moved = True
                        break
                    # 如果位置无效，继续尝试其他方向
                    continue
                
                # 检查目标位置是否有效（只能是环境或代理，不能是食物或其他捕食者）
                if self.map[new_pos] != 1 and self.map[new_pos] != 2 and new_pos not in moved_positions:
                    # 检查是否与食物或其他捕食者冲突
                    pred['pos'] = new_pos
                    self.map[new_pos] = 1
                    moved_positions.add(new_pos)
                    new_predators.append(pred)
                    moved = True
                    break
            
            # 如果无法移动，留在原位（检查位置是否仍可用）
            if not moved:
                if current_pos not in moved_positions and self.map[current_pos] != 2:  # 不是食物
                    self.map[current_pos] = 1
                    moved_positions.add(current_pos)
                    new_predators.append(pred)
                else:
                    # 当前位置已被占用，在随机位置生成新捕食者
                    new_pos = self._random_position()
                    if self.map[new_pos] == 0 and new_pos not in moved_positions:
                        new_pred = {
                            'pos': new_pos,
                            'img_idx': np.random.randint(len(self.image_datasets['predator']))
                        }
                        self.map[new_pos] = 1
                        moved_positions.add(new_pos)
                        new_predators.append(new_pred)
        
        self.predators = new_predators

    def _random_position(self) -> Tuple[int, int]:
        """随机位置（地图内）"""
        return (np.random.randint(0, MAP_SIZE), np.random.randint(0, MAP_SIZE))

    def _random_edge_position(self) -> Tuple[int, int]:
        """随机边缘位置"""
        edge = np.random.choice(['top', 'bottom', 'left', 'right'])
        if edge == 'top': 
            return (0, np.random.randint(0, MAP_SIZE))
        if edge == 'bottom': 
            return (MAP_SIZE - 1, np.random.randint(0, MAP_SIZE))
        if edge == 'left': 
            return (np.random.randint(0, MAP_SIZE), 0)
        return (np.random.randint(0, MAP_SIZE), MAP_SIZE - 1)  # right

    def _is_valid_position(self, pos: Tuple[int, int]) -> bool:
        """检查位置是否有效"""
        return 0 <= pos[0] < MAP_SIZE and 0 <= pos[1] < MAP_SIZE

    def render(self, mode='human'): # 添加 mode 参数以符合gym规范
        if mode != 'human':
            return

        # 检查存储的figure是否仍然有效和打开
        # 如果 self.fig 是 None (首次调用) 或 figure编号不存在 (窗口已关闭)
        if self.fig is None or not plt.fignum_exists(self.fig.number):
            # 如果存在一个其窗口已关闭的旧figure对象，清理它
            if self.fig is not None:
                plt.close(self.fig) # 关闭figure对象本身

            # 创建新的figure和axes
            self.fig, self.ax = plt.subplots(figsize=(8, 8)) # 可调整大小
            
            # 为imshow artist准备初始地图数据
            # self.map 应该在首次调用render之前由reset()初始化
            initial_display_map = self.map.copy()
            # self.agent_pos 也应该由reset()初始化
            # 确保属性存在，避免在reset()之前调用render()出错
            if hasattr(self, 'agent_pos') and self.agent_pos is not None:
                initial_display_map[self.agent_pos[0], self.agent_pos[1]] = 3 # 标记智能体

            # 创建图像艺术家 (image artist)
            self.current_map_image_artist = self.ax.imshow(
                initial_display_map,
                cmap='viridis',
                vmin=0,
                vmax=len(PIXEL_TYPES) - 1 
            )
            
            # 为新绘图添加颜色条、标题并关闭坐标轴显示
            self.fig.colorbar(self.current_map_image_artist, ax=self.ax, label='Entity Types')
            self.ax.set_title("Survival Game Environment") # 初始标题
            self.ax.axis('off')
            
            # 确保matplotlib处于交互模式
            if not plt.isinteractive():
                plt.ion()
            
            # 显示图形。对于某些后端，plt.show(block=False)或仅让事件循环运行就足够了。
            # self.fig.show() 可能更适合某些情况，或者直接依赖 plt.pause() 来刷新。
            # 为了简单起见，首次显示依赖于Gym环境的典型用法，即plt.pause()会处理。
            # 如果在Jupyter中，%matplotlib widget会自动处理显示。

        # --- 更新现有绘图以反映当前游戏状态 ---
        current_display_map = self.map.copy()
        if hasattr(self, 'agent_pos') and self.agent_pos is not None:
            current_display_map[self.agent_pos[0], self.agent_pos[1]] = 3 # 标记智能体当前位置

        # 更新图像艺术家的数据
        if self.current_map_image_artist: # 确保artist存在
            self.current_map_image_artist.set_data(current_display_map)
        
        # 更新标题
        title_str = "Survival Game"
        if hasattr(self, 'steps'): title_str += f" Step: {self.steps}"
        if hasattr(self, 'agent_pos'): title_str += f", Agent: {self.agent_pos}"
        if hasattr(self, 'score'): title_str += f", Score: {self.score:.1f}"
        
        if self.ax: # 确保ax存在
            self.ax.set_title(title_str)
        
        # 重绘画布
        if self.fig: # 确保fig存在
            self.fig.canvas.draw_idle()
            # self.fig.canvas.flush_events() # 对于某些后端可能需要
        
        # 暂停以允许绘图更新并可见
        plt.pause(0.05) # 较短的暂停时间，使动画更平滑

    def close(self):
        if self.fig is not None:
            plt.close(self.fig) # 关闭图形窗口和figure对象
            self.fig = None
            self.ax = None
            self.current_map_image_artist = None


# 2. 模型

## 2.1 模型架构

In [3]:

# ========================
# Model Architecture (Wrappers for Perception)
# ========================
class PerceptionModule(nn.Module):
    def __init__(self, pretrained_convnext=True):
        super(PerceptionModule, self).__init__()
        self.convnext = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.DEFAULT if pretrained_convnext else None)
        self.convnext.classifier = nn.Identity() # Remove original classifier
        
        self.do_normalize = pretrained_convnext
        if self.do_normalize:
            self.normalize_transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is (B, N_VIEWS, C, H, W), scaled [0,1]
        
        # Hooks for visualization (GradCAM needs these on a specific layer, see ModelVisualizer)
        # For general input/output grads if needed:
        # if x.requires_grad:
        #     x.register_hook(lambda grad: setattr(self, 'input_gradients', grad))
        
        batch_size, num_views, C, H, W = x.size()
        x_input_to_convnext = x.view(batch_size * num_views, C, H, W)

        if self.do_normalize:
            x_input_to_convnext = self.normalize_transform(x_input_to_convnext)
        
        # Register hook for GradCAM on the output of the target layer if not done externally
        # features = self.convnext(x_input_to_convnext)
        # For GradCAM, typically want features from a specific conv layer, not the final output of convnext here.
        # The ModelVisualizer will handle hooking the specific internal layer.
        
        raw_features = self.convnext(x_input_to_convnext) # Output: (B*N_VIEWS, feature_dim e.g. 768)
        return raw_features.view(batch_size, num_views, -1) # (B, N_VIEWS, feature_dim)

class PretrainedEncoderWrapper(nn.Module):
    def __init__(self, encoder_model: nn.Module):
        super().__init__()
        self.encoder = encoder_model # This is a ConvNeXt model (autoencoder.encoder)
        # AE encoder was trained on [0,1] images, so no ImageNet normalization here.
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is (batch_size, num_views, C, H, W), already scaled to [0,1]
        batch_size, num_views, C, H, W = x.size()
        x_reshaped = x.view(batch_size * num_views, C, H, W)
        
        features = self.encoder(x_reshaped) # encoder is ConvNeXt, outputs (B*N_VIEWS, 768)
        return features.view(batch_size, num_views, -1)

# Decision module as defined in doc (used by ModelVisualizer, SB3 PPO has its own MLP head)
class StandaloneDecisionModule(nn.Module):
    def __init__(self, input_dim: int = NUM_VIEWS * 768, hidden_dims: List[int] = [512, 256]):
        super(StandaloneDecisionModule, self).__init__()
        layers = []
        prev_dim = input_dim
        for dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, dim))
            layers.append(nn.ReLU())
            prev_dim = dim
        layers.append(nn.Linear(prev_dim, 3)) # Output 3 action probabilities
        layers.append(nn.Softmax(dim=-1))
        self.net = nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is (B, N_VIEWS * feature_dim)
        return self.net(x)


# ========================
# 自编码器
# ========================
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.DEFAULT) # Or None for from scratch
        self.encoder.classifier = nn.Identity() # Output of encoder is (B, 768)

        # Decoder: (B, 768, 1, 1) -> (B, 3, 100, 100)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(768, 512, kernel_size=5, stride=1, padding=0), # (B, 512, 5, 5)
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=5, padding=0), # (B, 256, 25, 25)
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # (B, 128, 50, 50)
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # (B, 64, 100, 100)
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1),   # (B, 3, 100, 100)
            nn.Sigmoid() # Output images in [0,1]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor: # x is (B, 3, H, W)
        encoded_flat = self.encoder(x) # (B, 768)
        encoded_reshaped = encoded_flat.view(-1, 768, 1, 1) # Reshape for ConvTranspose
        decoded = self.decoder(encoded_reshaped)
        return decoded

# ========================
# 特征提取器 for Stable Baselines3
# ========================
class CustomFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Box, perception_module: nn.Module):
        # features_dim is N_VIEWS * perception_module_output_dim_per_view
        # Assuming perception_module outputs 768 features per view
        super(CustomFeatureExtractor, self).__init__(observation_space, features_dim=NUM_VIEWS * 768)
        self.perception = perception_module # This is PerceptionModule or PretrainedEncoderWrapper

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        # Input: (batch, NUM_VIEWS, H, W, C) uint8
        # Permute to (batch, NUM_VIEWS, C, H, W) and scale to [0,1]
        observations_processed = observations.permute(0, 1, 4, 2, 3).float() / 255.0
        
        # self.perception module (PerceptionModule or PretrainedEncoderWrapper)
        # expects (B, N_VIEWS, C, H, W) scaled [0,1]
        # and returns (B, N_VIEWS, feature_dim_per_view)
        features_per_view = self.perception(observations_processed)
        
        # Flatten features from all views: (B, N_VIEWS * feature_dim_per_view)
        return features_per_view.reshape(observations.size(0), -1)

## 2.2 模型训练

In [4]:

# ========================
# 训练方法
# ========================
def fitness_training(env: gym.Env, total_timesteps: int = 10000):
    # PerceptionModule handles its own normalization if pretrained
    perception_fitness = PerceptionModule(pretrained_convnext=True)
    # Move perception_fitness to the device PPO will use.
    # If PPO is on CPU, perception module should also be on CPU.
    # If PPO is on GPU, perception module should also be on GPU.
    # Since we're likely choosing 'cpu' for PPO, let's ensure it.
    # However, CustomFeatureExtractor in SB3 handles moving its submodules
    # to the PPO agent's device automatically. So, manual moving here isn't strictly necessary
    # but doesn't hurt for clarity if you want to be explicit.
    # device_to_use = "cpu" # or "cuda" if you decide to try GPU despite warning
    # perception_fitness.to(device_to_use)

    policy_kwargs = dict(
        features_extractor_class=CustomFeatureExtractor,
        features_extractor_kwargs=dict(perception_module=perception_fitness),
        # Corrected net_arch: remove the outer list
        net_arch=dict(pi=[256, 128], vf=[256, 128]),
    )

    # Corrected device: explicitly set to "cpu" to address the warning
    model = PPO("MlpPolicy", env, verbose=1, policy_kwargs=policy_kwargs, device="cpu")
    model.learn(total_timesteps=total_timesteps)
    print(f"Total timesteps = {total_timesteps}")
    return model

def truth_training(env: gym.Env, ae_path: str, total_timesteps: int = 10000):
    autoencoder = Autoencoder()
    # Determine device before loading state_dict if model might move
    # For simplicity, let's assume ae_encoder will be moved to PPO's device
    # by the CustomFeatureExtractor.
    autoencoder.load_state_dict(torch.load(ae_path)) # load_state_dict typically loads to CPU by default
    ae_encoder = autoencoder.encoder

    for param in ae_encoder.parameters(): # Freeze encoder parameters
        param.requires_grad = False

    perception_truth = PretrainedEncoderWrapper(ae_encoder)
    # Similar to fitness_training, CustomFeatureExtractor will handle device placement
    # device_to_use = "cpu" # or "cuda"
    # perception_truth.to(device_to_use)

    policy_kwargs = dict(
        features_extractor_class=CustomFeatureExtractor,
        features_extractor_kwargs=dict(perception_module=perception_truth),
        # Corrected net_arch: remove the outer list
        net_arch=dict(pi=[256, 128], vf=[256, 128])
    )

    # Corrected device: explicitly set to "cpu"
    model = PPO("MlpPolicy", env, verbose=1, policy_kwargs=policy_kwargs, device="cpu")
    model.learn(total_timesteps=total_timesteps)
    print(f"Total timesteps = {total_timesteps}")
    return model

def train_autoencoder(image_datasets: Dict[str, List[np.ndarray]],
                      epochs: int = 10,
                      batch_size: int = 32,
                      ae_save_path: str = "autoencoder.pth",
                      num_workers: int = 0 # For DataLoader: 0 means data loaded in main process
                     ) -> Autoencoder: # Assuming Autoencoder class is defined
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Autoencoder().to(device) # Ensure Autoencoder class is defined
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Combine all images for training from the input dictionary
    all_images_np = [img for cat_imgs in image_datasets.values() for img in cat_imgs]

    if not all_images_np:
        print("No images found to train the autoencoder. Aborting.")
        return model # Or raise an error

    # Define the transformation - ToTensor is crucial
    # It converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255]
    # to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    # Create the custom dataset
    custom_dataset = CustomImageDataset(image_list=all_images_np, transform=transform)

    # Create the DataLoader
    # The DataLoader will now handle batching and shuffling efficiently.
    # Images are loaded and transformed on-the-fly (or by worker processes if num_workers > 0)
    # and only a batch at a time is moved to the GPU.
    dataloader = DataLoader(custom_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=num_workers, # Adjust based on your system
                            pin_memory=True if device.type == 'cuda' else False) # Speeds up CPU to GPU transfer

    print(f"Training Autoencoder with {len(custom_dataset)} images on {device}...")
    print(f"Number of batches per epoch: {len(dataloader)}")

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for i, inputs in enumerate(dataloader): # inputs are (B, C, H, W)
            # Move the batch of inputs to the designated device
            inputs = inputs.to(device)

            optimizer.zero_grad()
            outputs = model(inputs) # Autoencoder reconstructs the input
            loss = criterion(outputs, inputs)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

            if (i + 1) % (max(1, len(dataloader) // 10)) == 0: # Print progress 10 times per epoch
                print(f"  Epoch {epoch+1}/{epochs}, Batch {i+1}/{len(dataloader)}, "
                      f"Current Avg Loss: {loss.item():.4f}")


        epoch_loss = running_loss / len(custom_dataset)
        print(f"Epoch {epoch+1}/{epochs} finished. Average Training Loss: {epoch_loss:.4f}")

    torch.save(model.state_dict(), ae_save_path)
    print(f"Autoencoder trained and saved to {ae_save_path}")
    return model


# 3. 评估

## 3.1 Survival Evaluation

In [5]:

# ========================
# 生存评估函数
# ========================

# ... (evaluate_survival) ...
def evaluate_survival(model: PPO, env: gym.Env, n_episodes: int = 10) -> float:
    total_steps = 0
    for _ in range(n_episodes):
        obs, _ = env.reset()
        terminated = False
        truncated = False
        episode_steps = 0

        while not (terminated or truncated):
            # For Box space, model.predict returns the continuous action vector
            # These are the "raw_actions" (logits/preferences) our env.step now expects
            action_vector, _ = model.predict(obs, deterministic=True) # Use deterministic for evaluation

            # Environment interaction (env.step now handles softmax internally)
            obs, _, terminated, truncated, _ = env.step(action_vector)
            episode_steps += 1

        total_steps += episode_steps
    return total_steps / n_episodes

# ... (evaluate_survival_with_render) ...
def evaluate_survival_with_render(model: PPO, env: gym.Env, n_episodes: int = 1) -> float:
    total_steps = 0
    if n_episodes <= 0:
        return 0.0

    for episode in range(n_episodes):
        obs, _ = env.reset()
        terminated = False
        truncated = False
        episode_steps = 0
        print(f"Starting Episode {episode + 1} with rendering...")

        while not (terminated or truncated):
            env.render()

            # For Box space, model.predict returns the continuous action vector
            action_vector, _ = model.predict(obs, deterministic=True)

            # Environment interaction
            obs, reward, terminated, truncated, _ = env.step(action_vector)
            episode_steps += 1

            if episode_steps % 50 == 0:
                print(f"  Episode {episode + 1}, Step {episode_steps}, Current Score: {env.score:.1f}")

        total_steps += episode_steps
        print(f"Episode {episode + 1} finished after {episode_steps} steps. Final Score: {env.score:.1f}")

    env.close()
    return total_steps / n_episodes

## 3.2 Truth Evaluation

In [6]:

# ========================
# 可视化评估工具
# ========================
class ModelVisualizer:
    def __init__(self, perception_module: nn.Module, decision_module: nn.Module):
        self.perception = perception_module
        self.decision = decision_module # PPO's policy_net (mlp_extractor.policy_net)
        
        # --- Hooks and stored data management ---
        # For GradCAM specific target layer
        self.target_layer_hook_activations: Optional[torch.Tensor] = None
        self.target_layer_hook_gradients: Optional[torch.Tensor] = None
        
        # For VBP/GuidedBP (and general hook management)
        self._hook_handles: List[torch.utils.hooks.RemovableHandle] = []
        self.module_to_forward_output: Dict[nn.Module, torch.Tensor] = {}

    def cleanup_hooks(self):
        """Removes all registered hooks and clears stored hook-related data."""
        for handle in self._hook_handles:
            handle.remove()
        self._hook_handles = []
        
        self.target_layer_hook_activations = None
        self.target_layer_hook_gradients = None
        self.module_to_forward_output = {}

    # --- Hook callback functions ---
    def _store_forward_output_hook(self, module: nn.Module, input_val: Any, output_val: torch.Tensor):
        """Stores the output of a module during forward pass (for VBP)."""
        self.module_to_forward_output[module] = output_val.detach()

    def _gelu_hook_function_vbp(self, module: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: Tuple[torch.Tensor, ...]) -> Optional[Tuple[torch.Tensor, ...]]:
        """Approximated Guided Backpropagation hook for GELU."""
        if module not in self.module_to_forward_output:
            # This can happen if forward pass didn't go through this specific module
            # or if hooks were not registered correctly for the forward pass.
            return None # Or grad_input if we want to allow passthrough

        corresponding_forward_output = self.module_to_forward_output[module]
        # Guided BP logic: only pass gradient if grad_output is positive AND forward output was positive
        # grad_output[0] is the gradient w.r.t. the module's output.
        # torch.clamp(grad_output[0], min=0.0) ensures only positive gradients from above are considered.
        # (corresponding_forward_output > 0).float() ensures neuron was active.
        guided_grad = torch.clamp(grad_output[0], min=0.0) * (corresponding_forward_output > 0).float()
        return (guided_grad,)

    def _relu_hook_function_vbp(self, module: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: Tuple[torch.Tensor, ...]) -> Optional[Tuple[torch.Tensor, ...]]:
        """Guided Backpropagation hook for ReLU."""
        if module not in self.module_to_forward_output:
            return None
            
        corresponding_forward_output = self.module_to_forward_output[module]
        guided_grad = torch.clamp(grad_output[0], min=0.0) * (corresponding_forward_output > 0).float()
        return (guided_grad,)

    def _gradcam_activation_hook(self, module: nn.Module, input_val: Any, output_val: torch.Tensor):
        """Stores activations for GradCAM."""
        self.target_layer_hook_activations = output_val.detach()

    def _gradcam_gradient_hook(self, module: nn.Module, grad_input: Any, grad_output: Tuple[torch.Tensor, ...]):
        """Stores gradients for GradCAM."""
        self.target_layer_hook_gradients = grad_output[0].detach()

    # --- Main Visualization Methods ---
    def _register_gradcam_hooks(self):
        """Registers hooks for GradCAM on the target ConvNeXt layer."""
        self.cleanup_hooks() # Clear any previous hooks

        convnext_model = None
        if hasattr(self.perception, 'convnext'): # For PerceptionModule
            convnext_model = self.perception.convnext
        elif hasattr(self.perception, 'encoder'): # For PretrainedEncoderWrapper
            convnext_model = self.perception.encoder
        else:
            raise TypeError("Perception module is of an unknown type for GradCAM hook registration.")

        try:
            # Target the output of the last stage in ConvNeXt features
            target_layer = convnext_model.features[-1] 
            handle_fwd = target_layer.register_forward_hook(self._gradcam_activation_hook)
            handle_bwd = target_layer.register_full_backward_hook(self._gradcam_gradient_hook)
            self._hook_handles.extend([handle_fwd, handle_bwd])
        except Exception as e:
            print(f"Error registering GradCAM hooks: {e}. GradCAM might not work correctly.")
            self.cleanup_hooks() # Ensure partial hooks are removed
    
    def activation_maximization(self, action_idx: int, lr: float = 0.1, steps: int = 200, num_views_for_am = NUM_VIEWS) -> np.ndarray:
        self.perception.eval()
        self.decision.eval()
        print(f"Starting AM for action {action_idx}...")

        # Optimizes a single (1,3,H,W) image, assuming it's one of the N_VIEWS inputs
        # The perception module will process it as (1, 1, C, H, W) effectively
        # Then its features are replicated for the decision module.
        optimized_image_tensor = torch.rand(1, 3, 100, 100, requires_grad=True, device=next(self.perception.parameters()).device)
        optimizer = optim.Adam([optimized_image_tensor], lr=lr, weight_decay=1e-4)

        for i in range(steps):
            optimizer.zero_grad()
            
            # Clamp and ensure image is in [0,1] range for perception module
            current_image_0_1 = torch.clamp(optimized_image_tensor, 0.0, 1.0)
            
            # Perception module expects (B, N_VIEWS, C, H, W)
            # We form an input where one view is the optimized image, others could be neutral (e.g., gray)
            # For simplicity in AM: assume the optimized image is so dominant it works if it's just one view.
            # The self.perception here is the SB3 model's feature extractor's perception part.
            # It expects (B, N_VIEWS, C, H, W) format.
            # So, let's treat the optimized image as if it's all N_VIEWS for AM purposes.
            multi_view_input = current_image_0_1.repeat(1, num_views_for_am, 1, 1, 1).squeeze(0) # (N_VIEWS, C, H, W)
            multi_view_input = multi_view_input.unsqueeze(0) # (1, N_VIEWS, C, H, W)


            # Get features from perception ( (1, N_VIEWS, feat_dim) )
            features_per_view = self.perception(multi_view_input)
            # Flatten for decision module ( (1, N_VIEWS * feat_dim) )
            flat_features = features_per_view.view(1, -1)
            
            action_distribution = self.decision(flat_features) # decision is PPO's policy_net
            
            loss = -action_distribution[0, action_idx] # Maximize prob of this action
            
            # Add some regularization to the image (e.g., total variation)
            loss += 0.0001 * torch.sum(torch.abs(current_image_0_1[:, :, :, :-1] - current_image_0_1[:, :, :, 1:])) + \
                    0.0001 * torch.sum(torch.abs(current_image_0_1[:, :, :-1, :] - current_image_0_1[:, :, 1:, :]))

            loss.backward()
            optimizer.step()
            if i % (steps // 10) == 0:
                 print(f"AM step {i}, loss {loss.item()}")

        final_image_0_1 = torch.clamp(optimized_image_tensor.detach(), 0.0, 1.0)
        generated_np = final_image_0_1.squeeze().permute(1, 2, 0).cpu().numpy()
        return (generated_np * 255).astype(np.uint8)

    def grad_cam(self, obs_tensor_0_1: torch.Tensor, action_idx: int, target_view_idx: int = 0) -> Optional[np.ndarray]:
        # obs_tensor_0_1 is (1, N_VIEWS, C, H, W), scaled [0,1]
        # target_view_idx specifies which of the N_VIEWS to generate GradCAM for.
        self.perception.eval()
        self.decision.eval()
        self._register_gradcam_hooks() # Ensure hooks are on the correct layer
        
        if not self._hook_handles: # Check if hooks were successfully registered
            return None
        
        
        obs_tensor_0_1.requires_grad_(True)
        
        # Forward pass
        # self.perception is the perception module from the SB3 agent (e.g., PerceptionModule or PretrainedEncoderWrapper)
        features_per_view = self.perception(obs_tensor_0_1) # (1, N_VIEWS, feat_dim)
        flat_features = features_per_view.view(1, -1)       # (1, N_VIEWS * feat_dim)
        
        # self.decision is the PPO's policy_net
        action_distribution = self.decision(flat_features) # (1, num_actions)
        
        # Backward pass for the target action
        self.perception.zero_grad() # Zero grads for perception module's ConvNeXt
        if self.decision.parameters(): # Also zero grads for decision MLP if it has params
            self.decision.zero_grad()

        score = action_distribution[0, action_idx]
        score.backward(retain_graph=True)

        if self.target_layer_hook_activations is None or self.target_layer_hook_gradients is None:
            print("GradCAM: Activations or gradients not captured. Hooks might not be set correctly.")
            return None

        # Activations/Gradients are from the ConvNeXt internal layer, shape (N_VIEWS_eff, C_feat, H_feat, W_feat)
        # N_VIEWS_eff is batch_size * num_views from the perception module's internal reshaping. Here batch_size=1.
        activations_all_views = self.target_layer_hook_activations # (NUM_VIEWS, C_feat, H_feat, W_feat)
        gradients_all_views = self.target_layer_hook_gradients     # (NUM_VIEWS, C_feat, H_feat, W_feat)

        # Select the specific view
        activations_target_view = activations_all_views[target_view_idx] # (C_feat, H_feat, W_feat)
        gradients_target_view = gradients_all_views[target_view_idx]   # (C_feat, H_feat, W_feat)
        
        # Compute weights (alpha_k)
        pooled_gradients = torch.mean(gradients_target_view, dim=[1, 2]) # (C_feat)
        
        # Weight activations
        for i in range(activations_target_view.shape[0]): # Loop over channels
            activations_target_view[i, :, :] *= pooled_gradients[i]
            
        heatmap = torch.mean(activations_target_view, dim=0).cpu().numpy() # (H_feat, W_feat)
        heatmap = np.maximum(heatmap, 0) # ReLU
        if np.max(heatmap) > 0:
            heatmap /= np.max(heatmap) # Normalize
        
        # Resize to original image size
        original_h, original_w = obs_tensor_0_1.shape[-2:]
        heatmap_resized = cv2.resize(heatmap, (original_w, original_h))
        return heatmap_resized
    
    def visual_back_prop(self, obs_tensor_0_1: torch.Tensor, target_view_idx: int = 0) -> Optional[np.ndarray]:
        """
        Generates a VisualBackProp (Guided Backpropagation style) saliency map.
        Shows general input patterns contributing to the perception module's features for a view.
        """
        self.perception.eval()
        self.cleanup_hooks() # Clears all hooks and stored data (module_to_forward_output too)

        convnext_model = None
        if hasattr(self.perception, 'convnext'):
            convnext_model = self.perception.convnext
        elif hasattr(self.perception, 'encoder'):
            convnext_model = self.perception.encoder
        else:
            print("VBP: Perception module type not recognized.")
            return None

        # Register VBP hooks on all GELU/ReLU layers in the ConvNeXt model
        for module_name, module in convnext_model.named_modules():
            if isinstance(module, nn.GELU):
                self._hook_handles.append(module.register_forward_hook(self._store_forward_output_hook))
                self._hook_handles.append(module.register_full_backward_hook(self._gelu_hook_function_vbp))
            elif isinstance(module, nn.ReLU): # Fallback for any ReLUs
                self._hook_handles.append(module.register_forward_hook(self._store_forward_output_hook))
                self._hook_handles.append(module.register_full_backward_hook(self._relu_hook_function_vbp))
        
        if not self._hook_handles:
            print("VBP: No suitable activation layers (GELU/ReLU) found to hook in ConvNeXt.")
            return None # cleanup_hooks already called, so state is clean.

        # Prepare input image: clone, detach, and set requires_grad
        input_img_for_vbp = obs_tensor_0_1.clone().detach().requires_grad_(True)
        
        # --- Forward pass ---
        # This single forward pass will:
        # 1. Populate self.module_to_forward_output via the forward hooks.
        # 2. Give us features_per_view to backpropagate from.
        features_per_view = self.perception(input_img_for_vbp) # Output: (B, N_VIEWS, feature_dim)
        
        # --- Backward pass ---
        self.perception.zero_grad() # Zero gradients for the perception model parameters
        if input_img_for_vbp.grad is not None:
            input_img_for_vbp.grad.data.zero_()

        # Target for backpropagation: Sum of features for the target_view_idx.
        # This gives a "general pattern" for that view's features.
        target_features_for_bp = features_per_view[0, target_view_idx, :] # Shape: (feature_dim)
        
        # Gradient for backward pass (sum of features -> gradient of 1 for each feature)
        grad_outputs_for_bp = torch.ones_like(target_features_for_bp)
        
        target_features_for_bp.backward(gradient=grad_outputs_for_bp, retain_graph=False)

        # --- Retrieve and process gradient on the input image ---
        saliency_grad = input_img_for_vbp.grad
        if saliency_grad is None:
            print("VBP: Gradient for input image was not computed.")
            # Hooks will be cleaned up by the next call to a viz method or external `visualizer.cleanup_hooks()`
            return None

        # Take absolute value of gradients for the target view
        saliency_abs = saliency_grad.data.abs() # Shape: (1, N_VIEWS, C, H, W)
        saliency_target_view_abs = saliency_abs[0, target_view_idx, :, :, :] # Shape: (C, H, W)
        
        saliency_target_view_np = saliency_target_view_abs.cpu().numpy()
        
        # Normalize: max across channels, then scale to [0, 255]
        saliency_map = np.max(saliency_target_view_np, axis=0) # Shape: (H, W)
        if np.max(saliency_map) > 0:
            saliency_map /= np.max(saliency_map)
        saliency_map_uint8 = (saliency_map * 255).astype(np.uint8)
        
        # Hooks will be cleaned up by the next call to a viz method or an explicit call to self.cleanup_hooks().
        return saliency_map_uint8
    def cleanup_hooks(self):
        for handle in self._hook_handles:
            handle.remove()
        self._hook_handles = []


# ========================
# 可视化评估函数
# ========================
def visualize_truth_evaluation(
    model: PPO,
    # MODIFIED: Now takes a specific subset of images to evaluate
    evaluation_image_subset: Dict[str, List[Tuple[np.ndarray, int]]], # List of (image_array, original_index_or_id)
    generic_env_img: np.ndarray, # Pass the pre-selected generic environment image
    env: gym.Env, # Still useful for some observation space properties if needed
    run_name="default"
):
    print(f"\nPerforming Truth visualization evaluation for {run_name}...")

    device = model.device # Get device from the model

    feature_extractor = model.policy.features_extractor
    if not hasattr(feature_extractor, 'perception'):
        print("Error: PPO model's feature_extractor does not have a 'perception' attribute.")
        return

    perception_module_from_agent = feature_extractor.perception
    decision_module_from_agent = model.policy.mlp_extractor.policy_net

    visualizer = ModelVisualizer(perception_module_from_agent, decision_module_from_agent)
    action_names = ['Escape', 'Eat', 'Wander']

    # 1. Activation Maximization (run once per model, as it generates images)
    print(f"[{run_name}] Generating Activation Maximization visualizations...")
    am_output_dir = f"am_results_{run_name}"
    os.makedirs(am_output_dir, exist_ok=True)
    for action_idx, action_name in enumerate(action_names):
        print(f"  Running AM for intent: {action_name}")
        am_image_np = visualizer.activation_maximization(action_idx, steps=200)
        if am_image_np is not None:
            plt.figure(figsize=(5, 5))
            plt.imshow(am_image_np)
            plt.title(f'AM ({run_name}): Target Intent Neuron {action_name}')
            plt.axis('off')
            plt.savefig(os.path.join(am_output_dir, f'am_{action_name.lower()}.png'))
            plt.close()
        else:
            print(f"  AM failed for intent: {action_name}")

    # --- Specific Image Evaluation ---
    print(f"[{run_name}] Generating Saliency Maps for specific image types...")

    saliency_output_dir = f"saliency_results_{run_name}"
    os.makedirs(saliency_output_dir, exist_ok=True)

    # MODIFIED: Iterate through the pre-selected subset of images
    for img_type_name, images_to_eval in evaluation_image_subset.items():
        print(f"  Evaluating image type: {img_type_name} ({len(images_to_eval)} images)")
        if not images_to_eval:
            print(f"    Skipping '{img_type_name}', no images provided in subset.")
            continue

        for target_image_np, img_idx_or_id in images_to_eval: # img_idx_or_id for unique naming
            print(f"    Processing image index/id: {img_idx_or_id} for type {img_type_name}")
            # Construct the 4-view observation: target image in front, others generic
            current_obs_np = np.stack([
                target_image_np,    # Front view (our target image)
                generic_env_img,    # Left view
                generic_env_img,    # Right view
                generic_env_img     # Back view
            ])

            obs_tensor_0_1 = torch.tensor(current_obs_np, dtype=torch.float32, device=device)
            obs_tensor_0_1 = obs_tensor_0_1.permute(0, 3, 1, 2).unsqueeze(0) / 255.0

            target_view_idx_for_viz = 0

            # 2. Saliency Maps (GradCAM) - For each of the 3 intents
            # print(f"    Generating GradCAM for {img_type_name} image {img_idx_or_id}...") # Verbose
            for intent_idx, intent_name in enumerate(action_names):
                grad_cam_map = visualizer.grad_cam(obs_tensor_0_1.clone(), intent_idx, target_view_idx=target_view_idx_for_viz)

                plt.figure(figsize=(10, 5))
                # MODIFIED: More specific title and filename
                title = f'GradCAM ({run_name}) - Img: {img_type_name}_{img_idx_or_id}, Intent: {intent_name}'
                filename = f'gradcam_{img_type_name}_{img_idx_or_id}_intent_{intent_name.lower()}.png'
                plt.suptitle(title, fontsize=14)

                plt.subplot(1, 2, 1)
                plt.imshow(target_image_np)
                plt.title(f'Input: {img_type_name} (ID: {img_idx_or_id})')
                plt.axis('off')

                if grad_cam_map is not None:
                    plt.subplot(1, 2, 2)
                    plt.imshow(target_image_np, alpha=0.7)
                    plt.imshow(grad_cam_map, cmap='jet', alpha=0.3)
                    plt.title(f'GradCAM Overlay')
                    plt.axis('off')
                else:
                    plt.subplot(1, 2, 2)
                    plt.text(0.5, 0.5, "GradCAM Failed", ha='center', va='center')
                    plt.axis('off')

                plt.tight_layout(rect=[0, 0, 1, 0.96])
                plt.savefig(os.path.join(saliency_output_dir, filename))
                plt.close()

            # 3. VisualBackProp
            # print(f"    Generating VisualBackProp for {img_type_name} image {img_idx_or_id}...") # Verbose
            vbp_map = visualizer.visual_back_prop(obs_tensor_0_1.clone(), target_view_idx=target_view_idx_for_viz)

            plt.figure(figsize=(10, 5))
            # MODIFIED: More specific title and filename
            title = f'VisualBackProp ({run_name}) - Img: {img_type_name}_{img_idx_or_id}'
            filename = f'vbp_{img_type_name}_{img_idx_or_id}.png'
            plt.suptitle(title, fontsize=14)

            plt.subplot(1, 2, 1)
            plt.imshow(target_image_np)
            plt.title(f'Input: {img_type_name} (ID: {img_idx_or_id})')
            plt.axis('off')

            if vbp_map is not None:
                plt.subplot(1, 2, 2)
                plt.imshow(vbp_map, cmap='gray')
                plt.title('VBP Saliency')
                plt.axis('off')
            else:
                plt.subplot(1,2,2)
                plt.text(0.5, 0.5, "VBP Failed", ha='center', va='center')
                plt.axis('off')

            plt.tight_layout(rect=[0, 0, 1, 0.96])
            plt.savefig(os.path.join(saliency_output_dir, filename))
            plt.close()

    visualizer.cleanup_hooks() # Final cleanup
    print(f"Truth visualization for {run_name} (specific images) complete. AM results in '{am_output_dir}', Saliency in '{saliency_output_dir}'.")

# 4. 主工作流程

## 4.1 加载数据

In [None]:

# ========================
# 主工作流程
# ========================

# Training parameters
TRAIN_TIMESTEPS = 20480 # Very short for testing, increase for real training (e.g., 50000+)

base_directory = "TrainingData"  # 替换为你的文件夹路径
image_datasets = load_images_to_dict(base_directory)
# 打印结构信息
print(f"Loaded {len(image_datasets)} folders")
for folder, images in image_datasets.items():
    print(f"{folder}: {len(images)} images, shape={images[0].shape if images else 'N/A'}")

env = SurvivalGameEnv(image_datasets)


Loaded 3 folders
environment: 22736 images, shape=(100, 100, 3)
food: 102790 images, shape=(100, 100, 3)
predator: 861 images, shape=(100, 100, 3)


## 4.2 预训练自编码器

In [None]:

# Pretrain Autoencoder
# print("Pretraining Autoencoder...")
# autoencoder_model = train_autoencoder(image_datasets, epochs=5, ae_save_path="autoencoder.pth") # Short epochs for test


## 4.3 训练Fitness模型

In [None]:

# Train Fitness model
print("\nTraining Fitness model...")
fitness_model = fitness_training(env, total_timesteps=TRAIN_TIMESTEPS)
fitness_model.save("ppo_fitness_model")


## 4.4 训练Truth模型

In [None]:

# Train Truth model
print("\nTraining Truth model...")
truth_model = truth_training(env, ae_path="autoencoder.pth", total_timesteps=TRAIN_TIMESTEPS)
truth_model.save("ppo_truth_model")


## 4.5 评估模型

In [9]:

# Load models if needed (example)
fitness_model = PPO.load("ppo_fitness_model", env=env)
truth_model = PPO.load("ppo_truth_model", env=env)

# Evaluate models
print("\nEvaluating models for survival...")
# 非实时渲染fitness模型评估
fitness_score = evaluate_survival(fitness_model, env, n_episodes=5)
truth_score = evaluate_survival(truth_model, env, n_episodes=5)
# 实时渲染fitness模型评估
# fitness_score = evaluate_survival_with_render(fitness_model, env, n_episodes=1)
# truth_score = evaluate_survival_with_render(truth_model, env, n_episodes=1)

print(f"\nResults:")
print(f"Fitness Model Average Survival: {fitness_score:.1f} steps")
print(f"Truth Model Average Survival: {truth_score:.1f} steps")

if 'image_datasets' not in locals():
    print("Reloading image_datasets for visualization...")
    base_directory = "TrainingData" 
    image_datasets = load_images_to_dict(base_directory)
# Truth visualization evaluation
print("\\nPreparing images for Truth visualization evaluation...")
N_IMAGES_PER_CATEGORY = 10
evaluation_image_subset = {}
image_categories_for_viz = ['predator', 'food', 'environment']

for category in image_categories_for_viz:
    if category in image_datasets and image_datasets[category]:
        available_images = image_datasets[category]
        num_to_sample = min(N_IMAGES_PER_CATEGORY, len(available_images))
        if len(available_images) < N_IMAGES_PER_CATEGORY:
            print(f"Warning: Category '{category}' has only {len(available_images)} images. Sampling all of them.")
        
        # Sample 'num_to_sample' images. random.sample ensures no replacement.
        # Store as (image_array, original_index_within_category_list_for_uniqueness)
        # If you need globally unique IDs, you might need a different strategy for 'idx'
        sampled_indices = random.sample(range(len(available_images)), num_to_sample)
        evaluation_image_subset[category] = [(available_images[i], i) for i in sampled_indices]
    else:
        print(f"Warning: No images found for category '{category}'. It will be skipped in visualization.")
        evaluation_image_subset[category] = []

# Select a generic environment image for filler views
generic_env_img_list = image_datasets.get('environment')
if not generic_env_img_list:
    print("Critical Warning: No 'environment' images found for generic filler. Using a black image.")
    # Assuming image shape from observation_space or a default
    h_obs, w_obs, c_obs = (100,100,3) 
    if hasattr(env.observation_space, 'shape') and len(env.observation_space.shape) == 4:
         h_obs, w_obs, c_obs = env.observation_space.shape[1:4]
    generic_env_img = np.zeros((h_obs, w_obs, c_obs), dtype=np.uint8)
else:
    generic_env_img = random.choice(generic_env_img_list)


# Truth visualization evaluation for Fitness Model
visualize_truth_evaluation(
    fitness_model, 
    evaluation_image_subset, 
    generic_env_img,
    env, 
    run_name="FitnessModel"
)

# Truth visualization evaluation for Truth Model
visualize_truth_evaluation(
    truth_model, 
    evaluation_image_subset, 
    generic_env_img,
    env, 
    run_name="TruthModel"
)


env.close()

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.

Evaluating models for survival...

Results:
Fitness Model Average Survival: 143.4 steps
Truth Model Average Survival: 199.2 steps
\nPreparing images for Truth visualization evaluation...

Performing Truth visualization evaluation for FitnessModel...
[FitnessModel] Generating Activation Maximization visualizations...
  Running AM for intent: Escape
Starting AM for action 0...
AM step 0, loss 2.053948402404785
AM step 20, loss 0.3078014850616455
AM step 40, loss 0.18436838686466217
AM step 60, loss 0.08877288550138474
AM step 80, loss 0.08494089543819427
AM step 100, loss 0.08495432138442993
AM step 120, loss 0.0851912647485733
AM step 140, loss 0.08495713025331497
AM step 160, loss 0.08518572151660919
AM step 180, loss 0.08520476520061493
  Running AM for intent: Eat
Starting AM for action 1...
AM step 0, loss 2.029921770095825
AM step 20, loss 0.2604913115501404
AM step 40, loss 0.13519862294197083
AM step 60