### Installing the required libraries
```bash
conda create -n eppo python=3.12 -y
conda activate eppo

pip install gymnasium torch numpy scipy seaborn tqdm imageio 
pip install gymnasium['mujoco']
```


### Libraries

In [None]:
import torch
import torch.nn as nn
import numpy as np
import gymnasium as gym
import itertools
import time
import copy
from gymnasium.wrappers import RescaleAction
from typing import Optional
from torch.distributions import Normal
from tqdm import tqdm

## Hperparameters

In [None]:
# common parameters
max_steps = 500_000
learning_rate = 3e-4
learn_frequency = 2047  # it is 2048 but we are using 0 based index
batch_size = 256
max_iter = 10
act_actor = "relu"
act_critic = "relu"
depth_actor = 3
depth_critic = 3
width_actor = 256
width_critic = 256
gamma = 0.99
no_norm_actor = False
no_norm_critic = False
eval_frequency = 20_000
eval_episodes = 10
buffer_size = 2048

# ppo related
clip_param = 0.2
gae_lambda = 0.95
max_grad_norm = 0.5

# eppo related
regularization_coeff = 0.01
radius = 0.01
seed = 1
exploration_types = ["mean", "cor", "ind"]
exploration_type = exploration_types[0]

# experiment related
environments = ["ant", "halfcheetah"]
strategies = {
    "ant": ["back_one", "front_one", "back_two", "front_two", "parallel", "cross"],
    "halfcheetah": ["back_one", "front_one", "cross_v1", "cross_v2"],
}
environment = environments[1]
strategy = strategies[environment][0]
exp_name = f"{environment}_{strategy}"

device = "cuda" if torch.cuda.is_available() else "cpu"

## Environment

In [None]:
# ParalyzeActionWrapper
class ParalyzeActionWrapper(gym.ActionWrapper):
    def __init__(self, env, joint_idxs, paralyzed_ratio=0.0):
        super().__init__(env)
        self.paralyzed_ratio = paralyzed_ratio
        self.joint_idxs = joint_idxs
        self.coeff = np.ones(self.action_space.shape)
        self.coeff[joint_idxs] = paralyzed_ratio

    def step(self, action):
        action = action * self.coeff
        return self.env.step(action)

class SinglePrecision(gym.ObservationWrapper):

    def __init__(self, env):
        super().__init__(env)

        if isinstance(self.observation_space, gym.spaces.Box):
            obs_space = self.observation_space
            self.observation_space = gym.spaces.Box(obs_space.low, obs_space.high, obs_space.shape)
        elif isinstance(self.observation_space, gym.spaces.Dict):
            obs_spaces = copy.copy(self.observation_space.spaces)
            for k, v in obs_spaces.items():
                obs_spaces[k] = gym.spaces.Box(v.low, v.high, v.shape)
            self.observation_space = gym.spaces.Dict(obs_spaces)
        else:
            raise NotImplementedError

    def observation(self, observation: np.ndarray) -> np.ndarray:
        if isinstance(observation, np.ndarray):
            return observation.astype(np.float32)
        elif isinstance(observation, dict):
            observation = copy.copy(observation)
            for k, v in observation.items():
                observation[k] = v.astype(np.float32)
            return observation


def get_idx_to_paralyze(exp_name):
    if "ant" in exp_name:
        if "front_one" in exp_name:  # front left joints
            return [2, 3]
        elif "front_two" in exp_name:  # front left and right joints
            return [2, 3, 4, 5]
        elif "back_one" in exp_name:  # back left joints
            return [6, 7]
        elif "back_two" in exp_name:  # back left and right joints
            return [0, 1, 6, 7]
        elif "parallel" in exp_name:  # left front and back joints
            return [2, 3, 6, 7]
        elif "cross" in exp_name:  # left front and right back joints
            return [2, 3, 0, 1]
    elif "halfcheetah" in exp_name:
        if "front_one" in exp_name:
            return [5]
        elif "back_one" in exp_name:
            return [2]
        elif "cross_v1" in exp_name:
            return [2, 4]
        elif "cross_v2" in exp_name:
            return [1, 5]
    else:
        raise ValueError(f"Unknown experiment: {exp_name}")


def make_env(
    exp_name: str,
    seed: int,
    idxs: Optional[list] = None,
    paralyzed_ratio: Optional[float] = 0.0,
) -> gym.Env:

    if "ant" in exp_name:
        env_name = "Ant-v5"
    elif "halfcheetah" in exp_name:
        env_name = "HalfCheetah-v5"
    else:
        raise NotImplementedError(f"Environment {exp_name} not implemented")

    env = gym.make(env_name)
    env = RescaleAction(env, -1.0, 1.0)
    env = SinglePrecision(env)
    env = ParalyzeActionWrapper(env, idxs, paralyzed_ratio)

    env.reset(seed=seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    return env

## Experimenter

In [None]:
def totorch(x, dtype=torch.float32, device="cuda"):
    return torch.as_tensor(x, dtype=dtype, device=device)


def tonumpy(x):
    return x.data.cpu().numpy()


class ExperienceMemoryTorch:
    """Fixed-size buffer to store experience tuples."""

    field_names = ["state", "action", "reward", "next_state", "terminated", "step"]

    def __init__(self, device, buffer_size, dims):
        self.device = device
        self.buffer_size = buffer_size
        self.dims = dims
        self.reset()

    def reset(self, buffer_size=None):
        if buffer_size is not None:
            self.buffer_size = buffer_size
        self.data_size = 0
        self.pointer = 0
        self.memory = {
            field: torch.empty(self.dims[field], device=self.device)
            for field in self.field_names
        }

    def add(self, state, action, reward, next_state, terminated, step):
        for field, value in zip(
            self.field_names, [state, action, reward, next_state, terminated, step]
        ):
            self.memory[field][self.pointer] = value
        self.pointer = (self.pointer + 1) % self.buffer_size
        self.data_size = min(self.data_size + 1, self.buffer_size)

    def sample_by_index(self, index):
        return tuple(self.memory[field][index] for field in self.field_names)

    def sample_all(self):
        return self.sample_by_index(range(self.data_size))

    def __len__(self):
        return self.data_size

    @property
    def size(self):
        return self.data_size


class ParalysisExpriment(object):
    def __init__(
        self,
        agent,
        exp_name,
        seed,
        max_steps,
        eval_frequency,
        device,
        learn_frequency,
        eval_episodes,
        max_iter,
        gamma,
    ):
        self.exp_name = exp_name
        self.seed = seed
        self.max_steps = max_steps
        self.eval_frequency = eval_frequency
        self.device = device
        self.learn_frequency = learn_frequency
        self.eval_episodes = eval_episodes
        self.max_iter = max_iter
        self.gamma = gamma

        self.idx_to_paralyze = get_idx_to_paralyze(self.exp_name)
        self.prepare_tasks()

        self.agent = agent
        
        self.AULCS = []
        self.FINAL_RETURNS = []

    def prepare_tasks(self):
        task_order = [1.0, 0.75, 0.5, 0.25, 0.0, 0.25, 0.5, 0.75, 1.0]
        self.n_tasks = len(task_order)

        self.tasks = {
            task_id: {"task": coeff, "idxs": self.idx_to_paralyze}
            for task_id, coeff in enumerate(task_order)
        }
        task_names = [task_info["task"] for task_info in self.tasks.values()]
        print(f"Tasks: {task_names}")

    def set_task(self, task_id, task_info):
        task = task_info["task"]

        self.env = make_env(
            exp_name=self.exp_name,
            seed=self.seed,
            idxs=task_info["idxs"],
            paralyzed_ratio=task,
        )

        self.eval_env = make_env(
            exp_name=self.exp_name,
            seed=self.seed + 100,
            idxs=task_info["idxs"],
            paralyzed_ratio=task,
        )

        return task

    def train(self):
        time_start = time.time()

        information_dict = {
            "episode_rewards": torch.zeros(self.max_steps * (self.n_tasks + 1)),
            "episode_steps": torch.zeros(self.max_steps * (self.n_tasks + 1)),
            "step_rewards": np.empty((2 * self.max_steps * self.n_tasks), dtype=object),
        }

        r_cum = np.zeros(1)
        episode = 0
        e_step = 0
        for task_id, task_info in self.tasks.items():
            # task starts
            task = self.set_task(task_id, task_info)
            print(f"Starting to task {task_id}: {task}")

            r_cum = np.zeros(1)
            s, _ = self.env.reset()
            s = totorch(s, device=self.device)
            for step in tqdm(
                range(task_id * self.max_steps, (task_id + 1) * self.max_steps),
                leave=True,
                disable=True,
            ):
                e_step += 1

                if step % self.eval_frequency == 0:
                    self.eval(step)

                a = self.agent.select_action(s).clip(-1.0, 1.0)

                sp, r, done, truncated, info = self.env.step(tonumpy(a))
                sp = totorch(sp, device=self.device)

                self.agent.store_transition(s, a, r, sp, done, truncated, step + 1)

                information_dict["step_rewards"][step] = (
                    episode,
                    step,
                    r,
                )

                s = sp  # Update state
                r_cum += r  # Update cumulative reward

                if (step % self.learn_frequency) == 0:
                    # print("Learning at step: ", step)
                    self.agent.learn(max_iter=self.max_iter)

                if done or truncated:

                    information_dict["episode_rewards"][episode] = r_cum.item()
                    information_dict["episode_steps"][episode] = step
                    if episode % 10 == 0:
                        print(
                            f"Episode: {episode + 1:4d}\tN-steps: {step:7d}\tReward: {r_cum.item():10.3f}"
                        )
                    s, _ = self.env.reset()
                    s = totorch(s, device=self.device)
                    r_cum = np.zeros(1)
                    episode += 1
                    e_step = 0

            # task finishes
            self.eval(step, final=True)
            self.agent.end_task()

        time_end = time.time()
        print(f"Training time: {time_end - time_start:.2f} seconds")
        
        print(f"AULC: {np.mean(self.AULCS)}")
        print(f"Final Return: {np.mean(self.FINAL_RETURNS)}")

    @torch.no_grad()
    def eval(self, n_step, final=False):
        self.agent.eval()
        results = torch.zeros(self.eval_episodes)
        collect_infos = {}
        for episode in range(self.eval_episodes):
            collect_infos[episode] = []
            s, info = self.eval_env.reset()
            s = totorch(s, device=self.device)
            step = 0
            a = self.agent.select_action(s, is_training=False)
            done = False

            while not done:
                a = self.agent.select_action(s, is_training=False)

                sp, r, term, trunc, info = self.eval_env.step(tonumpy(a))
                collect_infos[episode].append(info)

                done = term or trunc
                s = totorch(sp, device=self.device)
                results[episode] += r
                step += 1

        print(f"EVALUATION\tN-steps: {n_step:7d}\tMean_Reward: {results.mean():10.3f}")
        self.AULCS.append(results.mean())
        if final:
            self.FINAL_RETURNS.append(results.mean())

        self.agent.train()

## Networks

In [None]:
def get_activation(act):
    if act == "relu":
        return nn.ReLU
    elif act == "tanh":
        return nn.Tanh
    else:
        raise NotImplementedError(f"{act} is not implemented")


def create_net(d_in, d_out, depth, width, act="crelu", has_norm=True, n_elements=1):
    assert depth > 0, "Need at least one layer"

    act = get_activation(act)

    if depth == 1:
        arch = nn.Linear(d_in, d_out)
    elif depth == 2:
        arch = nn.Sequential(
            nn.Linear(d_in, width),
            (
                nn.LayerNorm(width, elementwise_affine=False)
                if has_norm
                else nn.Identity()
            ),
            act(),
            nn.Linear(width, d_out),
        )
    else:
        in_layer = nn.Linear(d_in, width)
        if n_elements > 1:
            out_layer = nn.Linear(width, d_out, n_elements)
        else:
            out_layer = nn.Linear(width, d_out)

        # This can probably be done in a more readable way, but it's fast and works...
        hidden = list(
            itertools.chain.from_iterable(
                [
                    [
                        (
                            nn.LayerNorm(width, elementwise_affine=False)
                            if has_norm
                            else nn.Identity()
                        ),
                        act(),
                        nn.Linear(width, width),
                    ]
                    for _ in range(depth - 1)
                ]
            )
        )[:-1]
        arch = nn.Sequential(in_layer, *hidden, out_layer)

    return arch


class GaussianHead(nn.Module):
    def __init__(self, n):
        super().__init__()
        self._n = n

    def forward(self, x, is_training=True, return_dist=False):
        mean = x[..., : self._n]
        logstd = x[..., self._n :].clamp(-10.0, -2.0)
        std = logstd.exp()
        dist = Normal(mean, std, validate_args=False)
        if is_training:
            y = dist.rsample()
            y_logprob = dist.log_prob(y).sum(dim=-1, keepdim=True)
        else:
            y = dist.mode
            y_logprob = None
        if return_dist:
            return y, y_logprob, dist
        return y, mean


class ActorNetProbabilistic(nn.Module):
    def __init__(self, dim_obs, dim_act, depth=3, width=256, act="relu", has_norm=True):
        super().__init__()
        self.dim_act = dim_act
        self.arch = create_net(dim_obs, 2 * dim_act, depth, width, act, has_norm)

        self.head = GaussianHead(self.dim_act)

    def forward(self, x, is_training=True, return_dist=False):
        f = self.arch(x)
        return self.head(f, is_training, return_dist=return_dist)


class EvidentialCriticNet(nn.Module):
    def __init__(self, dim_obs, depth=3, width=256, act="relu", has_norm=False):
        super().__init__()

        self.arch = create_net(dim_obs, 4, depth, width, act=act, has_norm=has_norm)

    @staticmethod
    def evidence(x):
        return torch.exp(x)

    def forward(self, x):
        output = self.arch(x)
        gamma, logv, log_alpha, log_beta = output.chunk(4, dim=-1)
        v = self.evidence(logv)
        alpha = self.evidence(log_alpha) + 1  # to ensure that alpha > 1
        beta = self.evidence(log_beta)
        return gamma, v, alpha, beta

# EPPO

In [None]:
class EPPOActor(nn.Module):
    def __init__(
        self,
        arch,
        n_state,
        n_action,
        clip_param,
        max_grad_norm,
        learning_rate,
        depth_actor,
        width_actor,
        act_actor,
        no_norm_actor,
        device,
    ):
        super().__init__()
        self.n_state = n_state
        self.n_action = n_action
        self.arch = arch
        self.clip = clip_param
        self.max_grad_norm = max_grad_norm
        self.learning_rate = learning_rate
        self.depth_actor = depth_actor
        self.width_actor = width_actor
        self.act_actor = act_actor
        self.no_norm_actor = no_norm_actor
        self.device = device

        self.initialize()

    def initialize(self):
        self.model = self.arch(
            self.n_state,
            self.n_action,
            depth=self.depth_actor,
            width=self.width_actor,
            act=self.act_actor,
            has_norm=not self.no_norm_actor,
        ).to(self.device)
        self.optim = torch.optim.Adam(self.model.parameters(), self.learning_rate)

    def evaluate(self, s):
        _, _, dist = self.model(s, return_dist=True)
        # return dist here
        return dist

    def loss(self, s, a, old_probs, adv):
        dist = self.evaluate(s)
        new_probs = dist.log_prob(a)
        foo = new_probs.sum(1, keepdim=True) - old_probs.sum(1, keepdim=True)
        prob_ratio = torch.exp(foo.clamp(-20, 1))

        weighted_probs = adv * prob_ratio
        weighted_clipped_probs = (
            torch.clamp(prob_ratio, 1 - self.clip, 1 + self.clip) * adv
        )

        actor_loss = -torch.min(weighted_probs, weighted_clipped_probs).mean()
        self.clip_fraction = (
            (torch.abs((prob_ratio - 1)) > self.clip).to(torch.float).mean()
        )

        return actor_loss

    def update(self, s, a, old_probs, adv):
        self.optim.zero_grad()
        loss = self.loss(s, a, old_probs, adv)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
        self.optim.step()

    def act(self, s, is_training=True):
        a, e = self.model(s, is_training=is_training)
        return a, e


class EvidentialCritic(nn.Module):
    def __init__(
        self,
        arch,
        n_state,
        max_grad_norm,
        depth_critic,
        width_critic,
        act_critic,
        no_norm_critic,
        device,
        learning_rate,
        regularization_coeff,
    ):
        super().__init__()

        self.arch = arch
        self.max_grad_norm = max_grad_norm
        self.depth_critic = depth_critic
        self.width_critic = width_critic
        self.act_critic = act_critic
        self.no_norm_critic = no_norm_critic
        self.n_state = n_state
        self.learning_rate = learning_rate
        self.iter = 0
        self.device = device
        self.initialize()

        self.regularization_coeff = regularization_coeff

    def initialize(self):
        self.model = self.arch(
            self.n_state,
            depth=self.depth_critic,
            width=self.width_critic,
            act=self.act_critic,
            has_norm=not self.no_norm_critic,
        ).to(self.device)
        self.loss = torch.nn.MSELoss()
        self.optim = torch.optim.Adam(self.model.parameters(), self.learning_rate)

        self.prior_gamma = torch.distributions.Normal(
            torch.tensor(0.0).to(self.device), torch.tensor(100.0).to(self.device)
        )
        self.prior_v = torch.distributions.Gamma(
            torch.tensor(5.0).to(self.device), torch.tensor(1.0).to(self.device)
        )
        self.prior_alpha = torch.distributions.TransformedDistribution(
            torch.distributions.Gamma(
                torch.tensor(5.0).to(self.device), torch.tensor(1.0).to(self.device)
            ),
            [
                torch.distributions.transforms.AffineTransform(
                    loc=1.0, scale=1.0, cache_size=1
                )
            ],
        )
        self.prior_beta = torch.distributions.Gamma(
            torch.tensor(5.0).to(self.device), torch.tensor(1.0).to(self.device)
        )

    def get_prior(self, x):
        gamma, v, alpha, beta = self.model(x)
        return gamma, v, alpha, beta

    def loss(self, state, target):
        gamma, v, alpha, beta = self.get_prior(state)
        twoBlambda = 2 * beta * (1 + v)
        loss = (
            -0.5 * torch.log(v)
            - alpha * torch.log(twoBlambda)
            + (alpha + 0.5) * torch.log(v * (target - gamma) ** 2 + twoBlambda)
            + torch.lgamma(alpha)
            - torch.lgamma(alpha + 0.5)
        )

        regularization = (
            self.prior_gamma.log_prob(gamma).mean()
            + self.prior_v.log_prob(v).mean()
            + self.prior_alpha.log_prob(alpha).mean()
            + self.prior_beta.log_prob(beta).mean()
        )
        loss -= regularization * self.regularization_coeff

        return loss.mean()

    def update(self, state, target):  # y denotes bellman target
        self.optim.zero_grad()
        loss = self.loss(state, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
        self.optim.step()

    def forward(self, x, target=False):
        gamma, v, alpha, beta = self.get_prior(x)
        if target:
            return gamma

        return gamma

    def get_mean_and_variance_of_y(self, x):
        gamma, v, alpha, beta = self.get_prior(x)
        mean = gamma
        variance = (beta / (alpha - 1)) * (1 + 1.0 / v)
        return mean, variance


class EvidentialProximalPolicyOptimization(nn.Module):
    _agent_name = "EPPO"

    def __init__(
        self,
        env,
        actor_nn,
        critic_nn,
        device,
        gamma,
        buffer_size,
        clip_param,
        gae_lambda,
        max_grad_norm,
        learning_rate,
        depth_actor,
        width_actor,
        act_actor,
        no_norm_actor,
        depth_critic,
        width_critic,
        act_critic,
        no_norm_critic,
        regularization_coeff,
        radius,
        exploration_type,
        batch_size,
    ):
        super().__init__()

        self.device = device
        self.env = env
        self.eps = 1e-6  # small value to avoid division by zero
        self.clip_param = clip_param
        self.gae_lambda = gae_lambda
        self.max_grad_norm = max_grad_norm
        self.learning_rate = learning_rate
        self.depth_actor = depth_actor
        self.width_actor = width_actor
        self.act_actor = act_actor
        self.no_norm_actor = no_norm_actor
        self.depth_critic = depth_critic
        self.width_critic = width_critic
        self.act_critic = act_critic
        self.no_norm_critic = no_norm_critic
        self.regularization_coeff = regularization_coeff
        self.radius = radius
        self.exploration_type = (exploration_type,)
        self.batch_size = batch_size

        self.dim_obs, self.dim_act = (
            self.env.observation_space.shape,
            self.env.action_space.shape,
        )
        self.dim_obs_flat, self.dim_act_flat = np.prod(self.dim_obs), np.prod(
            self.dim_act
        )
        self._u_min = totorch(self.env.action_space.low, device=self.device)
        self._u_max = totorch(self.env.action_space.high, device=self.device)
        self._x_min = totorch(self.env.observation_space.low, device=self.device)
        self._x_max = totorch(self.env.observation_space.high, device=self.device)

        self._gamma = gamma
        self.buffer_size = buffer_size

        dims = {
            "state": (self.buffer_size, self.dim_obs_flat),
            "action": (self.buffer_size, self.dim_act_flat),
            "next_state": (self.buffer_size, self.dim_obs_flat),
            "reward": (self.buffer_size),
            "terminated": (self.buffer_size),
            "step": (self.buffer_size),
        }

        self.experience_memory = ExperienceMemoryTorch(
            self.device, self.buffer_size, dims
        )

        self.actor = EPPOActor(
            actor_nn,
            self.dim_obs_flat,
            self.dim_act_flat,
            self.clip_param,
            self.max_grad_norm,
            self.learning_rate,
            self.depth_actor,
            self.width_actor,
            self.act_actor,
            self.no_norm_actor,
            self.device,
        )

        self.critic = EvidentialCritic(
            critic_nn,
            self.dim_obs_flat,
            self.max_grad_norm,
            self.depth_critic,
            self.width_critic,
            self.act_critic,
            self.no_norm_critic,
            self.device,
            self.learning_rate,
            self.regularization_coeff,
        )

        self._variance_coeff = (1.0 - self.gae_lambda) / (1.0 + self.gae_lambda)
        self._next_variance_coeff = ((1.0 - self.gae_lambda) / (self.gae_lambda)) ** 2
        self._accumulation_coeff = (self._gamma * self.gae_lambda) ** 2

        if exploration_type == "mean":
            self.calculate_advantages = self.calculate_advantages_mean
        elif exploration_type == "cor":
            self.calculate_advantages = self.calculate_advantages_cor
        elif exploration_type == "ind":
            self.calculate_advantages = self.calculate_advantages_ind
        else:
            raise ValueError(f"Unknown exploration type: {exploration_type}")

    def end_task(self):
        self.experience_memory.reset()

    @torch.no_grad()
    def calculate_advantages_mean(self, states, next_states, rewards, dones):
        # EPPO_mean
        values = self.critic(states)
        next_values = self.critic(next_states)
        deltas = rewards + self._gamma * next_values * (1 - dones) - values
        advantages = torch.zeros_like(rewards)
        advantage = 0
        for i in reversed(range(len(deltas))):
            advantage = (
                self._gamma * self.gae_lambda * advantage * (1 - dones[i]) + deltas[i]
            )
            advantages[i] = advantage

        returns = advantages + values
        advantages = (advantages - advantages.mean()) / (advantages.std() + self.eps)
        return advantages, returns

    @torch.no_grad()
    def calculate_advantages_cor(self, states, next_states, rewards, dones):
        # EPPO_cor
        mean_values, variance_values = self.critic.get_mean_and_variance_of_y(states)
        mean_next_values, variance_next_values = self.critic.get_mean_and_variance_of_y(
            next_states
        )
        mean_deltas = (
            rewards + self._gamma * mean_next_values * (1 - dones) - mean_values
        )

        mean_advantages = torch.zeros_like(rewards)
        variance_advantages = torch.zeros_like(rewards)
        mean_advantage = 0
        variance_accumulated = 0
        for i in reversed(range(len(mean_deltas))):
            mean_advantage = (
                self._gamma * self.gae_lambda * mean_advantage * (1 - dones[i])
                + mean_deltas[i]
            )
            mean_advantages[i] = mean_advantage
            variance_accumulated = self._accumulation_coeff * (
                variance_accumulated * (1 - dones[i]) + variance_next_values[i]
            )
            variance_advantages[i] = (
                variance_values[i] + self._next_variance_coeff * variance_accumulated
            )

        std_advantages = torch.sqrt(variance_advantages)
        advantages = mean_advantages + self.radius * std_advantages
        returns = mean_advantages + mean_values
        advantages = (advantages - advantages.mean()) / (advantages.std() + self.eps)
        return advantages, returns

    @torch.no_grad()
    def calculate_advantages_ind(self, states, next_states, rewards, dones):
        # EPPO_ind
        mean_values, variance_values = self.critic.get_mean_and_variance_of_y(states)
        mean_next_values, variance_next_values = self.critic.get_mean_and_variance_of_y(
            next_states
        )
        mean_deltas = (
            rewards + self._gamma * mean_next_values * (1 - dones) - mean_values
        )

        mean_advantages = torch.zeros_like(rewards)
        variance_advantages = torch.zeros_like(rewards)
        mean_advantage = 0
        variance_accumulated = 0
        for i in reversed(range(len(mean_deltas))):
            mean_advantage = (
                self._gamma * self.gae_lambda * mean_advantage * (1 - dones[i])
                + mean_deltas[i]
            )
            mean_advantages[i] = mean_advantage
            variance_accumulated = self._accumulation_coeff * (
                variance_accumulated * (1 - dones[i]) + variance_next_values[i]
            )
            variance_advantages[i] = (
                self._variance_coeff * variance_values[i]
                + self._next_variance_coeff * variance_accumulated
            )

        std_advantages = torch.sqrt(variance_advantages)
        advantages = mean_advantages + self.radius * std_advantages
        returns = mean_advantages + mean_values
        advantages = (advantages - advantages.mean()) / (advantages.std() + self.eps)
        return advantages, returns

    @torch.no_grad()
    def calculate_old_probs(self, states, actions):
        dist = self.actor.evaluate(states)
        return dist.log_prob(actions)

    def learn(self, max_iter=1):
        if self.batch_size > len(self.experience_memory):
            return None

        states, actions, rewards, next_states, terminateds, _ = (
            self.experience_memory.sample_all()
        )
        rewards = rewards.reshape(-1, 1)
        terminateds = terminateds.reshape(-1, 1)
        advantages, returns = self.calculate_advantages(
            states, next_states, rewards, terminateds
        )
        old_probs = self.calculate_old_probs(states, actions)

        for ii in range(max_iter):
            # shuffle data
            indices = torch.randperm(len(states))
            for i in range(0, len(states), self.batch_size):
                batch_indices = indices[i : i + self.batch_size]
                batch_states = states[batch_indices]
                batch_actions = actions[batch_indices]
                batch_old_probs = old_probs[batch_indices]
                batch_advantages = advantages[batch_indices]
                batch_returns = returns[batch_indices]

                self.actor.update(
                    batch_states, batch_actions, batch_old_probs, batch_advantages
                )
                self.critic.update(batch_states, batch_returns)

        # clear memory after learning due to on-policy
        self.experience_memory.reset()

    @torch.no_grad()
    def select_action(self, s, is_training=True):
        a, _ = self.actor.act(s, is_training=is_training)
        return a.clamp(self._u_min + self.eps, self._u_max - self.eps)

    def Q_value(self, s, a):
        return self.critic(s)

    def store_transition(self, s, a, r, sp, terminated, truncated, step):
        self.experience_memory.add(s, a, r, sp, terminated or truncated, step)

# Experimenting

In [None]:
# env
eval_env = make_env(
    exp_name=exp_name,
    seed=seed + 100,
    idxs=get_idx_to_paralyze(exp_name),
    paralyzed_ratio=0.0,
)

# EPPO
agent = EvidentialProximalPolicyOptimization(
    env=eval_env,
    actor_nn=ActorNetProbabilistic,
    critic_nn=EvidentialCriticNet,
    device=device,
    gamma=gamma,
    buffer_size=buffer_size,
    clip_param=clip_param,
    gae_lambda=gae_lambda,
    max_grad_norm=max_grad_norm,
    learning_rate=learning_rate,
    depth_actor=depth_actor,
    width_actor=width_actor,
    act_actor=act_actor,
    no_norm_actor=no_norm_actor,
    depth_critic=depth_critic,
    width_critic=width_critic,
    act_critic=act_critic,
    no_norm_critic=no_norm_critic,
    regularization_coeff=regularization_coeff,
    radius=radius,
    exploration_type=exploration_type,
    batch_size=batch_size,
)

# experimenter
experimenter = ParalysisExpriment(
    agent=agent,
    exp_name=exp_name,
    seed=seed,
    max_steps=max_steps,
    eval_frequency=eval_frequency,
    device=device,
    learn_frequency=learn_frequency,
    eval_episodes=eval_episodes,
    max_iter=max_iter,
    gamma=gamma,
)

experimenter.train()