# 0. 导入

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.policies import BaseFeaturesExtractor
from PIL import Image
import matplotlib.pyplot as plt
from typing import Tuple, List, Dict, Optional, Any
import cv2
import random
from collections import defaultdict
import os

def load_images_to_dict(base_dir: str) -> Dict[str, List[np.ndarray]]:
    image_datasets = {}
    # 定义支持的图片扩展名
    valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.gif'}
    
    # 遍历基础目录下的所有子文件夹
    for folder_name in os.listdir(base_dir):
        folder_path = os.path.join(base_dir, folder_name)
        
        # 确保是文件夹
        if not os.path.isdir(folder_path):
            continue
        
        image_list = []
        
        # 遍历文件夹中的文件
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            
            # 检查文件扩展名
            ext = os.path.splitext(filename)[1].lower()
            if ext not in valid_extensions:
                continue  # 跳过非图片文件
            
            try:
                # 用PIL打开图片并转换为RGB（避免Alpha通道问题）
                with Image.open(file_path) as img:
                    img = img.convert('RGB')  # 统一转换为RGB三通道
                    image_list.append(np.array(img))
            except Exception as e:
                print(f"Error loading {file_path}: {str(e)}")
                continue
        
        # 将文件夹名作为key，图片数组列表作为value
        image_datasets[folder_name] = image_list
    
    return image_datasets

class CustomImageDataset(Dataset):
    def __init__(self, image_list: List[np.ndarray], transform: Any = None):
        """
        Args:
            image_list (List[np.ndarray]): List of images as NumPy arrays.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.images = image_list
        self.transform = transform

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, idx: int) -> torch.Tensor:
        image_np = self.images[idx]

        # Ensure image is in uint8 format for PIL conversion if necessary,
        # or if ToTensor is directly applied to numpy array.
        if image_np.dtype != np.uint8:
            # This can happen if images were, e.g., float32 from some processing.
            # Assuming they were originally 0-255 range if they are not uint8.
            # If they are already 0-1 float, this needs adjustment.
            # For typical image data, astype(np.uint8) is safe.
            image_np = image_np.astype(np.uint8)

        # transforms.ToTensor() can handle HWC uint8 NumPy arrays directly.
        # It will convert to CHW FloatTensor and scale to [0.0, 1.0].
        # If you prefer to use PIL Image explicitly:
        # image_pil = Image.fromarray(image_np)
        # if self.transform:
        #     image_tensor = self.transform(image_pil)
        # else: # Fallback if no transform, though ToTensor is crucial
        #     image_tensor = transforms.ToTensor()(image_pil)

        # Direct application to NumPy array (preferred if array is HWC)
        if self.transform:
            image_tensor = self.transform(image_np)
        else:
            # Default to ToTensor if no specific transform is given,
            # as it's essential for PyTorch models.
            image_tensor = transforms.ToTensor()(image_np)

        return image_tensor


# 1. Survival Game

In [None]:

# ========================
# 环境配置
# ========================
MAP_SIZE = 20
PREDATOR_COUNT = 5
MAX_FOOD = 50
CLUSTER_RADIUS = 3
PIXEL_TYPES = {
    0: 'environment',
    1: 'predator',
    2: 'food',
    3: 'agent'
}
REWARDS = {
    'predator': -50,  # 调整为文档要求的-50
    'food': +10,      # 调整为文档要求的+10
    'environment': -1
}
NUM_VIEWS = 4  # Front, Left, Right, Back
INITIAL_SCORE = 100  # 添加初始分数

# ========================
# 环境实现
# ========================
class SurvivalGameEnv(gym.Env):
    metadata = {'render_modes': ['human'], 'render_fps': 4}

    def __init__(self, image_datasets: Dict[str, List[np.ndarray]]):
        super(SurvivalGameEnv, self).__init__()

        self.image_datasets = image_datasets
        # This transform is for external use if needed, model handles its own
        self.vis_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        self.action_space = spaces.Discrete(3)  # Escape, Eat, Wander
        self.observation_space = spaces.Box(
            low=0, high=255,
            shape=(NUM_VIEWS, 100, 100, 3),  # (4 views, H, W, C)
            dtype=np.uint8
        )
        self.current_map_image = None  # For rendering

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> Tuple[np.ndarray, Dict[str, Any]]:
        super().reset(seed=seed)
        np.random.seed(seed)
        random.seed(seed)
        
        # 初始化地图
        self.map = np.zeros((MAP_SIZE, MAP_SIZE), dtype=np.uint8)
        self.predators = []
        self.foods = {}
        
        # 放置捕食者
        for _ in range(PREDATOR_COUNT):
            pos = self._random_edge_position()
            while self.map[pos] != 0:  # 确保位置为空
                pos = self._random_edge_position()
            self.map[pos] = 1
            self.predators.append({
                'pos': pos,
                'img_idx': np.random.randint(len(self.image_datasets['predator']))
            })
        
        # 集群式生成食物
        num_clusters = max(1, MAX_FOOD // 10)  # 每10个食物一个集群
        cluster_centers = [self._random_position() for _ in range(num_clusters)]
        
        for center in cluster_centers:
            # 使用高斯分布生成集群
            cluster_size = min(MAX_FOOD - len(self.foods), MAX_FOOD // num_clusters)
            for _ in range(cluster_size):
                # 在集群半径内生成随机偏移
                dx = int(np.random.normal(0, CLUSTER_RADIUS / 2))
                dy = int(np.random.normal(0, CLUSTER_RADIUS / 2))
                pos = (
                    np.clip(center[0] + dx, 0, MAP_SIZE - 1),
                    np.clip(center[1] + dy, 0, MAP_SIZE - 1)
                )
                
                if self._is_valid_position(pos) and self.map[pos] == 0:
                    self.map[pos] = 2
                    self.foods[pos] = {
                        'img_idx': np.random.randint(len(self.image_datasets['food'])),
                        'consumed': False
                    }
                    if len(self.foods) >= MAX_FOOD:
                        break
            if len(self.foods) >= MAX_FOOD:
                break

        # 放置代理
        self.agent_pos = self._random_position()
        while self.map[self.agent_pos] != 0:
            self.agent_pos = self._random_position()
        
        # 初始化代理方向和分数
        self.agent_direction = 0  # 0:上, 1:左, 2:右, 3:下
        self.score = INITIAL_SCORE  # 添加分数系统
        
        self.prev_agent_pos = self.agent_pos
        self.steps = 0
        self.terminated = False
        self.truncated = False

        return self._get_observation(), {}

    def step(self, action_probs: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
        # 1. 计算移动概率分布
        move_probs_dist = self._calculate_movement_distribution(action_probs)
        
        # 2. 选择移动方向
        chosen_direction_idx = np.random.choice(NUM_VIEWS, p=move_probs_dist)
        
        # 3. 移动代理并更新方向
        self.prev_agent_pos = self.agent_pos
        self.agent_pos = self._move_in_direction(self.agent_pos, chosen_direction_idx)
        self.agent_direction = chosen_direction_idx  # 移动后更新方向
        
        # 4. 处理交互和奖励
        reward = 0
        collided_entity_type = self.map[self.agent_pos]
        
        # 处理不同类型的碰撞
        if collided_entity_type == 1:  # 捕食者
            reward += REWARDS['predator']
            self.terminated = True
        elif collided_entity_type == 2:  # 食物
            if self.agent_pos in self.foods:
                reward += REWARDS['food']
                # 立即删除食物并变为环境
                del self.foods[self.agent_pos]
                self.map[self.agent_pos] = 0
        elif collided_entity_type == 0:  # 环境
            reward += REWARDS['environment']
        else:
            reward += REWARDS['environment']
        
        # 5. 更新分数并检查游戏结束条件
        self.score += reward
        if self.score < 1:
            self.terminated = True
        
        # 6. 移动捕食者
        self._move_predators()
        
        # 7. 检查捕食者是否移动到代理位置
        for pred in self.predators:
            if pred['pos'] == self.agent_pos:
                reward += REWARDS['predator']  # 确保即使代理移动后被捕食者追上也会受到惩罚
                self.terminated = True
                break
        
        # 8. 步数限制
        self.steps += 1
        if self.steps >= 1000:
            self.truncated = True
        
        return self._get_observation(), reward, self.terminated, self.truncated, {}

    def _get_direction_vectors(self) -> List[Tuple[int, int]]:
        """获取方向向量：前、左、右、后（相对于当前方向）"""
        # 绝对方向：0:上, 1:左, 2:右, 3:下
        # 根据代理当前方向计算相对方向
        if self.agent_direction == 0:  # 面向"上"
            return [(-1, 0), (0, -1), (0, 1), (1, 0)]  # 前、左、右、后
        elif self.agent_direction == 1:  # 面向"左"
            return [(0, -1), (1, 0), (-1, 0), (0, 1)]  # 前、左、右、后
        elif self.agent_direction == 2:  # 面向"右"
            return [(0, 1), (-1, 0), (1, 0), (0, -1)]  # 前、左、右、后
        else:  # 面向"下"
            return [(1, 0), (0, 1), (0, -1), (-1, 0)]  # 前、左、右、后

    def _get_observation(self) -> np.ndarray:
        """获取代理四个方向的观察图像（基于当前方向）"""
        observations = []
        direction_vectors = self._get_direction_vectors()
        
        for dr, dc in direction_vectors:
            target_pos = (
                (self.agent_pos[0] + dr) % MAP_SIZE,
                (self.agent_pos[1] + dc) % MAP_SIZE
            )
            entity_type = self.map[target_pos]
            
            if entity_type == 1:  # 捕食者
                # 查找对应捕食者的图像
                img_array = next(
                    (self.image_datasets['predator'][p['img_idx']] 
                    for p in self.predators if p['pos'] == target_pos),
                    np.random.choice(self.image_datasets['predator'])
                )
            elif entity_type == 2:  # 食物
                img_array = self.image_datasets['food'][self.foods.get(target_pos, {}).get('img_idx', 0)]
            else:  # 环境
                img_array = np.random.choice(self.image_datasets['environment'])
            
            observations.append(img_array)
        
        return np.array(observations, dtype=np.uint8)

    def _calculate_movement_distribution(self, action_probs: np.ndarray) -> np.ndarray:
        """计算移动概率分布（基于相对方向）"""
        # 获取周围实体
        surrounding_entities = []
        direction_vectors = self._get_direction_vectors()
        
        for dr, dc in direction_vectors:
            pos = (
                (self.agent_pos[0] + dr) % MAP_SIZE,
                (self.agent_pos[1] + dc) % MAP_SIZE
            )
            surrounding_entities.append(self.map[pos])
        
        move_dist = np.zeros(NUM_VIEWS)
        
        # Escape逻辑：远离感知到的捕食者（向相反方向移动）
        predator_sensed = False
        for i, entity in enumerate(surrounding_entities):
            if entity == 1:  # 捕食者
                # 相反方向是索引3-i（前<->后，左<->右）
                opposite_idx = 3 - i
                move_dist[opposite_idx] += action_probs[0]
                predator_sensed = True
        
        if not predator_sensed:
            move_dist += action_probs[0] / NUM_VIEWS
        
        # Eat逻辑：朝向感知到的食物
        food_sensed = False
        for i, entity in enumerate(surrounding_entities):
            if entity == 2:  # 食物
                move_dist[i] += action_probs[1]
                food_sensed = True
        
        if not food_sensed:
            move_dist += action_probs[1] / NUM_VIEWS
        
        # Wander逻辑：随机移动
        move_dist += action_probs[2] / NUM_VIEWS
        
        # 归一化
        total = np.sum(move_dist)
        if total > 0:
            move_dist /= total
        else:
            move_dist = np.ones(NUM_VIEWS) / NUM_VIEWS
        
        return move_dist

    def _move_in_direction(self, current_pos: Tuple[int, int], direction_idx: int) -> Tuple[int, int]:
        """在指定方向上移动位置（基于相对方向）"""
        dr, dc = self._get_direction_vectors()[direction_idx]
        new_r = (current_pos[0] + dr) % MAP_SIZE
        new_c = (current_pos[1] + dc) % MAP_SIZE
        return new_r, new_c

    def _move_predators(self):
        """改进的捕食者移动逻辑，解决移动冲突且禁止移动到食物位置"""
        # 清除所有捕食者位置标记（临时设为环境）
        for pred in self.predators:
            self.map[pred['pos']] = 0
        
        moved_positions = set()  # 记录所有计划移动到的位置
        new_predators = []       # 存储移动后的捕食者
        move_vectors = [(-1, 0), (1, 0), (0, -1), (0, 1)]  # 上、下、左、右
        
        # 随机顺序处理捕食者，确保公平性
        random.shuffle(self.predators)
        
        for pred in self.predators:
            current_pos = pred['pos']
            moved = False
            
            # 尝试所有可能的移动方向（随机顺序）
            random.shuffle(move_vectors)
            
            for dr, dc in move_vectors:
                new_pos = (current_pos[0] + dr, current_pos[1] + dc)
                
                # 检查是否移出边界
                if not (0 <= new_pos[0] < MAP_SIZE and 0 <= new_pos[1] < MAP_SIZE):
                    # 在边缘生成新捕食者
                    edge_pos = self._random_edge_position()
                    
                    # 确保新位置有效且未被占用
                    if self.map[edge_pos] == 0 and edge_pos not in moved_positions:
                        new_pred = {
                            'pos': edge_pos,
                            'img_idx': np.random.randint(len(self.image_datasets['predator']))
                        }
                        self.map[edge_pos] = 1
                        moved_positions.add(edge_pos)
                        new_predators.append(new_pred)
                        moved = True
                        break
                    # 如果位置无效，继续尝试其他方向
                    continue
                
                # 检查目标位置是否有效（只能是环境或代理，不能是食物或其他捕食者）
                if self.map[new_pos] != 1 and self.map[new_pos] != 2 and new_pos not in moved_positions:
                    # 检查是否与食物或其他捕食者冲突
                    pred['pos'] = new_pos
                    self.map[new_pos] = 1
                    moved_positions.add(new_pos)
                    new_predators.append(pred)
                    moved = True
                    break
            
            # 如果无法移动，留在原位（检查位置是否仍可用）
            if not moved:
                if current_pos not in moved_positions and self.map[current_pos] != 2:  # 不是食物
                    self.map[current_pos] = 1
                    moved_positions.add(current_pos)
                    new_predators.append(pred)
                else:
                    # 当前位置已被占用，在随机位置生成新捕食者
                    new_pos = self._random_position()
                    if self.map[new_pos] == 0 and new_pos not in moved_positions:
                        new_pred = {
                            'pos': new_pos,
                            'img_idx': np.random.randint(len(self.image_datasets['predator']))
                        }
                        self.map[new_pos] = 1
                        moved_positions.add(new_pos)
                        new_predators.append(new_pred)
        
        self.predators = new_predators

    def _random_position(self) -> Tuple[int, int]:
        """随机位置（地图内）"""
        return (np.random.randint(0, MAP_SIZE), np.random.randint(0, MAP_SIZE))

    def _random_edge_position(self) -> Tuple[int, int]:
        """随机边缘位置"""
        edge = np.random.choice(['top', 'bottom', 'left', 'right'])
        if edge == 'top': 
            return (0, np.random.randint(0, MAP_SIZE))
        if edge == 'bottom': 
            return (MAP_SIZE - 1, np.random.randint(0, MAP_SIZE))
        if edge == 'left': 
            return (np.random.randint(0, MAP_SIZE), 0)
        return (np.random.randint(0, MAP_SIZE), MAP_SIZE - 1)  # right

    def _is_valid_position(self, pos: Tuple[int, int]) -> bool:
        """检查位置是否有效"""
        return 0 <= pos[0] < MAP_SIZE and 0 <= pos[1] < MAP_SIZE

    def render(self):
        """渲染当前地图状态"""
        if self.current_map_image is None:
            fig, ax = plt.subplots(figsize=(10, 10))
            self.current_map_image = ax.imshow(
                self.map, 
                cmap='viridis', 
                vmin=0, 
                vmax=len(PIXEL_TYPES)-1
            )
            plt.colorbar(self.current_map_image, label='Entity Types')
            plt.title("Survival Game Environment")
            plt.axis('off')
            plt.ion()
            plt.show()
        
        # 创建带代理标记的地图
        display_map = self.map.copy()
        display_map[self.agent_pos] = 3  # 标记代理位置
        
        self.current_map_image.set_data(display_map)
        plt.title(f"Step: {self.steps}, Agent Position: {self.agent_pos}, Score: {self.score:.1f}")
        plt.gcf().canvas.draw_idle()
        plt.pause(0.1)

    def close(self):
        """关闭渲染"""
        if self.current_map_image is not None:
            plt.ioff()
            plt.close()
            self.current_map_image = None

# 2. 模型

## 2.1 模型架构

In [None]:

# ========================
# Model Architecture (Wrappers for Perception)
# ========================
class PerceptionModule(nn.Module):
    def __init__(self, pretrained_convnext=True):
        super(PerceptionModule, self).__init__()
        self.convnext = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.DEFAULT if pretrained_convnext else None)
        self.convnext.classifier = nn.Identity() # Remove original classifier
        
        self.do_normalize = pretrained_convnext
        if self.do_normalize:
            self.normalize_transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is (B, N_VIEWS, C, H, W), scaled [0,1]
        
        # Hooks for visualization (GradCAM needs these on a specific layer, see ModelVisualizer)
        # For general input/output grads if needed:
        # if x.requires_grad:
        #     x.register_hook(lambda grad: setattr(self, 'input_gradients', grad))
        
        batch_size, num_views, C, H, W = x.size()
        x_input_to_convnext = x.view(batch_size * num_views, C, H, W)

        if self.do_normalize:
            x_input_to_convnext = self.normalize_transform(x_input_to_convnext)
        
        # Register hook for GradCAM on the output of the target layer if not done externally
        # features = self.convnext(x_input_to_convnext)
        # For GradCAM, typically want features from a specific conv layer, not the final output of convnext here.
        # The ModelVisualizer will handle hooking the specific internal layer.
        
        raw_features = self.convnext(x_input_to_convnext) # Output: (B*N_VIEWS, feature_dim e.g. 768)
        return raw_features.view(batch_size, num_views, -1) # (B, N_VIEWS, feature_dim)

class PretrainedEncoderWrapper(nn.Module):
    def __init__(self, encoder_model: nn.Module):
        super().__init__()
        self.encoder = encoder_model # This is a ConvNeXt model (autoencoder.encoder)
        # AE encoder was trained on [0,1] images, so no ImageNet normalization here.
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is (batch_size, num_views, C, H, W), already scaled to [0,1]
        batch_size, num_views, C, H, W = x.size()
        x_reshaped = x.view(batch_size * num_views, C, H, W)
        
        features = self.encoder(x_reshaped) # encoder is ConvNeXt, outputs (B*N_VIEWS, 768)
        return features.view(batch_size, num_views, -1)

# Decision module as defined in doc (used by ModelVisualizer, SB3 PPO has its own MLP head)
class StandaloneDecisionModule(nn.Module):
    def __init__(self, input_dim: int = NUM_VIEWS * 768, hidden_dims: List[int] = [512, 256]):
        super(StandaloneDecisionModule, self).__init__()
        layers = []
        prev_dim = input_dim
        for dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, dim))
            layers.append(nn.ReLU())
            prev_dim = dim
        layers.append(nn.Linear(prev_dim, 3)) # Output 3 action probabilities
        layers.append(nn.Softmax(dim=-1))
        self.net = nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is (B, N_VIEWS * feature_dim)
        return self.net(x)


# ========================
# 自编码器
# ========================
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.DEFAULT) # Or None for from scratch
        self.encoder.classifier = nn.Identity() # Output of encoder is (B, 768)

        # Decoder: (B, 768, 1, 1) -> (B, 3, 100, 100)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(768, 512, kernel_size=5, stride=1, padding=0), # (B, 512, 5, 5)
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, kernel_size=5, stride=5, padding=0), # (B, 256, 25, 25)
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # (B, 128, 50, 50)
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # (B, 64, 100, 100)
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1),   # (B, 3, 100, 100)
            nn.Sigmoid() # Output images in [0,1]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor: # x is (B, 3, H, W)
        encoded_flat = self.encoder(x) # (B, 768)
        encoded_reshaped = encoded_flat.view(-1, 768, 1, 1) # Reshape for ConvTranspose
        decoded = self.decoder(encoded_reshaped)
        return decoded

# ========================
# 特征提取器 for Stable Baselines3
# ========================
class CustomFeatureExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Box, perception_module: nn.Module):
        # features_dim is N_VIEWS * perception_module_output_dim_per_view
        # Assuming perception_module outputs 768 features per view
        super(CustomFeatureExtractor, self).__init__(observation_space, features_dim=NUM_VIEWS * 768)
        self.perception = perception_module # This is PerceptionModule or PretrainedEncoderWrapper

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        # Input: (batch, NUM_VIEWS, H, W, C) uint8
        # Permute to (batch, NUM_VIEWS, C, H, W) and scale to [0,1]
        observations_processed = observations.permute(0, 1, 4, 2, 3).float() / 255.0
        
        # self.perception module (PerceptionModule or PretrainedEncoderWrapper)
        # expects (B, N_VIEWS, C, H, W) scaled [0,1]
        # and returns (B, N_VIEWS, feature_dim_per_view)
        features_per_view = self.perception(observations_processed)
        
        # Flatten features from all views: (B, N_VIEWS * feature_dim_per_view)
        return features_per_view.reshape(observations.size(0), -1)

## 2.2 模型训练

In [None]:

# ========================
# 训练方法
# ========================
def fitness_training(env: gym.Env, total_timesteps: int = 10000):
    # PerceptionModule handles its own normalization if pretrained
    perception_fitness = PerceptionModule(pretrained_convnext=True) 
    
    policy_kwargs = dict(
        features_extractor_class=CustomFeatureExtractor,
        features_extractor_kwargs=dict(perception_module=perception_fitness),
        net_arch=[dict(pi=[256, 128], vf=[256, 128])], # MLP layers for PPO policy and value
    )
    
    model = PPO("MlpPolicy", env, verbose=1, policy_kwargs=policy_kwargs, device="auto")
    model.learn(total_timesteps=total_timesteps)
    return model

def truth_training(env: gym.Env, ae_path: str, total_timesteps: int = 10000):
    autoencoder = Autoencoder()
    autoencoder.load_state_dict(torch.load(ae_path))
    ae_encoder = autoencoder.encoder
    
    for param in ae_encoder.parameters(): # Freeze encoder parameters
        param.requires_grad = False
    
    # Wrap the frozen AE encoder to match the perception module interface
    perception_truth = PretrainedEncoderWrapper(ae_encoder)
    
    policy_kwargs = dict(
        features_extractor_class=CustomFeatureExtractor,
        features_extractor_kwargs=dict(perception_module=perception_truth),
        net_arch=[dict(pi=[256, 128], vf=[256, 128])]
    )
    
    model = PPO("MlpPolicy", env, verbose=1, policy_kwargs=policy_kwargs, device="auto")
    model.learn(total_timesteps=total_timesteps)
    return model

def train_autoencoder(image_datasets: Dict[str, List[np.ndarray]],
                      epochs: int = 10,
                      batch_size: int = 32,
                      ae_save_path: str = "autoencoder.pth",
                      num_workers: int = 0 # For DataLoader: 0 means data loaded in main process
                     ) -> Autoencoder: # Assuming Autoencoder class is defined
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Autoencoder().to(device) # Ensure Autoencoder class is defined
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Combine all images for training from the input dictionary
    all_images_np = [img for cat_imgs in image_datasets.values() for img in cat_imgs]

    if not all_images_np:
        print("No images found to train the autoencoder. Aborting.")
        return model # Or raise an error

    # Define the transformation - ToTensor is crucial
    # It converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255]
    # to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    # Create the custom dataset
    custom_dataset = CustomImageDataset(image_list=all_images_np, transform=transform)

    # Create the DataLoader
    # The DataLoader will now handle batching and shuffling efficiently.
    # Images are loaded and transformed on-the-fly (or by worker processes if num_workers > 0)
    # and only a batch at a time is moved to the GPU.
    dataloader = DataLoader(custom_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=num_workers, # Adjust based on your system
                            pin_memory=True if device.type == 'cuda' else False) # Speeds up CPU to GPU transfer

    print(f"Training Autoencoder with {len(custom_dataset)} images on {device}...")
    print(f"Number of batches per epoch: {len(dataloader)}")

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for i, inputs in enumerate(dataloader): # inputs are (B, C, H, W)
            # Move the batch of inputs to the designated device
            inputs = inputs.to(device)

            optimizer.zero_grad()
            outputs = model(inputs) # Autoencoder reconstructs the input
            loss = criterion(outputs, inputs)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

            if (i + 1) % (max(1, len(dataloader) // 10)) == 0: # Print progress 10 times per epoch
                print(f"  Epoch {epoch+1}/{epochs}, Batch {i+1}/{len(dataloader)}, "
                      f"Current Avg Loss: {loss.item():.4f}")


        epoch_loss = running_loss / len(custom_dataset)
        print(f"Epoch {epoch+1}/{epochs} finished. Average Training Loss: {epoch_loss:.4f}")

    torch.save(model.state_dict(), ae_save_path)
    print(f"Autoencoder trained and saved to {ae_save_path}")
    return model


# 3. 评估

## 3.1 Survival Evaluation

In [None]:

# ========================
# 生存评估函数
# ========================
def evaluate_survival(model: PPO, env: gym.Env, n_episodes: int = 10) -> float:
    total_steps = 0
    for _ in range(n_episodes):
        obs, _ = env.reset()
        terminated = False
        truncated = False
        episode_steps = 0

        while not (terminated or truncated):
            # 获取原始动作概率分布
            with torch.no_grad():
                # 1. 将 NumPy observation 转换为 PyTorch 张量
                #    model.policy.get_distribution 期望的输入是已经经过特征提取器处理的潜在特征，
                #    或者如果直接传入观测值，它内部会调用特征提取器。
                #    为确保正确处理，需要将 obs 转换为 tensor 并添加 batch 维度。
                obs_tensor = torch.as_tensor(obs, device=model.device).unsqueeze(0)
                dist = model.policy.get_distribution(obs_tensor)
                # dist.probs 通常是 (batch_size, num_actions)
                action_probs = dist.distribution.probs.cpu().numpy().squeeze() # 使用 .squeeze() 移除 batch 维度

            # 环境交互（传入概率分布）
            obs, _, terminated, truncated, _ = env.step(action_probs)
            episode_steps += 1

        total_steps += episode_steps

    return total_steps / n_episodes

def evaluate_survival_with_render(model: PPO, env: gym.Env, n_episodes: int = 1) -> float:
    total_steps = 0
    if n_episodes <= 0:
        return 0.0

    for episode in range(n_episodes):
        obs, _ = env.reset()
        terminated = False
        truncated = False
        episode_steps = 0
        print(f"Starting Episode {episode + 1} with rendering...")

        while not (terminated or truncated):
            env.render()  # <-- 调用渲染

            # 获取原始动作概率分布
            with torch.no_grad():
                # 将 NumPy observation 转换为 PyTorch 张量, 添加 batch 维度，并放到正确的 device
                # obs from env is (NUM_VIEWS, H, W, C) numpy.uint8
                obs_tensor = torch.as_tensor(obs, device=model.device).unsqueeze(0)
                # CustomFeatureExtractor 会处理 permute 和 scaling
                # get_distribution 会在内部使用特征提取器处理 obs_tensor

                distribution = model.policy.get_distribution(obs_tensor)
                # 对于 CategoricalDistribution (离散动作空间), 概率在 distribution.probs
                # .squeeze() 移除批处理维度，得到 (num_actions,) 的概率分布
                action_probs = distribution.distribution.probs.cpu().numpy().squeeze()

            # 环境交互（传入概率分布）
            obs, reward, terminated, truncated, _ = env.step(action_probs)
            episode_steps += 1

            # 可选：在 episode 内部打印一些信息
            if episode_steps % 50 == 0:
                print(f"  Episode {episode + 1}, Step {episode_steps}, Current Score: {env.score:.1f}")

        total_steps += episode_steps
        print(f"Episode {episode + 1} finished after {episode_steps} steps. Final Score: {env.score:.1f}")

    env.close()  # 在所有 episodes 结束后关闭渲染窗口
    return total_steps / n_episodes

In [None]:

# ========================
# 可视化评估工具
# ========================
class ModelVisualizer:
    def __init__(self, perception_module: nn.Module, decision_module: nn.Module):
        self.perception = perception_module
        self.decision = decision_module # PPO's policy_net (mlp_extractor.policy_net)
        
        # --- Hooks and stored data management ---
        # For GradCAM specific target layer
        self.target_layer_hook_activations: Optional[torch.Tensor] = None
        self.target_layer_hook_gradients: Optional[torch.Tensor] = None
        
        # For VBP/GuidedBP (and general hook management)
        self._hook_handles: List[torch.utils.hooks.RemovableHandle] = []
        self.module_to_forward_output: Dict[nn.Module, torch.Tensor] = {}

    def cleanup_hooks(self):
        """Removes all registered hooks and clears stored hook-related data."""
        for handle in self._hook_handles:
            handle.remove()
        self._hook_handles = []
        
        self.target_layer_hook_activations = None
        self.target_layer_hook_gradients = None
        self.module_to_forward_output = {}

    # --- Hook callback functions ---
    def _store_forward_output_hook(self, module: nn.Module, input_val: Any, output_val: torch.Tensor):
        """Stores the output of a module during forward pass (for VBP)."""
        self.module_to_forward_output[module] = output_val.detach()

    def _gelu_hook_function_vbp(self, module: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: Tuple[torch.Tensor, ...]) -> Optional[Tuple[torch.Tensor, ...]]:
        """Approximated Guided Backpropagation hook for GELU."""
        if module not in self.module_to_forward_output:
            # This can happen if forward pass didn't go through this specific module
            # or if hooks were not registered correctly for the forward pass.
            return None # Or grad_input if we want to allow passthrough

        corresponding_forward_output = self.module_to_forward_output[module]
        # Guided BP logic: only pass gradient if grad_output is positive AND forward output was positive
        # grad_output[0] is the gradient w.r.t. the module's output.
        # torch.clamp(grad_output[0], min=0.0) ensures only positive gradients from above are considered.
        # (corresponding_forward_output > 0).float() ensures neuron was active.
        guided_grad = torch.clamp(grad_output[0], min=0.0) * (corresponding_forward_output > 0).float()
        return (guided_grad,)

    def _relu_hook_function_vbp(self, module: nn.Module, grad_input: Tuple[torch.Tensor, ...], grad_output: Tuple[torch.Tensor, ...]) -> Optional[Tuple[torch.Tensor, ...]]:
        """Guided Backpropagation hook for ReLU."""
        if module not in self.module_to_forward_output:
            return None
            
        corresponding_forward_output = self.module_to_forward_output[module]
        guided_grad = torch.clamp(grad_output[0], min=0.0) * (corresponding_forward_output > 0).float()
        return (guided_grad,)

    def _gradcam_activation_hook(self, module: nn.Module, input_val: Any, output_val: torch.Tensor):
        """Stores activations for GradCAM."""
        self.target_layer_hook_activations = output_val.detach()

    def _gradcam_gradient_hook(self, module: nn.Module, grad_input: Any, grad_output: Tuple[torch.Tensor, ...]):
        """Stores gradients for GradCAM."""
        self.target_layer_hook_gradients = grad_output[0].detach()

    # --- Main Visualization Methods ---
    def _register_gradcam_hooks(self):
        """Registers hooks for GradCAM on the target ConvNeXt layer."""
        self.cleanup_hooks() # Clear any previous hooks

        convnext_model = None
        if hasattr(self.perception, 'convnext'): # For PerceptionModule
            convnext_model = self.perception.convnext
        elif hasattr(self.perception, 'encoder'): # For PretrainedEncoderWrapper
            convnext_model = self.perception.encoder
        else:
            raise TypeError("Perception module is of an unknown type for GradCAM hook registration.")

        try:
            # Target the output of the last stage in ConvNeXt features
            target_layer = convnext_model.features[-1] 
            handle_fwd = target_layer.register_forward_hook(self._gradcam_activation_hook)
            handle_bwd = target_layer.register_full_backward_hook(self._gradcam_gradient_hook)
            self._hook_handles.extend([handle_fwd, handle_bwd])
        except Exception as e:
            print(f"Error registering GradCAM hooks: {e}. GradCAM might not work correctly.")
            self.cleanup_hooks() # Ensure partial hooks are removed
    
    def activation_maximization(self, action_idx: int, lr: float = 0.1, steps: int = 200, num_views_for_am = NUM_VIEWS) -> np.ndarray:
        self.perception.eval()
        self.decision.eval()
        print(f"Starting AM for action {action_idx}...")

        # Optimizes a single (1,3,H,W) image, assuming it's one of the N_VIEWS inputs
        # The perception module will process it as (1, 1, C, H, W) effectively
        # Then its features are replicated for the decision module.
        optimized_image_tensor = torch.rand(1, 3, 100, 100, requires_grad=True, device=next(self.perception.parameters()).device)
        optimizer = optim.Adam([optimized_image_tensor], lr=lr, weight_decay=1e-4)

        for i in range(steps):
            optimizer.zero_grad()
            
            # Clamp and ensure image is in [0,1] range for perception module
            current_image_0_1 = torch.clamp(optimized_image_tensor, 0.0, 1.0)
            
            # Perception module expects (B, N_VIEWS, C, H, W)
            # We form an input where one view is the optimized image, others could be neutral (e.g., gray)
            # For simplicity in AM: assume the optimized image is so dominant it works if it's just one view.
            # The self.perception here is the SB3 model's feature extractor's perception part.
            # It expects (B, N_VIEWS, C, H, W) format.
            # So, let's treat the optimized image as if it's all N_VIEWS for AM purposes.
            multi_view_input = current_image_0_1.repeat(1, num_views_for_am, 1, 1, 1).squeeze(0) # (N_VIEWS, C, H, W)
            multi_view_input = multi_view_input.unsqueeze(0) # (1, N_VIEWS, C, H, W)


            # Get features from perception ( (1, N_VIEWS, feat_dim) )
            features_per_view = self.perception(multi_view_input)
            # Flatten for decision module ( (1, N_VIEWS * feat_dim) )
            flat_features = features_per_view.view(1, -1)
            
            action_distribution = self.decision(flat_features) # decision is PPO's policy_net
            
            loss = -action_distribution[0, action_idx] # Maximize prob of this action
            
            # Add some regularization to the image (e.g., total variation)
            loss += 0.0001 * torch.sum(torch.abs(current_image_0_1[:, :, :, :-1] - current_image_0_1[:, :, :, 1:])) + \
                    0.0001 * torch.sum(torch.abs(current_image_0_1[:, :, :-1, :] - current_image_0_1[:, :, 1:, :]))

            loss.backward()
            optimizer.step()
            if i % (steps // 10) == 0:
                 print(f"AM step {i}, loss {loss.item()}")

        final_image_0_1 = torch.clamp(optimized_image_tensor.detach(), 0.0, 1.0)
        generated_np = final_image_0_1.squeeze().permute(1, 2, 0).cpu().numpy()
        return (generated_np * 255).astype(np.uint8)

    def grad_cam(self, obs_tensor_0_1: torch.Tensor, action_idx: int, target_view_idx: int = 0) -> Optional[np.ndarray]:
        # obs_tensor_0_1 is (1, N_VIEWS, C, H, W), scaled [0,1]
        # target_view_idx specifies which of the N_VIEWS to generate GradCAM for.
        self.perception.eval()
        self.decision.eval()
        self._register_gradcam_hooks() # Ensure hooks are on the correct layer
        
        if not self._hook_handles: # Check if hooks were successfully registered
            return None
        
        
        obs_tensor_0_1.requires_grad_(True)
        
        # Forward pass
        # self.perception is the perception module from the SB3 agent (e.g., PerceptionModule or PretrainedEncoderWrapper)
        features_per_view = self.perception(obs_tensor_0_1) # (1, N_VIEWS, feat_dim)
        flat_features = features_per_view.view(1, -1)       # (1, N_VIEWS * feat_dim)
        
        # self.decision is the PPO's policy_net
        action_distribution = self.decision(flat_features) # (1, num_actions)
        
        # Backward pass for the target action
        self.perception.zero_grad() # Zero grads for perception module's ConvNeXt
        if self.decision.parameters(): # Also zero grads for decision MLP if it has params
            self.decision.zero_grad()

        score = action_distribution[0, action_idx]
        score.backward(retain_graph=True)

        if self.target_layer_hook_activations is None or self.target_layer_hook_gradients is None:
            print("GradCAM: Activations or gradients not captured. Hooks might not be set correctly.")
            return None

        # Activations/Gradients are from the ConvNeXt internal layer, shape (N_VIEWS_eff, C_feat, H_feat, W_feat)
        # N_VIEWS_eff is batch_size * num_views from the perception module's internal reshaping. Here batch_size=1.
        activations_all_views = self.target_layer_hook_activations # (NUM_VIEWS, C_feat, H_feat, W_feat)
        gradients_all_views = self.target_layer_hook_gradients     # (NUM_VIEWS, C_feat, H_feat, W_feat)

        # Select the specific view
        activations_target_view = activations_all_views[target_view_idx] # (C_feat, H_feat, W_feat)
        gradients_target_view = gradients_all_views[target_view_idx]   # (C_feat, H_feat, W_feat)
        
        # Compute weights (alpha_k)
        pooled_gradients = torch.mean(gradients_target_view, dim=[1, 2]) # (C_feat)
        
        # Weight activations
        for i in range(activations_target_view.shape[0]): # Loop over channels
            activations_target_view[i, :, :] *= pooled_gradients[i]
            
        heatmap = torch.mean(activations_target_view, dim=0).cpu().numpy() # (H_feat, W_feat)
        heatmap = np.maximum(heatmap, 0) # ReLU
        if np.max(heatmap) > 0:
            heatmap /= np.max(heatmap) # Normalize
        
        # Resize to original image size
        original_h, original_w = obs_tensor_0_1.shape[-2:]
        heatmap_resized = cv2.resize(heatmap, (original_w, original_h))
        return heatmap_resized
    
    def visual_back_prop(self, obs_tensor_0_1: torch.Tensor, target_view_idx: int = 0) -> Optional[np.ndarray]:
        """
        Generates a VisualBackProp (Guided Backpropagation style) saliency map.
        Shows general input patterns contributing to the perception module's features for a view.
        """
        self.perception.eval()
        self.cleanup_hooks() # Clears all hooks and stored data (module_to_forward_output too)

        convnext_model = None
        if hasattr(self.perception, 'convnext'):
            convnext_model = self.perception.convnext
        elif hasattr(self.perception, 'encoder'):
            convnext_model = self.perception.encoder
        else:
            print("VBP: Perception module type not recognized.")
            return None

        # Register VBP hooks on all GELU/ReLU layers in the ConvNeXt model
        for module_name, module in convnext_model.named_modules():
            if isinstance(module, nn.GELU):
                self._hook_handles.append(module.register_forward_hook(self._store_forward_output_hook))
                self._hook_handles.append(module.register_full_backward_hook(self._gelu_hook_function_vbp))
            elif isinstance(module, nn.ReLU): # Fallback for any ReLUs
                self._hook_handles.append(module.register_forward_hook(self._store_forward_output_hook))
                self._hook_handles.append(module.register_full_backward_hook(self._relu_hook_function_vbp))
        
        if not self._hook_handles:
            print("VBP: No suitable activation layers (GELU/ReLU) found to hook in ConvNeXt.")
            return None # cleanup_hooks already called, so state is clean.

        # Prepare input image: clone, detach, and set requires_grad
        input_img_for_vbp = obs_tensor_0_1.clone().detach().requires_grad_(True)
        
        # --- Forward pass ---
        # This single forward pass will:
        # 1. Populate self.module_to_forward_output via the forward hooks.
        # 2. Give us features_per_view to backpropagate from.
        features_per_view = self.perception(input_img_for_vbp) # Output: (B, N_VIEWS, feature_dim)
        
        # --- Backward pass ---
        self.perception.zero_grad() # Zero gradients for the perception model parameters
        if input_img_for_vbp.grad is not None:
            input_img_for_vbp.grad.data.zero_()

        # Target for backpropagation: Sum of features for the target_view_idx.
        # This gives a "general pattern" for that view's features.
        target_features_for_bp = features_per_view[0, target_view_idx, :] # Shape: (feature_dim)
        
        # Gradient for backward pass (sum of features -> gradient of 1 for each feature)
        grad_outputs_for_bp = torch.ones_like(target_features_for_bp)
        
        target_features_for_bp.backward(gradient=grad_outputs_for_bp, retain_graph=False)

        # --- Retrieve and process gradient on the input image ---
        saliency_grad = input_img_for_vbp.grad
        if saliency_grad is None:
            print("VBP: Gradient for input image was not computed.")
            # Hooks will be cleaned up by the next call to a viz method or external `visualizer.cleanup_hooks()`
            return None

        # Take absolute value of gradients for the target view
        saliency_abs = saliency_grad.data.abs() # Shape: (1, N_VIEWS, C, H, W)
        saliency_target_view_abs = saliency_abs[0, target_view_idx, :, :, :] # Shape: (C, H, W)
        
        saliency_target_view_np = saliency_target_view_abs.cpu().numpy()
        
        # Normalize: max across channels, then scale to [0, 255]
        saliency_map = np.max(saliency_target_view_np, axis=0) # Shape: (H, W)
        if np.max(saliency_map) > 0:
            saliency_map /= np.max(saliency_map)
        saliency_map_uint8 = (saliency_map * 255).astype(np.uint8)
        
        # Hooks will be cleaned up by the next call to a viz method or an explicit call to self.cleanup_hooks().
        return saliency_map_uint8
    def cleanup_hooks(self):
        for handle in self._hook_handles:
            handle.remove()
        self._hook_handles = []


# ========================
# 可视化评估函数
# ========================
def visualize_truth_evaluation(model: PPO, env: gym.Env, n_samples: int = 3, run_name="default"):
    print(f"\nPerforming Truth visualization evaluation for {run_name}...")
    
    # Extract perception and decision modules from the PPO model
    # The feature extractor itself contains the perception_module
    feature_extractor = model.policy.features_extractor
    if not hasattr(feature_extractor, 'perception'):
        print("Error: PPO model's feature_extractor does not have a 'perception' attribute.")
        return
    
    perception_module_from_agent = feature_extractor.perception
    # The decision module is the policy network part of the MLP extractor
    decision_module_from_agent = model.policy.mlp_extractor.policy_net

    visualizer = ModelVisualizer(perception_module_from_agent, decision_module_from_agent)
    
    action_names = ['Escape', 'Eat', 'Wander']

    # 1. Activation Maximization
    print("Generating Activation Maximization visualizations...")
    for action_idx, action_name in enumerate(action_names):
        am_image_np = visualizer.activation_maximization(action_idx)
        plt.figure(figsize=(5, 5))
        plt.imshow(am_image_np)
        plt.title(f'AM ({run_name}): {action_name}')
        plt.axis('off')
        plt.savefig(f'am_{run_name}_{action_name.lower()}.png')
        plt.close()

    # 2. Saliency Maps (GradCAM)
    print("Generating Saliency Maps (GradCAM)...")
    obs_np, _ = env.reset() # obs_np is (N_VIEWS, H, W, C) uint8
    for i in range(n_samples):
        # Convert current observation to tensor, scale to [0,1]
        # (N_VIEWS, H, W, C) -> (1, N_VIEWS, C, H, W)
        obs_tensor_0_1 = torch.tensor(obs_np, dtype=torch.float32, device=model.device).permute(0, 3, 1, 2).unsqueeze(0) / 255.0
        
        # Get action from model (PPO model.predict expects the numpy obs)
        action_idx_sb3, _ = model.predict(obs_np, deterministic=True)
        action_name = action_names[action_idx_sb3]

        # Generate GradCAM for the first view (e.g., "front" view)
        target_view_for_gradcam = 0 # 0: Front, 1: Left, etc.
        grad_cam_map = visualizer.grad_cam(obs_tensor_0_1, action_idx_sb3, target_view_idx=target_view_for_gradcam)
        
        input_image_to_show_np = obs_np[target_view_for_gradcam] # (H, W, C) uint8

        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(input_image_to_show_np)
        plt.title(f'Input View ({run_name}, Sample {i+1})\nAction: {action_name}')
        plt.axis('off')
        
        if grad_cam_map is not None:
            plt.subplot(1, 2, 2)
            plt.imshow(input_image_to_show_np, alpha=0.7)
            plt.imshow(grad_cam_map, cmap='jet', alpha=0.3)
            plt.title('GradCAM')
            plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(f'saliency_{run_name}_sample_{i+1}.png')
        plt.close()
        
        action_probs_for_env = np.zeros(3)
        action_probs_for_env[action_idx_sb3] = 1.0
        obs_np, _, terminated, truncated, _ = env.step(action_probs_for_env)
        if terminated or truncated:
            obs_np, _ = env.reset()
            if i + 1 >= n_samples: break # Avoid reset if last sample

    # 3. VisualBackProp
    print("Generating VisualBackProp visualizations...")
    # Get a fresh observation if needed, or reuse
    # obs_np, _ = env.reset() # current obs_np is from end of GradCAM loop
    
    for i in range(n_samples): # Use n_samples or a different number for VBP
        # Ensure obs_np is current for this iteration
        if i > 0 or not ('obs_np' in locals() and obs_np is not None): # if not first iter or obs_np is not set
             action_probs_dummy = np.array([0.0, 0.0, 1.0]) # e.g., Wander
             obs_np, _, terminated, truncated, _ = env.step(action_probs_dummy)
             if terminated or truncated:
                 obs_np, _ = env.reset()
                 if i + 1 >= n_samples: break # Avoid issues if last sample leads to reset

        obs_tensor_vbp_0_1 = torch.tensor(obs_np, dtype=torch.float32, device=model.device).permute(0, 3, 1, 2).unsqueeze(0) / 255.0
        
        target_view_for_vbp = 0 # e.g., Front view
        vbp_map = visualizer.visual_back_prop(obs_tensor_vbp_0_1, target_view_idx=target_view_for_vbp)
        
        input_image_np_vbp = obs_np[target_view_for_vbp] # (H, W, C) uint8 for plotting

        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(input_image_np_vbp)
        plt.title(f'Input View ({run_name}, VBP Sample {i+1})')
        plt.axis('off')

        if vbp_map is not None:
            plt.subplot(1, 2, 2)
            plt.imshow(vbp_map, cmap='gray') # VBP typically shown in grayscale
            plt.title('VisualBackProp')
            plt.axis('off')
        else:
            plt.subplot(1,2,2)
            plt.text(0.5, 0.5, "VBP Failed", ha='center', va='center')
            plt.axis('off')

        plt.tight_layout()
        plt.savefig(f'vbp_{run_name}_sample_{i+1}.png')
        # plt.show() # Optional: show plot interactively
        plt.close()
        
        if i + 1 >= n_samples and (terminated or truncated): # Check if loop should break after reset
            break

    visualizer.cleanup_hooks() # Important to remove hooks after use
    print(f"Truth visualization for {run_name} complete.")

# 4. 主工作流程

In [None]:

# ========================
# 主工作流程
# ========================
def main():
    base_directory = "TrainingData"  # 替换为你的文件夹路径
    image_datasets = load_images_to_dict(base_directory)
    # 打印结构信息
    print(f"Loaded {len(image_datasets)} folders")
    for folder, images in image_datasets.items():
        print(f"{folder}: {len(images)} images, shape={images[0].shape if images else 'N/A'}")
    
    env = SurvivalGameEnv(image_datasets)
    
    # Pretrain Autoencoder
    print("Pretraining Autoencoder...")
    autoencoder_model = train_autoencoder(image_datasets, epochs=2, ae_save_path="autoencoder.pth") # Short epochs for test
    
    # Training parameters
    TRAIN_TIMESTEPS = 2000 # Very short for testing, increase for real training (e.g., 50000+)

    # Train Fitness model
    print("\nTraining Fitness model...")
    fitness_model = fitness_training(env, total_timesteps=TRAIN_TIMESTEPS)
    fitness_model.save("ppo_fitness_model")
    
    # Train Truth model
    print("\nTraining Truth model...")
    truth_model = truth_training(env, ae_path="autoencoder.pth", total_timesteps=TRAIN_TIMESTEPS)
    truth_model.save("ppo_truth_model")
    
    # Load models if needed (example)
    # fitness_model = PPO.load("ppo_fitness_model", env=env)
    # truth_model = PPO.load("ppo_truth_model", env=env)

    # Evaluate models
    print("\nEvaluating models for survival...")
    # 非实时渲染fitness模型评估
    # fitness_score = evaluate_survival(fitness_model, env, n_episodes=5)
    # 实时渲染fitness模型评估
    fitness_score = evaluate_survival_with_render(fitness_model, env, n_episodes=1)
    truth_score = evaluate_survival(truth_model, env, n_episodes=5)
    
    print(f"\nResults:")
    print(f"Fitness Model Average Survival: {fitness_score:.1f} steps")
    print(f"Truth Model Average Survival: {truth_score:.1f} steps")

    # Truth visualization evaluation
    visualize_truth_evaluation(fitness_model, env, run_name="FitnessModel")
    visualize_truth_evaluation(truth_model, env, run_name="TruthModel")

    env.close()

if __name__ == "__main__":
    main()