In [1]:
# chose the current file directory as the working directory
import os 
os.chdir("/teamspace/studios/this_studio/rubikscubesolver")


In [2]:
from tqdm import tqdm

import wandb  # for logging
import time
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import flax.nnx as nnx

import optax

from rubiktransformer.models import RubikTransformer, PolicyModel
import rubiktransformer.dataset as dataset
from rubiktransformer.trainer import train
from rubiktransformer.trainer import reshape_sample

cuda_plugin_extension is not found.


In [3]:
@dataclass
class Config:
    """Configuration class"""
    jax_key: jnp.ndarray = jax.random.PRNGKey(46)
    rngs = nnx.Rngs(45)
    batch_size: int = 128
    lr_1: float = 4e-3
    lr_2: float = 4e-3
    nb_games: int = 128 * 100
    len_seq: int = 15
    nb_step: int = 1000000
    log_every_step: int = 10
    log_eval_every_step: int = 10
    log_policy_reward_every_step: int = 10
    add_data_every_step: int = 500

config = Config()

# init wandb config
user = "forbu14"
project = "RubikTransformer"
display_name = "experiment_" + time.strftime("%Y%m%d-%H%M%S")

wandb.init(entity=user, project=project, name=display_name)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mforbu14[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
policy = PolicyModel(rngs=config.rngs)
transformer = RubikTransformer(rngs=config.rngs, causal=True)

scheduler = optax.linear_schedule(init_value=0., end_value=1., transition_steps=4000)

# init optimizer
optimizer_optaxworldmodel = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.lion(config.lr_1 / 100.),
   # optax.adamw(config.lr_1/10.),
    optax.scale_by_schedule(scheduler),
)

optimizer_worldmodel = nnx.Optimizer(transformer, optimizer_optaxworldmodel)
optimizer_policy = nnx.Optimizer(policy, optax.adam(config.lr_2))

# metrics
metrics_train = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
    loss_reward=nnx.metrics.Average("loss_reward"),
    loss_cross_entropy=nnx.metrics.Average("loss_cross_entropy"),
)

metrics_eval = nnx.MultiMetric(
    loss_eval=nnx.metrics.Average("loss_eval"),
    loss_reward_eval=nnx.metrics.Average("loss_reward_eval"),
    loss_cross_entropy_eval=nnx.metrics.Average("loss_cross_entropy_eval"),
)


In [5]:
# gather data from the environment
# init models and optimizers
env, buffer = dataset.init_env_buffer(sample_batch_size=config.batch_size)
env, buffer_eval = dataset.init_env_buffer(sample_batch_size=config.batch_size)


nb_games = config.nb_games
len_seq = config.len_seq

state_first = jnp.zeros((6, 3, 3))
state_next = jnp.zeros((len_seq, 6, 3, 3))
action = jnp.zeros((len_seq, 3))
action_proba = jnp.zeros((len_seq, 9))

# transform state to int8 type
state_first = state_first.astype(jnp.int8)
state_next = state_next.astype(jnp.int8)

# action to int32 type
action = action.astype(jnp.int32)

reward = jnp.zeros((len_seq))

jit_step = jax.jit(env.step)

buffer_list = buffer.init(
    {
        "state_first": state_first,
        "action": action,
        "reward": reward,
        "state_next": state_next,
        "action_pred": action_proba,
    }
)

buffer_list_eval = buffer_eval.init(
    {
        "state_first": state_first,
        "action": action,
        "reward": reward,
        "state_next": state_next,
        "action_pred": action_proba,
    }
)


In [6]:

def step_fn(state, key):
    """
    Simple step function
    We choose a random action
    """

    action = jax.random.randint(
        key=key,
        minval=env.action_spec.minimum,
        maxval=env.action_spec.maximum,
        shape=(3,),
    )

    new_state, timestep = jit_step(state, action)
    timestep.extras["action"] = action
    timestep.extras["action_pred"] = jnp.zeros((9,))

    return new_state, timestep

def run_n_steps(state, key, n):
    random_keys = jax.random.split(key, n)
    state, rollout = jax.lax.scan(step_fn, state, random_keys)

    return rollout

vmap_reset = jax.vmap(jax.jit(env.reset))
vmap_step = jax.vmap(run_n_steps, in_axes=(0, 0, None))

In [7]:
scale_factor = 0.01

def step_fn_proba_setup(state, data):
    """
    Simple step function
    We choose a random action
    """

    key1, key2 = data

    # we choose a probability distribution over the action
    action_proba_0 = jax.random.normal(key1, shape=(6,)) * scale_factor
    action_proba_1 = jax.random.normal(key2, shape=(3,)) * scale_factor

    # softmax over the action
    action_proba_0 = jax.nn.softmax(action_proba_0)
    action_proba_1 = jax.nn.softmax(action_proba_1)

    # sample the action from the probability distribution
    action_proba_0_value = jax.random.categorical(key1, action_proba_0)
    action_proba_1_value = jax.random.categorical(key2, action_proba_1)

    action = jnp.array([action_proba_0_value, jnp.array(0), action_proba_1_value])

    # concat 
    action_proba = jnp.concatenate([action_proba_0, action_proba_1])

    new_state, timestep = jit_step(state, action)
    timestep.extras["action"] = action
    timestep.extras["action_pred"] = action_proba

    return new_state, timestep

def run_n_steps_proba(state, key, n):
    random_keys = jax.random.split(key, (n, 2))
    state, rollout = jax.lax.scan(step_fn_proba_setup, state, random_keys)

    return rollout

vmap_reset = jax.vmap(jax.jit(env.reset))
vmap_step_proba = jax.vmap(run_n_steps_proba, in_axes=(0, 0, None))

In [8]:
nnx.display(transformer)

In [9]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data(
    env,
    vmap_reset,
    vmap_step,
    int(config.nb_games / 10),
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

In [10]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data(
    env,
    vmap_reset,
    vmap_step_proba,
    int(config.nb_games / 10),
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

In [11]:

sample = buffer.sample(buffer_list, subkey)
sample = reshape_sample(sample)

In [28]:
sample

TrajectoryBufferSample(experience={'action': Array([[[2.06892246e-11, 9.55462814e-13, 3.47617985e-14, ...,
         1.70450963e-04, 9.51987028e-01, 4.78424989e-02],
        [2.23278240e-09, 9.99999881e-01, 6.94929483e-08, ...,
         2.38487264e-05, 9.99976158e-01, 7.72832406e-12],
        [9.99887466e-01, 2.59504759e-07, 4.24920762e-17, ...,
         2.51430203e-04, 9.93669868e-01, 6.07857388e-03],
        ...,
        [4.36003500e-09, 9.16709439e-07, 1.09520182e-02, ...,
         9.97643292e-01, 5.05451055e-04, 1.85125030e-03],
        [2.57591438e-03, 4.25381074e-07, 2.15260824e-03, ...,
         3.20033112e-04, 6.29408873e-07, 9.99679267e-01],
        [8.09437040e-07, 1.00822868e-02, 1.97056139e-07, ...,
         1.74851753e-02, 1.12450890e-01, 8.70063901e-01]],

       [[7.36846849e-02, 2.78365478e-04, 4.14273024e-01, ...,
         3.44924883e-05, 9.99965549e-01, 1.61586799e-10],
        [8.44443321e-07, 4.18407442e-09, 9.22161162e-01, ...,
         5.91705543e-07, 9.99983788e-0

In [8]:

def loss_fn_transformer(model: RubikTransformer, batch):
    state_logits, reward_value = model(batch["state_first"], batch["action"])

    # reshape state_logits
    # from (batch_size, sequence_length, 324) => (batch_size, sequence_length -1, 54, 6)
    state_logits = state_logits[:, 1:, :]
    state_logits = state_logits.reshape(
        (state_logits.shape[0], state_logits.shape[1], 54, 6)
    )

    reward_value = reward_value[:, 1:]

    loss_crossentropy = optax.softmax_cross_entropy_with_integer_labels(
        logits=state_logits, labels=batch["state_next"]
    ).mean()

    loss_reward = jnp.square(reward_value - batch["reward"]).mean()

    loss = loss_crossentropy + loss_reward

    return loss, (loss_crossentropy, loss_reward)

@nnx.jit
def train_step_transformer(
    model: RubikTransformer, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch
):
    """Train for a single step."""
    grad_fn = nnx.value_and_grad(loss_fn_transformer, has_aux=True)
    (loss, (loss_crossentropy, loss_reward)), grads = grad_fn(model, batch)
    metrics.update(
        loss=loss, loss_reward=loss_reward, loss_cross_entropy=loss_crossentropy
    )
    optimizer.update(grads)


In [9]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data(
    env,
    vmap_reset,
    vmap_step_proba,
    config.nb_games * 10,
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)


In [10]:

# transformer model calibration
for idx_step in tqdm(range(config.nb_step)):

    # training for world model
    key, subkey = jax.random.split(config.jax_key)
    config.jax_key = key

    if idx_step % config.add_data_every_step == 0:
        buffer, buffer_list = dataset.fast_gathering_data(
            env,
            vmap_reset,
            vmap_step_proba,
            int(config.nb_games / 10.),
            config.len_seq,
            buffer,
            buffer_list,
            config.jax_key,
        )

    sample = buffer.sample(buffer_list, subkey)
    sample = reshape_sample(sample)

    # we update the policy
    train_step_transformer(
        transformer, optimizer_worldmodel, metrics_train, sample.experience
    )

    if idx_step % config.log_every_step == 0:
        metrics_train_result = metrics_train.compute()
        print(metrics_train_result)

        wandb.log(metrics_train_result, step=idx_step)
        metrics_train.reset()

    if idx_step % config.log_eval_every_step == 0:

        key, subkey = jax.random.split(config.jax_key)
        config.jax_key = key
        
        buffer_eval, buffer_list_eval = dataset.fast_gathering_data(
            env,
            vmap_reset,
            vmap_step_proba,
            int(128),
            config.len_seq,
            buffer_eval,
            buffer_list_eval,
            subkey,
        )

        sample = buffer_eval.sample(buffer_list_eval, subkey)
        sample = reshape_sample(sample)

        loss, (loss_crossentropy, loss_reward) = loss_fn_transformer(transformer, sample.experience)

        metrics_eval.update(loss_eval=loss, loss_reward_eval=loss_reward, loss_cross_entropy_eval=loss_crossentropy)
        wandb.log(metrics_eval.compute(), step=idx_step)

        metrics_eval.reset()



  0%|          | 0/1000000 [00:00<?, ?it/s]

{'loss': Array(2.417289, dtype=float32), 'loss_reward': Array(0.3267387, dtype=float32), 'loss_cross_entropy': Array(2.0905504, dtype=float32)}


  0%|          | 9/1000000 [00:37<567:27:30,  2.04s/it]  

{'loss': Array(2.5006628, dtype=float32), 'loss_reward': Array(0.40404272, dtype=float32), 'loss_cross_entropy': Array(2.0966203, dtype=float32)}


  0%|          | 20/1000000 [00:39<137:42:36,  2.02it/s]

{'loss': Array(2.4355872, dtype=float32), 'loss_reward': Array(0.34044582, dtype=float32), 'loss_cross_entropy': Array(2.0951414, dtype=float32)}


  0%|          | 30/1000000 [00:41<56:27:30,  4.92it/s] 

{'loss': Array(2.3759081, dtype=float32), 'loss_reward': Array(0.28704482, dtype=float32), 'loss_cross_entropy': Array(2.0888634, dtype=float32)}


  0%|          | 38/1000000 [00:42<43:52:49,  6.33it/s]

{'loss': Array(2.3407874, dtype=float32), 'loss_reward': Array(0.2572913, dtype=float32), 'loss_cross_entropy': Array(2.083496, dtype=float32)}


  0%|          | 50/1000000 [00:44<31:11:52,  8.90it/s]

{'loss': Array(2.3293843, dtype=float32), 'loss_reward': Array(0.25792685, dtype=float32), 'loss_cross_entropy': Array(2.0714576, dtype=float32)}


  0%|          | 60/1000000 [00:46<34:21:04,  8.09it/s]

{'loss': Array(2.255757, dtype=float32), 'loss_reward': Array(0.19924942, dtype=float32), 'loss_cross_entropy': Array(2.0565078, dtype=float32)}


  0%|          | 70/1000000 [00:48<32:15:21,  8.61it/s]

{'loss': Array(2.2013502, dtype=float32), 'loss_reward': Array(0.16402005, dtype=float32), 'loss_cross_entropy': Array(2.0373304, dtype=float32)}


  0%|          | 78/1000000 [00:49<36:51:45,  7.53it/s]

{'loss': Array(2.1614513, dtype=float32), 'loss_reward': Array(0.14299397, dtype=float32), 'loss_cross_entropy': Array(2.0184574, dtype=float32)}


  0%|          | 90/1000000 [00:51<35:34:32,  7.81it/s]

{'loss': Array(2.1054223, dtype=float32), 'loss_reward': Array(0.11403959, dtype=float32), 'loss_cross_entropy': Array(1.9913826, dtype=float32)}


  0%|          | 100/1000000 [00:53<30:30:36,  9.10it/s]

{'loss': Array(2.063242, dtype=float32), 'loss_reward': Array(0.09969811, dtype=float32), 'loss_cross_entropy': Array(1.9635441, dtype=float32)}


  0%|          | 108/1000000 [00:54<35:18:23,  7.87it/s]

{'loss': Array(2.010981, dtype=float32), 'loss_reward': Array(0.07355243, dtype=float32), 'loss_cross_entropy': Array(1.9374287, dtype=float32)}


  0%|          | 120/1000000 [00:56<29:18:33,  9.48it/s]

{'loss': Array(1.9665139, dtype=float32), 'loss_reward': Array(0.05944021, dtype=float32), 'loss_cross_entropy': Array(1.9070736, dtype=float32)}


  0%|          | 130/1000000 [00:58<32:33:49,  8.53it/s]

{'loss': Array(1.9269739, dtype=float32), 'loss_reward': Array(0.04869048, dtype=float32), 'loss_cross_entropy': Array(1.8782837, dtype=float32)}


  0%|          | 138/1000000 [00:59<36:18:50,  7.65it/s]

{'loss': Array(1.8867314, dtype=float32), 'loss_reward': Array(0.04094809, dtype=float32), 'loss_cross_entropy': Array(1.8457829, dtype=float32)}


  0%|          | 148/1000000 [01:01<33:55:46,  8.19it/s]

{'loss': Array(1.8544966, dtype=float32), 'loss_reward': Array(0.03419331, dtype=float32), 'loss_cross_entropy': Array(1.8203032, dtype=float32)}


  0%|          | 159/1000000 [01:03<37:36:57,  7.38it/s]

{'loss': Array(1.8225437, dtype=float32), 'loss_reward': Array(0.02836212, dtype=float32), 'loss_cross_entropy': Array(1.7941815, dtype=float32)}


  0%|          | 170/1000000 [01:05<30:02:19,  9.25it/s]

{'loss': Array(1.7940953, dtype=float32), 'loss_reward': Array(0.02486544, dtype=float32), 'loss_cross_entropy': Array(1.7692295, dtype=float32)}


  0%|          | 178/1000000 [01:06<35:42:20,  7.78it/s]

{'loss': Array(1.768383, dtype=float32), 'loss_reward': Array(0.02165408, dtype=float32), 'loss_cross_entropy': Array(1.7467289, dtype=float32)}


  0%|          | 189/1000000 [01:08<29:47:58,  9.32it/s]

{'loss': Array(1.7431366, dtype=float32), 'loss_reward': Array(0.01924718, dtype=float32), 'loss_cross_entropy': Array(1.7238894, dtype=float32)}


  0%|          | 200/1000000 [01:10<30:56:04,  8.98it/s]

{'loss': Array(1.7214451, dtype=float32), 'loss_reward': Array(0.01681979, dtype=float32), 'loss_cross_entropy': Array(1.7046254, dtype=float32)}


  0%|          | 208/1000000 [01:11<35:01:11,  7.93it/s]

{'loss': Array(1.7038043, dtype=float32), 'loss_reward': Array(0.01585401, dtype=float32), 'loss_cross_entropy': Array(1.6879504, dtype=float32)}


  0%|          | 219/1000000 [01:13<30:39:35,  9.06it/s]

{'loss': Array(1.6864923, dtype=float32), 'loss_reward': Array(0.01462694, dtype=float32), 'loss_cross_entropy': Array(1.6718657, dtype=float32)}


  0%|          | 229/1000000 [01:15<36:11:14,  7.67it/s]

{'loss': Array(1.6708558, dtype=float32), 'loss_reward': Array(0.01382474, dtype=float32), 'loss_cross_entropy': Array(1.6570309, dtype=float32)}


  0%|          | 239/1000000 [01:16<30:48:29,  9.01it/s]

{'loss': Array(1.6588856, dtype=float32), 'loss_reward': Array(0.01295055, dtype=float32), 'loss_cross_entropy': Array(1.6459349, dtype=float32)}


  0%|          | 250/1000000 [01:18<28:11:06,  9.85it/s]

{'loss': Array(1.6507918, dtype=float32), 'loss_reward': Array(0.01281718, dtype=float32), 'loss_cross_entropy': Array(1.6379747, dtype=float32)}


  0%|          | 260/1000000 [01:20<30:39:56,  9.06it/s]

{'loss': Array(1.6441786, dtype=float32), 'loss_reward': Array(0.01244008, dtype=float32), 'loss_cross_entropy': Array(1.6317385, dtype=float32)}


  0%|          | 268/1000000 [01:22<40:25:34,  6.87it/s]

{'loss': Array(1.6379908, dtype=float32), 'loss_reward': Array(0.01209219, dtype=float32), 'loss_cross_entropy': Array(1.6258987, dtype=float32)}


  0%|          | 280/1000000 [01:23<29:55:40,  9.28it/s]

{'loss': Array(1.6330721, dtype=float32), 'loss_reward': Array(0.01174925, dtype=float32), 'loss_cross_entropy': Array(1.6213229, dtype=float32)}


  0%|          | 288/1000000 [01:25<35:28:55,  7.83it/s]

{'loss': Array(1.6290725, dtype=float32), 'loss_reward': Array(0.0116009, dtype=float32), 'loss_cross_entropy': Array(1.6174717, dtype=float32)}


  0%|          | 300/1000000 [01:27<33:51:23,  8.20it/s]

{'loss': Array(1.6276983, dtype=float32), 'loss_reward': Array(0.01226265, dtype=float32), 'loss_cross_entropy': Array(1.6154358, dtype=float32)}


  0%|          | 308/1000000 [01:28<36:44:22,  7.56it/s]

{'loss': Array(1.6241249, dtype=float32), 'loss_reward': Array(0.01169965, dtype=float32), 'loss_cross_entropy': Array(1.6124251, dtype=float32)}


  0%|          | 319/1000000 [01:30<30:25:38,  9.13it/s]

{'loss': Array(1.6222908, dtype=float32), 'loss_reward': Array(0.01160028, dtype=float32), 'loss_cross_entropy': Array(1.6106905, dtype=float32)}


  0%|          | 329/1000000 [01:32<40:15:14,  6.90it/s]

{'loss': Array(1.6214848, dtype=float32), 'loss_reward': Array(0.0119967, dtype=float32), 'loss_cross_entropy': Array(1.6094879, dtype=float32)}


  0%|          | 340/1000000 [01:33<30:20:20,  9.15it/s]

{'loss': Array(1.6203766, dtype=float32), 'loss_reward': Array(0.01224383, dtype=float32), 'loss_cross_entropy': Array(1.608133, dtype=float32)}


  0%|          | 348/1000000 [01:35<35:27:43,  7.83it/s]

{'loss': Array(1.6197315, dtype=float32), 'loss_reward': Array(0.01220781, dtype=float32), 'loss_cross_entropy': Array(1.6075239, dtype=float32)}


  0%|          | 359/1000000 [01:37<30:09:32,  9.21it/s]

{'loss': Array(1.6184596, dtype=float32), 'loss_reward': Array(0.0125552, dtype=float32), 'loss_cross_entropy': Array(1.6059045, dtype=float32)}


  0%|          | 369/1000000 [01:38<34:01:07,  8.16it/s]

{'loss': Array(1.6181936, dtype=float32), 'loss_reward': Array(0.01248344, dtype=float32), 'loss_cross_entropy': Array(1.6057103, dtype=float32)}


  0%|          | 379/1000000 [01:40<30:32:58,  9.09it/s]

{'loss': Array(1.6175998, dtype=float32), 'loss_reward': Array(0.01308532, dtype=float32), 'loss_cross_entropy': Array(1.6045147, dtype=float32)}


  0%|          | 389/1000000 [01:42<31:03:30,  8.94it/s]

{'loss': Array(1.6174573, dtype=float32), 'loss_reward': Array(0.01355608, dtype=float32), 'loss_cross_entropy': Array(1.6039009, dtype=float32)}


  0%|          | 399/1000000 [01:44<38:14:38,  7.26it/s]

{'loss': Array(1.6163836, dtype=float32), 'loss_reward': Array(0.01334363, dtype=float32), 'loss_cross_entropy': Array(1.60304, dtype=float32)}


  0%|          | 410/1000000 [01:45<29:29:17,  9.42it/s]

{'loss': Array(1.6157662, dtype=float32), 'loss_reward': Array(0.01346018, dtype=float32), 'loss_cross_entropy': Array(1.6023062, dtype=float32)}


  0%|          | 420/1000000 [01:47<29:22:06,  9.45it/s]

{'loss': Array(1.6154, dtype=float32), 'loss_reward': Array(0.01399134, dtype=float32), 'loss_cross_entropy': Array(1.6014086, dtype=float32)}


  0%|          | 428/1000000 [01:48<34:41:03,  8.01it/s]

{'loss': Array(1.6149458, dtype=float32), 'loss_reward': Array(0.0139344, dtype=float32), 'loss_cross_entropy': Array(1.6010113, dtype=float32)}


  0%|          | 439/1000000 [01:50<34:25:21,  8.07it/s]

{'loss': Array(1.6148885, dtype=float32), 'loss_reward': Array(0.01480817, dtype=float32), 'loss_cross_entropy': Array(1.6000804, dtype=float32)}


  0%|          | 449/1000000 [01:52<30:42:12,  9.04it/s]

{'loss': Array(1.6146268, dtype=float32), 'loss_reward': Array(0.01512785, dtype=float32), 'loss_cross_entropy': Array(1.5994987, dtype=float32)}


  0%|          | 460/1000000 [01:54<28:05:09,  9.89it/s]

{'loss': Array(1.6145003, dtype=float32), 'loss_reward': Array(0.01562871, dtype=float32), 'loss_cross_entropy': Array(1.5988716, dtype=float32)}


  0%|          | 469/1000000 [01:55<42:03:50,  6.60it/s]

{'loss': Array(1.613854, dtype=float32), 'loss_reward': Array(0.01576485, dtype=float32), 'loss_cross_entropy': Array(1.5980891, dtype=float32)}


  0%|          | 480/1000000 [01:57<31:05:41,  8.93it/s]

{'loss': Array(1.6131966, dtype=float32), 'loss_reward': Array(0.01585658, dtype=float32), 'loss_cross_entropy': Array(1.5973399, dtype=float32)}


  0%|          | 488/1000000 [01:58<35:12:12,  7.89it/s]

{'loss': Array(1.6130278, dtype=float32), 'loss_reward': Array(0.01606592, dtype=float32), 'loss_cross_entropy': Array(1.5969619, dtype=float32)}


  0%|          | 499/1000000 [02:00<30:20:46,  9.15it/s]

{'loss': Array(1.6122477, dtype=float32), 'loss_reward': Array(0.01621721, dtype=float32), 'loss_cross_entropy': Array(1.5960305, dtype=float32)}


  0%|          | 508/1000000 [02:12<148:40:09,  1.87it/s]

{'loss': Array(1.611981, dtype=float32), 'loss_reward': Array(0.01645513, dtype=float32), 'loss_cross_entropy': Array(1.5955261, dtype=float32)}


  0%|          | 520/1000000 [02:14<52:26:52,  5.29it/s] 

{'loss': Array(1.6123126, dtype=float32), 'loss_reward': Array(0.0169848, dtype=float32), 'loss_cross_entropy': Array(1.5953277, dtype=float32)}


  0%|          | 528/1000000 [02:15<42:23:06,  6.55it/s]

{'loss': Array(1.6118355, dtype=float32), 'loss_reward': Array(0.01703436, dtype=float32), 'loss_cross_entropy': Array(1.594801, dtype=float32)}


  0%|          | 538/1000000 [02:17<44:03:55,  6.30it/s]

{'loss': Array(1.6112927, dtype=float32), 'loss_reward': Array(0.01666373, dtype=float32), 'loss_cross_entropy': Array(1.5946288, dtype=float32)}


  0%|          | 548/1000000 [02:19<34:35:24,  8.03it/s]

{'loss': Array(1.6110486, dtype=float32), 'loss_reward': Array(0.01665336, dtype=float32), 'loss_cross_entropy': Array(1.594395, dtype=float32)}


  0%|          | 560/1000000 [02:21<28:30:29,  9.74it/s]

{'loss': Array(1.6110274, dtype=float32), 'loss_reward': Array(0.01672965, dtype=float32), 'loss_cross_entropy': Array(1.5942976, dtype=float32)}


  0%|          | 570/1000000 [02:22<30:06:42,  9.22it/s]

{'loss': Array(1.6108963, dtype=float32), 'loss_reward': Array(0.01629159, dtype=float32), 'loss_cross_entropy': Array(1.5946046, dtype=float32)}


  0%|          | 578/1000000 [02:24<40:25:06,  6.87it/s]

{'loss': Array(1.610802, dtype=float32), 'loss_reward': Array(0.01643177, dtype=float32), 'loss_cross_entropy': Array(1.5943702, dtype=float32)}


  0%|          | 590/1000000 [02:26<30:03:00,  9.24it/s]

{'loss': Array(1.6112531, dtype=float32), 'loss_reward': Array(0.01690731, dtype=float32), 'loss_cross_entropy': Array(1.5943459, dtype=float32)}


  0%|          | 598/1000000 [02:27<35:23:31,  7.84it/s]

{'loss': Array(1.611309, dtype=float32), 'loss_reward': Array(0.01712762, dtype=float32), 'loss_cross_entropy': Array(1.5941814, dtype=float32)}


  0%|          | 610/1000000 [02:29<33:59:11,  8.17it/s]

{'loss': Array(1.611001, dtype=float32), 'loss_reward': Array(0.01684008, dtype=float32), 'loss_cross_entropy': Array(1.5941612, dtype=float32)}


  0%|          | 619/1000000 [02:31<33:12:05,  8.36it/s]

{'loss': Array(1.6110668, dtype=float32), 'loss_reward': Array(0.01722815, dtype=float32), 'loss_cross_entropy': Array(1.5938386, dtype=float32)}


  0%|          | 630/1000000 [02:32<29:15:29,  9.49it/s]

{'loss': Array(1.6108278, dtype=float32), 'loss_reward': Array(0.01709925, dtype=float32), 'loss_cross_entropy': Array(1.5937287, dtype=float32)}


  0%|          | 640/1000000 [02:34<40:38:49,  6.83it/s]

{'loss': Array(1.6111845, dtype=float32), 'loss_reward': Array(0.01757618, dtype=float32), 'loss_cross_entropy': Array(1.5936083, dtype=float32)}


  0%|          | 648/1000000 [02:36<38:57:24,  7.13it/s]

{'loss': Array(1.6107445, dtype=float32), 'loss_reward': Array(0.01731069, dtype=float32), 'loss_cross_entropy': Array(1.593434, dtype=float32)}


  0%|          | 660/1000000 [02:38<29:18:20,  9.47it/s]

{'loss': Array(1.6107676, dtype=float32), 'loss_reward': Array(0.01739084, dtype=float32), 'loss_cross_entropy': Array(1.5933769, dtype=float32)}


  0%|          | 670/1000000 [02:39<30:12:20,  9.19it/s]

{'loss': Array(1.610525, dtype=float32), 'loss_reward': Array(0.01713314, dtype=float32), 'loss_cross_entropy': Array(1.5933919, dtype=float32)}


  0%|          | 679/1000000 [02:41<38:34:10,  7.20it/s]

{'loss': Array(1.6104988, dtype=float32), 'loss_reward': Array(0.01720355, dtype=float32), 'loss_cross_entropy': Array(1.5932955, dtype=float32)}


  0%|          | 690/1000000 [02:43<28:02:41,  9.90it/s]

{'loss': Array(1.6106147, dtype=float32), 'loss_reward': Array(0.01743491, dtype=float32), 'loss_cross_entropy': Array(1.5931797, dtype=float32)}


  0%|          | 698/1000000 [02:44<35:07:42,  7.90it/s]

{'loss': Array(1.6110706, dtype=float32), 'loss_reward': Array(0.01780671, dtype=float32), 'loss_cross_entropy': Array(1.5932639, dtype=float32)}


  0%|          | 709/1000000 [02:46<40:04:06,  6.93it/s]

{'loss': Array(1.6108551, dtype=float32), 'loss_reward': Array(0.01762908, dtype=float32), 'loss_cross_entropy': Array(1.5932261, dtype=float32)}


  0%|          | 720/1000000 [02:48<30:47:33,  9.01it/s]

{'loss': Array(1.6103457, dtype=float32), 'loss_reward': Array(0.01722983, dtype=float32), 'loss_cross_entropy': Array(1.5931157, dtype=float32)}


  0%|          | 730/1000000 [02:50<30:00:00,  9.25it/s]

{'loss': Array(1.6107686, dtype=float32), 'loss_reward': Array(0.01769795, dtype=float32), 'loss_cross_entropy': Array(1.5930704, dtype=float32)}


  0%|          | 738/1000000 [02:51<36:21:06,  7.64it/s]

{'loss': Array(1.6112928, dtype=float32), 'loss_reward': Array(0.01834279, dtype=float32), 'loss_cross_entropy': Array(1.5929502, dtype=float32)}


  0%|          | 749/1000000 [02:53<33:57:36,  8.17it/s]

{'loss': Array(1.6108744, dtype=float32), 'loss_reward': Array(0.01788417, dtype=float32), 'loss_cross_entropy': Array(1.5929903, dtype=float32)}


  0%|          | 759/1000000 [02:55<30:53:18,  8.99it/s]

{'loss': Array(1.6111262, dtype=float32), 'loss_reward': Array(0.01812646, dtype=float32), 'loss_cross_entropy': Array(1.5929997, dtype=float32)}


  0%|          | 769/1000000 [02:56<30:37:05,  9.07it/s]

{'loss': Array(1.6113503, dtype=float32), 'loss_reward': Array(0.01826181, dtype=float32), 'loss_cross_entropy': Array(1.5930885, dtype=float32)}


  0%|          | 779/1000000 [02:58<41:42:30,  6.65it/s]

{'loss': Array(1.6109059, dtype=float32), 'loss_reward': Array(0.01801197, dtype=float32), 'loss_cross_entropy': Array(1.592894, dtype=float32)}


  0%|          | 790/1000000 [03:00<30:24:21,  9.13it/s]

{'loss': Array(1.6112518, dtype=float32), 'loss_reward': Array(0.01818835, dtype=float32), 'loss_cross_entropy': Array(1.5930634, dtype=float32)}


  0%|          | 798/1000000 [03:01<36:00:47,  7.71it/s]

{'loss': Array(1.6108481, dtype=float32), 'loss_reward': Array(0.0178459, dtype=float32), 'loss_cross_entropy': Array(1.5930022, dtype=float32)}


  0%|          | 810/1000000 [03:03<29:26:16,  9.43it/s]

{'loss': Array(1.611067, dtype=float32), 'loss_reward': Array(0.01821561, dtype=float32), 'loss_cross_entropy': Array(1.5928515, dtype=float32)}


  0%|          | 820/1000000 [03:05<34:30:00,  8.04it/s]

{'loss': Array(1.6113815, dtype=float32), 'loss_reward': Array(0.01848949, dtype=float32), 'loss_cross_entropy': Array(1.5928919, dtype=float32)}


  0%|          | 830/1000000 [03:07<31:56:22,  8.69it/s]

{'loss': Array(1.6112715, dtype=float32), 'loss_reward': Array(0.01828073, dtype=float32), 'loss_cross_entropy': Array(1.5929908, dtype=float32)}


  0%|          | 840/1000000 [03:08<31:37:05,  8.78it/s]

{'loss': Array(1.6110696, dtype=float32), 'loss_reward': Array(0.01811422, dtype=float32), 'loss_cross_entropy': Array(1.5929554, dtype=float32)}


  0%|          | 848/1000000 [03:10<45:56:59,  6.04it/s]

{'loss': Array(1.611585, dtype=float32), 'loss_reward': Array(0.01866044, dtype=float32), 'loss_cross_entropy': Array(1.5929247, dtype=float32)}


  0%|          | 858/1000000 [03:12<34:43:05,  7.99it/s]

{'loss': Array(1.6114386, dtype=float32), 'loss_reward': Array(0.0186101, dtype=float32), 'loss_cross_entropy': Array(1.5928286, dtype=float32)}


  0%|          | 870/1000000 [03:14<29:06:36,  9.53it/s]

{'loss': Array(1.6113338, dtype=float32), 'loss_reward': Array(0.01848349, dtype=float32), 'loss_cross_entropy': Array(1.5928503, dtype=float32)}


  0%|          | 878/1000000 [03:15<34:59:44,  7.93it/s]

{'loss': Array(1.6120872, dtype=float32), 'loss_reward': Array(0.01923378, dtype=float32), 'loss_cross_entropy': Array(1.5928534, dtype=float32)}


  0%|          | 889/1000000 [03:17<34:44:45,  7.99it/s]

{'loss': Array(1.611419, dtype=float32), 'loss_reward': Array(0.01850738, dtype=float32), 'loss_cross_entropy': Array(1.5929116, dtype=float32)}


  0%|          | 900/1000000 [03:19<29:37:59,  9.37it/s]

{'loss': Array(1.6118151, dtype=float32), 'loss_reward': Array(0.01900168, dtype=float32), 'loss_cross_entropy': Array(1.5928134, dtype=float32)}


  0%|          | 908/1000000 [03:20<35:11:08,  7.89it/s]

{'loss': Array(1.6116192, dtype=float32), 'loss_reward': Array(0.01874833, dtype=float32), 'loss_cross_entropy': Array(1.592871, dtype=float32)}


  0%|          | 919/1000000 [03:22<36:16:40,  7.65it/s]

{'loss': Array(1.611337, dtype=float32), 'loss_reward': Array(0.01857535, dtype=float32), 'loss_cross_entropy': Array(1.5927614, dtype=float32)}


  0%|          | 930/1000000 [03:24<29:58:30,  9.26it/s]

{'loss': Array(1.6117525, dtype=float32), 'loss_reward': Array(0.01902646, dtype=float32), 'loss_cross_entropy': Array(1.592726, dtype=float32)}


  0%|          | 940/1000000 [03:26<29:15:14,  9.49it/s]

{'loss': Array(1.611555, dtype=float32), 'loss_reward': Array(0.01876602, dtype=float32), 'loss_cross_entropy': Array(1.5927888, dtype=float32)}


  0%|          | 950/1000000 [03:27<40:36:45,  6.83it/s]

{'loss': Array(1.6114067, dtype=float32), 'loss_reward': Array(0.01865127, dtype=float32), 'loss_cross_entropy': Array(1.5927553, dtype=float32)}


  0%|          | 958/1000000 [03:29<38:59:19,  7.12it/s]

{'loss': Array(1.6112509, dtype=float32), 'loss_reward': Array(0.01850897, dtype=float32), 'loss_cross_entropy': Array(1.592742, dtype=float32)}


  0%|          | 970/1000000 [03:31<29:32:53,  9.39it/s]

{'loss': Array(1.6122532, dtype=float32), 'loss_reward': Array(0.01950678, dtype=float32), 'loss_cross_entropy': Array(1.5927465, dtype=float32)}


  0%|          | 978/1000000 [03:32<35:16:53,  7.87it/s]

{'loss': Array(1.6125351, dtype=float32), 'loss_reward': Array(0.01977454, dtype=float32), 'loss_cross_entropy': Array(1.5927604, dtype=float32)}


  0%|          | 989/1000000 [03:34<37:04:35,  7.48it/s]

{'loss': Array(1.6120297, dtype=float32), 'loss_reward': Array(0.01930654, dtype=float32), 'loss_cross_entropy': Array(1.5927233, dtype=float32)}


  0%|          | 999/1000000 [03:36<32:02:59,  8.66it/s]

{'loss': Array(1.6124867, dtype=float32), 'loss_reward': Array(0.01975856, dtype=float32), 'loss_cross_entropy': Array(1.5927281, dtype=float32)}


  0%|          | 1009/1000000 [03:47<139:08:27,  1.99it/s]

{'loss': Array(1.612893, dtype=float32), 'loss_reward': Array(0.02018882, dtype=float32), 'loss_cross_entropy': Array(1.5927042, dtype=float32)}


  0%|          | 1020/1000000 [03:49<61:54:24,  4.48it/s] 

{'loss': Array(1.6124382, dtype=float32), 'loss_reward': Array(0.0196757, dtype=float32), 'loss_cross_entropy': Array(1.5927625, dtype=float32)}


  0%|          | 1028/1000000 [03:51<47:22:16,  5.86it/s]

{'loss': Array(1.6126645, dtype=float32), 'loss_reward': Array(0.0199453, dtype=float32), 'loss_cross_entropy': Array(1.5927192, dtype=float32)}


  0%|          | 1039/1000000 [03:52<34:01:20,  8.16it/s]

{'loss': Array(1.6122698, dtype=float32), 'loss_reward': Array(0.01963391, dtype=float32), 'loss_cross_entropy': Array(1.5926358, dtype=float32)}


  0%|          | 1050/1000000 [03:54<29:15:02,  9.49it/s]

{'loss': Array(1.6122326, dtype=float32), 'loss_reward': Array(0.01949915, dtype=float32), 'loss_cross_entropy': Array(1.5927335, dtype=float32)}


  0%|          | 1058/1000000 [03:56<41:03:35,  6.76it/s]

{'loss': Array(1.6120423, dtype=float32), 'loss_reward': Array(0.01949306, dtype=float32), 'loss_cross_entropy': Array(1.5925492, dtype=float32)}


  0%|          | 1070/1000000 [03:58<30:24:47,  9.12it/s]

{'loss': Array(1.6125021, dtype=float32), 'loss_reward': Array(0.01982556, dtype=float32), 'loss_cross_entropy': Array(1.5926766, dtype=float32)}


  0%|          | 1078/1000000 [03:59<34:52:06,  7.96it/s]

{'loss': Array(1.612271, dtype=float32), 'loss_reward': Array(0.01958064, dtype=float32), 'loss_cross_entropy': Array(1.5926902, dtype=float32)}


  0%|          | 1089/1000000 [04:01<28:48:30,  9.63it/s]

{'loss': Array(1.6122936, dtype=float32), 'loss_reward': Array(0.01966412, dtype=float32), 'loss_cross_entropy': Array(1.5926294, dtype=float32)}


  0%|          | 1100/1000000 [04:03<31:12:04,  8.89it/s]

{'loss': Array(1.6117823, dtype=float32), 'loss_reward': Array(0.0191808, dtype=float32), 'loss_cross_entropy': Array(1.5926015, dtype=float32)}


  0%|          | 1108/1000000 [04:04<35:52:14,  7.74it/s]

{'loss': Array(1.6127074, dtype=float32), 'loss_reward': Array(0.02006105, dtype=float32), 'loss_cross_entropy': Array(1.5926464, dtype=float32)}


  0%|          | 1119/1000000 [04:06<31:04:15,  8.93it/s]

{'loss': Array(1.6119441, dtype=float32), 'loss_reward': Array(0.01939991, dtype=float32), 'loss_cross_entropy': Array(1.5925441, dtype=float32)}


  0%|          | 1130/1000000 [04:08<32:01:47,  8.66it/s]

{'loss': Array(1.612748, dtype=float32), 'loss_reward': Array(0.0202163, dtype=float32), 'loss_cross_entropy': Array(1.5925317, dtype=float32)}


  0%|          | 1139/1000000 [04:10<32:05:22,  8.65it/s]

{'loss': Array(1.6127646, dtype=float32), 'loss_reward': Array(0.02018717, dtype=float32), 'loss_cross_entropy': Array(1.5925776, dtype=float32)}


  0%|          | 1150/1000000 [04:11<28:45:18,  9.65it/s]

{'loss': Array(1.6127757, dtype=float32), 'loss_reward': Array(0.02023877, dtype=float32), 'loss_cross_entropy': Array(1.5925369, dtype=float32)}


  0%|          | 1160/1000000 [04:13<40:27:40,  6.86it/s]

{'loss': Array(1.6123904, dtype=float32), 'loss_reward': Array(0.0198669, dtype=float32), 'loss_cross_entropy': Array(1.5925235, dtype=float32)}


  0%|          | 1168/1000000 [04:15<38:57:16,  7.12it/s]

{'loss': Array(1.6119261, dtype=float32), 'loss_reward': Array(0.01948575, dtype=float32), 'loss_cross_entropy': Array(1.5924404, dtype=float32)}


  0%|          | 1180/1000000 [04:16<29:26:34,  9.42it/s]

{'loss': Array(1.6121699, dtype=float32), 'loss_reward': Array(0.01971254, dtype=float32), 'loss_cross_entropy': Array(1.5924574, dtype=float32)}


  0%|          | 1189/1000000 [04:18<31:36:57,  8.78it/s]

{'loss': Array(1.6119255, dtype=float32), 'loss_reward': Array(0.01956483, dtype=float32), 'loss_cross_entropy': Array(1.5923607, dtype=float32)}


  0%|          | 1198/1000000 [04:20<39:27:04,  7.03it/s]

{'loss': Array(1.6121047, dtype=float32), 'loss_reward': Array(0.01968398, dtype=float32), 'loss_cross_entropy': Array(1.5924208, dtype=float32)}


  0%|          | 1210/1000000 [04:21<28:58:09,  9.58it/s]

{'loss': Array(1.6121038, dtype=float32), 'loss_reward': Array(0.01962668, dtype=float32), 'loss_cross_entropy': Array(1.5924771, dtype=float32)}


  0%|          | 1218/1000000 [04:23<35:05:52,  7.90it/s]

{'loss': Array(1.6124598, dtype=float32), 'loss_reward': Array(0.02007307, dtype=float32), 'loss_cross_entropy': Array(1.5923867, dtype=float32)}


  0%|          | 1230/1000000 [04:25<35:17:22,  7.86it/s]

{'loss': Array(1.6124678, dtype=float32), 'loss_reward': Array(0.02008531, dtype=float32), 'loss_cross_entropy': Array(1.5923823, dtype=float32)}


  0%|          | 1238/1000000 [04:26<37:41:30,  7.36it/s]

{'loss': Array(1.6125431, dtype=float32), 'loss_reward': Array(0.020216, dtype=float32), 'loss_cross_entropy': Array(1.5923272, dtype=float32)}


  0%|          | 1250/1000000 [04:28<28:53:46,  9.60it/s]

{'loss': Array(1.6126823, dtype=float32), 'loss_reward': Array(0.02035136, dtype=float32), 'loss_cross_entropy': Array(1.5923312, dtype=float32)}


  0%|          | 1258/1000000 [04:30<36:03:40,  7.69it/s]

{'loss': Array(1.6122916, dtype=float32), 'loss_reward': Array(0.01988904, dtype=float32), 'loss_cross_entropy': Array(1.5924025, dtype=float32)}


  0%|          | 1268/1000000 [04:32<37:06:41,  7.48it/s]

{'loss': Array(1.6122792, dtype=float32), 'loss_reward': Array(0.01994496, dtype=float32), 'loss_cross_entropy': Array(1.5923342, dtype=float32)}


  0%|          | 1280/1000000 [04:33<29:07:17,  9.53it/s]

{'loss': Array(1.6121371, dtype=float32), 'loss_reward': Array(0.01987123, dtype=float32), 'loss_cross_entropy': Array(1.592266, dtype=float32)}


  0%|          | 1288/1000000 [04:35<35:26:44,  7.83it/s]

{'loss': Array(1.6121839, dtype=float32), 'loss_reward': Array(0.01987612, dtype=float32), 'loss_cross_entropy': Array(1.5923079, dtype=float32)}


  0%|          | 1299/1000000 [04:37<36:43:36,  7.55it/s]

{'loss': Array(1.6118828, dtype=float32), 'loss_reward': Array(0.01967403, dtype=float32), 'loss_cross_entropy': Array(1.592209, dtype=float32)}


  0%|          | 1310/1000000 [04:39<29:13:45,  9.49it/s]

{'loss': Array(1.6122884, dtype=float32), 'loss_reward': Array(0.02001898, dtype=float32), 'loss_cross_entropy': Array(1.5922693, dtype=float32)}


  0%|          | 1320/1000000 [04:40<30:14:15,  9.17it/s]

{'loss': Array(1.6119698, dtype=float32), 'loss_reward': Array(0.01980982, dtype=float32), 'loss_cross_entropy': Array(1.5921601, dtype=float32)}


  0%|          | 1330/1000000 [04:42<30:09:57,  9.20it/s]

{'loss': Array(1.6123238, dtype=float32), 'loss_reward': Array(0.02016601, dtype=float32), 'loss_cross_entropy': Array(1.5921577, dtype=float32)}


  0%|          | 1338/1000000 [04:44<40:10:29,  6.90it/s]

{'loss': Array(1.6117867, dtype=float32), 'loss_reward': Array(0.01963142, dtype=float32), 'loss_cross_entropy': Array(1.5921552, dtype=float32)}


  0%|          | 1349/1000000 [04:45<30:53:54,  8.98it/s]

{'loss': Array(1.6116478, dtype=float32), 'loss_reward': Array(0.01955317, dtype=float32), 'loss_cross_entropy': Array(1.5920948, dtype=float32)}


  0%|          | 1360/1000000 [04:47<28:06:31,  9.87it/s]

{'loss': Array(1.6120433, dtype=float32), 'loss_reward': Array(0.02006361, dtype=float32), 'loss_cross_entropy': Array(1.5919797, dtype=float32)}


  0%|          | 1370/1000000 [04:49<35:57:57,  7.71it/s]

{'loss': Array(1.6118755, dtype=float32), 'loss_reward': Array(0.01987417, dtype=float32), 'loss_cross_entropy': Array(1.5920012, dtype=float32)}


  0%|          | 1378/1000000 [04:50<36:48:33,  7.54it/s]

{'loss': Array(1.6120577, dtype=float32), 'loss_reward': Array(0.02005417, dtype=float32), 'loss_cross_entropy': Array(1.5920037, dtype=float32)}


  0%|          | 1390/1000000 [04:52<29:10:45,  9.51it/s]

{'loss': Array(1.6116695, dtype=float32), 'loss_reward': Array(0.01971119, dtype=float32), 'loss_cross_entropy': Array(1.5919586, dtype=float32)}


  0%|          | 1398/1000000 [04:53<35:00:17,  7.92it/s]

{'loss': Array(1.6109403, dtype=float32), 'loss_reward': Array(0.01912481, dtype=float32), 'loss_cross_entropy': Array(1.5918155, dtype=float32)}


  0%|          | 1410/1000000 [04:55<32:16:10,  8.60it/s]

{'loss': Array(1.6109848, dtype=float32), 'loss_reward': Array(0.01916529, dtype=float32), 'loss_cross_entropy': Array(1.5918195, dtype=float32)}


  0%|          | 1418/1000000 [04:57<36:05:13,  7.69it/s]

{'loss': Array(1.6112083, dtype=float32), 'loss_reward': Array(0.01951671, dtype=float32), 'loss_cross_entropy': Array(1.5916916, dtype=float32)}


  0%|          | 1428/1000000 [04:59<32:56:52,  8.42it/s]

{'loss': Array(1.6104561, dtype=float32), 'loss_reward': Array(0.01884924, dtype=float32), 'loss_cross_entropy': Array(1.591607, dtype=float32)}


  0%|          | 1439/1000000 [05:01<36:20:45,  7.63it/s]

{'loss': Array(1.611372, dtype=float32), 'loss_reward': Array(0.01974459, dtype=float32), 'loss_cross_entropy': Array(1.5916274, dtype=float32)}


  0%|          | 1450/1000000 [05:02<29:30:37,  9.40it/s]

{'loss': Array(1.6111187, dtype=float32), 'loss_reward': Array(0.0196018, dtype=float32), 'loss_cross_entropy': Array(1.5915169, dtype=float32)}


  0%|          | 1458/1000000 [05:04<35:43:41,  7.76it/s]

{'loss': Array(1.6101656, dtype=float32), 'loss_reward': Array(0.01869509, dtype=float32), 'loss_cross_entropy': Array(1.5914707, dtype=float32)}


  0%|          | 1469/1000000 [05:06<31:03:44,  8.93it/s]

{'loss': Array(1.611574, dtype=float32), 'loss_reward': Array(0.02022919, dtype=float32), 'loss_cross_entropy': Array(1.5913447, dtype=float32)}


  0%|          | 1480/1000000 [05:08<32:08:41,  8.63it/s]

{'loss': Array(1.6107801, dtype=float32), 'loss_reward': Array(0.0194093, dtype=float32), 'loss_cross_entropy': Array(1.5913709, dtype=float32)}


  0%|          | 1490/1000000 [05:09<31:00:34,  8.94it/s]

{'loss': Array(1.6096439, dtype=float32), 'loss_reward': Array(0.01854631, dtype=float32), 'loss_cross_entropy': Array(1.5910977, dtype=float32)}


  0%|          | 1498/1000000 [05:11<36:45:47,  7.54it/s]

{'loss': Array(1.6104449, dtype=float32), 'loss_reward': Array(0.01917441, dtype=float32), 'loss_cross_entropy': Array(1.5912706, dtype=float32)}


  0%|          | 1508/1000000 [05:23<160:39:00,  1.73it/s]

{'loss': Array(1.6104332, dtype=float32), 'loss_reward': Array(0.0194404, dtype=float32), 'loss_cross_entropy': Array(1.5909928, dtype=float32)}


  0%|          | 1520/1000000 [05:25<53:59:27,  5.14it/s] 

{'loss': Array(1.6108351, dtype=float32), 'loss_reward': Array(0.01981491, dtype=float32), 'loss_cross_entropy': Array(1.5910203, dtype=float32)}


  0%|          | 1528/1000000 [05:26<44:25:15,  6.24it/s]

{'loss': Array(1.6098124, dtype=float32), 'loss_reward': Array(0.01907094, dtype=float32), 'loss_cross_entropy': Array(1.5907413, dtype=float32)}


  0%|          | 1539/1000000 [05:28<42:38:23,  6.50it/s]

{'loss': Array(1.6107279, dtype=float32), 'loss_reward': Array(0.02000081, dtype=float32), 'loss_cross_entropy': Array(1.5907273, dtype=float32)}


  0%|          | 1550/1000000 [05:30<31:10:08,  8.90it/s]

{'loss': Array(1.609827, dtype=float32), 'loss_reward': Array(0.01901226, dtype=float32), 'loss_cross_entropy': Array(1.590815, dtype=float32)}


  0%|          | 1560/1000000 [05:31<30:18:45,  9.15it/s]

{'loss': Array(1.6099075, dtype=float32), 'loss_reward': Array(0.01947806, dtype=float32), 'loss_cross_entropy': Array(1.5904294, dtype=float32)}


  0%|          | 1568/1000000 [05:33<35:51:48,  7.73it/s]

{'loss': Array(1.6089735, dtype=float32), 'loss_reward': Array(0.01884677, dtype=float32), 'loss_cross_entropy': Array(1.5901266, dtype=float32)}


  0%|          | 1579/1000000 [05:35<34:28:21,  8.05it/s]

{'loss': Array(1.6096258, dtype=float32), 'loss_reward': Array(0.01954466, dtype=float32), 'loss_cross_entropy': Array(1.5900815, dtype=float32)}


  0%|          | 1589/1000000 [05:36<30:28:21,  9.10it/s]

{'loss': Array(1.608926, dtype=float32), 'loss_reward': Array(0.01909623, dtype=float32), 'loss_cross_entropy': Array(1.5898296, dtype=float32)}


  0%|          | 1599/1000000 [05:38<29:31:50,  9.39it/s]

{'loss': Array(1.609804, dtype=float32), 'loss_reward': Array(0.01995443, dtype=float32), 'loss_cross_entropy': Array(1.5898496, dtype=float32)}


  0%|          | 1609/1000000 [05:40<41:15:31,  6.72it/s]

{'loss': Array(1.6088766, dtype=float32), 'loss_reward': Array(0.01924426, dtype=float32), 'loss_cross_entropy': Array(1.5896324, dtype=float32)}


  0%|          | 1620/1000000 [05:41<30:46:54,  9.01it/s]

{'loss': Array(1.6086468, dtype=float32), 'loss_reward': Array(0.01897605, dtype=float32), 'loss_cross_entropy': Array(1.5896705, dtype=float32)}


  0%|          | 1628/1000000 [05:43<35:57:16,  7.71it/s]

{'loss': Array(1.6089535, dtype=float32), 'loss_reward': Array(0.01952014, dtype=float32), 'loss_cross_entropy': Array(1.5894333, dtype=float32)}


  0%|          | 1640/1000000 [05:45<28:41:49,  9.66it/s]

{'loss': Array(1.6078705, dtype=float32), 'loss_reward': Array(0.01896825, dtype=float32), 'loss_cross_entropy': Array(1.5889025, dtype=float32)}


  0%|          | 1649/1000000 [05:47<35:11:02,  7.88it/s]

{'loss': Array(1.6080383, dtype=float32), 'loss_reward': Array(0.01917191, dtype=float32), 'loss_cross_entropy': Array(1.5888665, dtype=float32)}


  0%|          | 1660/1000000 [05:48<27:56:22,  9.93it/s]

{'loss': Array(1.6086243, dtype=float32), 'loss_reward': Array(0.01981561, dtype=float32), 'loss_cross_entropy': Array(1.5888088, dtype=float32)}


  0%|          | 1668/1000000 [05:50<36:24:29,  7.62it/s]

{'loss': Array(1.6071271, dtype=float32), 'loss_reward': Array(0.01882015, dtype=float32), 'loss_cross_entropy': Array(1.5883068, dtype=float32)}


  0%|          | 1679/1000000 [05:52<40:53:14,  6.78it/s]

{'loss': Array(1.6077347, dtype=float32), 'loss_reward': Array(0.01935948, dtype=float32), 'loss_cross_entropy': Array(1.5883751, dtype=float32)}


  0%|          | 1690/1000000 [05:53<30:52:44,  8.98it/s]

{'loss': Array(1.6067923, dtype=float32), 'loss_reward': Array(0.01884268, dtype=float32), 'loss_cross_entropy': Array(1.5879495, dtype=float32)}


  0%|          | 1700/1000000 [05:55<29:43:43,  9.33it/s]

{'loss': Array(1.6069666, dtype=float32), 'loss_reward': Array(0.01927705, dtype=float32), 'loss_cross_entropy': Array(1.5876896, dtype=float32)}


  0%|          | 1710/1000000 [05:57<30:31:06,  9.09it/s]

{'loss': Array(1.6055739, dtype=float32), 'loss_reward': Array(0.01838092, dtype=float32), 'loss_cross_entropy': Array(1.587193, dtype=float32)}


  0%|          | 1718/1000000 [05:58<40:28:26,  6.85it/s]

{'loss': Array(1.6062968, dtype=float32), 'loss_reward': Array(0.0190402, dtype=float32), 'loss_cross_entropy': Array(1.5872567, dtype=float32)}


  0%|          | 1729/1000000 [06:00<31:41:47,  8.75it/s]

{'loss': Array(1.6055815, dtype=float32), 'loss_reward': Array(0.01854132, dtype=float32), 'loss_cross_entropy': Array(1.5870403, dtype=float32)}


  0%|          | 1740/1000000 [06:02<28:37:21,  9.69it/s]

{'loss': Array(1.6058266, dtype=float32), 'loss_reward': Array(0.01905876, dtype=float32), 'loss_cross_entropy': Array(1.5867678, dtype=float32)}


  0%|          | 1750/1000000 [06:04<35:48:59,  7.74it/s]

{'loss': Array(1.6050272, dtype=float32), 'loss_reward': Array(0.01870831, dtype=float32), 'loss_cross_entropy': Array(1.5863189, dtype=float32)}


  0%|          | 1760/1000000 [06:05<31:08:22,  8.90it/s]

{'loss': Array(1.6041237, dtype=float32), 'loss_reward': Array(0.01807913, dtype=float32), 'loss_cross_entropy': Array(1.5860447, dtype=float32)}


  0%|          | 1770/1000000 [06:07<30:29:05,  9.10it/s]

{'loss': Array(1.6047528, dtype=float32), 'loss_reward': Array(0.01905387, dtype=float32), 'loss_cross_entropy': Array(1.585699, dtype=float32)}


  0%|          | 1780/1000000 [06:09<40:52:32,  6.78it/s]

{'loss': Array(1.6044064, dtype=float32), 'loss_reward': Array(0.01888247, dtype=float32), 'loss_cross_entropy': Array(1.585524, dtype=float32)}


  0%|          | 1788/1000000 [06:10<39:00:48,  7.11it/s]

{'loss': Array(1.6036175, dtype=float32), 'loss_reward': Array(0.01856569, dtype=float32), 'loss_cross_entropy': Array(1.5850519, dtype=float32)}


  0%|          | 1800/1000000 [06:12<29:36:48,  9.36it/s]

{'loss': Array(1.6034869, dtype=float32), 'loss_reward': Array(0.01887347, dtype=float32), 'loss_cross_entropy': Array(1.5846133, dtype=float32)}


  0%|          | 1810/1000000 [06:13<30:23:01,  9.13it/s]

{'loss': Array(1.6028794, dtype=float32), 'loss_reward': Array(0.0184674, dtype=float32), 'loss_cross_entropy': Array(1.584412, dtype=float32)}


  0%|          | 1818/1000000 [06:15<42:21:32,  6.55it/s]

{'loss': Array(1.6035706, dtype=float32), 'loss_reward': Array(0.01988431, dtype=float32), 'loss_cross_entropy': Array(1.5836862, dtype=float32)}


  0%|          | 1829/1000000 [06:17<31:58:31,  8.67it/s]

{'loss': Array(1.6019115, dtype=float32), 'loss_reward': Array(0.01879171, dtype=float32), 'loss_cross_entropy': Array(1.5831199, dtype=float32)}


  0%|          | 1839/1000000 [06:19<31:41:30,  8.75it/s]

{'loss': Array(1.6020491, dtype=float32), 'loss_reward': Array(0.01906297, dtype=float32), 'loss_cross_entropy': Array(1.5829862, dtype=float32)}


  0%|          | 1850/1000000 [06:20<38:20:45,  7.23it/s]

{'loss': Array(1.6007063, dtype=float32), 'loss_reward': Array(0.01851553, dtype=float32), 'loss_cross_entropy': Array(1.5821905, dtype=float32)}


  0%|          | 1860/1000000 [06:22<32:10:07,  8.62it/s]

{'loss': Array(1.6009697, dtype=float32), 'loss_reward': Array(0.01848039, dtype=float32), 'loss_cross_entropy': Array(1.5824894, dtype=float32)}


  0%|          | 1868/1000000 [06:24<36:24:37,  7.61it/s]

{'loss': Array(1.6003754, dtype=float32), 'loss_reward': Array(0.01849909, dtype=float32), 'loss_cross_entropy': Array(1.5818762, dtype=float32)}


  0%|          | 1880/1000000 [06:25<29:08:19,  9.52it/s]

{'loss': Array(1.5999671, dtype=float32), 'loss_reward': Array(0.01860035, dtype=float32), 'loss_cross_entropy': Array(1.5813667, dtype=float32)}


  0%|          | 1889/1000000 [06:27<37:25:29,  7.41it/s]

{'loss': Array(1.5983251, dtype=float32), 'loss_reward': Array(0.01808234, dtype=float32), 'loss_cross_entropy': Array(1.580243, dtype=float32)}


  0%|          | 1899/1000000 [06:29<32:12:46,  8.61it/s]

{'loss': Array(1.5991777, dtype=float32), 'loss_reward': Array(0.01876087, dtype=float32), 'loss_cross_entropy': Array(1.5804169, dtype=float32)}


  0%|          | 1909/1000000 [06:31<30:21:15,  9.13it/s]

{'loss': Array(1.5985575, dtype=float32), 'loss_reward': Array(0.01873551, dtype=float32), 'loss_cross_entropy': Array(1.5798222, dtype=float32)}


  0%|          | 1919/1000000 [06:33<41:34:57,  6.67it/s]

{'loss': Array(1.5976932, dtype=float32), 'loss_reward': Array(0.01899756, dtype=float32), 'loss_cross_entropy': Array(1.5786957, dtype=float32)}


  0%|          | 1930/1000000 [06:34<30:36:09,  9.06it/s]

{'loss': Array(1.5973418, dtype=float32), 'loss_reward': Array(0.01828931, dtype=float32), 'loss_cross_entropy': Array(1.5790523, dtype=float32)}


  0%|          | 1940/1000000 [06:36<30:43:15,  9.02it/s]

{'loss': Array(1.5973269, dtype=float32), 'loss_reward': Array(0.01886993, dtype=float32), 'loss_cross_entropy': Array(1.5784568, dtype=float32)}


  0%|          | 1948/1000000 [06:37<36:45:46,  7.54it/s]

{'loss': Array(1.5965984, dtype=float32), 'loss_reward': Array(0.01864321, dtype=float32), 'loss_cross_entropy': Array(1.5779552, dtype=float32)}


  0%|          | 1958/1000000 [06:39<39:33:09,  7.01it/s]

{'loss': Array(1.5941837, dtype=float32), 'loss_reward': Array(0.01774093, dtype=float32), 'loss_cross_entropy': Array(1.5764426, dtype=float32)}


  0%|          | 1969/1000000 [06:41<31:29:28,  8.80it/s]

{'loss': Array(1.5940523, dtype=float32), 'loss_reward': Array(0.01864026, dtype=float32), 'loss_cross_entropy': Array(1.5754122, dtype=float32)}


  0%|          | 1979/1000000 [06:43<31:01:47,  8.93it/s]

{'loss': Array(1.5930102, dtype=float32), 'loss_reward': Array(0.01798056, dtype=float32), 'loss_cross_entropy': Array(1.5750296, dtype=float32)}


  0%|          | 1990/1000000 [06:45<35:53:41,  7.72it/s]

{'loss': Array(1.5934535, dtype=float32), 'loss_reward': Array(0.01782688, dtype=float32), 'loss_cross_entropy': Array(1.5756266, dtype=float32)}


  0%|          | 1998/1000000 [06:46<37:39:06,  7.36it/s]

{'loss': Array(1.5931762, dtype=float32), 'loss_reward': Array(0.01829362, dtype=float32), 'loss_cross_entropy': Array(1.5748826, dtype=float32)}


  0%|          | 2010/1000000 [06:58<129:29:57,  2.14it/s]

{'loss': Array(1.5918378, dtype=float32), 'loss_reward': Array(0.01827465, dtype=float32), 'loss_cross_entropy': Array(1.5735632, dtype=float32)}


  0%|          | 2020/1000000 [07:00<56:55:30,  4.87it/s] 

{'loss': Array(1.5911144, dtype=float32), 'loss_reward': Array(0.01800173, dtype=float32), 'loss_cross_entropy': Array(1.5731128, dtype=float32)}


  0%|          | 2029/1000000 [07:02<43:18:20,  6.40it/s]

{'loss': Array(1.5898024, dtype=float32), 'loss_reward': Array(0.0177531, dtype=float32), 'loss_cross_entropy': Array(1.5720491, dtype=float32)}


  0%|          | 2040/1000000 [07:03<31:42:02,  8.74it/s]

{'loss': Array(1.5886604, dtype=float32), 'loss_reward': Array(0.0180401, dtype=float32), 'loss_cross_entropy': Array(1.5706204, dtype=float32)}


  0%|          | 2048/1000000 [07:05<35:50:01,  7.74it/s]

{'loss': Array(1.5882168, dtype=float32), 'loss_reward': Array(0.01778066, dtype=float32), 'loss_cross_entropy': Array(1.5704361, dtype=float32)}


  0%|          | 2059/1000000 [07:07<36:55:36,  7.51it/s]

{'loss': Array(1.5893315, dtype=float32), 'loss_reward': Array(0.01816788, dtype=float32), 'loss_cross_entropy': Array(1.5711635, dtype=float32)}


  0%|          | 2070/1000000 [07:09<30:26:04,  9.11it/s]

{'loss': Array(1.5867928, dtype=float32), 'loss_reward': Array(0.01802971, dtype=float32), 'loss_cross_entropy': Array(1.5687631, dtype=float32)}


  0%|          | 2080/1000000 [07:10<30:08:53,  9.19it/s]

{'loss': Array(1.5855536, dtype=float32), 'loss_reward': Array(0.01811251, dtype=float32), 'loss_cross_entropy': Array(1.5674411, dtype=float32)}


  0%|          | 2090/1000000 [07:12<30:53:23,  8.97it/s]

{'loss': Array(1.5850195, dtype=float32), 'loss_reward': Array(0.01840317, dtype=float32), 'loss_cross_entropy': Array(1.5666164, dtype=float32)}


  0%|          | 2098/1000000 [07:14<40:28:42,  6.85it/s]

{'loss': Array(1.5853944, dtype=float32), 'loss_reward': Array(0.01833688, dtype=float32), 'loss_cross_entropy': Array(1.5670575, dtype=float32)}


  0%|          | 2109/1000000 [07:15<32:09:19,  8.62it/s]

{'loss': Array(1.5838487, dtype=float32), 'loss_reward': Array(0.01800904, dtype=float32), 'loss_cross_entropy': Array(1.5658396, dtype=float32)}


  0%|          | 2120/1000000 [07:17<29:02:08,  9.55it/s]

{'loss': Array(1.5824436, dtype=float32), 'loss_reward': Array(0.01814183, dtype=float32), 'loss_cross_entropy': Array(1.5643016, dtype=float32)}


  0%|          | 2130/1000000 [07:19<36:39:04,  7.56it/s]

{'loss': Array(1.5811809, dtype=float32), 'loss_reward': Array(0.01768948, dtype=float32), 'loss_cross_entropy': Array(1.5634913, dtype=float32)}


  0%|          | 2138/1000000 [07:20<37:18:45,  7.43it/s]

{'loss': Array(1.581054, dtype=float32), 'loss_reward': Array(0.01787846, dtype=float32), 'loss_cross_entropy': Array(1.5631757, dtype=float32)}


  0%|          | 2150/1000000 [07:22<30:19:29,  9.14it/s]

{'loss': Array(1.5802305, dtype=float32), 'loss_reward': Array(0.01735444, dtype=float32), 'loss_cross_entropy': Array(1.5628761, dtype=float32)}


  0%|          | 2158/1000000 [07:24<35:27:34,  7.82it/s]

{'loss': Array(1.5789468, dtype=float32), 'loss_reward': Array(0.01766478, dtype=float32), 'loss_cross_entropy': Array(1.5612819, dtype=float32)}


  0%|          | 2170/1000000 [07:26<32:09:11,  8.62it/s]

{'loss': Array(1.5776814, dtype=float32), 'loss_reward': Array(0.01778561, dtype=float32), 'loss_cross_entropy': Array(1.5598959, dtype=float32)}


  0%|          | 2178/1000000 [07:27<36:57:06,  7.50it/s]

{'loss': Array(1.5754145, dtype=float32), 'loss_reward': Array(0.01729615, dtype=float32), 'loss_cross_entropy': Array(1.5581183, dtype=float32)}


  0%|          | 2190/1000000 [07:29<29:23:11,  9.43it/s]

{'loss': Array(1.5756954, dtype=float32), 'loss_reward': Array(0.01767763, dtype=float32), 'loss_cross_entropy': Array(1.558018, dtype=float32)}


  0%|          | 2200/1000000 [07:31<37:03:06,  7.48it/s]

{'loss': Array(1.5745265, dtype=float32), 'loss_reward': Array(0.01817918, dtype=float32), 'loss_cross_entropy': Array(1.5563474, dtype=float32)}


  0%|          | 2208/1000000 [07:33<39:32:24,  7.01it/s]

{'loss': Array(1.573545, dtype=float32), 'loss_reward': Array(0.01813157, dtype=float32), 'loss_cross_entropy': Array(1.5554136, dtype=float32)}


  0%|          | 2220/1000000 [07:34<30:18:57,  9.14it/s]

{'loss': Array(1.5722315, dtype=float32), 'loss_reward': Array(0.01782799, dtype=float32), 'loss_cross_entropy': Array(1.5544035, dtype=float32)}


  0%|          | 2230/1000000 [07:36<40:44:35,  6.80it/s]

{'loss': Array(1.5714105, dtype=float32), 'loss_reward': Array(0.01733865, dtype=float32), 'loss_cross_entropy': Array(1.5540718, dtype=float32)}


  0%|          | 2240/1000000 [07:38<33:10:07,  8.36it/s]

{'loss': Array(1.5712004, dtype=float32), 'loss_reward': Array(0.01740038, dtype=float32), 'loss_cross_entropy': Array(1.5538001, dtype=float32)}


  0%|          | 2250/1000000 [07:40<32:12:22,  8.61it/s]

{'loss': Array(1.5693462, dtype=float32), 'loss_reward': Array(0.01784771, dtype=float32), 'loss_cross_entropy': Array(1.5514985, dtype=float32)}


  0%|          | 2258/1000000 [07:41<36:38:01,  7.57it/s]

{'loss': Array(1.5708894, dtype=float32), 'loss_reward': Array(0.01790408, dtype=float32), 'loss_cross_entropy': Array(1.5529853, dtype=float32)}


  0%|          | 2268/1000000 [07:43<38:50:57,  7.13it/s]

{'loss': Array(1.5691723, dtype=float32), 'loss_reward': Array(0.01796054, dtype=float32), 'loss_cross_entropy': Array(1.5512118, dtype=float32)}


  0%|          | 2280/1000000 [07:45<29:58:38,  9.25it/s]

{'loss': Array(1.5667126, dtype=float32), 'loss_reward': Array(0.01751754, dtype=float32), 'loss_cross_entropy': Array(1.549195, dtype=float32)}


  0%|          | 2288/1000000 [07:46<34:44:23,  7.98it/s]

{'loss': Array(1.5648428, dtype=float32), 'loss_reward': Array(0.01722374, dtype=float32), 'loss_cross_entropy': Array(1.547619, dtype=float32)}


  0%|          | 2299/1000000 [07:48<40:13:32,  6.89it/s]

{'loss': Array(1.5644556, dtype=float32), 'loss_reward': Array(0.01700648, dtype=float32), 'loss_cross_entropy': Array(1.5474492, dtype=float32)}


  0%|          | 2310/1000000 [07:50<30:19:43,  9.14it/s]

{'loss': Array(1.5660242, dtype=float32), 'loss_reward': Array(0.01790495, dtype=float32), 'loss_cross_entropy': Array(1.5481191, dtype=float32)}


  0%|          | 2318/1000000 [07:51<34:38:22,  8.00it/s]

{'loss': Array(1.5628603, dtype=float32), 'loss_reward': Array(0.01719429, dtype=float32), 'loss_cross_entropy': Array(1.5456659, dtype=float32)}


  0%|          | 2330/1000000 [07:53<28:50:33,  9.61it/s]

{'loss': Array(1.5623436, dtype=float32), 'loss_reward': Array(0.01746677, dtype=float32), 'loss_cross_entropy': Array(1.5448767, dtype=float32)}


  0%|          | 2340/1000000 [07:55<33:22:39,  8.30it/s]

{'loss': Array(1.5618013, dtype=float32), 'loss_reward': Array(0.01699553, dtype=float32), 'loss_cross_entropy': Array(1.5448056, dtype=float32)}


  0%|          | 2350/1000000 [07:56<31:11:53,  8.88it/s]

{'loss': Array(1.5606664, dtype=float32), 'loss_reward': Array(0.01673846, dtype=float32), 'loss_cross_entropy': Array(1.5439281, dtype=float32)}


  0%|          | 2358/1000000 [07:58<35:30:20,  7.81it/s]

{'loss': Array(1.5602621, dtype=float32), 'loss_reward': Array(0.01703041, dtype=float32), 'loss_cross_entropy': Array(1.5432318, dtype=float32)}


  0%|          | 2369/1000000 [08:00<41:41:17,  6.65it/s]

{'loss': Array(1.559163, dtype=float32), 'loss_reward': Array(0.01682181, dtype=float32), 'loss_cross_entropy': Array(1.5423412, dtype=float32)}


  0%|          | 2378/1000000 [08:01<35:23:22,  7.83it/s]

{'loss': Array(1.5587713, dtype=float32), 'loss_reward': Array(0.01680836, dtype=float32), 'loss_cross_entropy': Array(1.541963, dtype=float32)}


  0%|          | 2390/1000000 [08:03<28:52:27,  9.60it/s]

{'loss': Array(1.5579643, dtype=float32), 'loss_reward': Array(0.01724801, dtype=float32), 'loss_cross_entropy': Array(1.5407165, dtype=float32)}


  0%|          | 2400/1000000 [08:05<29:54:13,  9.27it/s]

{'loss': Array(1.5568072, dtype=float32), 'loss_reward': Array(0.01625854, dtype=float32), 'loss_cross_entropy': Array(1.5405486, dtype=float32)}


  0%|          | 2409/1000000 [08:07<36:09:06,  7.67it/s]

{'loss': Array(1.555326, dtype=float32), 'loss_reward': Array(0.01703342, dtype=float32), 'loss_cross_entropy': Array(1.5382924, dtype=float32)}


  0%|          | 2419/1000000 [08:08<31:40:31,  8.75it/s]

{'loss': Array(1.5553461, dtype=float32), 'loss_reward': Array(0.01681838, dtype=float32), 'loss_cross_entropy': Array(1.5385278, dtype=float32)}


  0%|          | 2430/1000000 [08:10<28:47:19,  9.63it/s]

{'loss': Array(1.5543107, dtype=float32), 'loss_reward': Array(0.01665605, dtype=float32), 'loss_cross_entropy': Array(1.5376546, dtype=float32)}


  0%|          | 2438/1000000 [08:12<44:15:28,  6.26it/s]

{'loss': Array(1.5543623, dtype=float32), 'loss_reward': Array(0.01648713, dtype=float32), 'loss_cross_entropy': Array(1.5378753, dtype=float32)}


  0%|          | 2450/1000000 [08:13<31:07:22,  8.90it/s]

{'loss': Array(1.5541972, dtype=float32), 'loss_reward': Array(0.0165092, dtype=float32), 'loss_cross_entropy': Array(1.537688, dtype=float32)}


  0%|          | 2458/1000000 [08:15<35:45:22,  7.75it/s]

{'loss': Array(1.5525615, dtype=float32), 'loss_reward': Array(0.01635644, dtype=float32), 'loss_cross_entropy': Array(1.536205, dtype=float32)}


  0%|          | 2470/1000000 [08:17<29:07:14,  9.52it/s]

{'loss': Array(1.5517124, dtype=float32), 'loss_reward': Array(0.01629524, dtype=float32), 'loss_cross_entropy': Array(1.5354173, dtype=float32)}


  0%|          | 2478/1000000 [08:19<40:00:42,  6.93it/s]

{'loss': Array(1.5498832, dtype=float32), 'loss_reward': Array(0.0165055, dtype=float32), 'loss_cross_entropy': Array(1.5333779, dtype=float32)}


  0%|          | 2489/1000000 [08:20<30:50:26,  8.98it/s]

{'loss': Array(1.5495323, dtype=float32), 'loss_reward': Array(0.01596658, dtype=float32), 'loss_cross_entropy': Array(1.5335655, dtype=float32)}


  0%|          | 2498/1000000 [08:22<32:15:20,  8.59it/s]

{'loss': Array(1.5510563, dtype=float32), 'loss_reward': Array(0.01589051, dtype=float32), 'loss_cross_entropy': Array(1.5351658, dtype=float32)}


  0%|          | 2509/1000000 [08:34<147:21:29,  1.88it/s]

{'loss': Array(1.5497456, dtype=float32), 'loss_reward': Array(0.01534388, dtype=float32), 'loss_cross_entropy': Array(1.5344017, dtype=float32)}


  0%|          | 2520/1000000 [08:36<54:20:46,  5.10it/s] 

{'loss': Array(1.5471947, dtype=float32), 'loss_reward': Array(0.01504616, dtype=float32), 'loss_cross_entropy': Array(1.5321485, dtype=float32)}


  0%|          | 2528/1000000 [08:37<44:33:45,  6.22it/s]

{'loss': Array(1.5479125, dtype=float32), 'loss_reward': Array(0.01548898, dtype=float32), 'loss_cross_entropy': Array(1.5324235, dtype=float32)}


  0%|          | 2540/1000000 [08:39<31:33:46,  8.78it/s]

{'loss': Array(1.5493765, dtype=float32), 'loss_reward': Array(0.01587622, dtype=float32), 'loss_cross_entropy': Array(1.5335006, dtype=float32)}


  0%|          | 2550/1000000 [08:41<33:18:59,  8.32it/s]

{'loss': Array(1.5459503, dtype=float32), 'loss_reward': Array(0.01569664, dtype=float32), 'loss_cross_entropy': Array(1.5302538, dtype=float32)}


  0%|          | 2558/1000000 [08:42<36:55:45,  7.50it/s]

{'loss': Array(1.5456909, dtype=float32), 'loss_reward': Array(0.01540244, dtype=float32), 'loss_cross_entropy': Array(1.5302886, dtype=float32)}


  0%|          | 2569/1000000 [08:44<30:46:57,  9.00it/s]

{'loss': Array(1.5447961, dtype=float32), 'loss_reward': Array(0.01483742, dtype=float32), 'loss_cross_entropy': Array(1.5299586, dtype=float32)}


  0%|          | 2579/1000000 [08:46<36:51:51,  7.52it/s]

{'loss': Array(1.5462699, dtype=float32), 'loss_reward': Array(0.01559236, dtype=float32), 'loss_cross_entropy': Array(1.5306777, dtype=float32)}


  0%|          | 2590/1000000 [08:48<30:30:20,  9.08it/s]

{'loss': Array(1.5462221, dtype=float32), 'loss_reward': Array(0.01541684, dtype=float32), 'loss_cross_entropy': Array(1.5308052, dtype=float32)}


  0%|          | 2600/1000000 [08:49<28:57:39,  9.57it/s]

{'loss': Array(1.5459694, dtype=float32), 'loss_reward': Array(0.01522626, dtype=float32), 'loss_cross_entropy': Array(1.530743, dtype=float32)}


  0%|          | 2610/1000000 [08:51<29:38:41,  9.35it/s]

{'loss': Array(1.5437484, dtype=float32), 'loss_reward': Array(0.01472654, dtype=float32), 'loss_cross_entropy': Array(1.529022, dtype=float32)}


  0%|          | 2618/1000000 [08:53<40:00:29,  6.92it/s]

{'loss': Array(1.5433339, dtype=float32), 'loss_reward': Array(0.01492354, dtype=float32), 'loss_cross_entropy': Array(1.5284103, dtype=float32)}


  0%|          | 2628/1000000 [08:54<33:58:38,  8.15it/s]

{'loss': Array(1.5436206, dtype=float32), 'loss_reward': Array(0.01524906, dtype=float32), 'loss_cross_entropy': Array(1.5283716, dtype=float32)}


  0%|          | 2639/1000000 [08:56<29:23:10,  9.43it/s]

{'loss': Array(1.5434692, dtype=float32), 'loss_reward': Array(0.01512046, dtype=float32), 'loss_cross_entropy': Array(1.5283488, dtype=float32)}


  0%|          | 2650/1000000 [08:58<33:20:40,  8.31it/s]

{'loss': Array(1.5419519, dtype=float32), 'loss_reward': Array(0.01483054, dtype=float32), 'loss_cross_entropy': Array(1.5271213, dtype=float32)}


  0%|          | 2660/1000000 [09:00<31:36:51,  8.76it/s]

{'loss': Array(1.5415787, dtype=float32), 'loss_reward': Array(0.01574117, dtype=float32), 'loss_cross_entropy': Array(1.5258374, dtype=float32)}


  0%|          | 2670/1000000 [09:01<29:49:39,  9.29it/s]

{'loss': Array(1.5409263, dtype=float32), 'loss_reward': Array(0.0150435, dtype=float32), 'loss_cross_entropy': Array(1.5258827, dtype=float32)}


  0%|          | 2680/1000000 [09:03<40:16:15,  6.88it/s]

{'loss': Array(1.5402097, dtype=float32), 'loss_reward': Array(0.01450764, dtype=float32), 'loss_cross_entropy': Array(1.5257019, dtype=float32)}


  0%|          | 2688/1000000 [09:05<37:56:26,  7.30it/s]

{'loss': Array(1.5395876, dtype=float32), 'loss_reward': Array(0.01469786, dtype=float32), 'loss_cross_entropy': Array(1.5248898, dtype=float32)}


  0%|          | 2700/1000000 [09:06<28:59:20,  9.56it/s]

{'loss': Array(1.5401987, dtype=float32), 'loss_reward': Array(0.01454841, dtype=float32), 'loss_cross_entropy': Array(1.5256504, dtype=float32)}


  0%|          | 2710/1000000 [09:08<30:25:56,  9.10it/s]

{'loss': Array(1.5390867, dtype=float32), 'loss_reward': Array(0.01478613, dtype=float32), 'loss_cross_entropy': Array(1.5243006, dtype=float32)}


  0%|          | 2719/1000000 [09:10<38:12:00,  7.25it/s]

{'loss': Array(1.5388327, dtype=float32), 'loss_reward': Array(0.01492946, dtype=float32), 'loss_cross_entropy': Array(1.5239033, dtype=float32)}


  0%|          | 2730/1000000 [09:11<30:21:24,  9.13it/s]

{'loss': Array(1.5395815, dtype=float32), 'loss_reward': Array(0.01427661, dtype=float32), 'loss_cross_entropy': Array(1.5253049, dtype=float32)}


  0%|          | 2738/1000000 [09:13<37:26:15,  7.40it/s]

{'loss': Array(1.5384829, dtype=float32), 'loss_reward': Array(0.01377486, dtype=float32), 'loss_cross_entropy': Array(1.5247082, dtype=float32)}


  0%|          | 2749/1000000 [09:15<41:25:36,  6.69it/s]

{'loss': Array(1.5378121, dtype=float32), 'loss_reward': Array(0.01456918, dtype=float32), 'loss_cross_entropy': Array(1.523243, dtype=float32)}


  0%|          | 2759/1000000 [09:17<31:32:28,  8.78it/s]

{'loss': Array(1.537951, dtype=float32), 'loss_reward': Array(0.0145321, dtype=float32), 'loss_cross_entropy': Array(1.5234188, dtype=float32)}


  0%|          | 2770/1000000 [09:18<28:35:33,  9.69it/s]

{'loss': Array(1.5373783, dtype=float32), 'loss_reward': Array(0.01452831, dtype=float32), 'loss_cross_entropy': Array(1.52285, dtype=float32)}


  0%|          | 2780/1000000 [09:20<29:18:50,  9.45it/s]

{'loss': Array(1.5386008, dtype=float32), 'loss_reward': Array(0.01459312, dtype=float32), 'loss_cross_entropy': Array(1.5240076, dtype=float32)}


  0%|          | 2790/1000000 [09:22<33:35:45,  8.25it/s]

{'loss': Array(1.5365324, dtype=float32), 'loss_reward': Array(0.01424979, dtype=float32), 'loss_cross_entropy': Array(1.5222826, dtype=float32)}


  0%|          | 2798/1000000 [09:23<36:24:48,  7.61it/s]

{'loss': Array(1.5351702, dtype=float32), 'loss_reward': Array(0.013843, dtype=float32), 'loss_cross_entropy': Array(1.5213271, dtype=float32)}


  0%|          | 2808/1000000 [09:25<31:47:51,  8.71it/s]

{'loss': Array(1.5360165, dtype=float32), 'loss_reward': Array(0.01414961, dtype=float32), 'loss_cross_entropy': Array(1.5218667, dtype=float32)}


  0%|          | 2819/1000000 [09:27<40:13:28,  6.89it/s]

{'loss': Array(1.5353047, dtype=float32), 'loss_reward': Array(0.01431306, dtype=float32), 'loss_cross_entropy': Array(1.5209917, dtype=float32)}


  0%|          | 2830/1000000 [09:29<30:59:44,  8.94it/s]

{'loss': Array(1.5349005, dtype=float32), 'loss_reward': Array(0.01385913, dtype=float32), 'loss_cross_entropy': Array(1.5210414, dtype=float32)}


  0%|          | 2838/1000000 [09:30<35:44:27,  7.75it/s]

{'loss': Array(1.534465, dtype=float32), 'loss_reward': Array(0.01385607, dtype=float32), 'loss_cross_entropy': Array(1.5206089, dtype=float32)}


  0%|          | 2849/1000000 [09:32<30:57:07,  8.95it/s]

{'loss': Array(1.5355966, dtype=float32), 'loss_reward': Array(0.01379914, dtype=float32), 'loss_cross_entropy': Array(1.5217974, dtype=float32)}


  0%|          | 2859/1000000 [09:34<34:28:31,  8.03it/s]

{'loss': Array(1.5333731, dtype=float32), 'loss_reward': Array(0.01365629, dtype=float32), 'loss_cross_entropy': Array(1.5197167, dtype=float32)}


  0%|          | 2868/1000000 [09:35<32:18:43,  8.57it/s]

{'loss': Array(1.5347027, dtype=float32), 'loss_reward': Array(0.01390503, dtype=float32), 'loss_cross_entropy': Array(1.5207976, dtype=float32)}


  0%|          | 2879/1000000 [09:37<29:11:56,  9.49it/s]

{'loss': Array(1.5346493, dtype=float32), 'loss_reward': Array(0.01337325, dtype=float32), 'loss_cross_entropy': Array(1.521276, dtype=float32)}


  0%|          | 2888/1000000 [09:39<42:59:04,  6.44it/s]

{'loss': Array(1.534992, dtype=float32), 'loss_reward': Array(0.01380779, dtype=float32), 'loss_cross_entropy': Array(1.5211842, dtype=float32)}


  0%|          | 2900/1000000 [09:41<30:48:12,  8.99it/s]

{'loss': Array(1.5334867, dtype=float32), 'loss_reward': Array(0.01391675, dtype=float32), 'loss_cross_entropy': Array(1.5195699, dtype=float32)}


  0%|          | 2908/1000000 [09:42<36:14:29,  7.64it/s]

{'loss': Array(1.5321115, dtype=float32), 'loss_reward': Array(0.01375422, dtype=float32), 'loss_cross_entropy': Array(1.5183574, dtype=float32)}


  0%|          | 2920/1000000 [09:44<26:49:12, 10.33it/s]

{'loss': Array(1.5332996, dtype=float32), 'loss_reward': Array(0.01347184, dtype=float32), 'loss_cross_entropy': Array(1.5198277, dtype=float32)}


  0%|          | 2928/1000000 [09:46<37:59:40,  7.29it/s]

{'loss': Array(1.5330104, dtype=float32), 'loss_reward': Array(0.01405722, dtype=float32), 'loss_cross_entropy': Array(1.5189532, dtype=float32)}


  0%|          | 2940/1000000 [09:47<28:21:49,  9.76it/s]

{'loss': Array(1.5310689, dtype=float32), 'loss_reward': Array(0.01366006, dtype=float32), 'loss_cross_entropy': Array(1.5174088, dtype=float32)}


  0%|          | 2949/1000000 [09:49<32:13:20,  8.60it/s]

{'loss': Array(1.5321738, dtype=float32), 'loss_reward': Array(0.01378683, dtype=float32), 'loss_cross_entropy': Array(1.5183868, dtype=float32)}


  0%|          | 2958/1000000 [09:51<44:19:19,  6.25it/s]

{'loss': Array(1.5337814, dtype=float32), 'loss_reward': Array(0.01342607, dtype=float32), 'loss_cross_entropy': Array(1.5203555, dtype=float32)}


  0%|          | 2970/1000000 [09:53<30:41:42,  9.02it/s]

{'loss': Array(1.5319055, dtype=float32), 'loss_reward': Array(0.01306323, dtype=float32), 'loss_cross_entropy': Array(1.5188422, dtype=float32)}


  0%|          | 2979/1000000 [09:54<31:12:22,  8.87it/s]

{'loss': Array(1.5309026, dtype=float32), 'loss_reward': Array(0.01322682, dtype=float32), 'loss_cross_entropy': Array(1.5176758, dtype=float32)}


  0%|          | 2990/1000000 [09:56<29:02:50,  9.53it/s]

{'loss': Array(1.5319127, dtype=float32), 'loss_reward': Array(0.01372117, dtype=float32), 'loss_cross_entropy': Array(1.5181915, dtype=float32)}


  0%|          | 2998/1000000 [09:58<39:06:57,  7.08it/s]

{'loss': Array(1.5310057, dtype=float32), 'loss_reward': Array(0.01338486, dtype=float32), 'loss_cross_entropy': Array(1.5176209, dtype=float32)}


  0%|          | 3010/1000000 [10:10<129:15:17,  2.14it/s]

{'loss': Array(1.5329306, dtype=float32), 'loss_reward': Array(0.01347833, dtype=float32), 'loss_cross_entropy': Array(1.5194522, dtype=float32)}


  0%|          | 3018/1000000 [10:11<70:24:00,  3.93it/s] 

{'loss': Array(1.5308822, dtype=float32), 'loss_reward': Array(0.01307767, dtype=float32), 'loss_cross_entropy': Array(1.5178046, dtype=float32)}


  0%|          | 3028/1000000 [10:13<52:33:30,  5.27it/s]

{'loss': Array(1.5310668, dtype=float32), 'loss_reward': Array(0.01260658, dtype=float32), 'loss_cross_entropy': Array(1.5184602, dtype=float32)}


  0%|          | 3040/1000000 [10:15<32:32:28,  8.51it/s]

{'loss': Array(1.529646, dtype=float32), 'loss_reward': Array(0.01291057, dtype=float32), 'loss_cross_entropy': Array(1.5167356, dtype=float32)}


  0%|          | 3048/1000000 [10:16<36:11:24,  7.65it/s]

{'loss': Array(1.5299597, dtype=float32), 'loss_reward': Array(0.01303145, dtype=float32), 'loss_cross_entropy': Array(1.5169281, dtype=float32)}


  0%|          | 3060/1000000 [10:18<29:41:00,  9.33it/s]

{'loss': Array(1.531416, dtype=float32), 'loss_reward': Array(0.0132408, dtype=float32), 'loss_cross_entropy': Array(1.5181752, dtype=float32)}


  0%|          | 3070/1000000 [10:20<32:41:21,  8.47it/s]

{'loss': Array(1.5309827, dtype=float32), 'loss_reward': Array(0.01312263, dtype=float32), 'loss_cross_entropy': Array(1.5178599, dtype=float32)}


  0%|          | 3080/1000000 [10:22<30:25:59,  9.10it/s]

{'loss': Array(1.5321101, dtype=float32), 'loss_reward': Array(0.01320889, dtype=float32), 'loss_cross_entropy': Array(1.5189011, dtype=float32)}


  0%|          | 3088/1000000 [10:23<35:11:34,  7.87it/s]

{'loss': Array(1.5283557, dtype=float32), 'loss_reward': Array(0.01299615, dtype=float32), 'loss_cross_entropy': Array(1.5153595, dtype=float32)}


  0%|          | 3099/1000000 [10:25<37:04:42,  7.47it/s]

{'loss': Array(1.5311474, dtype=float32), 'loss_reward': Array(0.0136113, dtype=float32), 'loss_cross_entropy': Array(1.517536, dtype=float32)}


  0%|          | 3110/1000000 [10:27<30:40:47,  9.03it/s]

{'loss': Array(1.5304736, dtype=float32), 'loss_reward': Array(0.01303215, dtype=float32), 'loss_cross_entropy': Array(1.5174414, dtype=float32)}


  0%|          | 3118/1000000 [10:28<34:39:00,  7.99it/s]

{'loss': Array(1.5293859, dtype=float32), 'loss_reward': Array(0.01279588, dtype=float32), 'loss_cross_entropy': Array(1.51659, dtype=float32)}


  0%|          | 3130/1000000 [10:30<37:54:27,  7.30it/s]

{'loss': Array(1.5293354, dtype=float32), 'loss_reward': Array(0.01299339, dtype=float32), 'loss_cross_entropy': Array(1.516342, dtype=float32)}


  0%|          | 3140/1000000 [10:32<32:30:42,  8.52it/s]

{'loss': Array(1.531354, dtype=float32), 'loss_reward': Array(0.01311114, dtype=float32), 'loss_cross_entropy': Array(1.5182427, dtype=float32)}


  0%|          | 3148/1000000 [10:33<35:48:22,  7.73it/s]

{'loss': Array(1.5285386, dtype=float32), 'loss_reward': Array(0.01320738, dtype=float32), 'loss_cross_entropy': Array(1.5153313, dtype=float32)}


  0%|          | 3159/1000000 [10:35<29:50:29,  9.28it/s]

{'loss': Array(1.5281191, dtype=float32), 'loss_reward': Array(0.01327131, dtype=float32), 'loss_cross_entropy': Array(1.5148476, dtype=float32)}


  0%|          | 3170/1000000 [10:37<33:03:43,  8.38it/s]

{'loss': Array(1.5283656, dtype=float32), 'loss_reward': Array(0.01296207, dtype=float32), 'loss_cross_entropy': Array(1.5154036, dtype=float32)}


  0%|          | 3178/1000000 [10:39<36:04:28,  7.68it/s]

{'loss': Array(1.5292325, dtype=float32), 'loss_reward': Array(0.01280239, dtype=float32), 'loss_cross_entropy': Array(1.5164303, dtype=float32)}


  0%|          | 3189/1000000 [10:40<30:04:53,  9.20it/s]

{'loss': Array(1.5279354, dtype=float32), 'loss_reward': Array(0.01289861, dtype=float32), 'loss_cross_entropy': Array(1.5150368, dtype=float32)}


  0%|          | 3199/1000000 [10:42<40:45:41,  6.79it/s]

{'loss': Array(1.5275626, dtype=float32), 'loss_reward': Array(0.01302812, dtype=float32), 'loss_cross_entropy': Array(1.5145344, dtype=float32)}


  0%|          | 3210/1000000 [10:44<31:26:34,  8.81it/s]

{'loss': Array(1.5285265, dtype=float32), 'loss_reward': Array(0.01237637, dtype=float32), 'loss_cross_entropy': Array(1.5161502, dtype=float32)}


  0%|          | 3218/1000000 [10:45<35:46:46,  7.74it/s]

{'loss': Array(1.5301025, dtype=float32), 'loss_reward': Array(0.01271517, dtype=float32), 'loss_cross_entropy': Array(1.5173873, dtype=float32)}


  0%|          | 3229/1000000 [10:47<29:33:36,  9.37it/s]

{'loss': Array(1.5284908, dtype=float32), 'loss_reward': Array(0.01289784, dtype=float32), 'loss_cross_entropy': Array(1.5155929, dtype=float32)}


  0%|          | 3238/1000000 [10:49<35:40:15,  7.76it/s]

{'loss': Array(1.5262622, dtype=float32), 'loss_reward': Array(0.01267727, dtype=float32), 'loss_cross_entropy': Array(1.513585, dtype=float32)}


  0%|          | 3250/1000000 [10:51<28:45:57,  9.63it/s]

{'loss': Array(1.5285711, dtype=float32), 'loss_reward': Array(0.01250061, dtype=float32), 'loss_cross_entropy': Array(1.5160704, dtype=float32)}


  0%|          | 3258/1000000 [10:52<35:01:17,  7.91it/s]

{'loss': Array(1.5254717, dtype=float32), 'loss_reward': Array(0.01241019, dtype=float32), 'loss_cross_entropy': Array(1.5130616, dtype=float32)}


  0%|          | 3270/1000000 [10:54<35:12:44,  7.86it/s]

{'loss': Array(1.5270493, dtype=float32), 'loss_reward': Array(0.0126507, dtype=float32), 'loss_cross_entropy': Array(1.5143986, dtype=float32)}


  0%|          | 3278/1000000 [10:56<37:42:30,  7.34it/s]

{'loss': Array(1.5281135, dtype=float32), 'loss_reward': Array(0.01248067, dtype=float32), 'loss_cross_entropy': Array(1.515633, dtype=float32)}


  0%|          | 3290/1000000 [10:57<29:23:58,  9.42it/s]

{'loss': Array(1.525802, dtype=float32), 'loss_reward': Array(0.01281295, dtype=float32), 'loss_cross_entropy': Array(1.5129889, dtype=float32)}


  0%|          | 3300/1000000 [10:59<30:17:13,  9.14it/s]

{'loss': Array(1.5271053, dtype=float32), 'loss_reward': Array(0.01259852, dtype=float32), 'loss_cross_entropy': Array(1.5145068, dtype=float32)}


  0%|          | 3308/1000000 [11:01<39:55:01,  6.94it/s]

{'loss': Array(1.5277153, dtype=float32), 'loss_reward': Array(0.012543, dtype=float32), 'loss_cross_entropy': Array(1.5151721, dtype=float32)}


  0%|          | 3320/1000000 [11:03<28:31:21,  9.71it/s]

{'loss': Array(1.5253075, dtype=float32), 'loss_reward': Array(0.01255535, dtype=float32), 'loss_cross_entropy': Array(1.5127522, dtype=float32)}


  0%|          | 3328/1000000 [11:04<36:04:22,  7.67it/s]

{'loss': Array(1.5267171, dtype=float32), 'loss_reward': Array(0.01245295, dtype=float32), 'loss_cross_entropy': Array(1.5142642, dtype=float32)}


  0%|          | 3338/1000000 [11:06<40:23:19,  6.85it/s]

{'loss': Array(1.5240945, dtype=float32), 'loss_reward': Array(0.01227249, dtype=float32), 'loss_cross_entropy': Array(1.5118222, dtype=float32)}


  0%|          | 3348/1000000 [11:08<32:51:21,  8.43it/s]

{'loss': Array(1.5269005, dtype=float32), 'loss_reward': Array(0.01253152, dtype=float32), 'loss_cross_entropy': Array(1.5143689, dtype=float32)}


  0%|          | 3360/1000000 [11:10<27:36:57, 10.02it/s]

{'loss': Array(1.5258955, dtype=float32), 'loss_reward': Array(0.01263186, dtype=float32), 'loss_cross_entropy': Array(1.5132637, dtype=float32)}


  0%|          | 3368/1000000 [11:11<34:15:13,  8.08it/s]

{'loss': Array(1.5257865, dtype=float32), 'loss_reward': Array(0.01208888, dtype=float32), 'loss_cross_entropy': Array(1.5136976, dtype=float32)}


  0%|          | 3379/1000000 [11:13<34:03:43,  8.13it/s]

{'loss': Array(1.5263826, dtype=float32), 'loss_reward': Array(0.01202527, dtype=float32), 'loss_cross_entropy': Array(1.5143572, dtype=float32)}


  0%|          | 3390/1000000 [11:15<29:37:52,  9.34it/s]

{'loss': Array(1.5255164, dtype=float32), 'loss_reward': Array(0.01274312, dtype=float32), 'loss_cross_entropy': Array(1.5127733, dtype=float32)}


  0%|          | 3398/1000000 [11:16<34:20:40,  8.06it/s]

{'loss': Array(1.5267329, dtype=float32), 'loss_reward': Array(0.01218437, dtype=float32), 'loss_cross_entropy': Array(1.5145487, dtype=float32)}


  0%|          | 3410/1000000 [11:18<34:43:27,  7.97it/s]

{'loss': Array(1.5244639, dtype=float32), 'loss_reward': Array(0.01202472, dtype=float32), 'loss_cross_entropy': Array(1.5124391, dtype=float32)}


  0%|          | 3418/1000000 [11:20<36:29:40,  7.59it/s]

{'loss': Array(1.5276082, dtype=float32), 'loss_reward': Array(0.01248744, dtype=float32), 'loss_cross_entropy': Array(1.5151207, dtype=float32)}


  0%|          | 3430/1000000 [11:21<29:03:22,  9.53it/s]

{'loss': Array(1.5255641, dtype=float32), 'loss_reward': Array(0.01216632, dtype=float32), 'loss_cross_entropy': Array(1.5133977, dtype=float32)}


  0%|          | 3438/1000000 [11:23<35:49:25,  7.73it/s]

{'loss': Array(1.5228084, dtype=float32), 'loss_reward': Array(0.01234599, dtype=float32), 'loss_cross_entropy': Array(1.5104623, dtype=float32)}


  0%|          | 3450/1000000 [11:25<32:13:53,  8.59it/s]

{'loss': Array(1.5233598, dtype=float32), 'loss_reward': Array(0.01214465, dtype=float32), 'loss_cross_entropy': Array(1.5112151, dtype=float32)}


  0%|          | 3460/1000000 [11:27<29:56:36,  9.24it/s]

{'loss': Array(1.5252435, dtype=float32), 'loss_reward': Array(0.01232107, dtype=float32), 'loss_cross_entropy': Array(1.5129224, dtype=float32)}


  0%|          | 3468/1000000 [11:28<35:47:45,  7.73it/s]

{'loss': Array(1.5243986, dtype=float32), 'loss_reward': Array(0.01195196, dtype=float32), 'loss_cross_entropy': Array(1.5124466, dtype=float32)}


  0%|          | 3480/1000000 [11:30<33:37:02,  8.23it/s]

{'loss': Array(1.5239856, dtype=float32), 'loss_reward': Array(0.01148683, dtype=float32), 'loss_cross_entropy': Array(1.5124987, dtype=float32)}


  0%|          | 3488/1000000 [11:32<35:56:27,  7.70it/s]

{'loss': Array(1.5255257, dtype=float32), 'loss_reward': Array(0.01218051, dtype=float32), 'loss_cross_entropy': Array(1.513345, dtype=float32)}


  0%|          | 3499/1000000 [11:33<30:24:43,  9.10it/s]

{'loss': Array(1.5241488, dtype=float32), 'loss_reward': Array(0.01183363, dtype=float32), 'loss_cross_entropy': Array(1.512315, dtype=float32)}


  0%|          | 3510/1000000 [11:46<129:17:11,  2.14it/s]

{'loss': Array(1.525804, dtype=float32), 'loss_reward': Array(0.01222512, dtype=float32), 'loss_cross_entropy': Array(1.5135789, dtype=float32)}


  0%|          | 3520/1000000 [11:47<52:22:37,  5.28it/s] 

{'loss': Array(1.5257016, dtype=float32), 'loss_reward': Array(0.01213823, dtype=float32), 'loss_cross_entropy': Array(1.5135633, dtype=float32)}


  0%|          | 3528/1000000 [11:49<42:05:07,  6.58it/s]

{'loss': Array(1.5246012, dtype=float32), 'loss_reward': Array(0.01192063, dtype=float32), 'loss_cross_entropy': Array(1.5126805, dtype=float32)}


  0%|          | 3539/1000000 [11:50<31:42:41,  8.73it/s]

{'loss': Array(1.5235813, dtype=float32), 'loss_reward': Array(0.01197687, dtype=float32), 'loss_cross_entropy': Array(1.5116047, dtype=float32)}


  0%|          | 3549/1000000 [11:52<33:57:29,  8.15it/s]

{'loss': Array(1.5263841, dtype=float32), 'loss_reward': Array(0.01200868, dtype=float32), 'loss_cross_entropy': Array(1.5143753, dtype=float32)}


  0%|          | 3558/1000000 [11:54<32:05:37,  8.62it/s]

{'loss': Array(1.5243083, dtype=float32), 'loss_reward': Array(0.01208275, dtype=float32), 'loss_cross_entropy': Array(1.5122256, dtype=float32)}


  0%|          | 3570/1000000 [11:56<28:02:22,  9.87it/s]

{'loss': Array(1.5227813, dtype=float32), 'loss_reward': Array(0.01186712, dtype=float32), 'loss_cross_entropy': Array(1.510914, dtype=float32)}


  0%|          | 3578/1000000 [11:57<44:36:00,  6.21it/s]

{'loss': Array(1.5235267, dtype=float32), 'loss_reward': Array(0.01185271, dtype=float32), 'loss_cross_entropy': Array(1.5116739, dtype=float32)}


  0%|          | 3589/1000000 [11:59<32:09:48,  8.61it/s]

{'loss': Array(1.5232385, dtype=float32), 'loss_reward': Array(0.0124832, dtype=float32), 'loss_cross_entropy': Array(1.5107553, dtype=float32)}


  0%|          | 3598/1000000 [12:01<33:40:26,  8.22it/s]

{'loss': Array(1.523551, dtype=float32), 'loss_reward': Array(0.01235283, dtype=float32), 'loss_cross_entropy': Array(1.5111982, dtype=float32)}


  0%|          | 3610/1000000 [12:02<28:52:14,  9.59it/s]

{'loss': Array(1.5238146, dtype=float32), 'loss_reward': Array(0.01158471, dtype=float32), 'loss_cross_entropy': Array(1.5122299, dtype=float32)}


  0%|          | 3620/1000000 [12:04<32:20:45,  8.56it/s]

{'loss': Array(1.5238403, dtype=float32), 'loss_reward': Array(0.01172636, dtype=float32), 'loss_cross_entropy': Array(1.5121139, dtype=float32)}


  0%|          | 3628/1000000 [12:06<36:28:45,  7.59it/s]

{'loss': Array(1.5230185, dtype=float32), 'loss_reward': Array(0.01192578, dtype=float32), 'loss_cross_entropy': Array(1.5110925, dtype=float32)}


  0%|          | 3640/1000000 [12:08<29:14:30,  9.46it/s]

{'loss': Array(1.5240592, dtype=float32), 'loss_reward': Array(0.0121929, dtype=float32), 'loss_cross_entropy': Array(1.5118662, dtype=float32)}


  0%|          | 3648/1000000 [12:09<44:34:45,  6.21it/s]

{'loss': Array(1.5224822, dtype=float32), 'loss_reward': Array(0.01175368, dtype=float32), 'loss_cross_entropy': Array(1.5107285, dtype=float32)}


  0%|          | 3658/1000000 [12:11<33:50:38,  8.18it/s]

{'loss': Array(1.5239824, dtype=float32), 'loss_reward': Array(0.0118372, dtype=float32), 'loss_cross_entropy': Array(1.5121452, dtype=float32)}


  0%|          | 3669/1000000 [12:13<30:37:56,  9.03it/s]

{'loss': Array(1.5234462, dtype=float32), 'loss_reward': Array(0.01210465, dtype=float32), 'loss_cross_entropy': Array(1.5113415, dtype=float32)}


  0%|          | 3679/1000000 [12:14<31:17:37,  8.84it/s]

{'loss': Array(1.5249896, dtype=float32), 'loss_reward': Array(0.01216104, dtype=float32), 'loss_cross_entropy': Array(1.5128285, dtype=float32)}


  0%|          | 3689/1000000 [12:16<34:01:45,  8.13it/s]

{'loss': Array(1.52184, dtype=float32), 'loss_reward': Array(0.01191813, dtype=float32), 'loss_cross_entropy': Array(1.5099218, dtype=float32)}


  0%|          | 3699/1000000 [12:18<30:28:07,  9.08it/s]

{'loss': Array(1.5225923, dtype=float32), 'loss_reward': Array(0.01128293, dtype=float32), 'loss_cross_entropy': Array(1.5113094, dtype=float32)}


  0%|          | 3710/1000000 [12:20<28:25:44,  9.73it/s]

{'loss': Array(1.5213909, dtype=float32), 'loss_reward': Array(0.01202275, dtype=float32), 'loss_cross_entropy': Array(1.5093682, dtype=float32)}


  0%|          | 3720/1000000 [12:22<36:15:46,  7.63it/s]

{'loss': Array(1.5217582, dtype=float32), 'loss_reward': Array(0.01172091, dtype=float32), 'loss_cross_entropy': Array(1.5100373, dtype=float32)}


  0%|          | 3730/1000000 [12:23<30:49:37,  8.98it/s]

{'loss': Array(1.5218492, dtype=float32), 'loss_reward': Array(0.01200372, dtype=float32), 'loss_cross_entropy': Array(1.5098455, dtype=float32)}


  0%|          | 3738/1000000 [12:25<35:28:41,  7.80it/s]

{'loss': Array(1.5229822, dtype=float32), 'loss_reward': Array(0.01145947, dtype=float32), 'loss_cross_entropy': Array(1.5115228, dtype=float32)}


  0%|          | 3750/1000000 [12:27<37:29:40,  7.38it/s]

{'loss': Array(1.521142, dtype=float32), 'loss_reward': Array(0.01144921, dtype=float32), 'loss_cross_entropy': Array(1.5096927, dtype=float32)}


  0%|          | 3758/1000000 [12:28<38:33:09,  7.18it/s]

{'loss': Array(1.5222113, dtype=float32), 'loss_reward': Array(0.01180193, dtype=float32), 'loss_cross_entropy': Array(1.5104094, dtype=float32)}


  0%|          | 3768/1000000 [12:30<33:02:20,  8.38it/s]

{'loss': Array(1.5227456, dtype=float32), 'loss_reward': Array(0.0119544, dtype=float32), 'loss_cross_entropy': Array(1.5107912, dtype=float32)}


  0%|          | 3778/1000000 [12:31<30:41:03,  9.02it/s]

{'loss': Array(1.5236783, dtype=float32), 'loss_reward': Array(0.01164781, dtype=float32), 'loss_cross_entropy': Array(1.5120306, dtype=float32)}


  0%|          | 3790/1000000 [12:34<30:54:54,  8.95it/s]

{'loss': Array(1.5221425, dtype=float32), 'loss_reward': Array(0.0115295, dtype=float32), 'loss_cross_entropy': Array(1.5106131, dtype=float32)}


  0%|          | 3798/1000000 [12:35<35:40:51,  7.76it/s]

{'loss': Array(1.5230796, dtype=float32), 'loss_reward': Array(0.01216579, dtype=float32), 'loss_cross_entropy': Array(1.5109137, dtype=float32)}


  0%|          | 3810/1000000 [12:37<29:17:19,  9.45it/s]

{'loss': Array(1.5231652, dtype=float32), 'loss_reward': Array(0.01185106, dtype=float32), 'loss_cross_entropy': Array(1.5113142, dtype=float32)}


  0%|          | 3820/1000000 [12:39<40:22:07,  6.85it/s]

{'loss': Array(1.5227094, dtype=float32), 'loss_reward': Array(0.01180167, dtype=float32), 'loss_cross_entropy': Array(1.5109079, dtype=float32)}


  0%|          | 3828/1000000 [12:40<38:24:05,  7.21it/s]

{'loss': Array(1.522433, dtype=float32), 'loss_reward': Array(0.01139571, dtype=float32), 'loss_cross_entropy': Array(1.5110371, dtype=float32)}


  0%|          | 3839/1000000 [12:42<31:34:04,  8.77it/s]

{'loss': Array(1.5212306, dtype=float32), 'loss_reward': Array(0.01163757, dtype=float32), 'loss_cross_entropy': Array(1.5095929, dtype=float32)}


  0%|          | 3850/1000000 [12:43<28:31:15,  9.70it/s]

{'loss': Array(1.52024, dtype=float32), 'loss_reward': Array(0.01171449, dtype=float32), 'loss_cross_entropy': Array(1.5085255, dtype=float32)}


  0%|          | 3859/1000000 [12:45<38:31:56,  7.18it/s]

{'loss': Array(1.5239944, dtype=float32), 'loss_reward': Array(0.0118835, dtype=float32), 'loss_cross_entropy': Array(1.5121108, dtype=float32)}


  0%|          | 3870/1000000 [12:47<30:01:53,  9.21it/s]

{'loss': Array(1.5231999, dtype=float32), 'loss_reward': Array(0.01176479, dtype=float32), 'loss_cross_entropy': Array(1.5114354, dtype=float32)}


  0%|          | 3878/1000000 [12:48<35:03:25,  7.89it/s]

{'loss': Array(1.522077, dtype=float32), 'loss_reward': Array(0.01142582, dtype=float32), 'loss_cross_entropy': Array(1.510651, dtype=float32)}


  0%|          | 3889/1000000 [12:50<41:38:51,  6.64it/s]

{'loss': Array(1.5205383, dtype=float32), 'loss_reward': Array(0.01176251, dtype=float32), 'loss_cross_entropy': Array(1.5087758, dtype=float32)}


  0%|          | 3900/1000000 [12:52<31:48:43,  8.70it/s]

{'loss': Array(1.5213438, dtype=float32), 'loss_reward': Array(0.01119221, dtype=float32), 'loss_cross_entropy': Array(1.5101517, dtype=float32)}


  0%|          | 3908/1000000 [12:54<36:44:38,  7.53it/s]

{'loss': Array(1.5205171, dtype=float32), 'loss_reward': Array(0.01136019, dtype=float32), 'loss_cross_entropy': Array(1.509157, dtype=float32)}


  0%|          | 3920/1000000 [12:56<29:41:14,  9.32it/s]

{'loss': Array(1.5213174, dtype=float32), 'loss_reward': Array(0.01163627, dtype=float32), 'loss_cross_entropy': Array(1.5096811, dtype=float32)}


  0%|          | 3930/1000000 [12:57<33:55:46,  8.15it/s]

{'loss': Array(1.5211366, dtype=float32), 'loss_reward': Array(0.01154683, dtype=float32), 'loss_cross_entropy': Array(1.5095899, dtype=float32)}


  0%|          | 3940/1000000 [12:59<30:10:12,  9.17it/s]

{'loss': Array(1.5207878, dtype=float32), 'loss_reward': Array(0.01114791, dtype=float32), 'loss_cross_entropy': Array(1.5096399, dtype=float32)}


  0%|          | 3949/1000000 [13:01<33:00:02,  8.38it/s]

{'loss': Array(1.5219271, dtype=float32), 'loss_reward': Array(0.01125835, dtype=float32), 'loss_cross_entropy': Array(1.5106686, dtype=float32)}


  0%|          | 3960/1000000 [13:03<36:09:00,  7.65it/s]

{'loss': Array(1.5213022, dtype=float32), 'loss_reward': Array(0.01125703, dtype=float32), 'loss_cross_entropy': Array(1.5100453, dtype=float32)}


  0%|          | 3968/1000000 [13:04<37:02:34,  7.47it/s]

{'loss': Array(1.5217971, dtype=float32), 'loss_reward': Array(0.0114178, dtype=float32), 'loss_cross_entropy': Array(1.5103792, dtype=float32)}


  0%|          | 3979/1000000 [13:06<31:01:54,  8.92it/s]

{'loss': Array(1.5204765, dtype=float32), 'loss_reward': Array(0.01143774, dtype=float32), 'loss_cross_entropy': Array(1.5090386, dtype=float32)}


  0%|          | 3990/1000000 [13:07<29:10:58,  9.48it/s]

{'loss': Array(1.5215248, dtype=float32), 'loss_reward': Array(0.01157496, dtype=float32), 'loss_cross_entropy': Array(1.5099498, dtype=float32)}


  0%|          | 4000/1000000 [13:09<33:02:54,  8.37it/s]

{'loss': Array(1.5209149, dtype=float32), 'loss_reward': Array(0.01156481, dtype=float32), 'loss_cross_entropy': Array(1.5093502, dtype=float32)}


  0%|          | 4009/1000000 [13:21<141:52:32,  1.95it/s]

{'loss': Array(1.5230434, dtype=float32), 'loss_reward': Array(0.0112001, dtype=float32), 'loss_cross_entropy': Array(1.5118431, dtype=float32)}


  0%|          | 4018/1000000 [13:23<60:05:13,  4.60it/s] 

{'loss': Array(1.5220212, dtype=float32), 'loss_reward': Array(0.0113884, dtype=float32), 'loss_cross_entropy': Array(1.5106329, dtype=float32)}


  0%|          | 4030/1000000 [13:25<38:02:21,  7.27it/s]

{'loss': Array(1.5208551, dtype=float32), 'loss_reward': Array(0.01150951, dtype=float32), 'loss_cross_entropy': Array(1.5093455, dtype=float32)}


  0%|          | 4038/1000000 [13:26<38:14:48,  7.23it/s]

{'loss': Array(1.5212721, dtype=float32), 'loss_reward': Array(0.01155165, dtype=float32), 'loss_cross_entropy': Array(1.5097203, dtype=float32)}


  0%|          | 4050/1000000 [13:28<30:16:17,  9.14it/s]

{'loss': Array(1.5204266, dtype=float32), 'loss_reward': Array(0.01113588, dtype=float32), 'loss_cross_entropy': Array(1.5092908, dtype=float32)}


  0%|          | 4060/1000000 [13:30<41:16:42,  6.70it/s]

{'loss': Array(1.5202678, dtype=float32), 'loss_reward': Array(0.01152714, dtype=float32), 'loss_cross_entropy': Array(1.5087407, dtype=float32)}


  0%|          | 4068/1000000 [13:31<39:33:06,  6.99it/s]

{'loss': Array(1.5226774, dtype=float32), 'loss_reward': Array(0.01118761, dtype=float32), 'loss_cross_entropy': Array(1.5114899, dtype=float32)}


  0%|          | 4080/1000000 [13:33<30:09:15,  9.17it/s]

{'loss': Array(1.5210811, dtype=float32), 'loss_reward': Array(0.01079105, dtype=float32), 'loss_cross_entropy': Array(1.5102899, dtype=float32)}


  0%|          | 4088/1000000 [13:35<36:05:02,  7.67it/s]

{'loss': Array(1.5207804, dtype=float32), 'loss_reward': Array(0.01112873, dtype=float32), 'loss_cross_entropy': Array(1.5096517, dtype=float32)}


  0%|          | 4099/1000000 [13:37<34:32:47,  8.01it/s]

{'loss': Array(1.51862, dtype=float32), 'loss_reward': Array(0.01095243, dtype=float32), 'loss_cross_entropy': Array(1.5076675, dtype=float32)}


  0%|          | 4109/1000000 [13:38<30:17:02,  9.13it/s]

{'loss': Array(1.5183109, dtype=float32), 'loss_reward': Array(0.01118532, dtype=float32), 'loss_cross_entropy': Array(1.5071256, dtype=float32)}


  0%|          | 4118/1000000 [13:40<31:41:34,  8.73it/s]

{'loss': Array(1.5207118, dtype=float32), 'loss_reward': Array(0.01114143, dtype=float32), 'loss_cross_entropy': Array(1.5095704, dtype=float32)}


  0%|          | 4129/1000000 [13:42<40:09:55,  6.89it/s]

{'loss': Array(1.5197232, dtype=float32), 'loss_reward': Array(0.01111728, dtype=float32), 'loss_cross_entropy': Array(1.5086058, dtype=float32)}


  0%|          | 4139/1000000 [13:43<32:05:45,  8.62it/s]

{'loss': Array(1.5196838, dtype=float32), 'loss_reward': Array(0.01089169, dtype=float32), 'loss_cross_entropy': Array(1.5087922, dtype=float32)}


  0%|          | 4149/1000000 [13:45<29:58:12,  9.23it/s]

{'loss': Array(1.5206364, dtype=float32), 'loss_reward': Array(0.01118647, dtype=float32), 'loss_cross_entropy': Array(1.5094498, dtype=float32)}


  0%|          | 4160/1000000 [13:47<28:20:15,  9.76it/s]

{'loss': Array(1.519018, dtype=float32), 'loss_reward': Array(0.01109273, dtype=float32), 'loss_cross_entropy': Array(1.5079253, dtype=float32)}


  0%|          | 4168/1000000 [13:49<40:09:28,  6.89it/s]

{'loss': Array(1.5193249, dtype=float32), 'loss_reward': Array(0.01076038, dtype=float32), 'loss_cross_entropy': Array(1.5085644, dtype=float32)}


  0%|          | 4180/1000000 [13:50<29:37:22,  9.34it/s]

{'loss': Array(1.5208964, dtype=float32), 'loss_reward': Array(0.01133985, dtype=float32), 'loss_cross_entropy': Array(1.5095565, dtype=float32)}


  0%|          | 4188/1000000 [13:52<36:14:28,  7.63it/s]

{'loss': Array(1.5196084, dtype=float32), 'loss_reward': Array(0.01104213, dtype=float32), 'loss_cross_entropy': Array(1.5085664, dtype=float32)}


  0%|          | 4198/1000000 [13:54<42:11:07,  6.56it/s]

{'loss': Array(1.5194747, dtype=float32), 'loss_reward': Array(0.0104418, dtype=float32), 'loss_cross_entropy': Array(1.509033, dtype=float32)}


  0%|          | 4208/1000000 [13:55<34:27:08,  8.03it/s]

{'loss': Array(1.5200546, dtype=float32), 'loss_reward': Array(0.01068786, dtype=float32), 'loss_cross_entropy': Array(1.5093666, dtype=float32)}


  0%|          | 4220/1000000 [13:57<28:13:21,  9.80it/s]

{'loss': Array(1.5198497, dtype=float32), 'loss_reward': Array(0.01100678, dtype=float32), 'loss_cross_entropy': Array(1.5088428, dtype=float32)}


  0%|          | 4228/1000000 [13:58<34:16:36,  8.07it/s]

{'loss': Array(1.5198079, dtype=float32), 'loss_reward': Array(0.01097503, dtype=float32), 'loss_cross_entropy': Array(1.5088328, dtype=float32)}


  0%|          | 4239/1000000 [14:01<33:05:17,  8.36it/s]

{'loss': Array(1.5185809, dtype=float32), 'loss_reward': Array(0.01119412, dtype=float32), 'loss_cross_entropy': Array(1.5073868, dtype=float32)}


  0%|          | 4250/1000000 [14:02<29:31:15,  9.37it/s]

{'loss': Array(1.5201256, dtype=float32), 'loss_reward': Array(0.01108711, dtype=float32), 'loss_cross_entropy': Array(1.5090387, dtype=float32)}


  0%|          | 4260/1000000 [14:04<30:24:54,  9.09it/s]

{'loss': Array(1.5181854, dtype=float32), 'loss_reward': Array(0.01070174, dtype=float32), 'loss_cross_entropy': Array(1.5074837, dtype=float32)}


  0%|          | 4270/1000000 [14:06<36:23:04,  7.60it/s]

{'loss': Array(1.521586, dtype=float32), 'loss_reward': Array(0.01118504, dtype=float32), 'loss_cross_entropy': Array(1.510401, dtype=float32)}


  0%|          | 4278/1000000 [14:07<36:41:51,  7.54it/s]

{'loss': Array(1.5187702, dtype=float32), 'loss_reward': Array(0.01087951, dtype=float32), 'loss_cross_entropy': Array(1.5078907, dtype=float32)}


  0%|          | 4289/1000000 [14:09<30:16:14,  9.14it/s]

{'loss': Array(1.5191883, dtype=float32), 'loss_reward': Array(0.01133751, dtype=float32), 'loss_cross_entropy': Array(1.5078508, dtype=float32)}


  0%|          | 4300/1000000 [14:11<37:04:09,  7.46it/s]

{'loss': Array(1.5172668, dtype=float32), 'loss_reward': Array(0.01080404, dtype=float32), 'loss_cross_entropy': Array(1.5064627, dtype=float32)}


  0%|          | 4310/1000000 [14:12<30:45:03,  8.99it/s]

{'loss': Array(1.517888, dtype=float32), 'loss_reward': Array(0.01094811, dtype=float32), 'loss_cross_entropy': Array(1.50694, dtype=float32)}


  0%|          | 4318/1000000 [14:14<34:50:30,  7.94it/s]

{'loss': Array(1.518035, dtype=float32), 'loss_reward': Array(0.01069794, dtype=float32), 'loss_cross_entropy': Array(1.5073372, dtype=float32)}


  0%|          | 4330/1000000 [14:15<28:43:57,  9.63it/s]

{'loss': Array(1.5193063, dtype=float32), 'loss_reward': Array(0.01063032, dtype=float32), 'loss_cross_entropy': Array(1.5086759, dtype=float32)}


  0%|          | 4338/1000000 [14:17<42:13:09,  6.55it/s]

{'loss': Array(1.5193539, dtype=float32), 'loss_reward': Array(0.01088817, dtype=float32), 'loss_cross_entropy': Array(1.5084658, dtype=float32)}


  0%|          | 4350/1000000 [14:19<31:32:13,  8.77it/s]

{'loss': Array(1.5182766, dtype=float32), 'loss_reward': Array(0.01044769, dtype=float32), 'loss_cross_entropy': Array(1.5078288, dtype=float32)}


  0%|          | 4360/1000000 [14:21<29:59:11,  9.22it/s]

{'loss': Array(1.5176241, dtype=float32), 'loss_reward': Array(0.01087352, dtype=float32), 'loss_cross_entropy': Array(1.5067505, dtype=float32)}


  0%|          | 4370/1000000 [14:23<39:58:50,  6.92it/s]

{'loss': Array(1.5194281, dtype=float32), 'loss_reward': Array(0.01046416, dtype=float32), 'loss_cross_entropy': Array(1.508964, dtype=float32)}


  0%|          | 4380/1000000 [14:24<32:00:14,  8.64it/s]

{'loss': Array(1.5176103, dtype=float32), 'loss_reward': Array(0.01054642, dtype=float32), 'loss_cross_entropy': Array(1.5070639, dtype=float32)}


  0%|          | 4388/1000000 [14:26<36:45:20,  7.52it/s]

{'loss': Array(1.517135, dtype=float32), 'loss_reward': Array(0.01076972, dtype=float32), 'loss_cross_entropy': Array(1.5063655, dtype=float32)}


  0%|          | 4399/1000000 [14:28<31:20:42,  8.82it/s]

{'loss': Array(1.5196129, dtype=float32), 'loss_reward': Array(0.01121431, dtype=float32), 'loss_cross_entropy': Array(1.5083988, dtype=float32)}


  0%|          | 4408/1000000 [14:29<38:14:30,  7.23it/s]

{'loss': Array(1.5183606, dtype=float32), 'loss_reward': Array(0.01118054, dtype=float32), 'loss_cross_entropy': Array(1.5071801, dtype=float32)}


  0%|          | 4418/1000000 [14:31<32:58:22,  8.39it/s]

{'loss': Array(1.5184861, dtype=float32), 'loss_reward': Array(0.01062474, dtype=float32), 'loss_cross_entropy': Array(1.5078614, dtype=float32)}


  0%|          | 4430/1000000 [14:33<28:34:52,  9.68it/s]

{'loss': Array(1.5195524, dtype=float32), 'loss_reward': Array(0.01084934, dtype=float32), 'loss_cross_entropy': Array(1.5087029, dtype=float32)}


  0%|          | 4440/1000000 [14:35<40:36:44,  6.81it/s]

{'loss': Array(1.5169755, dtype=float32), 'loss_reward': Array(0.01116533, dtype=float32), 'loss_cross_entropy': Array(1.5058101, dtype=float32)}


  0%|          | 4448/1000000 [14:36<38:28:18,  7.19it/s]

{'loss': Array(1.5170093, dtype=float32), 'loss_reward': Array(0.01073437, dtype=float32), 'loss_cross_entropy': Array(1.5062749, dtype=float32)}


  0%|          | 4460/1000000 [14:38<30:01:35,  9.21it/s]

{'loss': Array(1.5182505, dtype=float32), 'loss_reward': Array(0.01105506, dtype=float32), 'loss_cross_entropy': Array(1.5071954, dtype=float32)}


  0%|          | 4469/1000000 [14:39<31:59:17,  8.64it/s]

{'loss': Array(1.5192853, dtype=float32), 'loss_reward': Array(0.01052716, dtype=float32), 'loss_cross_entropy': Array(1.5087581, dtype=float32)}


  0%|          | 4479/1000000 [14:41<37:09:27,  7.44it/s]

{'loss': Array(1.5153595, dtype=float32), 'loss_reward': Array(0.01065153, dtype=float32), 'loss_cross_entropy': Array(1.5047079, dtype=float32)}


  0%|          | 4490/1000000 [14:43<30:43:17,  9.00it/s]

{'loss': Array(1.5176523, dtype=float32), 'loss_reward': Array(0.01074389, dtype=float32), 'loss_cross_entropy': Array(1.5069084, dtype=float32)}


  0%|          | 4498/1000000 [14:45<34:54:58,  7.92it/s]

{'loss': Array(1.5169576, dtype=float32), 'loss_reward': Array(0.01039694, dtype=float32), 'loss_cross_entropy': Array(1.5065607, dtype=float32)}


  0%|          | 4508/1000000 [14:57<159:53:25,  1.73it/s]

{'loss': Array(1.5166098, dtype=float32), 'loss_reward': Array(0.01084127, dtype=float32), 'loss_cross_entropy': Array(1.5057687, dtype=float32)}


  0%|          | 4520/1000000 [14:58<54:26:41,  5.08it/s] 

{'loss': Array(1.516265, dtype=float32), 'loss_reward': Array(0.01065143, dtype=float32), 'loss_cross_entropy': Array(1.5056137, dtype=float32)}


  0%|          | 4529/1000000 [15:00<39:58:01,  6.92it/s]

{'loss': Array(1.5172361, dtype=float32), 'loss_reward': Array(0.01090369, dtype=float32), 'loss_cross_entropy': Array(1.5063322, dtype=float32)}


  0%|          | 4540/1000000 [15:02<31:00:13,  8.92it/s]

{'loss': Array(1.5173538, dtype=float32), 'loss_reward': Array(0.01045552, dtype=float32), 'loss_cross_entropy': Array(1.5068984, dtype=float32)}


  0%|          | 4548/1000000 [15:03<40:16:42,  6.87it/s]

{'loss': Array(1.5171152, dtype=float32), 'loss_reward': Array(0.01059301, dtype=float32), 'loss_cross_entropy': Array(1.5065222, dtype=float32)}


  0%|          | 4560/1000000 [15:05<30:40:59,  9.01it/s]

{'loss': Array(1.5180067, dtype=float32), 'loss_reward': Array(0.01063436, dtype=float32), 'loss_cross_entropy': Array(1.5073723, dtype=float32)}


  0%|          | 4570/1000000 [15:07<30:30:44,  9.06it/s]

{'loss': Array(1.5173066, dtype=float32), 'loss_reward': Array(0.01081401, dtype=float32), 'loss_cross_entropy': Array(1.5064925, dtype=float32)}


  0%|          | 4580/1000000 [15:09<36:48:08,  7.51it/s]

{'loss': Array(1.5161641, dtype=float32), 'loss_reward': Array(0.0109688, dtype=float32), 'loss_cross_entropy': Array(1.5051953, dtype=float32)}


  0%|          | 4590/1000000 [15:10<31:56:10,  8.66it/s]

{'loss': Array(1.5170959, dtype=float32), 'loss_reward': Array(0.01053238, dtype=float32), 'loss_cross_entropy': Array(1.5065638, dtype=float32)}


  0%|          | 4600/1000000 [15:12<30:20:26,  9.11it/s]

{'loss': Array(1.5162592, dtype=float32), 'loss_reward': Array(0.01070136, dtype=float32), 'loss_cross_entropy': Array(1.5055579, dtype=float32)}


  0%|          | 4609/1000000 [15:14<32:45:32,  8.44it/s]

{'loss': Array(1.5188649, dtype=float32), 'loss_reward': Array(0.01044799, dtype=float32), 'loss_cross_entropy': Array(1.5084169, dtype=float32)}


  0%|          | 4619/1000000 [15:15<33:28:49,  8.26it/s]

{'loss': Array(1.5158473, dtype=float32), 'loss_reward': Array(0.01075681, dtype=float32), 'loss_cross_entropy': Array(1.5050905, dtype=float32)}


  0%|          | 4629/1000000 [15:17<31:42:08,  8.72it/s]

{'loss': Array(1.51762, dtype=float32), 'loss_reward': Array(0.01080615, dtype=float32), 'loss_cross_entropy': Array(1.5068139, dtype=float32)}


  0%|          | 4638/1000000 [15:19<33:30:11,  8.25it/s]

{'loss': Array(1.5179697, dtype=float32), 'loss_reward': Array(0.01074161, dtype=float32), 'loss_cross_entropy': Array(1.5072281, dtype=float32)}


  0%|          | 4648/1000000 [15:21<41:09:50,  6.72it/s]

{'loss': Array(1.5185719, dtype=float32), 'loss_reward': Array(0.01065792, dtype=float32), 'loss_cross_entropy': Array(1.5079138, dtype=float32)}


  0%|          | 4659/1000000 [15:23<33:09:17,  8.34it/s]

{'loss': Array(1.5161844, dtype=float32), 'loss_reward': Array(0.01066416, dtype=float32), 'loss_cross_entropy': Array(1.5055202, dtype=float32)}


  0%|          | 4669/1000000 [15:24<30:50:01,  8.97it/s]

{'loss': Array(1.5178117, dtype=float32), 'loss_reward': Array(0.01126919, dtype=float32), 'loss_cross_entropy': Array(1.5065424, dtype=float32)}


  0%|          | 4680/1000000 [15:26<38:31:07,  7.18it/s]

{'loss': Array(1.517866, dtype=float32), 'loss_reward': Array(0.0104915, dtype=float32), 'loss_cross_entropy': Array(1.5073744, dtype=float32)}


  0%|          | 4690/1000000 [15:28<33:03:37,  8.36it/s]

{'loss': Array(1.5181276, dtype=float32), 'loss_reward': Array(0.01035448, dtype=float32), 'loss_cross_entropy': Array(1.5077732, dtype=float32)}


  0%|          | 4698/1000000 [15:29<37:10:22,  7.44it/s]

{'loss': Array(1.5161695, dtype=float32), 'loss_reward': Array(0.01015963, dtype=float32), 'loss_cross_entropy': Array(1.5060099, dtype=float32)}


  0%|          | 4709/1000000 [15:31<30:44:25,  8.99it/s]

{'loss': Array(1.5171137, dtype=float32), 'loss_reward': Array(0.01040392, dtype=float32), 'loss_cross_entropy': Array(1.5067098, dtype=float32)}


  0%|          | 4719/1000000 [15:33<36:13:57,  7.63it/s]

{'loss': Array(1.5154796, dtype=float32), 'loss_reward': Array(0.01049988, dtype=float32), 'loss_cross_entropy': Array(1.5049797, dtype=float32)}


  0%|          | 4729/1000000 [15:35<30:44:20,  8.99it/s]

{'loss': Array(1.5151486, dtype=float32), 'loss_reward': Array(0.01054884, dtype=float32), 'loss_cross_entropy': Array(1.5045998, dtype=float32)}


  0%|          | 4740/1000000 [15:36<28:34:02,  9.68it/s]

{'loss': Array(1.5153672, dtype=float32), 'loss_reward': Array(0.01040323, dtype=float32), 'loss_cross_entropy': Array(1.5049641, dtype=float32)}


  0%|          | 4750/1000000 [15:38<40:05:51,  6.89it/s]

{'loss': Array(1.5165906, dtype=float32), 'loss_reward': Array(0.01047701, dtype=float32), 'loss_cross_entropy': Array(1.5061135, dtype=float32)}


  0%|          | 4758/1000000 [15:40<39:12:33,  7.05it/s]

{'loss': Array(1.5167884, dtype=float32), 'loss_reward': Array(0.01047502, dtype=float32), 'loss_cross_entropy': Array(1.5063132, dtype=float32)}


  0%|          | 4770/1000000 [15:41<30:23:54,  9.09it/s]

{'loss': Array(1.5133654, dtype=float32), 'loss_reward': Array(0.01023234, dtype=float32), 'loss_cross_entropy': Array(1.503133, dtype=float32)}


  0%|          | 4779/1000000 [15:43<31:14:47,  8.85it/s]

{'loss': Array(1.5171638, dtype=float32), 'loss_reward': Array(0.01062746, dtype=float32), 'loss_cross_entropy': Array(1.5065364, dtype=float32)}


  0%|          | 4789/1000000 [15:45<35:09:44,  7.86it/s]

{'loss': Array(1.51386, dtype=float32), 'loss_reward': Array(0.01051389, dtype=float32), 'loss_cross_entropy': Array(1.5033462, dtype=float32)}


  0%|          | 4800/1000000 [15:47<28:36:11,  9.66it/s]

{'loss': Array(1.5149761, dtype=float32), 'loss_reward': Array(0.00995556, dtype=float32), 'loss_cross_entropy': Array(1.5050206, dtype=float32)}


  0%|          | 4808/1000000 [15:48<34:56:15,  7.91it/s]

{'loss': Array(1.515304, dtype=float32), 'loss_reward': Array(0.0099713, dtype=float32), 'loss_cross_entropy': Array(1.5053325, dtype=float32)}


  0%|          | 4820/1000000 [15:50<35:11:27,  7.86it/s]

{'loss': Array(1.5168628, dtype=float32), 'loss_reward': Array(0.01034948, dtype=float32), 'loss_cross_entropy': Array(1.5065132, dtype=float32)}


  0%|          | 4828/1000000 [15:52<37:00:42,  7.47it/s]

{'loss': Array(1.5161012, dtype=float32), 'loss_reward': Array(0.01083527, dtype=float32), 'loss_cross_entropy': Array(1.5052662, dtype=float32)}


  0%|          | 4838/1000000 [15:53<32:54:13,  8.40it/s]

{'loss': Array(1.5145814, dtype=float32), 'loss_reward': Array(0.01066333, dtype=float32), 'loss_cross_entropy': Array(1.503918, dtype=float32)}


  0%|          | 4850/1000000 [15:55<26:30:38, 10.43it/s]

{'loss': Array(1.5141877, dtype=float32), 'loss_reward': Array(0.01019926, dtype=float32), 'loss_cross_entropy': Array(1.5039884, dtype=float32)}


  0%|          | 4859/1000000 [15:57<35:23:00,  7.81it/s]

{'loss': Array(1.514783, dtype=float32), 'loss_reward': Array(0.01013128, dtype=float32), 'loss_cross_entropy': Array(1.5046519, dtype=float32)}


  0%|          | 4870/1000000 [15:59<29:55:21,  9.24it/s]

{'loss': Array(1.5139518, dtype=float32), 'loss_reward': Array(0.01023529, dtype=float32), 'loss_cross_entropy': Array(1.5037167, dtype=float32)}


  0%|          | 4880/1000000 [16:00<29:26:03,  9.39it/s]

{'loss': Array(1.5132712, dtype=float32), 'loss_reward': Array(0.01003576, dtype=float32), 'loss_cross_entropy': Array(1.5032353, dtype=float32)}


  0%|          | 4890/1000000 [16:02<36:50:38,  7.50it/s]

{'loss': Array(1.5151373, dtype=float32), 'loss_reward': Array(0.01007032, dtype=float32), 'loss_cross_entropy': Array(1.5050671, dtype=float32)}


  0%|          | 4899/1000000 [16:04<32:48:57,  8.42it/s]

{'loss': Array(1.514764, dtype=float32), 'loss_reward': Array(0.01058231, dtype=float32), 'loss_cross_entropy': Array(1.5041817, dtype=float32)}


  0%|          | 4909/1000000 [16:05<30:05:29,  9.19it/s]

{'loss': Array(1.5146983, dtype=float32), 'loss_reward': Array(0.01014983, dtype=float32), 'loss_cross_entropy': Array(1.5045483, dtype=float32)}


  0%|          | 4920/1000000 [16:07<29:00:38,  9.53it/s]

{'loss': Array(1.5130376, dtype=float32), 'loss_reward': Array(0.01025722, dtype=float32), 'loss_cross_entropy': Array(1.5027803, dtype=float32)}


  0%|          | 4928/1000000 [16:09<39:36:53,  6.98it/s]

{'loss': Array(1.5163653, dtype=float32), 'loss_reward': Array(0.01005944, dtype=float32), 'loss_cross_entropy': Array(1.5063059, dtype=float32)}


  0%|          | 4939/1000000 [16:11<31:57:06,  8.65it/s]

{'loss': Array(1.5114313, dtype=float32), 'loss_reward': Array(0.01015394, dtype=float32), 'loss_cross_entropy': Array(1.5012774, dtype=float32)}


  0%|          | 4943/1000000 [16:12<54:06:47,  5.11it/s]

In [20]:
sample

TrajectoryBufferSample(experience={'action': Array([[[2.56099083e-05, 8.82284610e-08, 2.79197128e-08, ...,
         5.50557211e-10, 2.88656997e-06, 9.99997139e-01],
        [2.47293491e-10, 9.27802224e-10, 1.00000000e+00, ...,
         3.98532691e-04, 9.99601483e-01, 3.32171481e-08],
        [2.09763237e-14, 6.82480550e-06, 1.65303116e-09, ...,
         9.99999642e-01, 6.53804455e-10, 3.60816784e-07],
        ...,
        [1.27481584e-07, 9.99993205e-01, 4.13923908e-06, ...,
         6.99060142e-01, 2.78799862e-01, 2.21399572e-02],
        [1.02648977e-03, 3.72330985e-11, 9.40394122e-04, ...,
         1.11217632e-04, 7.39760087e-14, 9.99888778e-01],
        [1.46892916e-07, 2.50733981e-04, 3.89986113e-03, ...,
         3.50390064e-06, 9.99996543e-01, 2.23333588e-10]],

       [[2.48523785e-10, 3.07286285e-08, 1.07085234e-05, ...,
         9.99786496e-01, 2.13422667e-04, 2.66149840e-08],
        [4.37679797e-01, 9.33327512e-11, 1.15435640e-07, ...,
         2.28540314e-14, 7.32413685e-1

In [15]:
sample = buffer_eval.sample(buffer_list_eval, subkey)
sample = reshape_sample(sample)


In [16]:
loss, (loss_crossentropy, loss_reward) = loss_fn_transformer(transformer, sample.experience)


AttributeError: 'RubikTransformer' object has no attribute 'state_mapping'

In [15]:
# save buffer, buffer_list
# in pickle 
import pickle

state_weight = nnx.state(transformer)

In [16]:
state_weight

State({
  'action_mapping': {
    'bias': VariableState(
      type=Param,
      value=Array([ 0.0489077 ,  0.05748522, -0.04591497,  0.06825393, -0.01245873,
             -0.04504352, -0.02600798, -0.01612351,  0.04197448, -0.01106876,
              0.04133981,  0.00250468, -0.01917384,  0.08190358, -0.02015413,
              0.01493297,  0.07010985,  0.01872672,  0.00455121, -0.07804424,
             -0.02749256,  0.07241128, -0.01716232, -0.00095554, -0.07598419,
              0.07997464,  0.00542156,  0.0525552 , -0.02168683,  0.00650362,
             -0.05971894, -0.03304074,  0.01576223, -0.0412701 ,  0.07530559,
             -0.00969784, -0.03476293, -0.02465479,  0.01211519,  0.04951988,
              0.0479618 ,  0.02026099,  0.04587942, -0.00183689, -0.0394596 ,
              0.05025787,  0.03260148,  0.05018431,  0.03647119,  0.05557197,
              0.01052153, -0.07949464, -0.0741701 ,  0.01655396, -0.03202853,
             -0.01196375, -0.04284544, -0.04348039,  0.019043

In [17]:
# save state into pickle
with open('state_probainput_vscale2.pickle', 'wb') as handle:
    pickle.dump(state_weight, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [12]:
state_weight

State({})

In [13]:
nnx.state(transformer)

State({})

In [14]:

transformer

RubikTransformer()