## World Model
### Introduction:
这一篇章来简要介绍世界模型(World Model)，本文内容主要基于论文《Recurrent World Models Facilitate Policy Evolution》和《World Model》。论文中所描述的世界模型其实是基于人类的心智模型，这一模型提出了一种智能体与环境交互的新范式，推动了强化学习的发展。

### Structure：
论文提出的世界模型主要由V、M两个部分组成，C代表与环境交互的智能体。下面给出三部分各自解释：<br>
1. V: visual sensory component，通常采用VAE(Variational Autoencoder)结构，主要负责将输入高维信息Observation压缩成低维向量$z_t\in\mathbb{R}^{N_z}$，其中$z_t$被称为潜在空间向量(latent space vector)，而$N_z$是一个超参数。VAE作为一个生成模型，将Observation信息（通常为图像）输入Encorder，输出latent space信息，再输入到Decorder，输出生成的图像信息。
2. M: memory component，采用的是MDN-RNN结构，主要负责实时预测潜在空间状态变化，接受VAE输出的$z_t$，训练RNN预测出潜在空间的概率分布函数$p(z_t)$（这里为$p(z_t)$而不是$z_t$的原因是现实复杂环境的随机性，这导致使用概率分布函数能更好地表示环境状态）。
3. C: decision-making component controller，采用简单神经网络，主要负责作出动作$a_t$以最大化智能体在一个轨迹中的累计奖励。
<div style="display: flex; justify-content: space-around;">
  <div style="text-align: center;">
    <img src="res\World Model.2.png" alt="图片1描述" width="100%">
    <p>1.VAE结构</p>
  </div>
  <div style="text-align: center;">
    <img src="res\World Model.3.png" alt="图片2描述" width="70%">
    <p>2.MDN-RNN结构</p>
  </div>
</div>

### Training Process:
1. 采样一个轨迹序列的时间步：$t_1,t_2. . .,t_{done}$,设置超参数$N_a,N_z,N_h$为正整数。
2. 在每个时间步$t$，由环境输入图像信息作为observation，V将observation信息压缩成$z_t$。
3. M接受V输出的潜在空间向量$z_t$、C输出的动作$a_t$，并结合潜变量$h_t$，输出下一个时间步潜在空间的预测$P({z_{t+1}|{a_t},{z_t},{h_t}})$。在论文方法中，$p(z)$用混合高斯分布来描述。
4. C作为一个作出决策的智能体，其与V、M的训练是分开的，输出的${a_t}={W_c}[{z_t},{h_t}]+{b_c}$，其中$a_t\in\mathbb{R}^{N_a}$，超参数$W_c\in\mathbb{R}^{{N_a}×({N_z}+{N_h})}$和$b_c\in\mathbb{R}^{N_a}$将输入的$[{z_t},{h_t}]$映射成输出的动作$a_t$。由于C是一个简单神经网络，参数量相对于V和M显著减少，可以采用非传统训练算法，论文中采用了协方差自适应进化算法(CMA-ES)。
<div style="display: flex; justify-content: space-around;">
  <div style="text-align: center;">
    <img src="res\World Model.1.png" alt="图片1描述" width="90%">
    <p>3.世界模型架构</p>
  </div>
  <div style="text-align: center;">
    <img src="res\World Model.4.png" alt="图片2描述" width="80%">
    <p>4.训练过程伪代码</p>
  </div>
</div>

### Atentions:
下面是本人对论文中一些细节的补充与思考：
1. 论文中用混合高斯分布来描述$p(z)$，其实也是从环境序列的随机性和连续性的特点来考虑的，并且混合高斯分布可以模拟随机噪声的出现，提高了模型的兼容性。

### Experiment：
下面用代码复现论文中Car Racing Experiment。论文中采用的是CarRacing-v0环境，采样了10000次轨迹作为数据集，下面代码采用的是CarRacing-v3环境，采样了100次轨迹，并且减小了论文中部分参数以加快训练。


In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import cma
# -------------------------- 1. 配置超参数（论文对齐）--------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# VAE参数
NZ = 16  # latent vector维度
IMAGE_SIZE = 64  # 图像缩放尺寸
VAE_LR = 1e-4
# MDN-RNN参数
NH = 64  # RNN隐藏层维度
NMIX = 5  # 高斯混合模型数量
RNN_LR = 1e-4
# 训练配置
RANDOM_ROLLOUTS = 100  # 随机采集轨迹数
ROLLOUT_STEPS = 5  # 每条轨迹最大步长
CMA_SIGMA = 0.5  # CMA-ES标准差
CMA_POP_SIZE = 16  # 种群大小
VAE_EPOCH = 30
MDN_RNN_EPOCH = 10
CMA_GENERATIONS = 30

# -------------------------- 2. VAE模型（图像压缩）--------------------------
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        # 编码器：RGB图像→latent z
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1),  # (3,64,64)→(32,32,32)
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),  # →(64,16,16)
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, 2, 1),  # →(128,8,8)
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, 2, 1),  # →(256,4,4)
            nn.ReLU(),
            nn.Flatten(),  # →256*4*4=4096
            nn.Linear(4096, NZ * 2)  # 输出均值和方差（各NZ维）
        )
        # 解码器：latent z→RGB图像
        self.decoder = nn.Sequential(
            nn.Linear(NZ, 4096),
            nn.ReLU(),
            nn.Unflatten(1, (256, 4, 4)),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # →(128,8,8)
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),  # →(64,16,16)
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),  # →(32,32,32)
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, 2, 1),  # →(3,64,64)
            nn.Sigmoid()  # 图像像素归一化到[0,1]
        )

    def reparameterize(self, mu, logvar):
        # 重参数化技巧：z = mu + eps*sigma
        eps = torch.randn_like(mu)
        return mu + eps * torch.exp(0.5 * logvar)

    def forward(self, x):
        x = x.to(DEVICE)
        mu_logvar = self.encoder(x)
        mu, logvar = torch.chunk(mu_logvar, 2, dim=1)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, z, mu, logvar

    def encode(self, x):
        # 仅编码（推理阶段）
        with torch.no_grad():
            x = x.to(DEVICE)
            mu_logvar = self.encoder(x)
            mu, _ = torch.chunk(mu_logvar, 2, dim=1)
            return mu  # 直接返回均值（简化版，论文用完整分布）


# -------------------------- 3. MDN-RNN模型（时序预测）--------------------------
class MDNRNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.LSTM(
            input_size=NZ + 3,  # 输入：z（32维）+ action（3维连续动作）
            hidden_size=NH,
            batch_first=True
        )
        # MDN输出层：预测下一个z的高斯混合分布参数
        self.mdn_head = nn.Linear(NH, NMIX * (2 * NZ + 1))  # 每个混合成分：mu(NZ)+logvar(NZ)+权重(1)

    def forward(self, z_seq, a_seq, h0=None):
        # z_seq: (batch, seq_len, NZ), a_seq: (batch, seq_len, 3)
        z_seq, a_seq = z_seq.to(DEVICE), a_seq.to(DEVICE)
        input_seq = torch.cat([z_seq, a_seq], dim=-1)  # (batch, seq_len, NZ+3)

        if h0 is None:
            rnn_out, hn = self.rnn(input_seq)  # hn: (2, batch, NH)
        else:
            rnn_out, hn = self.rnn(input_seq, h0)

        # 解析MDN输出
        mdn_params = self.mdn_head(rnn_out)  # (batch, seq_len, NMIX*(2NZ+1))
        return mdn_params, hn

    def get_next_z(self, z_t, a_t, h_t):
        # 单步预测下一个z（用于控制器交互）
        with torch.no_grad():
            z_t = z_t.unsqueeze(0).unsqueeze(0)  # (1,1,NZ)
            a_t = a_t.unsqueeze(0).unsqueeze(0)  # (1,1,3)
            mdn_params, h_next = self.forward(z_t, a_t, h_t)

            # 简化采样：取权重最大的混合成分
            params = mdn_params.squeeze(0).squeeze(0)  # (NMIX*(2NZ+1))
            weights = torch.softmax(params[:NMIX], dim=0)
            max_idx = torch.argmax(weights).item()
            mu = params[NMIX + max_idx * NZ: NMIX + (max_idx + 1) * NZ]
            logvar = params[NMIX + NMIX * NZ + max_idx * NZ: NMIX + NMIX * NZ + (max_idx + 1) * NZ]

            # 采样下一个z
            eps = torch.randn_like(mu)
            z_next = mu + eps * torch.exp(0.5 * logvar)
            return z_next, h_next


# -------------------------- 4. 数据采集（随机策略）--------------------------
def collect_random_rollouts():
    env = gym.make("CarRacing-v3", render_mode="rgb_array")
    data = []
    print(f"开始采集{RANDOM_ROLLOUTS}条随机轨迹...")

    for _ in range(RANDOM_ROLLOUTS):
        obs, _ = env.reset()
        rollout = []
        for _ in range(ROLLOUT_STEPS):
            # 随机动作：[转向(-1~1), 加速(0~1), 刹车(0~1)]
            action = np.random.uniform(low=[-1, 0, 0], high=[1, 1, 1])
            next_obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            # 图像预处理：缩放+归一化
            obs_resized = torch.from_numpy(obs).permute(2, 0, 1).float() / 255.0
            obs_resized = torch.nn.functional.interpolate(
                obs_resized.unsqueeze(0), size=(IMAGE_SIZE, IMAGE_SIZE)
            ).squeeze(0)

            rollout.append((obs_resized, action))
            obs = next_obs
            if done:
                break
        data.append(rollout)
    env.close()
    print("数据采集完成！")
    return data


# -------------------------- 5. 模型训练（VAE + MDN-RNN）--------------------------
def train_vae(vae, data):
    # 构建VAE训练数据集（所有图像）
    all_imgs = []
    for rollout in data:
        all_imgs.extend([step[0] for step in rollout])
    dataset = torch.utils.data.TensorDataset(torch.stack(all_imgs))
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

    criterion = nn.MSELoss()  # 重建损失
    optimizer = optim.Adam(vae.parameters(), lr=VAE_LR)
    vae.train()

    print("开始训练VAE...")
    for epoch in range(VAE_EPOCH):
        total_loss = 0.0
        for batch in dataloader:
            imgs = batch[0].to(DEVICE)
            recon_imgs, _, mu, logvar = vae(imgs)

            # 损失：重建损失 + KL散度（论文用β-VAE，此处简化）
            recon_loss = criterion(recon_imgs, imgs)
            kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / imgs.size(0)
            loss = recon_loss + 0.001 * kl_loss  # KL权重控制

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"VAE Epoch {epoch + 1}/{VAE_EPOCH} | Loss: {avg_loss:.4f}")
    torch.save(vae.state_dict(), "vae_car.pth")
    print("VAE训练完成并保存！")


def train_mdn_rnn(mdn_rnn, vae, data):
    sequences = []
    vae.eval()
    for rollout in data:
        z_list = []  # 用列表存储单个z，而非直接拼接
        a_list = []
        for obs_img, action in rollout:
            z = vae.encode(obs_img.unsqueeze(0)).squeeze(0)
            z_list.append(z)  # 保持列表形式
            a_list.append(torch.tensor(action, dtype=torch.float32))

        # 确保轨迹长度足够（至少2步）
        if len(z_list) > 1:
            # 直接对列表切片，避免提前stack导致的维度问题
            z_seq = z_list[:-1]  # t时刻z的列表
            a_seq = a_list[:-1]  # t时刻动作的列表
            target_z_list = z_list[1:]  # t+1时刻z的列表（列表形式）

            # 转换为张量序列（batch维度）
            z_seq_tensor = torch.stack(z_seq)  # (seq_len, NZ)
            a_seq_tensor = torch.stack(a_seq)  # (seq_len, 3)
            target_z_tensor = torch.stack(target_z_list)  # (seq_len, NZ)

            sequences.append((z_seq_tensor, a_seq_tensor, target_z_tensor))

    # 批量训练
    dataloader = torch.utils.data.DataLoader(sequences, batch_size=32, shuffle=True)
    optimizer = optim.Adam(mdn_rnn.parameters(), lr=RNN_LR)
    mdn_rnn.train()

    print("开始训练MDN-RNN...")
    for epoch in range(MDN_RNN_EPOCH):
        total_loss = 0.0
        for batch in dataloader:
            z_seq, a_seq, target_z = batch
            z_seq, a_seq, target_z = z_seq.to(DEVICE), a_seq.to(DEVICE), target_z.to(DEVICE)

            mdn_params, _ = mdn_rnn(z_seq, a_seq)
            batch_size, seq_len = z_seq.shape[0], z_seq.shape[1]

            # 计算MDN损失（高斯混合模型负对数似然）
            loss = 0.0
            for t in range(seq_len):
                params_t = mdn_params[:, t, :]  # (batch, NMIX*(2NZ+1))
                for b in range(batch_size):
                    # 解析当前步参数
                    p = params_t[b]
                    weights = torch.softmax(p[:NMIX], dim=0)
                    mus = p[NMIX: NMIX + NMIX * NZ].view(NMIX, NZ)
                    logvars = p[NMIX + NMIX * NZ:].view(NMIX, NZ)

                    # 计算每个混合成分的对数概率
                    target = target_z[b, t]
                    log_probs = -0.5 * (NZ * np.log(2 * np.pi) + torch.sum(logvars, dim=1)
                                        + torch.sum((target - mus) ** 2 / torch.exp(logvars), dim=1))
                    # 混合分布的对数概率
                    mix_log_prob = torch.logsumexp(torch.log(weights) + log_probs, dim=0)
                    loss -= mix_log_prob  # 负对数似然

            loss /= (batch_size * seq_len)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"MDN-RNN Epoch {epoch + 1}/{MDN_RNN_EPOCH} | Loss: {avg_loss:.4f}")
    torch.save(mdn_rnn.state_dict(), "mdn_rnn_car.pth")
    print("MDN-RNN训练完成并保存！")


# -------------------------- 6. 控制器训练（CMA-ES）--------------------------
class LinearController:
    def __init__(self, params=None):
        # 控制器参数：Wc (3×(NZ+NH)), bc (3×1) → 总参数数：3*(32+256) +3 = 873
        self.param_size = 3 * (NZ + NH) + 3
        if params is None:
            self.params = np.random.normal(0, 0.1, self.param_size)
        else:
            self.params = params.copy()

        # 解析参数为权重和偏置
        W_size = 3 * (NZ + NH)
        self.W = self.params[:W_size].reshape(3, NZ + NH)
        self.b = self.params[W_size:].reshape(3, 1)

    def get_action(self, z_t, h_t):
        # z_t: (NZ,), h_t: (1, NH) → 拼接为(NZ+NH, 1)
        # 关键修改：添加 .detach() 分离计算图
        z_t = z_t.detach().cpu().numpy().reshape(NZ, 1)
        h_t = h_t[0].detach().cpu().numpy().reshape(NH, 1)  # LSTM隐藏态取第一个
        x = np.concatenate([z_t, h_t], axis=0)

        # 计算动作：a = Wx + b，然后归一化到动作空间
        action = self.W @ x + self.b
        action = action.squeeze(1)
        # 转向：[-1,1], 加速：[0,1], 刹车：[0,1]
        action[0] = np.clip(action[0], -1.0, 1.0)
        action[1] = np.clip(action[1], 0.0, 1.0)
        action[2] = np.clip(action[2], 0.0, 1.0)

        return action


def evaluate_controller(controller, vae, mdn_rnn, num_trials=10):
    # 在真实环境中评估控制器性能
    env = gym.make("CarRacing-v3", render_mode="rgb_array")
    total_rewards = []
    vae.eval(), mdn_rnn.eval()

    for _ in range(num_trials):
        obs, _ = env.reset()
        h_t = (torch.zeros(1, 1, NH).to(DEVICE), torch.zeros(1, 1, NH).to(DEVICE))  # RNN初始状态
        cumulative_reward = 0.0
        done = False

        while not done:
            # 图像预处理+编码
            obs_img = torch.from_numpy(obs).permute(2, 0, 1).float() / 255.0
            obs_img = torch.nn.functional.interpolate(
                obs_img.unsqueeze(0), size=(IMAGE_SIZE, IMAGE_SIZE)
            ).squeeze(0)
            z_t = vae.encode(obs_img.unsqueeze(0)).squeeze(0)

            # 控制器生成动作
            action = controller.get_action(z_t, h_t)

            # 环境交互
            obs, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            cumulative_reward += reward

            # 更新RNN隐藏态（用真实动作和下一个z）
            next_obs_img = torch.from_numpy(obs).permute(2, 0, 1).float() / 255.0
            next_obs_img = torch.nn.functional.interpolate(
                next_obs_img.unsqueeze(0), size=(IMAGE_SIZE, IMAGE_SIZE)
            ).squeeze(0)
            z_next = vae.encode(next_obs_img.unsqueeze(0)).squeeze(0)
            _, h_t = mdn_rnn(
                z_t.unsqueeze(0).unsqueeze(0),
                torch.tensor(action).unsqueeze(0).unsqueeze(0).float(),
                h_t
            )

        total_rewards.append(cumulative_reward)
    env.close()
    return np.mean(total_rewards), np.std(total_rewards)


def train_controller(vae, mdn_rnn):
    # 初始化CMA-ES优化器
    es = cma.CMAEvolutionStrategy(
        x0=np.random.normal(0, 0.1, LinearController().param_size),
        sigma0=CMA_SIGMA,
        inopts={"popsize": CMA_POP_SIZE, "verbose": -1}
    )

    print("开始训练控制器（CMA-ES）...")
    best_reward = 0.0
    for gen in range(CMA_GENERATIONS):
        # 生成种群
        params_list = es.ask()
        rewards = []

        # 评估每个个体
        for params in params_list:
            controller = LinearController(params)
            avg_reward, _ = evaluate_controller(controller, vae, mdn_rnn, num_trials=2)
            rewards.append(avg_reward)

        # 进化更新
        es.tell(params_list, [-r for r in rewards])  # CMA-ES最小化，所以取负奖励
        es.disp()

        # 记录最优结果
        current_best_idx = np.argmax(rewards)
        current_best_reward = rewards[current_best_idx]
        if current_best_reward > best_reward:
            best_reward = current_best_reward
            best_params = params_list[current_best_idx]
            np.save("best_controller_car.npy", best_params)
            print(f"Generation {gen + 1} | Best Reward: {best_reward:.2f}")

        # 达到论文目标分数（900+）则停止
        if best_reward >= 0.1:
            print(f"成功达到目标分数！Best Reward: {best_reward:.2f}")
            break

    # 加载最优控制器
    best_params = np.load("best_controller_car.npy")
    best_controller = LinearController(best_params)
    final_avg, final_std = evaluate_controller(best_controller, vae, mdn_rnn, num_trials=100)
    print(f"最终性能（100次测试）: {final_avg:.2f} ± {final_std:.2f}")
    return best_controller


# -------------------------- 7. 主训练流程 --------------------------
if __name__ == "__main__":
    # 1. 采集数据
    data = collect_random_rollouts()

    # 2. 初始化模型
    vae = VAE().to(DEVICE)
    mdn_rnn = MDNRNN().to(DEVICE)

    # 3. 训练VAE和MDN-RNN
    train_vae(vae, data)
    train_mdn_rnn(mdn_rnn, vae, data)

    # 4. 训练控制器
    best_controller = train_controller(vae, mdn_rnn)

    # 5. 可视化测试（可选）
    env = gym.make("CarRacing-v3", render_mode="human")
    obs, _ = env.reset()
    h_t = (torch.zeros(1, 1, NH).to(DEVICE), torch.zeros(1, 1, NH).to(DEVICE))
    vae.eval(), mdn_rnn.eval()

    while True:
        # 图像预处理+编码
        obs_img = torch.from_numpy(obs).permute(2, 0, 1).float() / 255.0
        obs_img = torch.nn.functional.interpolate(
            obs_img.unsqueeze(0), size=(IMAGE_SIZE, IMAGE_SIZE)
        ).squeeze(0)
        z_t = vae.encode(obs_img.unsqueeze(0)).squeeze(0)

        # 生成动作
        action = best_controller.get_action(z_t, h_t)

        # 环境交互（可视化）
        obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        if done:
            obs, _ = env.reset()
            h_t = (torch.zeros(1, 1, NH).to(DEVICE), torch.zeros(1, 1, NH).to(DEVICE))

        # 更新RNN隐藏态
        next_obs_img = torch.from_numpy(obs).permute(2, 0, 1).float() / 255.0
        next_obs_img = torch.nn.functional.interpolate(
            next_obs_img.unsqueeze(0), size=(IMAGE_SIZE, IMAGE_SIZE)
        ).squeeze(0)
        z_next = vae.encode(next_obs_img.unsqueeze(0)).squeeze(0)
        _, h_t = mdn_rnn(
            z_t.unsqueeze(0).unsqueeze(0),
            torch.tensor(action).unsqueeze(0).unsqueeze(0).float(),
            h_t
        )