In [1]:
import numpy as np
import torch
from tqdm import trange
from agilerl.hpo.mutation import Mutations
from agilerl.hpo.tournament import TournamentSelection
from agilerl.utils.utils import create_population, make_vect_envs
from copy import deepcopy

In [2]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np

In [None]:


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", device)
NET_CONFIG = {
    "arch": "mlp",  # Network architecture
    "hidden_size": [32, 32],  # Actor hidden size
}

INIT_HP = {
    "POP_SIZE": 6,  # Population size
    "DISCRETE_ACTIONS": False,  # Discrete action space
    "BATCH_SIZE": 128,  # Batch size
    "LR": 1e-3,  # Learning rate
    "LEARN_STEP": 128,  # Learning frequency
    "GAMMA": 0.99,  # Discount factor
    "GAE_LAMBDA": 0.95,  # Lambda for general advantage estimation
    "ACTION_STD_INIT": 0.6,  # Initial action standard deviation
    "CLIP_COEF": 0.2,  # Surrogate clipping coefficient
    "ENT_COEF": 0.01,  # Entropy coefficient
    "VF_COEF": 0.5,  # Value function coefficient
    "MAX_GRAD_NORM": 0.5,  # Maximum norm for gradient clipping
    "TARGET_KL": None,  # Target KL divergence threshold
    "UPDATE_EPOCHS": 4,  # Number of policy update epochs
    # Swap image channels dimension from last to first [H, W, C] -> [C, H, W]
     "wandb_api_key":'XYZ',
    "WANDB": True,
    "CHANNELS_LAST": False,
}

num_envs = 2
env_name="Pendulum-v1"
env = make_vect_envs(env_name, num_envs=num_envs)  # Create environment

# env = gym.vector.AsyncVectorEnv(
#         [lambda: Simple2DEnv() for i in range(num_envs)]
#     )

try:
    state_dim = env.single_observation_space.n  # Discrete observation space
    one_hot = True  # Requires one-hot encoding
except Exception:
    state_dim = env.single_observation_space.shape  # Continuous observation space
    one_hot = False  # Does not require one-hot encoding
try:
    action_dim = env.single_action_space.n  # Discrete action space
except Exception:
    action_dim = env.single_action_space.shape[0]  # Continuous action space

if INIT_HP["CHANNELS_LAST"]:
    state_dim = (state_dim[2], state_dim[0], state_dim[1])



DEVICE: cpu


In [4]:
tournament = TournamentSelection(
    tournament_size=2,  # Tournament selection size
    elitism=True,  # Elitism in tournament selection
    population_size=INIT_HP["POP_SIZE"],  # Population size
    eval_loop=1,  # Evaluate using last N fitness scores
)



max_steps = 200000  # Max steps
evo_steps = 10000  # Evolution frequency
eval_steps = None  # Evaluation steps per episode - go until done
eval_loop = 1  # Number of evaluation episodes

total_steps = 0

In [5]:
class MyMutations(Mutations):
    def reinit_opt(self, individual):
        if self.multi_agent:
            # Reinitialise optimizer
            raise NotImplementedError(
                    f"Mutations is not implemented for {individual.algo}"
                )
            actor_opts = getattr(individual, self.algo["actor"]["optimizer"])

            net_params = [
                actor.parameters()
                for actor in getattr(individual, self.algo["actor"]["eval"])
            ]

            offspring_actor_opts = [
                type(actor_opt)(net_param, lr=individual.lr_actor)
                for actor_opt, net_param in zip(actor_opts, net_params)
            ]

            setattr(
                individual,
                self.algo["actor"]["optimizer"],
                offspring_actor_opts,
            )

            for critic_list in self.algo["critics"]:
                critic_opts = getattr(individual, critic_list["optimizer"])

                net_params = [
                    critic.parameters()
                    for critic in getattr(individual, critic_list["eval"])
                ]

                offspring_critic_opts = [
                    type(critic_opt)(net_param, lr=individual.lr_critic)
                    for critic_opt, net_param in zip(critic_opts, net_params)
                ]

                setattr(
                    individual,
                    critic_list["optimizer"],
                    offspring_critic_opts,
                )
        else:
            if individual.algo in ["PPO"]:
                print("Mutations is here")
                # Reinitialise optimizer
                opt = getattr(individual, self.algo["actor"]["optimizer"])
                actor_net_params = getattr(
                    individual, self.algo["actor"]["eval"]
                ).parameters()
                critic_net_params = getattr(
                    individual, self.algo["critics"][0]["eval"]
                ).parameters()
                vol_net_params = getattr(
                    individual, self.algo["critics"][1]["eval"]
                ).parameters()
                opt_args = [
                    {"params": actor_net_params, "lr": individual.lr},
                    {"params": critic_net_params, "lr": individual.lr},
                    {"params": vol_net_params, "lr": individual.lr},
                ]
                setattr(
                    individual,
                    self.algo["actor"]["optimizer"],
                    type(opt)(opt_args),
                )

            else:
                raise NotImplementedError(
                    f"Mutations 2 is not implemented for {individual.algo}"
                )
                # Reinitialise optimizer
                actor_opt = getattr(individual, self.algo["actor"]["optimizer"])
                net_params = getattr(
                    individual, self.algo["actor"]["eval"]
                ).parameters()
                if individual.algo in ["DDPG", "TD3"]:
                    setattr(
                        individual,
                        self.algo["actor"]["optimizer"],
                        type(actor_opt)(net_params, lr=individual.lr_actor),
                    )
                else:
                    setattr(
                        individual,
                        self.algo["actor"]["optimizer"],
                        type(actor_opt)(net_params, lr=individual.lr),
                    )

                # If algorithm has critics, reinitialise their optimizers too
                for critic in self.algo["critics"]:
                    critic_opt = getattr(individual, critic["optimizer"])
                    net_params = getattr(individual, critic["eval"]).parameters()
                    setattr(
                        individual,
                        critic["optimizer"],
                        type(critic_opt)(net_params, lr=individual.lr_critic),
                    )

In [6]:
nets = {
       "actor": {"eval": "actor", "optimizer": "optimizer"},
       "critics": [
             {"eval": "critic", "optimizer": "critic_optimizer"},
             {"eval": "actor_var", "optimizer": "actor_var_optimizer"}
       ]
}

mutations = MyMutations(
    algo="PPO",  # Algorithm
    no_mutation=0.4,  # No mutation
    architecture=0.2,  # Architecture mutation
    new_layer_prob=0.2,  # New layer mutation
    parameters=0.2,  # Network parameters mutation
    activation=0,  # Activation layer mutation
    rl_hp=0.2,  # Learning HP mutation
    rl_hp_selection=["lr", "batch_size", "learn_step"],  # RL HPs to choose from
    mutation_sd=0.1,  # Mutation strength
    arch=NET_CONFIG["arch"],  # Network architecture
    rand_seed=1,  # Random seed
    device=device,
)

mutations.algo = mutations.algo | nets


### Custom Actor and Volatility Net

In [7]:
import torch.nn as nn
import torch

In [8]:
class MLPActor(nn.Module):
    def __init__(self, input_size, output_size):
        super(MLPActor, self).__init__()

        self.linear_layer_1 = nn.Linear(input_size, 64)
        self.linear_layer_2 = nn.Linear(64, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x1 = self.relu(self.linear_layer_1(x))
        x1 = self.linear_layer_2(x1)

        return x1

actor = MLPActor(state_dim[0], action_dim)

In [9]:
class VolatilityNet(nn.Module):
    def __init__(self, input_size, output_size):
        super(VolatilityNet, self).__init__()

        self.linear_layer_1 = nn.Linear(input_size, 64)
        self.linear_layer_2 = nn.Linear(64, output_size)
        self.relu = nn.ReLU()
        self.relu2 = nn.ReLU()

    def forward(self, x):
        x1 = self.relu(self.linear_layer_1(x))
        x1 = self.relu2(self.linear_layer_2(x1))

        return x1


vol_net = VolatilityNet(state_dim[0], action_dim)

In [10]:
import copy
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from agilerl.networks.custom_components import GumbelSoftmax, NoisyLinear

In [11]:
from agilerl.wrappers.make_evolvable import MakeEvolvable



evolvable_actor = MakeEvolvable(actor,
                                input_tensor=torch.randn(state_dim[0]),
                                device=device)

In [12]:
evolvable_vol_net = MakeEvolvable(vol_net,
                                input_tensor=torch.randn(state_dim[0]),
                                device=device)

Quick Tests

In [13]:
state, info = env.reset()

In [14]:
state, info = env.reset()
vol_net(torch.tensor(state))

tensor([[0.3404],
        [0.0314]], grad_fn=<ReluBackward0>)

In [15]:
evolvable_vol_net.forward(torch.tensor(state))

tensor([[0.0195],
        [0.0000]], grad_fn=<ReluBackward0>)

### Custom Critic

In [16]:
class MLPCritic(nn.Module):
    def __init__(self, input_size, output_size):
        super(MLPCritic, self).__init__()

        self.linear_layer_1 = nn.Linear(input_size, 64)
        self.linear_layer_2 = nn.Linear(64, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.linear_layer_1(x))
        x = self.linear_layer_2(x)
        return x

In [17]:
critic = MLPCritic(state_dim[0], 1)
evolvable_critic = MakeEvolvable(critic,
                                input_tensor=torch.randn(state_dim[0]),
                                device=device)

### Custom PPO Class


In [18]:
from agilerl.algorithms.ppo import PPO
from torch.distributions import Categorical, MultivariateNormal


In [19]:
import copy
import inspect

import dill
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical, MultivariateNormal
from torch.nn.utils import clip_grad_norm_

from agilerl.networks.evolvable_cnn import EvolvableCNN
from agilerl.networks.evolvable_mlp import EvolvableMLP
from agilerl.utils.algo_utils import chkpt_attribute_to_device, unwrap_optimizer
from agilerl.wrappers.make_evolvable import MakeEvolvable

In [20]:
class SigmaPPO(PPO):
    def __init__(
        self,
        state_dim,
        action_dim,
        one_hot,
        discrete_actions,
        max_action=1,
        min_action=-1,
        index=0,
        net_config={"arch": "mlp", "hidden_size": [64, 64]},
        batch_size=64,
        lr=1e-4,
        learn_step=2048,
        gamma=0.99,
        gae_lambda=0.95,
        mut=None,
        action_std_init=0.6,
        clip_coef=0.2,
        ent_coef=0.01,
        vf_coef=0.5,
        max_grad_norm=0.5,
        target_kl=None,
        update_epochs=4,
        actor_network=None,
        critic_network=None,
        device="cpu",
        accelerator=None,
        wrap=True,
    ):
        assert isinstance(
            state_dim, (list, tuple)
        ), "State dimension must be a list or tuple."
        assert isinstance(
            action_dim, (int, np.integer)
        ), "Action dimension must be an integer."
        assert isinstance(
            one_hot, bool
        ), "One-hot encoding flag must be boolean value True or False."
        assert isinstance(
            discrete_actions, bool
        ), "Discrete actions flag must be boolean value True or False."
        assert isinstance(
            max_action,
            (float, int, np.float32, np.float64, np.integer, list, np.ndarray),
        ), "Max action must be a float or integer."
        assert isinstance(
            min_action,
            (float, int, np.float32, np.float64, np.integer, list, np.ndarray),
        ), "Min action must be a float or integer."
        if isinstance(min_action, list):
            assert (
                len(min_action) == action_dim
            ), "Length of min_action must be equal to action_dim."
            min_action = np.array(min_action)
        if isinstance(max_action, list):
            assert (
                len(max_action) == action_dim
            ), "Length of max_action must be equal to action_dim."
            max_action = np.array(max_action)
        if isinstance(max_action, np.ndarray) or isinstance(min_action, np.ndarray):
            assert np.all(
                max_action > min_action
            ), "Max action must be greater than min action."
        else:
            assert (
                max_action > min_action
            ), "Max action must be greater than min action."
        assert isinstance(index, int), "Agent index must be an integer."
        assert isinstance(batch_size, int), "Batch size must be an integer."
        assert batch_size >= 1, "Batch size must be greater than or equal to one."
        assert isinstance(lr, float), "Learning rate must be a float."
        assert lr > 0, "Learning rate must be greater than zero."
        assert isinstance(gamma, (float, int)), "Gamma must be a float."
        assert isinstance(gae_lambda, (float, int)), "Lambda must be a float."
        assert gae_lambda >= 0, "Lambda must be greater than or equal to zero."
        assert isinstance(
            action_std_init, (float, int)
        ), "Action standard deviation must be a float."
        assert (
            action_std_init >= 0
        ), "Action standard deviation must be greater than or equal to zero."
        assert isinstance(
            clip_coef, (float, int)
        ), "Clipping coefficient must be a float."
        assert (
            clip_coef >= 0
        ), "Clipping coefficient must be greater than or equal to zero."
        assert isinstance(
            ent_coef, (float, int)
        ), "Entropy coefficient must be a float."
        assert (
            ent_coef >= 0
        ), "Entropy coefficient must be greater than or equal to zero."
        assert isinstance(
            vf_coef, (float, int)
        ), "Value function coefficient must be a float."
        assert (
            vf_coef >= 0
        ), "Value function coefficient must be greater than or equal to zero."
        assert isinstance(
            max_grad_norm, (float, int)
        ), "Maximum norm for gradient clipping must be a float."
        assert (
            max_grad_norm >= 0
        ), "Maximum norm for gradient clipping must be greater than or equal to zero."
        assert (
            isinstance(target_kl, (float, int)) or target_kl is None
        ), "Target KL divergence threshold must be a float."
        if target_kl is not None:
            assert (
                target_kl >= 0
            ), "Target KL divergence threshold must be greater than or equal to zero."
        assert isinstance(
            update_epochs, int
        ), "Policy update epochs must be an integer."
        assert (
            update_epochs >= 1
        ), "Policy update epochs must be greater than or equal to one."
        assert isinstance(
            wrap, bool
        ), "Wrap models flag must be boolean value True or False."

        self.algo = "PPO"
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.one_hot = one_hot
        self.discrete_actions = discrete_actions
        self.max_action = max_action
        self.min_action = min_action
        self.net_config = net_config
        self.batch_size = batch_size
        self.lr = lr
        self.learn_step = learn_step
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.mut = mut
        self.action_std_init = action_std_init
        self.clip_coef = clip_coef
        self.ent_coef = ent_coef
        self.vf_coef = vf_coef
        self.max_grad_norm = max_grad_norm
        self.target_kl = target_kl
        self.update_epochs = update_epochs
        self.actor_network = actor_network
        self.critic_network = critic_network
        self.device = device
        self.accelerator = accelerator

        self.index = index
        self.scores = []
        self.fitness = []
        self.steps = [0]



        if self.actor_network is not None and self.critic_network is not None:

            self.actor = actor_network
            self.critic = critic_network[0] ## !
            self.vol_net = critic_network[1] ## !

            if isinstance(self.actor, (EvolvableMLP, EvolvableCNN)) and isinstance(
                self.critic, (EvolvableMLP, EvolvableCNN)
            ):
                self.net_config = self.actor.net_config
            elif isinstance(self.actor, MakeEvolvable) and isinstance(
                self.critic, MakeEvolvable
            ):
                self.net_config = None
            else:
                assert (
                    False
                ), f"'actor_network' argument is of type {type(actor_network)} and 'critic_network' of type {type(critic_network)}, \
                                both must be the same type and be of type EvolvableMLP, EvolvableCNN or MakeEvolvable"

        else:
            raise NotImplementedError("Both Aren't None")
            assert isinstance(self.net_config, dict), "Net config must be a dictionary."
            assert (
                "arch" in self.net_config.keys()
            ), "Net config must contain arch: 'mlp' or 'cnn'."

            # Set up network output activations
            if "mlp_output_activation" not in self.net_config.keys():
                if self.discrete_actions:
                    self.net_config["mlp_output_activation"] = "Softmax"
                elif np.any(self.min_action < 0):
                    self.net_config["mlp_output_activation"] = "Tanh"
                else:
                    self.net_config["mlp_output_activation"] = "Sigmoid"

            if "mlp_activation" not in self.net_config.keys():
                self.net_config["mlp_activation"] = "Tanh"

            critic_net_config = copy.deepcopy(self.net_config)
            critic_net_config["mlp_output_activation"] = (
                None  # Critic must have no output activation
            )

            # model
            if self.net_config["arch"] == "mlp":  # Multi-layer Perceptron
                assert (
                    "hidden_size" in self.net_config.keys()
                ), "Net config must contain hidden_size: int."
                assert isinstance(
                    self.net_config["hidden_size"], list
                ), "Net config hidden_size must be a list."
                assert (
                    len(self.net_config["hidden_size"]) > 0
                ), "Net config hidden_size must contain at least one element."
                self.actor = EvolvableMLP(
                    num_inputs=state_dim[0],
                    num_outputs=action_dim,
                    device=self.device,
                    accelerator=self.accelerator,
                    **self.net_config,
                )
                self.critic = EvolvableMLP(
                    num_inputs=state_dim[0],
                    num_outputs=1,
                    device=self.device,
                    accelerator=self.accelerator,
                    **critic_net_config,
                )
            elif self.net_config["arch"] == "cnn":  # Convolutional Neural Network
                for key in [
                    "channel_size",
                    "kernel_size",
                    "stride_size",
                    "hidden_size",
                ]:
                    assert (
                        key in self.net_config.keys()
                    ), f"Net config must contain {key}: int."
                    assert isinstance(
                        self.net_config[key], list
                    ), f"Net config {key} must be a list."
                    assert (
                        len(self.net_config[key]) > 0
                    ), f"Net config {key} must contain at least one element."
                assert (
                    "normalize" in self.net_config.keys()
                ), "Net config must contain normalize: True or False."
                assert isinstance(
                    self.net_config["normalize"], bool
                ), "Net config normalize must be boolean value True or False."
                self.actor = EvolvableCNN(
                    input_shape=state_dim,
                    num_actions=action_dim,
                    device=self.device,
                    accelerator=self.accelerator,
                    **self.net_config,
                )
                self.critic = EvolvableCNN(
                    input_shape=state_dim,
                    num_actions=1,
                    device=self.device,
                    accelerator=self.accelerator,
                    **critic_net_config,
                )

        self.arch = (
            self.net_config["arch"] if self.net_config is not None else self.actor.arch
        )

        self.optimizer = optim.Adam(
            [
                {"params": self.actor.parameters(), "lr": self.lr},
                {"params": self.critic.parameters(), "lr": self.lr},
                {"params": self.vol_net.parameters(), "lr": self.lr}, ## !
            ]
        )

        if self.accelerator is not None:
            if wrap:
                self.wrap_models()
        else:
            self.actor = self.actor.to(self.device)
            self.critic = self.critic.to(self.device)
            self.vol_net = self.vol_net.to(self.device) ## !

    def get_action(self, state, action=None, grad=False, action_mask=None):
        """Returns the next action to take in the environment.

        :param state: Environment observation, or multiple observations in a batch
        :type state: numpy.ndarray[float]
        :param action: Action in environment to evaluate, defaults to None
        :type action: torch.Tensor(), optional
        :param grad: Calculate gradients on actions, defaults to False
        :type grad: bool, optional
        :param action_mask: Mask of legal actions 1=legal 0=illegal, defaults to None
        :type action_mask: numpy.ndarray, optional
        """
        state = self.prepare_state(state)

        # print('Yo, it`s a sigma PPO')

        if not grad:
            self.actor.eval()
            self.critic.eval()
            with torch.no_grad():

                act_out = self.actor(state)
                if type(act_out) == tuple:
                    raise NotImplementedError("Get Action 1")
                    action_values, std_values = act_out
                else:
                    action_values = act_out
                    std_values = self.vol_net(state) ## !

                # print('action_values', action_values)
                # print('std_values', std_values)
                state_values = self.critic(state).squeeze(-1)
            self.actor.train()
            self.critic.train()

        else:
            act_out = self.actor(state)
            if type(act_out) == tuple:
                raise NotImplementedError("Get Action 2")
                action_values, std_values = act_out
            else:
                action_values = act_out
                std_values = self.vol_net(state) ## !
            state_values = self.critic(state).squeeze(-1)



        if self.discrete_actions:
            # print('Discrete')
            if action_mask is not None:
                action_mask = torch.from_numpy(action_mask)
                action_values *= action_mask

            dist = Categorical(action_values)
        else:
            # print("self.action_var", self.action_var)
            # cov_mat = torch.diag(self.action_var).unsqueeze(dim=0)
            cov_mat = torch.diag_embed(std_values)
            batch_size = cov_mat.shape[0]
            for i in range(batch_size):
                cov_mat[i] += 1e-6 * torch.eye(cov_mat.size(1))  # Add a small constant to the diagonal
            # print('cov_mat', cov_mat)
            dist = MultivariateNormal(action_values, cov_mat)

        return_tensors = True
        if action is None:
            action = dist.sample()
            return_tensors = False
        elif self.accelerator is None:
            action = action.to(self.device)
        else:
            action = action.to(self.accelerator.device)

        action_logprob = dist.log_prob(action)
        dist_entropy = dist.entropy()

        if return_tensors:
            return (
                (
                    self.scale_to_action_space(action, convert_to_torch=True)
                    if not self.discrete_actions
                    else action
                ),
                action_logprob,
                dist_entropy,
                state_values,
            )
        else:
            return (
                (
                    self.scale_to_action_space(action.cpu().data.numpy())
                    if not self.discrete_actions
                    else action.cpu().data.numpy()
                ),
                action_logprob.cpu().data.numpy(),
                dist_entropy.cpu().data.numpy(),
                state_values.cpu().data.numpy(),
            )






Create Population of SigmaPPO

In [21]:

def custom_create_population(
        algo,
        state_dim,
        action_dim,
        one_hot,
        net_config,
        INIT_HP,
        actor_network=None,
        critic_network=None,
        population_size=1,
        num_envs=1,
        device="cpu",
        accelerator=None,
        torch_compiler=None,
    ):
        if "PPO" not in algo:
            raise NotImplementedError("Only PPO for now")

        population = []

        for idx in range(population_size):
            agent = SigmaPPO(
                state_dim=state_dim,
                action_dim=action_dim,
                one_hot=one_hot,
                discrete_actions=False,
                index=idx,
                net_config=net_config,
                batch_size=INIT_HP["BATCH_SIZE"],
                lr=INIT_HP["LR"],
                learn_step=INIT_HP["LEARN_STEP"],
                gamma=INIT_HP["GAMMA"],
                gae_lambda=INIT_HP["GAE_LAMBDA"],
                action_std_init=INIT_HP["ACTION_STD_INIT"],
                clip_coef=INIT_HP["CLIP_COEF"],
                ent_coef=INIT_HP["ENT_COEF"],
                vf_coef=INIT_HP["VF_COEF"],
                max_grad_norm=INIT_HP["MAX_GRAD_NORM"],
                target_kl=INIT_HP["TARGET_KL"],
                update_epochs=INIT_HP["UPDATE_EPOCHS"],
                actor_network=actor_network,
                critic_network=critic_network,
                device=device,
                accelerator=accelerator,
            )
            population.append(agent)

        return population

In [22]:
pop = custom_create_population(algo="SigmaPPO",  # Algorithm
                        state_dim=state_dim,  # State dimension
                        action_dim=action_dim,  # Action dimension
                        one_hot=one_hot,  # One-hot encoding
                        net_config=None,  # Network configuration set as None
                        actor_network=evolvable_actor, # Custom evolvable actor
                        critic_network=[evolvable_critic, evolvable_vol_net], # Custom evolvable critic
                        INIT_HP=INIT_HP,  # Initial hyperparameters
                        population_size=INIT_HP["POP_SIZE"],  # Population size
                        device=device)

In [23]:
for agent in pop:
    print(agent.vol_net)
    state, info = env.reset()
    agent.get_action(state)
    break

MakeEvolvable(
  (feature_net): Sequential(
    (feature_linear_layer_0): Linear(in_features=3, out_features=64, bias=True)
    (feature_activation_0): ReLU()
    (feature_linear_layer_output): Linear(in_features=64, out_features=1, bias=True)
    (feature_activation_output): ReLU()
  )
)


### Training

In [24]:


# TRAINING LOOP
print("Training...")
pbar = trange(max_steps, unit="step")
while np.less([agent.steps[-1] for agent in pop], max_steps).all():
    pop_episode_scores = []
    for agent in pop:  # Loop through population
        state, info = env.reset()  # Reset environment at start of episode
        scores = np.zeros(num_envs)
        completed_episode_scores = []
        steps = 0

        for _ in range(-(evo_steps // -agent.learn_step)):

            states = []
            actions = []
            log_probs = []
            rewards = []
            dones = []
            values = []

            learn_steps = 0

            for idx_step in range(-(agent.learn_step // -num_envs)):
                if INIT_HP["CHANNELS_LAST"]:
                    state = np.moveaxis(state, [-1], [-3])

                # Get next action from agent
                action, log_prob, _, value = agent.get_action(state)

                # Act in environment
                next_state, reward, terminated, truncated, info = env.step(action)

                total_steps += num_envs
                steps += num_envs
                learn_steps += num_envs

                states.append(state)
                actions.append(action)
                log_probs.append(log_prob)
                rewards.append(reward)
                dones.append(terminated)
                values.append(value)

                state = next_state
                scores += np.array(reward)

                for idx, (d, t) in enumerate(zip(terminated, truncated)):
                    if d or t:
                        completed_episode_scores.append(scores[idx])
                        agent.scores.append(scores[idx])
                        scores[idx] = 0

            pbar.update(learn_steps // len(pop))

            if INIT_HP["CHANNELS_LAST"]:
                next_state = np.moveaxis(next_state, [-1], [-3])

            experiences = (
                states,
                actions,
                log_probs,
                rewards,
                dones,
                values,
                next_state,
            )
            # Learn according to agent's RL algorithm
            agent.learn(experiences)

        agent.steps[-1] += steps
        pop_episode_scores.append(completed_episode_scores)

    # Evaluate population
    fitnesses = [
        agent.test(
            env,
            swap_channels=INIT_HP["CHANNELS_LAST"],
            max_steps=eval_steps,
            loop=eval_loop,
        )
        for agent in pop
    ]
    mean_scores = [
        (
            np.mean(episode_scores)
            if len(episode_scores) > 0
            else "0 completed episodes"
        )
        for episode_scores in pop_episode_scores
    ]

    print(f"--- Global steps {total_steps} ---")
    print(f"Steps {[agent.steps[-1] for agent in pop]}")
    print(f"Scores: {mean_scores}")
    print(f'Fitnesses: {["%.2f"%fitness for fitness in fitnesses]}')
    print(
        f'5 fitness avgs: {["%.2f"%np.mean(agent.fitness[-5:]) for agent in pop]}'
    )

    # Tournament selection and population mutation
    elite, pop = tournament.select(pop)
    pop = mutations.mutation(pop)

    # Update step counter
    for agent in pop:
        agent.steps.append(agent.steps[-1])

pbar.close()
env.close()

Training...


  5%|▍         | 9954/200000 [03:37<1:03:57, 49.52step/s]

--- Global steps 60672 ---
Steps [10112, 10112, 10112, 10112, 10112, 10112]
Scores: [-1243.6453003181653, -1555.3733385481316, -1507.773033485779, -1496.7876791630763, -1388.7472103345315, -1339.5358883137058]
Fitnesses: ['-1452.71', '-1281.13', '-1388.72', '-1450.54', '-1380.77', '-1456.92']
5 fitness avgs: ['-1452.71', '-1281.13', '-1388.72', '-1450.54', '-1380.77', '-1456.92']


In [None]:
# from agilerl.training.train_on_policy import train_on_policy

# trained_pop, pop_fitnesses = train_on_policy(
#     env=env,                              # Gym-style environment
#     env_name="LunarLander-v2",  # Environment name
#     algo="PPO",  # Algorithm
#     pop=pop,  # Population of agents
#     swap_channels=INIT_HP['CHANNELS_LAST'],  # Swap image channel from last to first
#     max_steps=200000,  # Max number of training steps
#     evo_steps=10000,  # Evolution frequency
#     eval_steps=None,  # Number of steps in evaluation episode
#     eval_loop=1,  # Number of evaluation episodes
#     target=200.,  # Target score for early stopping
#     tournament=tournament,  # Tournament selection object
#     mutation=mutations,  # Mutations object
#     wb=INIT_HP['WANDB'],  # Weights and Biases tracking
#     wandb_api_key=INIT_HP['wandb_api_key'],
# )

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc



Training...


  5%|4         | 9975/200000 [  02:03<31:55:58,  1.65step/s]


                --- Global Steps 60672 ---
                Fitness:		['-48.77', '-101.89', '-221.74', '-271.12', '-53.30', '-107.97']
                Score:		[-176.23922055595082, -155.75562717239725, -170.69177859329497, -177.7026453190563, -150.02117360401311, -141.84729437692192]
                5 fitness avgs:	['-48.77', '-53.30', '-48.77', '-48.77', '-221.74', '-221.74']
                10 score avgs:	['-173.84', '-101.44', '-173.84', '-173.84', '-184.69', '-184.69']
                Agents:		[0, 6, 7, 8, 9, 10]
                Steps:		[10112, 10112, 10112, 10112, 10112, 10112]
                Mutations:		['None', 'arch', 'None', 'None', 'None', 'None']
                

 10%|9         | 19929/200000 [  03:54<38:13:48,  1.31step/s]


                --- Global Steps 121344 ---
                Fitness:		['-86.23', '-122.45', '-265.49', '-120.54', '-102.91', '-167.36']
                Score:		[-141.97666576566547, -140.5941714317149, -136.11190112576617, -103.09168031440569, -109.01930863991645, -147.17937029822264]
                5 fitness avgs:	['-67.50', '-67.50', '-162.32', '-157.13', '-162.32', '-84.65']
                10 score avgs:	['-158.96', '-158.96', '-48.81', '-161.82', '-48.81', '-146.81']
                Agents:		[0, 11, 12, 13, 14, 15]
                Steps:		[20224, 20224, 20224, 20224, 20224, 20224]
                Mutations:		['lr', 'param', 'None', 'None', 'None', 'None']
                

 15%|#4        | 29883/200000 [  06:00<41:38:16,  1.13step/s]


                --- Global Steps 182016 ---
                Fitness:		['-22.68', '-166.74', '-15.47', '-231.88', '-66.99', '-138.34']
                Score:		[-128.88283848932244, -80.64383562592732, -108.82206316567299, -215.3515963769123, -133.86973578170796, -77.7678894780903]
                5 fitness avgs:	['-113.37', '-113.37', '-113.37', '-130.55', '-113.37', '-52.56']
                10 score avgs:	['-120.29', '-120.29', '-120.29', '-106.53', '-120.29', '-96.86']
                Agents:		[12, 16, 17, 18, 19, 20]
                Steps:		[30336, 30336, 30336, 30336, 30336, 30336]
                Mutations:		['arch', 'param', 'None', 'param', 'param', 'None']
                

 20%|#9        | 39837/200000 [  07:55<21:28:28,  2.07step/s]


                --- Global Steps 242688 ---
                Fitness:		['4.29', '16.51', '82.47', '-152.01', '44.58', '-88.86']
                Score:		[-72.76982630581337, -57.184796246310185, -69.05850945749296, -123.83158000776433, -87.83964405370523, -60.962476671962065]
                5 fitness avgs:	['-64.41', '-64.41', '-61.63', '-80.90', '-64.41', '-64.41']
                10 score avgs:	['-74.54', '-74.54', '-63.47', '-39.71', '-74.54', '-74.54']
                Agents:		[17, 21, 22, 23, 24, 25]
                Steps:		[40448, 40448, 40448, 40448, 40448, 40448]
                Mutations:		['arch', 'None', 'None', 'None', 'None', 'lr']
                

 25%|##4       | 49791/200000 [  10:01<23:20:46,  1.79step/s]


                --- Global Steps 303360 ---
                Fitness:		['29.64', '61.96', '-18.47', '66.85', '36.93', '-29.34']
                Score:		['0 completed episodes', -83.06512454319315, -57.54712295723905, 5.477688339793149, -11.707507770864643, 20.49466635440831]
                5 fitness avgs:	['-51.35', '-39.14', '-51.35', '-51.35', '-45.60', '-39.14']
                10 score avgs:	['28.65', '-52.42', '28.65', '28.65', '-74.54', '-52.42']
                Agents:		[23, 26, 27, 28, 29, 30]
                Steps:		[50560, 50560, 50560, 50560, 50560, 50560]
                Mutations:		['None', 'None', 'bs', 'None', 'bs', 'None']
                

 30%|##9       | 59745/200000 [  12:07<31:31:35,  1.24step/s]


                --- Global Steps 364032 ---
                Fitness:		['156.71', '-5.36', '32.20', '-58.64', '-48.40', '63.58']
                Score:		[-48.297843928705305, 152.63450061304798, -34.48455497363226, '0 completed episodes', -100.00753211582501, 52.239736687181704]
                5 fitness avgs:	['24.34', '24.34', '4.14', '24.34', '4.14', '17.93']
                10 score avgs:	['-10.65', '-10.65', '-35.89', '-10.65', '-35.89', '-40.74']
                Agents:		[23, 31, 32, 33, 34, 35]
                Steps:		[60672, 60672, 60672, 60672, 60672, 60672]
                Mutations:		['arch', 'None', 'None', 'None', 'None', 'param']
                

 35%|###4      | 69699/200000 [  14:08<32:00:40,  1.13step/s]


                --- Global Steps 424704 ---
                Fitness:		['-131.08', '-82.06', '-113.25', '-42.78', '-55.96', '-31.20']
                Score:		[-141.02258322140182, '0 completed episodes', '0 completed episodes', 26.175770216717346, -77.31401164717573, -94.97396464715598]
                5 fitness avgs:	['32.27', '32.27', '36.36', '28.51', '32.27', '32.27']
                10 score avgs:	['-45.22', '-45.22', '7.73', '-10.65', '-45.22', '-45.22']
                Agents:		[35, 36, 37, 38, 39, 40]
                Steps:		[70784, 70784, 70784, 70784, 70784, 70784]
                Mutations:		['arch', 'None', 'None', 'param', 'None', 'None']
                

 40%|###9      | 79653/200000 [  16:08<27:45:25,  1.20step/s]


                --- Global Steps 485376 ---
                Fitness:		['-93.20', '-70.47', '-142.07', '-157.16', '-47.05', '17.44']
                Score:		[-88.25023465417148, -97.57860022799575, -71.3074533580399, -107.21928275432323, -57.24363410105409, -81.93577449494707]
                5 fitness avgs:	['38.85', '38.85', '21.27', '38.85', '21.27', '25.95']
                10 score avgs:	['-74.24', '-74.24', '-99.51', '-74.24', '-99.51', '-46.65']
                Agents:		[40, 41, 42, 43, 44, 45]
                Steps:		[80896, 80896, 80896, 80896, 80896, 80896]
                Mutations:		['arch', 'param', 'lr', 'lr', 'None', 'None']
                

 45%|####4     | 89607/200000 [  18:07<27:09:10,  1.13step/s]


                --- Global Steps 546048 ---
                Fitness:		['-67.54', '-51.07', '-71.25', '-83.75', '-144.27', '-40.82']
                Score:		[-92.54870476187605, -83.09168859413971, -117.26894048102199, -69.45736984938493, -115.59761381897825, -99.19487029624085]
                5 fitness avgs:	['1.30', '1.30', '12.14', '5.61', '-9.47', '12.14']
                10 score avgs:	['-103.84', '-103.84', '-72.45', '-46.62', '-125.35', '-72.45']
                Agents:		[45, 46, 47, 48, 49, 50]
                Steps:		[91008, 91008, 91008, 91008, 91008, 91008]
                Mutations:		['lr', 'param', 'arch', 'None', 'param', 'None']
                

 50%|####9     | 99561/200000 [  20:13<24:05:44,  1.16step/s]


                --- Global Steps 606720 ---
                Fitness:		['-96.48', '-65.29', '-139.11', '-19.42', '-38.49', '-108.90']
                Score:		[-74.14464925564207, -92.40042057585039, -86.12892448037017, -82.34105152346754, -114.89749455277043, -88.7301991013866]
                5 fitness avgs:	['-10.67', '-10.67', '-10.67', '-10.67', '-10.67', '-30.39']
                10 score avgs:	['-60.90', '-60.90', '-60.90', '-60.90', '-60.90', '-54.50']
                Agents:		[48, 51, 52, 53, 54, 55]
                Steps:		[101120, 101120, 101120, 101120, 101120, 101120]
                Mutations:		['None', 'arch', 'arch', 'None', 'None', 'arch']
                

 55%|#####4    | 109515/200000 [  21:59<18:06:31,  1.39step/s]


                --- Global Steps 667392 ---
                Fitness:		['-26.62', '-129.02', '-116.78', '-43.23', '-52.59', '-81.52']
                Score:		[-87.79195366292578, -82.76861279905067, -89.82692392858988, -81.3693802368153, -104.06089211134508, -104.75442039041761]
                5 fitness avgs:	['-28.71', '-28.71', '-28.71', '-59.41', '-32.03', '-33.90']
                10 score avgs:	['-94.38', '-94.38', '-94.38', '-109.45', '-76.83', '-118.17']
                Agents:		[48, 56, 57, 58, 59, 60]
                Steps:		[111232, 111232, 111232, 111232, 111232, 111232]
                Mutations:		['None', 'None', 'None', 'None', 'param', 'lr']
                

 60%|#####9    | 119469/200000 [  23:54<16:52:43,  1.33step/s]


                --- Global Steps 728064 ---
                Fitness:		['-107.22', '-122.53', '-70.48', '-97.26', '-139.52', '-113.37']
                Score:		[-97.03724675199348, -90.94350791484624, -100.73155590224533, -97.49177504347637, -104.50244626598483, -94.74597996485201]
                5 fitness avgs:	['-36.57', '-72.62', '-46.98', '-36.57', '-46.98', '-43.92']
                10 score avgs:	['-114.34', '-95.58', '-87.41', '-114.34', '-87.41', '-102.19']
                Agents:		[57, 61, 62, 63, 64, 65]
                Steps:		[121344, 121344, 121344, 121344, 121344, 121344]
                Mutations:		['param', 'None', 'None', 'None', 'None', 'param']
                

 65%|######4   | 129423/200000 [  25:45<15:25:15,  1.27step/s]


                --- Global Steps 788736 ---
                Fitness:		['-61.52', '-80.04', '-108.44', '-75.89', '-52.23', '-110.31']
                Score:		[-105.76898978118282, -103.55824822308122, -111.13394707135542, -126.12025056059439, -110.67323653469296, -100.81077999164954]
                5 fitness avgs:	['-60.91', '-60.91', '-60.91', '-52.36', '-52.36', '-60.91']
                10 score avgs:	['-109.87', '-109.87', '-109.87', '-113.57', '-113.57', '-109.87']
                Agents:		[64, 66, 67, 68, 69, 70]
                Steps:		[131456, 131456, 131456, 131456, 131456, 131456]
                Mutations:		['bs', 'arch', 'arch', 'arch', 'param', 'lr']
                

 70%|######9   | 139377/200000 [  27:44<14:16:51,  1.18step/s]


                --- Global Steps 849408 ---
                Fitness:		['-117.73', '-54.99', '-90.13', '-69.13', '-106.27', '-86.08']
                Score:		[-99.72556868278174, -114.89462332511049, -104.11453362907882, -103.15080000801638, -92.02041686700224, -94.89752364088878]
                5 fitness avgs:	['-55.16', '-61.38', '-61.38', '-55.16', '-49.43', '-67.71']
                10 score avgs:	['-111.71', '-99.95', '-99.95', '-111.71', '-103.78', '-107.10']
                Agents:		[66, 71, 72, 73, 74, 75]
                Steps:		[141568, 141568, 141568, 141568, 141568, 141568]
                Mutations:		['arch', 'None', 'None', 'arch', 'None', 'arch']
                

 75%|#######4  | 149331/200000 [  29:30<9:39:42,  1.46step/s] 


                --- Global Steps 910080 ---
                Fitness:		['-137.86', '-57.56', '-43.97', '-102.76', '-82.33', '-105.38']
                Score:		[-116.23200768557322, -99.67399672589278, -98.35852634132415, -99.39424578817624, -83.89605974888391, -100.09812341202554]
                5 fitness avgs:	['-66.29', '-66.29', '-69.01', '-66.29', '-69.01', '-69.01']
                10 score avgs:	['-85.88', '-85.88', '-100.77', '-85.88', '-100.77', '-100.77']
                Agents:		[72, 76, 77, 78, 79, 80]
                Steps:		[151680, 151680, 151680, 151680, 151680, 151680]
                Mutations:		['param', 'param', 'arch', 'None', 'param', 'lr']
                

 80%|#######9  | 159285/200000 [  31:17<8:39:31,  1.31step/s] 


                --- Global Steps 970752 ---
                Fitness:		['-36.86', '-52.54', '-145.31', '-73.44', '-22.79', '-91.80']
                Score:		[-98.76523084071081, -51.27714049258225, -149.54934579919768, -94.16636710447274, -68.46088076628932, -65.87190853059533]
                5 fitness avgs:	['-68.24', '-68.33', '-71.47', '-68.33', '-68.24', '-68.24']
                10 score avgs:	['-43.85', '-112.83', '-38.84', '-112.83', '-43.85', '-43.85']
                Agents:		[79, 81, 82, 83, 84, 85]
                Steps:		[161792, 161792, 161792, 161792, 161792, 161792]
                Mutations:		['param', 'arch', 'None', 'param', 'None', 'None']
                

 85%|########4 | 169239/200000 [  33:12<3:50:55,  2.22step/s] 


                --- Global Steps 1031424 ---
                Fitness:		['2.36', '-146.93', '-70.97', '-38.51', '-93.83', '-91.97']
                Score:		[-63.35056974193875, -122.79288938239314, -50.13711297656004, -53.38054260923758, -91.08038090166053, -61.11030378091783]
                5 fitness avgs:	['-43.26', '-51.53', '-51.53', '-61.16', '-51.53', '-61.16']
                10 score avgs:	['-68.45', '-22.18', '-22.18', '-28.47', '-22.18', '-28.47']
                Agents:		[79, 86, 87, 88, 89, 90]
                Steps:		[171904, 171904, 171904, 171904, 171904, 171904]
                Mutations:		['param', 'None', 'bs', 'None', 'param', 'None']
                

 90%|########9 | 179193/200000 [  35:03<4:26:48,  1.30step/s] 


                --- Global Steps 1092096 ---
                Fitness:		['-1.69', '-67.20', '-26.65', '-104.65', '-119.55', '17.78']
                Score:		[-28.744827500340183, -81.33719910546594, -71.65083239862695, -97.06579737346472, -122.42425156069758, -71.17673045598136]
                5 fitness avgs:	['-47.15', '-47.15', '-46.41', '-54.52', '-54.52', '-47.15']
                10 score avgs:	['-52.27', '-52.27', '-65.06', '-95.58', '-95.58', '-52.27']
                Agents:		[90, 91, 92, 93, 94, 95]
                Steps:		[182016, 182016, 182016, 182016, 182016, 182016]
                Mutations:		['param', 'arch', 'param', 'None', 'arch', 'None']
                

 95%|#########4| 189147/200000 [  37:04<2:42:47,  1.11step/s] 


                --- Global Steps 1152768 ---
                Fitness:		['-37.63', '-166.77', '-60.11', '-75.89', '-110.12', '5.74']
                Score:		[-34.022859262901896, -131.45391749016397, -61.97133241476612, -108.11194094444645, -80.71938323373868, -56.98650555624211]
                5 fitness avgs:	['-28.79', '-37.47', '-28.79', '-28.79', '-52.49', '-41.22']
                10 score avgs:	['-49.22', '-42.36', '-49.22', '-49.22', '-122.45', '-24.18']
                Agents:		[95, 96, 97, 98, 99, 100]
                Steps:		[192128, 192128, 192128, 192128, 192128, 192128]
                Mutations:		['arch', 'None', 'arch', 'arch', 'None', 'None']
                

100%|#########9| 199080/200000 [  38:20<  00:04, 184.10step/s]


                --- Global Steps 1213440 ---
                Fitness:		['-67.83', '-42.92', '-104.14', '-22.30', '-103.30', '-59.10']
                Score:		[-37.1213742741476, -86.57895201494533, -64.83913583892107, -25.01514108695892, -101.08372238821408, -48.12941662413789]
                5 fitness avgs:	['-24.46', '-24.46', '-33.56', '-64.35', '-33.56', '-44.24']
                10 score avgs:	['50.49', '50.49', '8.14', '-113.30', '8.14', '-31.23']
                Agents:		[98, 101, 102, 103, 104, 105]
                Steps:		[202240, 202240, 202240, 202240, 202240, 202240]
                Mutations:		['None', 'bs', 'None', 'None', 'None', 'lr']
                

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/best_fitness,▂▁▃▆▅█▃▄▂▃▃▁▂▂▂▃▄▄▄▃
eval/mean_fitness,▁▁▃▆██▄▄▄▄▄▂▄▃▃▄▄▅▄▄
global_step,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
train/agent_0_loss,█▃▂▁▁▁▄▃▃▂▂▁▁▃▁▂▂▁▂▃
train/agent_1_loss,█▅█▂▂▃▁▁▇▂▂▂▄▂▂▄▂▁▃▃
train/agent_2_loss,█▂▃▁▂▂▁▂▁▂▁▂▁▃▂▁▁▃▆▂
train/agent_3_loss,▆█▅▄▁▁▁▂▅▃▂▁▂▂▂▂▄▁▂▃
train/agent_4_loss,█▅▆▃▁▃▁▂▁▂▃▃▂▂▁▃▁▃▁▂
train/agent_5_loss,▅▄▃▂▁▄▂▃▃▃▂▃▁▂▃▃▃▂▄█
train/mean_score,▁▂▃▅▇█▅▄▄▄▄▄▃▄▄▄▅▅▄▅

0,1
eval/best_fitness,-22.29863
eval/mean_fitness,-66.5977
global_step,1213440.0
train/agent_0_loss,0.16962
train/agent_1_loss,0.10075
train/agent_2_loss,0.1201
train/agent_3_loss,0.13996
train/agent_4_loss,0.0332
train/agent_5_loss,0.46607
train/mean_score,-60.46129


  lambda data: self._console_raw_callback("stderr", data),
100%|#########9| 199080/200000 [  38:59<  00:10, 85.09step/s] 


In [None]:
agent = trained_pop[0]

array([[-0.85239583,  1.19907   ,  1.0199465 ,  1.1254752 ,  2.9390442 ,
         1.2153194 ,  0.7284626 ,  0.4107304 ],
       [-1.0123198 , -0.08863777,  4.0035477 , -1.8826923 ,  0.6858207 ,
        -0.0595629 ,  0.9837324 ,  0.4361292 ],
       [-0.03299822, -1.384671  , -3.9240286 , -0.6785452 , -2.195526  ,
         4.4582644 ,  0.03838794,  0.660428  ],
       [-0.87020344,  1.2908074 ,  3.0512488 , -3.3140237 , -0.12082151,
         4.762933  ,  0.26961353,  0.6631864 ],
       [ 0.6808699 ,  0.7150591 ,  2.635245  , -3.7163386 ,  1.1280224 ,
        -2.3501148 ,  0.53162354,  0.758555  ],
       [ 0.18915331, -0.74293584,  2.6423454 ,  1.3652818 ,  0.46170154,
        -0.42255315,  0.13336295,  0.57097054],
       [-0.9718503 , -0.5298286 ,  3.7354355 ,  1.5907434 , -1.3374785 ,
        -1.350843  ,  0.0289788 ,  0.9720034 ],
       [-0.31048515,  0.2554536 , -0.04210167,  1.3130286 , -2.9176698 ,
         1.969855  ,  0.26110598,  0.20814197],
       [-1.1522628 ,  0.30782232

In [None]:
env.action_space.sample()

array([3, 1, 0, 1, 2, 1, 0, 1, 3, 2, 2, 0, 3, 3, 1, 1])

In [None]:
agent.get_action(state=env.observation_space.sample())

(array([3, 1, 1, 1, 3, 1, 2, 1, 1, 3, 1, 1, 3, 1, 1, 1]),
 array([-0.13260274, -0.00351821, -0.00383964, -0.6369603 , -0.04836098,
        -0.00381565, -0.09055927, -0.04441704, -0.05331632, -0.17483655,
        -0.07071134, -0.01006466, -0.00390373, -0.00330666, -0.008079  ,
        -1.0205497 ], dtype=float32),
 array([0.37867695, 0.02484129, 0.02674724, 0.7026828 , 0.19157144,
        0.02653088, 0.30269036, 0.18310887, 0.20520574, 0.4447777 ,
        0.2560623 , 0.05715971, 0.02692572, 0.02410962, 0.04798388,
        0.7158035 ], dtype=float32),
 array([-12.471022 ,   2.430461 , -39.258907 , -34.27118  , -44.87177  ,
          8.222554 , -14.760967 , -45.669674 , -39.85141  , -42.973305 ,
        -13.004054 , -22.519735 ,  -4.1115828, -33.528423 , -25.043116 ,
        -36.668076 ], dtype=float32))