In [None]:
import gymnasium as gym

# 打印所有已注册环境（包括 Atari）
for env_id in gym.registry.keys():
    if "ALE" in env_id:
        print(env_id)


In [None]:
import ale_py
print("ALE version:", ale_py.__version__)


#展示游戏的进程

In [None]:
import gymnasium as gym
import matplotlib.pyplot as plt
import time

env = gym.make("ALE/SpaceInvaders-v5", render_mode="rgb_array")
obs, info = env.reset()

for step in range(100):
    action = env.action_space.sample()
    obs, reward, done, truncated, info = env.step(action)
    
    plt.imshow(obs)
    plt.axis('off')
    plt.pause(0.02)
    plt.clf()
    
    if done or truncated:
        obs, info = env.reset()

env.close()
plt.close()


In [None]:
print(env.action_space.n)
print(env.observation_space)

In [None]:
import pickle

def safe_load(filepath, max_items=10):
    """安全读取前 max_items 条数据"""
    items = []
    try:
        with open(filepath, "rb") as f:
            data = pickle.load(f)
            # 如果是 list/deque
            if isinstance(data, (list, tuple)):
                return data[:max_items]
            # 如果是别的结构，统一转 list
            try:
                return list(data)[:max_items]
            except:
                return []
    except Exception as e:
        print("加载失败:", e)
        return []
    

def is_equal(a, b):
    """通用比较函数，支持 tensor、numpy、标量"""
    try:
        import torch
        if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
            return torch.equal(a, b)
    except:
        pass
    
    try:
        import numpy as np
        if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
            return np.array_equal(a, b)
    except:
        pass

    return a == b


def compare_transitions(data):
    if len(data) == 0:
        print("没有数据可比较")
        return
    
    first = data[0]
    print("前几条数据总长度:", len(data))
    print("开始逐条比较...\n")

    for i, item in enumerate(data):
        print(f"对比 第 1 条 和 第 {i+1} 条：")

        # 如果是 tuple 或 list：逐字段比较
        if isinstance(first, (tuple, list)) and isinstance(item, (tuple, list)):
            for idx, (x, y) in enumerate(zip(first, item)):
                same = is_equal(x, y)
                print(f"  字段 {idx}: {same}")
        else:
            # 普通对象直接比较
            same = is_equal(first, item)
            print("  整体是否相同:", same)
        
        print()


if __name__ == "__main__":
    path = r"E:\rl-learn\replay_buffer.pkl"
    small_data = safe_load(path, max_items=5)
    compare_transitions(small_data)
