<a href="https://colab.research.google.com/github/aditijha2000/FRE/blob/main/FRE_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install jax jaxlib



In [2]:
!pip install ml-collections

Collecting ml-collections
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/77.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.9/77.9 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: ml-collections
  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
  Created wheel for ml-collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94506 sha256=0fc9b3966b2041571b41e95b65f7bd9d80dde9a25f4a63832237be8dc793781d
  Stored in directory: /root/.cache/pip/wheels/7b/89/c9/a9b87790789e94aadcfc393c283e3ecd5ab916aed0a31be8fe
Successfully built ml-collections
Installing collected packages: ml-collections
Successfully installed ml-collections-0.1.1


In [3]:
!pip install dm_control

Collecting dm_control
  Downloading dm_control-1.0.18-py3-none-any.whl (56.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
Collecting dm-env (from dm_control)
  Downloading dm_env-1.6-py3-none-any.whl (26 kB)
Collecting glfw (from dm_control)
  Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl (211 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.8/211.8 kB[0m [31m26.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting labmaze (from dm_control)
  Downloading labmaze-1.0.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.9/4.9 MB[0m [31m88.3 MB/s[0m eta [36m0:00:00[0m
Collecting mujoco>=3.1.4 (from dm_control)
  Downloading mujoco-3.1.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
!pip install dm-env



In [6]:
#fre\common\typing

from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
import numpy as np
import jax.numpy as jnp
import flax

PRNGKey = Any
Params = flax.core.FrozenDict[str, Any]
PRNGKey = Any
Shape = Sequence[int]
Dtype = Any  # this could be a real type?
InfoDict = Dict[str, float]
Array = Union[np.ndarray, jnp.ndarray]
Data = Union[Array, Dict[str, "Data"]]
Batch = Dict[str, Data]
ModuleMethod = Union[
    str, Callable, None
]  # A method to be passed into TrainState.__call__


#fre\common\dataset
###############################
#
#  Dataset Pytrees for offline data, replay buffers, etc.
#
###############################

import numpy as np
#from fre.common.typing import Data, Array
from flax.core.frozen_dict import FrozenDict
from jax import tree_util


def get_size(data: Data) -> int:
    sizes = tree_util.tree_map(lambda arr: len(arr), data)
    return max(tree_util.tree_leaves(sizes))


class Dataset(FrozenDict):
    """
    A class for storing (and retrieving batches of) data in nested dictionary format.

    Example:
        dataset = Dataset({
            'observations': {
                'image': np.random.randn(100, 28, 28, 1),
                'state': np.random.randn(100, 4),
            },
            'actions': np.random.randn(100, 2),
        })

        batch = dataset.sample(32)
        # Batch should have nested shape: {
        # 'observations': {'image': (32, 28, 28, 1), 'state': (32, 4)},
        # 'actions': (32, 2)
        # }
    """

    @classmethod
    def create(
        cls,
        observations: Data,
        actions: Array,
        rewards: Array,
        masks: Array,
        next_observations: Data,
        freeze=True,
        **extra_fields
    ):
        data = {
            "observations": observations,
            "actions": actions,
            "rewards": rewards,
            "masks": masks,
            "next_observations": next_observations,
            **extra_fields,
        }
        # Force freeze
        if freeze:
            tree_util.tree_map(lambda arr: arr.setflags(write=False), data)
        return cls(data)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.size = get_size(self._dict)

    def sample(self, batch_size: int, indx=None):
        """
        Sample a batch of data from the dataset. Use `indx` to specify a specific
        set of indices to retrieve. Otherwise, a random sample will be drawn.

        Returns a dictionary with the same structure as the original dataset.
        """
        if indx is None:
            indx = np.random.randint(self.size, size=batch_size)
        return self.get_subset(indx)

    def get_subset(self, indx):
        return tree_util.tree_map(lambda arr: arr[indx], self._dict)


class ReplayBuffer(Dataset):
    """
    Dataset where data is added to the buffer.

    Example:
        example_transition = {
            'observations': {
                'image': np.random.randn(28, 28, 1),
                'state': np.random.randn(4),
            },
            'actions': np.random.randn(2),
        }
        buffer = ReplayBuffer.create(example_transition, size=1000)
        buffer.add_transition(example_transition)
        batch = buffer.sample(32)

    """

    @classmethod
    def create(cls, transition: Data, size: int):
        def create_buffer(example):
            example = np.array(example)
            return np.zeros((size, *example.shape), dtype=example.dtype)

        buffer_dict = tree_util.tree_map(create_buffer, transition)
        return cls(buffer_dict)

    @classmethod
    def create_from_initial_dataset(cls, init_dataset: dict, size: int):
        def create_buffer(init_buffer):
            buffer = np.zeros((size, *init_buffer.shape[1:]), dtype=init_buffer.dtype)
            buffer[: len(init_buffer)] = init_buffer
            return buffer

        buffer_dict = tree_util.tree_map(create_buffer, init_dataset)
        dataset = cls(buffer_dict)
        dataset.size = dataset.pointer = get_size(init_dataset)
        return dataset

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.max_size = get_size(self._dict)
        self.size = 0
        self.pointer = 0

    def add_transition(self, transition):
        def set_idx(buffer, new_element):
            buffer[self.pointer] = new_element

        tree_util.tree_map(set_idx, self._dict, transition)
        self.pointer = (self.pointer + 1) % self.max_size
        self.size = max(self.pointer, self.size)

    def clear(self):
        self.size = self.pointer = 0



  # fre/common/evaluation

###############################
#
#  Tools for evaluating policies in environments.
#
###############################


from typing import Dict
import jax
import gym
import numpy as np
from collections import defaultdict
import time
#import wandb


def flatten(d, parent_key="", sep="."):
    """
    Helper function that flattens a dictionary of dictionaries into a single dictionary.
    E.g: flatten({'a': {'b': 1}}) -> {'a.b': 1}
    """
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if hasattr(v, "items"):
            items.extend(flatten(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)


def add_to(dict_of_lists, single_dict):
    for k, v in single_dict.items():
        dict_of_lists[k].append(v)


def evaluate(policy_fn, env: gym.Env, num_episodes: int, record_video : bool = False,
             return_trajectories=False, clip_return_at_goal=False, binary_return=False, use_discrete_xy=False, clip_margin=0):
    print("Clip return at goal is", clip_return_at_goal)
    stats = defaultdict(list)
    frames = []
    trajectories = []
    for i in range(num_episodes):
        now = time.time()
        trajectory = defaultdict(list)
        ob_list = []
        ac_list = []
        observation, done = env.reset(), False
        ob_list.append(observation)
        while not done:
            if use_discrete_xy:
                #import fre.common.envs.d4rl.d4rl_ant as d4rl_ant
                ob_input = d4rl_ant.discretize_obs(observation)
            else:
                ob_input = observation
            action = policy_fn(ob_input)
            action = np.array(action)
            next_observation, r, done, info = env.step(action)
            add_to(stats, flatten(info))

            if type(observation) is dict:
                obs_pure = observation['observation']
                next_obs_pure = next_observation['observation']
            else:
                obs_pure = observation
                next_obs_pure = next_observation
            transition = dict(
                observation=obs_pure,
                next_observation=next_obs_pure,
                action=action,
                reward=r,
                done=done,
                info=info,
            )
            observation = next_observation
            ob_list.append(observation)
            ac_list.append(action)
            add_to(trajectory, transition)

            if i <= 3 and record_video:
                frames.append(env.render(mode="rgb_array"))
        add_to(stats, flatten(info, parent_key="final"))
        trajectories.append(trajectory)
        print("Finished Episode", i, "in", time.time() - now, "seconds")

    if clip_return_at_goal and 'episode.return' in stats:
        print("Episode finished. Return is {}. Length is {}.".format(stats['episode.return'], stats['episode.length']))
        stats['episode.return'] = np.clip(np.array(stats['episode.length']) + np.array(stats['episode.return']) - clip_margin, 0, 1) # Goal is a binary indicator.
        print("Clipped return is {}.".format(stats['episode.return']))
    elif binary_return and 'episode.return' in stats:
        # Assume that the reward is either 0 or 1 at each timestep.
        print("Episode finished. Return is {}. Length is {}.".format(stats['episode.return'], stats['episode.length']))
        stats['episode.return'] = np.clip(np.array(stats['episode.return']), 0, 1)
        print("Clipped return is {}.".format(stats['episode.return']))

    if 'episode.return' in stats:
        print("Episode finished. Return is {}. Length is {}.".format(stats['episode.return'], stats['episode.length']))

    for k, v in stats.items():
        stats[k] = np.mean(v)

    if record_video:
        stacked = np.stack(frames)
        stacked = stacked.transpose(0, 3, 1, 2)
        while stacked.shape[2] > 160:
            stacked = stacked[:, :, ::2, ::2]
        stats['video'] = wandb.Video(stacked, fps=60)

    if return_trajectories:
        return stats, trajectories
    else:
        return stats

  #fre/common/train_state
  ###############################
#
#  Structures for managing training of flax networks.
#
###############################

#from fre.common.typing import *
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
from jax import tree_util
import optax
import functools

import gym

nonpytree_field = functools.partial(flax.struct.field, pytree_node=False)


def shard_batch(batch):
    d = jax.local_device_count()

    def reshape(x):
        assert (
            x.shape[0] % d == 0
        ), f"Batch size needs to be divisible by # devices, got {x.shape[0]} and {d}"
        return x.reshape((d, x.shape[0] // d, *x.shape[1:]))

    return tree_util.tree_map(reshape, batch)


def target_update(
    model: "TrainState", target_model: "TrainState", tau: float
) -> "TrainState":
    new_target_params = jax.tree_map(
        lambda p, tp: p * tau + tp * (1 - tau), model.params, target_model.params
    )
    return target_model.replace(params=new_target_params)


class TrainState(flax.struct.PyTreeNode):
    """
    Core abstraction of a model in this repository.

    Creation:
    ```
        model_def = nn.Dense(12) # or any other flax.linen Module
        params = model_def.init(jax.random.PRNGKey(0), jnp.ones((1, 4)))['params']
        model = TrainState.create(model_def, params, tx=None) # Optionally, pass in an optax optimizer
    ```

    Usage:
    ```
        y = model(jnp.ones((1, 4))) # By default, uses the `__call__` method of the model_def and params stored in TrainState
        y = model(jnp.ones((1, 4)), params=params) # You can pass in params (useful for gradient computation)
        y = model(jnp.ones((1, 4)), method=method) # You can apply a different method as well
    ```

    More complete example:
    ```
        def loss(params):
            y_pred = model(x, params=params)
            return jnp.mean((y - y_pred) ** 2)

        grads = jax.grad(loss)(model.params)
        new_model = model.apply_gradients(grads=grads) # Alternatively, new_model = model.apply_loss_fn(loss_fn=loss)
    ```
    """

    step: int
    apply_fn: Callable[..., Any] = nonpytree_field()
    model_def: Any = nonpytree_field()
    params: Params
    tx: Optional[optax.GradientTransformation] = nonpytree_field()
    opt_state: Optional[optax.OptState] = None

    @classmethod
    def create(
        cls,
        model_def: nn.Module,
        params: Params,
        tx: Optional[optax.GradientTransformation] = None,
        **kwargs,
    ) -> "TrainState":
        if tx is not None:
            opt_state = tx.init(params)
        else:
            opt_state = None

        return cls(
            step=1,
            apply_fn=model_def.apply,
            model_def=model_def,
            params=params,
            tx=tx,
            opt_state=opt_state,
            **kwargs,
        )

    def __call__(
        self,
        *args,
        params=None,
        extra_variables: dict = None,
        method: ModuleMethod = None,
        **kwargs,
    ):
        """
        Internally calls model_def.apply_fn with the following logic:

        Arguments:
            params: If not None, use these params instead of the ones stored in the model.
            extra_variables: Additional variables to pass into apply_fn
            method: If None, use the `__call__` method of the model_def. If a string, uses
                the method of the model_def with that name (e.g. 'encode' -> model_def.encode).
                If a function, uses that function.

        """
        if params is None:
            params = self.params

        variables = {"params": params}

        if extra_variables is not None:
            variables = {**variables, **extra_variables}

        if isinstance(method, str):
            method = getattr(self.model_def, method)

        return self.apply_fn(variables, *args, method=method, **kwargs)

    def do(self, method):
        return functools.partial(self, method=method)

    def apply_gradients(self, *, grads, **kwargs):
        """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.

        Note that internally this function calls `.tx.update()` followed by a call
        to `optax.apply_updates()` to update `params` and `opt_state`.

        Args:
            grads: Gradients that have the same pytree structure as `.params`.
            **kwargs: Additional dataclass attributes that should be `.replace()`-ed.

        Returns:
            An updated instance of `self` with `step` incremented by one, `params`
            and `opt_state` updated by applying `grads`, and additional attributes
            replaced as specified by `kwargs`.
        """
        updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
        new_params = optax.apply_updates(self.params, updates)

        return self.replace(
            step=self.step + 1,
            params=new_params,
            opt_state=new_opt_state,
            **kwargs,
        )

    def apply_loss_fn(self, *, loss_fn, pmap_axis=None, has_aux=False):
        """
        Takes a gradient step towards minimizing `loss_fn`. Internally, this calls
        `jax.grad` followed by `TrainState.apply_gradients`. If pmap_axis is provided,
        additionally it averages gradients (and info) across devices before performing update.
        """
        if has_aux:
            grads, info = jax.grad(loss_fn, has_aux=has_aux)(self.params)
            if pmap_axis is not None:
                grads = jax.lax.pmean(grads, axis_name=pmap_axis)
                info = jax.lax.pmean(info, axis_name=pmap_axis)

            return self.apply_gradients(grads=grads), info

        else:
            grads = jax.grad(loss_fn, has_aux=has_aux)(self.params)
            if pmap_axis is not None:
                grads = jax.lax.pmean(grads, axis_name=pmap_axis)
            return self.apply_gradients(grads=grads)

class NormalizeActionWrapper(gym.Wrapper):
    """A wrapper that maps actions from [-1,1] to [low, hgih]."""
    def __init__(self, env):
        super().__init__(env)
        self.active = type(env.action_space) == gym.spaces.Box
        if self.active:
            self.action_low = env.action_space.low
            self.action_high = env.action_space.high
            self.action_scale = (self.action_high - self.action_low) * 0.5
            self.action_mid = (self.action_high + self.action_low) * 0.5
            print("Normalizing Action Space from [{}, {}] to [-1, 1]".format(self.action_low[0], self.action_high[0]))
    def step(self, action):
        if self.active:
            action = np.clip(action, -1, 1)
            action = action * self.action_scale
            action = action + self.action_mid
        return self.env.step(action)

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

  # fre/common/utils
  ###############################
#
#  Some shared utility functions
#
###############################

import jax

def supply_rng(f, rng=jax.random.PRNGKey(0)):
    """
    Wraps a function to supply jax rng. It will remember the rng state for that function.
    """
    def wrapped(*args, **kwargs):
        nonlocal rng
        rng, key = jax.random.split(rng)
        return f(*args, seed=key, **kwargs)

    return wrapped

#fre/common/wandb
"""WandB logging helpers.

Run setup_wandb(hyperparam_dict, ...) to initialize wandb logging.
See default_wandb_config() for a list of available configurations.

We recommend the following workflow (see examples/mujoco/d4rl_iql.py for a more full example):

    from ml_collections import config_flags
    from jaxrl_m.wandb import setup_wandb, default_wandb_config
    import wandb

    # This line allows us to change wandb config flags from the command line
    config_flags.DEFINE_config_dict('wandb', default_wandb_config(), lock_config=False)

    ...
    def main(argv):
        hyperparams = ...
        setup_wandb(hyperparams, **FLAGS.wandb)

        # Log metrics as you wish now
        wandb.log({'metric': 0.0}, step=0)


With the following setup, you may set wandb configurations from the command line, e.g.
    python main.py --wandb.project=my_project --wandb.group=my_group --wandb.offline
"""
#import wandb

import tempfile
import absl.flags as flags
import ml_collections
from  ml_collections.config_dict import FieldReference
import datetime
#import wandb
import time
import numpy as np
import os


def get_flag_dict():
    flag_dict = {k: getattr(flags.FLAGS, k) for k in flags.FLAGS}
    for k in flag_dict:
        if isinstance(flag_dict[k], ml_collections.ConfigDict):
            flag_dict[k] = flag_dict[k].to_dict()
    return flag_dict


def default_wandb_config():
    config = ml_collections.ConfigDict()
    config.offline = False  # Syncs online or not?
    config.project = "jaxrl_m"  # WandB Project Name
    config.entity = FieldReference(None, field_type=str)  # Which entity to log as (default: your own user)

    group_name = FieldReference(None, field_type=str)  # Group name
    config.exp_prefix = group_name  # Group name (deprecated, but kept for backwards compatibility)
    config.group = group_name  # Group name

    experiment_name = FieldReference(None, field_type=str) # Experiment name
    config.name = experiment_name  # Run name (will be formatted with flags / variant)
    config.exp_descriptor = experiment_name  # Run name (deprecated, but kept for backwards compatibility)

    config.unique_identifier = ""  # Unique identifier for run (will be automatically generated unless provided)
    config.random_delay = 0  # Random delay for wandb.init (in seconds)
    return config


def setup_wandb(
    hyperparam_dict,
    entity=None,
    project="jaxrl_m",
    group=None,
    name=None,
    unique_identifier="",
    offline=False,
    random_delay=0,
    **additional_init_kwargs,
):
    """
    Utility for setting up wandb logging (based on Young's simplesac):

    Arguments:
        - hyperparam_dict: dict of hyperparameters for experiment
        - offline: bool, whether to sync online or not
        - project: str, wandb project name
        - entity: str, wandb entity name (default is your user)
        - group: str, Group name for wandb
        - name: str, Experiment name for wandb (formatted with FLAGS & hyperparameter_dict)
        - unique_identifier: str, Unique identifier for wandb (default is timestamp)
        - random_delay: float, Random delay for wandb.init (in seconds) to avoid collisions
        - additional_init_kwargs: dict, additional kwargs to pass to wandb.init
    Returns:
        - wandb.run

    """
    if "exp_descriptor" in additional_init_kwargs:
        # Remove deprecated exp_descriptor
        additional_init_kwargs.pop("exp_descriptor")
        additional_init_kwargs.pop("exp_prefix")

    if not unique_identifier:
        if random_delay:
            time.sleep(np.random.uniform(0, random_delay))
        unique_identifier = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        unique_identifier += f"_{np.random.randint(0, 1000000):06d}"
        flag_dict = get_flag_dict()
        if 'seed' in flag_dict:
            unique_identifier += f"_{flag_dict['seed']:02d}"

    if name is not None:
        name = name.format(**{**get_flag_dict(), **hyperparam_dict})

    if group is not None and name is not None:
        experiment_id = f"{name}_{unique_identifier}"
    elif name is not None:
        experiment_id = f"{name}_{unique_identifier}"
    else:
        experiment_id = None

    # check if dir exists.
    wandb_output_dir = tempfile.mkdtemp()
    tags = [group] if group is not None else None

    init_kwargs = dict(
        config=hyperparam_dict,
        project=project,
        entity=entity,
        tags=tags,
        group=group,
        dir=wandb_output_dir,
        id=experiment_id,
        name=name,
        settings=wandb.Settings(
            start_method="thread",
            _disable_stats=False,
        ),
        mode="offline" if offline else "online",
        save_code=True,
    )

    init_kwargs.update(additional_init_kwargs)
    run = wandb.init(**init_kwargs)

    wandb.config.update(get_flag_dict())

    wandb_config = dict(
        exp_prefix=group,
        exp_descriptor=name,
        experiment_id=experiment_id,
    )
    wandb.config.update(wandb_config)
    return run

#fre/common/envs/data_transforms

###############################
#
#  Helpful utilities for processing actions, observations.
#
###############################

import numpy as np
import jax.numpy as jnp

class ActionTransform():
    pass

class ActionDiscretizeBins(ActionTransform):
    def __init__(self, bins_per_dim, action_dim):
        self.bins_per_dim = bins_per_dim
        self.action_dim = action_dim
        self.bins = np.linspace(-1, 1, bins_per_dim + 1)

    # Assumes action is in [-1, 1].
    def action_to_ids(self, action):
        ids = np.digitize(action, self.bins) - 1
        ids = np.clip(ids, 0, self.bins_per_dim - 1)
        return ids

    def ids_to_action(self, ids):
        action = (self.bins[ids] + self.bins[ids + 1]) / 2
        return action

class ActionDiscretizeCluster(ActionTransform):
    def __init__(self, num_clusters, data_actions):
        self.num_clusters = num_clusters
        assert len(data_actions.shape) == 2 # (data_size, action_dim)
        print("Clustering actions of shape", data_actions.shape)

        # Cluster the data.
        from sklearn.cluster import KMeans
        kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(data_actions)
        self.centers = kmeans.cluster_centers_
        # self.labels = kmeans.labels_
        self.centers = jnp.array(self.centers)
        print("Average cluster error is", kmeans.inertia_ / len(data_actions))
        print("Average cluster error per dimension is", (kmeans.inertia_ / len(data_actions)) / data_actions.shape[1])
        # print(self.centers.shape)

    def action_to_ids(self, action):
        if len(action.shape) == 1:
            action = action[None]
        assert len(action.shape) == 2 # (batch, action_dim,)
        # Find the closest cluster center.
        dists = jnp.linalg.norm(self.centers[None] - action[:, None], axis=-1)
        ids = jnp.argmin(dists, axis=-1)
        return ids

    def ids_to_action(self, ids):
        action = self.centers[ids]
        return action

# Test
# action_discretize_bins = ActionDiscretizeBins(32, 2)
# action = np.array([-1, -0.999, -0.5, 0, 0.5, 0.999, 1])
# ids = action_discretize_bins.action_to_ids(action)
# print(ids)
# action_recreate = action_discretize_bins.ids_to_action(ids)
# print(action_recreate)
# assert np.abs(action - action_recreate).max() < 0.1

# action_discretize_cluster = ActionDiscretizeCluster(32, np.random.uniform(low=-1, high=1, size=(10000, 1)))
# action = np.array([-1, -0.999, -0.5, 0, 0.5, 0.999, 1])[:, None] # [7, 1]
# ids = action_discretize_cluster.action_to_ids(action)
# print(ids)
# action_recreate = action_discretize_cluster.ids_to_action(ids)
# print(action_recreate)
# assert np.abs(action - action_recreate).max() < 0.1


#fre\common\envs\env_helper


###############################
#
#   Helper that initializes environments with the proper imports.
#   Returns an environment that is:
#   - Action normalized.
#   - Video rendering works.
#   - Episode monitor attached.
#
###############################

import matplotlib.pyplot as plt
import numpy as np
import os
import os.path as osp
import gym
import numpy as np
import functools as ft
#from fre.common.train_state import NormalizeActionWrapper
#from fre.common.envs.wrappers import EpisodeMonitor


# Supported envs:
env_list = [
    # From Gym
    'HalfCheetah-v2',
    'Hopper-v2',
    'Walker2d-v2',
    'Pendulum-v1',
    'CartPole-v1',
    'Acrobot_v1',
    'MountainCar-v0',
    'MountainCarContinuous-v0',
    # From DMC
    'pendulum_swingup',
    'acrobot_swingup',
    'acrobot_swingup_sparse',
    'cartpole_swingup', # has exorl dataset.
    'cartpole_swingup_sparse',
    'pointmass_easy',
    'reacher_easy',
    'reacher_hard',
    'cheetah_run',  # has exorl dataset.
    'hopper_hop',
    'walker_stand', # has exorl dataset.
    'walker_walk', # has exorl dataset.
    'walker_run', # has exorl dataset.
    'quadruped_walk', # has exorl dataset.
    'quadruped_run', # has exorl dataset.
    'humanoid_stand',
    'humanoid_run',
    'jaco_reach_top_left', # has exorl dataset.
    'jaco_reach_bottom_right', # has exorl dataset.
    # TODO: Atari games
    # Offline D4RL envs
    'antmaze-large-diverse-v2', # Start in the corner, goal is in the top corner.
    'gc-antmaze-large-diverse-v2', # Start in the corner, goal is in the top corner.
    'center-antmaze-large-diverse-v2', # Start in the center, goal is UNDEFINED (this is for RARE rewards).
    'maze2d-large-v1',
    'gc-maze2d-large-v1',
    'center-maze2d-large-v1',
    # D4RL mujoco
    'halfcheetah-expert-v2',
    'walker2d-expert-v2',
    'hopper-expert-v2',
    'kitchen-complete-v0' # broken
    'kitchen-mixed-v0' # broken
]

# Making an environment.
def make_env(env_name, **kwargs):
    if 'exorl' in env_name:
        import os
        os.environ['DISPLAY'] = ':0'
        #import fre.common.envs.exorl.dmc as dmc
        _, env_name, task_name = env_name.split('_', 2)
        def make_env(env_name, task_name):
            env = dmc.make(f'{env_name}_{task_name}', obs_type='states', frame_stack=1, action_repeat=1, seed=0)
            env = dmc.DMCWrapper(env, 0)
            return env
        env = make_env(env_name, task_name)
        env.reset()
    elif '_' in env_name: # DMC Control
        #import fre.common.envs.dmc as dmc2gym
        import os
        os.environ['DISPLAY'] = ':0'
        suite, task = env_name.split('_', 1)
        print(suite, task)
        if suite == 'pointmass':
            suite = 'point_mass'
        frame_skip = kwargs['frame_skip'] if 'frame_skip' in kwargs else 2
        visualize_reward = kwargs['visualize_reward'] if 'visualize_reward' in kwargs else False
        env = dmc2gym.make(
            domain_name=suite,
            task_name=task, seed=1,
            frame_skip=frame_skip,
            visualize_reward=visualize_reward)
        env = NormalizeActionWrapper(env)
    elif 'antmaze' in env_name:
        #from fre.common.envs.d4rl.d4rl_ant import CenteredMaze, GoalReachingMaze, MazeWrapper
        if 'gc-antmaze' in env_name:
            env = GoalReachingMaze('antmaze-large-diverse-v2')
        elif 'center-antmaze' in env_name:
            env = CenteredMaze('antmaze-large-diverse-v2')
        else:
            env = MazeWrapper('antmaze-large-diverse-v2')
    elif 'maze2d' in env_name:
        #from fre.common.envs.d4rl.d4rl_ant import CenteredMaze, GoalReachingMaze, MazeWrapper
        if 'gc-maze2d' in env_name:
            env = GoalReachingMaze('maze2d-large-v1')
        elif 'center-maze2d' in env_name:
            env = CenteredMaze('maze2d-large-v1')
        else:
            env = CenteredMaze('maze2d-large-v1', start_loc='original')
    elif 'halfcheetah-' in env_name or 'walker2d-' in env_name or 'hopper-' in env_name: # D4RL Mujoco
        #import d4rl
        #import d4rl.gym_mujoco
        env = gym.make(env_name)
    elif 'kitchen' in env_name: # This doesn't work yet.
        import os
        os.environ['DISPLAY'] = ':0'
        #from fre.common.envs.d4rl.d4rl_utils import KitchenRenderWrapper
        env = KitchenRenderWrapper(gym.make(env_name))
    elif 'bandit' in env_name:
        #from fre.common.envs.bandit.bandit import BanditEnv
        env = BanditEnv()
    else:
        env = gym.make(env_name)
    env = EpisodeMonitor(env)
    return env

# For getting offline data.
def get_dataset(env, env_name, **kwargs):
    if 'exorl' in env_name:
        #from fre.common.envs.exorl.exorl_utils import get_dataset
        env_name_short = env_name.split('_', 1)[1]
        return get_dataset(env, env_name_short, **kwargs)
    elif 'ant' in env_name or 'maze2d' in env_name or 'kitchen' in env_name or 'halfcheetah' in env_name or 'walker2d' in env_name or 'hopper' in env_name:
        #from fre.common.envs.d4rl.d4rl_utils import get_dataset, normalize_dataset
        dataset = get_dataset(env, env_name, **kwargs)
        dataset = normalize_dataset(env_name, dataset)
        return dataset
    elif 'cartpole' in env_name or 'cheetah' in env_name or 'jaco' in env_name or 'quadruped' in env_name or 'walker' in env_name:
        #from fre.common.envs.exorl.exorl_utils import get_dataset
        return get_dataset(env, env_name, **kwargs)

def make_vec_env(env_name, num_envs, **kwargs):
    from gym.vector import SyncVectorEnv
    envs = [lambda : make_env(env_name, **kwargs) for _ in range(num_envs)]
    env = SyncVectorEnv(envs)
    return env

#fre\common\env\gc_utils

#from fre.common.dataset import Dataset
from flax.core.frozen_dict import FrozenDict
from flax.core import freeze
import dataclasses
import numpy as np
import jax
import ml_collections

@dataclasses.dataclass
class GCDataset:
    dataset: Dataset
    p_randomgoal: float
    p_trajgoal: float
    p_currgoal: float
    geom_sample: int
    discount: float
    terminal_key: str = 'dones_float'
    reward_scale: float = 1.0
    reward_shift: float = -1.0
    mask_terminal: int = 1

    @staticmethod
    def get_default_config():
        return ml_collections.ConfigDict({
            'p_randomgoal': 0.3,
            'p_trajgoal': 0.5,
            'p_currgoal': 0.2,
            'geom_sample': 1,
            'discount': 0.99,
            'reward_scale': 1.0,
            'reward_shift': -1.0,
            'mask_terminal': 1,
        })

    def __post_init__(self):
        self.terminal_locs, = np.nonzero(self.dataset[self.terminal_key] > 0)
        assert np.isclose(self.p_randomgoal + self.p_trajgoal + self.p_currgoal, 1.0)

    def sample_goals(self, indx, p_randomgoal=None, p_trajgoal=None, p_currgoal=None):
        if p_randomgoal is None:
            p_randomgoal = self.p_randomgoal
        if p_trajgoal is None:
            p_trajgoal = self.p_trajgoal
        if p_currgoal is None:
            p_currgoal = self.p_currgoal

        batch_size = len(indx)
        # Random goals
        goal_indx = np.random.randint(self.dataset.size, size=batch_size)

        # Goals from the same trajectory
        final_state_indx = self.terminal_locs[np.searchsorted(self.terminal_locs, indx)]

        distance = np.random.rand(batch_size)
        if self.geom_sample:
            us = np.random.rand(batch_size)
            middle_goal_indx = np.minimum(indx + np.ceil(np.log(1 - us) / np.log(self.discount)).astype(int), final_state_indx)
        else:
            middle_goal_indx = np.round((np.minimum(indx + 1, final_state_indx) * distance + final_state_indx * (1 - distance))).astype(int)

        goal_indx = np.where(np.random.rand(batch_size) < p_trajgoal / (1.0 - p_currgoal), middle_goal_indx, goal_indx)

        # Goals at the current state
        goal_indx = np.where(np.random.rand(batch_size) < p_currgoal, indx, goal_indx)

        return goal_indx

    def sample(self, batch_size: int, indx=None):
        if indx is None:
            indx = np.random.randint(self.dataset.size-1, size=batch_size)

        batch = self.dataset.sample(batch_size, indx)
        goal_indx = self.sample_goals(indx)

        success = (indx == goal_indx)
        batch['rewards'] = success.astype(float) * self.reward_scale + self.reward_shift
        batch['goals'] = jax.tree_map(lambda arr: arr[goal_indx], self.dataset['observations'])

        if self.mask_terminal:
            batch['masks'] = 1.0 - success.astype(float)
        else:
            batch['masks'] = np.ones(batch_size)

        return batch

    def sample_traj_random(self, batch_size, num_traj_states, num_random_states, num_random_states_decode):
        indx = np.random.randint(self.dataset.size-1, size=batch_size)
        batch = self.dataset.sample(batch_size, indx)
        indx_expand = np.repeat(indx, num_traj_states-1) # (batch_size * num_traj_states)
        traj_indx = self.sample_goals(indx_expand, p_randomgoal=0.0, p_trajgoal=1.0, p_currgoal=0.0)
        traj_indx = traj_indx.reshape(batch_size, num_traj_states-1) # (batch_size, num_traj_states)
        batch['traj_states'] = jax.tree_map(lambda arr: arr[traj_indx], self.dataset['observations'])
        batch['traj_states'] = np.concatenate([batch['observations'][:,None,:], batch['traj_states']], axis=1)

        rand_indx = np.random.randint(self.dataset.size-1, size=batch_size * num_random_states)
        rand_indx = rand_indx.reshape(batch_size, num_random_states)
        batch['random_states'] = jax.tree_map(lambda arr: arr[rand_indx], self.dataset['observations'])

        rand_indx_decode = np.random.randint(self.dataset.size-1, size=batch_size * num_random_states_decode)
        rand_indx_decode = rand_indx_decode.reshape(batch_size, num_random_states_decode)
        batch['random_states_decode'] = jax.tree_map(lambda arr: arr[rand_indx_decode], self.dataset['observations'])
        return batch

def flatten_obgoal(obgoal):
    return np.concatenate([obgoal['observation'], obgoal['goal']], axis=-1)

#fre\common\envs\gc_utils
#from fre.common.dataset import Dataset
from flax.core.frozen_dict import FrozenDict
from flax.core import freeze
import dataclasses
import numpy as np
import jax
import ml_collections

@dataclasses.dataclass
class GCDataset:
    dataset: Dataset
    p_randomgoal: float
    p_trajgoal: float
    p_currgoal: float
    geom_sample: int
    discount: float
    terminal_key: str = 'dones_float'
    reward_scale: float = 1.0
    reward_shift: float = -1.0
    mask_terminal: int = 1

    @staticmethod
    def get_default_config():
        return ml_collections.ConfigDict({
            'p_randomgoal': 0.3,
            'p_trajgoal': 0.5,
            'p_currgoal': 0.2,
            'geom_sample': 1,
            'discount': 0.99,
            'reward_scale': 1.0,
            'reward_shift': -1.0,
            'mask_terminal': 1,
        })

    def __post_init__(self):
        self.terminal_locs, = np.nonzero(self.dataset[self.terminal_key] > 0)
        assert np.isclose(self.p_randomgoal + self.p_trajgoal + self.p_currgoal, 1.0)

    def sample_goals(self, indx, p_randomgoal=None, p_trajgoal=None, p_currgoal=None):
        if p_randomgoal is None:
            p_randomgoal = self.p_randomgoal
        if p_trajgoal is None:
            p_trajgoal = self.p_trajgoal
        if p_currgoal is None:
            p_currgoal = self.p_currgoal

        batch_size = len(indx)
        # Random goals
        goal_indx = np.random.randint(self.dataset.size, size=batch_size)

        # Goals from the same trajectory
        final_state_indx = self.terminal_locs[np.searchsorted(self.terminal_locs, indx)]

        distance = np.random.rand(batch_size)
        if self.geom_sample:
            us = np.random.rand(batch_size)
            middle_goal_indx = np.minimum(indx + np.ceil(np.log(1 - us) / np.log(self.discount)).astype(int), final_state_indx)
        else:
            middle_goal_indx = np.round((np.minimum(indx + 1, final_state_indx) * distance + final_state_indx * (1 - distance))).astype(int)

        goal_indx = np.where(np.random.rand(batch_size) < p_trajgoal / (1.0 - p_currgoal), middle_goal_indx, goal_indx)

        # Goals at the current state
        goal_indx = np.where(np.random.rand(batch_size) < p_currgoal, indx, goal_indx)

        return goal_indx

    def sample(self, batch_size: int, indx=None):
        if indx is None:
            indx = np.random.randint(self.dataset.size-1, size=batch_size)

        batch = self.dataset.sample(batch_size, indx)
        goal_indx = self.sample_goals(indx)

        success = (indx == goal_indx)
        batch['rewards'] = success.astype(float) * self.reward_scale + self.reward_shift
        batch['goals'] = jax.tree_map(lambda arr: arr[goal_indx], self.dataset['observations'])

        if self.mask_terminal:
            batch['masks'] = 1.0 - success.astype(float)
        else:
            batch['masks'] = np.ones(batch_size)

        return batch

    def sample_traj_random(self, batch_size, num_traj_states, num_random_states, num_random_states_decode):
        indx = np.random.randint(self.dataset.size-1, size=batch_size)
        batch = self.dataset.sample(batch_size, indx)
        indx_expand = np.repeat(indx, num_traj_states-1) # (batch_size * num_traj_states)
        traj_indx = self.sample_goals(indx_expand, p_randomgoal=0.0, p_trajgoal=1.0, p_currgoal=0.0)
        traj_indx = traj_indx.reshape(batch_size, num_traj_states-1) # (batch_size, num_traj_states)
        batch['traj_states'] = jax.tree_map(lambda arr: arr[traj_indx], self.dataset['observations'])
        batch['traj_states'] = np.concatenate([batch['observations'][:,None,:], batch['traj_states']], axis=1)

        rand_indx = np.random.randint(self.dataset.size-1, size=batch_size * num_random_states)
        rand_indx = rand_indx.reshape(batch_size, num_random_states)
        batch['random_states'] = jax.tree_map(lambda arr: arr[rand_indx], self.dataset['observations'])

        rand_indx_decode = np.random.randint(self.dataset.size-1, size=batch_size * num_random_states_decode)
        rand_indx_decode = rand_indx_decode.reshape(batch_size, num_random_states_decode)
        batch['random_states_decode'] = jax.tree_map(lambda arr: arr[rand_indx_decode], self.dataset['observations'])
        return batch

def flatten_obgoal(obgoal):
    return np.concatenate([obgoal['observation'], obgoal['goal']], axis=-1)


#fre/common/envs/wrapper
###############################
#
#  Wrappers on top of gym environments
#
###############################

from typing import Dict
import gym
import numpy as np
import time

class EpisodeMonitor(gym.ActionWrapper):
    """A class that computes episode returns and lengths."""

    def __init__(self, env: gym.Env):
        super().__init__(env)
        self._reset_stats()
        self.total_timesteps = 0

    def _reset_stats(self):
        self.reward_sum = 0.0
        self.episode_length = 0
        self.start_time = time.time()

    def step(self, action: np.ndarray):
        observation, reward, done, info = self.env.step(action)

        self.reward_sum += reward
        self.episode_length += 1
        self.total_timesteps += 1
        info["total"] = {"timesteps": self.total_timesteps}

        if done:
            info["episode"] = {}
            info["episode"]["return"] = self.reward_sum
            info["episode"]["length"] = self.episode_length
            info["episode"]["duration"] = time.time() - self.start_time

            if hasattr(self, "get_normalized_score"):
                info["episode"]["normalized_return"] = (
                    self.get_normalized_score(info["episode"]["return"]) * 100.0
                )

        return observation, reward, done, info

    def reset(self, **kwargs) -> np.ndarray:
        self._reset_stats()
        return self.env.reset(**kwargs)

class RewardOverride(gym.ActionWrapper):
    def __init__(self, env: gym.Env):
        super().__init__(env)
        self.reward_fn = None

    def step(self, action: np.ndarray):
        observation, reward, done, info = self.env.step(action)

        if self.env.observation_space.shape[0] == 24:
            horizontal_velocity = self.env.physics.horizontal_velocity()
            torso_upright = self.env.physics.torso_upright()
            torso_height = self.env.physics.torso_height()
            aux = np.array([horizontal_velocity, torso_upright, torso_height])
            observation_aux = np.concatenate([observation, aux])
            reward = self.reward_fn(observation_aux)
        elif self.env.observation_space.shape[0] == 17:
            horizontal_velocity = self.env.physics.speed()
            aux = np.array([horizontal_velocity])
            observation_aux = np.concatenate([observation, aux])
            reward = self.reward_fn(observation_aux)
        else:
            reward = self.reward_fn(observation)
        return observation, reward, done, info

    def reset(self, **kwargs) -> np.ndarray:
        return self.env.reset(**kwargs)

class TruncateObservation(gym.ObservationWrapper):
    def __init__(self, env: gym.Env, truncate_size: int):
        super().__init__(env)
        self.truncate_size = truncate_size

    def observation(self, observation: np.ndarray) -> np.ndarray:
        return observation[:self.truncate_size]

class GoalWrapper(gym.ObservationWrapper):
    def __init__(self, env: gym.Env):
        super().__init__(env)
        self.custom_goal = None

    def observation(self, observation: np.ndarray) -> np.ndarray:
        if self.custom_goal is not None:
            return np.concatenate([observation, self.custom_goal])
        else:
            return observation

#fre\common\envs\bandit\bandit

import numpy as np
import gym

# Here's a super simple bandit environment that follows the OpenAI Gym API.
# There is one continuous action. The observation is always zero.
# A reward of 1 is given if the action is either 0.5 or -0.5.

class BanditEnv(gym.Env):
    def __init__(self):
        self.action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
        self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
        self._state = None
        self.width = 0.15

    def reset(self):
        self._state = np.zeros(1)
        return self._state

    def step(self, action):
        reward = 0
        if (np.abs(action[0] - 0.5) < self.width) or (np.abs(action[0] + 0.5) < self.width):
            reward = 1
        self.last_action = action
        return self._state, reward, True, {}

    def render(self, mode='human'):
        # Render the last action on a line. Also indicate where the reward is. Return this as a numpy array.
        img = np.ones((20, 100, 3), dtype=np.uint8) * 255
        # Render reward zones in green. 0-100 means actions between -1 and 1.
        center_low = 25
        center_high = 75
        width_int = int(self.width * 50)
        img[:, center_low-width_int:center_low+width_int, :] = [0, 255, 0]
        img[:, center_high-width_int:center_high+width_int, :] = [0, 255, 0]
        # Render the last action in red.
        action = self.last_action[0]
        action = int((action + 1) * 50)
        img[:, action:action+1, :] = [255, 0, 0]
        return img




    def close(self):
        pass

#fre/common/envs/d4rl/antmaze_actions.npy

import numpy as np

# Load the .npy file
file_path = '/content/antmaze_actions.npy'  # Replace with the actual path to your .npy file
loaded_array = np.load(file_path)

# Now you can use loaded_array in your code
#print(loaded_array)

#fre/common/envs/d4rl/d4rl_ant
import matplotlib
matplotlib.use('Agg')
from matplotlib import patches

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from functools import partial
from mpl_toolkits.axes_grid1 import make_axes_locatable

import os
import os.path as osp

import gym
#import d4rl
import numpy as np
import functools as ft
import math
import matplotlib.gridspec as gridspec

#from fre.common.envs.gc_utils import GCDataset

class MazeWrapper(gym.Wrapper):
    def __init__(self, env_name):
        self.env = gym.make(env_name)
        self.env.render(mode='rgb_array', width=200, height=200)
        self.env_name = env_name
        self.inner_env = get_inner_env(self.env)
        if 'antmaze' in env_name:
            if 'medium' in env_name:
                self.env.viewer.cam.lookat[0] = 10
                self.env.viewer.cam.lookat[1] = 10
                self.env.viewer.cam.distance = 40
                self.env.viewer.cam.elevation = -90
            elif 'umaze' in env_name:
                self.env.viewer.cam.lookat[0] = 4
                self.env.viewer.cam.lookat[1] = 4
                self.env.viewer.cam.distance = 30
                self.env.viewer.cam.elevation = -90
            elif 'large' in env_name:
                self.env.viewer.cam.lookat[0] = 18
                self.env.viewer.cam.lookat[1] = 13
                self.env.viewer.cam.distance = 55
                self.env.viewer.cam.elevation = -90
            self.inner_env.goal_sampler = ft.partial(valid_goal_sampler, self.inner_env)
        elif 'maze2d' in env_name:
            if 'open' in env_name:
                pass
            elif 'large' in env_name:
                self.env.viewer.cam.lookat[0] = 5
                self.env.viewer.cam.lookat[1] = 6.5
                self.env.viewer.cam.distance = 15
                self.env.viewer.cam.elevation = -90
                self.env.viewer.cam.azimuth = 180
            self.draw_ant_maze = get_inner_env(gym.make('antmaze-large-diverse-v2'))
        self.action_space = self.env.action_space

    def render(self, *args, **kwargs):
        img = self.env.render(*args, **kwargs)
        if 'maze2d' in self.env_name:
            img = img[::-1]
        return img

    # ======== BELOW is helper stuff for drawing and visualizing ======== #

    def get_starting_boundary(self):
        if 'antmaze' in self.env_name:
            self = self.inner_env
        else:
            self = self.draw_ant_maze
        torso_x, torso_y = self._init_torso_x, self._init_torso_y
        S =  self._maze_size_scaling
        return (0 - S / 2 + S - torso_x, 0 - S/2 + S - torso_y), (len(self._maze_map[0]) * S - torso_x - S/2 - S, len(self._maze_map) * S - torso_y - S/2 - S)

    def XY(self, n=20, m=30):
        bl, tr = self.get_starting_boundary()
        X = np.linspace(bl[0] + 0.04 * (tr[0] - bl[0]) , tr[0] - 0.04 * (tr[0] - bl[0]), m)
        Y = np.linspace(bl[1] + 0.04 * (tr[1] - bl[1]) , tr[1] - 0.04 * (tr[1] - bl[1]), n)

        X,Y = np.meshgrid(X,Y)
        states = np.array([X.flatten(), Y.flatten()]).T
        return states

    def four_goals(self):
        self = self.inner_env

        valid_cells = []
        goal_cells = []

        for i in range(len(self._maze_map)):
            for j in range(len(self._maze_map[0])):
                if self._maze_map[i][j] in [0, 'r', 'g']:
                    valid_cells.append(self._rowcol_to_xy((i, j), add_random_noise=False))

        goals = []
        goals.append(max(valid_cells, key=lambda x: -x[0]-x[1]))
        goals.append(max(valid_cells, key=lambda x: x[0]-x[1]))
        goals.append(max(valid_cells, key=lambda x: x[0]+x[1]))
        goals.append(max(valid_cells, key=lambda x: -x[0] + x[1]))
        return goals

    def draw(self, ax=None, scale=1.0):
        if not ax: ax = plt.gca()
        if 'antmaze' in self.env_name:
            self = self.inner_env
        else:
            self = self.draw_ant_maze
        torso_x, torso_y = self._init_torso_x, self._init_torso_y
        S =  self._maze_size_scaling
        if scale < 1.0:
            S *= 0.965
            torso_x -= 0.7
            torso_y -= 0.95
        for i in range(len(self._maze_map)):
            for j in range(len(self._maze_map[0])):
                struct = self._maze_map[i][j]
                if struct == 1:
                    rect = patches.Rectangle((j *S - torso_x - S/ 2,
                                            i * S- torso_y - S/ 2),
                                            S,
                                            S, linewidth=1, edgecolor='none', facecolor='grey', alpha=1.0)

                    ax.add_patch(rect)
        ax.set_xlim(0 - S /2 + 0.6 * S - torso_x, len(self._maze_map[0]) * S - torso_x - S/2 - S * 0.6)
        ax.set_ylim(0 - S/2 + 0.6 * S - torso_y, len(self._maze_map) * S - torso_y - S/2 - S * 0.6)
        ax.axis('off')

class CenteredMaze(MazeWrapper):
    start_loc: str = "center"

    def __init__(self, env_name, start_loc="center"):
        super().__init__(env_name)
        self.start_loc = start_loc
        self.t = 0

    def step(self, action):
        next_obs, r, done, info = self.env.step(action)
        if 'antmaze' in self.env_name:
            info['x'], info['y'] = self.get_xy()
        self.t += 1
        done = self.t >= 2000
        return next_obs, r, done, info

    def reset(self, **kwargs):
        self.t = 0
        obs = self.env.reset(**kwargs)
        if 'maze2d' in self.env_name:
            if self.start_loc == 'center' or self.start_loc == 'center2':
                obs = self.env.reset_to_location([4, 5.8])
            elif self.start_loc == 'original':
                obs = self.env.reset_to_location([0.9, 0.9])
            else:
                raise NotImplementedError
        elif 'antmaze' in self.env_name:
            if self.start_loc == 'center' or self.start_loc == 'center2':
                self.env.set_xy([20, 15])
                obs[:2] = [20, 15]
            elif self.start_loc == 'original':
                pass
            else:
                raise NotImplementedError
        return obs

class GoalReachingMaze(MazeWrapper):
    def __init__(self, env_name):
        super().__init__(env_name)
        self.observation_space = gym.spaces.Dict({
            'observation': self.env.observation_space,
            'goal': self.env.observation_space,
        })

    def step(self, action):
        next_obs, r, done, info = self.env.step(action)

        if 'antmaze' in self.env_name:
            achieved = self.get_xy()
            desired = self.target_goal
        elif 'maze2d' in self.env_name:
            achieved = next_obs[:2]
            desired = self.env.get_target()
        distance = np.linalg.norm(achieved - desired)
        info['x'], info['y'] = achieved
        info['achieved_goal'] = np.array(achieved)
        info['desired_goal'] = np.copy(desired)
        info['success'] = float(distance < 0.5)
        done = 'TimeLimit.truncated' in info or info['success']

        return self.get_obs(next_obs), r, done, info

    def get_obs(self, obs):
        if 'antmaze' in self.env_name:
            desired = self.target_goal
        elif 'maze2d' in self.env_name:
            desired = self.env.get_target()
        target_goal = obs.copy()
        target_goal[:2] = desired
        if 'antmaze' in self.env_name:
            obs = discretize_obs(obs)
            target_goal = discretize_obs(target_goal)
        return dict(observation=obs, goal=target_goal)

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        if 'maze2d' in self.env_name:
            obs = self.env.reset_to_location([0.9, 0.9])
        return self.get_obs(obs)

    def get_normalized_score(self, score):
        return score

# ===================================
# HELPER FUNCTIONS FOR OB DISCRETIZATION
# ===================================

def discretize_obs(ob, num_bins=32, disc_type='tanh', disc_temperature=1.0):
    min_ob = np.array([0, 0])
    max_ob = np.array([35, 35])
    disc_dims = 2
    bins = np.linspace(min_ob, max_ob, num_bins).T # [num_bins,] values from min_ob to max_ob
    bin_size = (max_ob - min_ob) / num_bins
    if disc_type == 'twohot':
        raise NotImplementedError
    elif disc_type == 'tanh':
        orig_ob = ob
        ob = np.expand_dims(ob, -1)
        # Convert each discretized dimension into num_bins dimensions. Value of each dimension is tanh of the distance from the bin center.
        bin_diff = ob[..., :disc_dims, :] - bins[:disc_dims]
        bin_diff_normalized = bin_diff / np.expand_dims(bin_size[:disc_dims], -1) * disc_temperature
        bin_tanh = np.tanh(bin_diff_normalized).reshape(*orig_ob.shape[:-1], -1)
        disc_ob = np.concatenate([bin_tanh, orig_ob[..., disc_dims:]], axis=-1)
        return disc_ob
    else:
        raise NotImplementedError

# ===================================
# HELPER FUNCTIONS FOR VISUALIZATION
# ===================================

def get_canvas_image(canvas):
    canvas.draw()
    out_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
    out_image = out_image.reshape(canvas.get_width_height()[::-1] + (3,))
    return out_image

def valid_goal_sampler(self, np_random):
    valid_cells = []
    goal_cells = []
    # print('Hello')

    for i in range(len(self._maze_map)):
      for j in range(len(self._maze_map[0])):
        if self._maze_map[i][j] in [0, 'r', 'g']:
          valid_cells.append((i, j))

    # If there is a 'goal' designated, use that. Otherwise, any valid cell can
    # be a goal.
    sample_choices = valid_cells
    cell = sample_choices[np_random.choice(len(sample_choices))]
    xy = self._rowcol_to_xy(cell, add_random_noise=True)

    random_x = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling
    random_y = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling

    xy = (max(xy[0] + random_x, 0), max(xy[1] + random_y, 0))

    return xy


def get_inner_env(env):
    if hasattr(env, '_maze_size_scaling'):
        return env
    elif hasattr(env, 'env'):
        return get_inner_env(env.env)
    elif hasattr(env, 'wrapped_env'):
        return get_inner_env(env.wrapped_env)
    return env


# ===================================
# PLOT VALUE FUNCTION
# ===================================

def value_image(env, dataset, value_fn):
    """
    Visualize the value function.
    Args:
        env: The environment.
        value_fn: a function with signature value_fn([# states, state_dim]) -> [#states, 1]
    Returns:
        A numpy array of the image.
    """
    fig, axs = plt.subplots(2, 2, tight_layout=True)
    axs_flat = axs.flatten()
    canvas = FigureCanvas(fig)
    if type(dataset) is GCDataset:
        dataset = dataset.dataset
    if 'antmaze' in env.env_name:
        goals = env.four_goals()
        goal_states = dataset['observations'][0]
        goal_states = goal_states[-29:] # Remove discretized observations.
        goal_states = np.tile(goal_states, (len(goals), 1))
        goal_states[:, :2] = goals
        goal_states = discretize_obs(goal_states)
    elif 'maze2d' in env.env_name:
        goals = np.array([[0.8, 0.8], [1, 9.7], [6.8, 9], [6.8, 1]])
        goal_states = dataset['observations'][0]
        goal_states = np.tile(goal_states, (len(goals), 1))
        goal_states[:, :2] = goals
    for i in range(4):
        plot_value(goal_states[i], env, dataset, value_fn, axs_flat[i])
    image = get_canvas_image(canvas)
    plt.close(fig)
    return image

def plot_value(goal_observation, env, dataset, value_fn, ax):
    N = 14
    M = 20
    ob_xy = env.XY(n=N, m=M)

    goal_observation = np.tile(goal_observation, (ob_xy.shape[0], 1)) # (N*M, 29)

    base_observation = np.copy(dataset['observations'][0])
    xy_observations = np.tile(base_observation, (ob_xy.shape[0], 1)) # (N*M, 29)
    if 'antmaze' in env.env_name:
        xy_observations = xy_observations[:, -29:] # Remove discretized observations.
        xy_observations[:, :2] = ob_xy # Set to XY.
        xy_observations = discretize_obs(xy_observations) # Discretize again.
        assert xy_observations.shape[1] == 91
    elif 'maze2d' in env.env_name:
        ob_xy_scaled = ob_xy / 3.5
        ob_xy_scaled = ob_xy_scaled[:, [1, 0]]
        xy_observations[:, :2] = ob_xy_scaled
        assert xy_observations.shape[1] == 4 # (x, y, vx, vy)
    values = value_fn(xy_observations, goal_observation) # (N*M, 1)

    x, y = ob_xy[:, 0], ob_xy[:, 1]
    x = x.reshape(N, M)
    y = y.reshape(N, M) * 0.975 + 0.7
    values = values.reshape(N, M)
    mesh = ax.pcolormesh(x, y, values, cmap='viridis')

    env.draw(ax, scale=0.95)


# ===================================
# PLOT TRAJECTORIES
# ===================================

# Makes an image of the trajectory the Ant follows.
def trajectory_image(env, trajectories, **kwargs):
    fig = plt.figure(tight_layout=True)
    canvas = FigureCanvas(fig)

    plot_trajectories(env, trajectories, fig, plt.gca(), **kwargs)

    plt.tight_layout()
    image = get_canvas_image(canvas)
    plt.close(fig)
    return image

# Helper that plots the XY coordinates as scatter plots.
def plot_trajectories(env, trajectories, fig, ax, color_list=None):
    if color_list is None:
        from itertools import cycle
        color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
        color_list = cycle(color_cycle)

    for color, trajectory in zip(color_list, trajectories):
        obs = np.array(trajectory['observation'])

        # convert back to xy?
        if 'ant' in env.env_name:
            all_x = []
            all_y = []
            for info in trajectory['info']:
                all_x.append(info['x'])
                all_y.append(info['y'])
            all_x = np.array(all_x)
            all_y = np.array(all_y)
        elif 'maze2d' in env.env_name:
            all_x = obs[:, 1] * 4 - 3.2
            all_y = obs[:, 0] * 4 - 3.2
        ax.scatter(all_x, all_y, s=5, c=color, alpha=0.2)
        ax.scatter(all_x[-1], all_y[-1], s=50, c=color, marker='*', alpha=1, edgecolors='black')

    env.draw(ax)


#fre/common/envs/d4rl/d4rl_utils
"""import d4rl
import d4rl.gym_mujoco"""
import gym
import numpy as np
from jax import tree_util


"""import fre.common.envs.d4rl.d4rl_ant as d4rl_ant
from fre.common.dataset import Dataset
"""

# Note on AntMaze. Reward = 1 at the goal, and Terminal = 1 at the goal.
# Masks = Does the episode end due to final state?
# Dones_float = Does the episode end due to time limit? OR does the episode end due to final state?
def get_dataset(env: gym.Env, env_name: str, clip_to_eps: bool = True,
                eps: float = 1e-5, dataset=None, filter_terminals=False, obs_dtype=np.float32):
    if dataset is None:
        dataset = d4rl.qlearning_dataset(env)

    if clip_to_eps:
        lim = 1 - eps
        dataset['actions'] = np.clip(dataset['actions'], -lim, lim)

    # Mask everything that is marked as a terminal state.
    # For AntMaze, this should mask the end of each trajectory.
    masks = 1.0 - dataset['terminals']

    # In the AntMaze data, terminal is 1 when at the goal. But the episode doesn't end.
    # This just ensures that we treat AntMaze trajectories as non-ending.
    if "antmaze" in env_name or "maze2d" in env_name:
        dataset['terminals'] = np.zeros_like(dataset['terminals'])

    # if 'antmaze' in env_name:
    #     print("Discretizing AntMaze observations.")
    #     print("Raw observations looks like", dataset['observations'].shape[1:])
    #     dataset['observations'] = d4rl_ant.discretize_obs(dataset['observations'])
    #     dataset['next_observations'] = d4rl_ant.discretize_obs(dataset['next_observations'])
    #     print("Discretized observations looks like", dataset['observations'].shape[1:])

    # Compute dones if terminal OR orbservation jumps.
    dones_float = np.zeros_like(dataset['rewards'])

    imputed_next_observations = np.roll(dataset['observations'], -1, axis=0)
    same_obs = np.all(np.isclose(imputed_next_observations, dataset['next_observations'], atol=1e-5), axis=-1)
    dones_float = 1.0 - same_obs.astype(np.float32)
    dones_float += dataset['terminals']
    dones_float[-1] = 1.0
    dones_float = np.clip(dones_float, 0.0, 1.0)

    observations = dataset['observations'].astype(obs_dtype)
    next_observations = dataset['next_observations'].astype(obs_dtype)

    return Dataset.create(
        observations=observations,
        actions=dataset['actions'].astype(np.float32),
        rewards=dataset['rewards'].astype(np.float32),
        masks=masks.astype(np.float32),
        dones_float=dones_float.astype(np.float32),
        next_observations=next_observations,
    )

def get_normalization(dataset):
    returns = []
    ret = 0
    for r, term in zip(dataset['rewards'], dataset['dones_float']):
        ret += r
        if term:
            returns.append(ret)
            ret = 0
    return (max(returns) - min(returns)) / 1000

def normalize_dataset(env_name, dataset):
    print("Normalizing", env_name)
    if 'antmaze' in env_name or 'maze2d' in env_name:
        return dataset.copy({'rewards': dataset['rewards']- 1.0})
    else:
        normalizing_factor = get_normalization(dataset)
        print(f'Normalizing factor: {normalizing_factor}')
        dataset = dataset.copy({'rewards': dataset['rewards'] / normalizing_factor})
        return dataset

# Flattens environment with a dictionary of observation,goal to a single concatenated observation.
class GoalReachingFlat(gym.Wrapper):
    """A wrapper that maps actions from [-1,1] to [low, hgih]."""
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf, shape=(self.observation_space['observation'].shape[0] + self.observation_space['goal'].shape[0],), dtype=np.float32)

    def step(self, action):
        ob, reward, done, info = self.env.step(action)
        ob_flat = np.concatenate([ob['observation'], ob['goal']])
        return ob_flat, reward, done, info

    def reset(self, **kwargs):
        ob = self.env.reset(**kwargs)
        ob_flat = np.concatenate([ob['observation'], ob['goal']])
        return ob_flat

def parse_trajectories(dataset):
    trajectory_ids = np.where(dataset['dones_float'] == 1)[0] + 1
    trajectory_ids = np.concatenate([[0], trajectory_ids])
    num_trajectories = trajectory_ids.shape[0] - 1
    print("There are {} trajectories. Some traj lens are {}".format(num_trajectories, [trajectory_ids[i + 1] - trajectory_ids[i] for i in range(min(5, num_trajectories))]))
    trajectories = []
    for i in range(len(trajectory_ids) - 1):
        trajectories.append(tree_util.tree_map(lambda arr: arr[trajectory_ids[i]:trajectory_ids[i + 1]], dataset._dict))
    return trajectories

class KitchenRenderWrapper(gym.Wrapper):
    def render(self, *args, **kwargs):
        from dm_control.mujoco import engine
        camera = engine.MovableCamera(self.sim, 1920, 2560)
        camera.set_pose(distance=2.2, lookat=[-0.2, .5, 2.], azimuth=70, elevation=-35)
        img = camera.render()
        return img

#fre/common/envs/dmc/__init__.py

import gym
from gym.envs.registration import register


def make(
        domain_name,
        task_name,
        seed=1,
        visualize_reward=True,
        from_pixels=False,
        height=84,
        width=84,
        camera_id=0,
        frame_skip=1,
        episode_length=1000,
        environment_kwargs=None,
        time_limit=None,
        channels_first=True
):
    env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed)

    if from_pixels:
        assert not visualize_reward, 'cannot use visualize reward when learning from pixels'

    # shorten episode length
    max_episode_steps = (episode_length + frame_skip - 1) // frame_skip

    if not env_id in gym.envs.registry.env_specs:
        task_kwargs = {}
        if seed is not None:
            task_kwargs['random'] = seed
        if time_limit is not None:
            task_kwargs['time_limit'] = time_limit
        register(
            id=env_id,
            # entry_point='dmc2gym.wrappers:DMCWrapper',
            entry_point='fre.common.envs.dmc.wrappers:DMCWrapper',
            kwargs=dict(
                domain_name=domain_name,
                task_name=task_name,
                task_kwargs=task_kwargs,
                environment_kwargs=environment_kwargs,
                visualize_reward=visualize_reward,
                from_pixels=from_pixels,
                height=height,
                width=width,
                camera_id=camera_id,
                frame_skip=frame_skip,
                channels_first=channels_first,
            ),
            max_episode_steps=max_episode_steps,
        )
    return gym.make(env_id)

#fre/common/envs/dmc/jaco
"""A task where the goal is to move the hand close to a target prop or site."""

import collections

from dm_control import composer
from dm_control.composer import initializers
from dm_control.composer.variation import distributions
from dm_control.entities import props
from dm_control.manipulation.shared import arenas
from dm_control.manipulation.shared import cameras
from dm_control.manipulation.shared import constants
from dm_control.manipulation.shared import observations
from dm_control.manipulation.shared import robots
from dm_control.manipulation.shared import workspaces
from dm_control.utils import rewards
from dm_env import specs
import numpy as np

_ReachWorkspace = collections.namedtuple(
    '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset'])

# Ensures that the props are not touching the table before settling.
_PROP_Z_OFFSET = 0.001

_DUPLO_WORKSPACE = _ReachWorkspace(
    target_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, _PROP_Z_OFFSET),
                                       upper=(0.1, 0.1, _PROP_Z_OFFSET)),
    tcp_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, 0.2),
                                    upper=(0.1, 0.1, 0.4)),
    arm_offset=robots.ARM_OFFSET)

_SITE_WORKSPACE = _ReachWorkspace(
    target_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02),
                                       upper=(0.2, 0.2, 0.4)),
    tcp_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02),
                                    upper=(0.2, 0.2, 0.4)),
    arm_offset=robots.ARM_OFFSET)

_TARGET_RADIUS = 0.05
_TIME_LIMIT = 10.

TASKS = [('reach_top_left', np.array([-0.09, 0.09, _PROP_Z_OFFSET])),
         ('reach_top_right', np.array([0.09, 0.09, _PROP_Z_OFFSET])),
         ('reach_bottom_left', np.array([-0.09, -0.09, _PROP_Z_OFFSET])),
         ('reach_bottom_right', np.array([0.09, -0.09, _PROP_Z_OFFSET]))]


def make(task_id, obs_type, seed):
    obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES
    task = _reach(task_id, obs_settings=obs_settings, use_site=True)
    return composer.Environment(task,
                                time_limit=_TIME_LIMIT,
                                random_state=seed)


class MultiTaskReach(composer.Task):
    """Bring the hand close to a target prop or site."""

    def __init__(self, task_id, arena, arm, hand, prop, obs_settings,
                 workspace, control_timestep):
        """Initializes a new `Reach` task.
    Args:
      arena: `composer.Entity` instance.
      arm: `robot_base.RobotArm` instance.
      hand: `robot_base.RobotHand` instance.
      prop: `composer.Entity` instance specifying the prop to reach to, or None
        in which case the target is a fixed site whose position is specified by
        the workspace.
      obs_settings: `observations.ObservationSettings` instance.
      workspace: `_ReachWorkspace` specifying the placement of the prop and TCP.
      control_timestep: Float specifying the control timestep in seconds.
    """
        self._arena = arena
        self._arm = arm
        self._hand = hand
        self._arm.attach(self._hand)
        self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
        self.control_timestep = control_timestep
        self._tcp_initializer = initializers.ToolCenterPointInitializer(
            self._hand,
            self._arm,
            position=distributions.Uniform(*workspace.tcp_bbox),
            quaternion=workspaces.DOWN_QUATERNION)

        # Add custom camera observable.
        self._task_observables = cameras.add_camera_observables(
            arena, obs_settings, cameras.FRONT_CLOSE)

        if task_id == 'reach_multitask':
            self._targets = [target for (_, target) in TASKS]
        else:
            self._targets = [
                target for (task, target) in TASKS if task == task_id
            ]
            assert len(self._targets) > 0

        #target_pos_distribution = distributions.Uniform(*TASKS[task_id])
        self._prop = prop
        if prop:
            # The prop itself is used to visualize the target location.
            self._make_target_site(parent_entity=prop, visible=False)
            self._target = self._arena.add_free_entity(prop)
            self._prop_placer = initializers.PropPlacer(
                props=[prop],
                position=target_pos_distribution,
                quaternion=workspaces.uniform_z_rotation,
                settle_physics=True)
        else:
            if len(self._targets) == 1:
                self._target = self._make_target_site(parent_entity=arena,
                                                      visible=True)

            #obs = observable.MJCFFeature('pos', self._target)
            # obs.configure(**obs_settings.prop_pose._asdict())
            #self._task_observables['target_position'] = obs

        # Add sites for visualizing the prop and target bounding boxes.
        workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
                                 lower=workspace.tcp_bbox.lower,
                                 upper=workspace.tcp_bbox.upper,
                                 rgba=constants.GREEN,
                                 name='tcp_spawn_area')
        workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
                                 lower=workspace.target_bbox.lower,
                                 upper=workspace.target_bbox.upper,
                                 rgba=constants.BLUE,
                                 name='target_spawn_area')

    def _make_target_site(self, parent_entity, visible):
        return workspaces.add_target_site(
            body=parent_entity.mjcf_model.worldbody,
            radius=_TARGET_RADIUS,
            visible=visible,
            rgba=constants.RED,
            name='target_site')

    @property
    def root_entity(self):
        return self._arena

    @property
    def arm(self):
        return self._arm

    @property
    def hand(self):
        return self._hand

    def get_reward_spec(self):
        n = len(self._targets)
        return specs.Array(shape=(n,), dtype=np.float32, name='reward')

    @property
    def task_observables(self):
        return self._task_observables

    def get_reward(self, physics):
        hand_pos = physics.bind(self._hand.tool_center_point).xpos
        rews = []
        for target_pos in self._targets:
            distance = np.linalg.norm(hand_pos - target_pos)
            reward = rewards.tolerance(distance,
                                       bounds=(0, _TARGET_RADIUS),
                                       margin=_TARGET_RADIUS)
            rews.append(reward)
        rews = np.array(rews).astype(np.float32)
        if len(self._targets) == 1:
            return rews[0]
        return rews

    def initialize_episode(self, physics, random_state):
        self._hand.set_grasp(physics, close_factors=random_state.uniform())
        self._tcp_initializer(physics, random_state)
        if self._prop:
            self._prop_placer(physics, random_state)
        else:
            if len(self._targets) == 1:
                physics.bind(self._target).pos = self._targets[0]


def _reach(task_id, obs_settings, use_site):
    """Configure and instantiate a `Reach` task.
  Args:
    obs_settings: An `observations.ObservationSettings` instance.
    use_site: Boolean, if True then the target will be a fixed site, otherwise
      it will be a moveable Duplo brick.
  Returns:
    An instance of `reach.Reach`.
  """
    arena = arenas.Standard()
    arm = robots.make_arm(obs_settings=obs_settings)
    hand = robots.make_hand(obs_settings=obs_settings)
    if use_site:
        workspace = _SITE_WORKSPACE
        prop = None
    else:
        workspace = _DUPLO_WORKSPACE
        prop = props.Duplo(observable_options=observations.make_options(
            obs_settings, observations.FREEPROP_OBSERVABLES))
    task = MultiTaskReach(task_id,
                          arena=arena,
                          arm=arm,
                          hand=hand,
                          prop=prop,
                          obs_settings=obs_settings,
                          workspace=workspace,
                          control_timestep=constants.CONTROL_TIMESTEP)
    return task

#fre/common/envs/dmc/wrappers
from gym import core, spaces
from dm_control import suite
from dm_env import specs
import numpy as np


def _spec_to_box(spec, dtype):
    def extract_min_max(s):
        assert s.dtype == np.float64 or s.dtype == np.float32
        dim = int(np.prod(s.shape))
        if type(s) == specs.Array:
            bound = np.inf * np.ones(dim, dtype=np.float32)
            return -bound, bound
        elif type(s) == specs.BoundedArray:
            zeros = np.zeros(dim, dtype=np.float32)
            return s.minimum + zeros, s.maximum + zeros

    mins, maxs = [], []
    for s in spec:
        mn, mx = extract_min_max(s)
        mins.append(mn)
        maxs.append(mx)
    low = np.concatenate(mins, axis=0).astype(dtype)
    high = np.concatenate(maxs, axis=0).astype(dtype)
    assert low.shape == high.shape
    return spaces.Box(low, high, dtype=dtype)


def _flatten_obs(obs):
    obs_pieces = []
    for v in obs.values():
        flat = np.array([v]) if np.isscalar(v) else v.ravel()
        obs_pieces.append(flat)
    return np.concatenate(obs_pieces, axis=0)


class DMCWrapper(core.Env):
    def __init__(
        self,
        domain_name,
        task_name,
        task_kwargs=None,
        visualize_reward={},
        from_pixels=False,
        height=84,
        width=84,
        camera_id=0,
        frame_skip=1,
        environment_kwargs=None,
        channels_first=True
    ):
        assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
        self._from_pixels = from_pixels
        self._height = height
        self._width = width
        self._camera_id = camera_id
        self._frame_skip = frame_skip
        self._channels_first = channels_first

        # create task
        if domain_name == 'jaco':
            #import fre.common.envs.dmc.jaco as jaco
            self._env = jaco.make(task_id=task_name, obs_type=jaco.observations.PERFECT_FEATURES, seed=1)
        else:
            self._env = suite.load(
                domain_name=domain_name,
                task_name=task_name,
                task_kwargs=task_kwargs,
                visualize_reward=visualize_reward,
                environment_kwargs=environment_kwargs
            )

        # true and normalized action spaces
        self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32)
        self._norm_action_space = spaces.Box(
            low=-1.0,
            high=1.0,
            shape=self._true_action_space.shape,
            dtype=np.float32
        )

        # create observation space
        if from_pixels:
            shape = [3, height, width] if channels_first else [height, width, 3]
            self._observation_space = spaces.Box(
                low=0, high=255, shape=shape, dtype=np.uint8
            )
        else:
            self._observation_space = _spec_to_box(
                self._env.observation_spec().values(),
                np.float64
            )

        self._state_space = _spec_to_box(
            self._env.observation_spec().values(),
            np.float64
        )

        self.current_state = None

        # set seed
        self.seed(seed=task_kwargs.get('random', 1))

    def __getattr__(self, name):
        return getattr(self._env, name)

    def _get_obs(self, time_step):
        if self._from_pixels:
            obs = self.render(
                height=self._height,
                width=self._width,
                camera_id=self._camera_id
            )
            if self._channels_first:
                obs = obs.transpose(2, 0, 1).copy()
        else:
            obs = _flatten_obs(time_step.observation)

        return obs

    def _convert_action(self, action):
        action = action.astype(np.float32)
        true_delta = self._true_action_space.high - self._true_action_space.low
        norm_delta = self._norm_action_space.high - self._norm_action_space.low
        action = (action - self._norm_action_space.low) / norm_delta
        action = action * true_delta + self._true_action_space.low
        action = action.astype(np.float32)
        return action

    @property
    def observation_space(self):
        return self._observation_space

    @property
    def state_space(self):
        return self._state_space

    @property
    def action_space(self):
        return self._norm_action_space

    @property
    def reward_range(self):
        return 0, self._frame_skip

    def seed(self, seed):
        self._true_action_space.seed(seed)
        self._norm_action_space.seed(seed)
        self._observation_space.seed(seed)

    def step(self, action):
        assert self._norm_action_space.contains(action)
        action = self._convert_action(action)
        assert self._true_action_space.contains(action)
        reward = 0
        extra = {'internal_state': self._env.physics.get_state().copy()}

        for _ in range(self._frame_skip):
            time_step = self._env.step(action)
            reward += time_step.reward or 0
            done = time_step.last()
            if done:
                break
        obs = self._get_obs(time_step)
        self.current_state = _flatten_obs(time_step.observation)
        extra['discount'] = time_step.discount
        return obs, reward, done, extra

    def reset(self):
        time_step = self._env.reset()
        self.current_state = _flatten_obs(time_step.observation)
        obs = self._get_obs(time_step)
        return obs

    def render(self, mode='rgb_array', height=None, width=None, camera_id=0):
        assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode
        height = height or self._height
        width = width or self._width
        camera_id = camera_id or self._camera_id
        return self._env.physics.render(
            height=height, width=width, camera_id=camera_id
        )

#fre/common/envs/exorl/custom_dmc_tasks/__init__.py

import typing as tp
"""from . import cheetah
from . import walker
from . import hopper
from . import quadruped
from . import jaco

"""

def make(domain, task,
         task_kwargs=None,
         environment_kwargs=None,
         visualize_reward: bool = False):

    if domain == 'cheetah':
        return cheetah.make(task,
                            task_kwargs=task_kwargs,
                            environment_kwargs=environment_kwargs,
                            visualize_reward=visualize_reward)
    elif domain == 'walker':
        return walker.make(task,
                           task_kwargs=task_kwargs,
                           environment_kwargs=environment_kwargs,
                           visualize_reward=visualize_reward)
    elif domain == 'hopper':
        return hopper.make(task,
                           task_kwargs=task_kwargs,
                           environment_kwargs=environment_kwargs,
                           visualize_reward=visualize_reward)
    elif domain == 'quadruped':
        return quadruped.make(task,
                              task_kwargs=task_kwargs,
                              environment_kwargs=environment_kwargs,
                              visualize_reward=visualize_reward)
    elif domain == 'point_mass_maze':
        return point_mass_maze.make(task,
                                    task_kwargs=task_kwargs,
                                    environment_kwargs=environment_kwargs,
                                    visualize_reward=visualize_reward)

    else:
        raise ValueError(f'{task} not found')

    assert None


def make_jaco(task, obs_type, seed) -> tp.Any:
    return jaco.make(task, obs_type, seed)

#fre\common\envs\exorl\custom_dmc_tasks\cheetah
"""Cheetah Domain."""

import collections
import os
import typing as tp
from typing import Any, Tuple

from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.utils import containers
from dm_control.utils import rewards
from dm_control.utils import io as resources

_DEFAULT_TIME_LIMIT: int
_RUN_SPEED: int
_SPIN_SPEED: int

# How long the simulation will run, in seconds.
_DEFAULT_TIME_LIMIT = 10

# Running speed above which reward is 1.
_RUN_SPEED = 10
_WALK_SPEED = 2
_SPIN_SPEED = 5

SUITE = containers.TaggedTasks()


def make(task,
         task_kwargs=None,
         environment_kwargs=None,
         visualize_reward: bool = False):
    task_kwargs = task_kwargs or {}
    if environment_kwargs is not None:
        task_kwargs = task_kwargs.copy()
        task_kwargs['environment_kwargs'] = environment_kwargs
    env = SUITE[task](**task_kwargs)
    env.task.visualize_reward = visualize_reward
    return env


def get_model_and_assets() -> Tuple[Any, Any]:
    """Returns a tuple containing the model XML string and a dict of assets."""
    root_dir = os.path.dirname(os.path.dirname(__file__))
    xml = resources.GetResource(
        os.path.join(root_dir, 'custom_dmc_tasks', 'cheetah.xml'))
    return xml, common.ASSETS


@SUITE.add('benchmarking')
def walk(time_limit: int = _DEFAULT_TIME_LIMIT,
         random=None,
         environment_kwargs=None):
    """Returns the run task."""
    physics = Physics.from_xml_string(*get_model_and_assets())
    task = Cheetah(move_speed=_WALK_SPEED, forward=True, flip=False, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics,
                               task,
                               time_limit=time_limit,
                               **environment_kwargs)


@SUITE.add('benchmarking')
def walk_backward(time_limit: int = _DEFAULT_TIME_LIMIT,
                  random=None,
                  environment_kwargs=None):
    """Returns the run task."""
    physics = Physics.from_xml_string(*get_model_and_assets())
    task = Cheetah(move_speed=_WALK_SPEED, forward=False, flip=False, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics,
                               task,
                               time_limit=time_limit,
                               **environment_kwargs)


@SUITE.add('benchmarking')
def run_backward(time_limit: int = _DEFAULT_TIME_LIMIT,
                 random=None,
                 environment_kwargs=None):
    """Returns the run task."""
    physics = Physics.from_xml_string(*get_model_and_assets())
    task = Cheetah(forward=False, flip=False, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics,
                               task,
                               time_limit=time_limit,
                               **environment_kwargs)


@SUITE.add('benchmarking')
def flip(time_limit: int = _DEFAULT_TIME_LIMIT,
         random=None,
         environment_kwargs=None):
    """Returns the run task."""
    physics = Physics.from_xml_string(*get_model_and_assets())
    task = Cheetah(move_speed=_WALK_SPEED, forward=True, flip=True, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics,
                               task,
                               time_limit=time_limit,
                               **environment_kwargs)


@SUITE.add('benchmarking')
def flip_backward(time_limit: int = _DEFAULT_TIME_LIMIT,
                  random=None,
                  environment_kwargs=None):
    """Returns the run task."""
    physics = Physics.from_xml_string(*get_model_and_assets())
    task = Cheetah(move_speed=_WALK_SPEED, forward=False, flip=True, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics,
                               task,
                               time_limit=time_limit,
                               **environment_kwargs)


class Physics(mujoco.Physics):
    """Physics simulation with additional features for the Cheetah domain."""

    def speed(self) -> Any:
        """Returns the horizontal speed of the Cheetah."""
        return self.named.data.sensordata['torso_subtreelinvel'][0]

    def angmomentum(self) -> Any:
        """Returns the angular momentum of torso of the Cheetah about Y axis."""
        return self.named.data.subtree_angmom['torso'][1]


class Cheetah(base.Task):
    """A `Task` to train a running Cheetah."""

    def __init__(self, move_speed=_RUN_SPEED, forward=True, flip=False, random=None) -> None:
        self._move_speed = move_speed
        self._forward = 1 if forward else -1
        self._flip = flip
        super(Cheetah, self).__init__(random=random)
        self._timeout_progress = 0

    def initialize_episode(self, physics) -> None:
        """Sets the state of the environment at the start of each episode."""
        # The indexing below assumes that all joints have a single DOF.
        assert physics.model.nq == physics.model.njnt
        is_limited = physics.model.jnt_limited == 1
        lower, upper = physics.model.jnt_range[is_limited].T
        physics.data.qpos[is_limited] = self.random.uniform(lower, upper)

        # Stabilize the model before the actual simulation.
        for _ in range(200):
            physics.step()

        physics.data.time = 0
        self._timeout_progress = 0
        super().initialize_episode(physics)

    def get_observation(self, physics) -> tp.Dict[str, Any]:
        """Returns an observation of the state, ignoring horizontal position."""
        obs = collections.OrderedDict()
        # Ignores horizontal position to maintain translational invariance.
        obs['position'] = physics.data.qpos[1:].copy()
        obs['velocity'] = physics.velocity()
        return obs

    def get_reward(self, physics) -> Any:
        """Returns a reward to the agent."""
        if self._flip:
            reward = rewards.tolerance(self._forward * physics.angmomentum(),
                                       bounds=(_SPIN_SPEED, float('inf')),
                                       margin=_SPIN_SPEED,
                                       value_at_margin=0,
                                       sigmoid='linear')

        else:
            reward = rewards.tolerance(self._forward * physics.speed(),
                                       bounds=(self._move_speed, float('inf')),
                                       margin=self._move_speed,
                                       value_at_margin=0,
                                       sigmoid='linear')
        return reward

#fre\common\envs\exorl\custom_dmc_tasks\hopper
"""Hopper domain."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import os
import typing as tp
from typing import Any, Tuple

import numpy as np
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.suite.utils import randomizers
from dm_control.utils import containers
from dm_control.utils import rewards
from dm_control.utils import io as resources

_CONTROL_TIMESTEP: float
_DEFAULT_TIME_LIMIT: int
_HOP_SPEED: int
_SPIN_SPEED: int
_STAND_HEIGHT: float

SUITE = containers.TaggedTasks()

_CONTROL_TIMESTEP = .02  # (Seconds)

# Default duration of an episode, in seconds.
_DEFAULT_TIME_LIMIT = 20

# Minimal height of torso over foot above which stand reward is 1.
_STAND_HEIGHT = 0.6

# Hopping speed above which hop reward is 1.
_HOP_SPEED = 2
_SPIN_SPEED = 5


def make(task,
         task_kwargs=None,
         environment_kwargs=None,
         visualize_reward: bool = False):
    task_kwargs = task_kwargs or {}
    if environment_kwargs is not None:
        task_kwargs = task_kwargs.copy()
        task_kwargs['environment_kwargs'] = environment_kwargs
    env = SUITE[task](**task_kwargs)
    env.task.visualize_reward = visualize_reward
    return env


def get_model_and_assets() -> Tuple[Any, Any]:
    """Returns a tuple containing the model XML string and a dict of assets."""
    root_dir = os.path.dirname(os.path.dirname(__file__))
    xml = resources.GetResource(
        os.path.join(root_dir, 'custom_dmc_tasks', 'hopper.xml'))
    return xml, common.ASSETS


@SUITE.add('benchmarking')
def hop_backward(time_limit: int = _DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns a Hopper that strives to hop forward."""
    physics = Physics.from_xml_string(*get_model_and_assets())
    task = Hopper(hopping=True, forward=False, flip=False, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics,
                               task,
                               time_limit=time_limit,
                               control_timestep=_CONTROL_TIMESTEP,
                               **environment_kwargs)


@SUITE.add('benchmarking')
def flip(time_limit: int = _DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns a Hopper that strives to hop forward."""
    physics = Physics.from_xml_string(*get_model_and_assets())
    task = Hopper(hopping=True, forward=True, flip=True, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics,
                               task,
                               time_limit=time_limit,
                               control_timestep=_CONTROL_TIMESTEP,
                               **environment_kwargs)


@SUITE.add('benchmarking')
def flip_backward(time_limit: int = _DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns a Hopper that strives to hop forward."""
    physics = Physics.from_xml_string(*get_model_and_assets())
    task = Hopper(hopping=True, forward=False, flip=True, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics,
                               task,
                               time_limit=time_limit,
                               control_timestep=_CONTROL_TIMESTEP,
                               **environment_kwargs)


class Physics(mujoco.Physics):
    """Physics simulation with additional features for the Hopper domain."""

    def height(self) -> Any:
        """Returns height of torso with respect to foot."""
        return (self.named.data.xipos['torso', 'z'] -
                self.named.data.xipos['foot', 'z'])

    def speed(self) -> Any:
        """Returns horizontal speed of the Hopper."""
        return self.named.data.sensordata['torso_subtreelinvel'][0]

    def touch(self) -> Any:
        """Returns the signals from two foot touch sensors."""
        return np.log1p(self.named.data.sensordata[['touch_toe',
                                                    'touch_heel']])

    def angmomentum(self) -> Any:
        """Returns the angular momentum of torso of the Cheetah about Y axis."""
        return self.named.data.subtree_angmom['torso'][1]


class Hopper(base.Task):
    """A Hopper's `Task` to train a standing and a jumping Hopper."""

    def __init__(self, hopping, forward=True, flip=False, random=None) -> None:
        """Initialize an instance of `Hopper`.

    Args:
      hopping: Boolean, if True the task is to hop forwards, otherwise it is to
        balance upright.
      random: Optional, either a `numpy.random.RandomState` instance, an
        integer seed for creating a new `RandomState`, or None to select a seed
        automatically (default).
    """
        self._hopping = hopping
        self._forward = 1 if forward else -1
        self._flip = flip
        self._timeout_progress = 0
        super(Hopper, self).__init__(random=random)

    def initialize_episode(self, physics) -> None:
        """Sets the state of the environment at the start of each episode."""
        randomizers.randomize_limited_and_rotational_joints(
            physics, self.random)
        self._timeout_progress = 0
        super(Hopper, self).initialize_episode(physics)

    def get_observation(self, physics) -> tp.Dict[str, Any]:
        """Returns an observation of positions, velocities and touch sensors."""
        obs = collections.OrderedDict()
        # Ignores horizontal position to maintain translational invariance:
        obs['position'] = physics.data.qpos[1:].copy()
        obs['velocity'] = physics.velocity()
        obs['touch'] = physics.touch()
        return obs

    def get_reward(self, physics) -> Any:
        """Returns a reward applicable to the performed task."""
        standing = rewards.tolerance(physics.height(), (_STAND_HEIGHT, 2))
        assert self._hopping
        if self._flip:
            hopping = rewards.tolerance(self._forward * physics.angmomentum(),
                                        bounds=(_SPIN_SPEED, float('inf')),
                                        margin=_SPIN_SPEED,
                                        value_at_margin=0,
                                        sigmoid='linear')
        else:
            hopping = rewards.tolerance(self._forward * physics.speed(),
                                        bounds=(_HOP_SPEED, float('inf')),
                                        margin=_HOP_SPEED / 2,
                                        value_at_margin=0.5,
                                        sigmoid='linear')
        return standing * hopping


#fre\common\envs\exorl\custom_dmc_tasks\jaco
"""A task where the goal is to move the hand close to a target prop or site."""

import collections

from dm_control import composer
from dm_control.composer import initializers
from dm_control.composer.variation import distributions
from dm_control.entities import props
from dm_control.manipulation.shared import arenas
from dm_control.manipulation.shared import cameras
from dm_control.manipulation.shared import constants
from dm_control.manipulation.shared import observations
from dm_control.manipulation.shared import robots
from dm_control.manipulation.shared import workspaces
from dm_control.utils import rewards
from dm_env import specs
import numpy as np

_ReachWorkspace = collections.namedtuple(
    '_ReachWorkspace', ['target_bbox', 'tcp_bbox', 'arm_offset'])

# Ensures that the props are not touching the table before settling.
_PROP_Z_OFFSET = 0.001

_DUPLO_WORKSPACE = _ReachWorkspace(
    target_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, _PROP_Z_OFFSET),
                                       upper=(0.1, 0.1, _PROP_Z_OFFSET)),
    tcp_bbox=workspaces.BoundingBox(lower=(-0.1, -0.1, 0.2),
                                    upper=(0.1, 0.1, 0.4)),
    arm_offset=robots.ARM_OFFSET)

_SITE_WORKSPACE = _ReachWorkspace(
    target_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02),
                                       upper=(0.2, 0.2, 0.4)),
    tcp_bbox=workspaces.BoundingBox(lower=(-0.2, -0.2, 0.02),
                                    upper=(0.2, 0.2, 0.4)),
    arm_offset=robots.ARM_OFFSET)

_TARGET_RADIUS = 0.05
_TIME_LIMIT = 10.

TASKS = [('reach_top_left', np.array([-0.09, 0.09, _PROP_Z_OFFSET])),
         ('reach_top_right', np.array([0.09, 0.09, _PROP_Z_OFFSET])),
         ('reach_bottom_left', np.array([-0.09, -0.09, _PROP_Z_OFFSET])),
         ('reach_bottom_right', np.array([0.09, -0.09, _PROP_Z_OFFSET]))]


def make(task_id, obs_type, seed):
    obs_settings = observations.VISION if obs_type == 'pixels' else observations.PERFECT_FEATURES
    task = _reach(task_id, obs_settings=obs_settings, use_site=True)
    return composer.Environment(task,
                                time_limit=_TIME_LIMIT,
                                random_state=seed)


class MultiTaskReach(composer.Task):
    """Bring the hand close to a target prop or site."""

    def __init__(self, task_id, arena, arm, hand, prop, obs_settings,
                 workspace, control_timestep):
        """Initializes a new `Reach` task.

    Args:
      arena: `composer.Entity` instance.
      arm: `robot_base.RobotArm` instance.
      hand: `robot_base.RobotHand` instance.
      prop: `composer.Entity` instance specifying the prop to reach to, or None
        in which case the target is a fixed site whose position is specified by
        the workspace.
      obs_settings: `observations.ObservationSettings` instance.
      workspace: `_ReachWorkspace` specifying the placement of the prop and TCP.
      control_timestep: Float specifying the control timestep in seconds.
    """
        self._arena = arena
        self._arm = arm
        self._hand = hand
        self._arm.attach(self._hand)
        self._arena.attach_offset(self._arm, offset=workspace.arm_offset)
        self.control_timestep = control_timestep
        self._tcp_initializer = initializers.ToolCenterPointInitializer(
            self._hand,
            self._arm,
            position=distributions.Uniform(*workspace.tcp_bbox),
            quaternion=workspaces.DOWN_QUATERNION)

        # Add custom camera observable.
        self._task_observables = cameras.add_camera_observables(
            arena, obs_settings, cameras.FRONT_CLOSE)

        if task_id == 'reach_multitask':
            self._targets = [target for (_, target) in TASKS]
        else:
            self._targets = [
                target for (task, target) in TASKS if task == task_id
            ]
            assert len(self._targets) > 0

        #target_pos_distribution = distributions.Uniform(*TASKS[task_id])
        self._prop = prop
        if prop:
            # The prop itself is used to visualize the target location.
            self._make_target_site(parent_entity=prop, visible=False)
            self._target = self._arena.add_free_entity(prop)
            self._prop_placer = initializers.PropPlacer(
                props=[prop],
                position=target_pos_distribution,
                quaternion=workspaces.uniform_z_rotation,
                settle_physics=True)
        else:
            if len(self._targets) == 1:
                self._target = self._make_target_site(parent_entity=arena,
                                                      visible=True)

            #obs = observable.MJCFFeature('pos', self._target)
            # obs.configure(**obs_settings.prop_pose._asdict())
            #self._task_observables['target_position'] = obs

        # Add sites for visualizing the prop and target bounding boxes.
        workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
                                 lower=workspace.tcp_bbox.lower,
                                 upper=workspace.tcp_bbox.upper,
                                 rgba=constants.GREEN,
                                 name='tcp_spawn_area')
        workspaces.add_bbox_site(body=self.root_entity.mjcf_model.worldbody,
                                 lower=workspace.target_bbox.lower,
                                 upper=workspace.target_bbox.upper,
                                 rgba=constants.BLUE,
                                 name='target_spawn_area')

    def _make_target_site(self, parent_entity, visible):
        return workspaces.add_target_site(
            body=parent_entity.mjcf_model.worldbody,
            radius=_TARGET_RADIUS,
            visible=visible,
            rgba=constants.RED,
            name='target_site')

    @property
    def root_entity(self):
        return self._arena

    @property
    def arm(self):
        return self._arm

    @property
    def hand(self):
        return self._hand

    def get_reward_spec(self):
        n = len(self._targets)
        return specs.Array(shape=(n,), dtype=np.float32, name='reward')

    @property
    def task_observables(self):
        return self._task_observables

    def get_reward(self, physics):
        hand_pos = physics.bind(self._hand.tool_center_point).xpos
        rews = []
        for target_pos in self._targets:
            distance = np.linalg.norm(hand_pos - target_pos)
            reward = rewards.tolerance(distance,
                                       bounds=(0, _TARGET_RADIUS),
                                       margin=_TARGET_RADIUS)
            rews.append(reward)
        rews = np.array(rews).astype(np.float32)
        if len(self._targets) == 1:
            return rews[0]
        return rews

    def initialize_episode(self, physics, random_state):
        self._hand.set_grasp(physics, close_factors=random_state.uniform())
        self._tcp_initializer(physics, random_state)
        if self._prop:
            self._prop_placer(physics, random_state)
        else:
            if len(self._targets) == 1:
                physics.bind(self._target).pos = self._targets[0]


def _reach(task_id, obs_settings, use_site):
    """Configure and instantiate a `Reach` task.

  Args:
    obs_settings: An `observations.ObservationSettings` instance.
    use_site: Boolean, if True then the target will be a fixed site, otherwise
      it will be a moveable Duplo brick.

  Returns:
    An instance of `reach.Reach`.
  """
    arena = arenas.Standard()
    arm = robots.make_arm(obs_settings=obs_settings)
    hand = robots.make_hand(obs_settings=obs_settings)
    if use_site:
        workspace = _SITE_WORKSPACE
        prop = None
    else:
        workspace = _DUPLO_WORKSPACE
        prop = props.Duplo(observable_options=observations.make_options(
            obs_settings, observations.FREEPROP_OBSERVABLES))
    task = MultiTaskReach(task_id,
                          arena=arena,
                          arm=arm,
                          hand=hand,
                          prop=prop,
                          obs_settings=obs_settings,
                          workspace=workspace,
                          control_timestep=constants.CONTROL_TIMESTEP)
    return task

#fre/common/envs/exorl/custom_dmc_tasks/quadrupled
"""Quadruped Domain."""

import collections
import typing as tp
from typing import Any
import os

from dm_control import mujoco
from dm_control.mujoco.wrapper import mjbindings
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.utils import containers
from dm_control.utils import rewards
from dm_control.utils import xml_tools
from lxml import etree
import numpy as np
from scipy import ndimage


enums = mjbindings.enums
mjlib = mjbindings.mjlib


_DEFAULT_TIME_LIMIT = 20
_CONTROL_TIMESTEP = .02

# Horizontal speeds above which the move reward is 1.
_RUN_SPEED = 5
_WALK_SPEED = 0.5

_JUMP_HEIGHT = 1.0

# Constants related to terrain generation.
_HEIGHTFIELD_ID = 0
_TERRAIN_SMOOTHNESS = 0.15  # 0.0: maximally bumpy; 1.0: completely smooth.
_TERRAIN_BUMP_SCALE = 2  # Spatial scale of terrain bumps (in meters).

# Named model elements.
_TOES = ['toe_front_left', 'toe_back_left', 'toe_back_right', 'toe_front_right']
_WALLS = ['wall_px', 'wall_py', 'wall_nx', 'wall_ny']

SUITE = containers.TaggedTasks()


def make(task,
         task_kwargs=None,
         environment_kwargs=None,
         visualize_reward: bool = False):
    task_kwargs = task_kwargs or {}
    if environment_kwargs is not None:
        task_kwargs = task_kwargs.copy()
        task_kwargs['environment_kwargs'] = environment_kwargs
    env = SUITE[task](**task_kwargs)
    env.task.visualize_reward = visualize_reward
    return env


# REMOVED since resources is undefined
# def get_model_and_assets() -> Tuple[Any, Any]:
#     """Returns a tuple containing the model XML string and a dict of assets."""
#     root_dir = os.path.dirname(os.path.dirname(__file__))
#     xml = resources.GetResource(
#         os.path.join(root_dir, 'custom_dmc_tasks', 'quadruped.xml'))
#     return xml, common.ASSETS


def make_model(floor_size=None, terrain: bool = False, rangefinders: bool = False,
               walls_and_ball: bool = False):
    """Returns the model XML string."""
    root_dir = os.path.dirname(os.path.dirname(__file__))
    xml_string = common.read_model(os.path.join(root_dir, 'custom_dmc_tasks', 'quadruped.xml'))
    parser = etree.XMLParser(remove_blank_text=True)
    mjcf = etree.XML(xml_string, parser)

    # Set floor size.
    if floor_size is not None:
        floor_geom = mjcf.find('.//geom[@name=\'floor\']')
        floor_geom.attrib['size'] = f'{floor_size} {floor_size} .5'

    # Remove walls, ball and target.
    if not walls_and_ball:
        for wall in _WALLS:
            wall_geom = xml_tools.find_element(mjcf, 'geom', wall)
            wall_geom.getparent().remove(wall_geom)

        # Remove ball.
        ball_body = xml_tools.find_element(mjcf, 'body', 'ball')
        ball_body.getparent().remove(ball_body)

        # Remove target.
        target_site = xml_tools.find_element(mjcf, 'site', 'target')
        target_site.getparent().remove(target_site)

    # Remove terrain.
    if not terrain:
        terrain_geom = xml_tools.find_element(mjcf, 'geom', 'terrain')
        terrain_geom.getparent().remove(terrain_geom)

    # Remove rangefinders if they're not used, as range computations can be
    # expensive, especially in a scene with heightfields.
    if not rangefinders:
        rangefinder_sensors = mjcf.findall('.//rangefinder')
        for rf in rangefinder_sensors:
            rf.getparent().remove(rf)

    return etree.tostring(mjcf, pretty_print=True)


@SUITE.add()
def stand(time_limit: int = _DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the Walk task."""
    xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
    physics = Physics.from_xml_string(xml_string, common.ASSETS)
    task = Stand(random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               control_timestep=_CONTROL_TIMESTEP,
                               **environment_kwargs)


@SUITE.add()
def jump(time_limit: int = _DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the Walk task."""
    xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
    physics = Physics.from_xml_string(xml_string, common.ASSETS)
    task = Jump(desired_height=_JUMP_HEIGHT, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               control_timestep=_CONTROL_TIMESTEP,
                               **environment_kwargs)


@SUITE.add()
def roll(time_limit: int = _DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the Walk task."""
    xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
    physics = Physics.from_xml_string(xml_string, common.ASSETS)
    task = Roll(desired_speed=_WALK_SPEED, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               control_timestep=_CONTROL_TIMESTEP,
                               **environment_kwargs)


@SUITE.add()
def roll_fast(time_limit: int = _DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the Walk task."""
    xml_string = make_model(floor_size=_DEFAULT_TIME_LIMIT * _WALK_SPEED)
    physics = Physics.from_xml_string(xml_string, common.ASSETS)
    task = Roll(desired_speed=_RUN_SPEED, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               control_timestep=_CONTROL_TIMESTEP,
                               **environment_kwargs)


@SUITE.add()
def escape(time_limit: int = _DEFAULT_TIME_LIMIT, random=None,
           environment_kwargs=None):
    """Returns the Escape task."""
    xml_string = make_model(floor_size=40, terrain=True, rangefinders=True)
    physics = Physics.from_xml_string(xml_string, common.ASSETS)
    task = Escape(random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               control_timestep=_CONTROL_TIMESTEP,
                               **environment_kwargs)


@SUITE.add()
def fetch(time_limit: int = _DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the Fetch task."""
    xml_string = make_model(walls_and_ball=True)
    physics = Physics.from_xml_string(xml_string, common.ASSETS)
    task = Fetch(random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics, task, time_limit=time_limit,
                               control_timestep=_CONTROL_TIMESTEP,
                               **environment_kwargs)


# pylint: disable=attribute-defined-outside-init
class Physics(mujoco.Physics):
    """Physics simulation with additional features for the Quadruped domain."""

    def _reload_from_data(self, data) -> None:
        super()._reload_from_data(data)
        # Clear cached sensor names when the physics is reloaded.
        self._sensor_types_to_names: tp.Dict[tp.Tuple[tp.Any, ...], tp.List[str]] = {}
        self._hinge_names: tp.List[str] = []

    def _get_sensor_names(self, *sensor_types) -> Any:
        try:
            sensor_names = self._sensor_types_to_names[sensor_types]
        except KeyError:
            [sensor_ids] = np.where(np.in1d(self.model.sensor_type, sensor_types))
            sensor_names = [self.model.id2name(s_id, 'sensor') for s_id in sensor_ids]
            self._sensor_types_to_names[sensor_types] = sensor_names
        return sensor_names

    def torso_upright(self) -> np.ndarray:
        """Returns the dot-product of the torso z-axis and the global z-axis."""
        return np.asarray(self.named.data.xmat['torso', 'zz'])

    def torso_velocity(self) -> Any:
        """Returns the velocity of the torso, in the local frame."""
        return self.named.data.sensordata['velocimeter'].copy()

    def com_height(self) -> Any:
        return self.named.data.sensordata['center_of_mass'].copy()[2]

    def egocentric_state(self) -> Any:
        """Returns the state without global orientation or position."""
        if not self._hinge_names:
            [hinge_ids] = np.nonzero(self.model.jnt_type ==
                                     enums.mjtJoint.mjJNT_HINGE)
            self._hinge_names = [self.model.id2name(j_id, 'joint')
                                 for j_id in hinge_ids]
        return np.hstack((self.named.data.qpos[self._hinge_names],
                          self.named.data.qvel[self._hinge_names],
                          self.data.act))

    def toe_positions(self) -> Any:
        """Returns toe positions in egocentric frame."""
        torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
        torso_pos = self.named.data.xpos['torso']
        torso_to_toe = self.named.data.xpos[_TOES] - torso_pos
        return torso_to_toe.dot(torso_frame)

    def force_torque(self) -> Any:
        """Returns scaled force/torque sensor readings at the toes."""
        force_torque_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_FORCE,
                                                      enums.mjtSensor.mjSENS_TORQUE)
        return np.arcsinh(self.named.data.sensordata[force_torque_sensors])

    def imu(self) -> Any:
        """Returns IMU-like sensor readings."""
        imu_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_GYRO,
                                             enums.mjtSensor.mjSENS_ACCELEROMETER)
        return self.named.data.sensordata[imu_sensors]

    def rangefinder(self) -> Any:
        """Returns scaled rangefinder sensor readings."""
        rf_sensors = self._get_sensor_names(enums.mjtSensor.mjSENS_RANGEFINDER)
        rf_readings = self.named.data.sensordata[rf_sensors]
        no_intersection = -1.0
        return np.where(rf_readings == no_intersection, 1.0, np.tanh(rf_readings))

    def origin_distance(self) -> np.ndarray:
        """Returns the distance from the origin to the workspace."""
        return np.asarray(np.linalg.norm(self.named.data.site_xpos['workspace']))

    def origin(self) -> Any:
        """Returns origin position in the torso frame."""
        torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
        torso_pos = self.named.data.xpos['torso']
        return -torso_pos.dot(torso_frame)

    def ball_state(self) -> Any:
        """Returns ball position and velocity relative to the torso frame."""
        data = self.named.data
        torso_frame = data.xmat['torso'].reshape(3, 3)
        ball_rel_pos = data.xpos['ball'] - data.xpos['torso']
        ball_rel_vel = data.qvel['ball_root'][:3] - data.qvel['root'][:3]
        ball_rot_vel = data.qvel['ball_root'][3:]
        ball_state = np.vstack((ball_rel_pos, ball_rel_vel, ball_rot_vel))
        return ball_state.dot(torso_frame).ravel()

    def target_position(self) -> Any:
        """Returns target position in torso frame."""
        torso_frame = self.named.data.xmat['torso'].reshape(3, 3)
        torso_pos = self.named.data.xpos['torso']
        torso_to_target = self.named.data.site_xpos['target'] - torso_pos
        return torso_to_target.dot(torso_frame)

    def ball_to_target_distance(self) -> Any:
        """Returns horizontal distance from the ball to the target."""
        ball_to_target = (self.named.data.site_xpos['target'] -
                          self.named.data.xpos['ball'])
        return np.linalg.norm(ball_to_target[:2])

    def self_to_ball_distance(self) -> Any:
        """Returns horizontal distance from the quadruped workspace to the ball."""
        self_to_ball = (self.named.data.site_xpos['workspace']
                        - self.named.data.xpos['ball'])
        return np.linalg.norm(self_to_ball[:2])


def _find_non_contacting_height(physics, orientation, x_pos: float = 0.0, y_pos: float = 0.0) -> None:
    """Find a height with no contacts given a body orientation.
    Args:
      physics: An instance of `Physics`.
      orientation: A quaternion.
      x_pos: A float. Position along global x-axis.
      y_pos: A float. Position along global y-axis.
    Raises:
      RuntimeError: If a non-contacting configuration has not been found after
      10,000 attempts.
    """
    z_pos = 0.0  # Start embedded in the floor.
    num_contacts = 1
    num_attempts = 0
    # Move up in 1cm increments until no contacts.
    while num_contacts > 0:
        try:
            with physics.reset_context():
                physics.named.data.qpos['root'][:3] = x_pos, y_pos, z_pos
                physics.named.data.qpos['root'][3:] = orientation
        except control.PhysicsError:
            # We may encounter a PhysicsError here due to filling the contact
            # buffer, in which case we simply increment the height and continue.
            pass
        num_contacts = physics.data.ncon
        z_pos += 0.01
        num_attempts += 1
        if num_attempts > 10000:
            raise RuntimeError('Failed to find a non-contacting configuration.')


def _common_observations(physics) -> tp.Dict[str, Any]:
    """Returns the observations common to all tasks."""
    obs = collections.OrderedDict()
    obs['egocentric_state'] = physics.egocentric_state()
    obs['torso_velocity'] = physics.torso_velocity()
    obs['torso_upright'] = physics.torso_upright()
    obs['imu'] = physics.imu()
    obs['force_torque'] = physics.force_torque()
    return obs


def _upright_reward(physics, deviation_angle: int = 0):
    """Returns a reward proportional to how upright the torso is.
    Args:
      physics: an instance of `Physics`.
      deviation_angle: A float, in degrees. The reward is 0 when the torso is
        exactly upside-down and 1 when the torso's z-axis is less than
        `deviation_angle` away from the global z-axis.
    """
    deviation = np.cos(np.deg2rad(deviation_angle))
    return rewards.tolerance(
        physics.torso_upright(),
        bounds=(deviation, float('inf')),
        sigmoid='linear',
        margin=1 + deviation,
        value_at_margin=0)


class Move(base.Task):
    """A quadruped task solved by moving forward at a designated speed."""

    def __init__(self, desired_speed, random=None) -> None:
        """Initializes an instance of `Move`.
        Args:
          desired_speed: A float. If this value is zero, reward is given simply
            for standing upright. Otherwise this specifies the horizontal velocity
            at which the velocity-dependent reward component is maximized.
          random: Optional, either a `numpy.random.RandomState` instance, an
            integer seed for creating a new `RandomState`, or None to select a seed
            automatically (default).
        """
        self._desired_speed = desired_speed
        super().__init__(random=random)

    def initialize_episode(self, physics) -> None:
        """Sets the state of the environment at the start of each episode.
        Args:
          physics: An instance of `Physics`.
        """
        # Initial configuration.
        orientation = self.random.randn(4)
        orientation /= np.linalg.norm(orientation)
        _find_non_contacting_height(physics, orientation)
        super().initialize_episode(physics)

    def get_observation(self, physics) -> tp.Dict[str, Any]:
        """Returns an observation to the agent."""
        return _common_observations(physics)

    def get_reward(self, physics) -> Any:
        """Returns a reward to the agent."""

        # Move reward term.
        move_reward = rewards.tolerance(
            physics.torso_velocity()[0],
            bounds=(self._desired_speed, float('inf')),
            margin=self._desired_speed,
            value_at_margin=0.5,
            sigmoid='linear')

        return _upright_reward(physics) * move_reward


class Stand(base.Task):
    """A quadruped task solved by moving forward at a designated speed."""

    def __init__(self, random=None) -> None:
        """Initializes an instance of `Move`.
        Args:
          desired_speed: A float. If this value is zero, reward is given simply
            for standing upright. Otherwise this specifies the horizontal velocity
            at which the velocity-dependent reward component is maximized.
          random: Optional, either a `numpy.random.RandomState` instance, an
            integer seed for creating a new `RandomState`, or None to select a seed
            automatically (default).
        """
        super().__init__(random=random)

    def initialize_episode(self, physics) -> None:
        """Sets the state of the environment at the start of each episode.
        Args:
          physics: An instance of `Physics`.
        """
        # Initial configuration.
        orientation = self.random.randn(4)
        orientation /= np.linalg.norm(orientation)
        _find_non_contacting_height(physics, orientation)
        super().initialize_episode(physics)

    def get_observation(self, physics) -> tp.Dict[str, Any]:
        """Returns an observation to the agent."""
        return _common_observations(physics)

    def get_reward(self, physics) -> Any:
        """Returns a reward to the agent."""

        return _upright_reward(physics)


class Jump(base.Task):
    """A quadruped task solved by moving forward at a designated speed."""

    def __init__(self, desired_height, random=None) -> None:
        """Initializes an instance of `Move`.
        Args:
          desired_speed: A float. If this value is zero, reward is given simply
            for standing upright. Otherwise this specifies the horizontal velocity
            at which the velocity-dependent reward component is maximized.
          random: Optional, either a `numpy.random.RandomState` instance, an
            integer seed for creating a new `RandomState`, or None to select a seed
            automatically (default).
        """
        self._desired_height = desired_height
        super().__init__(random=random)

    def initialize_episode(self, physics) -> None:
        """Sets the state of the environment at the start of each episode.
        Args:
          physics: An instance of `Physics`.
        """
        # Initial configuration.
        orientation = self.random.randn(4)
        orientation /= np.linalg.norm(orientation)
        _find_non_contacting_height(physics, orientation)
        super().initialize_episode(physics)

    def get_observation(self, physics) -> tp.Dict[str, Any]:
        """Returns an observation to the agent."""
        return _common_observations(physics)

    def get_reward(self, physics) -> Any:
        """Returns a reward to the agent."""

        # Move reward term.
        jump_up = rewards.tolerance(
            physics.com_height(),
            bounds=(self._desired_height, float('inf')),
            margin=self._desired_height,
            value_at_margin=0.5,
            sigmoid='linear')

        return _upright_reward(physics) * jump_up


class Roll(base.Task):
    """A quadruped task solved by moving forward at a designated speed."""

    def __init__(self, desired_speed, random=None) -> None:
        """Initializes an instance of `Move`.
        Args:
          desired_speed: A float. If this value is zero, reward is given simply
            for standing upright. Otherwise this specifies the horizontal velocity
            at which the velocity-dependent reward component is maximized.
          random: Optional, either a `numpy.random.RandomState` instance, an
            integer seed for creating a new `RandomState`, or None to select a seed
            automatically (default).
        """
        self._desired_speed = desired_speed
        super().__init__(random=random)

    def initialize_episode(self, physics) -> None:
        """Sets the state of the environment at the start of each episode.
        Args:
          physics: An instance of `Physics`.
        """
        # Initial configuration.
        orientation = self.random.randn(4)
        orientation /= np.linalg.norm(orientation)
        _find_non_contacting_height(physics, orientation)
        super().initialize_episode(physics)

    def get_observation(self, physics) -> tp.Dict[str, Any]:
        """Returns an observation to the agent."""
        return _common_observations(physics)

    def get_reward(self, physics) -> Any:
        """Returns a reward to the agent."""
        # Move reward term.
        move_reward = rewards.tolerance(
            np.linalg.norm(physics.torso_velocity()),
            bounds=(self._desired_speed, float('inf')),
            margin=self._desired_speed,
            value_at_margin=0.5,
            sigmoid='linear')

        return _upright_reward(physics) * move_reward


class Escape(base.Task):
    """A quadruped task solved by escaping a bowl-shaped terrain."""

    def initialize_episode(self, physics) -> None:
        """Sets the state of the environment at the start of each episode.
        Args:
          physics: An instance of `Physics`.
        """
        # Get heightfield resolution, assert that it is square.
        res = physics.model.hfield_nrow[_HEIGHTFIELD_ID]
        assert res == physics.model.hfield_ncol[_HEIGHTFIELD_ID]
        # Sinusoidal bowl shape.
        row_grid, col_grid = np.ogrid[-1:1:res * 1j, -1:1:res * 1j]
        radius = np.clip(np.sqrt(col_grid**2 + row_grid**2), .04, 1)
        bowl_shape = .5 - np.cos(2 * np.pi * radius) / 2
        # Random smooth bumps.
        terrain_size = 2 * physics.model.hfield_size[_HEIGHTFIELD_ID, 0]
        bump_res = int(terrain_size / _TERRAIN_BUMP_SCALE)
        bumps = self.random.uniform(_TERRAIN_SMOOTHNESS, 1, (bump_res, bump_res))
        smooth_bumps = ndimage.zoom(bumps, res / float(bump_res))
        # Terrain is elementwise product.
        terrain = bowl_shape * smooth_bumps
        start_idx = physics.model.hfield_adr[_HEIGHTFIELD_ID]
        physics.model.hfield_data[start_idx:start_idx + res**2] = terrain.ravel()
        super().initialize_episode(physics)

        # If we have a rendering context, we need to re-upload the modified
        # heightfield data.
        if physics.contexts:
            with physics.contexts.gl.make_current() as ctx:
                ctx.call(mjlib.mjr_uploadHField,
                         physics.model.ptr,
                         physics.contexts.mujoco.ptr,
                         _HEIGHTFIELD_ID)

        # Initial configuration.
        orientation = self.random.randn(4)
        orientation /= np.linalg.norm(orientation)
        _find_non_contacting_height(physics, orientation)

    def get_observation(self, physics) -> tp.Dict[str, Any]:
        """Returns an observation to the agent."""
        obs = _common_observations(physics)
        obs['origin'] = physics.origin()
        obs['rangefinder'] = physics.rangefinder()
        return obs

    def get_reward(self, physics) -> Any:
        """Returns a reward to the agent."""

        # Escape reward term.
        terrain_size = physics.model.hfield_size[_HEIGHTFIELD_ID, 0]
        escape_reward = rewards.tolerance(
            physics.origin_distance(),
            bounds=(terrain_size, float('inf')),
            margin=terrain_size,
            value_at_margin=0,
            sigmoid='linear')

        return _upright_reward(physics, deviation_angle=20) * escape_reward


class Fetch(base.Task):
    """A quadruped task solved by bringing a ball to the origin."""

    def initialize_episode(self, physics) -> None:
        """Sets the state of the environment at the start of each episode.
        Args:
          physics: An instance of `Physics`.
        """
        # Initial configuration, random azimuth and horizontal position.
        azimuth = self.random.uniform(0, 2 * np.pi)
        orientation = np.array((np.cos(azimuth / 2), 0, 0, np.sin(azimuth / 2)))
        spawn_radius = 0.9 * physics.named.model.geom_size['floor', 0]
        x_pos, y_pos = self.random.uniform(-spawn_radius, spawn_radius, size=(2,))
        _find_non_contacting_height(physics, orientation, x_pos, y_pos)

        # Initial ball state.
        physics.named.data.qpos['ball_root'][:2] = self.random.uniform(
            -spawn_radius, spawn_radius, size=(2,))
        physics.named.data.qpos['ball_root'][2] = 2
        physics.named.data.qvel['ball_root'][:2] = 5 * self.random.randn(2)
        super().initialize_episode(physics)

    def get_observation(self, physics) -> tp.Dict[str, Any]:
        """Returns an observation to the agent."""
        obs = _common_observations(physics)
        obs['ball_state'] = physics.ball_state()
        obs['target_position'] = physics.target_position()
        return obs

    def get_reward(self, physics) -> Any:
        """Returns a reward to the agent."""

        # Reward for moving close to the ball.
        arena_radius = physics.named.model.geom_size['floor', 0] * np.sqrt(2)
        workspace_radius = physics.named.model.site_size['workspace', 0]
        ball_radius = physics.named.model.geom_size['ball', 0]
        reach_reward = rewards.tolerance(
            physics.self_to_ball_distance(),
            bounds=(0, workspace_radius + ball_radius),
            sigmoid='linear',
            margin=arena_radius, value_at_margin=0)

        # Reward for bringing the ball to the target.
        target_radius = physics.named.model.site_size['target', 0]
        fetch_reward = rewards.tolerance(
            physics.ball_to_target_distance(),
            bounds=(0, target_radius),
            sigmoid='linear',
            margin=arena_radius, value_at_margin=0)

        reach_then_fetch = reach_reward * (0.5 + 0.5 * fetch_reward)

        return _upright_reward(physics) * reach_then_fetch

#fre/common/envs/exorl/custom_dmc_tasks/walker
"""Planar Walker Domain."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
from typing import Any, Tuple
import typing as tp
import os

from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.suite.utils import randomizers
from dm_control.utils import containers
from dm_control.utils import rewards
from dm_control.utils import io as resources

_CONTROL_TIMESTEP: float
_DEFAULT_TIME_LIMIT: int
_RUN_SPEED: int
_SPIN_SPEED: int
_STAND_HEIGHT: float
_WALK_SPEED: int
# from dm_control import suite  # TODO useless?

_DEFAULT_TIME_LIMIT = 25
_CONTROL_TIMESTEP = .025

# Minimal height of torso over foot above which stand reward is 1.
_STAND_HEIGHT = 1.2

# Horizontal speeds (meters/second) above which move reward is 1.
_WALK_SPEED = 1
_RUN_SPEED = 8
_SPIN_SPEED = 5

SUITE = containers.TaggedTasks()


def make(task,
         task_kwargs=None,
         environment_kwargs=None,
         visualize_reward: bool = False):
    task_kwargs = task_kwargs or {}
    if environment_kwargs is not None:
        task_kwargs = task_kwargs.copy()
        task_kwargs['environment_kwargs'] = environment_kwargs
    env = SUITE[task](**task_kwargs)
    env.task.visualize_reward = visualize_reward
    return env


def get_model_and_assets() -> Tuple[Any, Any]:
    """Returns a tuple containing the model XML string and a dict of assets."""
    root_dir = os.path.dirname(os.path.dirname(__file__))
    xml = resources.GetResource(os.path.join(root_dir, 'custom_dmc_tasks',
                                             'walker.xml'))
    return xml, common.ASSETS


@SUITE.add('benchmarking')
def flip(time_limit: int = _DEFAULT_TIME_LIMIT,
         random=None,
         environment_kwargs=None):
    """Returns the Run task."""
    physics = Physics.from_xml_string(*get_model_and_assets())
    task = PlanarWalker(move_speed=_RUN_SPEED,
                        forward=True,
                        flip=True,
                        random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(physics,
                               task,
                               time_limit=time_limit,
                               control_timestep=_CONTROL_TIMESTEP,
                               **environment_kwargs)


class Physics(mujoco.Physics):
    """Physics simulation with additional features for the Walker domain."""

    def torso_upright(self) -> Any:
        """Returns projection from z-axes of torso to the z-axes of world."""
        return self.named.data.xmat['torso', 'zz']

    def torso_height(self) -> Any:
        """Returns the height of the torso."""
        return self.named.data.xpos['torso', 'z']

    def horizontal_velocity(self) -> Any:
        """Returns the horizontal velocity of the center-of-mass."""
        return self.named.data.sensordata['torso_subtreelinvel'][0]

    def orientations(self) -> Any:
        """Returns planar orientations of all bodies."""
        return self.named.data.xmat[1:, ['xx', 'xz']].ravel()

    def angmomentum(self) -> Any:
        """Returns the angular momentum of torso of the Cheetah about Y axis."""
        return self.named.data.subtree_angmom['torso'][1]


class PlanarWalker(base.Task):
    """A planar walker task."""

    def __init__(self, move_speed, forward=True, flip=False, random=None) -> None:
        """Initializes an instance of `PlanarWalker`.

    Args:
      move_speed: A float. If this value is zero, reward is given simply for
        standing up. Otherwise this specifies a target horizontal velocity for
        the walking task.
      random: Optional, either a `numpy.random.RandomState` instance, an
        integer seed for creating a new `RandomState`, or None to select a seed
        automatically (default).
    """
        self._move_speed = move_speed
        self._forward = 1 if forward else -1
        self._flip = flip
        super(PlanarWalker, self).__init__(random=random)

    def initialize_episode(self, physics) -> None:
        """Sets the state of the environment at the start of each episode.

    In 'standing' mode, use initial orientation and small velocities.
    In 'random' mode, randomize joint angles and let fall to the floor.

    Args:
      physics: An instance of `Physics`.

    """
        randomizers.randomize_limited_and_rotational_joints(
            physics, self.random)
        super(PlanarWalker, self).initialize_episode(physics)

    def get_observation(self, physics) -> tp.Dict[str, Any]:
        """Returns an observation of body orientations, height and velocites."""
        obs = collections.OrderedDict()
        obs['orientations'] = physics.orientations()
        obs['height'] = physics.torso_height()
        obs['velocity'] = physics.velocity()
        return obs

    def get_reward(self, physics) -> Any:
        """Returns a reward to the agent."""
        standing = rewards.tolerance(physics.torso_height(),
                                     bounds=(_STAND_HEIGHT, float('inf')),
                                     margin=_STAND_HEIGHT / 2)
        upright = (1 + physics.torso_upright()) / 2
        stand_reward = (3 * standing + upright) / 4

        if self._flip:
            move_reward = rewards.tolerance(self._forward *
                                            physics.angmomentum(),
                                            bounds=(_SPIN_SPEED, float('inf')),
                                            margin=_SPIN_SPEED,
                                            value_at_margin=0,
                                            sigmoid='linear')
        else:
            move_reward = rewards.tolerance(
                self._forward * physics.horizontal_velocity(),
                bounds=(self._move_speed, float('inf')),
                margin=self._move_speed / 2,
                value_at_margin=0.5,
                sigmoid='linear')

        return stand_reward * (5 * move_reward + 1) / 6

#fre/common/envs/exorl/dmc.py
import unittest
import dataclasses
from collections import OrderedDict, deque
import typing as tp
from typing import Any

from gym import core, spaces
from dm_env import Environment
from dm_env import StepType, specs
from dm_control import suite
from dm_control.suite.wrappers import action_scale, pixels
"""import fre.common.envs.exorl.custom_dmc_tasks as cdmc"""
import numpy as np


S = tp.TypeVar("S", bound="TimeStep")
Env = tp.Union["EnvWrapper", Environment]


@dataclasses.dataclass
class TimeStep:
    step_type: StepType
    reward: float
    discount: float
    observation: np.ndarray
    physics: np.ndarray = dataclasses.field(default=np.ndarray([]), init=False)

    def first(self) -> bool:
        return self.step_type == StepType.FIRST  # type: ignore

    def mid(self) -> bool:
        return self.step_type == StepType.MID  # type: ignore

    def last(self) -> bool:
        return self.step_type == StepType.LAST  # type: ignore

    def __getitem__(self, attr: str) -> tp.Any:
        return getattr(self, attr)

    def _replace(self: S, **kwargs: tp.Any) -> S:
        for name, val in kwargs.items():
            setattr(self, name, val)
        return self


@dataclasses.dataclass
class ExtendedTimeStep(TimeStep):
    action: tp.Any


class EnvWrapper:
    def __init__(self, env: Env) -> None:
        self._env = env

    def _augment_time_step(self, time_step: TimeStep, action: tp.Optional[np.ndarray] = None) -> TimeStep:
        if not isinstance(time_step, TimeStep):
            # dm_env time step is a named tuple
            time_step = TimeStep(**time_step._asdict())
        if self.physics is not None:
            return time_step._replace(physics=self.physics.get_state())
        else:
            return time_step

    def reset(self) -> TimeStep:
        time_step = self._env.reset()
        return self._augment_time_step(time_step)

    def step(self, action: np.ndarray) -> TimeStep:
        time_step = self._env.step(action)
        return self._augment_time_step(time_step, action)

    def observation_spec(self) -> tp.Any:
        assert isinstance(self, EnvWrapper)
        return self._env.observation_spec()

    def action_spec(self) -> specs.Array:
        return self._env.action_spec()

    def render(self, *args: tp.Any, **kwargs: tp.Any) -> np.ndarray:
        return self._env.render(*args, **kwargs)  # type: ignore

    @property
    def base_env(self) -> tp.Any:
        env = self._env
        if isinstance(env, EnvWrapper):
            return self.base_env
        return env

    @property
    def physics(self) -> tp.Any:
        if hasattr(self._env, "physics"):
            return self._env.physics

    def __getattr__(self, name):
        return getattr(self._env, name)


class FlattenJacoObservationWrapper(EnvWrapper):
    def __init__(self, env: Env) -> None:
        super().__init__(env)
        self._obs_spec = OrderedDict()
        wrapped_obs_spec = env.observation_spec().copy()
        if 'front_close' in wrapped_obs_spec:
            spec = wrapped_obs_spec['front_close']
            # drop batch dim
            self._obs_spec['pixels'] = specs.BoundedArray(shape=spec.shape[1:],
                                                          dtype=spec.dtype,
                                                          minimum=spec.minimum,
                                                          maximum=spec.maximum,
                                                          name='pixels')
            wrapped_obs_spec.pop('front_close')

        for spec in wrapped_obs_spec.values():
            assert spec.dtype == np.float64
            assert type(spec) == specs.Array
        dim = np.sum(
            np.fromiter((int(np.prod(spec.shape))  # type: ignore
                         for spec in wrapped_obs_spec.values()), np.int32))

        self._obs_spec['observations'] = specs.Array(shape=(dim,),
                                                     dtype=np.float32,
                                                     name='observations')

    def observation_spec(self) -> tp.Any:
        return self._obs_spec

    def _augment_time_step(self, time_step: TimeStep, action: tp.Optional[np.ndarray] = None) -> TimeStep:
        super()._augment_time_step(time_step=time_step, action=action)
        obs = OrderedDict()

        # TODO: this is badly typed since observation is a dict in this case
        if 'front_close' in time_step.observation:
            pixels = time_step.observation['front_close']
            time_step.observation.pop('front_close')  # type: ignore
            pixels = np.squeeze(pixels)
            obs['pixels'] = pixels

        features = []
        for feature in time_step.observation.values():  # type: ignore
            features.append(feature.ravel())
        obs['observations'] = np.concatenate(features, axis=0)
        return time_step._replace(observation=obs)


class ActionRepeatWrapper(EnvWrapper):
    def __init__(self, env: tp.Any, num_repeats: int) -> None:
        super().__init__(env)
        self._num_repeats = num_repeats

    def step(self, action: np.ndarray) -> TimeStep:
        reward = 0.0
        discount = 1.0
        for _ in range(self._num_repeats):
            time_step = self._env.step(action)
            reward += (time_step.reward or 0.0) * discount
            discount *= time_step.discount
            if time_step.last():
                break

        return time_step._replace(reward=reward, discount=discount)


class FrameStackWrapper(EnvWrapper):
    def __init__(self, env: Env, num_frames: int, pixels_key: str = 'pixels') -> None:
        super().__init__(env)
        self._num_frames = num_frames
        self._frames: tp.Deque[np.ndarray] = deque([], maxlen=num_frames)
        self._pixels_key = pixels_key

        wrapped_obs_spec = env.observation_spec()
        assert pixels_key in wrapped_obs_spec

        pixels_shape = wrapped_obs_spec[pixels_key].shape
        # remove batch dim
        if len(pixels_shape) == 4:
            pixels_shape = pixels_shape[1:]
        self._obs_spec = specs.BoundedArray(shape=np.concatenate(
            [[pixels_shape[2] * num_frames], pixels_shape[:2]], axis=0),
            dtype=np.uint8,
            minimum=0,
            maximum=255,
            name='observation')

    def _augment_time_step(self, time_step: TimeStep, action: tp.Optional[np.ndarray] = None) -> TimeStep:
        super()._augment_time_step(time_step=time_step, action=action)
        assert len(self._frames) == self._num_frames
        obs = np.concatenate(list(self._frames), axis=0)
        return time_step._replace(observation=obs)

    def _extract_pixels(self, time_step: TimeStep) -> np.ndarray:
        pixels_ = time_step.observation[self._pixels_key]
        # remove batch dim
        if len(pixels_.shape) == 4:
            pixels_ = pixels_[0]
        return pixels_.transpose(2, 0, 1).copy()

    def reset(self) -> TimeStep:
        time_step = self._env.reset()
        pixels_ = self._extract_pixels(time_step)
        for _ in range(self._num_frames):
            self._frames.append(pixels_)
        return self._augment_time_step(time_step)

    def step(self, action: np.ndarray) -> TimeStep:
        time_step = self._env.step(action)
        pixels_ = self._extract_pixels(time_step)
        self._frames.append(pixels_)
        return self._augment_time_step(time_step)


class ActionDTypeWrapper(EnvWrapper):
    def __init__(self, env: Env, dtype) -> None:
        super().__init__(env)
        wrapped_action_spec = env.action_spec()
        self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
                                               dtype,
                                               wrapped_action_spec.minimum,
                                               wrapped_action_spec.maximum,
                                               'action')

    def action_spec(self) -> specs.BoundedArray:
        return self._action_spec

    def step(self, action) -> Any:
        action = action.astype(self._env.action_spec().dtype)
        return self._env.step(action)


class ObservationDTypeWrapper(EnvWrapper):
    def __init__(self, env: Env, dtype) -> None:
        super().__init__(env)
        self._dtype = dtype
        wrapped_obs_spec = env.observation_spec()['observations']
        self._obs_spec = specs.Array(wrapped_obs_spec.shape, dtype,
                                     'observation')

    def _augment_time_step(self, time_step: TimeStep, action: tp.Optional[np.ndarray] = None) -> TimeStep:
        obs = time_step.observation['observations'].astype(self._dtype)
        return time_step._replace(observation=obs)

    def observation_spec(self) -> Any:
        return self._obs_spec


class ExtendedTimeStepWrapper(EnvWrapper):

    def _augment_time_step(self, time_step: TimeStep, action: tp.Optional[np.ndarray] = None) -> TimeStep:
        if action is None:
            action_spec = self.action_spec()
            action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
        ts = ExtendedTimeStep(observation=time_step.observation,
                              step_type=time_step.step_type,
                              action=action,
                              reward=time_step.reward or 0.0,
                              discount=time_step.discount or 1.0)
        return super()._augment_time_step(time_step=ts, action=action)


def _make_jaco(obs_type, domain, task, frame_stack, action_repeat, seed) -> FlattenJacoObservationWrapper:
    env = cdmc.make_jaco(task, obs_type, seed)
    env = ActionDTypeWrapper(env, np.float32)
    env = ActionRepeatWrapper(env, action_repeat)
    env = FlattenJacoObservationWrapper(env)
    return env


def _make_dmc(obs_type, domain, task, frame_stack, action_repeat, seed):
    visualize_reward = False
    if (domain, task) in suite.ALL_TASKS:
        env = suite.load(domain,
                         task,
                         task_kwargs=dict(random=seed),
                         environment_kwargs=dict(flat_observation=True),
                         visualize_reward=visualize_reward)
    else:
        env = cdmc.make(domain,
                        task,
                        task_kwargs=dict(random=seed),
                        environment_kwargs=dict(flat_observation=True),
                        visualize_reward=visualize_reward)
    env = ActionDTypeWrapper(env, np.float32)
    env = ActionRepeatWrapper(env, action_repeat)
    if obs_type == 'pixels':
        # zoom in camera for quadruped
        camera_id = dict(quadruped=2).get(domain, 0)
        render_kwargs = dict(height=84, width=84, camera_id=camera_id)
        env = pixels.Wrapper(env,
                             pixels_only=True,
                             render_kwargs=render_kwargs)
    return env


def make(
    name: str, obs_type='states', frame_stack=1, action_repeat=1,
    seed=1,
) -> EnvWrapper:
    assert obs_type in ['states', 'pixels']
    if name.startswith('point_mass_maze'):
        domain = 'point_mass_maze'
        _, _, _, task = name.split('_', 3)
    else:
        domain, task = name.split('_', 1)
    domain = dict(cup='ball_in_cup').get(domain, domain)

    make_fn = _make_jaco if domain == 'jaco' else _make_dmc
    env = make_fn(obs_type, domain, task, frame_stack, action_repeat, seed)

    if obs_type == 'pixels':
        env = FrameStackWrapper(env, frame_stack)
    else:
        env = ObservationDTypeWrapper(env, np.float32)

    env = action_scale.Wrapper(env, minimum=-1.0, maximum=+1.0)
    env = ExtendedTimeStepWrapper(env)
    return env


def extract_physics(env: Env) -> tp.Dict[str, float]:
    """Extract some physics available in the env"""
    output = {}
    names = ["torso_height", "torso_upright", "horizontal_velocity", "torso_velocity"]
    for name in names:
        if not hasattr(env.physics, name):
            continue
        val: tp.Union[float, np.ndarray] = getattr(env.physics, name)()
        if isinstance(val, (int, float)) or not val.ndim:
            output[name] = float(val)
        else:
            for k, v in enumerate(val):
                output[f"{name}#{k}"] = float(v)
    return output


class FloatStats:
    """Handle for keeping track of the statistics of a float variable"""

    def __init__(self) -> None:
        self.min = np.inf
        self.max = -np.inf
        self.mean = 0.0
        self._count = 0

    def add(self, value: float) -> "FloatStats":
        self.min = min(value, self.min)
        self.max = max(value, self.max)
        self._count += 1
        self.mean = (self._count - 1) / self._count * self.mean + 1 / self._count * value
        return self

    def items(self) -> tp.Iterator[tp.Tuple[str, float]]:
        for name, val in self.__dict__.items():
            if not name.startswith("_"):
                yield name, val


class PhysicsAggregator:
    """Aggregate stats on the physics of an environment"""

    def __init__(self) -> None:
        self.stats: tp.Dict[str, FloatStats] = {}

    def add(self, env: Env) -> "PhysicsAggregator":
        phy = extract_physics(env)
        for key, val in phy.items():
            self.stats.setdefault(key, FloatStats()).add(val)
        return self

    def dump(self) -> tp.Iterator[tp.Tuple[str, float]]:
        """Exports all statistics and reset the statistics"""
        for key, stats in self.stats.items():
            for stat, val in stats.items():
                yield (f'{key}/{stat}', val)
        self.stats.clear()


def _spec_to_box(spec, dtype):
    def extract_min_max(s):
        assert s.dtype == np.float64 or s.dtype == np.float32
        dim = int(np.prod(s.shape))
        if type(s) == specs.Array:
            bound = np.inf * np.ones(dim, dtype=np.float32)
            return -bound, bound
        elif type(s) == specs.BoundedArray:
            zeros = np.zeros(dim, dtype=np.float32)
            return s.minimum + zeros, s.maximum + zeros

    mins, maxs = [], []
    for s in spec:
        mn, mx = extract_min_max(s)
        mins.append(mn)
        maxs.append(mx)
    low = np.concatenate(mins, axis=0).astype(dtype)
    high = np.concatenate(maxs, axis=0).astype(dtype)
    assert low.shape == high.shape
    return spaces.Box(low, high, dtype=dtype)


def _flatten_obs(obs):
    obs_pieces = []
    v = obs
    flat = np.array([v]) if np.isscalar(v) else v.ravel()
    obs_pieces.append(flat)
    return np.concatenate(obs_pieces, axis=0)


class DMCWrapper(core.Env):
    def __init__(
            self,
            env,
            seed,
            from_pixels=False,
            height=84,
            width=84,
            camera_id=0,
            frame_skip=1,
            channels_first=True,
    ):
        self._from_pixels = from_pixels
        self._height = height
        self._width = width
        self._camera_id = camera_id
        self._frame_skip = frame_skip
        self._channels_first = channels_first

        self._env = env

        # true and normalized action spaces
        self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32)
        self._norm_action_space = spaces.Box(
            low=-1.0,
            high=1.0,
            shape=self._true_action_space.shape,
            dtype=np.float32
        )

        # create observation space
        if from_pixels:
            shape = [3, height, width] if channels_first else [height, width, 3]
            self._observation_space = spaces.Box(
                low=0, high=255, shape=shape, dtype=np.uint8
            )
        else:
            self._observation_space = _spec_to_box([self._env.observation_spec()], np.float32)

        self._state_space = _spec_to_box([self._env.observation_spec()], np.float32)

        self.current_state = None

        self.seed(seed)

    def __getattr__(self, name):
        return getattr(self._env, name)

    def _get_obs(self, time_step):
        if self._from_pixels:
            obs = self.render(
                height=self._height,
                width=self._width,
                camera_id=self._camera_id
            )
            if self._channels_first:
                obs = obs.transpose(2, 0, 1).copy()
        else:
            obs = _flatten_obs(time_step.observation)
        return obs

    def _convert_action(self, action):
        true_delta = self._true_action_space.high - self._true_action_space.low
        norm_delta = self._norm_action_space.high - self._norm_action_space.low
        action = (action - self._norm_action_space.low) / norm_delta
        action = action * true_delta + self._true_action_space.low
        return action

    @property
    def observation_space(self):
        return self._observation_space

    @property
    def state_space(self):
        return self._state_space

    @property
    def action_space(self):
        return self._norm_action_space

    @property
    def reward_range(self):
        return 0, self._frame_skip

    def seed(self, seed):
        self._true_action_space.seed(seed)
        self._norm_action_space.seed(seed)
        self._observation_space.seed(seed)

    def step(self, action):
        assert self._norm_action_space.contains(action)
        action = self._convert_action(action)
        assert self._true_action_space.contains(action)
        reward = 0
        extra = {'internal_state': self._env.physics.get_state().copy()}

        for _ in range(self._frame_skip):
            time_step = self._env.step(action)
            reward += time_step.reward or 0
            done = time_step.last()
            if done:
                break
        obs = self._get_obs(time_step)
        self.current_state = _flatten_obs(time_step.observation)
        extra['discount'] = time_step.discount
        return obs, reward, done, extra

    def reset(self):
        time_step = self._env.reset()
        self.current_state = _flatten_obs(time_step.observation)
        obs = self._get_obs(time_step)
        return obs

    def render(self, mode='rgb_array', height=None, width=None, camera_id=0):
        assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode
        height = height or self._height
        width = width or self._width
        camera_id = camera_id or self._camera_id
        return self._env.physics.render(
            height=height, width=width, camera_id=camera_id
        )

#fre/common/envs/exorl/exorl_utils.py
import os
from pathlib import Path
import glob
import tqdm
import numpy as np
from collections import defaultdict

# from fre.common.envs.d4rl import d4rl_utils
#from fre.common.dataset import Dataset

# get path relative to 'fre' package
data_path = open('/content/download.py')
print("Path to exorl data is", data_path)

def get_dataset(env, env_name, method='rnd', dmc_dataset_size=10000000, use_task_reward=True):

    # dmc_dataset_size /= 10
    # print("WARNING: Only using 10 percent of exorl data.")

    domain_name, task_name = env_name.split('_', 1)

    path = os.path.join(data_path, domain_name, method)
    if not os.path.exists(path):
        print("Downloading exorl data.")
        os.makedirs(path)
        url = "https://dl.fbaipublicfiles.com/exorl/" + domain_name + "/" + method + ".zip"
        print("Downloading from", url)
        os.system("wget " + url + " -P " + path)
        os.system("unzip " + path + "/" + method + ".zip -d " + path)

    # process data into Dataset object.
    path = os.path.join(data_path, domain_name, method, 'buffer')
    npzs = sorted(glob.glob(f'{path}/*.npz'))
    dataset_npy = os.path.join(data_path, domain_name, method, task_name + '.npy')
    if os.path.exists(dataset_npy):
        dataset = np.load(dataset_npy, allow_pickle=True).item()
    else:
        print("Calculating exorl rewards.")
        dataset = defaultdict(list)
        num_steps = 0
        for i, npz in tqdm.tqdm(enumerate(npzs)):
            traj_data = dict(np.load(npz))
            dataset['observations'].append(traj_data['observation'][:-1, :])
            dataset['next_observations'].append(traj_data['observation'][1:, :])
            dataset['actions'].append(traj_data['action'][1:, :])
            dataset['physics'].append(traj_data['physics'][1:, :])  # Note that this corresponds to next_observations (i.e., r(s, a, s') = r(s') -- following the original DMC rewards)

            if use_task_reward:
                # TODO: make this faster and sanity check it
                rewards = []
                reward_spec = env.reward_spec()
                states = traj_data['physics']
                for j in range(states.shape[0]):
                    with env.physics.reset_context():
                        env.physics.set_state(states[j])
                    reward = env.task.get_reward(env.physics)
                    reward = np.full(reward_spec.shape, reward, reward_spec.dtype)
                    rewards.append(reward)
                traj_data['reward'] = np.array(rewards, dtype=reward_spec.dtype)
                dataset['rewards'].append(traj_data['reward'][1:])
            else:
                dataset['rewards'].append(traj_data['reward'][1:, 0])

            terminals = np.full((len(traj_data['observation']) - 1,), False)
            dataset['terminals'].append(terminals)
            num_steps += len(traj_data['observation']) - 1
            if num_steps >= dmc_dataset_size:
                break
        print("Loaded {} steps".format(num_steps))
        for k, v in dataset.items():
            dataset[k] = np.concatenate(v, axis=0)
        np.save(dataset_npy, dataset)



    # Processing
    masks = 1.0 - dataset['terminals']
    dones_float = dataset['terminals']

    return Dataset.create(
        observations=dataset['observations'],
        actions=dataset['actions'],
        rewards=dataset['rewards'],
        masks=masks,
        dones_float=dones_float,
        next_observations=dataset['next_observations'],
    )

#fre/common/networks/basic.py
###############################
#
#  Common Flax Networks.
#
###############################
#from fre.common.typing import *



  and should_run_async(code)


Path to exorl data is <_io.TextIOWrapper name='/content/download.py' mode='r' encoding='UTF-8'>


In [7]:
!pip install distrax

  and should_run_async(code)


Collecting distrax
  Downloading distrax-0.1.5-py3-none-any.whl (319 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/319.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m307.2/319.7 kB[0m [31m9.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m319.7/319.7 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: distrax
Successfully installed distrax-0.1.5


In [8]:
import flax.linen as nn
import jax.numpy as jnp

import distrax
import flax.linen as nn
import jax.numpy as jnp
from dataclasses import field

###############################
#
#  Common Networks
#
###############################

def mish(x):
    return x * jnp.tanh(nn.softplus(x))

def default_init(scale: Optional[float] = 1.0):
    return nn.initializers.variance_scaling(scale, "fan_avg", "uniform")

class MLP(nn.Module):
    hidden_dims: Sequence[int]
    activations: Callable[[jnp.ndarray], jnp.ndarray] = mish
    activate_final: int = False
    use_layer_norm: bool = True
    kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_init()

    def setup(self):
        self.layers = [
            nn.Dense(size, kernel_init=self.kernel_init) for size in self.hidden_dims
        ]
        if self.use_layer_norm:
            self.layer_norms = [nn.LayerNorm() for _ in self.hidden_dims]

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i + 1 < len(self.layers) and self.use_layer_norm:
                x = self.layer_norms[i](x)
            if i + 1 < len(self.layers) or self.activate_final:
                x = self.activations(x)
        return x

###############################
#
#  Common RL Networks
#
###############################


# DQN-style critic.
class DiscreteCritic(nn.Module):
    hidden_dims: Sequence[int]
    n_actions: int
    mlp_kwargs: Dict[str, Any] = field(default_factory=dict)

    @nn.compact
    def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
        return MLP((*self.hidden_dims, self.n_actions), **self.mlp_kwargs)(
            observations
        )

# Q(s,a) critic.
class Critic(nn.Module):
    hidden_dims: Sequence[int]
    mlp_kwargs: Dict[str, Any] = field(default_factory=dict)

    @nn.compact
    def __call__(self, observations: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray:
        inputs = jnp.concatenate([observations, actions], -1)
        critic = MLP((*self.hidden_dims, 1), **self.mlp_kwargs)(inputs)
        return jnp.squeeze(critic, -1)

# V(s) critic.
class ValueCritic(nn.Module):
    hidden_dims: Sequence[int]
    mlp_kwargs: Dict[str, Any] = field(default_factory=dict)

    @nn.compact
    def __call__(self, observations: jnp.ndarray) -> jnp.ndarray:
        critic = MLP((*self.hidden_dims, 1), **self.mlp_kwargs)(observations)
        return jnp.squeeze(critic, -1)

# pi(a|s). Returns a distrax distribution.
class Policy(nn.Module):
    hidden_dims: Sequence[int]
    action_dim: int
    mlp_kwargs: Dict[str, Any] = field(default_factory=dict)

    is_discrete: bool = False
    log_std_min: Optional[float] = -20
    log_std_max: Optional[float] = 2
    mean_min: Optional[float] = -5
    mean_max: Optional[float] = 5
    tanh_squash_distribution: bool = False
    state_dependent_std: bool = True
    final_fc_init_scale: float = 1e-2

    @nn.compact
    def __call__(
        self, observations: jnp.ndarray, temperature: float = 1.0
    ) -> distrax.Distribution:
        outputs = MLP(
            self.hidden_dims,
            activate_final=True,
            **self.mlp_kwargs
        )(observations)

        if self.is_discrete:
            logits = nn.Dense(
                self.action_dim, kernel_init=default_init(self.final_fc_init_scale)
            )(outputs)
            distribution = distrax.Categorical(logits=logits / jnp.maximum(1e-6, temperature))
        else:
            means = nn.Dense(
                self.action_dim, kernel_init=default_init(self.final_fc_init_scale)
            )(outputs)
            if self.state_dependent_std:
                log_stds = nn.Dense(
                    self.action_dim, kernel_init=default_init(self.final_fc_init_scale)
                )(outputs)
            else:
                log_stds = self.param("log_stds", nn.initializers.zeros, (self.action_dim,))

            log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max)
            means = jnp.clip(means, self.mean_min, self.mean_max)

            distribution = distrax.MultivariateNormalDiag(
                loc=means, scale_diag=jnp.exp(log_stds) * temperature
            )
            if self.tanh_squash_distribution:
                distribution = TransformedWithMode(
                    distribution, distrax.Block(distrax.Tanh(), ndims=1)
                )
        return distribution

###############################
#
#   Helper Things
#
###############################


class TransformedWithMode(distrax.Transformed):
    def mode(self) -> jnp.ndarray:
        return self.bijector.forward(self.distribution.mode())

def ensemblize(cls, num_qs, out_axes=0, **kwargs):
    """
    Useful for making ensembles of Q functions (e.g. double Q in SAC).

    Usage:

        critic_def = ensemblize(Critic, 2)(hidden_dims=hidden_dims)

    """
    return nn.vmap(
        cls,
        variable_axes={"params": 0},
        split_rngs={"params": True},
        in_axes=None,
        out_axes=out_axes,
        axis_size=num_qs,
        **kwargs
    )

#fre/common/networks/transformers
from typing import Any, Callable, Optional, Tuple, Type

import flax.linen as nn
import jax.numpy as jnp

Array = Any
PRNGKey = Any
Shape = Tuple[int]
Dtype = Any


class IdentityLayer(nn.Module):
    """Identity layer, convenient for giving a name to an array."""

    @nn.compact
    def __call__(self, x):
        return x


class AddPositionEmbs(nn.Module):
    # Need to define function that adds the poisition embeddings to the input.
    posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]

    @nn.compact
    def __call__(self, inputs):
        """
            inputs.shape is (batch_size, timesteps, emb_dim).
            Output tensor with shape `(batch_size, timesteps, in_dim)`.
        """
        assert inputs.ndim == 3, ('Number of dimensions should be 3, but it is: %d' % inputs.ndim)

        position_ids = jnp.arange(inputs.shape[1])[None] # (1, timesteps)
        pos_embeddings = nn.Embed(
            128, # Max Positional Embeddings
            inputs.shape[2],
            embedding_init=self.posemb_init,
            dtype=inputs.dtype,
        )(position_ids)
        print("For Input Shape {}, Pos Embes Shape is {}".format(inputs.shape, pos_embeddings.shape))
        return inputs + pos_embeddings

        # pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
        # pe = self.param('pos_embedding', self.posemb_init, pos_emb_shape)
        # return inputs + pe


class MlpBlock(nn.Module):
    """Transformer MLP / feed-forward block."""

    mlp_dim: int
    dtype: Dtype = jnp.float32
    out_dim: Optional[int] = None
    dropout_rate: float = None
    kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform()
    bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6)

    @nn.compact
    def __call__(self, inputs, *, deterministic):
        """It's just an MLP, so the input shape is (batch, len, emb)."""
        actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
        x = nn.Dense(
                features=self.mlp_dim,
                dtype=self.dtype,
                kernel_init=self.kernel_init,
                bias_init=self.bias_init)(inputs)
        x = nn.gelu(x)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
        output = nn.Dense(
                features=actual_out_dim,
                dtype=self.dtype,
                kernel_init=self.kernel_init,
                bias_init=self.bias_init)(x)
        output = nn.Dropout(
                rate=self.dropout_rate)(output, deterministic=deterministic)
        return output


class Encoder1DBlock(nn.Module):
    """Transformer encoder layer.
    Given a sequence, it passes it through an attention layer, then through a mlp layer.
    In each case it is a residual block with a layer norm.
    """

    mlp_dim: int
    num_heads: int
    causal: bool
    dropout_rate: float
    attention_dropout_rate: float
    dtype: Dtype = jnp.float32

    @nn.compact
    def __call__(self, inputs, *, deterministic, train=True):

        if self.causal:
            causal_mask = nn.make_causal_mask(jnp.ones((inputs.shape[0], inputs.shape[1]),
                                                        dtype="bool"), dtype="bool")
            print("Using Causal Mask with shape", causal_mask.shape, "and inputs shape", inputs.shape, ".")
        else:
            causal_mask = None

        # Attention block.
        assert inputs.ndim == 3, f'Expected (batch, seq, hidden) got {inputs.shape}'
        x = nn.LayerNorm(dtype=self.dtype)(inputs)
        x = nn.MultiHeadDotProductAttention(
            dtype=self.dtype,
            kernel_init=nn.initializers.xavier_uniform(),
            broadcast_dropout=False,
            deterministic=deterministic,
            dropout_rate=self.attention_dropout_rate,
            decode=False,
            num_heads=self.num_heads)(x, x, causal_mask)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
        x = x + inputs

        # MLP block. This does NOT change the embedding dimension!
        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)(y, deterministic=deterministic)

        return x + y


class Transformer(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation.
    """

    num_layers: int
    emb_dim: int
    mlp_dim: int
    num_heads: int
    dropout_rate: float
    attention_dropout_rate: float
    causal: bool = True

    @nn.compact
    def __call__(self, x, *, train):
        assert x.ndim == 3  # (batch, len, emb)
        assert x.shape[-1] == self.emb_dim

        # Input Encoder. Each layer processes x, but the shape of x does not change.
        for lyr in range(self.num_layers):
            x = Encoder1DBlock(
                    mlp_dim=self.mlp_dim,
                    dropout_rate=self.dropout_rate,
                    attention_dropout_rate=self.attention_dropout_rate,
                    name=f'encoderblock_{lyr}',
                    causal=self.causal,
                    num_heads=self.num_heads)(
                            x, deterministic=not train, train=train)
        encoded = nn.LayerNorm(name='encoder_norm')(x)

        return encoded

def get_default_config():
    import ml_collections

    config = ml_collections.ConfigDict({
        'num_layers': 4,
        'emb_dim': 256,
        'mlp_dim': 256,
        'num_heads': 4,
        'dropout_rate': 0.0,
        'attention_dropout_rate': 0.0,
        'causal': True,
    })
    return config

#fre/experiment/ant_helper.py

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

def get_canvas_image(canvas):
    canvas.draw()
    out_image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
    out_image = out_image.reshape(canvas.get_width_height()[::-1] + (3,))
    return out_image

def value_image(env, dataset, value_fn, mask, clip=False):
    """
    Visualize the value function.
    Args:
        env: The environment.
        value_fn: a function with signature value_fn([# states, state_dim]) -> [#states, 1]
    Returns:
        A numpy array of the image.
    """
    fig = plt.figure(tight_layout=True)
    canvas = FigureCanvas(fig)
    plot_value(env, dataset, value_fn, mask, fig, plt.gca(), clip=clip)
    image = get_canvas_image(canvas)
    plt.close(fig)
    return image

def plot_value(env, dataset, value_fn, mask, fig, ax, title=None, clip=True):
    N = 14
    M = 20
    ob_xy = env.XY(n=N, m=M)

    base_observation = np.copy(dataset['observations'][0])
    base_observations = np.tile(base_observation, (5, ob_xy.shape[0], 1))
    base_observations[:, :, :2] = ob_xy
    base_observations[:, :, 15:17] = 0.0
    base_observations[0, :, 15] = 1.0
    base_observations[1, :, 16] = 1.0
    base_observations[2, :, 15] = -1.0
    base_observations[3, :, 16] = -1.0
    print("Base observations, ", base_observations.shape)


    values = []
    for i in range(5):
        values.append(value_fn(base_observations[i]))
    values = np.stack(values, axis=0)
    print("Values", values.shape)

    x, y = ob_xy[:, 0], ob_xy[:, 1]
    x = x.reshape(N, M)
    y = y.reshape(N, M) * 0.975 + 0.7
    values = values.reshape(5, N, M)
    values[-1, 10, 0] = np.min(values[-1]) + 0.3 # Hack to make the scaling not show small errors.
    print("Clip:", clip)
    if clip:
        mesh = ax.pcolormesh(x, y, values[-1], cmap='viridis', vmin=-0.1, vmax=1.0)
    else:
        mesh = ax.pcolormesh(x, y, values[-1], cmap='viridis')

    v = (values[1] - values[3]) / 2
    u = (values[0] - values[2]) / 2
    uv_dist = np.sqrt(u**2 + v**2) + 1e-6
    # Normalize u,v
    un = u / uv_dist
    vn = v / uv_dist
    un[uv_dist < 0.1] = 0
    vn[uv_dist < 0.1] = 0

    plt.quiver(x, y, un, vn, color='r', pivot='mid', scale=0.75, scale_units='xy')

    if mask is not None and type(mask) == np.ndarray:
        # mask = NxM array of things to unmask.
        from matplotlib.colors import LinearSegmentedColormap
        colors = [(0,0,0,c) for c in np.linspace(0,1,100)]
        cmapred = LinearSegmentedColormap.from_list('mycmap', colors, N=5)
        mask_mesh_ax = ax.pcolormesh(x, y, mask, cmap=cmapred)
    elif mask is not None and type(mask) is list:
        maskmesh = np.ones((N, M))
        for xy in mask:
            for xi in range(N):
                for yi in range(M):
                    if np.linalg.norm(np.array(xy) - np.array([x[xi, yi], y[xi, yi]])) < 1.4:
                        # print(xy, x[xi, yi], y[xi, yi])
                        maskmesh[xi,yi] = 0
        from matplotlib.colors import LinearSegmentedColormap
        colors = [(0,0,0,c) for c in np.linspace(0,1,100)]
        cmapred = LinearSegmentedColormap.from_list('mycmap', colors, N=5)
        mask_mesh_ax = ax.pcolormesh(x, y, maskmesh, cmap=cmapred)

    env.draw(ax, scale=0.95)



    # env.draw(ax, scale=1.0)

    # divider = make_axes_locatable(ax)
    # cax = divider.append_axes('right', size='5%', pad=0.05)
    # fig.colorbar(mesh, cax=cax, orientation='vertical')

    if title:
        ax.set_title(title)



  and should_run_async(code)


In [9]:
!pip install opensimplex

  and should_run_async(code)


Collecting opensimplex
  Downloading opensimplex-0.4.5.1-py3-none-any.whl (267 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.0/268.0 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: opensimplex
Successfully installed opensimplex-0.4.5.1


In [10]:
#fre/experiment/rewards_unsupervised
import numpy as np
import tqdm
import opensimplex
import jax
import jax.numpy as jnp
from functools import partial

class RewardFunction():
    # Given a batch of trajectory states and random states, generate a reward function.
    # Return the labelled state-reward pairs. (batch_size, num_pairs, obs_dim + 1)
    def generate_params_and_pairs(self, traj_states, random_states):
        raise NotImplementedError

    # Given a batch of states and a batch of parameters, compute the reward.
    def compute_reward(self, states, params):
        raise NotImplementedError

class GoalReachingRewardFunction(RewardFunction):
    def __init__(self):
        self.p_current = 0.2
        self.p_trajectory = 0.5
        self.p_random = 0.3

    # TODO: If this is slow, we can try and JIT it.
    # Select a random goal from the provided states.
    def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
        all_states = np.concatenate([traj_states, random_states], axis=1)
        batch_size = all_states.shape[0]
        p_trajectory_normalized = self.p_trajectory / traj_states.shape[1]
        p_random_normalized = self.p_random / random_states.shape[1]
        probabilities = [self.p_current] + [p_trajectory_normalized] * (traj_states.shape[1]-1) \
            + [p_random_normalized] * random_states.shape[1]
        probabilities = np.array(probabilities) / np.sum(probabilities)
        selected_goal_idx = np.random.choice(len(probabilities), size=(batch_size,), p=probabilities)
        selected_goal = all_states[np.arange(batch_size), selected_goal_idx]

        params = selected_goal # (batch_size, obs_dim)
        encode_pairs = np.concatenate([traj_states, random_states], axis=1)
        encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
        encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)

        decode_pairs = random_states_decode
        decode_pairs[:, 0] = params # Decode the goal state too.
        decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
        decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)

        rewards = encode_pairs[:, 0, -1] # (batch_size,)
        masks = -rewards # If (rew=-1, mask=1), else (rew=0, mask=0)

        return params, encode_pairs, decode_pairs, rewards, masks

    def compute_reward(self, states, params, delta=False):
        assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
        if states.shape[-1] == 29: # AntMaze
            if delta:
                dists = np.linalg.norm(states - params, axis=-1)
                is_goal = (dists < 0.1)
            else:
                dists = np.linalg.norm(states[..., :2] - params[..., :2], axis=-1)
                is_goal = (dists < 2)
            return -1 + is_goal.astype(float) # (batch_size,)
        elif states.shape[-1] == 18: # Cheetah
            std = np.array([[0.4407440506721877, 10.070289916801876, 0.5172332956856273, 0.5601041145815341, 0.518947027289748, 0.3204431592542281, 0.5501848643154092, 0.3856393812067661, 1.9882502334402663, 1.6377168569884073, 4.308505013609855, 12.144181770553105, 13.537567521831702, 16.88983033626308, 7.715009572436841, 14.345667964212357, 10.6904255152284, 100]])
            assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
            # if len(states.shape) == 3:
            #     breakpoint()
            dists_per_dim = states - params
            dists_per_dim = dists_per_dim / std
            dists = np.linalg.norm(dists_per_dim, axis=-1) / states.shape[-1]
            is_goal = (dists < 0.08)
            # print(dists_per_dim)
            # print(dists, is_goal)
            return -1 + is_goal.astype(float) # (batch_size,)
        elif states.shape[-1] == 27: # Walker
            std = np.array([[0.7212967364054736, 0.6775020895964047, 0.7638155887842976, 0.6395721376821286, 0.6849394775886244, 0.7078581708129903, 0.7113168519036742, 0.6753408522523937, 0.6818095329625652, 0.7133958718133511, 0.65227578338642, 0.757622576816855, 0.7311826446274479, 0.6745824928740024, 0.36822491550384456, 2.1134839667805805, 1.813353841099317, 10.594648894374815, 17.41041469033713, 17.836743227082106, 22.399097178637533, 16.1492222730888, 15.693574546557201, 18.539929326905067, 100, 100, 100]])
            assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
            dists_per_dim = states - params
            dists_per_dim = dists_per_dim / std
            dists = np.linalg.norm(dists_per_dim, axis=-1) / states.shape[-1]
            is_goal = (dists < 0.2)
            return -1 + is_goal.astype(float) # e6yfwsc ebnev (batch_size,)
        elif states.shape[-1] == 30: # Kitchen
            dists_per_dim = states - params
            dists_per_dim = dists_per_dim
            dists = np.linalg.norm(dists_per_dim, axis=-1) / states.shape[-1]
            is_goal = (dists < 1e-6)
            return -1 + is_goal.astype(float)
        else:
            raise NotImplementedError

    def make_encoder_pairs_testing(self, params, random_states):
        assert len(params.shape) == 2, params.shape # (batch_size, 2)
        assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)

        if random_states.shape[-1] == 29: # AntMaze
            random_states[:, 0, :2] = params[:, :2] # Make sure to include the goal.
        else:
            random_states[:, 0] = params
        reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
        reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
        return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)

class LinearRewardFunction(RewardFunction):
    def __init__(self):
        pass

    # Randomly generate a linear weighting over state features.
    def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
        assert len(traj_states.shape) == 3, traj_states.shape # (batch_size, traj_len, obs_dim)
        batch_size = traj_states.shape[0]
        state_len = traj_states.shape[-1]

        params = np.random.uniform(-1, 1, size=(batch_size, state_len)) # Uniform weighting.
        random_mask = np.random.uniform(size=(batch_size,state_len)) < 0.9
        if state_len == 29:
            random_mask[:, :2] = True # Zero out the XY position for antmaze.
        random_mask_positive = np.random.randint(2, state_len, size=(batch_size))
        random_mask[np.arange(batch_size), random_mask_positive] = False # Force at least one positive weight.
        params[random_mask] = 0 # Zero out some of the weights.
        # if state_len == 29:
        #     params = params / np.linalg.norm(params, axis=-1, keepdims=True) # Normalize XY

        # Remove auxilliary features during training.
        if state_len == 27:
            params[:, -3:] = 0
        if state_len == 18:
            params[:, -1:] = 0

        clip_bit = np.random.uniform(size=(batch_size,)) < 0.5
        params = np.concatenate([params, clip_bit[:, None]], axis=-1) # (batch_size, obs_dim + 1)

        encode_pairs = np.concatenate([traj_states, random_states], axis=1)
        encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
        encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)

        decode_pairs = random_states_decode
        decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
        decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)

        rewards = encode_pairs[:, 0, -1] # (batch_size,)
        masks = np.ones_like(rewards) # (batch_size,)

        return params, encode_pairs, decode_pairs, rewards, masks

    def compute_reward(self, states, params):
        params_raw = params[..., :-1]
        assert len(states.shape) == len(params_raw.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
        r = np.sum(states * params_raw, axis=-1) # (batch_size,)
        r = np.where(params[..., -1] > 0, np.clip(r, 0, 1), np.clip(r, -1, 1))
        return r

    def make_encoder_pairs_testing(self, params, random_states):
        assert len(params.shape) == 2, params.shape # (batch_size, 2)
        assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)

        reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
        reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
        return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)

class RandomRewardFunction(RewardFunction):
    def __init__(self, num_simplex, obs_len=29):
        # Pre-compute parameter matrices.
        print("Generating parameter matrices...")
        self.simplex_size = num_simplex
        np_random = np.random.RandomState(0)
        self.param_w1 = np_random.normal(size=(self.simplex_size, obs_len, 32)) * np.sqrt(1/32)
        self.param_b1 = np_random.normal(size=(self.simplex_size, 1, 32)) * np.sqrt(16)
        self.param_w2 = np_random.normal(size=(self.simplex_size, 32, 1)) * np.sqrt(1/16)

        # Remove auxilliary features during training.
        if obs_len == 27:
            self.param_w1[:, -3:] = 0
        if obs_len == 18:
            self.param_w1[:, -1:] = 0

    # Random neural network.
    def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
        batch_size = traj_states.shape[0]
        params = np.random.randint(self.simplex_size, size=(batch_size, 1)) # (batch_size, 1)

        encode_pairs = np.concatenate([traj_states, random_states], axis=1)
        encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
        encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)

        decode_pairs = random_states_decode
        decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
        decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)

        rewards = encode_pairs[:, 0, -1] # (batch_size,)
        masks = np.ones_like(rewards) # (batch_size,)

        return params, encode_pairs, decode_pairs, rewards, masks

    def compute_reward(self, states, params):
        assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)

        param_id = params[..., 0].astype(int)
        param1_w = self.param_w1[param_id]
        param1_b = self.param_b1[param_id]
        param2_w = self.param_w2[param_id]

        obs = states
        x = np.expand_dims(obs, -2) # [batch, (pairs), 1, features_in]
        x = np.matmul(x, param1_w) # [batch, (pairs), 1, features_out]
        x = x + param1_b
        x = np.tanh(x)
        x = np.matmul(x, param2_w) # [batch, (pairs), 1, 1]
        x = x.squeeze(-1).squeeze(-1) # [batch, (pairs)]
        x = np.clip(x, -1, 1)
        return x

    def make_encoder_pairs_testing(self, params, random_states):
        assert len(params.shape) == 2, params.shape # (batch_size, 2)
        assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)
        batch_size = random_states.shape[0]

        reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
        reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
        return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)


  and should_run_async(code)


In [12]:
#fre/experiment/rewards_eval.py

import numpy as np
import tqdm
import opensimplex
import jax
import jax.numpy as jnp
from functools import partial

#from fre.experiment.rewards_unsupervised import RewardFunction


class VelocityRewardFunction(RewardFunction):
    def __init__(self):
        pass

    # Select an XY velocity from a future state in the trajectory.
    def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
        batch_size = traj_states.shape[0]
        selected_traj_state_idx = np.random.randint(traj_states.shape[1], size=(batch_size,))
        selected_traj_state = traj_states[np.arange(batch_size), selected_traj_state_idx] # (batch_size, obs_dim)
        params = selected_traj_state[:, 15:17] # (batch_size, 2)
        params[:batch_size//4] = np.random.uniform(-1, 1, size=(batch_size//4, 2)) # Randomize 25% of the time.
        params = params / np.linalg.norm(params, axis=-1, keepdims=True) # Normalize XY

        encode_pairs = np.concatenate([traj_states, random_states], axis=1)
        encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
        encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)

        decode_pairs = random_states_decode
        decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
        decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)

        rewards = encode_pairs[:, 0, -1] # (batch_size,)
        masks = np.ones_like(rewards) # (batch_size,)

        return params, encode_pairs, decode_pairs, rewards, masks

    def compute_reward(self, states, params):
        assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
        xy_vels = states[..., 15:17] * 0.33820298
        return np.sum(xy_vels * params, axis=-1) # (batch_size,)

    def make_encoder_pairs_testing(self, params, random_states):
        assert len(params.shape) == 2, params.shape # (batch_size, 2)
        assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)

        reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
        reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
        return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)

class TestRewMatrix(RewardFunction):
    def __init__(self):
        self.pos = np.zeros((36, 24))
        self.xvel = np.zeros((36, 24))
        self.yvel = np.zeros((36, 24))

    def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
        batch_size = traj_states.shape[0]
        params = np.zeros((batch_size, 1)) # (batch_size, 1)

        encode_pairs = np.concatenate([traj_states, random_states], axis=1)
        encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
        encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)

        decode_pairs = random_states_decode
        decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
        decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)

        rewards = encode_pairs[:, 0, -1] # (batch_size,)
        masks = np.ones_like(rewards) # (batch_size,)

        return params, encode_pairs, decode_pairs, rewards, masks

    def compute_reward(self, s, params):
        rews = np.zeros_like(s[..., 0]) # (batch, examples)
        # XY Vel Reward
        xy_vels = s[..., 15:17] * 0.33820298

        x = s[..., 0].astype(int).clip(0, 35)
        y = s[..., 1].astype(int).clip(0, 23)
        simplex = self.pos[x, y]
        simplex_xvel = self.xvel[x, y]
        simplex_yvel = self.yvel[x, y]
        rews = (simplex > 0.3).astype(float) * 0.5
        rews += xy_vels[...,0] * simplex_xvel + xy_vels[...,1] * simplex_yvel

        return rews

    def make_encoder_pairs_testing(self, params, random_states):
        assert len(params.shape) == 2, params.shape # (batch_size, 2)
        assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)
        batch_size = random_states.shape[0]

        # TODO: Be smarter about the states to use here.

        reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
        reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
        return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)

class SimplexRewardFunction(RewardFunction):
    def __init__(self, num_simplex):
        self.simplex_size = num_simplex
        self.simplex_seeds_pos = np.zeros((self.simplex_size, 36, 24))
        self.simplex_seeds_xvel = np.zeros((self.simplex_size, 36, 24))
        self.simplex_seeds_yvel = np.zeros((self.simplex_size, 36, 24))
        self.simplex_best_xy = np.zeros((self.simplex_size, 10, 2))
        print("Generating simplex seeds")
        xi = np.arange(36)
        yi = np.arange(24)
        for r in tqdm.tqdm(range(self.simplex_size)):
            opensimplex.seed(r)
            self.simplex_seeds_pos[r] = opensimplex.noise2array(x=xi/20.0, y=yi/20.0).T
            opensimplex.seed(r + self.simplex_size)
            self.simplex_seeds_xvel[r] = opensimplex.noise2array(x=xi/20.0, y=yi/20.0).T
            opensimplex.seed(r + self.simplex_size * 2)
            self.simplex_seeds_yvel[r] = opensimplex.noise2array(x=xi/20.0, y=yi/20.0).T

            best_topn = np.argpartition(self.simplex_seeds_pos[r].flatten(), -10)[-10:] # (10,)
            best_xy = np.array(np.unravel_index(best_topn, self.simplex_seeds_pos[r].shape)).T # (10, 2)
            self.simplex_best_xy[r] = best_xy
        self.simplex_seeds_xvel[np.abs(self.simplex_seeds_xvel) < 0.5] = 0
        self.simplex_seeds_yvel[np.abs(self.simplex_seeds_yvel) < 0.5] = 0

    # Select an XY velocity from a future state in the trajectory.
    def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
        batch_size = traj_states.shape[0]
        params = np.random.randint(self.simplex_size, size=(batch_size, 1)) # (batch_size, 1)

        encode_pairs = np.concatenate([traj_states, random_states], axis=1)
        encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
        encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)

        decode_pairs = random_states_decode
        decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
        decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)

        rewards = encode_pairs[:, 0, -1] # (batch_size,)
        masks = np.ones_like(rewards) # (batch_size,)

        return params, encode_pairs, decode_pairs, rewards, masks

    def compute_reward(self, states, params):
        assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)

        simplex_id = params[..., 0].astype(int)
        x = states[..., 0].astype(int).clip(0, 35)
        y = states[..., 1].astype(int).clip(0, 23)
        simplex = self.simplex_seeds_pos[simplex_id, x, y]
        simplex_xvel = self.simplex_seeds_xvel[simplex_id, x, y]
        simplex_yvel = self.simplex_seeds_yvel[simplex_id, x, y]
        rews = -1 + (simplex > 0.3).astype(float) * 0.5
        xy_vels = states[..., 15:17] * 0.33820298
        rews += xy_vels[...,0] * simplex_xvel + xy_vels[...,1] * simplex_yvel
        return rews # (batch_size,)

    def make_encoder_pairs_testing(self, params, random_states):
        assert len(params.shape) == 2, params.shape # (batch_size, 2)
        assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)
        batch_size = random_states.shape[0]

        # For simplex rewards, make sure to include the top 4 best points.
        simplex_id = params[..., 0].astype(int)
        random_best_4 = np.random.randint(0, 10, size=(batch_size, 4))
        random_states[:, :4, :2] = self.simplex_best_xy[simplex_id[:, None], random_best_4, :]

        reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
        reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
        return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)

class TestRewMatrixEdges(TestRewMatrix):
    def __init__(self):
        super().__init__()
        self.pos[:3, :] = 1
        self.pos[-3:, :] = 1
        self.pos[:, :3] = 1
        self.pos[:, -3:] = 1

class TestRewLoop(TestRewMatrix):
    def __init__(self):
        super().__init__()
        self.pos[22:33, 14:18] = 1
        self.xvel[22:33, 14:18] = -1

        self.pos[21:, 0:3] = 1
        self.xvel[21:, 0:3] = 1

        self.pos[33:, 3:18] = 1
        self.yvel[33:, 3:18] = 1

        self.pos[18:21, 0:7] = 1
        self.yvel[18:21, 0:7] = -1

class TestRewPath(TestRewMatrix):
    def __init__(self):
        super().__init__()
        self.pos[3:21, 7:10] = 1
        self.xvel[3:21, 7:10] = -1

        self.pos[0:3, 3:10] = 1
        self.yvel[0:3, 3:10] = -1

        self.pos[0:18, 0:3] = 1
        self.xvel[0:18, 0:3] = 1

class TestRewLoop2(TestRewMatrix):
    def __init__(self):
        super().__init__()
        self.pos[22:33, 14:18] = 1
        self.pos[21:, 0:3] = 1
        self.pos[33:, 3:18] = 1
        self.pos[18:21, 0:7] = 1

class TestRewPath2(TestRewMatrix):
    def __init__(self):
        super().__init__()
        self.pos[3:21, 7:10] = 1
        self.pos[0:3, 3:10] = 1
        self.pos[0:18, 0:3] = 1


# =================== For DMC

class VelocityRewardFunctionWalker(RewardFunction):
    def __init__(self):
        pass

    def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
        batch_size = traj_states.shape[0]
        params = np.random.uniform(low=0, high=8, size=(batch_size, 1)) # (batch_size, 1)

        encode_pairs = np.concatenate([traj_states, random_states], axis=1)
        encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
        encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)

        decode_pairs = random_states_decode
        decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
        decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)

        rewards = encode_pairs[:, 0, -1] # (batch_size,)
        masks = np.ones_like(rewards) # (batch_size,)

        return params, encode_pairs, decode_pairs, rewards, masks

    def _sigmoids(self, x, value_at_1, sigmoid):
        if sigmoid == 'gaussian':
            scale = np.sqrt(-2 * np.log(value_at_1))
            return np.exp(-0.5 * (x*scale)**2)

        elif sigmoid == 'linear':
            scale = 1-value_at_1
            scaled_x = x*scale
            return np.where(abs(scaled_x) < 1, 1 - scaled_x, 0.0)

    def tolerance(self, x, lower, upper, margin=0.0, sigmoid='gaussian', value_at_margin=0.1):
        in_bounds = np.logical_and(lower <= x, x <= upper)
        d = np.where(x < lower, lower - x, x - upper) / margin
        value = np.where(in_bounds, 1.0, self._sigmoids(d, value_at_margin, sigmoid))
        return value

    def compute_reward(self, states, params):
        assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)

        _STAND_HEIGHT = 1.2
        horizontal_velocity = states[..., 24:25]
        torso_upright = states[..., 25:26]
        torso_height = states[..., 26:27]
        standing = self.tolerance(torso_height, lower=_STAND_HEIGHT, upper=float('inf'), margin=_STAND_HEIGHT/2)
        upright = (1 + torso_upright) / 2
        stand_reward = (3*standing + upright) / 4
        move_reward = self.tolerance(horizontal_velocity,
                                        lower=params,
                                        upper=float('inf'),
                                        margin=params/2,
                                        value_at_margin=0.5,
                                        sigmoid='linear')
        # move_reward[params == 0] = stand_reward[params == 0]
        rew = stand_reward * (5*move_reward + 1) / 6
        return rew[..., 0]

    def make_encoder_pairs_testing(self, params, random_states):
        assert len(params.shape) == 2, params.shape # (batch_size, 2)
        assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)

        reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
        reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
        return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)

class VelocityRewardFunctionCheetah(RewardFunction):
    def __init__(self):
        pass

    def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
        batch_size = traj_states.shape[0]
        params = np.random.uniform(low=-10, high=10, size=(batch_size, 1)) # (batch_size, 1)

        encode_pairs = np.concatenate([traj_states, random_states], axis=1)
        encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
        encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)

        decode_pairs = random_states_decode
        decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
        decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)

        rewards = encode_pairs[:, 0, -1] # (batch_size,)
        masks = np.ones_like(rewards) # (batch_size,)

        return params, encode_pairs, decode_pairs, rewards, masks

    def _sigmoids(self, x, value_at_1, sigmoid):
        if sigmoid == 'linear':
            scale = 1-value_at_1
            scaled_x = x*scale
            return np.where(abs(scaled_x) < 1, 1 - scaled_x, 0.0)
        else:
            raise NotImplementedError

    def tolerance(self, x, lower, upper, margin=0.0, sigmoid='linear', value_at_margin=0):
        in_bounds = np.logical_and(lower <= x, x <= upper)
        d = np.where(x < lower, lower - x, x - upper) / margin
        value = np.where(in_bounds, 1.0, self._sigmoids(d, value_at_margin, sigmoid))
        return value

    def compute_reward(self, states, params):
        assert len(states.shape) == len(params.shape), states.shape # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)

        horizontal_velocity = states[..., 17:18]
        sign_of_param = np.sign(params)
        horizontal_velocity = horizontal_velocity * sign_of_param
        rew = self.tolerance(horizontal_velocity,
                             lower=np.abs(params),
                             upper=float('inf'),
                             margin=np.abs(params),
                             value_at_margin=0,
                             sigmoid='linear')
        return rew[..., 0]

    def make_encoder_pairs_testing(self, params, random_states):
        assert len(params.shape) == 2, params.shape # (batch_size, 2)
        assert len(random_states.shape) == 3, random_states.shape # (batch_size, num_pairs, obs_dim)

        reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
        reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
        return reward_pairs # (batch_size, reward_pairs, obs_dim + 1)

# =================== For Kitchen

class SingleTaskRewardFunction(RewardFunction):
    def __init__(self):
        self.obs_element_indices = {
            "bottom left burner": np.array([11, 12]),
            "top left burner": np.array([15, 16]),
            "light switch": np.array([17, 18]),
            "slide cabinet": np.array([19]),
            "hinge cabinet": np.array([20, 21]),
            "microwave": np.array([22]),
            "kettle": np.array([23, 24, 25, 26, 27, 28, 29]),
        }
        self.obs_element_goals = {
            "bottom left burner": np.array([-0.88, -0.01]),
            "top left burner": np.array([-0.92, -0.01]),
            "light switch": np.array([-0.69, -0.05]),
            "slide cabinet": np.array([0.37]),
            "hinge cabinet": np.array([0.0, 1.45]),
            "microwave": np.array([-0.75]),
            "kettle": np.array([-0.23, 0.75, 1.62, 0.99, 0.0, 0.0, -0.06]),
        }
        self.dist_thresh = 0.3
        self.num_tasks = len(self.obs_element_indices)

    def generate_params_and_pairs(self, traj_states, random_states, random_states_decode):
        batch_size = traj_states.shape[0]
        params = np.random.randint(self.num_tasks, size=(batch_size, 1))  # (batch_size, 1)
        params = np.eye(self.num_tasks)[params[:, 0]]  # (batch_size, num_tasks)

        encode_pairs = np.concatenate([traj_states, random_states], axis=1)
        encode_rewards = self.compute_reward(encode_pairs, params[:, None, :])[:, :, None]
        encode_pairs = np.concatenate([encode_pairs, encode_rewards], axis=-1)

        decode_pairs = random_states_decode
        decode_rewards = self.compute_reward(decode_pairs, params[:, None, :])[:, :, None]
        decode_pairs = np.concatenate([random_states_decode, decode_rewards], axis=-1)

        rewards = encode_pairs[:, 0, -1]  # (batch_size,)
        masks = np.ones_like(rewards)  # (batch_size,)

        return params, encode_pairs, decode_pairs, rewards, masks

    def compute_reward(self, states, params):
        assert len(states.shape) == len(params.shape), states.shape  # (batch_size, obs_dim) OR (batch_size, num_pairs, obs_dim)
        task_rewards = []
        for task, target_indices in self.obs_element_indices.items():
            task_dists = np.linalg.norm(states[..., target_indices] - self.obs_element_goals[task], axis=-1)
            task_completes = (task_dists < self.dist_thresh).astype(float)
            task_rewards.append(task_completes)
        task_rewards = np.stack(task_rewards, axis=-1)

        return np.sum(task_rewards * params, axis=-1)

    def make_encoder_pairs_testing(self, params, random_states):
        assert len(params.shape) == 2, params.shape  # (batch_size, 2)
        assert len(random_states.shape) == 3, random_states.shape  # (batch_size, num_pairs, obs_dim)

        reward_pair_rews = self.compute_reward(random_states, params[:, None, :])[..., None]
        reward_pairs = np.concatenate([random_states, reward_pair_rews], axis=-1)
        return reward_pairs  # (batch_size, reward_pairs, obs_dim + 1)


#fre\eperiment\run_free

import jax
import jax.numpy as jnp
import numpy as np
import optax
import flax.linen as nn
import functools
import ml_collections
from ml_collections import config_flags
from absl import app, flags
import os
import pickle
import tqdm
from flax.training import checkpoints
"""
import matplotlib.pyplot as plt
import wandb

from fre.common.typing import *
from fre.common.networks.transformer import Transformer
import fre.common.networks.transformer as transformer
from fre.common.dataset import Dataset
from fre.common.typing import *
from fre.common.train_state import TrainState, target_update
from fre.common.networks.basic import Policy, ValueCritic, Critic, ensemblize
from fre.common.wandb import setup_wandb, default_wandb_config, get_flag_dict
from fre.common.envs.gc_utils import GCDataset
from fre.common.envs.env_helper import make_env, get_dataset
from fre.common.evaluation import evaluate
from fre.common.envs.wrappers import EpisodeMonitor, RewardOverride, TruncateObservation
from fre.common.utils import supply_rng
from fre.experiment.rewards_unsupervised import *
from fre.experiment.rewards_eval import *

"""
import flax

###############################
#  Configs
###############################

for name in list(flags.FLAGS):
      delattr(flags.FLAGS,name)

FLAGS = flags.FLAGS
flags.DEFINE_string('env_name', 'antmaze-large-diverse-v2', 'Environment name.')
flags.DEFINE_integer('dmc_dataset_size', 5000000, 'ExORL dataset size.')
flags.DEFINE_string('name', 'default', '')

flags.DEFINE_string('save_dir', None, 'Logging dir (if not None, save params).')
flags.DEFINE_string('load_dir', None, 'Logging dir (if not None, load params).')

flags.DEFINE_integer('seed', np.random.choice(1000000), 'Random seed.')
flags.DEFINE_integer('eval_episodes', 20,
                     'Number of episodes used for evaluation.')
flags.DEFINE_integer('log_interval', 1000, 'Logging interval.')
flags.DEFINE_integer('eval_interval', 200000, 'Eval interval.')
flags.DEFINE_integer('save_interval', 200000, 'Eval interval.')
flags.DEFINE_integer('video_interval', 10050000, 'Eval interval.')
flags.DEFINE_integer('batch_size', 512, 'Mini batch size.')
flags.DEFINE_integer('max_steps', int(1e6), 'Number of training steps.')

flags.DEFINE_integer('reward_pairs_encode', 32, 'Number of reward pairs to use for encoding.')
flags.DEFINE_integer('reward_pairs_decode', 8, 'Number of reward pairs to use for decoding.')
flags.DEFINE_integer('reward_pairs_encode_test', 32, 'Number of reward pairs to use for encoding (for testing).')

flags.DEFINE_float('rew_ratio_goal', 0.3333, 'Ratio of reward functions that are goal.')
flags.DEFINE_float('rew_ratio_linear', 0.3333, 'Ratio of reward functions that are linear.')
flags.DEFINE_float('rew_ratio_mlp', 0.3333, 'Ratio of reward functions that are random mlp.')

# Env-Specific Settings
flags.DEFINE_string('start_loc', 'center2', 'Starting location of the ant')
flags.DEFINE_integer('use_discrete_xy', 1, 'Use discrete XY encoding for antmaze?')
flags.DEFINE_integer('dmc_use_oracle', 0, 'Use true rewards during training?')

agent_config = ml_collections.ConfigDict({
    'lr': 1e-4,
    'reward_pairs_emb_dim': 128,
    'hidden_dims': (512, 512, 512),
    'discount': 0.99,
    'expectile': 0.8,
    'temperature': 3.0, # 0 for behavior cloning.
    'tau': 0.001,
    'opt_decay_schedule': 'none',
    'warmup_steps': 150000,
    "num_discrete_embeddings": 32,
    'kl_weight': 0.01,
    'actor_loss_type': 'awr', # awr or ddpg.
    'bc_coefficient': 0.0,
})

wandb_config = default_wandb_config()
wandb_config.update({
    'project': 'fre_fre',
    'name': 'fre_{env_name}',
})

def GCDataset_get_default_config():
        return ml_collections.ConfigDict({
            'p_randomgoal': 0.3,
            'p_trajgoal': 0.5,
            'p_currgoal': 0.2,
            'geom_sample': 1,
            'discount': 0.99,
            'reward_scale': 1.0,
            'reward_shift': -1.0,
            'mask_terminal': 1,
        })

def transformer_get_default_config():
    config = ml_collections.ConfigDict({
        'num_layers': 4,
        'emb_dim': 256,
        'mlp_dim': 256,
        'num_heads': 4,
        'dropout_rate': 0.0,
        'attention_dropout_rate': 0.0,
        'causal': True,
    })
    return config

config_flags.DEFINE_config_dict('wandb', wandb_config, lock_config=False)
config_flags.DEFINE_config_dict('agent', agent_config, lock_config=False)
config_flags.DEFINE_config_dict('gcdataset', GCDataset_get_default_config(), lock_config=False)
config_flags.DEFINE_config_dict('transformer', transformer_get_default_config(), lock_config=False)


###############################
#  Agent. Contains the neural networks, training logic, and sampling.
###############################

def expectile_loss(adv, diff, expectile=0.7):
    weight = jnp.where(adv >= 0, expectile, (1 - expectile))
    return weight * (diff**2)

class IdentityLayer(nn.Module):
    """Identity layer, convenient for giving a name to an array."""

    @nn.compact
    def __call__(self, x):
        return jnp.expand_dims(x, axis=-1)


class FRENetwork(nn.Module):
    transformer_params: dict
    hidden_dims: Sequence[int]
    action_dim: int
    reward_pairs_emb_dim : int
    num_discrete_embeddings: int

    def setup(self):
        self.encoder_transformer = Transformer(**self.transformer_params)
        self.encoder_mean = nn.Dense(self.reward_pairs_emb_dim)
        self.encoder_log_std = nn.Dense(self.reward_pairs_emb_dim)

        self.reward_embed = nn.Embed(self.num_discrete_embeddings, self.reward_pairs_emb_dim // 2)
        self.embed_reward_pairs = nn.Dense(self.reward_pairs_emb_dim // 2)

        self.value = ValueCritic(self.hidden_dims)
        self.critic = ensemblize(Critic, num_qs=2)(self.hidden_dims)
        self.actor = Policy(self.hidden_dims, action_dim=self.action_dim,
            log_std_min=-5.0, state_dependent_std=False, tanh_squash_distribution=False)

        self.reward_predict = ValueCritic(self.hidden_dims)

    def __call__(self, x): # (batch_size, timesteps, emb_dim)
        raise None

    def get_transformer_encoding(self, reward_state_pairs):
        reward_states = reward_state_pairs[:, :, :-1]
        reward_values = reward_state_pairs[:, :, -1]
        reward_values_idx = jnp.floor((reward_values / 2.0 + 0.5) * self.num_discrete_embeddings).astype(jnp.int32)
        reward_values_idx = jnp.clip(reward_values_idx, 0, self.num_discrete_embeddings - 1)

        reward_state_emb = self.embed_reward_pairs(reward_states)
        reward_state_val = self.reward_embed(reward_values_idx)
        reward_state_pairs = jnp.concatenate([reward_state_emb, reward_state_val], axis=-1)

        w_pre = self.encoder_transformer(reward_state_pairs, train=True) # [batch, reward_pairs, emb_dim]
        w_pair_mean = w_pre.mean(axis=1)
        w_mean = self.encoder_mean(w_pair_mean)
        w_log_std = self.encoder_log_std(w_pair_mean)

        return w_mean, w_log_std # (batch_size, emb_dim)

    def get_value(self, w, obs):
        w_and_obs = jnp.concatenate([w, obs], axis=-1)
        return self.value(w_and_obs)

    def get_critic(self, w, obs, actions):
        w_and_obs = jnp.concatenate([w, obs], axis=-1)
        return self.critic(w_and_obs, actions)

    def get_actor(self, w, obs, temperature=1.0):
        w_and_obs = jnp.concatenate([w, obs], axis=-1)
        return self.actor(w_and_obs, temperature)

    def get_reward_pred(self, w, reward_pairs): # Reward Pairs: [batch, reward_pairs, obs_dim + 1]
        z_expand = jnp.expand_dims(w, axis=1) # [batch, 1, emb_dim]
        z_expand = jnp.repeat(z_expand, repeats=reward_pairs.shape[1], axis=1)
        reward_states = reward_pairs[:, :, :-1]
        w_and_obs = jnp.concatenate([z_expand, reward_states], axis=-1)
        reward_pred = self.reward_predict(w_and_obs)
        return reward_pred # [batch, reward_pairs]

    def get_all(self, reward_state_pairs, obs, actions):
        w_mean, w_log_std = self.get_transformer_encoding(reward_state_pairs)
        w = w_mean
        w_and_obs = jnp.concatenate([w, obs], axis=-1)
        ret = self.value(w_and_obs), self.get_actor(w, obs), self.get_reward_pred(w, reward_state_pairs), self.critic(w_and_obs, actions)
        return ret


class FREAgent(flax.struct.PyTreeNode):
    rng: PRNGKey
    fre: TrainState
    target_fre: TrainState
    config: dict = flax.struct.field(pytree_node=False)

    @functools.partial(jax.jit, static_argnames=('train_encoder', 'train_actor', 'train_critic'))
    def update(agent, batch: Batch, train_encoder=True, train_actor=True, train_critic=True, apply_updates=True) -> InfoDict:
        new_rng, w_key = jax.random.split(agent.rng, 2)
        reward_state_pairs = batch['reward_pairs_encode']
        reward_pairs_decode = batch['reward_pairs_decode']

        def full_loss_fn(params):
            if train_encoder:
                w_mean, w_log_std = agent.fre.do('get_transformer_encoding')(reward_state_pairs, params=params)
            else:
                w_mean, w_log_std = agent.fre.do('get_transformer_encoding')(reward_state_pairs)
            w_no_grad = jax.lax.stop_gradient(w_mean)

            if train_encoder:
                # Reward Pred Loss
                w = w_mean + jax.random.normal(agent.rng, w_mean.shape) * jnp.exp(w_log_std)
                reward_pred = agent.fre.do('get_reward_pred')(w, reward_pairs_decode, params=params)
                reward_truths = reward_pairs_decode[:, :, -1]
                reward_pred_loss = ((reward_pred - reward_truths)**2).mean()
                kl_loss = -0.5 * (1 + w_log_std - w_mean**2 - jnp.exp(w_log_std)).mean()
                reward_kl_loss = reward_pred_loss + kl_loss * agent.config['kl_weight']
                reward_pred_info = {
                    'reward_pred_loss': reward_pred_loss,
                    'reward_pred': reward_pred.mean(),
                    'kl_loss': kl_loss,
                }
            else:
                reward_kl_loss = 0.0
                reward_pred_info = {}

            if train_critic:
                # Implicit Q-Learning
                # Value Loss: Update V towards expectile of min(q1, q2).
                w_target_mean = w_no_grad
                w_mean = w_no_grad
                q1, q2 = agent.target_fre.do("get_critic")(w_target_mean, batch['observations'], batch['actions'])
                q = jnp.minimum(q1, q2)
                v = agent.fre.do("get_value")(w_mean, batch['observations'], params=params)
                adv = q - v
                v_loss = expectile_loss(adv, q - v, agent.config['expectile'])
                v_loss = (v_loss).mean()

                # Critic Loss. Update Q = r
                next_v = jax.lax.stop_gradient(agent.fre.do("get_value")(w_mean, batch['next_observations']))
                q = batch['rewards'] + agent.config['discount'] * batch['masks'] * next_v

                q1, q2 = agent.fre.do("get_critic")(w_mean, batch['observations'], batch['actions'], params=params)
                q_loss = (q1 - q) ** 2 + (q2 - q) ** 2
                q_loss = (q_loss).mean()

                value_loss = v_loss + q_loss
                value_info = {
                    # 'value_loss': value_loss,
                    'v_loss': v_loss,
                    'q_loss': q_loss,
                    'v': v.mean(),
                    'q': q.mean(),
                }
            else:
                value_loss = 0.0
                value_info = {}

            if train_actor:
                # Actor Loss
                actor_w = w_mean
                if agent.config['actor_loss_type'] == 'awr':
                    v = agent.fre.do("get_value")(w_no_grad, batch['observations'])
                    q1, q2 = agent.fre.do("get_critic")(w_no_grad, batch['observations'], batch['actions'])
                    q = jnp.minimum(q1, q2)
                    adv = q - v

                    actions = batch['actions']
                    exp_a = jnp.exp(adv * agent.config['temperature'])
                    exp_a = jnp.minimum(exp_a, 100.0)
                    dist = agent.fre.do('get_actor')(actor_w, batch['observations'], params=params)
                    log_probs = dist.log_prob(actions)
                    assert exp_a.shape == log_probs.shape
                    print("Log probs shape", log_probs.shape)
                    actor_loss = -(exp_a * log_probs).mean()
                elif agent.config['actor_loss_type'] == 'ddpg':
                    dist = agent.fre.do("get_actor")(actor_w, batch['observations'], params=params)
                    normalized_actions = jnp.tanh(dist.loc)
                    q1, q2 = agent.fre.do("get_critic")(w_no_grad, batch['observations'], normalized_actions)
                    q = (q1 + q2) / 2

                    q_loss = -q.mean()

                    log_probs = dist.log_prob(batch['actions'])
                    bc_loss = -((agent.config['bc_coefficient'] * log_probs)).mean()

                    actor_loss = ((q_loss + bc_loss)).mean()

                std = dist.stddev().mean()
                mse_error = jnp.square(dist.loc - batch['actions']).mean()
                actor_info = {
                    'actor_loss': actor_loss,
                    'std': std,
                    'adv': adv.mean(),
                    'mse_error': mse_error,
                }
            else:
                actor_loss = 0.0
                actor_info = {}

            return value_loss + actor_loss + reward_kl_loss, {**value_info, **actor_info, **reward_pred_info}

        new_fre, info = agent.fre.apply_loss_fn(loss_fn=full_loss_fn, has_aux=True)
        new_target_fre = target_update(agent.fre, agent.target_fre, agent.config['target_update_rate'])

        return agent.replace(fre=new_fre, target_fre=new_target_fre, rng=new_rng), {
            **info
        }

    @jax.jit
    def sample_actions(agent,
                       observations: np.ndarray, # [obs_dim]
                       reward_pairs: np.ndarray, # [1, reward_pairs, obs_dim + 1]
                       *,
                       seed: PRNGKey,
                       temperature: float = 1.0) -> jnp.ndarray:
        if type(observations) is dict:
            observations = jnp.concatenate([observations['observation'], observations['goal']], axis=-1)
        observations = jnp.expand_dims(observations, axis=0)
        print("Reward pairs shape", reward_pairs.shape)
        w_mean, w_log_std = agent.fre.do('get_transformer_encoding')(reward_pairs)
        print("W shape", w_mean.shape)
        actions = agent.fre.do('get_actor')(w_mean, observations, temperature=temperature).sample(seed=seed)
        actions = jnp.clip(actions, -1, 1)
        return actions[0]

    def get_reward_pred(agent, observations: np.ndarray, reward_pairs: np.ndarray):
        # append a dummy reward to the observations.
        decode_pairs = jnp.concatenate([observations, np.ones((observations.shape[0], 1))], axis=-1)[None]
        w_mean, w_log_std = agent.fre.do('get_transformer_encoding')(reward_pairs)
        return agent.fre.do('get_reward_pred')(w_mean, decode_pairs)

    def get_value_pred(agent, observations: np.ndarray, reward_pairs: np.ndarray):
        w_mean, w_log_std = agent.fre.do('get_transformer_encoding')(reward_pairs) # [batch, emb_dim]
        w_expand = jnp.repeat(w_mean, repeats=observations.shape[0], axis=0)
        v = agent.fre.do('get_value')(w_expand, observations)
        return v

def create_learner(
                seed: int,
                batch: Batch,
                transformer_params: dict,
                lr: float,
                reward_pairs_emb_dim: int,
                num_discrete_embeddings: int,
                kl_weight: float,
                hidden_dims: Sequence[int],
                discount: float,
                tau: float,
                expectile: float,
                temperature: float,
                max_steps: Optional[int],
                opt_decay_schedule: str,
                actor_loss_type: str,
                bc_coefficient: float,
            **kwargs):

        print('Extra kwargs:', kwargs)

        rng = jax.random.PRNGKey(seed)
        rng, actor_key, critic_key, value_key = jax.random.split(rng, 4)

        action_dim = batch['actions'].shape[-1]
        transformer_params['causal'] = False
        transformer_params['emb_dim'] = reward_pairs_emb_dim
        transformer_params['num_heads'] = 2
        transformer_params['num_layers'] = 2
        fre_def = FRENetwork(transformer_params, hidden_dims, action_dim, reward_pairs_emb_dim=reward_pairs_emb_dim, num_discrete_embeddings=num_discrete_embeddings)

        if opt_decay_schedule == "cosine":
            schedule_fn = optax.cosine_decay_schedule(-lr, max_steps)
            tx = optax.chain(optax.scale_by_adam(),
                                    optax.scale_by_schedule(schedule_fn))
        else:
            tx = optax.adam(learning_rate=lr)

        params = fre_def.init(actor_key, batch['reward_pairs_encode'], batch['observations'], batch['actions'], method='get_all')['params']
        fre = TrainState.create(fre_def, params, tx=tx)
        target_fre = TrainState.create(fre_def, params)

        config = flax.core.FrozenDict(dict(
            discount=discount, temperature=temperature, expectile=expectile, target_update_rate=tau, reward_pairs_emb_dim=reward_pairs_emb_dim, kl_weight=kl_weight, actor_loss_type=actor_loss_type, bc_coefficient=bc_coefficient, num_discrete_embeddings=num_discrete_embeddings
        ))

        return FREAgent(rng, fre=fre, target_fre=target_fre, config=config)

###############################
#  Run Script. Loads data, logs to wandb, and runs the training loop.
###############################

def main(_):
    # Create wandb logger
    setup_wandb(FLAGS.agent.to_dict(), **FLAGS.wandb)
    assert 'ant' in FLAGS.env_name or 'dmc' in FLAGS.env_name or 'kitchen' in FLAGS.env_name

    agent = None

    if FLAGS.save_dir is not None:
        os.makedirs(FLAGS.save_dir, exist_ok=True)
        print(f'Saving config to {FLAGS.save_dir}/config.pkl')
        with open(os.path.join(FLAGS.save_dir, 'config.pkl'), 'wb') as f:
            pickle.dump(get_flag_dict(), f)

    if 'ant' in FLAGS.env_name:
        #import fre.common.envs.d4rl.d4rl_ant as d4rl_ant
        env = d4rl_ant.CenteredMaze(FLAGS.env_name)
        dataset = get_dataset(env, FLAGS.env_name)
        dataset = dataset.copy({'masks': np.ones_like(dataset['masks'])})
        dataset_gc = GCDataset(dataset, **FLAGS.gcdataset.to_dict())
        example_batch = dataset.sample(1)
        eval_env = EpisodeMonitor(RewardOverride(d4rl_ant.CenteredMaze(FLAGS.env_name)))
        ## =============== Reward Functions for Testing =============== ##

        base_ob = example_batch['observations'][0]
        def goal_at(x,y):
            goal = base_ob.copy()
            goal[:2] = [x,y]
            return goal
        reward_fn_ratios = [FLAGS.rew_ratio_goal, FLAGS.rew_ratio_linear, FLAGS.rew_ratio_mlp]
        GoalReachingRewards = GoalReachingRewardFunction()
        VelocityRewards = VelocityRewardFunction()
        LinearRewards = LinearRewardFunction()
        SimplexRewards = SimplexRewardFunction(num_simplex=10)
        RandomRewards = RandomRewardFunction(num_simplex=10000)
        reward_fns = [GoalReachingRewards, LinearRewards, RandomRewards]

        linear_states = dataset.sample(5)['observations'][:, None, :]
        linear_params = LinearRewards.generate_params_and_pairs(linear_states, linear_states, linear_states)[0] # (5, params_dim)
        print("Linear Params: ", linear_params)
        test_rewards = [
            (GoalReachingRewards, 'goal_bottom', goal_at(28, 0)),
            (GoalReachingRewards, 'goal_left', goal_at(0, 15)),
            (GoalReachingRewards, 'goal_top', goal_at(35, 24)),
            (GoalReachingRewards, 'goal_center', goal_at(12, 24)),
            (GoalReachingRewards, 'goal_right', goal_at(33, 16)),
            (VelocityRewards, 'vel_left', np.array([-1, 0])),
            (VelocityRewards, 'vel_up', np.array([0, 1])),
            (VelocityRewards, 'vel_down', np.array([0, -1])),
            (VelocityRewards, 'vel_right', np.array([1, 0])),
            (SimplexRewards, 'simplex_1', np.array([1])),
            (SimplexRewards, 'simplex_2', np.array([2])),
            (SimplexRewards, 'simplex_3', np.array([3])),
            (SimplexRewards, 'simplex_4', np.array([4])),
            (SimplexRewards, 'simplex_5', np.array([5])),
            (TestRewPath(), 'path_center', np.array([0])),
            (TestRewLoop(), 'path_loop', np.array([0])),
            (TestRewMatrixEdges(), 'path_edges', np.array([0])),
        ]

        slices = []
        slices.append(0)
        for j in range(len(reward_fns)):
            slices.append(int(FLAGS.batch_size * reward_fn_ratios[j]))
        slices[-1] = FLAGS.batch_size - sum(slices[:-1])
        print("Number of samples for each reward func: ", slices)
        slices = np.cumsum(slices)
        print("Cumsum of samples for each reward func: ", slices)
    elif 'dmc' in FLAGS.env_name:
        _, env_name, task_name = FLAGS.env_name.split('_')
        env = make_env(f'{env_name}_{task_name}')
        env.reset()

        # Load dataset.
        import pathlib
        file_path = str(pathlib.Path().resolve().parents[0])
        path = file_path + f'/fre/data/exorl/{env_name}/rnd'
        dataset_npy = os.path.join(path, task_name + '.npy')
        dataset = np.load(dataset_npy, allow_pickle=True).item()
        dataset['dones_float'] = np.zeros_like(dataset['rewards'])
        dataset['dones_float'][::1000] = 1.0 # Each exorl trajectory is length 1000.
        dataset['dones_float'][-1] = 1.0 # Last state is terminal.

        # For evaluating the velocity rewareds, we need an augmented observation that uses the physics state.
        if 'walker' in FLAGS.env_name:
            aux = np.load(file_path+'/fre/data/aux_walker.npy', allow_pickle=True)
        elif 'cheetah' in FLAGS.env_name:
            aux = np.load(file_path+'/fre/data/aux_cheetah.npy', allow_pickle=True)
        dataset['observations'] = np.concatenate([dataset['observations'], aux], axis=1)
        aux_shifted = np.concatenate([aux[1::], aux[-1:]], axis=0)
        dataset['next_observations'] = np.concatenate([dataset['next_observations'], aux_shifted], axis=1)

        dataset = Dataset(dataset)
        dataset_gc = GCDataset(dataset, **FLAGS.gcdataset.to_dict())
        eval_env = EpisodeMonitor(RewardOverride(make_env(f'{env_name}_{task_name}')))
        eval_env.reset()

        def goal_at(seed):
            return dataset.sample(1, indx=seed*777)['observations']

        VelocityRewardsWalker = VelocityRewardFunctionWalker()
        VelocityRewardsCheetah = VelocityRewardFunctionCheetah()
        GoalReachingRewards = GoalReachingRewardFunction()
        LinearRewards = LinearRewardFunction()
        if 'walker' in FLAGS.env_name:
            reward_fns = [VelocityRewardsWalker]
            RandomRewards = RandomRewardFunction(num_simplex=10000, obs_len=27)
            test_rewards = [
                (VelocityRewardsWalker, 'vel0.1', np.array([0.1])),
                (VelocityRewardsWalker, 'vel1', np.array([1])),
                (VelocityRewardsWalker, 'vel4', np.array([4])),
                (VelocityRewardsWalker, 'vel8', np.array([8])),
                (GoalReachingRewards, 'goal_1', goal_at(1)),
                (GoalReachingRewards, 'goal_2', goal_at(2)),
                (GoalReachingRewards, 'goal_3', goal_at(3)),
                (GoalReachingRewards, 'goal_4', goal_at(4)),
                (GoalReachingRewards, 'goal_5', goal_at(5)),
            ]
        elif 'cheetah' in FLAGS.env_name:
            reward_fns = [VelocityRewardsCheetah]
            RandomRewards = RandomRewardFunction(num_simplex=10000, obs_len=18)
            test_rewards = [
                (VelocityRewardsCheetah, 'vel10Back', np.array([-10])),
                (VelocityRewardsCheetah, 'vel2Back', np.array([-2])),
                (VelocityRewardsCheetah, 'vel2', np.array([2])),
                (VelocityRewardsCheetah, 'vel10', np.array([10])),
                (GoalReachingRewards, 'goal_1', goal_at(1)),
                (GoalReachingRewards, 'goal_2', goal_at(2)),
                (GoalReachingRewards, 'goal_3', goal_at(3)),
                (GoalReachingRewards, 'goal_4', goal_at(4)),
                (GoalReachingRewards, 'goal_5', goal_at(5)),
            ]
        if FLAGS.dmc_use_oracle:
            pass
        else:
            reward_fns = [GoalReachingRewards, LinearRewards, RandomRewards]

        reward_fn_ratios = [FLAGS.rew_ratio_goal, FLAGS.rew_ratio_linear, FLAGS.rew_ratio_mlp]
        slices = []
        slices.append(0)
        for j in range(len(reward_fns)):
            slices.append(int(FLAGS.batch_size * reward_fn_ratios[j]))
        slices[-1] = FLAGS.batch_size - sum(slices[:-1])
        print("Number of samples for each reward func: ", slices)
        slices = np.cumsum(slices)
        print("Cumsum of samples for each reward func: ", slices)

    elif 'kitchen' in FLAGS.env_name:
        # HACK: Monkey patching to make it compatible with Python 3.10.
        import collections
        if not hasattr(collections, 'Mapping'):
            collections.Mapping = collections.abc.Mapping

        def make_kitchen_env(env_name):
            env = make_env(env_name)
            # Only use the first 30 dimensions (because the other half corresponds to the goal).
            env = TruncateObservation(env, truncate_size=30)
            return env
        dataset = get_dataset(make_kitchen_env(FLAGS.env_name), FLAGS.env_name, filter_terminals=True)
        dataset = dataset.copy({'observations': dataset['observations'][:, :30], 'next_observations': dataset['next_observations'][:, :30]})
        dataset = dataset.copy({'masks': np.ones_like(dataset['masks'])})
        dataset_gc = GCDataset(dataset, **FLAGS.gcdataset.to_dict())
        eval_env = EpisodeMonitor(RewardOverride(make_kitchen_env(FLAGS.env_name)))

        SingleTaskRewards = SingleTaskRewardFunction()
        GoalReachingRewards = GoalReachingRewardFunction()
        LinearRewards = LinearRewardFunction()
        RandomRewards = RandomRewardFunction(num_simplex=10000, obs_len=30)
        reward_fns = [GoalReachingRewards, LinearRewards, RandomRewards]

        reward_fn_ratios = [FLAGS.rew_ratio_goal, FLAGS.rew_ratio_linear, FLAGS.rew_ratio_mlp]
        slices = []
        slices.append(0)
        for j in range(len(reward_fns)):
            slices.append(int(FLAGS.batch_size * reward_fn_ratios[j]))
        slices[-1] = FLAGS.batch_size - sum(slices[:-1])
        print("Number of samples for each reward func: ", slices)
        slices = np.cumsum(slices)
        print("Cumsum of samples for each reward func: ", slices)
        test_rewards = [
            (SingleTaskRewards, 'binary_bottom_left_burner', np.array([1, 0, 0, 0, 0, 0, 0])),
            (SingleTaskRewards, 'binary_top_left_burner', np.array([0, 1, 0, 0, 0, 0, 0])),
            (SingleTaskRewards, 'binary_light_switch', np.array([0, 0, 1, 0, 0, 0, 0])),
            (SingleTaskRewards, 'binary_slide_cabinet', np.array([0, 0, 0, 1, 0, 0, 0])),
            (SingleTaskRewards, 'binary_hinge_cabinet', np.array([0, 0, 0, 0, 1, 0, 0])),
            (SingleTaskRewards, 'binary_microwave', np.array([0, 0, 0, 0, 0, 1, 0])),
            (SingleTaskRewards, 'binary_kettle', np.array([0, 0, 0, 0, 0, 0, 1])),
        ]
    else:
        raise NotImplementedError

    for i in tqdm.tqdm(range(1, FLAGS.max_steps + 1),
                       smoothing=0.1,
                       dynamic_ncols=True):

        # Sample a batch of trajectories.
        # Get future states from the trajectory, and random states.
        num_traj_states = min(8, FLAGS.reward_pairs_encode-1)

        num_random_states = FLAGS.reward_pairs_encode - num_traj_states
        num_random_states_decode = FLAGS.reward_pairs_decode
        batch = dataset_gc.sample_traj_random(FLAGS.batch_size, num_traj_states, num_random_states, num_random_states_decode)
        # The first index of traj_states contains the CURRENT state.
        assert batch['traj_states'].shape == (FLAGS.batch_size, num_traj_states, dataset['observations'].shape[-1])
        assert batch['random_states'].shape == (FLAGS.batch_size, num_random_states, dataset['observations'].shape[-1])
        assert batch['random_states_decode'].shape == (FLAGS.batch_size, num_random_states_decode, dataset['observations'].shape[-1])

        encode_pairs = np.zeros((FLAGS.batch_size, FLAGS.reward_pairs_encode, dataset['observations'].shape[-1] + 1))
        decode_pairs = np.zeros((FLAGS.batch_size, FLAGS.reward_pairs_decode, dataset['observations'].shape[-1] + 1))
        rewards = np.zeros((FLAGS.batch_size))
        masks = np.zeros((FLAGS.batch_size))
        for j in range(len(reward_fns)):
            batch_traj_states = batch['traj_states'][slices[j]:slices[j+1], :, :]
            batch_random_states = batch['random_states'][slices[j]:slices[j+1], :, :]
            batch_random_states_decode = batch['random_states_decode'][slices[j]:slices[j+1], :, :]
            params_slice, encode_pairs_slice, decode_pairs_slice, rewards_slice, masks_slice = reward_fns[j].generate_params_and_pairs(batch_traj_states, batch_random_states, batch_random_states_decode)
            encode_pairs[slices[j]:slices[j+1], :, :] = encode_pairs_slice
            decode_pairs[slices[j]:slices[j+1], :, :] = decode_pairs_slice
            rewards[slices[j]:slices[j+1]] = rewards_slice
            masks[slices[j]:slices[j+1]] = masks_slice

        assert len(encode_pairs.shape) == 3 # (batch_size, reward_pairs_encode, obs_dim + 1)
        assert len(decode_pairs.shape) == 3 # (batch_size, reward_pairs_decode, obs_dim + 1)

        batch['rewards'] = rewards
        batch['masks'] = masks
        batch['reward_pairs_encode'] = encode_pairs
        batch['reward_pairs_decode'] = decode_pairs

        if FLAGS.use_discrete_xy and 'ant' in FLAGS.env_name:
            batch['observations'] = d4rl_ant.discretize_obs(batch['observations'])
            batch['next_observations'] = d4rl_ant.discretize_obs(batch['next_observations'])

        # Don't train agents using the auxilliary physics states.
        if 'walker' in FLAGS.env_name:
            batch['observations'] = batch['observations'][:, :24]
            batch['next_observations'] = batch['next_observations'][:, :24]
        elif 'cheetah' in FLAGS.env_name:
            batch['observations'] = batch['observations'][:, :17]
            batch['next_observations'] = batch['next_observations'][:, :17]

        if agent is None:
            agent = create_learner(FLAGS.seed,
                batch,
                transformer_params=FLAGS.transformer.to_dict(),
                max_steps=FLAGS.max_steps,
                **FLAGS.agent)
            if FLAGS.load_dir is not None:
                agent = checkpoints.restore_checkpoint(FLAGS.load_dir, agent)

        agent, update_info = agent.update(batch,
                                        train_encoder=(i <= FLAGS.agent['warmup_steps']),
                                        train_actor=(i > FLAGS.agent['warmup_steps']),
                                        train_critic=(i > FLAGS.agent['warmup_steps'])
        )

        if i % FLAGS.log_interval == 0:
            train_metrics = {f'training/{k}': v for k, v in update_info.items()}

            # Debug logs for wandb.
            if i == FLAGS.log_interval:
                fig, ax = plt.subplots(figsize=(3.0, 3.0))
                reward_pair_rewards = batch['reward_pairs_encode'][:, :, -1].flatten()
                ax.hist(reward_pair_rewards, bins=32)
                ax.set_title("Reward Values (For Reward Pairs)")
                wandb.log({"Reward Values (For Reward Pairs)": wandb.Image(fig)}, step=i)
                fig.clf()
                plt.cla()
                plt.clf()
                plt.close('all')

            wandb.log(train_metrics, step=i)

        # Evaluate the unsupervised rewards on the evaluation rewards.
        if i % 10000 == 0 and i < FLAGS.agent['warmup_steps']:
            eval_rew_metrics = {}
            for k, test_reward in enumerate(test_rewards):
                test_reward_generator, test_reward_label, test_reward_params = test_reward
                random_states_encode = dataset.sample(FLAGS.reward_pairs_encode_test)['observations']
                random_states_decode = dataset.sample(FLAGS.reward_pairs_encode_test)['observations']
                test_reward_pairs = test_reward_generator.make_encoder_pairs_testing(test_reward_params[None], \
                                                                                        random_states_encode[None])
                test_reward_pairs_decode = test_reward_generator.make_encoder_pairs_testing(test_reward_params[None], \
                                                                                        random_states_decode[None])
                assert test_reward_pairs.shape == (1, FLAGS.reward_pairs_encode_test, dataset['observations'].shape[-1] + 1)
                true_decode_rewards = test_reward_pairs_decode[0, :, -1] # (reward_pairs_encode_test, )
                decode_predictions = agent.get_reward_pred(random_states_decode, test_reward_pairs)[0] # (reward_pairs_encode_test, )
                assert true_decode_rewards.shape == decode_predictions.shape
                loss = jnp.mean((true_decode_rewards - decode_predictions)**2)
                eval_rew_metrics[f'rew_pred/{test_reward_label}'] = loss
            if 'ant' in FLAGS.env_name:
                # Merge separate metrics into simpler metrics.
                total_goals = eval_rew_metrics['rew_pred/goal_bottom'] + eval_rew_metrics['rew_pred/goal_center'] + eval_rew_metrics['rew_pred/goal_top'] + eval_rew_metrics['rew_pred/goal_left'] + eval_rew_metrics['rew_pred/goal_right']
                total_velocity = eval_rew_metrics['rew_pred/vel_left'] + eval_rew_metrics['rew_pred/vel_up'] + eval_rew_metrics['rew_pred/vel_down'] + eval_rew_metrics['rew_pred/vel_right']
                total_simplex = eval_rew_metrics['rew_pred/simplex_1'] + eval_rew_metrics['rew_pred/simplex_2'] + eval_rew_metrics['rew_pred/simplex_3'] + eval_rew_metrics['rew_pred/simplex_4'] + eval_rew_metrics['rew_pred/simplex_5']
                total_path = eval_rew_metrics['rew_pred/path_center'] + eval_rew_metrics['rew_pred/path_loop'] + eval_rew_metrics['rew_pred/path_edges']
                eval_rew_metrics['rew_pred_total/total_goals'] = total_goals
                eval_rew_metrics['rew_pred_total/total_velocity'] = total_velocity
                eval_rew_metrics['rew_pred_total/total_simplex'] = total_simplex
                eval_rew_metrics['rew_pred_total/total_path'] = total_path
                wandb.log(eval_rew_metrics, step=i)
            elif 'dmc' in FLAGS.env_name:
                # Merge separate metrics into simpler metrics.
                total_goals = eval_rew_metrics['rew_pred/goal_1'] + eval_rew_metrics['rew_pred/goal_2'] + eval_rew_metrics['rew_pred/goal_3'] + eval_rew_metrics['rew_pred/goal_4'] + eval_rew_metrics['rew_pred/goal_5']
                if 'cheetah' in FLAGS.env_name:
                    total_vel = eval_rew_metrics['rew_pred/vel10Back'] + eval_rew_metrics['rew_pred/vel2Back'] + eval_rew_metrics['rew_pred/vel2'] + eval_rew_metrics['rew_pred/vel10']
                elif 'walker' in FLAGS.env_name:
                    total_vel = eval_rew_metrics['rew_pred/vel0.1'] + eval_rew_metrics['rew_pred/vel1'] + eval_rew_metrics['rew_pred/vel4'] + eval_rew_metrics['rew_pred/vel8']
                eval_rew_metrics['rew_pred_total/total_goals'] = total_goals
                eval_rew_metrics['rew_pred_total/total_vel'] = total_vel
                wandb.log(eval_rew_metrics, step=i)

        # Evaluate on test tasks. These are training tasks AND test tasks.
        if i % FLAGS.eval_interval == 0 or (i == 10000 and FLAGS.eval_interval < 10006000):
            print("Performing Eval Loop")
            record_video = i % FLAGS.video_interval == 0
            eval_metrics = {}

            for k, test_reward in enumerate(test_rewards):
                test_reward_generator, test_reward_label, test_reward_params = test_reward
                print("Eval on reward function", test_reward_label)

                # Update eval env to record the right reward.
                def override_reward(s):
                    r = test_reward_generator.compute_reward(s[None,:], test_reward_params[None, :])
                    return r[0]
                eval_env.env.reward_fn = override_reward
                random_states_encode = dataset.sample(FLAGS.reward_pairs_encode_test)['observations']
                test_reward_pairs = test_reward_generator.make_encoder_pairs_testing(test_reward_params[None], \
                                                                                        random_states_encode[None])
                assert test_reward_pairs.shape == (1, FLAGS.reward_pairs_encode_test, dataset['observations'].shape[-1] + 1)

                # Run policy.
                policy_fn = functools.partial(supply_rng(agent.sample_actions), temperature=0.0, reward_pairs=test_reward_pairs)
                if 'dmc' in FLAGS.env_name:
                    eval_info, trajs = evaluate(policy_fn, eval_env, num_episodes=FLAGS.eval_episodes, record_video=record_video, return_trajectories=True, clip_return_at_goal=('goal' in test_reward_label), use_discrete_xy=False, clip_margin=100)
                elif 'antmaze' in FLAGS.env_name:
                    eval_info, trajs = evaluate(policy_fn, eval_env, num_episodes=FLAGS.eval_episodes, record_video=record_video, return_trajectories=True, clip_return_at_goal=('goal' in test_reward_label), use_discrete_xy=FLAGS.use_discrete_xy)
                elif 'kitchen' in FLAGS.env_name:
                    eval_info, trajs = evaluate(policy_fn, eval_env, num_episodes=FLAGS.eval_episodes, record_video=record_video, return_trajectories=True, clip_return_at_goal=('goal' in test_reward_label), use_discrete_xy=False, binary_return=('binary' in test_reward_label))
                else:
                    raise NotImplementedError

                eval_metrics[f'evaluation/{test_reward_label}.return'] = eval_info['episode.return']
                if record_video:
                    wandb.log({f'{test_reward_label}.video': eval_info['video']}, step=i)

                # Antmaze Specific Logging
                if 'antmaze' in FLAGS.env_name and 'large' in FLAGS.env_name and FLAGS.env_name.startswith('antmaze'):
                    #import fre.experiment.ant_helper as ant_helper
                    # Make an image of the trajectories.
                    traj_image = d4rl_ant.trajectory_image(eval_env, trajs)
                    # eval_metrics[f'trajectories/{test_reward_label}'] = wandb.Image(traj_image)

                    # Make image of reward function predictions.
                    test_reward_expand = np.tile(test_reward_params[None, :], (280, 1)) # (280, 3)
                    ground_truth_rew = lambda s_grid : test_reward_generator.compute_reward(s_grid, test_reward_expand)
                    true_rew_img = ant_helper.value_image(eval_env, dataset, ground_truth_rew, None)
                    # eval_metrics[f'draw_true/{test_reward_label}'] = wandb.Image(true_rew_img)

                    mask = []
                    for pair in test_reward_pairs[0]:
                        mask.append(pair[:2])
                    mask_rew_img = ant_helper.value_image(eval_env, dataset, ground_truth_rew, mask)
                    # eval_metrics[f'draw_mask/{test_reward_label}'] = wandb.Image(mask_rew_img)

                    pred_rew = lambda s_grid : agent.get_reward_pred(s_grid, test_reward_pairs)
                    pred_rew_img = ant_helper.value_image(eval_env, dataset, pred_rew, None)
                    # eval_metrics[f'draw_pred/{test_reward_label}'] = wandb.Image(pred_rew_img)

                    def pred_value(s_grid):
                        if FLAGS.use_discrete_xy and 'ant' in FLAGS.env_name:
                            s_grid = d4rl_ant.discretize_obs(s_grid)
                        return agent.get_value_pred(s_grid, test_reward_pairs)
                    pred_value_img = ant_helper.value_image(eval_env, dataset, pred_value, None, clip=False)
                    # eval_metrics[f'draw_value1/{test_reward_label}'] = wandb.Image(pred_value_img1)


                    full_img = np.concatenate([
                        np.concatenate([true_rew_img, mask_rew_img], axis=0),
                        np.concatenate([pred_rew_img, traj_image], axis=0),
                        np.concatenate([pred_value_img, pred_value_img], axis=0)
                    ], axis=1)
                    print("Min/Max of full_img is", np.min(full_img), np.max(full_img))
                    # if any nans, breakpoint.
                    if np.isnan(full_img).any():
                        breakpoint()
                    eval_metrics[f'draw/{test_reward_label}'] = wandb.Image(full_img)

            if 'ant' in FLAGS.env_name:
                # Merge separate metrics into simpler metrics.
                total_goals = eval_metrics['evaluation/goal_bottom.return'] + eval_metrics['evaluation/goal_center.return'] + eval_metrics['evaluation/goal_top.return'] + eval_metrics['evaluation/goal_left.return'] + eval_metrics['evaluation/goal_right.return']
                total_velocity = eval_metrics['evaluation/vel_left.return'] + eval_metrics['evaluation/vel_up.return'] + eval_metrics['evaluation/vel_down.return'] + eval_metrics['evaluation/vel_right.return']
                total_simplex = eval_metrics['evaluation/simplex_1.return'] + eval_metrics['evaluation/simplex_2.return'] + eval_metrics['evaluation/simplex_3.return'] + eval_metrics['evaluation/simplex_4.return'] + eval_metrics['evaluation/simplex_5.return']
                total_path = eval_metrics['evaluation/path_center.return'] + eval_metrics['evaluation/path_loop.return'] + eval_metrics['evaluation/path_edges.return']
                eval_metrics['evaluation_total/total_goals'] = total_goals
                eval_metrics['evaluation_total/total_velocity'] = total_velocity
                eval_metrics['evaluation_total/total_simplex'] = total_simplex
                eval_metrics['evaluation_total/total_path'] = total_path
                print(eval_metrics)
            elif 'dmc' in FLAGS.env_name:
                total_goals = eval_metrics['evaluation/goal_1.return'] + eval_metrics['evaluation/goal_2.return'] + eval_metrics['evaluation/goal_3.return'] + eval_metrics['evaluation/goal_4.return'] + eval_metrics['evaluation/goal_5.return']
                eval_metrics['evaluation/total_goals'] = total_goals
            elif 'kitchen' in FLAGS.env_name:
                total_test = 0.
                for test_reward in test_rewards:
                    total_test += eval_metrics[f'evaluation/{test_reward[1]}.return']
                eval_metrics['evaluation/total_test'] = total_test

            wandb.log(eval_metrics, step=i)

        if i % FLAGS.save_interval == 0 and FLAGS.save_dir is not None:
            checkpoints.save_checkpoint(FLAGS.save_dir, agent, i)

if __name__ == '__main__':
    app.run(main)


  and should_run_async(code)


KeyError: 'verbosity'