# Import Necessary Packages

In [None]:
# Standard Library
import argparse
import itertools
import math
import os
import random
import time
from collections import deque
from typing import Dict, Optional, OrderedDict, Tuple

# Third-Party Libraries
import gymnasium as gym
from gymnasium import core, spaces
from gymnasium.spaces import Box, Dict
from gymnasium.wrappers import RescaleAction
from dm_control import suite
from scipy.stats import norm
import numpy as np
from tqdm import tqdm

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch import nn as torch_nn  
from torch import func as thf    
from torch.distributions import Normal, TransformedDistribution
from torch.distributions.transforms import TanhTransform
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair

# Aliases
import torch as th 
import gym  
import dm_env  

# Argument Parser for Jupyter compatibility
parser = argparse.ArgumentParser()
parser.add_argument("-f", required=False)  
args, unknown = parser.parse_known_args()

# Get the Model 

In [2]:
def get_model( env):
    return RandomEnsembleDoubleQLearning (env, args)

# DMControl Environment
 Define a wrapper that converts DeepMind Control Suite (DMC) environments into a Gym-compatible format.

In [3]:

TimeStep = Tuple[np.ndarray, float, bool, bool, dict]

def dmc_spec2gym_space(spec):
    if isinstance(spec, OrderedDict) or isinstance(spec, dict):
        spec = copy.copy(spec)
        for k, v in spec.items():
            spec[k] = dmc_spec2gym_space(v)
        return spaces.Dict(spec)
    elif isinstance(spec, dm_env.specs.BoundedArray):
        return spaces.Box(low=spec.minimum,
                          high=spec.maximum,
                          shape=spec.shape,
                          dtype=spec.dtype)
    elif isinstance(spec, dm_env.specs.Array):
        return spaces.Box(low=-float('inf'),
                          high=float('inf'),
                          shape=spec.shape,
                          dtype=spec.dtype)
    else:
        raise NotImplementedError


class DMCEnv(core.Env):
    def __init__(self,
                 domain_name: Optional[str] = None,
                 task_name: Optional[str] = None,
                 env: Optional[dm_env.Environment] = None,
                 task_kwargs: Optional[Dict] = {},
                 environment_kwargs=None):
        assert 'random' in task_kwargs, 'Please specify a seed, for deterministic behaviour.'
        assert (
            env is not None
            or (domain_name is not None and task_name is not None)
        ), 'You must provide either an environment or domain and task names.'

        if env is None:
            env = suite.load(
                domain_name=domain_name,
                task_name=task_name,
                task_kwargs=task_kwargs,
                environment_kwargs=environment_kwargs,
                visualize_reward=True
            )

        self._env = env
        self.domain_name = domain_name
        self.task_name = task_name
        self.action_space = dmc_spec2gym_space(self._env.action_spec())

        self.observation_space = dmc_spec2gym_space(
            self._env.observation_spec())

    def __getattr__(self, name):
        return getattr(self._env, name)

    def step(self, action: np.ndarray) -> TimeStep:
        assert self.action_space.contains(action)

        time_step = self._env.step(action)
        reward = time_step.reward or 0
        done = time_step.last()
        obs = time_step.observation

        info  = {}
        trunc = done and (time_step.discount == 1.0)
        term = done and (time_step.discount != 1.0)
        if trunc:
            info['TimeLimit.truncated'] = True
        return obs, reward, term, trunc, info

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        time_step = self._env.reset()
        info = {}
        return time_step.observation, info

    def render(self,
               mode='rgb_array',
               height: int = 84,
               width: int = 84,
               camera_id: int = 0):
        assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode
        return self._env.physics.render(height=height,
                                        width=width,)
       

# Make Environment
Defines wrappers and utilities to preprocess Gym environments. The `make_env()` function sets up a DeepMind Control Suite environment with various preprocessing options.


In [4]:


class SinglePrecision(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)

        if isinstance(self.observation_space, Box):
            obs_space = self.observation_space
            self.observation_space = Box(obs_space.low, obs_space.high,
                                         obs_space.shape)
        elif isinstance(self.observation_space, Dict):
            obs_spaces = copy.copy(self.observation_space.spaces)
            for k, v in obs_spaces.items():
                obs_spaces[k] = Box(v.low, v.high, v.shape)
            self.observation_space = Dict(obs_spaces)
        else:
            raise NotImplementedError

    def observation(self, observation: np.ndarray) -> np.ndarray:
        if isinstance(observation, np.ndarray):
            return observation.astype(np.float32)
        elif isinstance(observation, dict):
            observation = copy.copy(observation)
            for k, v in observation.items():
                observation[k] = v.astype(np.float32)
            return observation
        
class FlattenAction(gym.ActionWrapper):
    """Action wrapper that flattens the action."""

    def __init__(self, env):
        super(FlattenAction, self).__init__(env)
        self.action_space = gym.spaces.utils.flatten_space(self.env.action_space)

    def action(self, action):
        return gym.spaces.utils.unflatten(self.env.action_space, action)

    def reverse_action(self, action):
        return gym.spaces.utils.flatten(self.env.action_space, action)

def make_env(env_name: str,
             seed: int,
             save_folder: Optional[str] = None,
             add_episode_monitor: bool = True,
             action_repeat: int = 1,
             frame_stack: int = 1,
             from_pixels: bool = False,
             pixels_only: bool = True,
             image_size: int = 84,
             sticky: bool = False,
             gray_scale: bool = False,
             flatten: bool = True,
             terminate_when_unhealthy: bool = True,
             action_concat: int = 1,
             obs_concat: int = 1,
             continuous: bool = True,
             ) -> gym.Env:

    env_ids = list(gym.envs.registry.keys())

    
    if env_name in env_ids:
        env = gym.make(env_name)
        save_folder = None
    else:
        domain_name, task_name = env_name.split('-')
        env = DMCEnv(domain_name=domain_name, task_name=task_name, task_kwargs={'random': seed})

    if flatten and isinstance(env.observation_space, gym.spaces.Dict):
        env = gym.wrappers.FlattenObservation(env)
        env = FlattenAction(env)

    if continuous:
        env = RescaleAction(env, -1.0, 1.0)

    env = SinglePrecision(env)
    env.reset(seed=seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    return env

class Experiment(object):
    def __init__(self):
        self.args = args
        self.n_total_steps = 0
        self.max_steps = 100000
        # self.env = make_env('Ant-v4', 1)
        # self.eval_env = make_env('Ant-v4', 101)
        self.env = make_env('cartpole-swingup', 1)
        self.eval_env = make_env('cartpole-swingup', 101)
        self.agent = get_model( self.env)


# Architectures 
This cell includes all the network design architectures, such as the critic and actor networks.

In [None]:

def tonumpy(x):
    return x.data.cpu().numpy() 

class Critic(nn.Module):
    def __init__(self, arch, args, n_state, n_action):
        super(Critic, self).__init__()
        self.args = args
        self.args.device ="cpu"
        self.model = arch(n_state, n_action, args.n_hidden).to(self.args.device)
        self.target = arch(n_state, n_action, args.n_hidden).to(self.args.device)
        self.init_target()
        self.loss = nn.MSELoss()
        self.optim = torch.optim.Adam(self.model.parameters(), args.learning_rate)
        self.iter = 0
        self.args.tau = 0.005

    def set_writer(self, writer):
        self.writer = writer

    def init_target(self):
        for target_param, local_param in zip(
            self.target.parameters(), self.model.parameters()
        ):
            target_param.data.copy_(local_param.data)

    @th.no_grad()
    def update_target(self):
        for target_param, local_param in zip(
            self.target.parameters(), self.model.parameters()
        ):
            target_param.data.mul_(1.0 - self.args.tau)
            target_param.data.add_(self.args.tau * local_param.data)

    def Q(self, s, a):
        return self.model(s, a)

    def Q_t(self, s, a):
        return self.target(s, a)

    def update(self, s, a, y):  # y denotes bellman target
        self.optim.zero_grad()
        loss = self.loss(self.Q(s, a), y)
        loss.backward()
        self.optim.step()
        self.iter += 1


class CriticEnsemble(nn.Module):
    def __init__(self, arch, args, n_state, n_action, critictype=Critic):
        super(CriticEnsemble, self).__init__()
        self.n_elements = self.args.n_critics
        print(f"Number of elements: {self.n_elements}")
        self.args = args
        self.critics = [
            critictype(arch, args, n_state, n_action) for _ in range(self.n_elements)
        ]
        self.gamma=0.99
        self.iter = 0

    def __getitem__(self, item):
        return self.critics[item]

    def set_writer(self, writer):
        self.writer = writer
        [critic.set_writer(writer) for critic in self.critics]

    def Q(self, s, a):
        return [critic.Q(s, a) for critic in self.critics]

    def Q_t(self, s, a):
        return [critic.Q_t(s, a) for critic in self.critics]

    def update(self, s, a, y):
        [critic.update(s, a, y) for critic in self.critics]
        self.iter += 1

    def update_target(self):
        [critic.update_target() for critic in self.critics]

    def reduce(self, q_val_list):
        return torch.stack(q_val_list, dim=-1).min(dim=-1)[0]

    @torch.no_grad()
    def get_bellman_target(self, r, sp, done, actor):
        alpha = actor.log_alpha.exp().detach() if hasattr(actor, "log_alpha") else 0
        ap, ep = actor.act(sp)
        qp = self.Q_t(sp, ap)
        if ep is None:
            ep = 0
        qp_t = self.reduce(qp) - alpha * ep
        y = r.unsqueeze(-1) + (self.args.gamma * qp_t * (1 - done.unsqueeze(-1)))
        return y
 

class ParallelCritic(nn.Module):
    def __init__(self, arch, args, n_state, n_action):
        super(ParallelCritic, self).__init__()
        self.args = args
        self.arch = arch
        args.device = "cpu"
        args.learning_rate=3e-4
        self.model = arch(
            n_state,
            n_action,
            depth=3,
            width=256,
            act="crelu",
            has_norm=not False,
        ).to(args.device)
        self.target = arch(
            n_state,
            n_action,
            depth=3,
            width=256,
            act="crelu",
            has_norm=not False
        ).to(args.device)
        self.init_target()
        self.loss = nn.HuberLoss()
        self.optim = torch.optim.Adam(self.model.parameters(), args.learning_rate)
        self.iter = 0

    def set_writer(self, writer):
        self.writer = writer

    def init_target(self):
        for target_param, local_param in zip(
            self.target.parameters(), self.model.parameters()
        ):
            target_param.data.copy_(local_param.data)

    def update_target(self):
        for target_param, local_param in zip(
            self.target.parameters(), self.model.parameters()
        ):
            target_param.data.mul_(1.0 - self.args.tau)
            target_param.data.add_(self.args.tau * local_param.data)

    def Q(self, s, a):
        if a.shape == ():
            a = a.view(1, 1)
        return self.model(th.cat((s, a), -1))

    def Q_t(self, s, a):
        if a.shape == ():
            a = a.view(1, 1)
        return self.target(th.cat((s, a), -1))

    def update(self, s, a, y):  # y denotes bellman target
        self.optim.zero_grad()
        loss = self.loss(self.Q(s, a), y)
        loss.backward()
        self.optim.step()
        self.iter += 1


class ParallelCritics(nn.Module):
    def __init__(self, arch, args, n_state, n_action, critictype=ParallelCritic):
        super(ParallelCritics, self).__init__()
        self.n_members = 10
        self.args = args
        self.args.verbose = False
        self.arch = arch
        self.n_state = n_state
        self.n_action = n_action
        self.critictype = critictype
        self.iter = 0
        self.args.tau = 0.005
        self.loss = self.critictype(
            self.arch, self.args, self.n_state, self.n_action
        ).loss
        self.optim = self.critictype(
            self.arch, self.args, self.n_state, self.n_action
        ).optim

        # Helperfunctions
        self.expand = lambda x: (
            x.expand(self.n_members, *x.shape) if len(x.shape) < 3 else x
        )
        self.reset()

    def reset(self):
        self.critics = [
            self.critictype(self.arch, self.args, self.n_state, self.n_action)
            for _ in range(self.n_members)
        ]

        self.critics_model = [
            self.critictype(self.arch, self.args, self.n_state, self.n_action).model
            for _ in range(self.n_members)
        ]
        self.critics_target = [
            self.critictype(self.arch, self.args, self.n_state, self.n_action).target
            for _ in range(self.n_members)
        ]

        self.params_model, self.buffers_model = thf.stack_module_state(
            self.critics_model
        )
        self.params_target, self.buffers_target = thf.stack_module_state(
            self.critics_target
        )

        self.base_model = copy.deepcopy(self.critics[0].model).to("meta")
        self.base_target = copy.deepcopy(self.critics[0].target).to("meta")

        def _fmodel(base_model, params, buffers, x):
            return thf.functional_call(base_model, (params, buffers), (x,))

        self.forward_model = thf.vmap(lambda p, b, x: _fmodel(self.base_model, p, b, x))
        self.forward_target = thf.vmap(
            lambda p, b, x: _fmodel(self.base_target, p, b, x)
        )
        self.optim = th.optim.Adam(
            self.params_model.values(), lr=self.args.learning_rate
        )

    def reduce(self, q_val):
        return q_val.min(0)[0]

    def __getitem__(self, item):
        return self.critics[item]

    def unstack(self, target=False, single=True, net_id=None):
        """
        Extract the single parameters back to the individual members
        target: whether the target ensemble should be extracted or not
        single: whether just the first member of the ensemble should be extracted
        """
        params = self.params_target if target else self.params_model
        if single and net_id is None:
            net_id = 0

        for key in params.keys():
            if single:
                tmp = (
                    self.critics[net_id].model
                    if not target
                    else self.critics[net_id].target
                )
                for name in key.split("."):
                    tmp = getattr(tmp, name)
                tmp.data.copy_(params[key][net_id])
            else:
                for net_id in range(self.n_members):
                    tmp = (
                        self.critics[net_id].model
                        if not target
                        else self.critics[net_id].target
                    )
                    for name in key.split("."):
                        tmp = getattr(tmp, name)
                    tmp.data.copy_(params[key][net_id])
                    if single:
                        break

    def set_writer(self, writer):
        assert (
            writer is None
        ), "For now nothing else is implemented for the parallel version"
        self.writer = writer
        [critic.set_writer(writer) for critic in self.critics]

    def Q(self, s, a):
        if len(a.shape) == 1:
            a = a.view(-1,1)
        SA = self.expand(th.cat((s, a), -1))
        return self.forward_model(self.params_model, self.buffers_model, SA)

    @th.no_grad()
    def Q_t(self, s, a):
        SA = self.expand(th.cat((s, a), -1))
        return self.forward_target(self.params_target, self.buffers_target, SA)

    def update(self, s, a, y):  # y denotes bellman target
        self.optim.zero_grad()
        loss = self.loss(self.Q(s, a), self.expand(y))
        loss.backward()
        self.optim.step()
        self.iter += 1

    @torch.no_grad()
    def update_target(self):
        for key in self.params_model.keys():
            self.params_target[key].data.mul_(1.0 - self.args.tau)
            self.params_target[key].data.add_(
                self.args.tau * self.params_model[key].data
            )

    @torch.no_grad()
    def get_bellman_target(self, r, sp, done, actor):
        alpha = actor.log_alpha.exp().detach() if hasattr(actor, "log_alpha") else 0
        ap, ep = actor.act(sp)
        qp = self.Q_t(sp, ap)
        qp_t = self.reduce(qp) - alpha * (ep if ep is not None else 0)
        y = r.unsqueeze(-1) + (self.args.gamma * qp_t * (1 - done.unsqueeze(-1)))
        tqdm.write(f"{y = }")
        return y


class Actor(nn.Module):
    def __init__(self, arch, args, n_state, n_action, has_target=False):
        super().__init__()
        self.model = arch(
            n_state,
            n_action,
            depth=3,
            width=256,
            act="crelu",
            has_norm=not False,
        )
        self.optim = torch.optim.Adam(self.model.parameters(), args.learning_rate)
        self.args = args
        self.has_target = has_target
        self.args.verbose = False
        self.iter = 0
        self.is_episode_end = False
        self.states = []
        self.print_freq = 500

        if has_target:
            self.target = arch(
                n_state,
                n_action,
                depth=3,
                width=256,
                act="crelu",
                has_norm=not False,
            )
            self.init_target()

    def init_target(self):
        for target_param, local_param in zip(
            self.target.parameters(), self.model.parameters()
        ):
            target_param.data.copy_(local_param.data)

    def set_writer(self, writer):
        self.writer = writer

    def act(self, s, is_training=True):
        a, e = self.model(
            s, is_training=is_training
        ) 

        if is_training:
            if self.args.verbose and self.iter % self.print_freq == 0:
                self.states.append(tonumpy(s))
        return a, e
    
    def act_target(self, s):
        a, e = self.target(s)
        return a, e

    def set_episode_status(self, is_end):
        self.is_episode_end = is_end

    @th.no_grad()
    def update_target(self):
        for target_param, local_param in zip(
            self.target.parameters(), self.model.parameters()
        ):
            target_param.data.mul_(1.0 - self.tau)
            target_param.data.add_(self.tau * local_param.data)

    def loss(self, s, critics):
        a, _ = self.act(s)
        q_list = critics.Q(s, a)
        q = critics.reduce(q_list)
        return (-q).mean(), None

    def update(self, s, critics):
        self.optim.zero_grad()
        loss, _ = self.loss(s, critics)
        loss.backward()
        self.optim.step()

        if self.has_target:
            self.update_target()

        self.iter += 1
    
    def save_actor_params(self, path):
        params = {
            "params_model": self.model.state_dict(),
        }

        params_th = {
            k: v if isinstance(v, torch.Tensor) else v  # Ensure the values are tensors
            for k, v in params.items()
        }

        torch.save(params_th, path)


class SoftActor(Actor):
    def __init__(self, arch, args, n_state, n_action, has_target=False):
        super(SoftActor, self).__init__(arch, args, n_state, n_action, has_target)
        self.H_target = -n_action[0]
        args.learning_rate=3e-4
        args.alpha = 1
        self.device = "cpu"
        self.log_alpha = torch.tensor(
            math.log(args.alpha), requires_grad=True , device=self.device
        )
        self.optim_alpha = torch.optim.Adam([self.log_alpha], args.learning_rate)

    def update_alpha(self, e):
        self.optim_alpha.zero_grad()
        alpha_loss = -(self.log_alpha.exp() * (e + self.H_target).detach()).mean()
        alpha_loss.backward()
        self.optim_alpha.step()

    def loss(self, s, critics):
        a, e = self.act(s)
        q_list = critics.Q(s, a)
        q = critics.reduce(q_list)
        return (-q + self.log_alpha.exp() * e).mean(), e

    def update(self, s, critics):
        self.optim.zero_grad()
        loss, e = self.loss(s, critics)
        loss.backward()
        self.optim.step()
        self.update_alpha(e)
        self.iter += 1


class CReLU(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = torch.cat((x, -x), -1)
        return F.relu(x)


def create_net(d_in, d_out, depth, width, act="crelu", has_norm=True, n_elements=1):
    assert depth > 0, "Need at least one layer"

    double_width = False
    if act == "crelu":
        act = CReLU
        double_width = True
    elif act == "relu":
        act = nn.ReLU
    else:
        raise NotImplementedError(f"{act} is not implemented")

    if depth == 1:
        arch = nn.Linear(d_in, d_out)
    elif depth == 2:
        arch = nn.Sequential(
            nn.Linear(d_in, width),
            (
                nn.LayerNorm(width, elementwise_affine=False)
                if has_norm
                else nn.Identity()
            ),
            act(),
            nn.Linear(2 * width if double_width else width, d_out),
        )
    else:
        in_layer = nn.Linear(d_in, width)
        if n_elements > 1:
            out_layer = nn.Linear(2 * width if double_width else width, d_out, n_elements)
        else:
            out_layer = nn.Linear(2 * width if double_width else width, d_out)

        hidden = list(
            itertools.chain.from_iterable(
                [
                    [
                        (
                            nn.LayerNorm(width, elementwise_affine=False)
                            if has_norm
                            else nn.Identity()
                        ),
                        act(),
                        nn.Linear(2 * width if double_width else width, width),
                    ]
                    for _ in range(depth - 1)
                ]
            )
        )[:-1]
        arch = nn.Sequential(in_layer, *hidden, out_layer)

    return arch



class SquashedGaussianHead(nn.Module):
    def __init__(self, n, upper_clamp=-2.0):
        super().__init__()
        self._n = n
        self._upper_clamp = upper_clamp

    def forward(self, x, is_training=True):
        # bt means before tanh
        mean_bt = x[..., : self._n]
        log_var_bt = (x[..., self._n :]).clamp(-10, -self._upper_clamp)  # clamp added
        std_bt = log_var_bt.exp().sqrt()
        dist_bt = Normal(mean_bt, std_bt)
        transform = TanhTransform(cache_size=1)
        dist = TransformedDistribution(dist_bt, transform)
        if is_training:
            y = dist.rsample()
            y_logprob = dist.log_prob(y).sum(dim=-1, keepdim=True)
        else:
            y_samples = dist.rsample((100,))
            y = y_samples.mean(dim=0)
            y_logprob = None

        return y, y_logprob  # dist
       
    
class ActorNet(nn.Module):
    def __init__(
        self,
        dim_obs,
        dim_act,
        depth=3,
        width=256,
        act="crelu",
        has_norm=True,
        upper_clamp=None,
    ):
        super().__init__()

        self.arch = create_net(
            dim_obs[0], dim_act[0], depth, width, act, has_norm
        ).append(nn.Tanh())

    def forward(self, x, is_training=None):
        out = self.arch(x).clamp(-0.9999, 0.9999)
        return out, None


class ActorNetEnsemble(ActorNet):
    def __init__(
        self,
        dim_obs,
        dim_act,
        depth=3,
        width=256,
        act="crelu",
        has_norm=True,
        upper_clamp=None,
        n_elements=10
    ):
        super(ActorNetEnsemble, self).__init__(dim_obs, dim_act, depth, width, act, has_norm, upper_clamp)

        self.dim_act = dim_act
        self.arch = create_net(
            dim_obs[0], dim_act[0]*n_elements, depth, width, act, has_norm
        ).append(nn.Tanh())
        self.n_elements = n_elements

    def forward(self, x, is_training=None):
        out = self.arch(x).clamp(-0.9999, 0.9999)
        out = out.view(-1, self.n_elements, self.dim_act[0])
        return out, None
    

class Critic(nn.Module):
    def __init__(self, arch, args, n_state, n_action):
        super().__init__()
        self.args = args
        self.arch = arch
        self.args.depth_critic = 3
        self.args.width_critic = 256
        self.args.act_critic = "crelu"
        self.args.no_norm_critic = False
        self.args.device = "cpu"
        self.args.learning_rate = 3e-4
        self.model = arch(
            n_state,
            n_action,
            depth=self.args.depth_critic,
            width=self.args.width_critic,
            act=self.args.act_critic,
            has_norm=not self.args.no_norm_critic,
        ).to(self.args.device)
        self.target = arch(
            n_state,
            n_action,
            depth=self.args.depth_critic,
            width=self.args.width_critic,
            act=self.args.act_critic,
            has_norm=not self.args.no_norm_critic,
        ).to(self.args.device)
        self.init_target()
        # self.loss = nn.MSELoss()
        self.loss = nn.HuberLoss()
        self.optim = torch.optim.Adam(self.model.parameters(), self.args.learning_rate)
        self.iter = 0

    def set_writer(self, writer):
        self.writer = writer

    def init_target(self):
        for target_param, local_param in zip(
            self.target.parameters(), self.model.parameters()
        ):
            target_param.data.copy_(local_param.data)

    def update_target(self):
        for target_param, local_param in zip(
            self.target.parameters(), self.model.parameters()
        ):
            target_param.data.mul_(1.0 - self.args.tau)
            target_param.data.add_(self.args.tau * local_param.data)

    def Q(self, s, a):
        if a.shape == ():
            a = a.view(1, 1)
        return self.model(th.cat((s, a), -1))

    def Q_t(self, s, a):
        if a.shape == ():
            a = a.view(1, 1)
        return self.target(th.cat((s, a), -1))

    def update(self, s, a, y):  # y denotes bellman target
        self.optim.zero_grad()
        loss = self.loss(self.Q(s, a), y)
        loss.backward()
        self.optim.step()
        self.iter += 1


class Critics(nn.Module):
    def __init__(self, arch, args, n_state, n_action, critictype=Critic):
        super().__init__()
        self.args = args
        self.args.gamma = 0.99
        self.args.tau = 0.005
        self.args.verbose = False
        self.args.buffer_size = 100000
        self.args.learning_rate = 3e-4
        self.args.depth_critic = 3
        self.args.width_critic = 256
        self.args.act_critic = "crelu"
        self.args.no_norm_critic = False
        self.args.device = "cpu"
        self.args.learning_rate = 3e-4
        self.n_members = 10
        self.arch = arch
        self.n_state = n_state
        self.n_action = n_action
        self.critictype = critictype
        self.iter = 0
        # self.loss = nn.MSELoss()
        self.loss = self.critictype(
            self.arch, self.args, self.n_state, self.n_action
        ).loss
        self.optim = self.critictype(
            self.arch, self.args, self.n_state, self.n_action
        ).optim

        # Helperfunctions
        self.expand = lambda x: (
            x.expand(self.n_members, *x.shape) if len(x.shape) < 3 else x
        )
        # self.reduce = lambda q_val: q_val.min(0)[0]

        self.reset()

    def reset(self):
        self.critics = [
            self.critictype(self.arch, self.args, self.n_state, self.n_action)
            for _ in range(self.n_members)
        ]

        self.critics_model = [
            self.critictype(self.arch, self.args, self.n_state, self.n_action).model
            for _ in range(self.n_members)
        ]
        self.critics_target = [
            self.critictype(self.arch, self.args, self.n_state, self.n_action).target
            for _ in range(self.n_members)
        ]

        self.params_model, self.buffers_model = thf.stack_module_state(
            self.critics_model
        )
        self.params_target, self.buffers_target = thf.stack_module_state(
            self.critics_target
        )

        self.base_model = copy.deepcopy(self.critics[0].model).to("meta")
        self.base_target = copy.deepcopy(self.critics[0].target).to("meta")

        def _fmodel(base_model, params, buffers, x):
            return thf.functional_call(base_model, (params, buffers), (x,))

        self.forward_model = thf.vmap(lambda p, b, x: _fmodel(self.base_model, p, b, x))
        self.forward_target = thf.vmap(
            lambda p, b, x: _fmodel(self.base_target, p, b, x)
        )
        self.optim = th.optim.Adam(
            self.params_model.values(), lr=self.args.learning_rate
        )

    def reduce(self, q_val):
        return q_val.min(0)[0]

    def __getitem__(self, item):
        return self.critics[item]

    def unstack(self, target=False, single=True, net_id=None):
        """
        Extract the single parameters back to the individual members
        target: whether the target ensemble should be extracted or not
        single: whether just the first member of the ensemble should be extracted
        """
        params = self.params_target if target else self.params_model
        if single and net_id is None:
            net_id = 0

        for key in params.keys():
            if single:
                tmp = (
                    self.critics[net_id].model
                    if not target
                    else self.critics[net_id].target
                )
                for name in key.split("."):
                    tmp = getattr(tmp, name)
                tmp.data.copy_(params[key][net_id])
            else:
                for net_id in range(self.n_members):
                    tmp = (
                        self.critics[net_id].model
                        if not target
                        else self.critics[net_id].target
                    )
                    for name in key.split("."):
                        tmp = getattr(tmp, name)
                    tmp.data.copy_(params[key][net_id])
                    if single:
                        break

    def set_writer(self, writer):
        assert (
            writer is None
        ), "For now nothing else is implemented for the parallel version"
        self.writer = writer
        [critic.set_writer(writer) for critic in self.critics]

    def Q(self, s, a):
        SA = self.expand(th.cat((s, a), -1))
        return self.forward_model(self.params_model, self.buffers_model, SA)

    @th.no_grad()
    def Q_t(self, s, a):
        SA = self.expand(th.cat((s, a), -1))
        return self.forward_target(self.params_target, self.buffers_target, SA)

    def update(self, s, a, y):  # y denotes bellman target
        self.optim.zero_grad()
        loss = self.loss(self.Q(s, a), self.expand(y))
        loss.backward()
        self.optim.step()
        self.iter += 1

    @torch.no_grad()
    def update_target(self):
        for key in self.params_model.keys():
            self.params_target[key].data.mul_(1.0 - self.args.tau)
            self.params_target[key].data.add_(
                self.args.tau * self.params_model[key].data
            )

    @torch.no_grad()
    def get_bellman_target(self, r, sp, done, actor):
        alpha = actor.log_alpha.exp().detach() if hasattr(actor, "log_alpha") else 0
        ap, ep = actor.act(sp)
        qp = self.Q_t(sp, ap)
        qp_t = self.reduce(qp) - alpha * (ep if ep is not None else 0)
        y = r.unsqueeze(-1) + (self.args.gamma * qp_t * (1 - done.unsqueeze(-1)))
        return y

    def save_params(self, path):
        self.unstack(target=False, single=False, net_id=None)
        self.unstack(target=True, single=False, net_id=None)
        params_list = []
        for i in range(len(self.critics)):
            params_list.append(self.load_params(self.critics[i]))
        torch.save(params_list, path)

    def load_params(self, critic):
        params = {
            "params_model": critic.model.state_dict(),
            "params_target": critic.target.state_dict(),
            "optim": self.optim.state_dict(),
        }
        params_th = {
            k: v if isinstance(v, torch.Tensor) else v  # Ensure the values are tensors
            for k, v in params.items()
        }
        return params_th

class ActorNetProbabilistic(nn.Module):
    def __init__(
        self,
        dim_obs,
        dim_act,
        depth=3,
        width=256,
        act="crelu",
        has_norm=True,
        upper_clamp=-2.0,
    ):
        super().__init__()
        self.dim_act = dim_act

        self.arch = create_net(dim_obs[0], 2 * dim_act[0], depth, width, act, has_norm)

        self.head = SquashedGaussianHead(self.dim_act[0], upper_clamp)

    def forward(self, x, is_training=True):
        f = self.arch(x)
        return self.head(f, is_training)

class CriticNet(nn.Module):
    def __init__(
        self, dim_obs, dim_act, depth=3, width=256, act="crelu", has_norm=True
    ):
        super().__init__()

        self.arch = create_net(
            dim_obs[0] + dim_act[0], 1, depth, width, act=act, has_norm=has_norm
        )

    def forward(self, xu):
        return self.arch(xu)



# Experience Memory
Stores past experiences and allows efficient sampling for training an agent.

In [None]:

class ExperienceMemoryTorch:
    """Fixed-size buffer to store experience tuples."""

    field_names = ["state", "action", "reward", "next_state", "terminated", "step"]

    def __init__(self, args):
        self.device = "cpu"
        self.buffer_size = 100000
        self.dims = args.dims
        self.reset()

    def reset(self, buffer_size=None):
        if buffer_size is not None:
            self.buffer_size = buffer_size
        self.data_size = 0
        self.pointer = 0
        self.memory = {
            field: th.empty(self.dims[field], device=self.device)
            for field in self.field_names
        }

    def add(self, state, action, reward, next_state, terminated, step):
        for field, value in zip(
            self.field_names, [state, action, reward, next_state, terminated, step]
        ):
            self.memory[field][self.pointer] = value
        self.pointer = (self.pointer + 1) % self.buffer_size
        self.data_size = min(self.data_size + 1, self.buffer_size)

    def sample_by_index(self, index):
        return tuple(self.memory[field][index] for field in self.field_names)

    def sample_by_index_fields(self, index, fields):
        if len(fields) == 1:
            return self.memory[fields[0]][index]  # return a tensor
        return tuple(self.memory[field][index] for field in fields)

    def sample_random(self, batch_size):
        index = th.randint(self.data_size, (batch_size,))
        return self.sample_by_index(index)

    @staticmethod
    def set_diff_1d(t1, t2, assume_unique=False):
        """
        Set difference of two 1D tensors.
        Returns the unique values in t1 that are not in t2.
        Source: https://stackoverflow.com/questions/55110047/finding-non-intersection-of-two-pytorch-tensors/72898627#72898627
        """
        if not assume_unique:
            t1 = torch.unique(t1)
            t2 = torch.unique(t2)
        return t1[(t1[:, None] != t2).all(dim=1)]

    def filter_by_nonterminal_steps_with_horizon(self, horizon):
        all_indices = th.arange(self.data_size - horizon + 1)
        terminal_indices = th.argwhere(self.memory["terminated"] == True)
        if terminal_indices.size == 0:
            return all_indices

        terminal_with_horizon_indices = th.tensor(
            [
                th.arange(terminal - horizon + 2, terminal + 1)
                for terminal in terminal_indices
            ]
        ).flatten()
        nonterminal_indices = th.setdiff1d(all_indices, terminal_with_horizon_indices)
        return nonterminal_indices

    def sample_random_sequence_snippet(self, batch_size, sequence_length):
        non_terminal_indices = self.filter_by_nonterminal_steps_with_horizon(
            sequence_length
        )
        indices = th.randint(non_terminal_indices, (batch_size,))
        output = []
        # TODO: Why this loop?
        for i in range(sequence_length):
            output.append(self.sample_by_index(indices + i))
        return output

    def sample_all(self):
        return self.sample_by_index(range(self.data_size))

    def clone(self, other_memory):
        self.data_size = other_memory.data_size
        self.memory = copy.deepcopy(other_memory.memory)

    def extend(self, other_memory):
        for field in self.field_names:
            self.memory[field].extend(other_memory.memory[field])
        self.data_size = len(self.memory[field])

    def __len__(self):
        return self.data_size

    @property
    def size(self):
        return self.data_size

    def save(self, path):
        th.save(self.memory, os.path.join(path, "experience_memory.pt"))

    def get_last_observation(self):
        return self.sample_by_index([-1])

    def get_last_observations(self, batch_size):
        return self.sample_by_index(range(-batch_size, 0))


# Agent
Defines a reinforcement learning agent that interacts with the environment, stores experiences, and updates its models using soft or hard updates.

In [None]:

args, unknown = parser.parse_known_args()

def totorch(x, dtype=th.float32, device="cpu"):
    return th.as_tensor(x, dtype=dtype, device=device)


class Agent(nn.Module):
    def __init__(self, env, args):
        super(Agent, self).__init__()
        self.args = args
        self.device ="cpu" 
        args.buffer_size=100000
        self.tau = 0.005
        self.gamma = 0.99 
        self.env = env
        self.dim_obs, self.dim_act = (
            self.env.observation_space.shape,
            self.env.action_space.shape,
        )
        print(f"INFO: dim_obs = {self.dim_obs} dim_act = {self.dim_act}")
        self.dim_obs_flat, self.dim_act_flat = np.prod(self.dim_obs), np.prod(
            self.dim_act
        )
        self._u_min = totorch(self.env.action_space.low, device=self.device)
        self._u_max = totorch(self.env.action_space.high, device=self.device)
        self._x_min = totorch(self.env.observation_space.low, device=self.device)
        self._x_max = totorch(self.env.observation_space.high, device=self.device)

        self._gamma = self.gamma
        self._tau = self.tau

        args.dims = {
            "state": (args.buffer_size, self.dim_obs_flat),
            "action": (args.buffer_size, self.dim_act_flat),
            "next_state": (args.buffer_size, self.dim_obs_flat),
            "reward": (args.buffer_size),
            "terminated": (args.buffer_size),
            "step": (args.buffer_size),
        }

        self.experience_memory = ExperienceMemoryTorch(args)

    def set_writer(self, writer):
        self.writer = writer

    def _soft_update(self, local_model, target_model):
        for target_param, local_param in zip(
            target_model.parameters(), local_model.parameters()
        ):
            target_param.data.mul_(1.0 - self.args.tau)
            target_param.data.add_(self.args.tau * local_param.data)

    def _hard_update(self, local_model, target_model):

        for target_param, local_param in zip(
            target_model.parameters(), local_model.parameters()
        ):
            target_param.data.copy_(local_param.data)

    def learn(self, max_iter=1):
        raise NotImplementedError(f"learn() not implemented for {self.name} agent")

    def select_action(self, warmup=False, exploit=False):
        raise NotImplementedError(
            f"select_action() not implemented for {self.name} agent"
        )

    def store_transition(self, s, a, r, sp, terminated, truncated, step):
        self.experience_memory.add(s, a, r, sp, terminated, step)
        self.actor.set_episode_status(terminated or truncated)


# Actor-Critic Agent
Creates an actor-critic agent that learns by updating the critic network and adjusting the actor network based on feedback from the critic.

In [8]:

class ActorCritic(Agent):
    _agent_name = "AC"

    def __init__(
        self,
        env,
        args,
        actor_nn,
        critic_nn,
        CriticEnsembleType=CriticEnsemble,
        ActorType=Actor,
    ):
        super(ActorCritic, self).__init__(env, args)
        self.critics = CriticEnsembleType(critic_nn, args, self.dim_obs, self.dim_act)
        self.actor = ActorType(actor_nn, args, self.dim_obs, self.dim_act)
        self.n_iter = 0
        self.policy_delay = 1
        self.args.batch_size = 256

    def set_writer(self, writer):
        self.writer = writer
        self.actor.set_writer(writer)
        self.critics.set_writer(writer)

    def learn(self, max_iter=5):
        if self.args.batch_size > len(self.experience_memory):
            return None

        for ii in range(max_iter):
            s, a, r, sp, done, step = self.experience_memory.sample_random(
                self.args.batch_size
            )
            y = self.critics.get_bellman_target(r, sp, done, self.actor)
            self.critics.update(s, a, y)

            if self.n_iter % self.policy_delay == 0:
                self.actor.update(s, self.critics)
            self.critics.update_target()
            self.n_iter += 1

    @torch.no_grad()
    def select_action(self, s, is_training=True):
        # s to device
        a, _ = self.actor.act(s, is_training=is_training)
        return a

    def Q_value(self, s, a):
        
        if len(s.shape) == 1:
            s = s[None]
        if len(a.shape) == 1:
            a = a[None]
        if isinstance(self.critics, ParallelCritics):
            self.critics.unstack(target=False, single=True)
        
        q = self.critics[0].Q(s, a)
        return q.item()


# REDQ
REDQ agent used for training process of actor critic agent in order to collect environment transisions for constructing train and test datasets.

In [9]:
class REDQCritics(Critics):
    def __init__(self, arch, args, n_state, n_action, critictype=Critic):
        super().__init__(arch, args, n_state, n_action, critictype)
        self.args = args
        self.n_in_target = 2
        
        

    def reduce(self, q_val_list):
        i_targets = torch.randint(0, self.n_members, (self.n_in_target,))
        return torch.stack([q_val_list[i] for i in i_targets], dim=-1).min(-1)[0]
        
class RandomEnsembleDoubleQLearning(ActorCritic):
    _agent_name = "REDQ"

    def __init__(self, env, args, actor_nn=ActorNetProbabilistic, critic_nn=CriticNet):
        super().__init__(
            env,
            args,
            actor_nn,
            critic_nn,
            CriticEnsembleType=REDQCritics,
            ActorType=SoftActor,
        )
        self.args.explore_noise = 0.1
        self.actor.c = self.args.explore_noise
        self.args = args

# Control Experiments
Manage the training and evaluation of an agent in a control task, and its interactions with the environment.

Assign "saveparams" to True in orde to save the policy parameters during agent training. After training completed, put "validationrounds"  True and save 200 validation rounds to collect dataset (100 for training and 100 for test).

In [None]:
from types import SimpleNamespace

class ControlExperiment(Experiment):
    def __init__(self):
        super(ControlExperiment, self).__init__()
        self.args = SimpleNamespace()
        self.args.verbose=False
        self.eval_reward=0
        self.late_eval_reward=0
        self.is_break=False
        self.device_str ="cpu"
        self.optimizer_args = {"lr": 4e-3}
        self.n_total_steps = 0
        self.args.max_steps = 300000
        self.args.eval_frequency=2000
        self.args.eval_episodes=10
        self.args.gamma=0.99
        self.args.warmup_steps=10000
        self.args.learn_frequency=1
        self.args.max_iter=5
        self.args.n_critics=10
        self.args.alpha=1
        self.args.progress=False
        self.args.reset_frequency=0
        self.args.depth_critic=3
        self.args.width_critic=256
        self.args.device = "cpu"
        self.args.saveparams= False 
        self.args.validationrounds= False 
        

    def train(self):
        time_start = time.time()
        information_dict = {
            "episode_rewards": th.zeros(1000000),
            "episode_steps": th.zeros(1000000),
            "step_rewards": np.empty((2 * self.args.max_steps), dtype=object),
        }

        s, _ = self.env.reset()
        s = totorch(s) 
        r_cum = np.zeros(1)
        episode = 0
        e_step = 0
        self.last_saved_step = 0

        for step in tqdm(
            range(self.args.max_steps), leave=True, disable=not self.args.progress
        ):
            e_step += 1

            if (
                step > self.args.warmup_steps
                and self.args.reset_frequency > 0
                and step % self.args.reset_frequency == 0
            ):
                self.agent.critics.reset()
                self.agent.to(self.args.device)

            if step % self.args.eval_frequency == 0:
                self.eval(step)

            if step < self.args.warmup_steps:
                a = self.env.action_space.sample()
                a = totorch(np.clip(a, -1.0, 1.0), device=self.args.device)

            else:
                
                a = self.agent.select_action(s.to("cpu")).clip(-1.0, 1.0)

            sp, r, done, truncated, info = self.env.step(tonumpy(a))
            sp = totorch(sp, device=self.args.device)

            if self.args.verbose and "sp" in self.args.env:
                print("X pos: ", info["x_pos"], "Action norm: ", info["action_norm"])
                # TODO: Write this instead into a file!

            self.agent.store_transition(s, a, r, sp, done, truncated, step + 1)
            #self.agent.to(self.args.device)

            information_dict["step_rewards"][self.n_total_steps + step] = (
                episode,
                step,
                r,
            )

            s = sp  # Update state
            r_cum += r  # Update cumulative reward

            if (
                step >= self.args.warmup_steps
                and (step % self.args.learn_frequency) == 0
            ):
                #self.agent.to(self.args.device)
                self.agent.learn(max_iter=self.args.max_iter)
                
            if self.args.saveparams:
                next_save_step = ((self.last_saved_step // 50000) + 1) * 50000  # Compute next 50000 milestone
                
                if self.last_saved_step < next_save_step <= step or step == self.args.max_steps - 1:  # First step after passing 50000, 100000, etc.
                    self.agent.critics.save_params(
                        f"_logs/{self.args.env}/{self.args.model}/seed_0{self.args.seed}/params_{step}.pth"
                    )
                    self.agent.actor.save_actor_params(
                        f"_logs/{self.args.env}/{self.args.model}/seed_0{self.args.seed}/Actor_params_{step}.pth"
                    )
                    self.last_saved_step = next_save_step  # Update the last saved milestone

            if done or truncated:

                information_dict["episode_rewards"][episode] = r_cum.item()
                information_dict["episode_steps"][episode] = step
                print('Episode:', episode, ' Reward: %.3f' % np.mean(r_cum), 'N-steps: %d' % step)
                s, _ = self.env.reset()
                s = totorch(s, device=self.args.device)
                r_cum = np.zeros(1)
                episode += 1
                e_step = 0


        self.eval(step)
        time_end = time.time()
    
    
    
    @torch.no_grad()
    def eval(self, n_step):
        self.agent.eval()
        results = th.zeros(self.args.eval_episodes)
        q_values = th.zeros((self.args.eval_episodes, 2))
        avg_reward = th.zeros(self.args.eval_episodes)
        collect_infos = {}
        performance_eval_dict = {
            "episode_info": np.empty((2 * self.args.max_steps), dtype=object),
            "trajectory": [],
        }

        for episode in range(self.args.eval_episodes):
            collect_infos[episode] = []
            s, info = self.eval_env.reset()
            s = totorch(s)
            step = 0
            a = self.agent.select_action(s, is_training=False)
            q_values[episode] = self.agent.Q_value(totorch(s, device=self.args.device), totorch(a, device=self.args.device))
            done = False

            while not done:
                s = totorch(s)
                a = self.agent.select_action(s, is_training=False)

                sp, r, term, trunc, info = self.eval_env.step(tonumpy(a))
                collect_infos[episode].append(info)

                if self.args.validationrounds:
                    performance_eval_dict["trajectory"].append(
                        (episode, step, s, a, sp, r, term, trunc, info)
                    )

                done = term or trunc
                s = totorch(sp, device=self.args.device)
                results[episode] += r
                avg_reward[episode] += self.args.gamma**step * r
                step += 1

            if self.args.validationrounds:
                performance_eval_dict["episode_info"][episode] = (
                    episode,
                    avg_reward[episode],
                )

        self.agent.actor.states = []

        self.agent.train()


# Run All

In [None]:
exp = ControlExperiment()
exp.train()