In [None]:
!pip install mo-gymnasium morl-baselines wandb



In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
from morl_baselines.common.evaluation import eval_mo
import mo_gymnasium as mo_gym
from morl_baselines.multi_policy.envelope.envelope import Envelope
import numpy as np

In [None]:
# from morl_baselines.common.evaluation import eval_mo
# import mo_gymnasium as mo_gym
# from morl_baselines.multi_policy.envelope.envelope import Envelope
# import numpy as np

# env = mo_gym.make("resource-gathering-v0")
# eval_env = mo_gym.make("resource-gathering-v0")

# agent = Envelope(
#     env,
#     log=True,
# )

# agent.train(
#     total_timesteps=1000,
#     eval_env=eval_env,
#     ref_point=np.array([0.0, 0.0, -200.0]),
#     eval_freq=100,
# )

# scalar_return, scalarized_disc_return, vec_ret, vec_disc_ret = eval_mo(agent, env=eval_env, w=np.array([0.5, 0.4, 0.1]))

In [None]:
"""FMDQ (Fair Multi-Objective DQN) implementation."""

import os
from typing import List, Optional, Union
from typing_extensions import override

import gymnasium as gym
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import wandb

from morl_baselines.common.buffer import ReplayBuffer
from morl_baselines.common.evaluation import (
    log_all_multi_policy_metrics,
    log_episode_info,
)
from morl_baselines.common.morl_algorithm import MOAgent, MOPolicy
from morl_baselines.common.networks import (
    NatureCNN,
    get_grad_norm,
    layer_init,
    mlp,
    polyak_update,
)
from morl_baselines.common.prioritized_buffer import PrioritizedReplayBuffer
from morl_baselines.common.utils import linearly_decaying_value
from morl_baselines.common.weights import equally_spaced_weights, random_weights


def np_ggf_scalarization(v: np.ndarray, w: np.ndarray) -> float:
    """Numpy GGF scalarization for logging."""
    # Ensure w is 1d
    if w.ndim > 1:
        w = w.flatten()
    # Sort v ascending, w descending
    v_sorted = np.sort(v)
    w_sorted = np.sort(w)[::-1]
    return np.dot(v_sorted, w_sorted)


class VectorQNet(nn.Module):
    """Multi-objective Q-Network (outputs a Q-vector for each action)."""

    def __init__(self, obs_shape, action_dim, rew_dim, net_arch):
        """Initialize the Q network.

        Args:
            obs_shape: shape of the observation
            action_dim: number of actions
            rew_dim: number of objectives
            net_arch: network architecture (number of units per layer)
        """
        super().__init__()
        self.obs_shape = obs_shape
        self.action_dim = action_dim
        self.rew_dim = rew_dim
        if len(obs_shape) == 1:
            self.feature_extractor = None
            input_dim = obs_shape[0]
        elif len(obs_shape) > 1:  # Image observation
            self.feature_extractor = NatureCNN(self.obs_shape, features_dim=512)
            input_dim = self.feature_extractor.features_dim
        else:
            raise ValueError(f"Invalid observation shape: {obs_shape}")

        # |S| -> ... -> |A| * |R|
        self.net = mlp(input_dim, action_dim * rew_dim, net_arch)
        self.apply(layer_init)

    def forward(self, obs):
        """Predict Q-vectors for all actions.

        Args:
            obs: current observation

        Returns:
            Q-vectors for all actions, shape (batch_size, action_dim, rew_dim)
        """
        if self.feature_extractor is not None:
            features = self.feature_extractor(obs)
        else:
            # Handle non-batched observations for MLP
            if obs.dim() == 1:
                obs = obs.unsqueeze(0)
            features = obs

        q_values = self.net(features)
        return q_values.view(-1, self.action_dim, self.rew_dim)  # Batch size X Actions X Rewards


class FMDQ(MOPolicy, MOAgent):
    """FMDQ Algorithm.

    This algorithm learns a vector-valued Q-function and uses the Generalized Gini Welfare (GGF)
    function for action selection and target updates to learn a set of "fair" Pareto-optimal policies.
    It is based on the paper "Learning Fair Pareto Optimal Policies in Multi-Objective Reinforcement Learning".
    """

    def __init__(
        self,
        env,
        learning_rate: float = 3e-4,
        initial_epsilon: float = 1.0,
        final_epsilon: float = 0.05,
        epsilon_decay_steps: int = 50000,
        tau: float = 1.0,
        target_net_update_freq: int = 200,  # ignored if tau != 1.0
        buffer_size: int = int(1e6),
        net_arch: List = [256, 256],
        batch_size: int = 256,
        learning_starts: int = 100,
        gradient_updates: int = 1,
        gamma: float = 0.99,
        max_grad_norm: Optional[float] = 1.0,
        per: bool = True,
        per_alpha: float = 0.6,
        project_name: str = "MORL-Baselines",
        experiment_name: str = "FMDQ",
        wandb_entity: Optional[str] = None,
        log: bool = True,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        group: Optional[str] = None,
    ):
        """FMDQ algorithm.

        Args:
            env: The environment to learn from.
            learning_rate: The learning rate (alpha).
            initial_epsilon: The initial epsilon value for epsilon-greedy exploration.
            final_epsilon: The final epsilon value for epsilon-greedy exploration.
            epsilon_decay_steps: The number of steps to decay epsilon over.
            tau: The soft update coefficient (keep in [0, 1]).
            target_net_update_freq: The frequency with which the target network is updated.
            buffer_size: The size of the replay buffer.
            net_arch: The size of the hidden layers of the value net.
            batch_size: The size of the batch to sample from the replay buffer.
            learning_starts: The number of steps before learning starts.
            gradient_updates: The number of gradient updates per step.
            gamma: The discount factor (gamma).
            max_grad_norm: The maximum norm for the gradient clipping. If None, no gradient clipping is applied.
            per: Whether to use prioritized experience replay.
            per_alpha: The alpha parameter for prioritized experience replay.
            project_name: The name of the project, for wandb logging.
            experiment_name: The name of the experiment, for wandb logging.
            wandb_entity: The entity of the project, for wandb logging.
            log: Whether to log to wandb.
            seed: The seed for the random number generator.
            device: The device to use for training.
            group: The wandb group to use for logging.
        """
        MOAgent.__init__(self, env, device=device, seed=seed)
        MOPolicy.__init__(self, device)
        self.learning_rate = learning_rate
        self.initial_epsilon = initial_epsilon
        self.epsilon = initial_epsilon
        self.epsilon_decay_steps = epsilon_decay_steps
        self.final_epsilon = final_epsilon
        self.tau = tau
        self.target_net_update_freq = target_net_update_freq
        self.gamma = gamma
        self.max_grad_norm = max_grad_norm
        self.buffer_size = buffer_size
        self.net_arch = net_arch
        self.learning_starts = learning_starts
        self.batch_size = batch_size
        self.per = per
        self.per_alpha = per_alpha
        self.gradient_updates = gradient_updates

        self.q_net = VectorQNet(self.observation_shape, self.action_dim, self.reward_dim, net_arch=net_arch).to(self.device)
        self.target_q_net = VectorQNet(self.observation_shape, self.action_dim, self.reward_dim, net_arch=net_arch).to(
            self.device
        )
        self.target_q_net.load_state_dict(self.q_net.state_dict())
        for param in self.target_q_net.parameters():
            param.requires_grad = False

        self.q_optim = optim.Adam(self.q_net.parameters(), lr=self.learning_rate)

        if self.per:
            self.replay_buffer = PrioritizedReplayBuffer(
                self.observation_shape,
                1,
                rew_dim=self.reward_dim,
                max_size=buffer_size,
                action_dtype=np.uint8,
            )
        else:
            self.replay_buffer = ReplayBuffer(
                self.observation_shape,
                1,
                rew_dim=self.reward_dim,
                max_size=buffer_size,
                action_dtype=np.uint8,
            )

        self.log = log
        if log:
            self.setup_wandb(project_name, experiment_name, wandb_entity, group)

    @override
    def get_config(self):
        return {
            "env_id": self.env.unwrapped.spec.id,
            "learning_rate": self.learning_rate,
            "initial_epsilon": self.initial_epsilon,
            "epsilon_decay_steps": self.epsilon_decay_steps,
            "batch_size": self.batch_size,
            "tau": self.tau,
            "clip_grand_norm": self.max_grad_norm,
            "target_net_update_freq": self.target_net_update_freq,
            "gamma": self.gamma,
            "net_arch": self.net_arch,
            "per": self.per,
            "gradient_updates": self.gradient_updates,
            "buffer_size": self.buffer_size,
            "learning_starts": self.learning_starts,
            "seed": self.seed,
        }

    def save(self, save_replay_buffer: bool = True, save_dir: str = "weights/", filename: Optional[str] = None):
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        saved_params = {}
        saved_params["q_net_state_dict"] = self.q_net.state_dict()

        saved_params["q_net_optimizer_state_dict"] = self.q_optim.state_dict()
        if save_replay_buffer:
            saved_params["replay_buffer"] = self.replay_buffer
        filename = self.experiment_name if filename is None else filename
        th.save(saved_params, save_dir + "/" + filename + ".tar")

    def load(self, path: str, load_replay_buffer: bool = True):
        params = th.load(path, weights_only=False)
        self.q_net.load_state_dict(params["q_net_state_dict"])
        self.target_q_net.load_state_dict(params["q_net_state_dict"])
        self.q_optim.load_state_dict(params["q_net_optimizer_state_dict"])
        if load_replay_buffer and "replay_buffer" in params:
            self.replay_buffer = params["replay_buffer"]

    def __sample_batch_experiences(self):
        return self.replay_buffer.sample(self.batch_size, to_tensor=True, device=self.device)

    def ggf_scalarization(self, q_vectors: th.Tensor, w: th.Tensor) -> th.Tensor:
        """Applies GGF scalarization.

        Args:
            q_vectors: (batch_size, num_actions, rew_dim)
            w: (batch_size, 1, rew_dim) or (batch_size, rew_dim)

        Returns:
            Scalarized Q-values (batch_size, num_actions)
        """
        if w.dim() == 2:  # (B, R)
            w = w.unsqueeze(1)  # (B, 1, R)
        # Sort Q-vectors ascending along the reward dimension
        q_sorted, _ = th.sort(q_vectors, dim=-1, descending=False)
        # Sort weights descending along the reward dimension
        w_sorted, _ = th.sort(w, dim=-1, descending=True)
        # GGF is the dot product of sorted vectors
        # w_sorted broadcasts across the action dimension
        return th.sum(q_sorted * w_sorted, dim=-1)

    def ggf_scalarization_vec(self, v: th.Tensor, w: th.Tensor) -> th.Tensor:
        """Applies GGF scalarization to 1D vectors.

        Args:
            v: (batch_size, rew_dim)
            w: (batch_size, rew_dim)

        Returns:
            Scalarized values (batch_size,)
        """
        v_sorted, _ = th.sort(v, dim=-1, descending=False)
        w_sorted, _ = th.sort(w, dim=-1, descending=True)
        return th.sum(v_sorted * w_sorted, dim=-1)

    @override
    def update(self):
        critic_losses = []
        for g in range(self.gradient_updates):
            if self.per:
                (
                    b_obs,
                    b_actions,
                    b_rewards,
                    b_next_obs,
                    b_dones,
                    b_inds,
                ) = self.__sample_batch_experiences()
            else:
                (
                    b_obs,
                    b_actions,
                    b_rewards,
                    b_next_obs,
                    b_dones,
                ) = self.__sample_batch_experiences()

            # Sample a weight vector for each transition in the batch
            b_w = (
                th.tensor(random_weights(dim=self.reward_dim, n=self.batch_size, dist="gaussian", rng=self.np_random))
                .float()
                .to(self.device)
            )

            with th.no_grad():
                target_q_vecs = self.fmdq_target(b_next_obs, b_w)
                target_q = b_rewards + (1 - b_dones) * self.gamma * target_q_vecs

            q_values = self.q_net(b_obs)
            q_value = q_values.gather(
                1,
                b_actions.long().reshape(-1, 1, 1).expand(q_values.size(0), 1, q_values.size(2)),
            )
            q_value = q_value.squeeze(1)  # (batch_size, rew_dim)

            critic_loss = F.mse_loss(q_value, target_q)  # Vector-valued MSE loss

            self.q_optim.zero_grad()
            critic_loss.backward()

            if self.log and self.global_step % 100 == 0:
                wandb.log(
                    {
                        "losses/grad_norm": get_grad_norm(self.q_net.parameters()).item(),
                        "global_step": self.global_step,
                    },
                )
            if self.max_grad_norm is not None:
                th.nn.utils.clip_grad_norm_(self.q_net.parameters(), self.max_grad_norm)
            self.q_optim.step()
            critic_losses.append(critic_loss.item())

            if self.per:
                # Calculate TD-error using GGF scalarization for priority
                td_err = (q_value - target_q).detach()
                priority = self.ggf_scalarization_vec(td_err, b_w).abs()
                priority = priority.cpu().numpy().flatten()
                priority = (priority + self.replay_buffer.min_priority) ** self.per_alpha
                self.replay_buffer.update_priorities(b_inds, priority)

        if self.tau != 1 or self.global_step % self.target_net_update_freq == 0:
            polyak_update(self.q_net.parameters(), self.target_q_net.parameters(), self.tau)

        if self.epsilon_decay_steps is not None:
            self.epsilon = linearly_decaying_value(
                self.initial_epsilon,
                self.epsilon_decay_steps,
                self.global_step,
                self.learning_starts,
                self.final_epsilon,
            )

        if self.log and self.global_step % 100 == 0:
            wandb.log(
                {
                    "losses/critic_loss": np.mean(critic_losses),
                    "metrics/epsilon": self.epsilon,
                    "global_step": self.global_step,
                },
            )
            if self.per:
                wandb.log({"metrics/mean_priority": np.mean(priority)})

    @override
    def eval(self, obs: np.ndarray, w: np.ndarray) -> int:
        obs_tensor = th.as_tensor(obs).float().to(self.device)
        w_tensor = th.as_tensor(w).float().to(self.device)
        return self.max_action(obs_tensor, w_tensor)

    def act(self, obs: th.Tensor, w: th.Tensor) -> int:
        """Epsilon-greedily select an action given an observation and weight."""
        if self.np_random.random() < self.epsilon:
            return self.env.action_space.sample()
        else:
            return self.max_action(obs, w)

    @th.no_grad()
    def max_action(self, obs: th.Tensor, w: th.Tensor) -> int:
        """Select the action with the highest GGF Q-value."""
        # Add batch dim if missing
        if w.dim() == 1:
            w = w.unsqueeze(0)
        if obs.dim() < len(self.observation_shape) + 1:
            obs = obs.unsqueeze(0)

        q_values = self.q_net(obs)  # (1, num_actions, rew_dim)
        scalarized_q_values = self.ggf_scalarization(q_values, w.unsqueeze(1))  # (1, num_actions)
        max_act = th.argmax(scalarized_q_values, dim=1)
        return max_act.detach().item()

    @th.no_grad()
    def fmdq_target(self, obs: th.Tensor, w: th.Tensor) -> th.Tensor:
        """FMDQ target (DDQN target with GGF for action selection).

        Args:
            obs: Next observation (batch_size, obs_shape)
            w: Weight vector (batch_size, rew_dim)

        Returns:
            The target Q-vector (batch_size, rew_dim)
        """
        # 1. Select best action a' using main Q-net and GGF
        q_values = self.q_net(obs)  # (B, A, R)
        scalarized_q_values = self.ggf_scalarization(q_values, w.unsqueeze(1))  # (B, A)
        max_acts = th.argmax(scalarized_q_values, dim=1)  # (B,)

        # 2. Get Q-vector for a' from target Q-net
        q_values_target = self.target_q_net(obs)  # (B, A, R)

        # Gather the Q-vector corresponding to max_acts
        q_values_target = q_values_target.gather(
            1,
            max_acts.long().reshape(-1, 1, 1).expand(-1, 1, self.reward_dim),
        )
        q_values_target = q_values_target.squeeze(1)  # (B, R)
        return q_values_target

    def train(
        self,
        total_timesteps: int,
        eval_env: Optional[gym.Env] = None,
        ref_point: Optional[np.ndarray] = None,
        known_pareto_front: Optional[List[np.ndarray]] = None,
        weight: Optional[np.ndarray] = None,
        total_episodes: Optional[int] = None,
        reset_num_timesteps: bool = True,
        eval_freq: int = 10000,
        num_eval_weights_for_front: int = 100,
        num_eval_episodes_for_front: int = 5,
        num_eval_weights_for_eval: int = 50,
        reset_learning_starts: bool = False,
        verbose: bool = False,
    ):
        """Train the agent."""
        if eval_env is not None:
            assert ref_point is not None, "Reference point must be provided for the hypervolume computation."
        if self.log:
            self.register_additional_config(
                {
                    "total_timesteps": total_timesteps,
                    "ref_point": ref_point.tolist() if ref_point is not None else None,
                    "known_front": known_pareto_front,
                    "weight": weight.tolist() if weight is not None else None,
                    "total_episodes": total_episodes,
                    "reset_num_timesteps": reset_num_timesteps,
                    "eval_freq": eval_freq,
                    "num_eval_weights_for_front": num_eval_weights_for_front,
                    "num_eval_episodes_for_front": num_eval_episodes_for_front,
                    "num_eval_weights_for_eval": num_eval_weights_for_eval,
                    "reset_learning_starts": reset_learning_starts,
                }
            )

        self.global_step = 0 if reset_num_timesteps else self.global_step
        self.num_episodes = 0 if reset_num_timesteps else self.num_episodes
        if reset_learning_starts:
            self.learning_starts = self.global_step

        num_episodes = 0
        eval_weights = equally_spaced_weights(self.reward_dim, n=num_eval_weights_for_front)
        obs, _ = self.env.reset()

        w = weight if weight is not None else random_weights(self.reward_dim, 1, dist="gaussian", rng=self.np_random)
        tensor_w = th.tensor(w).float().to(self.device)

        for _ in range(1, total_timesteps + 1):
            if total_episodes is not None and num_episodes == total_episodes:
                break

            if self.global_step < self.learning_starts:
                action = self.env.action_space.sample()
            else:
                action = self.act(th.as_tensor(obs).float().to(self.device), tensor_w)

            next_obs, vec_reward, terminated, truncated, info = self.env.step(action)
            self.global_step += 1

            self.replay_buffer.add(obs, action, vec_reward, next_obs, terminated)
            if self.global_step >= self.learning_starts:
                self.update()

            if eval_env is not None and self.log and self.global_step % eval_freq == 0:
                current_front = [
                    self.policy_eval(eval_env, weights=ew, num_episodes=num_eval_episodes_for_front, log=self.log)[3]
                    for ew in eval_weights
                ]
                log_all_multi_policy_metrics(
                    current_front=current_front,
                    hv_ref_point=ref_point,
                    reward_dim=self.reward_dim,
                    global_step=self.global_step,
                    n_sample_weights=num_eval_weights_for_eval,
                    ref_front=known_pareto_front,
                )

            if terminated or truncated:
                obs, _ = self.env.reset()
                num_episodes += 1
                self.num_episodes += 1

                if self.log and "episode" in info.keys():
                    # Log with GGF scalarization instead of dot product
                    log_episode_info(info["episode"], np_ggf_scalarization, w, self.global_step, verbose=verbose)

                if weight is None:
                    w = random_weights(self.reward_dim, 1, dist="gaussian", rng=self.np_random)
                    tensor_w = th.tensor(w).float().to(self.device)

            else:
                obs = next_obs

In [None]:
# from morl_baselines.common.evaluation import eval_mo
# import mo_gymnasium as mo_gym

# env = mo_gym.make("resource-gathering-v0")
# eval_env = mo_gym.make("resource-gathering-v0")

# agent = FMDQ(
#     env,
#     log=True,
#     project_name = "Exp-1"
# )

# agent.train(
#     total_timesteps=5000,
#     eval_env=eval_env,
#     ref_point=np.array([0.0, 0.0, -200.0]),
#     eval_freq=10,
# )

# scalar_return, scalarized_disc_return, vec_ret, vec_disc_ret = eval_mo(agent, env=eval_env, w=np.array([0.5, 0.4, 0.1]))


[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33m2201ai04_ammar[0m ([33m2201ai04_ammar-indian-institute-of-technology-patna[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pymoo.indicators.hv import HV
import mo_gymnasium as mo_gym
from morl_baselines.common.evaluation import eval_mo
from morl_baselines.multi_policy.envelope.envelope import Envelope

def compute_metrics(name, returns):
    print(f"\n===== {name} RESULTS =====")

    # Hypervolume
    ref = np.array([0.0, 0.0])
    hv = HV(ref_point=ref)(returns)
    print("Hypervolume:", hv)

    # Cardinality
    cd = len(returns)
    print("Cardinality:", cd)

    # Totals, max, min
    total_resources = returns.sum(axis=1)
    print("Avg Total:", total_resources.mean())
    print("Avg Max:", returns.max(axis=1).mean())
    print("Avg Min:", returns.min(axis=1).mean())

    # CV
    cv_per_obj = returns.std(axis=0) / returns.mean(axis=0)
    cv_total = total_resources.std() / total_resources.mean()
    print("CV per objective:", cv_per_obj)
    print("CV total:", cv_total)

    # GGF
    def ggf(v):
        v = np.array(v)
        d = len(v)
        v_sorted = np.sort(v)
        lambdas = np.array([(d - i) / d for i in range(1, d+1)])
        return np.sum(lambdas * v_sorted)

    ggf_scores = np.array([ggf(v) for v in returns])
    print("GGF mean:", ggf_scores.mean())
    print("GGF min:", ggf_scores.min())
    print("GGF max:", ggf_scores.max())

    return {
        "hv": hv,
        "cd": cd,
        "cv": cv_total,
        "ggf_mean": ggf_scores.mean()
    }


def plot_front(returns, title):
    plt.scatter(returns[:,0], returns[:,1])
    plt.xlabel("Resource 1")
    plt.ylabel("Resource 2")
    plt.title(title)
    plt.show()

In [None]:
env = mo_gym.make("resource-gathering-v0")
eval_env = mo_gym.make("resource-gathering-v0")

envelope_agent = Envelope(
    env,
    log=True,
)

envelope_agent.train(
    total_timesteps=1000,
    eval_env=eval_env,
    ref_point=np.array([0.0, 0.0]),
    eval_freq=100,
)

env_returns = np.array(envelope_agent.archive["vectors"])
env_metrics = compute_metrics("Envelope", env_returns)
plot_front(env_returns, "Envelope Pareto Front")


fmdq_agent = FMDQ(
    env,
    log=True,
    project_name="FMDQ-Run"
)

fmdq_agent.train(
    total_timesteps=5000,
    eval_env=eval_env,
    ref_point=np.array([0.0, 0.0]),
    eval_freq=100,
)

fmdq_returns = np.array(fmdq_agent.archive["vectors"])

fmdq_metrics = compute_metrics("FMDQ", fmdq_returns)

plot_front(fmdq_returns, "F-MDQ Pareto Front")


In [None]:
print("\n========== FINAL COMPARISON ==========")
print(f"Envelope HV: {env_metrics['hv']:.4f} vs FMDQ HV: {fmdq_metrics['hv']:.4f}")
print(f"Envelope CD: {env_metrics['cd']} vs FMDQ CD: {fmdq_metrics['cd']}")
print(f"Envelope CV: {env_metrics['cv']:.4f} vs FMDQ CV: {fmdq_metrics['cv']:.4f}")
print(f"Envelope GGF mean: {env_metrics['ggf_mean']:.4f} vs FMDQ GGF mean: {fmdq_metrics['ggf_mean']:.4f}")