In [None]:
!pip install mctx
!pip install jumanji
!pip install flashbax
!pip install dm-haiku

In [None]:
# solver.py
from typing import Callable, Any

import jax
import jax.numpy as jnp


def bisection_method(
        f: Callable[[jax.Array, dict[str, Any]], jax.Array],
        n: int,
        max_steps: int,
        step_size: float = 0.0
) -> Callable[..., jax.Array]:
    """Compile a bisection method.

    Converges to the root (if it exists) with O(n * sqrt(max_steps))

    Only supports 1-Dimensional search.
    """
    batch_fun = jax.vmap(f, in_axes=(0, None))

    def body(
            carry: tuple[tuple[jax.Array, jax.Array], dict],
            x: None = None
    ) -> tuple[
        tuple[tuple[jax.Array, jax.Array], dict],
        jax.Array | float
    ]:
        prev_bounds, kwargs = carry

        grid = jnp.linspace(*prev_bounds, 2 * n + 1)
        step = grid[1] - grid[0]
        values = batch_fun(grid, kwargs).squeeze()
        values = jnp.nan_to_num(values, nan=1e32, neginf=-1e10, posinf=1e10)

        # Transform values to make the function's roots an attractor.
        # Then argmax/ argmin can find the best bounds without array reshaping.
        best_positive = grid.at[jnp.argmax(1.0 / (values + 1e-32))].get()
        best_negative = grid.at[jnp.argmin(1.0 / (values - 1e-32))].get()

        # Construct new-bounds closest to the function root with some slack.
        new_bounds = (best_negative - step * step_size,
                      best_positive + step * step_size)

        return (new_bounds, kwargs), sum(new_bounds) / 2.0

    def run(
            bounds: tuple[jax.typing.ArrayLike, jax.typing.ArrayLike],
            kwargs: dict[str, Any]
    ) -> jax.Array:
        _, best = jax.lax.scan(
            body, (bounds, kwargs), xs=None, length=max_steps
        )
        return best[-1]

    return run


In [None]:
# network_2048.py
import jax
import jax.numpy as jnp
from flax import linen as nn
from jax import grad, jit
from flax.training import train_state
import optax


class PolicyValueNetwork_2048(nn.Module):
    num_actions: int
    num_channels: int

    @nn.compact
    def __call__(self, x):
        # Conv Layers + MLP Layer.
        k_size = (2, 2)
        x = nn.Conv(features=self.num_channels, kernel_size=k_size)(x)
        x = nn.leaky_relu(x)
        x = jnp.reshape(x, (x.shape[0], -1))  # Flatten

        # Policy Layers.
        actions = nn.Dense(128)(x) # TODO try value norm
        actions = nn.leaky_relu(actions)
        actions = nn.Dense(128)(actions)
        actions = nn.leaky_relu(actions)
        actions = nn.Dense(self.num_actions)(actions)
        actions = nn.softmax(actions)

        # Value Layers
        value = nn.Dense(256)(x)
        value = nn.leaky_relu(value)
        value = nn.Dense(256)(value)
        value = nn.leaky_relu(value)
        value = nn.Dense(1)(value)

        return actions, value

In [None]:
# network_knapsack.py
import jax
import jax.numpy as jnp
import chex
import haiku as hk
from flax.linen import compact
from jumanji.training.networks.knapsack.actor_critic import (
    make_knapsack_masks,
    make_knapsack_query,
    KnapsackTorso,
)
from jumanji.environments.packing.knapsack.types import Observation
from typing import Tuple

config = {
    "transformer_num_blocks": 6,
    "transformer_num_heads": 8,
    "transformer_key_size": 16,
    "transformer_mlp_units": [512],
}

def value_fn(observation: Observation) -> chex.Array:
    torso = KnapsackTorso(
        transformer_num_blocks=config["transformer_num_blocks"],
        transformer_num_heads=config["transformer_num_heads"],
        transformer_key_size=config["transformer_key_size"],
        transformer_mlp_units=config["transformer_mlp_units"],
        name="torso",
    )
    self_attention_mask, cross_attention_mask = make_knapsack_masks(observation)
    items_features = jnp.concatenate(
        [observation.weights[..., None], observation.values[..., None]], axis=-1
    )
    embeddings = torso(items_features, self_attention_mask)
    query = make_knapsack_query(observation, embeddings)
    cross_attention_block = hk.MultiHeadAttention(
        num_heads=config["transformer_num_heads"],
        key_size=config["transformer_key_size"],
        w_init=hk.initializers.VarianceScaling(1.0),
        name="cross_attention_block",
    )
    cross_attention = cross_attention_block(
        query=query,
        value=embeddings,
        key=embeddings,
        mask=cross_attention_mask,
    ).squeeze(axis=-2)
    values = jnp.einsum("...Tk,...k->...T", embeddings, cross_attention)
    values = values / jnp.sqrt(cross_attention_block.model_size)
    value = values.sum(axis=-1, where=cross_attention_mask.squeeze(axis=(-2, -3)))
    return value


def policy_fn(observation: Observation) -> chex.Array:
    torso = KnapsackTorso(
        transformer_num_blocks=config["transformer_num_blocks"],
        transformer_num_heads=config["transformer_num_heads"],
        transformer_key_size=config["transformer_key_size"],
        transformer_mlp_units=config["transformer_mlp_units"],
        name="torso",
    )
    self_attention_mask, cross_attention_mask = make_knapsack_masks(observation)
    items_features = jnp.concatenate(
        [observation.weights[..., None], observation.values[..., None]], axis=-1
    )
    embeddings = torso(items_features, self_attention_mask)
    query = make_knapsack_query(observation, embeddings)
    cross_attention_block = hk.MultiHeadAttention(
        num_heads=config["transformer_num_heads"],
        key_size=config["transformer_key_size"],
        w_init=hk.initializers.VarianceScaling(1.0),
        name="cross_attention_block",
    )
    cross_attention = cross_attention_block(
        query=query,
        value=embeddings,
        key=embeddings,
        mask=cross_attention_mask,
    ).squeeze(axis=-2)
    logits = jnp.einsum("...Tk,...k->...T", embeddings, cross_attention)
    logits = logits / jnp.sqrt(cross_attention_block.model_size)
    logits = 10 * jnp.tanh(logits)  # clip to [-10,10]
    logits = jnp.where(observation.action_mask, logits, jnp.finfo(jnp.float32).min)
    actions = jax.nn.softmax(logits)
    return actions


def forward_fn(inputs) -> Tuple[chex.Array, chex.Array]:
    weights = inputs[:, 0, :].astype(jnp.float32)
    values = inputs[:, 1, :].astype(jnp.float32)
    packed_items = inputs[:, 2, :].astype(jnp.bool)
    action_mask = inputs[:, 3, :].astype(jnp.bool)
    observation = Observation(weights, values, packed_items, action_mask)
    return policy_fn(observation), value_fn(observation)


In [None]:
# cnn.py
import jax
import jax.numpy as jnp
from flax import linen as nn
from jax import grad, jit
from flax.training import train_state
import optax
from jax import random


class CNNPolicyNetwork(nn.Module):
    """A simple policy network that outputs a probability distribution over actions."""

    num_actions: int # Number of possible actions
    num_channels: int

    @nn.compact
    def __call__(self, x):
        #x = jnp.reshape(x, (x.shape[0], -1)) #flatten, do not that we get errors when we do not input batches
        k_size = (3, 3)

        x = nn.Conv(features=self.num_channels, kernel_size=k_size)(x)
        x = nn.leaky_relu(x)
        x = nn.Conv(features=self.num_channels, kernel_size=k_size)(x)
        x = nn.leaky_relu(x)
        x = jnp.reshape(x, (x.shape[0], -1))  # Flatten
        x = nn.Dense(64)(x)
        x = nn.leaky_relu(x)
        x = nn.Dense(64)(x)
        x = nn.leaky_relu(x)
        x = nn.Dense(64)(x)
        x = nn.leaky_relu(x)
        x = nn.Dense(self.num_actions)(x)
        x = nn.softmax(x)
        return x


# Define the Value Network
class CNNValueNetwork(nn.Module):
    """A simple value network."""
    #num_outputs: int = 1
    num_channels: int

    @nn.compact
    def __call__(self, x):
        # key = random.PRNGKey(758493)
        # x = random.uniform(key, shape=x.shape)
        k_size = (3, 3)

        #x = jnp.reshape(x, (x.shape[0], -1)) # flatten
        x = nn.Conv(features=self.num_channels, kernel_size=k_size)(x)
        x = nn.leaky_relu(x)
        x = nn.Conv(features=self.num_channels, kernel_size=k_size)(x)
        x = nn.leaky_relu(x)
        x = jnp.reshape(x, (x.shape[0], -1))  # Flatten
        x = nn.Dense(128)(x)
        x = nn.leaky_relu(x)
        x = nn.Dense(128)(x)
        x = nn.leaky_relu(x)
        x = nn.Dense(128)(x)
        x = nn.leaky_relu(x)
        x = nn.Dense(1)(x)
        return x


In [None]:
# network_snake.py
import jax.numpy as jnp
from flax import linen as nn


class PolicyValueNetwork(nn.Module):
    num_actions: int
    num_channels: int

    @nn.compact
    def __call__(self, x):
        # Conv Layers + MLP Layer.
        k_size = (2, 2)
        x = nn.Conv(features=self.num_channels, kernel_size=k_size, strides=(2, 2))(x)
        x = nn.leaky_relu(x)
        x = nn.Conv(features=self.num_channels, kernel_size=k_size)(x)
        x = nn.leaky_relu(x)
        x = jnp.reshape(x, (x.shape[0], -1))  # Flatten


        # Policy Layers.
        actions = nn.Dense(64)(x)
        actions = nn.leaky_relu(actions)
        actions = nn.Dense(64)(actions)
        actions = nn.leaky_relu(actions)
        actions = nn.Dense(self.num_actions)(actions)
        actions = nn.softmax(actions)

        # Value Layers
        value = nn.Dense(128)(x)
        value = nn.leaky_relu(value)
        value = nn.Dense(128)(value)
        value = nn.leaky_relu(value)
        value = nn.Dense(1)(value)

        return actions, value


In [None]:
# wandb_logging.py
import wandb
import jax.numpy as jnp

env_short_names = {
 "Game2048-v1": "2048",
    "Knapsack-v1": "Knapsack",
    "Maze-v0": "Maze",
    "Snake-v1": "Snake",
}

def init_wandb(params):
    if params["run_in_kaggle"]:
        from kaggle_secrets import UserSecretsClient
        user_secrets = UserSecretsClient()
        wandb_api = user_secrets.get_secret("wandb_api")
    else:
        import os
        wandb_api = os.environ['WANDB_API_KEY']

    relevant_params = {k: v for k, v in params.items() if k not in
                      ["maze_size", "agent", "num_actions", "obs_spec", "run_in_kaggle", "logging", "buffer_max_length", "buffer_min_length"]}

    wandb.login(key=wandb_api)
    wandb.init(
        project="action-selection-mcts",
        name=f"{env_short_names[params['env_name']]}_{params['policy']}_sim{params['num_simulations']}_seed{params['seed']}",
        config=relevant_params)


def log_rewards(reward, loss, episode, params):
    wandb.log(reward)

In [None]:
# plotting.py
import matplotlib.pyplot as plt
import numpy as np
import wandb
def plot_losses(all_results_array):

    # Assuming all_results_array is a numpy array
    all_results_array = np.array(all_results_array)

    fig, ax1 = plt.subplots()

    # Plot the Total Return on the left y-axis
    ax1.set_xlabel('Episode')
    ax1.tick_params(axis='y', labelcolor='b')
    if all_results_array.shape[1] == 3: #Single, combined loss
        ax1.plot(all_results_array[:, 2], label="Loss", color='b')
        ax1.set_ylabel('Loss', color='b')
    else: #Separate value and policy loss
        ax1.set_ylabel('Value Loss', color='b')
        ax1.plot(all_results_array[:, 2], label="Value Loss", color='b')

        # Create a second y-axis sharing the same x-axis
        ax2 = ax1.twinx()
        ax2.plot(all_results_array[:, 3], label="Policy Loss", color='r')
        ax2.set_ylabel('Policy Loss', color='r')
        ax2.tick_params(axis='y', labelcolor='r')

    ax1.xaxis.set_major_locator(plt.MaxNLocator(integer=True))

    # Adding legends to both axes
    fig.legend(loc="upper right", bbox_to_anchor=(1, 1), bbox_transform=ax1.transAxes)

    plt.show()
    if wandb.run is not None:
        plot = wandb.Image(fig)
        wandb.log({"Loss": plot})
        plt.clf()

def plot_rewards(all_results_array):
    # Assuming all_results_array is a numpy array
    all_results_array = np.array(all_results_array)

    fig, ax1 = plt.subplots()

    # Plot the Total Return on the left y-axis
    ax1.plot(all_results_array[:, 0], label="Total Return", color='b')
    ax1.set_xlabel('Episode')
    ax1.set_ylabel('Total Return', color='b')
    ax1.tick_params(axis='y', labelcolor='b')

    # Create a second y-axis sharing the same x-axis
    ax2 = ax1.twinx()
    ax2.plot(all_results_array[:, 1], label="Max Return", color='r')
    ax2.set_ylabel('Max Return', color='r')
    ax2.tick_params(axis='y', labelcolor='r')

    ax1.xaxis.set_major_locator(plt.MaxNLocator(integer=True))

    # Adding legends to both axes
    fig.legend(loc="upper right", bbox_to_anchor=(1, 1), bbox_transform=ax1.transAxes)

    plt.show()
    if wandb.run is not None:
        plot = wandb.Image(fig)
        wandb.log({"Returns": plot})
        plt.clf()

In [None]:
# interface.py
from typing import Any

import abc

import jax
import jax.numpy as jnp
import jax.typing as jxt

from dataclasses import dataclass



@dataclass
class PolicyObjective(abc.ABC):
    """Base class for implementing proposal distributions for control.

    The proposal distribution solves a constrained program,
        max_q <Q, q>
        s.t.,
        D(q, pi) < epsilon
        int q(a) da = 1

    For the hard-constraint, one needs to specify epsilon to __call__.
    For the soft-constraint solution, one can also specify the Lagrange
    multiplier directly.

    Only discrete spaces are supported for a very good reason. For
    compatibility with continuous spaces one needs to sample from `pi`
    directly, and then passing a uniform distribution to __call__.
    """
    solver: Any | None = None
    logits: bool = False

    # Tolerance constants for validating solutions.
    _norm_tolerance: float = 0.001
    _epsilon_ltol: float = 0.001
    _epsilon_rtol: float = 0.01

    def set_solver(self, solver):
        self.solver = solver

    @staticmethod
    def _validate_args(
            inv_beta: jxt.ArrayLike | None,
            epsilon: jxt.ArrayLike | None,
            raise_ambiguity: bool = False
    ):
        if (epsilon is None) and (inv_beta is None):
            raise ValueError("Both `epsilon` and `inv_beta` cannot be None!")

        if raise_ambiguity:
            if (epsilon is not None) and (inv_beta is not None):
                raise ValueError(
                    "Ambiguity Error. Values given for both "
                    "`epsilon` and `inv_beta`! "
                )

    @staticmethod
    def epsilon_greedy(
            q: jax.Array,
            eps: jax.typing.ArrayLike = 0.0
    ) -> jax.Array:
        greedy = (q == q.max())
        return (greedy / greedy.sum()) * (1 - eps) + eps / q.size

    @abc.abstractmethod
    def trust_region_upperbound(
            self, q: jax.Array, pi: jax.Array
    ) -> jax.Array:
        pass

    @abc.abstractmethod
    def lagrangian(self, *args, **kwargs) -> tuple:
        pass

    @staticmethod
    @abc.abstractmethod
    def divergence(
            q_star: jax.Array,
            pi: jax.Array
    ) -> jxt.ArrayLike:
        pass

    @abc.abstractmethod
    def __call__(
            self,
            q: jax.Array,
            pi: jax.Array,
            *,
            epsilon: jxt.ArrayLike | None = None,
            inv_beta: jxt.ArrayLike | None = None
    ) -> jax.Array:
        # Compute Lagrangian solution w.r.t. q
        pass


@dataclass
class AnalyticObjective(PolicyObjective, abc.ABC):
    """Base class for objectives that are fully analytical.

    """

    def __init__(self, *, logits: bool = False):
        super().__init__(solver=object(), logits=logits)

    def lagrangian(self, *args, **kwargs) -> tuple:
        return 0,  # Solution is fully analytical


@dataclass
class NumericalNormalizerObjective(PolicyObjective, abc.ABC):
    """Base class for objectives without analytical trust-region constraint.

    Provides solver utility to estimate the optimal Lagrange multiplier.
    """
    num_init: int = 30
    recursive_steps: int = 10

    def __post_init__(self):
        if self.solver is None:
            self.solver = bisection_method(
                lambda x, k: self._objective(x, **k),
                self.num_init, self.recursive_steps
            )

    def _objective(self, x, **kwargs) -> jax.Array:
        return self.lagrangian(x, **kwargs)[0]

    @abc.abstractmethod
    def get_search_bounds(
            self,
            q: jax.Array,
            log_pi: jax.Array,
            *,
            inv_beta: jax.Array | None = None,
            epsilon: jax.Array | None = None
    ) -> tuple[jax.Array, jax.Array]:
        """Get informative bounds for the normalizer in the true search-space.
        """
        pass

    @abc.abstractmethod
    def solve_normalizer(self, *args, **kwargs) -> jxt.ArrayLike:
        pass


@dataclass
class NumericalTrustRegionObjective(PolicyObjective, abc.ABC):
    """Base class for objectives without analytical trust-region constraint.

    Provides solver utility to estimate the optimal Lagrange multiplier.
    """
    num_init: int = 30
    recursive_steps: int = 10

    def __post_init__(self):
        if self.solver is None:
            self.solver = bisection_method(
                lambda x, k: self._objective(x, **k),
                self.num_init, self.recursive_steps
            )

    def _objective(self, x, **kwargs) -> jax.Array:
        return self.lagrangian(x, **kwargs)[0]

    @abc.abstractmethod
    def solve_trust_region(self, *args, **kwargs) -> jxt.ArrayLike:
        pass


@dataclass
class NumericalNormalizerAndTrustRegionObjective(PolicyObjective, abc.ABC):
    """Base class for objectives without analytical trust-region constraint.

    Provides solver utility to estimate the optimal Lagrange multiplier.
    """
    num_init: int = 30
    recursive_steps: int = 10

    def __post_init__(self):

        if self.solver is None:
            self.solver = bisection_method(
                lambda x, k: self._objective(x, **k)[0],
                self.num_init, self.recursive_steps, step_size=0.0
            )

    def _objective(self, x, **kwargs) -> jax.Array:
        return jnp.asarray(self.lagrangian(x, **kwargs))

    @abc.abstractmethod
    def get_search_bounds(
            self,
            q: jax.Array,
            log_pi: jax.Array,
            *,
            inv_beta: jax.Array | None = None,
            epsilon: jax.Array | None = None
    ) -> tuple[jax.Array, jax.Array]:
        """Get informative bounds for the normalizer in the true search-space.
        """
        pass

    @abc.abstractmethod
    def solve(self, *args, **kwargs) -> jxt.ArrayLike:
        pass


In [None]:
# mixins.py
"""Module for useful mixin classes to extend base class functionalities.

"""
from typing import Callable, Any
from functools import partial

import abc

import jax
import jax.numpy as jnp



class TypesMixin:
    num_init: int
    recursive_steps: int
    solver: Any
    bounds: tuple[float, float]

    bounds_slack: tuple[float, float] = (1.0, 0.1)

    _validate_args: Callable[[jax.Array, jax.Array, bool], None]
    _objective: Callable[..., jax.Array]


class SolveNormalizerMixin(TypesMixin):
    """Mixin class to implement a generic solver for the normalization value.
    """

    @staticmethod
    def transform(x):
        return jnp.log(x)

    def get_search_bounds(
            self,
            q: jax.Array,
            log_pi: jax.Array,
            *,
            inv_beta: jax.Array | None = None,
            epsilon: jax.Array | None = None
    ) -> tuple[jax.Array, jax.Array]:
        """Get informative bounds in the true search-space."""
        ...

    def solve_normalizer(
            self,
            q: jax.Array,
            log_pi: jax.Array,
            *,
            inv_beta: jax.Array | None = None,
            epsilon: jax.Array | None = None
    ) -> jax.Array:
        """Do a recursive grid-search + local hill-climber.

        Search is optionally performed (default = log) in a monotonically
        transformed search-space. This helps the stability of search,
        especially at the boundaries.

        This method returns the results in the *transformed* space.
        """
        self._validate_args(inv_beta, epsilon, True)

        # Compute bounds to speed up search.
        canonical_low, canonical_high = self.get_search_bounds(
            q, log_pi, inv_beta=inv_beta, epsilon=epsilon
        )

        results = self.solver(
            (self.transform(canonical_low) - self.bounds_slack[0],
             self.transform(canonical_high) + self.bounds_slack[1]),
            kwargs=dict(log_pi=log_pi, q=q, inv_beta=inv_beta, epsilon=epsilon)
        )

        return results


class SolveNormalizerAndTrustRegionMixin(TypesMixin):
    # Split up compute-budget for normalization and the trust-region.
    num_init_tr: int
    recursive_steps_tr: int

    def __post_init__(self):
        super().__post_init__()  # type: ignore

        # Quick hacky way to compose two bisection-search methods.
        # self.inv_beta_solver = bisection_method(
        #     lambda x, k: self._objective(
        #         self.solver(
        #             k['bounds'],
        #             kwargs=dict(log_inv_beta=x) | {
        #                 a: b for a, b in k.items() if a != 'bounds'
        #             }
        #         ),
        #         log_inv_beta=x, **{
        #             a: b for a, b in k.items() if a != 'bounds'
        #         }
        #     )[1],
        #     self.num_init_tr, self.recursive_steps_tr, step_size=0.1
        # )
        self.inv_beta_solver = bisection_method(
            self._tr_objective,
            self.num_init_tr, self.recursive_steps_tr, step_size=0.1
        )

    def _tr_objective(self, x: jax.Array, kwargs: dict[str, Any]):
        unpack = {a: b for a, b in kwargs.items() if a != 'bounds'}

        # For a given log_inv_beta = 'x', normalize the solution
        f_eta = self.solver(
            kwargs['bounds'],
            kwargs=dict(log_inv_beta=x) | unpack
        )

        # Then return the lagrangian of the solution for 'x'
        return self._objective(f_eta, log_inv_beta=x, **unpack)[1]

    @staticmethod
    def transform(x):
        return jnp.log(x)

    def get_search_bounds(
            self,
            q: jax.Array,
            log_pi: jax.Array,
            *,
            inv_beta: jax.Array | None = None,
            epsilon: jax.Array | None = None
    ) -> tuple[jax.Array, jax.Array]:
        """Get informative bounds in the true search-space."""
        ...

    def solve(
            self,
            q: jax.Array,
            log_pi: jax.Array,
            *,
            inv_beta: jax.Array | None = None,
            epsilon: jax.Array | None = None
    ):
        """
        """
        self._validate_args(inv_beta, epsilon, True)

        if inv_beta is None:
            low, high = self.get_search_bounds(
                q, log_pi, epsilon=epsilon
            )
            search_eta_low = self.transform(low) - self.bounds_slack[0]
            search_eta_high = self.transform(high) + self.bounds_slack[1]

            log_inv_beta = self.inv_beta_solver(
                self.bounds,  # log_inv_beta bounds
                kwargs=dict(
                    bounds=(search_eta_low, search_eta_high),  # eta bounds
                    q=q, log_pi=log_pi, epsilon=epsilon
                )
            )
        else:
            log_inv_beta = jnp.log(inv_beta)

        # Find the normalizer using tight bounds.
        low, high = self.get_search_bounds(
            q, log_pi, inv_beta=jnp.exp(log_inv_beta)
        )
        search_eta_low = self.transform(low) - self.bounds_slack[0]
        search_eta_high = self.transform(high) + self.bounds_slack[1]

        # If inv_beta is given, we can reduce to 1D optimization
        search_eta_star = self.solver(
            (search_eta_low, search_eta_high),
            kwargs=dict(q=q, log_pi=log_pi, log_inv_beta=log_inv_beta)
        )
        return search_eta_star, log_inv_beta


class SandwichTrustRegionMixin(TypesMixin, abc.ABC):
    """

    """

    # Core Constants
    logits: bool
    _epsilon_ltol: float
    _epsilon_rtol: float

    # Core Methods
    epsilon_greedy: Callable[[jax.Array, jax.typing.ArrayLike], jax.Array]
    trust_region_upperbound: Callable[[jax.Array, jax.Array], jax.Array]

    # Mixin-functionality
    _greedy_jitter: float = 0.0

    def _trust_region_interior(
            self,
            q: jax.Array,
            pi: jax.Array,
            *,
            epsilon: jax.typing.ArrayLike | None = None,
            inv_beta: jax.typing.ArrayLike | None = None
    ) -> jax.Array:
        ...

    def _trust_region_boundary(
            self,
            q: jax.Array,
            pi: jax.Array,
            *,
            use_prior: bool
    ):
        """Get q-star when the trust-region constraint gives an extreme-case

        If epsilon is tiny, then return the prior. If epsilon exceeds the
        divergence between the greedy policy and the prior, return the greedy
        policy.
        """
        greedy = self.epsilon_greedy(q, self._greedy_jitter)
        out = jax.lax.select(use_prior, pi, greedy)

        # Warning will return -infs for epsilon \approx 1.0
        return jnp.log(out) if self.logits else out

    def __call__(
            self,
            q: jax.Array,
            pi: jax.Array,
            *,
            epsilon: jax.typing.ArrayLike | None = None,
            inv_beta: jax.typing.ArrayLike | None = None
    ) -> jax.Array:
        self._validate_args(inv_beta, epsilon, True)

        all_same = jnp.isclose(q, q[0]).all()

        if epsilon is not None:
            epsilon_ub = self.trust_region_upperbound(q, pi)

            use_prior = all_same | (epsilon < self._epsilon_ltol)
            use_greedy = epsilon > jnp.clip(epsilon_ub - self._epsilon_rtol, 0)

            return jax.lax.cond(
                ~use_prior & ~use_greedy,
                partial(self._trust_region_interior, epsilon=epsilon),
                partial(self._trust_region_boundary, use_prior=use_prior),
                q, pi
            )

        # We cannot check bounds for inv_beta without knowing the normalizer.
        inv_beta = jnp.clip(inv_beta, jnp.exp(self.bounds[0]))
        return jax.lax.cond(
            ~all_same,
            partial(self._trust_region_interior, inv_beta=inv_beta),
            lambda *_: pi,
            q, pi
        )


In [None]:
# agent.py
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state, checkpoints
import wandb

class Agent: 
    def __init__(self, params):

        self.network = params.get("network", PolicyValueNetwork_2048)(num_actions=params["num_actions"], num_channels=params["num_channels"])
        self.optimizer = optax.adam(params['lr'])
        self.input_shape = self.input_shape_fn(params["obs_spec"])

        self.key = jax.random.PRNGKey(params['seed'])
        
        self.train_state = train_state.TrainState.create(
            apply_fn=self.network.apply,
            params=self.network.init(self.key, jnp.ones((1, *self.input_shape))),
            tx=self.optimizer
        )

        self.net_apply_fn = jax.jit(self.train_state.apply_fn)
        self.grad_fn = jax.value_and_grad(self.loss_fn)

        self.last_mse_losses = []
        self.last_kl_losses = []

    def input_shape_fn(self, observation_spec):
        raise NotImplementedError()

    def get_state_from_observation(self, observation, batched):
        raise NotImplementedError()

    def normalize_rewards(self, r):
        return r

    def reverse_normalize_rewards(self, r):
        return r

    def save(self, path, step):
        checkpoints.save_checkpoint(
            target=self.train_state,
            ckpt_dir=path,
            step=step,
            overwrite=True,
            prefix="agent_",
        )

    def loss_fn(self, params, states, actions, returns, episode):
        # KL Loss for policy part of the network:
        probs, values = self.net_apply_fn(params, states)
        # optax expects this to be log probabilities
        log_probs = jnp.log(probs + 1e-9)

        targets = actions

        kl_loss = optax.losses.kl_divergence(log_predictions=log_probs, targets=targets)
        kl_loss = jnp.mean(kl_loss)

        # MSE Loss for value part of the network:
        mse_loss = optax.l2_loss(values.flatten(), returns)
        mse_loss = jnp.mean(mse_loss)

        self.last_mse_losses.append(mse_loss.item())
        self.last_kl_losses.append(kl_loss.item())

        return kl_loss + mse_loss


    def update_fn(self, states, actions, returns, episode):
        returns = self.normalize_rewards(returns)
        loss, grads = self.grad_fn(self.train_state.params, states, actions, returns, episode)
        self.train_state = self.train_state.apply_gradients(grads=grads)
        return loss


    def get_output(self, state):
        mask = state.action_mask

        # the state has to be gotten depending on the environment
        state = self.get_state_from_observation(state, True)

        # forward pass of the network
        actions, value = self.net_apply_fn(self.train_state.params, state)
        actions = jnp.ravel(actions)
        value = jnp.ravel(value)[0]

        # mask and renormalize the actions
        masked_actions = self.mask_actions(actions, mask)
        renormalized_actions = masked_actions / jnp.sum(masked_actions)

        value = self.reverse_normalize_rewards(value)

        return renormalized_actions, value

    def log_losses(self, episode, params):
        if params["logging"]:
            wandb.log({
                "kl_loss": sum(self.last_kl_losses) / len(self.last_kl_losses),
                "mse_loss": sum(self.last_mse_losses) / len(self.last_mse_losses),

            }, step=episode*params["num_steps"]*params["num_batches"])
        self.last_kl_losses = []
        self.last_mse_losses = []
        
    
    def mask_actions(self, actions, mask):
        return jnp.where(mask, actions, 0)



In [None]:
# solve_norm.py
from __future__ import annotations

from dataclasses import dataclass

import jax
import jax.numpy as jnp
from jax import typing as jxt



@dataclass
class ExPropKullbackLeibler(
    SandwichTrustRegionMixin, SolveNormalizerMixin, NumericalNormalizerObjective
):
    """Implements the Kullback-Leibler divergence from prior to model.

    KL(prior || model) = int_X prior(x) log(prior(x) / model(x)) dx

    This KL is often obtained in empirical density fitting, as its
    minimization w.r.t. the model is equivalent to minimization of the
    cross-entropy loss. It is also obtained by swapping the model and
    prior arguments for the Variational KL. This is known in literature as
    Expectation-Propagation.

    This divergence is also referred to, ambiguously, as the forward-KL.

    For parametric model-fitting, this divergence induces moment-matching
    behaviour.
    """

    num_init: int = 16
    recursive_steps: int = 5
    bounds: tuple[float, float] = (-20, 10.0)

    _greedy_jitter: float = 1e-2

    def lagrangian(
        self,
        log_eta: jax.Array,
        q: jax.Array,
        log_pi: jax.Array,
        *,
        inv_beta: jax.Array | None = None,
        epsilon: jax.Array | None = None,
    ) -> tuple:
        """Computes the partial log-Lagrangian for the normalizer eta."""
        self._validate_args(inv_beta, epsilon, True)

        log_z = jax.nn.logsumexp(
            jnp.asarray([jnp.broadcast_to(log_eta, q.shape), jnp.log(-q + 1e-32)]),
            axis=0,
        )

        if epsilon is not None:
            log_inv_beta = jnp.sum(jnp.exp(log_pi) * log_z) - epsilon
        else:
            log_inv_beta = jnp.log(inv_beta)

        log_q_star = log_pi + log_inv_beta - log_z

        # 1 - exp(logits) = 0 --> logits = 0
        return (jax.nn.logsumexp(log_q_star),)

    def get_search_bounds(
        self,
        q: jax.Array,
        log_pi: jax.Array,
        *,
        inv_beta: jax.Array | None = None,
        epsilon: jax.Array | None = None,
    ) -> tuple[jax.Array, jax.Array]:
        # Bound for eta depends on beta_inv. Given epsilon, the solution to
        # beta_inv depends on eta. So unless the constraint is soft, we
        # need to expand the search-domain for the hard-constraint program.
        if epsilon is not None:
            return (
                jnp.max(q + jnp.exp(self.bounds[0]) * jnp.exp(log_pi)),
                q.max() + jnp.exp(self.bounds[-1]),
            )

        return jnp.max(q + inv_beta * jnp.exp(log_pi)), q.max() + inv_beta

    @staticmethod
    def divergence(q_star: jax.Array, pi: jax.Array) -> jxt.ArrayLike:
        divergence = pi * (
            jnp.clip(jnp.log(pi), -1e3) - jnp.clip(jnp.log(q_star), -1e3)
        )
        return divergence.sum()

    @staticmethod
    def inv_beta(
        q: jax.Array,
        pi: jax.Array,
        eta: jax.Array,
        epsilon: jxt.ArrayLike | None = None,
    ) -> jax.Array:
        # Get analytical solution to the trust-region multiplier.
        log_z = jnp.log(eta - q)
        return jnp.exp(jnp.sum(pi * log_z) - epsilon)

    def _trust_region_interior(
        self,
        q: jax.Array,
        pi: jax.Array,
        *,
        epsilon: jxt.ArrayLike | None = None,
        inv_beta: jxt.ArrayLike | None = None,
    ) -> jax.Array:
        q = q - q.max()
        log_pi = jnp.log(jnp.clip(pi, 1e-16))

        log_eta = self.solve_normalizer(q, log_pi, inv_beta=inv_beta, epsilon=epsilon)

        log_z = jax.nn.logsumexp(
            jnp.asarray([jnp.broadcast_to(log_eta, q.shape), jnp.log(-q + 1e-32)]),
            axis=0,
        )

        if epsilon is not None:
            log_inv_beta = jnp.sum(jnp.exp(log_pi) * log_z) - epsilon
        else:
            log_inv_beta = jnp.log(inv_beta)

        logits = log_pi + log_inv_beta - log_z
        return logits if self.logits else jnp.exp(logits)

    def trust_region_upperbound(self, q: jax.Array, pi: jax.Array) -> jax.Array:
        """Uses the KL-divergence from pi to an epsilon-greedy policy.

        We have to use an epsilon-greedy policy as a heuristic since the
        KL from pi to greedy is undefined. The greedy policy is the dirac
        measure on the maximum of q. The logarithm inside the KL grows to
        infinity due to this measure.
        """
        max_supported = self.divergence(
            self(q, pi, inv_beta=jnp.exp(self.bounds[0])), pi
        )
        jittered = self.divergence(self.epsilon_greedy(q, self._greedy_jitter), pi)
        return jnp.minimum(max_supported, jittered)


@dataclass
class SquaredHellinger(
    SandwichTrustRegionMixin, SolveNormalizerMixin, NumericalNormalizerObjective
):
    """Implements the squared Hellinger distance from prior to model.

    H^2(prior, model) = 2 - 2int_X sqrt(prior(x) * model(x)) dx

    This objective is symmetric in its arguments and bounded between [0, 1].

    Note: as a result of the divergence being bounded, epsilon and inv_beta
          are also bounded. If epsilon exceeds tuned bounds, we opt for a
          uniform or greedy distribution to prevent numerical problems in
          the solver. For inv_beta we cannot check the bounds without knowing
          the normalizer, so we clip it to the epsilon-bound in the solver.
    """

    num_init: int = 16
    recursive_steps: int = 5
    bounds: tuple[float, float] = (-20, 10.0)

    _epsilon_rtol: float = 0.01

    def lagrangian(
        self,
        log_eta: jax.Array,
        q: jax.Array,
        log_pi: jax.Array,
        *,
        inv_beta: jax.Array | None = None,
        epsilon: jax.Array | None = None,
    ) -> tuple:
        """Computes the partial log-Lagrangian for the normalizer eta."""
        self._validate_args(inv_beta, epsilon, True)
        eta = jnp.exp(log_eta)

        log_z = jnp.log((2 * jnp.abs(eta - q)) + 1e-32)
        norm = jax.nn.logsumexp(log_pi - log_z)

        if epsilon is not None:
            log_inv_beta = jnp.log(1 - epsilon) - norm
        else:
            log_inv_beta = jnp.log(jnp.clip(inv_beta, 1e-32))

        log_q_star = log_pi + 2 * (log_inv_beta - log_z)

        # 1 - exp(logits) = 0 --> logits = 0
        return (jax.nn.logsumexp(log_q_star),)

    def get_search_bounds(
        self,
        q: jax.Array,
        log_pi: jax.Array,
        *,
        inv_beta: jax.Array | None = None,
        epsilon: jax.Array | None = None,
    ) -> tuple[jax.Array, jax.Array]:
        # Bound for eta depends on beta_inv. Given epsilon, the solution to
        # beta_inv depends on eta. So unless the constraint is soft, we
        # need to expand the search-domain for the hard-constraint program.
        if epsilon is not None:
            return (
                jnp.max(q + jnp.exp(self.bounds[0]) * jnp.exp(0.5 * log_pi)),
                q.max() + jnp.exp(self.bounds[-1]),
            )

        return (jnp.max(q + inv_beta * jnp.exp(0.5 * log_pi)), q.max() + inv_beta)

    @staticmethod
    def inv_beta(
        q: jax.Array,
        pi: jax.Array,
        eta: jax.Array,
        epsilon: jxt.ArrayLike | None = None,
    ) -> jax.Array:
        # Get analytical solution to the trust-region multiplier.
        z = (2 * jnp.abs(eta - q)) + 1e-8
        norm = jnp.sum(pi / z)

        return (1 - epsilon) / norm

    @staticmethod
    def divergence(q_star: jax.Array, pi: jax.Array) -> jxt.ArrayLike:
        return 1 - jnp.sqrt(q_star * pi).sum()

    def _trust_region_interior(
        self,
        q: jax.Array,
        pi: jax.Array,
        *,
        epsilon: jxt.ArrayLike | None = None,
        inv_beta: jxt.ArrayLike | None = None,
    ) -> jax.Array:
        q = q - q.max()
        log_pi = jnp.log(jnp.clip(pi, 1e-32))

        log_eta = self.solve_normalizer(q, log_pi, inv_beta=inv_beta, epsilon=epsilon)

        log_z = jnp.log((2 * jnp.abs(jnp.exp(log_eta) - q)) + 1e-32)
        norm = jax.nn.logsumexp(log_pi - log_z)

        if epsilon is not None:
            log_inv_beta = jnp.log(1 - epsilon) - norm
        else:
            log_inv_beta = jnp.log(jnp.clip(inv_beta, 1e-32))

        log_q_star = log_pi + 2 * (log_inv_beta - log_z)

        return log_q_star if self.logits else jnp.exp(log_q_star)

    def trust_region_upperbound(self, q: jax.Array, pi: jax.Array) -> jax.Array:
        max_supported = self.divergence(
            self(q, pi, inv_beta=jnp.exp(self.bounds[0])), pi
        )
        jittered = self.divergence(self.epsilon_greedy(q, self._greedy_jitter), pi)
        return jnp.minimum(max_supported, jittered)


In [None]:
# solve_trust_region.py
from __future__ import annotations
from dataclasses import dataclass

import jax
import jax.numpy as jnp
from jax import typing as jxt



@dataclass
class VariationalKullbackLeibler(
    SandwichTrustRegionMixin, NumericalTrustRegionObjective
):
    """Implements the Kullback-Leibler divergence from model to prior.

    KL(model || prior) = int_X model(x) log(model(x) / prior(x)) dx

    This is the KL as derived uniquely through the evidence lower bound.
    It is also referred to, ambiguously, as the reverse-KL.

    For parametric model-fitting, this divergence induces mode-finding
    behaviour.
    """

    num_init: int = 10
    recursive_steps: int = 5

    bounds: tuple[float, float] = (-20.0, 10.0)

    _greedy_jitter = 0.01

    def lagrangian(
        self,
        min_log_beta: jax.Array,
        q: jax.Array,
        log_pi: jax.Array,
        *,
        epsilon: jax.Array,
    ) -> tuple:
        """Computes the partial log-Lagrangian for inv_beta in log-space."""
        logits = log_pi + q * jnp.exp(-min_log_beta)
        log_z = jax.nn.logsumexp(logits)

        log_q_star = logits - log_z
        return (jnp.log(epsilon) - jax.nn.logsumexp(log_q_star, b=log_q_star - log_pi),)

    def solve_trust_region(
        self, q: jax.Array, log_pi: jax.Array, epsilon: jxt.ArrayLike
    ) -> jax.Array:
        # = log(inv_beta)
        return self.solver(
            self.bounds, kwargs=dict(q=q, log_pi=log_pi, epsilon=epsilon)
        )

    @staticmethod
    def divergence(q_star: jax.Array, pi: jax.Array) -> jxt.ArrayLike:
        divergence = q_star * (
            jnp.clip(jnp.log(q_star), -1e3) - jnp.clip(jnp.log(pi), -1e3)
        )
        return divergence.sum()

    def _trust_region_interior(
        self,
        q: jax.Array,
        pi: jax.Array,
        *,
        epsilon: jxt.ArrayLike | None = None,
        inv_beta: jxt.ArrayLike | None = None,
    ) -> jax.Array:
        q = q - q.max()
        log_pi = jnp.log(jnp.clip(pi, 1e-16))

        if epsilon is not None:
            min_log_beta = self.solve_trust_region(q, log_pi, epsilon=epsilon)
            inv_beta = jnp.exp(min_log_beta)

        logits = log_pi + q / jnp.clip(inv_beta, 1e-16)

        if self.logits:
            return logits

        return jax.nn.softmax(logits)

    def trust_region_upperbound(self, q: jax.Array, pi: jax.Array) -> jax.Array:
        max_supported = self.divergence(
            self(q, pi, inv_beta=jnp.exp(self.bounds[0])), pi
        )
        jittered = self.divergence(self.epsilon_greedy(q, self._greedy_jitter), pi)
        return jnp.minimum(max_supported, jittered)


In [None]:
# agent_2048.py
import jax.numpy as jnp

class Agent2048(Agent): 
    def __init__(self, params):
        params["network"] = PolicyValueNetwork_2048
        super().__init__(params)


    def normalize_rewards(self, r):
        # Log2 the rewards and then normalize given that 2^16 is the highest realistic reward
        # return jnp.log2(r + 1) / 16
        return r / 50

    def reverse_normalize_rewards(self, r):
         # do the reverse of the above
        # return 2 ** (r * 16) - 1
        return r * 50

    def input_shape_fn(self, observation_spec):
        return observation_spec.board.shape

    def get_state_from_observation(self, observation, batched=True):
        state = observation.board
        if batched and len(state.shape) == 2:
            state = state[None, ...]
        return state

In [None]:
# agent_knapsack.py
import jax.numpy as jnp
import haiku.experimental.flax as hkflax
import haiku as hk
import jax
from flax.training import train_state
import optax

class AgentKnapsack(Agent):
    def __init__(self, params):
        self.network = hkflax.Module(hk.transform(forward_fn))
        self.optimizer = optax.adam(params['lr'])
        self.input_shape = self.input_shape_fn(params["obs_spec"])

        self.key = jax.random.PRNGKey(params['seed'])
        
        self.train_state = train_state.TrainState.create(
            apply_fn=self.network.apply,
            params=self.network.init(self.key, jnp.ones((1, *self.input_shape))),
            tx=self.optimizer
        )

        self.net_apply_fn = jax.jit(self.train_state.apply_fn)
        self.grad_fn = jax.value_and_grad(self.loss_fn)
    
        self.last_mse_losses = []
        self.last_kl_losses = []
    # def mask_actions(self, actions, mask):
    #     return actions

    def input_shape_fn(self, observation_spec):
        return (4, *observation_spec.weights.shape)

    def get_state_from_observation(self, observation, batched=True):
        state = jnp.stack([observation.weights, observation.values, observation.packed_items, observation.action_mask], axis=-2)
        assert state.shape[-2:] == (4, observation.weights.shape[-1])
        if batched and len(state.shape) == 2:
            state = state[None, ...]
        return state

In [None]:
# agent_maze.py
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state

class AgentMaze(Agent): 
    def __init__(self, params, env):
        params["policy_network"] = CNNPolicyNetwork
        params["value_network"] = CNNValueNetwork
        # TODO: Think those are not really needed. Check if we can remove it. 

        self.key = jax.random.PRNGKey(params['seed'])
        self.env = env        
        state, self.timestep = jax.jit(env.reset)(self.key)
        #print(self.timestep.discount)
        self._observation_spec = state
        self._action_spec = env.action_spec

        self.policy_network = params.get("policy_network")(num_actions=self._action_spec.num_values, num_channels=params['num_channels'])
        self.value_network = params.get("value_network")(num_channels=params['num_channels'])
        self.policy_optimizer = optax.adam(params['lr'])
        self.value_optimizer = optax.adam(params['lr'])

        self.input_shape = self.input_shape_fn(self._observation_spec).shape
        print(self.input_shape)

        key1, key2 = jax.random.split(self.key)
        
        self.policy_train_state = train_state.TrainState.create(
            apply_fn=self.policy_network.apply,
            params=self.policy_network.init(key1, jnp.ones((1, *self.input_shape))),
            tx=self.policy_optimizer
        )

        self.value_train_state = train_state.TrainState.create(
            apply_fn=self.value_network.apply,
            params=self.value_network.init(key2, jnp.ones((1, *self.input_shape))),
            tx=self.value_optimizer
        )

        self.policy_apply_fn = jax.jit(self.policy_train_state.apply_fn)
        self.value_apply_fn = jax.jit(self.value_train_state.apply_fn)

        self.policy_grad_fn = jax.value_and_grad(self.compute_policy_loss)
        self.value_grad_fn = jax.value_and_grad(self.compute_value_loss)


    def input_shape_fn(self, observation_spec):
        return self.process_observation(observation_spec)

    def get_state_from_observation(self, observation, batched=True):
        state = self.process_observation(observation)
        if batched and len(state.shape) == 3:
            state = state[None, ...]
        return state
    
    def process_observation(self, observation):
        """Add the agent and the target to the walls array."""
        agent = 2
        target = 3
        obs = observation.walls.astype(float)
        obs = obs.at[tuple(observation.agent_position)].set(agent)
        obs = obs.at[tuple(observation.target_position)].set(target)
        # jax.debug.print("{obs}", obs=obs)
        return jnp.expand_dims(obs, axis=-1)  # Adding a channels axis.

In [None]:
# agent_snake.py
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state

class AgentSnake(Agent): 
    def __init__(self, params):
        params["network"] = PolicyValueNetwork
        super().__init__(params)


    def input_shape_fn(self, observation_spec):
        return observation_spec.grid.shape

    def get_state_from_observation(self, observation, batched=True):
        state = observation.grid
        if batched and len(state.shape) == 3:
            state = state[None, ...]
        return state

In [None]:
# action_selection.py
import chex
import jax
import mctx



def custom_action_selection(
        rng_key: chex.PRNGKey,
        tree: mctx.Tree,
        node_index: chex.Numeric,
        depth: chex.Numeric,
        *,
        pb_c_init: float = 1.25,
        pb_c_base: float = 19652.0,
        qtransform=mctx.qtransform_by_parent_and_siblings,
        selector=VariationalKullbackLeibler()
) -> chex.Array:
    """Returns the action selected for a node index.

    See Appendix B in https://arxiv.org/pdf/1911.08265.pdf for more details.

    Args:
      rng_key: random number generator state.
      tree: _unbatched_ MCTS tree state.
      node_index: scalar index of the node from which to select an action.
      depth: the scalar depth of the current node. The root has depth zero.
      pb_c_init: constant c_1 in the PUCT formula.
      pb_c_base: constant c_2 in the PUCT formula.
      qtransform: a monotonic transformation to convert the Q-values to [0, 1].

    Returns:
      action: the action selected from the given node.
    """
    prior_logits = tree.children_prior_logits[node_index]
    prior_probs = jax.nn.softmax(prior_logits)
    value_score = qtransform(tree, node_index)
    dist = selector(prior_logits, value_score, inv_beta=1.0)

    # Add tiny bit of randomness for tie break

    # Masking the invalid actions at the root.
    #   return masked_argmax(to_argmax, tree.root_invalid_actions * (depth == 0))
    return jax.random.categorical(rng_key, dist)


In [None]:
# tree_policies.py
import functools
from typing import Optional, Tuple

import chex
import jax
import jax.numpy as jnp
import mctx
from mctx import search

def muzero_custom_policy(
    params,
    selector,
    rng_key: chex.PRNGKey,
    root: mctx.RootFnOutput,
    recurrent_fn: mctx.RecurrentFn,
    num_simulations: int,
    invalid_actions: Optional[chex.Array] = None,
    max_depth: Optional[int] = None,
    loop_fn = jax.lax.fori_loop,
    *,
    qtransform = mctx.qtransform_by_parent_and_siblings,
    dirichlet_fraction: chex.Numeric = 0.25,
    dirichlet_alpha: chex.Numeric = 0.3,
    pb_c_init: chex.Numeric = 1.25,
    pb_c_base: chex.Numeric = 19652,
    temperature: chex.Numeric = 1.0) -> mctx.PolicyOutput[None]:
  """Runs MuZero search and returns the `PolicyOutput`.

  In the shape descriptions, `B` denotes the batch dimension.

  Args:
    params: params to be forwarded to root and recurrent functions.
    rng_key: random number generator state, the key is consumed.
    root: a `(prior_logits, value, embedding)` `RootFnOutput`. The
      `prior_logits` are from a policy network. The shapes are
      `([B, num_actions], [B], [B, ...])`, respectively.
    recurrent_fn: a callable to be called on the leaf nodes and unvisited
      actions retrieved by the simulation step, which takes as args
      `(params, rng_key, action, embedding)` and returns a `RecurrentFnOutput`
      and the new state embedding. The `rng_key` argument is consumed.
    num_simulations: the number of simulations.
    invalid_actions: a mask with invalid actions. Invalid actions
      have ones, valid actions have zeros in the mask. Shape `[B, num_actions]`.
    max_depth: maximum search tree depth allowed during simulation.
    loop_fn: Function used to run the simulations. It may be required to pass
      hk.fori_loop if using this function inside a Haiku module.
    qtransform: function to obtain completed Q-values for a node.
    dirichlet_fraction: float from 0 to 1 interpolating between using only the
      prior policy or just the Dirichlet noise.
    dirichlet_alpha: concentration parameter to parametrize the Dirichlet
      distribution.
    pb_c_init: constant c_1 in the PUCT formula.
    pb_c_base: constant c_2 in the PUCT formula.
    temperature: temperature for acting proportionally to
      `visit_counts**(1 / temperature)`.

  Returns:
    `PolicyOutput` containing the proposed action, action_weights and the used
    search tree.
  """
  rng_key, dirichlet_rng_key, search_rng_key = jax.random.split(rng_key, 3)

  # Adding Dirichlet noise.
  noisy_logits = _get_logits_from_probs(
      _add_dirichlet_noise(
          dirichlet_rng_key,
          jax.nn.softmax(root.prior_logits),
          dirichlet_fraction=dirichlet_fraction,
          dirichlet_alpha=dirichlet_alpha))
  root = root.replace(
      prior_logits=_mask_invalid_actions(noisy_logits, invalid_actions))

  # Running the search.
  interior_action_selection_fn = functools.partial(
      custom_action_selection,
      selector=selector,
    #   pb_c_base=pb_c_base,
    #   pb_c_init=pb_c_init,
      qtransform=qtransform)
  root_action_selection_fn = functools.partial(
      interior_action_selection_fn,
      depth=0)
  search_tree = search(
      params=params,
      rng_key=search_rng_key,
      root=root,
      recurrent_fn=recurrent_fn,
      root_action_selection_fn=root_action_selection_fn,
      interior_action_selection_fn=interior_action_selection_fn,
      num_simulations=num_simulations,
      max_depth=max_depth,
      invalid_actions=invalid_actions,
      loop_fn=loop_fn)

  # Sampling the proposed action proportionally to the visit counts.
  summary = search_tree.summary()
  action_weights = summary.visit_probs
  action_logits = _apply_temperature(
      _get_logits_from_probs(action_weights), temperature)
  action = jax.random.categorical(rng_key, action_logits)
  return mctx.PolicyOutput(
      action=action,
      action_weights=action_weights,
      search_tree=search_tree)
  
def _mask_invalid_actions(logits, invalid_actions):
  """Returns logits with zero mass to invalid actions."""
  if invalid_actions is None:
    return logits
  chex.assert_equal_shape([logits, invalid_actions])
  logits = logits - jnp.max(logits, axis=-1, keepdims=True)
  # At the end of an episode, all actions can be invalid. A softmax would then
  # produce NaNs, if using -inf for the logits. We avoid the NaNs by using
  # a finite `min_logit` for the invalid actions.
  min_logit = jnp.finfo(logits.dtype).min
  return jnp.where(invalid_actions, min_logit, logits)


def _get_logits_from_probs(probs):
  tiny = jnp.finfo(probs).tiny
  return jnp.log(jnp.maximum(probs, tiny))


def _add_dirichlet_noise(rng_key, probs, *, dirichlet_alpha,
                         dirichlet_fraction):
  """Mixes the probs with Dirichlet noise."""
  chex.assert_rank(probs, 2)
  chex.assert_type([dirichlet_alpha, dirichlet_fraction], float)

  batch_size, num_actions = probs.shape
  noise = jax.random.dirichlet(
      rng_key,
      alpha=jnp.full([num_actions], fill_value=dirichlet_alpha),
      shape=(batch_size,))
  noisy_probs = (1 - dirichlet_fraction) * probs + dirichlet_fraction * noise
  return noisy_probs


def _apply_temperature(logits, temperature):
  """Returns `logits / temperature`, supporting also temperature=0."""
  # The max subtraction prevents +inf after dividing by a small temperature.
  logits = logits - jnp.max(logits, keepdims=True, axis=-1)
  tiny = jnp.finfo(logits.dtype).tiny
  return logits / jnp.maximum(tiny, temperature)


In [None]:
# Main.py
import functools
from functools import partial
from typing import Optional

import flashbax as fbx
import jax
import jax.numpy as jnp
from jax import random
from flax.training import checkpoints

import logging
import jumanji
import mctx
from jumanji.environments.routing.maze import generator
from jumanji.types import StepType
from jumanji.wrappers import AutoResetWrapper

import wandb

# Environments: Snake-v1, Knapsack-v1, Game2048-v1, Maze-v0
params = {
    "env_name": "Knapsack-v1",
    "maze_size": (5, 5),
    "policy": "KL_ex_prop",
    "agent": AgentKnapsack,
    "num_channels": 32,
    "seed": 42,
    "lr": 2e-4,  # 0.00003
    "num_episodes": 4000,
    "num_steps": 200,
    "num_actions": 4,
    "obs_spec": Optional,
    "buffer_max_length": 20000,
    "buffer_min_length": 256,
    "num_batches": 64,
    "sample_size": 256,
    "num_simulations": 8,  # 16,
    "max_tree_depth": 4,  # 12,
    "discount": 1,
    "logging": True,
    "run_in_kaggle": True,
    "checkpoint_dir": r'/kaggle/working',
    "checkpoint_interval": 5,
}

policy_dict = {
    "default": mctx.muzero_policy,
    "KL_variational": functools.partial(
        muzero_custom_policy, selector=VariationalKullbackLeibler()
    ),
    "KL_ex_prop": functools.partial(
        muzero_custom_policy, selector=ExPropKullbackLeibler()
    ),
    "squared_hellinger": functools.partial(
        muzero_custom_policy, selector=SquaredHellinger()
    ),
}


class Timestep:
    """Tuple for storing the step type and reward together.
    TODO Consider renaming to avoid confusion with the environment timestep.

    Attributes:
        step_type: The type of the step (e.g., LAST).
        reward: The reward received at this timestep.
    """

    def __init__(self, step_type, reward):
        self.step_type = step_type
        self.reward = reward


@jax.jit
def env_step(state, action):
    """A single step in the environment."""
    next_state, next_timestep = env.step(state, action)
    return next_state, next_timestep


def ep_loss_reward(timestep):
    """Reward transformation for the environment."""
    new_reward = jnp.where(timestep.step_type == StepType.LAST, -10, timestep.reward)
    return new_reward


def recurrent_fn(agent: Agent, rng_key, action, embedding):
    """One simulation step in MCTS."""
    del rng_key

    (state, timestep) = embedding
    new_state, new_timestep = env_step(state, action)

    # get the action probabilities from the network
    prior_logits, value = agent.get_output(new_timestep.observation)

    # return the recurrent function output
    recurrent_fn_output = mctx.RecurrentFnOutput(
        reward=timestep.reward,
        discount=params["discount"],
        prior_logits=prior_logits,
        value=value,
    )
    return recurrent_fn_output, (new_state, new_timestep)


def get_actions(agent, state, timestep, subkey):
    """Get the actions from the MCTS"""

    def root_fn(state, timestep, _):
        """Root function for the MCTS."""
        priors, value = agent.get_output(timestep.observation)

        root = mctx.RootFnOutput(
            prior_logits=priors,
            value=value,
            embedding=(state, timestep),
        )
        return root

    policy = policy_dict[params["policy"]]

    policy_output = policy(
        params=agent,
        rng_key=subkey,
        root=jax.vmap(root_fn, (None, None, 0))(
            state, timestep, jnp.ones(1)
        ),  # params["num_steps"])),
        recurrent_fn=jax.vmap(recurrent_fn, (None, None, 0, 0)),
        num_simulations=params["num_simulations"],
        max_depth=params["max_tree_depth"],
        # max_num_considered_actions=params["num_actions"],
        qtransform=partial(
            mctx.qtransform_completed_by_mix_value,
            value_scale=0.1,
            maxvisit_init=50,
            rescale_values=False,
        ),
        # gumbel_scale=1.0,
    )

    return policy_output


def get_rewards(timestep, prev_reward_arr, episode):
    rewards = []
    new_reward_arr = []

    max_reward = jnp.max(timestep.reward)

    # go over all batches
    for batch_num in range(len(prev_reward_arr)):

        # the previous rewards for this batch
        prev_rewards = prev_reward_arr[batch_num]

        # go over all timesteps in the batch
        for i, (step_type, ep_rew) in enumerate(
            zip(timestep.step_type[batch_num], timestep.reward[batch_num])
        ):
            # if the episode has ended, add the total reward to the rewards list
            if step_type == StepType.LAST:
                # add the reward from the entire game and the timestep it happened
                rew = {
                    "reward": sum(prev_rewards) + ep_rew,
                    "max_reward": max_reward,
                }
                prev_rewards = []

                rewards.append(rew)
                if params["logging"]:
                    wandb.log(
                        rew,
                        step=(episode - 1) * params["num_batches"] * params["num_steps"]
                        + batch_num * params["num_steps"]
                        + (i + 1),
                    )
            else:
                prev_rewards.append(ep_rew)

        new_reward_arr.append(prev_rewards)

    avg_reward = sum([r["reward"] for r in rewards]) / max(1, len(rewards))

    steps = (episode - 1) * params["num_batches"] * params["num_steps"] + params[
        "num_steps"
    ]
    print(
        f"Episode {episode}, Average reward: {str(round(avg_reward, 1))}, Max Reward: {max_reward}, Steps: {steps} / {params['num_episodes']*params['num_batches']*params['num_steps']}"
    )

    return new_reward_arr


def step_fn(agent, state_timestep, subkey):
    """A single step in the environment."""
    state, timestep = state_timestep
    actions = get_actions(agent, state, timestep, subkey)

    assert actions.action.shape[0] == 1
    assert actions.action_weights.shape[0] == 1

    # key = jax.random.PRNGKey(42)
    best_action = actions.action[0]

    state, timestep = env_step(state, best_action)
    q_value = actions.search_tree.summary().qvalues[
        actions.search_tree.ROOT_INDEX, best_action
    ]
    # timestep.extra["game_reward"]

    return (state, timestep), (timestep, actions.action_weights[0], q_value)


def run_n_steps(state, timestep, subkey, agent, n):
    random_keys = jax.random.split(subkey, n)
    # partial function to be able to send the agent as an argument
    partial_step_fn = functools.partial(step_fn, agent)
    # scan over the n steps
    (next_ep_state, next_ep_timestep), (cum_timestep, actions, q_values) = jax.lax.scan(
        partial_step_fn, (state, timestep), random_keys
    )
    return cum_timestep, actions, q_values, next_ep_state, next_ep_timestep


def gather_data(state, timestep, subkey):
    keys = jax.random.split(subkey, params["num_batches"])
    timestep, actions, q_values, next_ep_state, next_ep_timestep = jax.vmap(
        run_n_steps, in_axes=(0, 0, 0, None, None)
    )(state, timestep, keys, agent, params["num_steps"])
    # print(timestep.reward.shape)
    # print(timestep.step_type.shape)

    return timestep, actions, q_values, next_ep_state, next_ep_timestep


def train(agent: Agent, action_weights_arr, q_values_arr, states_arr, episode):
    losses = [
        agent.update_fn(states, actions, q_values, episode)
        for actions, q_values, states in zip(
            action_weights_arr, q_values_arr, states_arr
        )
    ]

    return jnp.mean(jnp.array(losses), axis=0)


if __name__ == "__main__":

    # Initialize wandb
    if params["logging"]:
        init_wandb(params)

    # Initialize the environment
    if params["env_name"] == "Maze-v0":
        gen = generator.RandomGenerator(*params["maze_size"])
        env = jumanji.make(params["env_name"], generator=gen)
    else:
        env = jumanji.make(params["env_name"])

    print(f"running {params['env_name']}")
    env = AutoResetWrapper(env)

    # Initialize the agent
    params["num_actions"] = env.action_spec.num_values
    params["obs_spec"] = env.observation_spec
    agent = params.get("agent", Agent)(params)

    # Specify buffer parameters
    buffer = fbx.make_flat_buffer(
        max_length=params["buffer_max_length"],
        min_length=params["buffer_min_length"],
        sample_batch_size=params["sample_size"],
        add_batch_size=params["num_batches"],
    )

    # Jit the buffer functions
    buffer = buffer.replace(
        init=jax.jit(buffer.init),
        add=jax.jit(buffer.add, donate_argnums=0),
        sample=jax.jit(buffer.sample),
        can_sample=jax.jit(buffer.can_sample),
    )

    # Specify buffer format
    if params["env_name"] in ["Snake-v1", "Knapsack-v1"]:
        fake_timestep = {
            "q_value": jnp.zeros((params["num_steps"])),
            "actions": jnp.zeros(
                (params["num_steps"], params["num_actions"]), dtype=jnp.float32
            ),
            "states": jnp.zeros(
                (params["num_steps"], *agent.input_shape), dtype=jnp.float32
            ),
        }
    else:
        fake_timestep = {
            "q_value": jnp.zeros((params["num_steps"])),
            "actions": jnp.zeros(
                (params["num_steps"], params["num_actions"]), dtype=jnp.float32
            ),
            "states": jnp.zeros(
                (params["num_steps"], *agent.input_shape), dtype=jnp.int32
            ),
        }
    buffer_state = buffer.init(fake_timestep)

    # Initialize the random keys
    key = jax.random.PRNGKey(params["seed"])
    rng_key, subkey = jax.random.split(key)
    keys = jax.random.split(rng_key, params["num_batches"])

    # Get the initial state and timestep
    next_ep_state, next_ep_timestep = jax.vmap(env.reset)(keys)

    prev_reward_arr = [[] for _ in range(params["num_batches"])]
    for episode in range(1, params["num_episodes"] + 1):

        # Get new key every episode
        key, sample_key = jax.jit(jax.random.split)(key)
        # Gather data
        timestep, actions, q_values, next_ep_state, next_ep_timestep = gather_data(
            next_ep_state, next_ep_timestep, sample_key
        )

        prev_reward_arr = get_rewards(timestep, prev_reward_arr, episode)

        # Get state in the correct format given environment
        states = agent.get_state_from_observation(timestep.observation, True)

        # Add data to buffer
        buffer_state = buffer.add(
            buffer_state,
            {
                "q_value": q_values,
                "actions": actions,
                "states": states,
            },
        )

        if buffer.can_sample(buffer_state):
            key, sample_key = jax.jit(jax.random.split)(key)
            data = buffer.sample(buffer_state, sample_key).experience.first
            loss = train(
                agent, data["actions"], data["q_value"], data["states"], episode
            )
            agent.log_losses(episode, params)
        else:
            loss = None

        # if params["logging"]:
        #     log_rewards(rewards, loss, episode, params)

        if episode % params["checkpoint_interval"] == 0:
            print(f"Saving checkpoint for episode {episode}")
            agent.save(params["checkpoint_dir"], episode)

    if params["logging"]:
        wandb.finish()
