In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical

# ==========================================
# Variant A: Reward-Space Combination
# ==========================================
class ScalarRewardWrapper(gym.Wrapper):
    """
    """
    def __init__(self, env):
        super().__init__(env)
        original_shape = env.observation_space.shape[0]
        print(original_shape)
        # add 2 dimensions for w (w1, w2)
        # 4 + 2 = 6
        self.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf,
            shape=(original_shape + 2,),
            dtype=np.float32
        )
        print(self.observation_space)
        # initialize w
        self.current_w = np.array([0.5, 0.5], dtype=np.float32)

    def reset(self, seed=None, options=None):
        if options and 'w' in options:
            self.current_w = np.array(options['w'], dtype=np.float32)
        else:
            w = np.random.rand(2)
            self.current_w = w / w.sum()

        obs, info = self.env.reset(seed=seed)
        return np.concatenate([obs, self.current_w]), info

    def step(self, action):
        obs, _, terminated, truncated, info = self.env.step(action)
        x, x_dot, theta, theta_dot = obs

        # x range [-2.4, 2.4]
        norm_left = (2.4 - x) / 4.8
        norm_right = (x + 2.4) / 4.8

        r1 = norm_left ** 2
        r2 = norm_right ** 2

        if terminated:
            r1 = 0.0
            r2 = 0.0

        # scalarize the reward vector
        # Reward = w1 * r1 + w2 * r2
        scalar_reward = self.current_w[0] * r1 + self.current_w[1] * r2

        new_obs = np.concatenate([obs, self.current_w])

        # only know scalar_reward
        return new_obs, scalar_reward, terminated, truncated, info

# ==========================================
# 2. PPO Agent
# ==========================================
# orthogonal
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class Agent(nn.Module):
    def __init__(self, env):
        super().__init__()
        # Critic：estimate V(s)
        self.critic = nn.Sequential(
            layer_init(nn.Linear(6, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        # Actor：action prob distribution
        self.actor = nn.Sequential(
            layer_init(nn.Linear(6, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, env.action_space.n), std=0.01),
        )

    def get_value(self, x):
        # get Value
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        logits = self.actor(x)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(x)

# ==========================================
# 3. main loop
# ==========================================
if __name__ == "__main__":
    # hyperparameters
    learning_rate = 3e-4
    num_steps = 128
    total_timesteps = 80000
    gamma = 0.99
    gae_lambda = 0.95
    update_epochs = 4
    clip_coef = 0.2
    ent_coef = 0.001
    vf_coef = 0.5

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize env and agent
    env = ScalarRewardWrapper(gym.make("CartPole-v1"))
    agent = Agent(env).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=learning_rate)

    # initialize Buffer
    obs = torch.zeros((num_steps, 6)).to(device)
    actions = torch.zeros((num_steps,)).to(device)
    logprobs = torch.zeros((num_steps,)).to(device)
    rewards = torch.zeros((num_steps,)).to(device)
    dones = torch.zeros((num_steps,)).to(device)
    values = torch.zeros((num_steps,)).to(device)

    global_step = 0
    next_obs, _ = env.reset()
    next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.zeros(1).to(device)

    num_updates = total_timesteps // num_steps

    print("Starting Training (Variant A: Reward Scalarization)...")

    for update in range(1, num_updates + 1):

        # Rollout
        for step in range(num_steps):
            global_step += 1
            # save current observation
            obs[step] = next_obs
            dones[step] = next_done

            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(next_obs)
                values[step] = value.flatten()

            actions[step] = action
            logprobs[step] = logprob

            real_next_obs, reward, terminated, truncated, info = env.step(action.item())
            done = terminated or truncated

            # save reward
            rewards[step] = torch.tensor(reward).to(device)
            # update next_obs
            next_obs = torch.Tensor(real_next_obs).to(device)
            next_done = torch.tensor(float(done)).to(device)

            if done:
                next_obs_np, _ = env.reset()
                next_obs = torch.Tensor(next_obs_np).to(device)

        # GAE
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            # initialize
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0

            for t in reversed(range(num_steps)):
                if t == num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]

                # TD Error = r + gamma * V(s') - V(s)
                delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
                # GAE = delta + gamma * lambda * GAE(t+1)
                advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam

            # Returns = Advantage + Value
            returns = advantages + values

        # PPO Update
        b_obs = obs
        b_logprobs = logprobs
        b_actions = actions
        b_advantages = advantages
        b_returns = returns
        b_values = values

        for epoch in range(update_epochs):
            _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs, b_actions)
            #  ratio = new / old
            logratio = newlogprob - b_logprobs
            ratio = logratio.exp()

            # Advantage Normalization
            mb_advantages = (b_advantages - b_advantages.mean()) / (b_advantages.std() + 1e-8)

            # Policy Loss
            pg_loss1 = -mb_advantages * ratio
            pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
            pg_loss = torch.max(pg_loss1, pg_loss2).mean()

            # Value Loss (MSE Loss)
            v_loss = 0.5 * ((newvalue.view(-1) - b_returns) ** 2).mean()

            # total Loss
            loss = pg_loss - ent_coef * entropy.mean() + vf_coef * v_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if update % 10 == 0:
            print(f"Update {update}/{num_updates}, Loss: {loss.item():.4f}, Mean Reward: {rewards.mean().item():.2f}")

    print("Training Finished!")
    # save model weight
    torch.save(agent.state_dict(), "variant_a_agent.pth")


    # ==========================================
    # 4. Verification
    # ==========================================
    print("\n=== Running Verification (Variant A) ===")


    try:
        demo_env = ScalarRewardWrapper(gym.make("CartPole-v1"))
    except NameError:
        demo_env = ScalarRewardWrapper(gym.make("CartPole-v1"))

    agent.eval()

    test_weights = [
        [1.0, 0.0],  # extreme left
        [0.7, 0.3],  # left
        [0.5, 0.5],  # middle
        [0.3, 0.7],  # right
        [0.0, 1.0]   # extreme right
    ]

    print(f"{'Weight (w1, w2)':<20} | {'Avg Position':<15} | {'Steps':<10}")
    print("-" * 50)

    for w in test_weights:
        obs, _ = demo_env.reset(options={'w': w})
        obs = torch.Tensor(obs).to(device)

        positions = []
        steps = 0

        while True:
            with torch.no_grad():
                action, _, _, _ = agent.get_action_and_value(obs)

            step_result = demo_env.step(action.item())
            real_next_obs = step_result[0]
            term = step_result[2]
            trunc = step_result[3]

            positions.append(real_next_obs[0])
            steps += 1

            obs = torch.Tensor(real_next_obs).to(device)
            if term or trunc: break

        avg_pos = np.mean(positions)
        print(f"{str(w):<20} | {avg_pos: .4f}          | {steps:<10}")

    demo_env.close()

Using device: cpu
4
Box(-inf, inf, (6,), float32)
Starting Training (Variant A: Reward Scalarization)...
Update 10/625, Loss: 0.8019, Mean Reward: 0.24
Update 20/625, Loss: 1.8382, Mean Reward: 0.26
Update 30/625, Loss: 1.1735, Mean Reward: 0.25
Update 40/625, Loss: 1.7226, Mean Reward: 0.27
Update 50/625, Loss: 1.4436, Mean Reward: 0.25
Update 60/625, Loss: 2.1940, Mean Reward: 0.28
Update 70/625, Loss: 5.4131, Mean Reward: 0.42
Update 80/625, Loss: 2.5732, Mean Reward: 0.35
Update 90/625, Loss: 0.8587, Mean Reward: 0.26
Update 100/625, Loss: 1.8339, Mean Reward: 0.30
Update 110/625, Loss: 0.8233, Mean Reward: 0.27
Update 120/625, Loss: 4.1180, Mean Reward: 0.32
Update 130/625, Loss: 0.1123, Mean Reward: 0.15
Update 140/625, Loss: 0.4919, Mean Reward: 0.26
Update 150/625, Loss: 4.4566, Mean Reward: 0.25
Update 160/625, Loss: 4.5162, Mean Reward: 0.26
Update 170/625, Loss: 3.2008, Mean Reward: 0.25
Update 180/625, Loss: 0.6855, Mean Reward: 0.29
Update 190/625, Loss: 0.2180, Mean Rewar

In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical

# ==========================================
# Variant B: Value/Q-Space Combination
# ==========================================
class SteerableCartPoleWrapper(gym.Wrapper):
    """
    """
    def __init__(self, env):
        super().__init__(env)
        original_shape = env.observation_space.shape[0]
        # state space: 4+2=6 (State + w)
        self.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf,
            shape=(original_shape + 2,),
            dtype=np.float32
        )
        self.current_w = np.array([0.5, 0.5], dtype=np.float32)

    def reset(self, seed=None, options=None):
        if options and 'w' in options:
            self.current_w = np.array(options['w'], dtype=np.float32)
        else:
            w = np.random.rand(2)
            self.current_w = w / w.sum()

        obs, info = self.env.reset(seed=seed)
        # combine State + w
        return np.concatenate([obs, self.current_w]), info

    def step(self, action):
        obs, _, terminated, truncated, info = self.env.step(action)
        x, x_dot, theta, theta_dot = obs

        norm_left = (2.4 - x) / 4.8
        norm_right = (x + 2.4) / 4.8

        r1 = norm_left ** 2
        r2 = norm_right ** 2

        if terminated:
            r1 = 0.0
            r2 = 0.0

        # reward as a vector
        vec_reward = np.array([r1, r2], dtype=np.float32)

        # save into info
        info['vec_reward'] = vec_reward

        # Gym requires step must return a scalar reward。
        # but we won't use it in our training
        scalar_reward_log = r1 + r2

        new_obs = np.concatenate([obs, self.current_w])
        return new_obs, scalar_reward_log, terminated, truncated, info

# ==========================================
# 2. Vector-Critic Agent
# ==========================================
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class VectorAgent(nn.Module):
    def __init__(self, env):
        super().__init__()
        # Critic
        # Input: 6 (State + w)
        # Output: 2 (V1 and V2)
        self.critic = nn.Sequential(
            layer_init(nn.Linear(6, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 2), std=1.0), # 2D
        )
        # Actor: the same
        self.actor = nn.Sequential(
            layer_init(nn.Linear(6, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, env.action_space.n), std=0.01),
        )

    def get_value(self, x):
        # shape: [batch_size, 2]
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        logits = self.actor(x)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(x)

# ==========================================
# 3. main loop
# ==========================================
if __name__ == "__main__":
    # hyperparameters
    learning_rate = 3e-4
    num_steps = 128
    total_timesteps = 80000
    gamma = 0.99
    gae_lambda = 0.95
    update_epochs = 4
    clip_coef = 0.2
    ent_coef = 0.001
    vf_coef = 0.5

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # initialize env and VectorAgent
    env = SteerableCartPoleWrapper(gym.make("CartPole-v1"))
    agent = VectorAgent(env).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=learning_rate)

    # initialize Buffer
    obs = torch.zeros((num_steps, 6)).to(device)
    actions = torch.zeros((num_steps,)).to(device)
    logprobs = torch.zeros((num_steps,)).to(device)
    # Rewards and Values becomes [num_steps, 2] (2D)
    rewards = torch.zeros((num_steps, 2)).to(device)
    values = torch.zeros((num_steps, 2)).to(device)
    dones = torch.zeros((num_steps,)).to(device)

    # save Context w for each step
    contexts = torch.zeros((num_steps, 2)).to(device)

    # initailize
    global_step = 0
    next_obs, _ = env.reset()
    next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.tensor(0.0).to(device)
    num_updates = total_timesteps // num_steps

    print("Starting Training (Variant B: Q-Space Scalarization)...")

    for update in range(1, num_updates + 1):

        # 1.  Rollout
        for step in range(num_steps):
            global_step += 1
            obs[step] = next_obs
            dones[step] = next_done
            # save w (State - last two pos)
            contexts[step] = next_obs[-2:]

            with torch.no_grad():
                # Value is a vector [2]
                action, logprob, _, value = agent.get_action_and_value(next_obs)
                values[step] = value.flatten()

            actions[step] = action
            logprobs[step] = logprob

            real_next_obs, _, terminated, truncated, info = env.step(action.item())
            done = terminated or truncated

            # get Vector Reward from info
            r_vec = info['vec_reward']
            rewards[step] = torch.tensor(r_vec).to(device)

            next_obs = torch.Tensor(real_next_obs).to(device)
            next_done = torch.tensor(float(done)).to(device)

            if done:
                next_obs_np, _ = env.reset()
                next_obs = torch.Tensor(next_obs_np).to(device)

        # 2. Vector GAE
        with torch.no_grad():
            # get next state Vector Value
            next_value = agent.get_value(next_obs).reshape(1, -1) # [1, 2]
            advantages = torch.zeros_like(rewards).to(device) # [128, 2]
            lastgaelam = torch.zeros(2).to(device) # [2]

            # calculate Advantage
            for t in reversed(range(num_steps)):
                if t == num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value[0]
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]

                # TD Error
                # delta -- 2D
                delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]

                # GAE recursive
                advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam

            # Returns = Advantage + Value
            returns = advantages + values

        # 3. Scalarization & PPO Update

        # advantages shape: [128, 2] (advantage for each objective)
        # contexts shape:   [128, 2] (w)

        # Scalarization:
        # Scalar Advantage = w1 * A1 + w2 * A2
        scalar_advantages = (advantages * contexts).sum(dim=1)

        # Batch data prep
        b_obs = obs
        b_logprobs = logprobs
        b_actions = actions
        b_returns = returns             # Value Loss - Vector Returns
        b_values = values               # Value - Vector
        b_scalar_advantages = scalar_advantages # Policy Loss - Scalar Advantage

        for epoch in range(update_epochs):
            _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs, b_actions)
            logratio = newlogprob - b_logprobs
            ratio = logratio.exp()

            # standardize Advantage
            mb_advantages = (b_scalar_advantages - b_scalar_advantages.mean()) / (b_scalar_advantages.std() + 1e-8)

            #  Policy Loss (Scalar Advantage)
            pg_loss1 = -mb_advantages * ratio
            pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
            pg_loss = torch.max(pg_loss1, pg_loss2).mean()

            #  Value Loss ( Vector Returns) ---
            # newvalue: [128, 2], b_returns: [128, 2]
            # MSE Loss
            v_loss = 0.5 * ((newvalue - b_returns) ** 2).mean()

            # total Loss
            loss = pg_loss - ent_coef * entropy.mean() + vf_coef * v_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if update % 20 == 0:
            # print actual Scalar Reward instead of Advantage
            train_scalar_rewards = (rewards * contexts).sum(dim=1).mean().item()
            print(f"Update {update}/{num_updates}, Loss: {loss.item():.4f}, Mean Scalar Reward: {train_scalar_rewards:.2f}")

    print("Training Finished!")
    torch.save(agent.state_dict(), "variant_b_agent.pth")

    # ==========================================
    # 4. Verification
    # ==========================================
    print("\n=== Running Verification ===")
    try:
        demo_env = SteerableCartPoleWrapper(gym.make("CartPole-v1", render_mode="human"))
    except:
        demo_env = SteerableCartPoleWrapper(gym.make("CartPole-v1"))

    agent.eval()

    test_weights = [
        [1.0, 0.0],
        [0.7, 0.3],
        [0.5, 0.5],
        [0.3, 0.7],
        [0.0, 1.0]
    ]

    print(f"{'Weight (w1, w2)':<20} | {'Avg Position':<15} | {'Steps':<10}")
    print("-" * 50)

    for w in test_weights:
        obs, _ = demo_env.reset(options={'w': w})
        obs = torch.Tensor(obs).to(device)

        positions = []
        steps = 0

        while True:
            with torch.no_grad():
                action, _, _, _ = agent.get_action_and_value(obs)

            step_result = demo_env.step(action.item())
            real_next_obs = step_result[0]
            term = step_result[2]
            trunc = step_result[3]

            positions.append(real_next_obs[0])
            steps += 1

            obs = torch.Tensor(real_next_obs).to(device)
            if term or trunc: break

        avg_pos = np.mean(positions)
        print(f"{str(w):<20} | {avg_pos: .4f}          | {steps:<10}")

    demo_env.close()

Using device: cpu
Starting Training (Variant B: Q-Space Scalarization)...
Update 20/625, Loss: 0.8657, Mean Scalar Reward: 0.23
Update 40/625, Loss: 1.1028, Mean Scalar Reward: 0.25
Update 60/625, Loss: 1.6978, Mean Scalar Reward: 0.23
Update 80/625, Loss: 3.4439, Mean Scalar Reward: 0.24
Update 100/625, Loss: 3.4219, Mean Scalar Reward: 0.27
Update 120/625, Loss: 1.6548, Mean Scalar Reward: 0.32
Update 140/625, Loss: 2.1339, Mean Scalar Reward: 0.36
Update 160/625, Loss: 1.3735, Mean Scalar Reward: 0.25
Update 180/625, Loss: 3.6503, Mean Scalar Reward: 0.26
Update 200/625, Loss: 1.9251, Mean Scalar Reward: 0.33
Update 220/625, Loss: 2.3691, Mean Scalar Reward: 0.25
Update 240/625, Loss: 0.8280, Mean Scalar Reward: 0.21
Update 260/625, Loss: 2.3836, Mean Scalar Reward: 0.31
Update 280/625, Loss: 4.5013, Mean Scalar Reward: 0.27
Update 300/625, Loss: 0.2873, Mean Scalar Reward: 0.25
Update 320/625, Loss: 0.2973, Mean Scalar Reward: 0.25
Update 340/625, Loss: 0.2667, Mean Scalar Reward: 

In [None]:
# ==========================================
# 3. Variant C: Variant C: Gradient-Space Combination
# ==========================================
# reuse SteerableCartPoleWrapper and VectorAgent class

if __name__ == "__main__":
    learning_rate = 3e-4
    num_steps = 128
    total_timesteps = 80000
    gamma = 0.99
    gae_lambda = 0.95
    update_epochs = 4
    clip_coef = 0.2
    ent_coef = 0.001
    vf_coef = 0.5

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    env = SteerableCartPoleWrapper(gym.make("CartPole-v1"))
    agent = VectorAgent(env).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=learning_rate)

    # Initialize Buffer
    obs = torch.zeros((num_steps, 6)).to(device)
    actions = torch.zeros((num_steps,)).to(device)
    logprobs = torch.zeros((num_steps,)).to(device)
    rewards = torch.zeros((num_steps, 2)).to(device)
    values = torch.zeros((num_steps, 2)).to(device)
    dones = torch.zeros((num_steps,)).to(device)
    contexts = torch.zeros((num_steps, 2)).to(device)

    global_step = 0
    next_obs, _ = env.reset()
    next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.tensor(0.0).to(device)
    num_updates = total_timesteps // num_steps

    print("Starting Training (Variant C: Gradient Mixing)...")

    for update in range(1, num_updates + 1):

        for step in range(num_steps):
            global_step += 1
            obs[step] = next_obs
            dones[step] = next_done
            contexts[step] = next_obs[-2:]

            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(next_obs)
                values[step] = value.flatten()

            actions[step] = action
            logprobs[step] = logprob

            real_next_obs, _, terminated, truncated, info = env.step(action.item())
            done = terminated or truncated

            r_vec = info['vec_reward']
            rewards[step] = torch.tensor(r_vec).to(device)

            next_obs = torch.Tensor(real_next_obs).to(device)
            next_done = torch.tensor(float(done)).to(device)

            if done:
                next_obs_np, _ = env.reset()
                next_obs = torch.Tensor(next_obs_np).to(device)

        #  Vector GAE (the same as B)
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = torch.zeros(2).to(device)

            for t in reversed(range(num_steps)):
                if t == num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value[0]
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]

                delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam

            returns = advantages + values

        # Variant C core Update

        # we keep Vector Advantage
        b_obs = obs
        b_logprobs = logprobs
        b_actions = actions
        b_returns = returns
        b_values = values
        b_vector_advantages = advantages  # [128, 2]
        b_contexts = contexts             # [128, 2]

        for epoch in range(update_epochs):
            _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs, b_actions)
            logratio = newlogprob - b_logprobs
            ratio = logratio.exp()

            # calculaye loss respectively and then add weight
            total_pg_loss = 0

            for k in range(2):
                # get kth Advantage
                adv_k = b_vector_advantages[:, k]

                # Normalization
                adv_k = (adv_k - adv_k.mean()) / (adv_k.std() + 1e-8)

                # get the Clip Loss for kth objective
                pg_loss1 = -adv_k * ratio
                pg_loss2 = -adv_k * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
                loss_k = torch.max(pg_loss1, pg_loss2).mean()

                # get weight
                w_k = b_contexts[:, k].mean()

                # Loss
                total_pg_loss += w_k * loss_k

            # Value Loss - Vector MSE
            v_loss = 0.5 * ((newvalue - b_returns) ** 2).mean()

            # total Loss
            loss = total_pg_loss - ent_coef * entropy.mean() + vf_coef * v_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if update % 20 == 0:
            train_scalar_rewards = (rewards * contexts).sum(dim=1).mean().item()
            print(f"Update {update}/{num_updates}, Loss: {loss.item():.4f}, Mean Scalar Reward: {train_scalar_rewards:.2f}")

    print("Training Finished (Variant C)!")
    torch.save(agent.state_dict(), "variant_c_agent.pth")

    # ==========================================
    # Verification
    # ==========================================
    print("\n=== Running Verification (Variant C) ===")
    try:
        demo_env = SteerableCartPoleWrapper(gym.make("CartPole-v1"))
    except:
        demo_env = SteerableCartPoleWrapper(gym.make("CartPole-v1"))

    agent.eval()

    test_weights = [
        [1.0, 0.0],
        [0.7, 0.3],
        [0.5, 0.5],
        [0.3, 0.7],
        [0.0, 1.0]
    ]

    print(f"{'Weight (w1, w2)':<20} | {'Avg Position':<15} | {'Steps':<10}")
    print("-" * 50)

    for w in test_weights:
        obs, _ = demo_env.reset(options={'w': w})
        obs = torch.Tensor(obs).to(device)

        positions = []
        steps = 0

        while True:
            with torch.no_grad():
                action, _, _, _ = agent.get_action_and_value(obs)

            step_result = demo_env.step(action.item())
            real_next_obs = step_result[0]
            term = step_result[2]
            trunc = step_result[3]

            positions.append(real_next_obs[0])
            steps += 1

            obs = torch.Tensor(real_next_obs).to(device)
            if term or trunc: break

        avg_pos = np.mean(positions)
        print(f"{str(w):<20} | {avg_pos: .4f}          | {steps:<10}")

    demo_env.close()

Using device: cpu
Starting Training (Variant C: Gradient Mixing)...
Update 20/625, Loss: 0.5886, Mean Scalar Reward: 0.24
Update 40/625, Loss: 1.7689, Mean Scalar Reward: 0.26
Update 60/625, Loss: 1.6883, Mean Scalar Reward: 0.27
Update 80/625, Loss: 2.5371, Mean Scalar Reward: 0.23
Update 100/625, Loss: 3.6777, Mean Scalar Reward: 0.27
Update 120/625, Loss: 2.4311, Mean Scalar Reward: 0.29
Update 140/625, Loss: 4.6784, Mean Scalar Reward: 0.23
Update 160/625, Loss: 1.4494, Mean Scalar Reward: 0.29
Update 180/625, Loss: 0.5523, Mean Scalar Reward: 0.28
Update 200/625, Loss: 4.7431, Mean Scalar Reward: 0.24
Update 220/625, Loss: 6.2157, Mean Scalar Reward: 0.21
Update 240/625, Loss: 5.5149, Mean Scalar Reward: 0.29
Update 260/625, Loss: 0.8777, Mean Scalar Reward: 0.22
Update 280/625, Loss: 1.7413, Mean Scalar Reward: 0.34
Update 300/625, Loss: 1.5268, Mean Scalar Reward: 0.37
Update 320/625, Loss: 2.3168, Mean Scalar Reward: 0.29
Update 340/625, Loss: 0.1880, Mean Scalar Reward: 0.25
U

In [None]:
import gymnasium as gym
import numpy as np

class SteerableWalkerWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        original_shape = env.observation_space.shape[0]

        # === 修改点 1: 变成 3 个目标 ===
        self.num_objectives = 3

        # Observation = State(17) + Preference(3) = 20
        self.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf,
            shape=(original_shape + self.num_objectives,),
            dtype=np.float32
        )
        # 初始化权重 (w1, w2, w3)
        self.current_w = np.array([0.33, 0.33, 0.33], dtype=np.float32)

    def reset(self, seed=None, options=None):
        if options and 'w' in options:
            self.current_w = np.array(options['w'], dtype=np.float32)
        else:
            # 随机采样 3 个权重并归一化
            w = np.random.rand(self.num_objectives)
            self.current_w = w / w.sum()

        obs, info = self.env.reset(seed=seed)
        return np.concatenate([obs, self.current_w]), info

    def step(self, action):
        obs, _, terminated, truncated, info = self.env.step(action)

        # === 修改点 2: 提取 3 个分量 ===
        # 1. 速度 (Forward Reward)
        r_velocity = info.get('reward_forward', 0.0)

        # 2. 存活 (Healthy Reward)
        # 注意: 倒下时这个通常就没有了，或者为0
        r_survive = info.get('reward_survive', 0.0)

        # 3. 能耗 (Control Cost)
        # Gym 返回的是正数的 cost，我们要优化的目标是"负的cost" (即 maximize -cost)
        r_energy = -info.get('reward_ctrl', 0.0)

        # 组装成 3维 向量
        vec_reward = np.array([r_velocity, r_survive, r_energy], dtype=np.float32)

        # 存入 info 供 Variant B/C 使用
        info['vec_reward'] = vec_reward

        # 计算标量奖励 (供 Variant A 使用，或者作为 log)
        # Scalar = w1*Vel + w2*Survive + w3*Energy
        weighted_scalar_reward = np.dot(self.current_w, vec_reward)

        new_obs = np.concatenate([obs, self.current_w])

        return new_obs, weighted_scalar_reward, terminated, truncated, info

In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal

# ... (请确保最上面已经包含了 SteerableWalkerWrapper 类) ...

# ==========================================
# 2. Continuous Scalar Agent (Variant A 专用)
# ==========================================
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class ContinuousScalarAgent(nn.Module):
    def __init__(self, env):
        super().__init__()
        # Observation = 17 (state) + 3 (weights) = 20
        obs_dim = env.observation_space.shape[0]
        action_dim = env.action_space.shape[0] # Walker2d = 6

        # === Critic (Scalar Output) ===
        # Variant A: Critic 只预测一个标量 Value
        self.critic = nn.Sequential(
            layer_init(nn.Linear(obs_dim, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )

        # === Actor (Continuous Mean) ===
        # 和 Variant B 一模一样
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(obs_dim, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, action_dim), std=0.01),
        )
        self.actor_logstd = nn.Parameter(torch.zeros(1, action_dim))

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)

        probs = Normal(action_mean, action_std)

        if action is None:
            action = probs.sample()

        log_prob = probs.log_prob(action).sum(1)
        entropy = probs.entropy().sum(1)

        return action, log_prob, entropy, self.critic(x)

# ==========================================
# 3. Main Training Loop (Variant A: Reward Scalarization)
# ==========================================
if __name__ == "__main__":
    # === Hyperparameters ===
    learning_rate = 3e-4
    num_steps = 2048
    total_timesteps = 1000000
    gamma = 0.99
    gae_lambda = 0.95
    update_epochs = 10
    clip_coef = 0.2
    ent_coef = 0.0
    vf_coef = 0.5
    max_grad_norm = 0.5
    batch_size = 64

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize Env
    env = SteerableWalkerWrapper(gym.make("Walker2d-v4"))

    # 使用 Scalar Agent
    agent = ContinuousScalarAgent(env).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=learning_rate)

    # Initialize Buffers
    obs = torch.zeros((num_steps, env.observation_space.shape[0])).to(device)
    actions = torch.zeros((num_steps, env.action_space.shape[0])).to(device)
    logprobs = torch.zeros((num_steps,)).to(device)

    # Variant A: Rewards 和 Values 都是 1维 Scalar
    rewards = torch.zeros((num_steps,)).to(device)
    values = torch.zeros((num_steps,)).to(device)
    dones = torch.zeros((num_steps,)).to(device)

    # 依然需要记录 w 只是为了打印 log 或者 debug，但不需要参与计算
    # 这里我们不用 contexts 来算 loss，因为 reward 已经是加权好的了

    global_step = 0
    next_obs, _ = env.reset()
    next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.tensor(0.0).to(device)

    num_updates = total_timesteps // num_steps

    print(f"Starting Walker2d Training (Variant A)... Target: {total_timesteps} steps")

    for update in range(1, num_updates + 1):

        # 1. Rollout
        for step in range(num_steps):
            global_step += 1
            obs[step] = next_obs
            dones[step] = next_done

            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(next_obs)
                values[step] = value.flatten() # 变成标量

            actions[step] = action
            logprobs[step] = logprob

            # Step
            real_next_obs, scalar_reward, terminated, truncated, info = env.step(action.cpu().numpy())
            done = terminated or truncated

            # Variant A: 直接存 Wrapper 算好的 scalar_reward
            rewards[step] = torch.tensor(scalar_reward).to(device)

            next_obs = torch.Tensor(real_next_obs).to(device)
            next_done = torch.tensor(float(done)).to(device)

            if done:
                next_obs_np, _ = env.reset()
                next_obs = torch.Tensor(next_obs_np).to(device)

        # 2. Standard Scalar GAE
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = 0

            for t in reversed(range(num_steps)):
                if t == num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]

                # Scalar Delta
                delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam

            returns = advantages + values

        # 3. Standard PPO Update
        b_obs = obs.reshape((-1, env.observation_space.shape[0]))
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1, env.action_space.shape[0]))
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

        b_inds = np.arange(num_steps)

        for epoch in range(update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, num_steps, batch_size):
                end = start + batch_size
                mb_inds = b_inds[start:end]

                _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                mb_adv = b_advantages[mb_inds]
                mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8)

                pg_loss1 = -mb_adv * ratio
                pg_loss2 = -mb_adv * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value Loss (Scalar MSE)
                v_loss = 0.5 * ((newvalue.view(-1) - b_returns[mb_inds]) ** 2).mean()

                loss = pg_loss - ent_coef * entropy.mean() + vf_coef * v_loss

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)
                optimizer.step()

        if update % 10 == 0:
            print(f"Update {update}/{num_updates} | Loss: {loss.item():.4f} | Mean Reward: {rewards.mean().item():.2f}")

    print("Training Finished (Variant A)!")
    torch.save(agent.state_dict(), "walker_variant_a_scalar.pth")

In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal

# ... (你的 SteerableWalkerWrapper 代码应该在上面) ...

# ==========================================
# 2. Continuous Vector-Critic Agent (Walker2d专用)
# ==========================================
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class ContinuousVectorAgent(nn.Module):
    def __init__(self, env, num_objectives=3):
        super().__init__()
        # 获取维度
        # Observation = 17 (state) + 3 (weights) = 20
        obs_dim = env.observation_space.shape[0]
        action_dim = env.action_space.shape[0] # Walker2d = 6

        self.num_objectives = num_objectives

        # === Critic (Vector Output) ===
        # 输出维度 = 3 (Vel, Survive, Energy)
        self.critic = nn.Sequential(
            layer_init(nn.Linear(obs_dim, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, num_objectives), std=1.0),
        )

        # === Actor (Continuous Mean) ===
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(obs_dim, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, action_dim), std=0.01),
        )
        # Learnable Log Std (不依赖状态，是一个独立参数)
        self.actor_logstd = nn.Parameter(torch.zeros(1, action_dim))

    def get_value(self, x):
        # returns [batch, 3]
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)

        # 使用正态分布
        probs = Normal(action_mean, action_std)

        if action is None:
            action = probs.sample()

        # 连续动作空间的 log_prob 需要对所有维度求和
        log_prob = probs.log_prob(action).sum(1)
        entropy = probs.entropy().sum(1)

        return action, log_prob, entropy, self.critic(x)

# ==========================================
# 3. Main Training Loop (Variant B)
# ==========================================
if __name__ == "__main__":
    # === Hyperparameters for MuJoCo ===
    learning_rate = 3e-4
    num_steps = 2048           # 增加: MuJoCo 需要更长的 horizon
    total_timesteps = 1000000  # 增加: 1M steps
    gamma = 0.99
    gae_lambda = 0.95
    update_epochs = 10         # 增加: 稍微多一点 epoch
    clip_coef = 0.2
    ent_coef = 0.0             # 连续控制通常不需要太高的熵系数
    vf_coef = 0.5
    max_grad_norm = 0.5
    batch_size = 64            # Minibatch size

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize Env
    # 确保你的 Wrapper 名字是对的
    env = SteerableWalkerWrapper(gym.make("Walker2d-v4"))

    agent = ContinuousVectorAgent(env, num_objectives=3).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=learning_rate)

    # Initialize Buffers
    # Obs 维度包含权重
    obs = torch.zeros((num_steps, env.observation_space.shape[0])).to(device)
    # Action 维度不再是 scalar，而是 vector [num_steps, 6]
    actions = torch.zeros((num_steps, env.action_space.shape[0])).to(device)
    logprobs = torch.zeros((num_steps,)).to(device)

    # Rewards/Values 变成 [num_steps, 3]
    rewards = torch.zeros((num_steps, 3)).to(device)
    values = torch.zeros((num_steps, 3)).to(device)

    dones = torch.zeros((num_steps,)).to(device)
    contexts = torch.zeros((num_steps, 3)).to(device) # save weights

    global_step = 0
    next_obs, _ = env.reset()
    next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.tensor(0.0).to(device)

    num_updates = total_timesteps // num_steps

    print(f"Starting Walker2d Training (3 Objectives)... Target: {total_timesteps} steps")

    for update in range(1, num_updates + 1):

        # 1. Rollout Phase
        for step in range(num_steps):
            global_step += 1
            obs[step] = next_obs
            dones[step] = next_done
            # 保存当前的权重 w (Observation 的最后3位)
            contexts[step] = next_obs[-3:]

            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(next_obs)
                values[step] = value # [3]

            actions[step] = action
            logprobs[step] = logprob

            # Step environment
            # Action 需要转回 numpy
            real_next_obs, _, terminated, truncated, info = env.step(action.cpu().numpy())
            done = terminated or truncated

            # 获取 3D Reward
            r_vec = info['vec_reward']
            rewards[step] = torch.tensor(r_vec).to(device)

            next_obs = torch.Tensor(real_next_obs).to(device)
            next_done = torch.tensor(float(done)).to(device)

            if done:
                next_obs_np, _ = env.reset()
                next_obs = torch.Tensor(next_obs_np).to(device)

        # 2. Vector GAE Calculation
        with torch.no_grad():
            next_value = agent.get_value(next_obs) # [3] vector
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = torch.zeros(3).to(device) # [3] vector

            for t in reversed(range(num_steps)):
                if t == num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]

                # Vector Delta
                delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam

            returns = advantages + values

        # 3. Scalarization & PPO Update
        # Variant B: 在 Advantage 层面进行加权求和
        # scalar_advantages shape: [num_steps]
        scalar_advantages = (advantages * contexts).sum(dim=1)

        # Flatten batch
        b_obs = obs.reshape((-1, env.observation_space.shape[0]))
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1, env.action_space.shape[0]))
        b_scalar_advantages = scalar_advantages.reshape(-1)
        b_returns = returns.reshape((-1, 3))
        b_values = values.reshape((-1, 3))

        # Mini-batch Update
        # 创建索引
        b_inds = np.arange(num_steps)
        clipfracs = []

        for epoch in range(update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, num_steps, batch_size):
                end = start + batch_size
                mb_inds = b_inds[start:end]

                _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                mb_adv = b_scalar_advantages[mb_inds]
                # Normalize Advantage
                mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8)

                # Policy Loss
                pg_loss1 = -mb_adv * ratio
                pg_loss2 = -mb_adv * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value Loss (Vector MSE)
                # 计算 3个 objective 的平均 MSE
                v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                loss = pg_loss - ent_coef * entropy.mean() + vf_coef * v_loss

                optimizer.zero_grad()
                loss.backward()
                # Gradient Clipping
                nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)
                optimizer.step()

        if update % 10 == 0:
            # 打印当前所有目标加权后的平均奖励
            train_scalar_rewards = (rewards * contexts).sum(dim=1).sum().item() / num_steps
            print(f"Update {update}/{num_updates} | Loss: {loss.item():.4f} | Avg Weighted Reward: {train_scalar_rewards:.2f}")

    print("Training Finished!")
    torch.save(agent.state_dict(), "walker_variant_b_3obj.pth")

In [None]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.normal import Normal

# ... (请确保 SteerableWalkerWrapper 和 ContinuousVectorAgent 类已经在上面定义了) ...
# ... (可以直接复用 Variant B 里面定义的 ContinuousVectorAgent) ...

# ==========================================
# 3. Main Training Loop (Variant C: Gradient Mixing)
# ==========================================
if __name__ == "__main__":
    # === Hyperparameters ===
    learning_rate = 3e-4
    num_steps = 2048
    total_timesteps = 1000000
    gamma = 0.99
    gae_lambda = 0.95
    update_epochs = 10
    clip_coef = 0.2
    ent_coef = 0.0
    vf_coef = 0.5
    max_grad_norm = 0.5
    batch_size = 64

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize Env
    env = SteerableWalkerWrapper(gym.make("Walker2d-v4"))

    # Variant C 也需要 Vector Critic 来分别估计每个目标的 Advantage
    agent = ContinuousVectorAgent(env, num_objectives=3).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=learning_rate)

    # Initialize Buffers
    obs = torch.zeros((num_steps, env.observation_space.shape[0])).to(device)
    actions = torch.zeros((num_steps, env.action_space.shape[0])).to(device)
    logprobs = torch.zeros((num_steps,)).to(device)

    # 3D Rewards & Values
    rewards = torch.zeros((num_steps, 3)).to(device)
    values = torch.zeros((num_steps, 3)).to(device)
    dones = torch.zeros((num_steps,)).to(device)

    # 必须保存 context (权重)，用于在 Loss 阶段进行加权
    contexts = torch.zeros((num_steps, 3)).to(device)

    global_step = 0
    next_obs, _ = env.reset()
    next_obs = torch.Tensor(next_obs).to(device)
    next_done = torch.tensor(0.0).to(device)

    num_updates = total_timesteps // num_steps

    print(f"Starting Walker2d Training (Variant C: Multi-Objective Gradient)...")

    for update in range(1, num_updates + 1):

        # 1. Rollout (和 Variant B 一样)
        for step in range(num_steps):
            global_step += 1
            obs[step] = next_obs
            dones[step] = next_done
            contexts[step] = next_obs[-3:]

            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(next_obs)
                values[step] = value

            actions[step] = action
            logprobs[step] = logprob

            real_next_obs, _, terminated, truncated, info = env.step(action.cpu().numpy())
            done = terminated or truncated

            r_vec = info['vec_reward']
            rewards[step] = torch.tensor(r_vec).to(device)

            next_obs = torch.Tensor(real_next_obs).to(device)
            next_done = torch.tensor(float(done)).to(device)

            if done:
                next_obs_np, _ = env.reset()
                next_obs = torch.Tensor(next_obs_np).to(device)

        # 2. Vector GAE (和 Variant B 一样，分别计算 3 个 Advantage)
        with torch.no_grad():
            next_value = agent.get_value(next_obs)
            advantages = torch.zeros_like(rewards).to(device)
            lastgaelam = torch.zeros(3).to(device)

            for t in reversed(range(num_steps)):
                if t == num_steps - 1:
                    nextnonterminal = 1.0 - next_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]

                delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
                advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam

            returns = advantages + values

        # 3. Variant C Special Update: Gradient Mixing
        # 我们保留 3D Advantage，不要在这里做加权求和

        b_obs = obs.reshape((-1, env.observation_space.shape[0]))
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1, env.action_space.shape[0]))
        b_vector_advantages = advantages.reshape((-1, 3)) # [Batch, 3]
        b_returns = returns.reshape((-1, 3))
        b_contexts = contexts.reshape((-1, 3))            # [Batch, 3]

        b_inds = np.arange(num_steps)

        for epoch in range(update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, num_steps, batch_size):
                end = start + batch_size
                mb_inds = b_inds[start:end]

                # 获取新的 logprob
                _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                # === Variant C 核心部分 ===
                total_pg_loss = 0

                # 遍历 3 个目标，分别计算 Loss
                for k in range(3):
                    # 获取第 k 个目标的 Advantage
                    adv_k = b_vector_advantages[mb_inds, k]

                    # Per-objective Normalization (对于多目标训练非常重要)
                    adv_k = (adv_k - adv_k.mean()) / (adv_k.std() + 1e-8)

                    # 计算第 k 个目标的 PPO Loss (Independent Clipping)
                    # 注意：这里的 clip 仅受当前目标的 advantage 影响
                    loss1 = -adv_k * ratio
                    loss2 = -adv_k * torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)

                    # 得到 element-wise 的 loss [batch_size]
                    loss_k_element = torch.max(loss1, loss2)

                    # 获取第 k 个目标的权重 w_k
                    w_k = b_contexts[mb_inds, k]

                    # 加权累加 Loss
                    # mean() 是对 batch 取平均
                    total_pg_loss += (w_k * loss_k_element).mean()

                # Value Loss (Vector MSE)
                v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                # Total Loss
                loss = total_pg_loss - ent_coef * entropy.mean() + vf_coef * v_loss

                optimizer.zero_grad()
                loss.backward() # 这里实际上完成了 Gradients 的加权求和
                nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)
                optimizer.step()

        if update % 10 == 0:
            # 这里的 log 还是打印加权后的奖励，方便观察
            train_scalar_rewards = (rewards * contexts).sum(dim=1).sum().item() / num_steps
            print(f"Update {update}/{num_updates} | Loss: {loss.item():.4f} | Avg Weighted Reward: {train_scalar_rewards:.2f}")

    print("Training Finished (Variant C)!")
    torch.save(agent.state_dict(), "walker_variant_c_grad.pth")