In [1]:
# chose the current file directory as the working directory
import os
os.chdir("/teamspace/studios/this_studio/rubikscubesolver")

In [2]:
from tqdm import tqdm
import pickle

import wandb  # for logging
import time
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import flax.nnx as nnx

import optax

from rubiktransformer.model_diffusion_dt import RubikDTTransformer, InverseRLModel
import rubiktransformer.dataset as dataset
from rubiktransformer.trainer import reshape_sample

cuda_plugin_extension is not found.


In [3]:
@dataclass
class Config:
    """Configuration class"""

    jax_key: jnp.ndarray = jax.random.PRNGKey(49)
    rngs = nnx.Rngs(48)
    batch_size: int = 128
    lr_1: float = 4e-4
    lr_2: float = 4e-4
    nb_games: int = 128 * 100
    len_seq: int = 32
    nb_step: int = 1000000
    log_every_step: int = 10
    log_eval_every_step: int = 10
    log_policy_reward_every_step: int = 10
    add_data_every_step: int = 500

    save_model_every_step: int = 2000


config = Config()

# init wandb config
user = "forbu14"
project = "RubikTransformer"
display_name = "experiment_" + time.strftime("%Y%m%d-%H%M%S")

wandb.init(entity=user, project=project, name=display_name)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mforbu14[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [27]:
transformer = RubikDTTransformer(rngs=config.rngs, causal=True)

inverse_rl_model = InverseRLModel(
    dim_input_state = 6 * 6 * 3 * 3,
    dim_output_action = 6 + 3,
    dim_middle = 1024,
    nb_layers = 3,
    rngs=config.rngs,
)

scheduler = optax.linear_schedule(init_value=0.0, end_value=1.0, transition_steps=4000)

# init optimizer
optimizer_dd = optax.chain(
    optax.clip_by_global_norm(1.0),
    #optax.lion(config.lr_1 / 10.0),
    optax.adamw(config.lr_1),
    optax.scale_by_schedule(scheduler),
)

optimizer_rl_inverse = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(config.lr_2),
)

optimizer_diffuser = nnx.Optimizer(transformer, optimizer_dd)
optimizer_inverse = nnx.Optimizer(inverse_rl_model, optimizer_rl_inverse)

# metrics
metrics_train = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
    loss_cross_entropy=nnx.metrics.Average("loss_cross_entropy"),
)

metrics_eval = nnx.MultiMetric(
    loss_eval=nnx.metrics.Average("loss_eval"),

    loss_cross_entropy_eval=nnx.metrics.Average("loss_cross_entropy_eval"),
)

# metric for inverse model
metrics_inverse = nnx.MultiMetric(
    loss_inverse=nnx.metrics.Average("loss_inverse"),
)

In [28]:
import pickle

filename = "state_ddt_model_v1.pickle"

with open(filename, "rb") as input_file:
    state = pickle.load(input_file)

nnx.update(transformer, state)

In [29]:
# gather data from the environment
# init models and optimizers
env, buffer = dataset.init_env_buffer(sample_batch_size=config.batch_size)
env, buffer_eval = dataset.init_env_buffer(sample_batch_size=config.batch_size)


nb_games = config.nb_games
len_seq = config.len_seq

state_first = jnp.zeros((6, 3, 3))
state_next = jnp.zeros((len_seq, 6, 3, 3))
action = jnp.zeros((len_seq, 3))
action_proba = jnp.zeros((len_seq, 9))

# transform state to int8 type
state_first = state_first.astype(jnp.int8)
state_next = state_next.astype(jnp.int8)

# action to int32 type
action = action.astype(jnp.int32)

reward = jnp.zeros((1))

jit_step = jax.jit(env.step)

buffer_list = buffer.init(
    {
        "action": action,
        "reward": reward,
        "state_histo": state_next,
    }
)

buffer_list_eval = buffer_eval.init(
    {
        "action": action,
        "reward": reward,
        "state_histo": state_next,
    }
)

In [30]:
def step_fn(state, key):
    """
    Simple step function
    We choose a random action
    """

    action = jax.random.randint(
        key=key,
        minval=env.action_spec.minimum,
        maxval=env.action_spec.maximum,
        shape=(3,),
    )

    new_state, timestep = jit_step(state, action)
    timestep.extras["action"] = action

    return new_state, timestep


def run_n_steps(state, key, n):
    random_keys = jax.random.split(key, n)
    state, rollout = jax.lax.scan(step_fn, state, random_keys)

    return rollout


vmap_reset = jax.vmap(jax.jit(env.reset))
vmap_step = jax.vmap(run_n_steps, in_axes=(0, 0, None))

In [31]:
nnx.display(transformer)

In [32]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data_diffusion(
    env,
    vmap_reset,
    vmap_step,
    int(config.nb_games / 10),
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

In [33]:
sample = buffer.sample(buffer_list, subkey)

def reshape_diffusion_setup(sample, key=jax.random.PRNGKey(0)):
    sample.experience["state_histo"] = sample.experience["state_histo"].reshape(
        (sample.experience["state_histo"].shape[0], sample.experience["state_histo"].shape[1], 54)
    )

    # one hot encoding for state_histo
    sample.experience["state_histo"] = jax.nn.one_hot(
        sample.experience["state_histo"],
        num_classes=6,
        axis=-1,
    )

    # batch creation
    batch  = sample.experience
    len_seq = batch["state_histo"].shape[1]

    time_step = jax.random.uniform(
        key, (batch["state_histo"].shape[0], 1, 1, 1)
    ) # random value between 0 and 1

    batch['time_step'] = time_step

    # now contact the value to have the context for the rectified flow setup
    batch["context"] = jnp.concatenate([batch["reward"], time_step[:, :, 0, 0]], axis=1)

    batch["state_past"] = batch["state_histo"][:, :len_seq//4, :, :]
    batch["state_future"] = batch["state_histo"][:, len_seq//4:, :, :]

    # now we generate the random noise for the rectified flow setup
    simplex_noise = jax.random.dirichlet(key, jnp.ones(6), batch["state_future"].shape[:-1])

    batch["state_future_noise"] = (
        (1 - time_step) * simplex_noise + time_step * batch["state_future"]
    )

    batch["action_inverse"] = sample.experience["action"][:, 1:, :]

    # flatten the action_inverse to only have batch data
    batch["action_inverse"] = jnp.reshape(batch["action_inverse"], (batch["action_inverse"].shape[0] * batch["action_inverse"].shape[1], -1))

    # now we can one hot encode the action_inverse

    action_inverse_0 = jax.nn.one_hot(batch["action_inverse"][:, 0], num_classes=6, axis=-1)
    action_inverse_1 = jax.nn.one_hot(batch["action_inverse"][:, 2], num_classes=3, axis=-1)

    batch["action_inverse"] = jnp.concatenate([action_inverse_0, action_inverse_1], axis=1)

    state_histo_inverse_t = sample.experience["state_histo"][:, :-1, :, :]
    state_histo_inverse_td1 = sample.experience["state_histo"][:, 1:, :, :]

    batch["state_histo_inverse_t"] = state_histo_inverse_t
    batch["state_histo_inverse_td1"] = state_histo_inverse_td1

    # we flatten the two state_histo_inverse
    batch["state_histo_inverse_t"] = jnp.reshape(batch["state_histo_inverse_t"], (batch["state_histo_inverse_t"].shape[0] * batch["state_histo_inverse_t"].shape[1], -1))
    batch["state_histo_inverse_td1"] = jnp.reshape(batch["state_histo_inverse_td1"], (batch["state_histo_inverse_td1"].shape[0] * batch["state_histo_inverse_td1"].shape[1], -1))

    return batch


sample = reshape_diffusion_setup(sample)


(128, 32, 3)

(128, 32, 6, 3, 3)

In [34]:
def loss_fn_transformer_rf(model: RubikDTTransformer, batch):
    # rectified flow setup
    state_past, state_future = model(
        batch["state_past"], batch["state_future_noise"], batch["context"]
    )

    loss_crossentropy = optax.softmax_cross_entropy(
        logits=state_future, labels=batch["state_future"]
    ).mean(axis=[1, 2])

    weight = jnp.clip(1. / (1. - batch["time_step"][:, 0, 0, 0]), min=0.005, max=1.5)

    loss_cross_entropy_weight = loss_crossentropy * weight

    return loss_cross_entropy_weight.mean(), (loss_crossentropy.mean())


@nnx.jit
def train_step_transformer_rf(
    model: RubikDTTransformer,
    optimizer: nnx.Optimizer,
    metrics: nnx.MultiMetric,
    batch,
):
    """Train for a single step."""

    grad_fn = nnx.value_and_grad(loss_fn_transformer_rf, has_aux=True)
    (loss, (loss_crossentropy)), grads = grad_fn(model, batch)
    metrics.update(
        loss=loss, loss_cross_entropy=loss_crossentropy
    )
    optimizer.update(grads)

################# INVERSE RL ####################
def loss_fn_inverse_rl(model: InverseRLModel, batch):
    # rectified flow setup
    action = model(
        batch["state_histo_inverse_t"], batch["state_histo_inverse_td1"]
    )

    loss_crossentropy_0 = optax.softmax_cross_entropy(
        logits=action[:, :6], labels=batch["action_inverse"][:, :6]
    ).mean()

    loss_cross_entropy_1 = optax.softmax_cross_entropy(
        logits=action[:, 6:], labels=batch["action_inverse"][:, 6:]
    ).mean()

    return loss_crossentropy_0 + loss_cross_entropy_1


@nnx.jit
def train_step_inverse_rl(
    model: InverseRLModel,
    optimizer: nnx.Optimizer,
    metrics: nnx.MultiMetric,
    batch,
):
    """Train for a single step."""

    grad_fn = nnx.value_and_grad(loss_fn_inverse_rl)
    loss, grads = grad_fn(model, batch)
    metrics.update(
        loss_inverse=loss
    )
    optimizer.update(grads)

In [38]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data_diffusion(
    env,
    vmap_reset,
    vmap_step,
    config.nb_games * 15, # old is int(config.nb_games * 10.0),
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

In [39]:


# transformer model calibration
for idx_step in tqdm(range(config.nb_step)):
    # training for world model
    key, subkey = jax.random.split(config.jax_key)
    config.jax_key = key

    if idx_step % config.add_data_every_step == 0:
        buffer, buffer_list = dataset.fast_gathering_data_diffusion(
            env,
            vmap_reset,
            vmap_step,
            int(config.nb_games // 10),
            config.len_seq,
            buffer,
            buffer_list,
            config.jax_key,
        )

    sample = buffer.sample(buffer_list, subkey)
    sample = reshape_diffusion_setup(sample, subkey)

    # we update the policy
    train_step_transformer_rf(
        transformer, optimizer_diffuser, metrics_train, sample
    )

    # train the inverse model
    train_step_inverse_rl(
        inverse_rl_model, optimizer_inverse, metrics_inverse, sample
    )

    if idx_step % config.log_every_step == 0:
        metrics_train_result = metrics_train.compute()
        print(metrics_train_result)

        wandb.log(metrics_train_result, step=idx_step)
        metrics_train.reset()

        metrics_inverse_result = metrics_inverse.compute()
        print(metrics_inverse_result)

        wandb.log(metrics_inverse_result, step=idx_step)
        metrics_inverse.reset()

    if idx_step % config.log_eval_every_step == 0:
        key, subkey = jax.random.split(config.jax_key)
        config.jax_key = key

        buffer_eval, buffer_list_eval = dataset.fast_gathering_data_diffusion(
            env,
            vmap_reset,
            vmap_step,
            int(128),
            config.len_seq,
            buffer_eval,
            buffer_list_eval,
            subkey,
        )

        sample = buffer_eval.sample(buffer_list_eval, subkey)
        sample = reshape_diffusion_setup(sample, subkey)

        loss, (loss_crossentropy) = loss_fn_transformer_rf(
            transformer, sample
        )

        metrics_eval.update(
            loss_eval=loss,
            loss_cross_entropy_eval=loss_crossentropy,
        )
        wandb.log(metrics_eval.compute(), step=idx_step)

        metrics_eval.reset()

    if idx_step % config.save_model_every_step == 0:

        state_weight = nnx.state(transformer)

        with open("state_ddt_model.pickle", "wb") as handle:
            pickle.dump(state_weight, handle, protocol=pickle.HIGHEST_PROTOCOL)

        # save inverse model
        state_weight = nnx.state(inverse_rl_model)

        with open("state_inverse_rl_model.pickle", "wb") as handle:
            pickle.dump(state_weight, handle, protocol=pickle.HIGHEST_PROTOCOL)



{'loss': Array(0.15461496, dtype=float32), 'loss_cross_entropy': Array(0.14609098, dtype=float32)}
{'loss_inverse': Array(0.39211738, dtype=float32)}




{'loss': Array(0.13382857, dtype=float32), 'loss_cross_entropy': Array(0.12704742, dtype=float32)}
{'loss_inverse': Array(0.37786818, dtype=float32)}


  0%|          | 19/1000000 [00:10<43:52:32,  6.33it/s]

{'loss': Array(0.12655406, dtype=float32), 'loss_cross_entropy': Array(0.11972822, dtype=float32)}
{'loss_inverse': Array(0.35755154, dtype=float32)}




{'loss': Array(0.12772916, dtype=float32), 'loss_cross_entropy': Array(0.12123533, dtype=float32)}
{'loss_inverse': Array(0.33645052, dtype=float32)}


  0%|          | 39/1000000 [00:13<39:05:00,  7.11it/s]

{'loss': Array(0.14599803, dtype=float32), 'loss_cross_entropy': Array(0.13886742, dtype=float32)}
{'loss_inverse': Array(0.3199631, dtype=float32)}




{'loss': Array(0.1419088, dtype=float32), 'loss_cross_entropy': Array(0.13468434, dtype=float32)}
{'loss_inverse': Array(0.30310372, dtype=float32)}


  0%|          | 59/1000000 [00:16<28:34:28,  9.72it/s]

{'loss': Array(0.14238988, dtype=float32), 'loss_cross_entropy': Array(0.13478689, dtype=float32)}
{'loss_inverse': Array(0.28719926, dtype=float32)}


  0%|          | 69/1000000 [00:18<48:23:35,  5.74it/s]

{'loss': Array(0.12645268, dtype=float32), 'loss_cross_entropy': Array(0.12010985, dtype=float32)}
{'loss_inverse': Array(0.27153423, dtype=float32)}


  0%|          | 79/1000000 [00:20<31:53:34,  8.71it/s]

{'loss': Array(0.12694852, dtype=float32), 'loss_cross_entropy': Array(0.12042151, dtype=float32)}
{'loss_inverse': Array(0.257411, dtype=float32)}


  0%|          | 89/1000000 [00:21<29:37:05,  9.38it/s]

{'loss': Array(0.12761395, dtype=float32), 'loss_cross_entropy': Array(0.12062955, dtype=float32)}
{'loss_inverse': Array(0.24374273, dtype=float32)}


  0%|          | 99/1000000 [00:23<28:35:57,  9.71it/s]

{'loss': Array(0.13570766, dtype=float32), 'loss_cross_entropy': Array(0.12850901, dtype=float32)}
{'loss_inverse': Array(0.22928198, dtype=float32)}


  0%|          | 109/1000000 [00:25<33:15:35,  8.35it/s]

{'loss': Array(0.13847332, dtype=float32), 'loss_cross_entropy': Array(0.13076948, dtype=float32)}
{'loss_inverse': Array(0.22026172, dtype=float32)}




{'loss': Array(0.1361432, dtype=float32), 'loss_cross_entropy': Array(0.12899071, dtype=float32)}
{'loss_inverse': Array(0.20654252, dtype=float32)}


  0%|          | 129/1000000 [00:28<29:05:11,  9.55it/s]

{'loss': Array(0.14016752, dtype=float32), 'loss_cross_entropy': Array(0.13303308, dtype=float32)}
{'loss_inverse': Array(0.19620974, dtype=float32)}


  0%|          | 139/1000000 [00:30<33:36:07,  8.27it/s]

{'loss': Array(0.11907648, dtype=float32), 'loss_cross_entropy': Array(0.11241098, dtype=float32)}
{'loss_inverse': Array(0.18558194, dtype=float32)}




{'loss': Array(0.13597381, dtype=float32), 'loss_cross_entropy': Array(0.12885323, dtype=float32)}
{'loss_inverse': Array(0.17578025, dtype=float32)}


  0%|          | 159/1000000 [00:33<28:55:26,  9.60it/s]

{'loss': Array(0.126873, dtype=float32), 'loss_cross_entropy': Array(0.11997842, dtype=float32)}
{'loss_inverse': Array(0.16681968, dtype=float32)}


  0%|          | 169/1000000 [00:35<35:21:26,  7.85it/s]

{'loss': Array(0.1336355, dtype=float32), 'loss_cross_entropy': Array(0.12668796, dtype=float32)}
{'loss_inverse': Array(0.15811981, dtype=float32)}




{'loss': Array(0.12374101, dtype=float32), 'loss_cross_entropy': Array(0.11711033, dtype=float32)}
{'loss_inverse': Array(0.15203266, dtype=float32)}


  0%|          | 189/1000000 [00:38<28:22:58,  9.78it/s]

{'loss': Array(0.11756764, dtype=float32), 'loss_cross_entropy': Array(0.11167746, dtype=float32)}
{'loss_inverse': Array(0.14012067, dtype=float32)}




{'loss': Array(0.13815388, dtype=float32), 'loss_cross_entropy': Array(0.13108376, dtype=float32)}
{'loss_inverse': Array(0.13342348, dtype=float32)}


  0%|          | 209/1000000 [00:41<33:45:19,  8.23it/s]

{'loss': Array(0.12165167, dtype=float32), 'loss_cross_entropy': Array(0.11554223, dtype=float32)}
{'loss_inverse': Array(0.12943797, dtype=float32)}




{'loss': Array(0.12544082, dtype=float32), 'loss_cross_entropy': Array(0.11922165, dtype=float32)}
{'loss_inverse': Array(0.11753571, dtype=float32)}


  0%|          | 229/1000000 [00:44<28:27:39,  9.76it/s]

{'loss': Array(0.130931, dtype=float32), 'loss_cross_entropy': Array(0.1242337, dtype=float32)}
{'loss_inverse': Array(0.1127223, dtype=float32)}


  0%|          | 239/1000000 [00:46<43:26:14,  6.39it/s]

{'loss': Array(0.14917068, dtype=float32), 'loss_cross_entropy': Array(0.14184795, dtype=float32)}
{'loss_inverse': Array(0.10624486, dtype=float32)}


  0%|          | 249/1000000 [00:48<31:48:22,  8.73it/s]

{'loss': Array(0.12971804, dtype=float32), 'loss_cross_entropy': Array(0.12245561, dtype=float32)}
{'loss_inverse': Array(0.10117482, dtype=float32)}


  0%|          | 259/1000000 [00:49<29:10:18,  9.52it/s]

{'loss': Array(0.1295768, dtype=float32), 'loss_cross_entropy': Array(0.12327977, dtype=float32)}
{'loss_inverse': Array(0.09558576, dtype=float32)}


  0%|          | 269/1000000 [00:51<43:47:55,  6.34it/s]

{'loss': Array(0.13121776, dtype=float32), 'loss_cross_entropy': Array(0.12457725, dtype=float32)}
{'loss_inverse': Array(0.0914719, dtype=float32)}


  0%|          | 279/1000000 [00:53<30:53:46,  8.99it/s]

{'loss': Array(0.15685098, dtype=float32), 'loss_cross_entropy': Array(0.15011756, dtype=float32)}
{'loss_inverse': Array(0.08797484, dtype=float32)}


  0%|          | 289/1000000 [00:54<28:43:58,  9.66it/s]

{'loss': Array(0.12183086, dtype=float32), 'loss_cross_entropy': Array(0.11517253, dtype=float32)}
{'loss_inverse': Array(0.07979413, dtype=float32)}


  0%|          | 299/1000000 [00:56<28:38:35,  9.69it/s]

{'loss': Array(0.14052002, dtype=float32), 'loss_cross_entropy': Array(0.13386753, dtype=float32)}
{'loss_inverse': Array(0.07695884, dtype=float32)}


  0%|          | 309/1000000 [00:58<49:21:33,  5.63it/s]

{'loss': Array(0.13415633, dtype=float32), 'loss_cross_entropy': Array(0.127549, dtype=float32)}
{'loss_inverse': Array(0.07476064, dtype=float32)}


  0%|          | 319/1000000 [00:59<32:03:43,  8.66it/s]

{'loss': Array(0.1265234, dtype=float32), 'loss_cross_entropy': Array(0.11998122, dtype=float32)}
{'loss_inverse': Array(0.06793172, dtype=float32)}


  0%|          | 329/1000000 [01:01<29:01:22,  9.57it/s]

{'loss': Array(0.13664609, dtype=float32), 'loss_cross_entropy': Array(0.12952438, dtype=float32)}
{'loss_inverse': Array(0.06430935, dtype=float32)}


  0%|          | 339/1000000 [01:02<28:19:20,  9.80it/s]

{'loss': Array(0.134382, dtype=float32), 'loss_cross_entropy': Array(0.12765907, dtype=float32)}
{'loss_inverse': Array(0.0603136, dtype=float32)}


  0%|          | 349/1000000 [01:04<42:38:43,  6.51it/s]

{'loss': Array(0.14205173, dtype=float32), 'loss_cross_entropy': Array(0.13461228, dtype=float32)}
{'loss_inverse': Array(0.05632284, dtype=float32)}


  0%|          | 359/1000000 [01:05<31:24:14,  8.84it/s]

{'loss': Array(0.13634533, dtype=float32), 'loss_cross_entropy': Array(0.12874728, dtype=float32)}
{'loss_inverse': Array(0.05672811, dtype=float32)}


  0%|          | 369/1000000 [01:07<28:35:18,  9.71it/s]

{'loss': Array(0.12931442, dtype=float32), 'loss_cross_entropy': Array(0.12260609, dtype=float32)}
{'loss_inverse': Array(0.05056208, dtype=float32)}


  0%|          | 379/1000000 [01:08<28:37:16,  9.70it/s]

{'loss': Array(0.1225698, dtype=float32), 'loss_cross_entropy': Array(0.1162341, dtype=float32)}
{'loss_inverse': Array(0.05149758, dtype=float32)}


  0%|          | 389/1000000 [01:10<37:00:04,  7.50it/s]

{'loss': Array(0.12145109, dtype=float32), 'loss_cross_entropy': Array(0.11449033, dtype=float32)}
{'loss_inverse': Array(0.04644219, dtype=float32)}


  0%|          | 399/1000000 [01:12<29:49:47,  9.31it/s]

{'loss': Array(0.10768602, dtype=float32), 'loss_cross_entropy': Array(0.10200881, dtype=float32)}
{'loss_inverse': Array(0.04402, dtype=float32)}


  0%|          | 409/1000000 [01:13<29:08:48,  9.53it/s]

{'loss': Array(0.12651625, dtype=float32), 'loss_cross_entropy': Array(0.11983191, dtype=float32)}
{'loss_inverse': Array(0.0417438, dtype=float32)}


  0%|          | 419/1000000 [01:15<43:55:24,  6.32it/s]

{'loss': Array(0.13844495, dtype=float32), 'loss_cross_entropy': Array(0.13134657, dtype=float32)}
{'loss_inverse': Array(0.03997377, dtype=float32)}


  0%|          | 429/1000000 [01:17<30:13:47,  9.18it/s]

{'loss': Array(0.13533534, dtype=float32), 'loss_cross_entropy': Array(0.12818274, dtype=float32)}
{'loss_inverse': Array(0.03722059, dtype=float32)}


  0%|          | 439/1000000 [01:18<28:47:07,  9.65it/s]

{'loss': Array(0.12485391, dtype=float32), 'loss_cross_entropy': Array(0.11831007, dtype=float32)}
{'loss_inverse': Array(0.03316877, dtype=float32)}


  0%|          | 449/1000000 [01:20<30:08:11,  9.21it/s]

{'loss': Array(0.1283969, dtype=float32), 'loss_cross_entropy': Array(0.12165745, dtype=float32)}
{'loss_inverse': Array(0.03308463, dtype=float32)}


  0%|          | 459/1000000 [01:22<39:24:04,  7.05it/s]

{'loss': Array(0.12999693, dtype=float32), 'loss_cross_entropy': Array(0.12268096, dtype=float32)}
{'loss_inverse': Array(0.02478969, dtype=float32)}


  0%|          | 469/1000000 [01:23<30:54:43,  8.98it/s]

{'loss': Array(0.14236252, dtype=float32), 'loss_cross_entropy': Array(0.13509165, dtype=float32)}
{'loss_inverse': Array(0.02382129, dtype=float32)}


  0%|          | 479/1000000 [01:25<29:46:05,  9.33it/s]

{'loss': Array(0.14194837, dtype=float32), 'loss_cross_entropy': Array(0.13511798, dtype=float32)}
{'loss_inverse': Array(0.02382745, dtype=float32)}


  0%|          | 489/1000000 [01:26<28:27:55,  9.75it/s]

{'loss': Array(0.13075204, dtype=float32), 'loss_cross_entropy': Array(0.12384278, dtype=float32)}
{'loss_inverse': Array(0.02106911, dtype=float32)}


  0%|          | 499/1000000 [01:28<38:00:41,  7.30it/s]

{'loss': Array(0.11454239, dtype=float32), 'loss_cross_entropy': Array(0.10832538, dtype=float32)}
{'loss_inverse': Array(0.01906163, dtype=float32)}


  0%|          | 509/1000000 [01:36<97:06:17,  2.86it/s] 

{'loss': Array(0.1249349, dtype=float32), 'loss_cross_entropy': Array(0.11890507, dtype=float32)}
{'loss_inverse': Array(0.01862136, dtype=float32)}


  0%|          | 519/1000000 [01:38<38:55:18,  7.13it/s] 

{'loss': Array(0.12260435, dtype=float32), 'loss_cross_entropy': Array(0.11621933, dtype=float32)}
{'loss_inverse': Array(0.01797247, dtype=float32)}


  0%|          | 529/1000000 [01:39<29:59:33,  9.26it/s]

{'loss': Array(0.14227808, dtype=float32), 'loss_cross_entropy': Array(0.13490178, dtype=float32)}
{'loss_inverse': Array(0.01617466, dtype=float32)}


  0%|          | 539/1000000 [01:41<38:50:02,  7.15it/s]

{'loss': Array(0.12488742, dtype=float32), 'loss_cross_entropy': Array(0.1182989, dtype=float32)}
{'loss_inverse': Array(0.01443574, dtype=float32)}


  0%|          | 549/1000000 [01:43<29:48:10,  9.32it/s]

{'loss': Array(0.13510153, dtype=float32), 'loss_cross_entropy': Array(0.12824453, dtype=float32)}
{'loss_inverse': Array(0.01468645, dtype=float32)}


  0%|          | 559/1000000 [01:44<28:40:30,  9.68it/s]

{'loss': Array(0.13589881, dtype=float32), 'loss_cross_entropy': Array(0.12920657, dtype=float32)}
{'loss_inverse': Array(0.01437861, dtype=float32)}


  0%|          | 569/1000000 [01:46<50:00:41,  5.55it/s]

{'loss': Array(0.13484217, dtype=float32), 'loss_cross_entropy': Array(0.1277432, dtype=float32)}
{'loss_inverse': Array(0.01382351, dtype=float32)}


  0%|          | 579/1000000 [01:48<32:10:22,  8.63it/s]

{'loss': Array(0.1271976, dtype=float32), 'loss_cross_entropy': Array(0.12117817, dtype=float32)}
{'loss_inverse': Array(0.01267937, dtype=float32)}


  0%|          | 589/1000000 [01:49<28:32:47,  9.72it/s]

{'loss': Array(0.11881521, dtype=float32), 'loss_cross_entropy': Array(0.11257967, dtype=float32)}
{'loss_inverse': Array(0.01111125, dtype=float32)}


  0%|          | 599/1000000 [01:51<29:00:18,  9.57it/s]

{'loss': Array(0.14442192, dtype=float32), 'loss_cross_entropy': Array(0.13756941, dtype=float32)}
{'loss_inverse': Array(0.01163621, dtype=float32)}


  0%|          | 609/1000000 [01:53<34:56:22,  7.95it/s]

{'loss': Array(0.12704663, dtype=float32), 'loss_cross_entropy': Array(0.12122308, dtype=float32)}
{'loss_inverse': Array(0.01138796, dtype=float32)}


  0%|          | 619/1000000 [01:54<29:04:57,  9.55it/s]

{'loss': Array(0.13615577, dtype=float32), 'loss_cross_entropy': Array(0.1296629, dtype=float32)}
{'loss_inverse': Array(0.00998064, dtype=float32)}


  0%|          | 629/1000000 [01:56<29:22:46,  9.45it/s]

{'loss': Array(0.12741986, dtype=float32), 'loss_cross_entropy': Array(0.12077361, dtype=float32)}
{'loss_inverse': Array(0.01026065, dtype=float32)}


  0%|          | 639/1000000 [01:57<42:25:25,  6.54it/s]

{'loss': Array(0.12137487, dtype=float32), 'loss_cross_entropy': Array(0.11493187, dtype=float32)}
{'loss_inverse': Array(0.01143807, dtype=float32)}


  0%|          | 649/1000000 [01:59<30:04:18,  9.23it/s]

{'loss': Array(0.1273386, dtype=float32), 'loss_cross_entropy': Array(0.12060974, dtype=float32)}
{'loss_inverse': Array(0.01050278, dtype=float32)}


  0%|          | 659/1000000 [02:00<29:26:00,  9.43it/s]

{'loss': Array(0.12134924, dtype=float32), 'loss_cross_entropy': Array(0.11498404, dtype=float32)}
{'loss_inverse': Array(0.01097289, dtype=float32)}


  0%|          | 669/1000000 [02:02<48:41:38,  5.70it/s]

{'loss': Array(0.12723367, dtype=float32), 'loss_cross_entropy': Array(0.12070673, dtype=float32)}
{'loss_inverse': Array(0.0089547, dtype=float32)}


  0%|          | 679/1000000 [02:04<31:30:23,  8.81it/s]

{'loss': Array(0.12044891, dtype=float32), 'loss_cross_entropy': Array(0.11441912, dtype=float32)}
{'loss_inverse': Array(0.00805375, dtype=float32)}


  0%|          | 690/1000000 [02:05<26:58:00, 10.29it/s]

{'loss': Array(0.11808826, dtype=float32), 'loss_cross_entropy': Array(0.11175104, dtype=float32)}
{'loss_inverse': Array(0.00778345, dtype=float32)}


  0%|          | 700/1000000 [02:07<27:37:47, 10.05it/s]

{'loss': Array(0.12420559, dtype=float32), 'loss_cross_entropy': Array(0.11815748, dtype=float32)}
{'loss_inverse': Array(0.00780144, dtype=float32)}


  0%|          | 710/1000000 [02:08<27:54:49,  9.94it/s]

{'loss': Array(0.1331269, dtype=float32), 'loss_cross_entropy': Array(0.1262717, dtype=float32)}
{'loss_inverse': Array(0.0071905, dtype=float32)}


  0%|          | 720/1000000 [02:10<36:08:47,  7.68it/s]

{'loss': Array(0.11213092, dtype=float32), 'loss_cross_entropy': Array(0.10634883, dtype=float32)}
{'loss_inverse': Array(0.00678427, dtype=float32)}


  0%|          | 730/1000000 [02:12<29:10:46,  9.51it/s]

{'loss': Array(0.13767236, dtype=float32), 'loss_cross_entropy': Array(0.13047192, dtype=float32)}
{'loss_inverse': Array(0.0068629, dtype=float32)}


  0%|          | 740/1000000 [02:13<28:13:20,  9.84it/s]

{'loss': Array(0.11598795, dtype=float32), 'loss_cross_entropy': Array(0.10944493, dtype=float32)}
{'loss_inverse': Array(0.00629597, dtype=float32)}


  0%|          | 750/1000000 [02:15<27:43:45, 10.01it/s]

{'loss': Array(0.11867534, dtype=float32), 'loss_cross_entropy': Array(0.11220054, dtype=float32)}
{'loss_inverse': Array(0.00654362, dtype=float32)}


  0%|          | 760/1000000 [02:17<33:38:12,  8.25it/s]

{'loss': Array(0.14284827, dtype=float32), 'loss_cross_entropy': Array(0.13609968, dtype=float32)}
{'loss_inverse': Array(0.00725114, dtype=float32)}


  0%|          | 770/1000000 [02:18<29:03:41,  9.55it/s]

{'loss': Array(0.13000384, dtype=float32), 'loss_cross_entropy': Array(0.12358674, dtype=float32)}
{'loss_inverse': Array(0.00694343, dtype=float32)}


  0%|          | 780/1000000 [02:20<28:06:13,  9.88it/s]

{'loss': Array(0.1481434, dtype=float32), 'loss_cross_entropy': Array(0.14105926, dtype=float32)}
{'loss_inverse': Array(0.00704452, dtype=float32)}


  0%|          | 790/1000000 [02:22<44:33:25,  6.23it/s]

{'loss': Array(0.08993089, dtype=float32), 'loss_cross_entropy': Array(0.08504378, dtype=float32)}
{'loss_inverse': Array(0.0053534, dtype=float32)}


  0%|          | 800/1000000 [02:23<31:25:04,  8.83it/s]

{'loss': Array(0.12313833, dtype=float32), 'loss_cross_entropy': Array(0.11692175, dtype=float32)}
{'loss_inverse': Array(0.00564529, dtype=float32)}


  0%|          | 810/1000000 [02:25<28:50:45,  9.62it/s]

{'loss': Array(0.13347386, dtype=float32), 'loss_cross_entropy': Array(0.12682898, dtype=float32)}
{'loss_inverse': Array(0.00508916, dtype=float32)}


  0%|          | 820/1000000 [02:26<29:25:42,  9.43it/s]

{'loss': Array(0.13070369, dtype=float32), 'loss_cross_entropy': Array(0.1242164, dtype=float32)}
{'loss_inverse': Array(0.00535275, dtype=float32)}


  0%|          | 829/1000000 [02:28<37:03:21,  7.49it/s]

{'loss': Array(0.13230158, dtype=float32), 'loss_cross_entropy': Array(0.12572257, dtype=float32)}
{'loss_inverse': Array(0.0057195, dtype=float32)}


  0%|          | 839/1000000 [02:30<30:06:31,  9.22it/s]

{'loss': Array(0.12714803, dtype=float32), 'loss_cross_entropy': Array(0.1205663, dtype=float32)}
{'loss_inverse': Array(0.0050381, dtype=float32)}


  0%|          | 849/1000000 [02:31<29:20:37,  9.46it/s]

{'loss': Array(0.12256064, dtype=float32), 'loss_cross_entropy': Array(0.11654279, dtype=float32)}
{'loss_inverse': Array(0.00547542, dtype=float32)}


  0%|          | 859/1000000 [02:33<43:32:07,  6.38it/s]

{'loss': Array(0.12328092, dtype=float32), 'loss_cross_entropy': Array(0.11670341, dtype=float32)}
{'loss_inverse': Array(0.00582399, dtype=float32)}


  0%|          | 869/1000000 [02:35<31:02:14,  8.94it/s]

{'loss': Array(0.13381974, dtype=float32), 'loss_cross_entropy': Array(0.12724632, dtype=float32)}
{'loss_inverse': Array(0.00485764, dtype=float32)}


  0%|          | 879/1000000 [02:36<28:47:01,  9.64it/s]

{'loss': Array(0.12417831, dtype=float32), 'loss_cross_entropy': Array(0.11838746, dtype=float32)}
{'loss_inverse': Array(0.00467128, dtype=float32)}


  0%|          | 889/1000000 [02:38<28:11:29,  9.84it/s]

{'loss': Array(0.1306154, dtype=float32), 'loss_cross_entropy': Array(0.12402103, dtype=float32)}
{'loss_inverse': Array(0.00458535, dtype=float32)}


  0%|          | 899/1000000 [02:39<38:36:13,  7.19it/s]

{'loss': Array(0.11696763, dtype=float32), 'loss_cross_entropy': Array(0.11037413, dtype=float32)}
{'loss_inverse': Array(0.00461166, dtype=float32)}


  0%|          | 909/1000000 [02:41<30:19:26,  9.15it/s]

{'loss': Array(0.1205241, dtype=float32), 'loss_cross_entropy': Array(0.11429077, dtype=float32)}
{'loss_inverse': Array(0.00473827, dtype=float32)}


  0%|          | 919/1000000 [02:42<28:23:37,  9.77it/s]

{'loss': Array(0.12778044, dtype=float32), 'loss_cross_entropy': Array(0.12138664, dtype=float32)}
{'loss_inverse': Array(0.00401734, dtype=float32)}


  0%|          | 929/1000000 [02:44<28:17:25,  9.81it/s]

{'loss': Array(0.11271645, dtype=float32), 'loss_cross_entropy': Array(0.10668581, dtype=float32)}
{'loss_inverse': Array(0.00392904, dtype=float32)}


  0%|          | 939/1000000 [02:46<38:30:34,  7.21it/s]

{'loss': Array(0.12896603, dtype=float32), 'loss_cross_entropy': Array(0.12286883, dtype=float32)}
{'loss_inverse': Array(0.00364122, dtype=float32)}


  0%|          | 949/1000000 [02:47<29:35:40,  9.38it/s]

{'loss': Array(0.12413516, dtype=float32), 'loss_cross_entropy': Array(0.11783488, dtype=float32)}
{'loss_inverse': Array(0.00386887, dtype=float32)}


  0%|          | 959/1000000 [02:49<28:19:38,  9.80it/s]

{'loss': Array(0.13587241, dtype=float32), 'loss_cross_entropy': Array(0.1290578, dtype=float32)}
{'loss_inverse': Array(0.00390088, dtype=float32)}


  0%|          | 969/1000000 [02:50<28:56:58,  9.59it/s]

{'loss': Array(0.11720798, dtype=float32), 'loss_cross_entropy': Array(0.11120278, dtype=float32)}
{'loss_inverse': Array(0.00318461, dtype=float32)}


  0%|          | 980/1000000 [02:52<31:09:12,  8.91it/s]

{'loss': Array(0.10971852, dtype=float32), 'loss_cross_entropy': Array(0.1035999, dtype=float32)}
{'loss_inverse': Array(0.00346536, dtype=float32)}


  0%|          | 990/1000000 [02:54<28:11:35,  9.84it/s]

{'loss': Array(0.11591876, dtype=float32), 'loss_cross_entropy': Array(0.10991492, dtype=float32)}
{'loss_inverse': Array(0.00310038, dtype=float32)}


  0%|          | 1000/1000000 [02:55<28:49:02,  9.63it/s]

{'loss': Array(0.1350899, dtype=float32), 'loss_cross_entropy': Array(0.12881646, dtype=float32)}
{'loss_inverse': Array(0.00293528, dtype=float32)}


  0%|          | 1010/1000000 [03:04<110:03:14,  2.52it/s]

{'loss': Array(0.12348186, dtype=float32), 'loss_cross_entropy': Array(0.11761621, dtype=float32)}
{'loss_inverse': Array(0.00284611, dtype=float32)}


  0%|          | 1020/1000000 [03:06<42:08:35,  6.58it/s] 

{'loss': Array(0.13595384, dtype=float32), 'loss_cross_entropy': Array(0.1291087, dtype=float32)}
{'loss_inverse': Array(0.0031078, dtype=float32)}


  0%|          | 1030/1000000 [03:07<30:07:32,  9.21it/s]

{'loss': Array(0.13315336, dtype=float32), 'loss_cross_entropy': Array(0.12654142, dtype=float32)}
{'loss_inverse': Array(0.00305881, dtype=float32)}


  0%|          | 1040/1000000 [03:08<28:38:00,  9.69it/s]

{'loss': Array(0.13333938, dtype=float32), 'loss_cross_entropy': Array(0.12658164, dtype=float32)}
{'loss_inverse': Array(0.00320852, dtype=float32)}


  0%|          | 1050/1000000 [03:10<33:19:14,  8.33it/s]

{'loss': Array(0.12590209, dtype=float32), 'loss_cross_entropy': Array(0.11973244, dtype=float32)}
{'loss_inverse': Array(0.0027018, dtype=float32)}


  0%|          | 1060/1000000 [03:12<28:41:08,  9.67it/s]

{'loss': Array(0.12421019, dtype=float32), 'loss_cross_entropy': Array(0.11773254, dtype=float32)}
{'loss_inverse': Array(0.00307521, dtype=float32)}


  0%|          | 1070/1000000 [03:13<28:51:21,  9.62it/s]

{'loss': Array(0.13245149, dtype=float32), 'loss_cross_entropy': Array(0.12570432, dtype=float32)}
{'loss_inverse': Array(0.00297492, dtype=float32)}


  0%|          | 1080/1000000 [03:15<38:24:25,  7.22it/s]

{'loss': Array(0.12508915, dtype=float32), 'loss_cross_entropy': Array(0.11882367, dtype=float32)}
{'loss_inverse': Array(0.00281553, dtype=float32)}


  0%|          | 1090/1000000 [03:17<29:13:15,  9.50it/s]

{'loss': Array(0.107086, dtype=float32), 'loss_cross_entropy': Array(0.10154056, dtype=float32)}
{'loss_inverse': Array(0.00297418, dtype=float32)}


  0%|          | 1100/1000000 [03:18<28:11:28,  9.84it/s]

{'loss': Array(0.13243161, dtype=float32), 'loss_cross_entropy': Array(0.12552224, dtype=float32)}
{'loss_inverse': Array(0.00259458, dtype=float32)}


  0%|          | 1110/1000000 [03:20<50:57:43,  5.44it/s]

{'loss': Array(0.13771747, dtype=float32), 'loss_cross_entropy': Array(0.13064332, dtype=float32)}
{'loss_inverse': Array(0.00240435, dtype=float32)}


  0%|          | 1120/1000000 [03:22<31:37:32,  8.77it/s]

{'loss': Array(0.14468382, dtype=float32), 'loss_cross_entropy': Array(0.13740198, dtype=float32)}
{'loss_inverse': Array(0.00282165, dtype=float32)}


  0%|          | 1130/1000000 [03:23<28:57:54,  9.58it/s]

{'loss': Array(0.14295045, dtype=float32), 'loss_cross_entropy': Array(0.13631366, dtype=float32)}
{'loss_inverse': Array(0.00253788, dtype=float32)}


  0%|          | 1140/1000000 [03:25<28:23:52,  9.77it/s]

{'loss': Array(0.14344369, dtype=float32), 'loss_cross_entropy': Array(0.13629226, dtype=float32)}
{'loss_inverse': Array(0.00266564, dtype=float32)}


  0%|          | 1150/1000000 [03:27<36:10:49,  7.67it/s]

{'loss': Array(0.12239143, dtype=float32), 'loss_cross_entropy': Array(0.1163175, dtype=float32)}
{'loss_inverse': Array(0.00228492, dtype=float32)}


  0%|          | 1160/1000000 [03:28<30:06:31,  9.22it/s]

{'loss': Array(0.12334409, dtype=float32), 'loss_cross_entropy': Array(0.11715093, dtype=float32)}
{'loss_inverse': Array(0.00209554, dtype=float32)}


  0%|          | 1170/1000000 [03:30<28:41:29,  9.67it/s]

{'loss': Array(0.11721265, dtype=float32), 'loss_cross_entropy': Array(0.11052202, dtype=float32)}
{'loss_inverse': Array(0.00235293, dtype=float32)}


  0%|          | 1180/1000000 [03:31<30:11:26,  9.19it/s]

{'loss': Array(0.12536903, dtype=float32), 'loss_cross_entropy': Array(0.11883741, dtype=float32)}
{'loss_inverse': Array(0.0023798, dtype=float32)}


  0%|          | 1190/1000000 [03:33<33:31:36,  8.28it/s]

{'loss': Array(0.12963419, dtype=float32), 'loss_cross_entropy': Array(0.12334533, dtype=float32)}
{'loss_inverse': Array(0.00217991, dtype=float32)}


  0%|          | 1200/1000000 [03:35<29:04:32,  9.54it/s]

{'loss': Array(0.11482503, dtype=float32), 'loss_cross_entropy': Array(0.1086024, dtype=float32)}
{'loss_inverse': Array(0.00209911, dtype=float32)}


  0%|          | 1210/1000000 [03:36<28:48:54,  9.63it/s]

{'loss': Array(0.11473546, dtype=float32), 'loss_cross_entropy': Array(0.10846736, dtype=float32)}
{'loss_inverse': Array(0.00231864, dtype=float32)}


  0%|          | 1220/1000000 [03:38<28:15:02,  9.82it/s]

{'loss': Array(0.12571457, dtype=float32), 'loss_cross_entropy': Array(0.11963618, dtype=float32)}
{'loss_inverse': Array(0.00227879, dtype=float32)}


  0%|          | 1230/1000000 [03:40<35:52:27,  7.73it/s]

{'loss': Array(0.1276605, dtype=float32), 'loss_cross_entropy': Array(0.12130415, dtype=float32)}
{'loss_inverse': Array(0.00222015, dtype=float32)}


  0%|          | 1240/1000000 [03:41<30:55:23,  8.97it/s]

{'loss': Array(0.13483079, dtype=float32), 'loss_cross_entropy': Array(0.12848924, dtype=float32)}
{'loss_inverse': Array(0.00209936, dtype=float32)}


  0%|          | 1250/1000000 [03:43<28:26:11,  9.76it/s]

{'loss': Array(0.13948682, dtype=float32), 'loss_cross_entropy': Array(0.13264707, dtype=float32)}
{'loss_inverse': Array(0.00200245, dtype=float32)}


  0%|          | 1260/1000000 [03:44<28:55:43,  9.59it/s]

{'loss': Array(0.13641904, dtype=float32), 'loss_cross_entropy': Array(0.13004296, dtype=float32)}
{'loss_inverse': Array(0.00177406, dtype=float32)}


  0%|          | 1270/1000000 [03:46<33:53:35,  8.19it/s]

{'loss': Array(0.1254205, dtype=float32), 'loss_cross_entropy': Array(0.11920978, dtype=float32)}
{'loss_inverse': Array(0.00220733, dtype=float32)}


  0%|          | 1280/1000000 [03:48<29:45:20,  9.32it/s]

{'loss': Array(0.12693922, dtype=float32), 'loss_cross_entropy': Array(0.1203947, dtype=float32)}
{'loss_inverse': Array(0.00192961, dtype=float32)}


  0%|          | 1290/1000000 [03:49<28:33:22,  9.71it/s]

{'loss': Array(0.11078527, dtype=float32), 'loss_cross_entropy': Array(0.10476196, dtype=float32)}
{'loss_inverse': Array(0.00176618, dtype=float32)}


  0%|          | 1300/1000000 [03:51<38:46:11,  7.16it/s]

{'loss': Array(0.13919288, dtype=float32), 'loss_cross_entropy': Array(0.1321125, dtype=float32)}
{'loss_inverse': Array(0.00164923, dtype=float32)}


  0%|          | 1310/1000000 [03:52<29:44:00,  9.33it/s]

{'loss': Array(0.12644242, dtype=float32), 'loss_cross_entropy': Array(0.12040839, dtype=float32)}
{'loss_inverse': Array(0.00185444, dtype=float32)}


  0%|          | 1320/1000000 [03:54<27:55:29,  9.93it/s]

{'loss': Array(0.13963997, dtype=float32), 'loss_cross_entropy': Array(0.13280149, dtype=float32)}
{'loss_inverse': Array(0.00166765, dtype=float32)}


  0%|          | 1330/1000000 [03:56<43:17:54,  6.41it/s]

{'loss': Array(0.12760165, dtype=float32), 'loss_cross_entropy': Array(0.12063064, dtype=float32)}
{'loss_inverse': Array(0.00177893, dtype=float32)}


  0%|          | 1340/1000000 [03:57<30:38:27,  9.05it/s]

{'loss': Array(0.12272602, dtype=float32), 'loss_cross_entropy': Array(0.11609165, dtype=float32)}
{'loss_inverse': Array(0.00173552, dtype=float32)}


  0%|          | 1350/1000000 [03:59<28:18:09,  9.80it/s]

{'loss': Array(0.12131323, dtype=float32), 'loss_cross_entropy': Array(0.11490365, dtype=float32)}
{'loss_inverse': Array(0.00184197, dtype=float32)}


  0%|          | 1360/1000000 [04:00<28:52:41,  9.61it/s]

{'loss': Array(0.12716249, dtype=float32), 'loss_cross_entropy': Array(0.12058844, dtype=float32)}
{'loss_inverse': Array(0.00168599, dtype=float32)}


  0%|          | 1370/1000000 [04:02<49:08:10,  5.65it/s]

{'loss': Array(0.11648799, dtype=float32), 'loss_cross_entropy': Array(0.11017709, dtype=float32)}
{'loss_inverse': Array(0.00150696, dtype=float32)}


  0%|          | 1380/1000000 [04:04<31:56:11,  8.69it/s]

{'loss': Array(0.1265587, dtype=float32), 'loss_cross_entropy': Array(0.1201761, dtype=float32)}
{'loss_inverse': Array(0.00144549, dtype=float32)}


  0%|          | 1389/1000000 [04:05<30:36:09,  9.06it/s]

{'loss': Array(0.13589838, dtype=float32), 'loss_cross_entropy': Array(0.12910591, dtype=float32)}
{'loss_inverse': Array(0.00153629, dtype=float32)}


  0%|          | 1399/1000000 [04:07<28:31:15,  9.73it/s]

{'loss': Array(0.13441728, dtype=float32), 'loss_cross_entropy': Array(0.12805781, dtype=float32)}
{'loss_inverse': Array(0.00127341, dtype=float32)}


  0%|          | 1409/1000000 [04:09<42:59:43,  6.45it/s]

{'loss': Array(0.13695966, dtype=float32), 'loss_cross_entropy': Array(0.13000862, dtype=float32)}
{'loss_inverse': Array(0.0013801, dtype=float32)}


  0%|          | 1420/1000000 [04:10<28:42:35,  9.66it/s]

{'loss': Array(0.11815858, dtype=float32), 'loss_cross_entropy': Array(0.11200126, dtype=float32)}
{'loss_inverse': Array(0.00165789, dtype=float32)}


  0%|          | 1430/1000000 [04:12<27:41:59, 10.01it/s]

{'loss': Array(0.11800839, dtype=float32), 'loss_cross_entropy': Array(0.11170878, dtype=float32)}
{'loss_inverse': Array(0.00158239, dtype=float32)}


  0%|          | 1440/1000000 [04:13<27:48:46,  9.97it/s]

{'loss': Array(0.10230514, dtype=float32), 'loss_cross_entropy': Array(0.09672127, dtype=float32)}
{'loss_inverse': Array(0.00147157, dtype=float32)}


  0%|          | 1450/1000000 [04:15<33:59:23,  8.16it/s]

{'loss': Array(0.1282941, dtype=float32), 'loss_cross_entropy': Array(0.12187263, dtype=float32)}
{'loss_inverse': Array(0.00136123, dtype=float32)}


  0%|          | 1460/1000000 [04:17<28:57:19,  9.58it/s]

{'loss': Array(0.13116549, dtype=float32), 'loss_cross_entropy': Array(0.12425923, dtype=float32)}
{'loss_inverse': Array(0.0013069, dtype=float32)}


  0%|          | 1470/1000000 [04:18<28:26:50,  9.75it/s]

{'loss': Array(0.12778842, dtype=float32), 'loss_cross_entropy': Array(0.12111473, dtype=float32)}
{'loss_inverse': Array(0.00156154, dtype=float32)}


  0%|          | 1480/1000000 [04:20<39:31:15,  7.02it/s]

{'loss': Array(0.11580864, dtype=float32), 'loss_cross_entropy': Array(0.10975613, dtype=float32)}
{'loss_inverse': Array(0.0014866, dtype=float32)}


  0%|          | 1490/1000000 [04:22<30:01:01,  9.24it/s]

{'loss': Array(0.13276553, dtype=float32), 'loss_cross_entropy': Array(0.12614177, dtype=float32)}
{'loss_inverse': Array(0.00143602, dtype=float32)}


  0%|          | 1500/1000000 [04:23<28:11:55,  9.84it/s]

{'loss': Array(0.13362975, dtype=float32), 'loss_cross_entropy': Array(0.12716037, dtype=float32)}
{'loss_inverse': Array(0.00127641, dtype=float32)}


  0%|          | 1510/1000000 [04:31<94:21:19,  2.94it/s] 

{'loss': Array(0.13062738, dtype=float32), 'loss_cross_entropy': Array(0.12392282, dtype=float32)}
{'loss_inverse': Array(0.00124712, dtype=float32)}


  0%|          | 1520/1000000 [04:33<46:33:56,  5.96it/s] 

{'loss': Array(0.1399421, dtype=float32), 'loss_cross_entropy': Array(0.13341893, dtype=float32)}
{'loss_inverse': Array(0.00124586, dtype=float32)}


  0%|          | 1530/1000000 [04:34<30:48:14,  9.00it/s]

{'loss': Array(0.12939797, dtype=float32), 'loss_cross_entropy': Array(0.12269055, dtype=float32)}
{'loss_inverse': Array(0.00117078, dtype=float32)}


  0%|          | 1540/1000000 [04:36<29:43:15,  9.33it/s]

{'loss': Array(0.13617584, dtype=float32), 'loss_cross_entropy': Array(0.12916747, dtype=float32)}
{'loss_inverse': Array(0.00127202, dtype=float32)}


  0%|          | 1550/1000000 [04:37<28:10:04,  9.85it/s]

{'loss': Array(0.1388017, dtype=float32), 'loss_cross_entropy': Array(0.13180345, dtype=float32)}
{'loss_inverse': Array(0.00137165, dtype=float32)}


  0%|          | 1560/1000000 [04:39<38:29:03,  7.21it/s]

{'loss': Array(0.10914327, dtype=float32), 'loss_cross_entropy': Array(0.10389829, dtype=float32)}
{'loss_inverse': Array(0.00139584, dtype=float32)}


  0%|          | 1570/1000000 [04:41<30:04:48,  9.22it/s]

{'loss': Array(0.1332715, dtype=float32), 'loss_cross_entropy': Array(0.12665333, dtype=float32)}
{'loss_inverse': Array(0.00110022, dtype=float32)}


  0%|          | 1580/1000000 [04:42<28:11:28,  9.84it/s]

{'loss': Array(0.12811612, dtype=float32), 'loss_cross_entropy': Array(0.12169772, dtype=float32)}
{'loss_inverse': Array(0.00112731, dtype=float32)}


  0%|          | 1590/1000000 [04:44<27:23:04, 10.13it/s]

{'loss': Array(0.14520103, dtype=float32), 'loss_cross_entropy': Array(0.13769992, dtype=float32)}
{'loss_inverse': Array(0.00131612, dtype=float32)}


  0%|          | 1600/1000000 [04:46<43:14:50,  6.41it/s]

{'loss': Array(0.13316989, dtype=float32), 'loss_cross_entropy': Array(0.126283, dtype=float32)}
{'loss_inverse': Array(0.00127469, dtype=float32)}


  0%|          | 1610/1000000 [04:47<30:34:24,  9.07it/s]

{'loss': Array(0.11769217, dtype=float32), 'loss_cross_entropy': Array(0.11150854, dtype=float32)}
{'loss_inverse': Array(0.00119323, dtype=float32)}


  0%|          | 1620/1000000 [04:49<28:24:06,  9.76it/s]

{'loss': Array(0.11186849, dtype=float32), 'loss_cross_entropy': Array(0.10606103, dtype=float32)}
{'loss_inverse': Array(0.00116752, dtype=float32)}


  0%|          | 1630/1000000 [04:50<29:05:52,  9.53it/s]

{'loss': Array(0.11995635, dtype=float32), 'loss_cross_entropy': Array(0.11327725, dtype=float32)}
{'loss_inverse': Array(0.00121553, dtype=float32)}


  0%|          | 1640/1000000 [04:52<43:03:09,  6.44it/s]

{'loss': Array(0.1319304, dtype=float32), 'loss_cross_entropy': Array(0.12506829, dtype=float32)}
{'loss_inverse': Array(0.00094828, dtype=float32)}


  0%|          | 1650/1000000 [04:54<30:36:42,  9.06it/s]

{'loss': Array(0.12284458, dtype=float32), 'loss_cross_entropy': Array(0.11685568, dtype=float32)}
{'loss_inverse': Array(0.0009164, dtype=float32)}


  0%|          | 1660/1000000 [04:55<29:21:59,  9.44it/s]

{'loss': Array(0.1382917, dtype=float32), 'loss_cross_entropy': Array(0.13146618, dtype=float32)}
{'loss_inverse': Array(0.00104019, dtype=float32)}


  0%|          | 1670/1000000 [04:57<28:24:49,  9.76it/s]

{'loss': Array(0.1367072, dtype=float32), 'loss_cross_entropy': Array(0.12971476, dtype=float32)}
{'loss_inverse': Array(0.00110619, dtype=float32)}


  0%|          | 1680/1000000 [04:59<33:32:54,  8.27it/s]

{'loss': Array(0.13289951, dtype=float32), 'loss_cross_entropy': Array(0.12588714, dtype=float32)}
{'loss_inverse': Array(0.00103446, dtype=float32)}


  0%|          | 1690/1000000 [05:00<30:51:18,  8.99it/s]

{'loss': Array(0.13446598, dtype=float32), 'loss_cross_entropy': Array(0.12760246, dtype=float32)}
{'loss_inverse': Array(0.00099988, dtype=float32)}


  0%|          | 1700/1000000 [05:02<28:41:11,  9.67it/s]

{'loss': Array(0.12632163, dtype=float32), 'loss_cross_entropy': Array(0.11955135, dtype=float32)}
{'loss_inverse': Array(0.00094058, dtype=float32)}


  0%|          | 1710/1000000 [05:04<38:48:39,  7.14it/s]

{'loss': Array(0.11686315, dtype=float32), 'loss_cross_entropy': Array(0.11080142, dtype=float32)}
{'loss_inverse': Array(0.00095583, dtype=float32)}


  0%|          | 1720/1000000 [05:05<31:22:36,  8.84it/s]

{'loss': Array(0.10923686, dtype=float32), 'loss_cross_entropy': Array(0.10372329, dtype=float32)}
{'loss_inverse': Array(0.00090673, dtype=float32)}


  0%|          | 1730/1000000 [05:07<29:00:05,  9.56it/s]

{'loss': Array(0.12910257, dtype=float32), 'loss_cross_entropy': Array(0.12243296, dtype=float32)}
{'loss_inverse': Array(0.00089119, dtype=float32)}


  0%|          | 1740/1000000 [05:09<49:05:10,  5.65it/s]

{'loss': Array(0.13328917, dtype=float32), 'loss_cross_entropy': Array(0.12670183, dtype=float32)}
{'loss_inverse': Array(0.00102627, dtype=float32)}


  0%|          | 1750/1000000 [05:10<33:03:50,  8.39it/s]

{'loss': Array(0.13321702, dtype=float32), 'loss_cross_entropy': Array(0.12669289, dtype=float32)}
{'loss_inverse': Array(0.00119039, dtype=float32)}


  0%|          | 1760/1000000 [05:12<29:11:01,  9.50it/s]

{'loss': Array(0.13272066, dtype=float32), 'loss_cross_entropy': Array(0.12593842, dtype=float32)}
{'loss_inverse': Array(0.00118047, dtype=float32)}


  0%|          | 1770/1000000 [05:13<28:29:35,  9.73it/s]

{'loss': Array(0.13515508, dtype=float32), 'loss_cross_entropy': Array(0.12839726, dtype=float32)}
{'loss_inverse': Array(0.00095366, dtype=float32)}


  0%|          | 1780/1000000 [05:15<36:51:01,  7.52it/s]

{'loss': Array(0.1362379, dtype=float32), 'loss_cross_entropy': Array(0.12956843, dtype=float32)}
{'loss_inverse': Array(0.00107879, dtype=float32)}


  0%|          | 1790/1000000 [05:16<29:47:51,  9.31it/s]

{'loss': Array(0.10832145, dtype=float32), 'loss_cross_entropy': Array(0.10241669, dtype=float32)}
{'loss_inverse': Array(0.00098232, dtype=float32)}


  0%|          | 1800/1000000 [05:18<29:07:39,  9.52it/s]

{'loss': Array(0.1444471, dtype=float32), 'loss_cross_entropy': Array(0.13722222, dtype=float32)}
{'loss_inverse': Array(0.00118718, dtype=float32)}


  0%|          | 1810/1000000 [05:19<28:29:04,  9.73it/s]

{'loss': Array(0.13066588, dtype=float32), 'loss_cross_entropy': Array(0.12471335, dtype=float32)}
{'loss_inverse': Array(0.00167925, dtype=float32)}


  0%|          | 1820/1000000 [05:21<33:40:16,  8.23it/s]

{'loss': Array(0.12554824, dtype=float32), 'loss_cross_entropy': Array(0.11906736, dtype=float32)}
{'loss_inverse': Array(0.00138934, dtype=float32)}


  0%|          | 1830/1000000 [05:23<29:09:04,  9.51it/s]

{'loss': Array(0.10706103, dtype=float32), 'loss_cross_entropy': Array(0.10128202, dtype=float32)}
{'loss_inverse': Array(0.00107372, dtype=float32)}


  0%|          | 1840/1000000 [05:24<28:18:21,  9.80it/s]

{'loss': Array(0.13572133, dtype=float32), 'loss_cross_entropy': Array(0.12942474, dtype=float32)}
{'loss_inverse': Array(0.00101858, dtype=float32)}


  0%|          | 1850/1000000 [05:26<28:36:30,  9.69it/s]

{'loss': Array(0.15139072, dtype=float32), 'loss_cross_entropy': Array(0.14467229, dtype=float32)}
{'loss_inverse': Array(0.00081544, dtype=float32)}


  0%|          | 1860/1000000 [05:28<35:21:19,  7.84it/s]

{'loss': Array(0.13774744, dtype=float32), 'loss_cross_entropy': Array(0.13098106, dtype=float32)}
{'loss_inverse': Array(0.00082831, dtype=float32)}


  0%|          | 1870/1000000 [05:29<29:47:37,  9.31it/s]

{'loss': Array(0.12493124, dtype=float32), 'loss_cross_entropy': Array(0.11813994, dtype=float32)}
{'loss_inverse': Array(0.00080977, dtype=float32)}


  0%|          | 1880/1000000 [05:31<28:47:56,  9.63it/s]

{'loss': Array(0.13543625, dtype=float32), 'loss_cross_entropy': Array(0.1292983, dtype=float32)}
{'loss_inverse': Array(0.00079191, dtype=float32)}


  0%|          | 1890/1000000 [05:32<28:46:07,  9.64it/s]

{'loss': Array(0.12309003, dtype=float32), 'loss_cross_entropy': Array(0.11613321, dtype=float32)}
{'loss_inverse': Array(0.00082549, dtype=float32)}


  0%|          | 1900/1000000 [05:34<33:23:15,  8.30it/s]

{'loss': Array(0.10856485, dtype=float32), 'loss_cross_entropy': Array(0.10277062, dtype=float32)}
{'loss_inverse': Array(0.00078221, dtype=float32)}


  0%|          | 1910/1000000 [05:36<29:18:51,  9.46it/s]

{'loss': Array(0.12614205, dtype=float32), 'loss_cross_entropy': Array(0.11953324, dtype=float32)}
{'loss_inverse': Array(0.00085362, dtype=float32)}


  0%|          | 1920/1000000 [05:37<28:16:01,  9.81it/s]

{'loss': Array(0.14063405, dtype=float32), 'loss_cross_entropy': Array(0.13443835, dtype=float32)}
{'loss_inverse': Array(0.00070745, dtype=float32)}


  0%|          | 1930/1000000 [05:39<37:46:09,  7.34it/s]

{'loss': Array(0.13752063, dtype=float32), 'loss_cross_entropy': Array(0.13008466, dtype=float32)}
{'loss_inverse': Array(0.00068413, dtype=float32)}


  0%|          | 1940/1000000 [05:41<29:42:27,  9.33it/s]

{'loss': Array(0.13602439, dtype=float32), 'loss_cross_entropy': Array(0.12937345, dtype=float32)}
{'loss_inverse': Array(0.00061115, dtype=float32)}


  0%|          | 1950/1000000 [05:42<29:02:27,  9.55it/s]

{'loss': Array(0.1249239, dtype=float32), 'loss_cross_entropy': Array(0.11841408, dtype=float32)}
{'loss_inverse': Array(0.00057604, dtype=float32)}


  0%|          | 1960/1000000 [05:44<43:25:10,  6.38it/s]

{'loss': Array(0.13719894, dtype=float32), 'loss_cross_entropy': Array(0.13012488, dtype=float32)}
{'loss_inverse': Array(0.00067411, dtype=float32)}


  0%|          | 1970/1000000 [05:46<31:26:12,  8.82it/s]

{'loss': Array(0.11696861, dtype=float32), 'loss_cross_entropy': Array(0.11086088, dtype=float32)}
{'loss_inverse': Array(0.00065057, dtype=float32)}


  0%|          | 1980/1000000 [05:47<28:38:19,  9.68it/s]

{'loss': Array(0.11096513, dtype=float32), 'loss_cross_entropy': Array(0.10512742, dtype=float32)}
{'loss_inverse': Array(0.00071604, dtype=float32)}


  0%|          | 1990/1000000 [05:49<27:27:09, 10.10it/s]

{'loss': Array(0.11674752, dtype=float32), 'loss_cross_entropy': Array(0.11031613, dtype=float32)}
{'loss_inverse': Array(0.00078914, dtype=float32)}


  0%|          | 2000/1000000 [05:51<49:40:23,  5.58it/s]

{'loss': Array(0.11153787, dtype=float32), 'loss_cross_entropy': Array(0.10603692, dtype=float32)}
{'loss_inverse': Array(0.00060912, dtype=float32)}


  0%|          | 2009/1000000 [05:59<103:48:28,  2.67it/s]

{'loss': Array(0.13873503, dtype=float32), 'loss_cross_entropy': Array(0.13195737, dtype=float32)}
{'loss_inverse': Array(0.00066658, dtype=float32)}


  0%|          | 2019/1000000 [06:00<41:00:22,  6.76it/s] 

{'loss': Array(0.13003737, dtype=float32), 'loss_cross_entropy': Array(0.1235429, dtype=float32)}
{'loss_inverse': Array(0.00072938, dtype=float32)}


  0%|          | 2029/1000000 [06:02<29:44:10,  9.32it/s]

{'loss': Array(0.13099562, dtype=float32), 'loss_cross_entropy': Array(0.12465888, dtype=float32)}
{'loss_inverse': Array(0.00047607, dtype=float32)}


  0%|          | 2039/1000000 [06:04<49:07:30,  5.64it/s]

{'loss': Array(0.13287048, dtype=float32), 'loss_cross_entropy': Array(0.12577794, dtype=float32)}
{'loss_inverse': Array(0.0007547, dtype=float32)}


  0%|          | 2050/1000000 [06:05<29:10:03,  9.50it/s]

{'loss': Array(0.13121234, dtype=float32), 'loss_cross_entropy': Array(0.12470768, dtype=float32)}
{'loss_inverse': Array(0.00064842, dtype=float32)}


  0%|          | 2060/1000000 [06:07<27:49:56,  9.96it/s]

{'loss': Array(0.12856737, dtype=float32), 'loss_cross_entropy': Array(0.12151083, dtype=float32)}
{'loss_inverse': Array(0.00056329, dtype=float32)}


  0%|          | 2070/1000000 [06:08<28:17:41,  9.80it/s]

{'loss': Array(0.1461679, dtype=float32), 'loss_cross_entropy': Array(0.13905929, dtype=float32)}
{'loss_inverse': Array(0.00071485, dtype=float32)}


  0%|          | 2080/1000000 [06:10<34:56:48,  7.93it/s]

{'loss': Array(0.12793688, dtype=float32), 'loss_cross_entropy': Array(0.121452, dtype=float32)}
{'loss_inverse': Array(0.00063716, dtype=float32)}


  0%|          | 2090/1000000 [06:12<28:45:21,  9.64it/s]

{'loss': Array(0.12222462, dtype=float32), 'loss_cross_entropy': Array(0.11604726, dtype=float32)}
{'loss_inverse': Array(0.00077798, dtype=float32)}


  0%|          | 2100/1000000 [06:13<27:41:59, 10.01it/s]

{'loss': Array(0.117455, dtype=float32), 'loss_cross_entropy': Array(0.11118468, dtype=float32)}
{'loss_inverse': Array(0.00072122, dtype=float32)}


  0%|          | 2110/1000000 [06:15<36:15:31,  7.64it/s]

{'loss': Array(0.12812568, dtype=float32), 'loss_cross_entropy': Array(0.12184824, dtype=float32)}
{'loss_inverse': Array(0.00067879, dtype=float32)}


  0%|          | 2120/1000000 [06:17<28:59:47,  9.56it/s]

{'loss': Array(0.11941879, dtype=float32), 'loss_cross_entropy': Array(0.11274399, dtype=float32)}
{'loss_inverse': Array(0.00067533, dtype=float32)}


  0%|          | 2130/1000000 [06:18<27:58:54,  9.91it/s]

{'loss': Array(0.12354962, dtype=float32), 'loss_cross_entropy': Array(0.11669022, dtype=float32)}
{'loss_inverse': Array(0.00073895, dtype=float32)}


  0%|          | 2140/1000000 [06:20<28:45:46,  9.64it/s]

{'loss': Array(0.13484246, dtype=float32), 'loss_cross_entropy': Array(0.12792367, dtype=float32)}
{'loss_inverse': Array(0.00070213, dtype=float32)}


  0%|          | 2150/1000000 [06:22<38:53:45,  7.13it/s]

{'loss': Array(0.13206157, dtype=float32), 'loss_cross_entropy': Array(0.12537742, dtype=float32)}
{'loss_inverse': Array(0.00053242, dtype=float32)}


  0%|          | 2160/1000000 [06:23<29:56:00,  9.26it/s]

{'loss': Array(0.12466685, dtype=float32), 'loss_cross_entropy': Array(0.11819126, dtype=float32)}
{'loss_inverse': Array(0.0005149, dtype=float32)}


  0%|          | 2170/1000000 [06:25<27:49:52,  9.96it/s]

{'loss': Array(0.14083229, dtype=float32), 'loss_cross_entropy': Array(0.13393612, dtype=float32)}
{'loss_inverse': Array(0.00060523, dtype=float32)}


  0%|          | 2180/1000000 [06:26<28:43:08,  9.65it/s]

{'loss': Array(0.15404378, dtype=float32), 'loss_cross_entropy': Array(0.14696068, dtype=float32)}
{'loss_inverse': Array(0.00057605, dtype=float32)}


  0%|          | 2190/1000000 [06:28<33:23:53,  8.30it/s]

{'loss': Array(0.14810652, dtype=float32), 'loss_cross_entropy': Array(0.14106326, dtype=float32)}
{'loss_inverse': Array(0.00059232, dtype=float32)}


  0%|          | 2200/1000000 [06:30<28:58:59,  9.56it/s]

{'loss': Array(0.12151419, dtype=float32), 'loss_cross_entropy': Array(0.11518034, dtype=float32)}
{'loss_inverse': Array(0.00070307, dtype=float32)}


  0%|          | 2210/1000000 [06:31<29:09:59,  9.50it/s]

{'loss': Array(0.11595637, dtype=float32), 'loss_cross_entropy': Array(0.10956756, dtype=float32)}
{'loss_inverse': Array(0.00055138, dtype=float32)}


  0%|          | 2220/1000000 [06:33<43:24:37,  6.38it/s]

{'loss': Array(0.1243551, dtype=float32), 'loss_cross_entropy': Array(0.11772484, dtype=float32)}
{'loss_inverse': Array(0.00061624, dtype=float32)}


  0%|          | 2230/1000000 [06:35<30:53:48,  8.97it/s]

{'loss': Array(0.12720472, dtype=float32), 'loss_cross_entropy': Array(0.12065404, dtype=float32)}
{'loss_inverse': Array(0.00054454, dtype=float32)}


  0%|          | 2240/1000000 [06:36<29:12:48,  9.49it/s]

{'loss': Array(0.12233268, dtype=float32), 'loss_cross_entropy': Array(0.11545445, dtype=float32)}
{'loss_inverse': Array(0.00053962, dtype=float32)}


  0%|          | 2250/1000000 [06:38<28:21:41,  9.77it/s]

{'loss': Array(0.10475963, dtype=float32), 'loss_cross_entropy': Array(0.0994307, dtype=float32)}
{'loss_inverse': Array(0.00046039, dtype=float32)}


  0%|          | 2260/1000000 [06:39<32:49:41,  8.44it/s]

{'loss': Array(0.12959659, dtype=float32), 'loss_cross_entropy': Array(0.12273379, dtype=float32)}
{'loss_inverse': Array(0.00051562, dtype=float32)}


  0%|          | 2270/1000000 [06:41<28:57:20,  9.57it/s]

{'loss': Array(0.13562766, dtype=float32), 'loss_cross_entropy': Array(0.12843363, dtype=float32)}
{'loss_inverse': Array(0.00036457, dtype=float32)}


  0%|          | 2280/1000000 [06:42<27:55:14,  9.93it/s]

{'loss': Array(0.11860663, dtype=float32), 'loss_cross_entropy': Array(0.11218859, dtype=float32)}
{'loss_inverse': Array(0.00042367, dtype=float32)}


  0%|          | 2290/1000000 [06:44<38:37:46,  7.17it/s]

{'loss': Array(0.11481722, dtype=float32), 'loss_cross_entropy': Array(0.10900284, dtype=float32)}
{'loss_inverse': Array(0.00050288, dtype=float32)}


  0%|          | 2300/1000000 [06:46<30:26:49,  9.10it/s]

{'loss': Array(0.13989733, dtype=float32), 'loss_cross_entropy': Array(0.13284719, dtype=float32)}
{'loss_inverse': Array(0.00042536, dtype=float32)}


  0%|          | 2310/1000000 [06:47<28:44:26,  9.64it/s]

{'loss': Array(0.12135834, dtype=float32), 'loss_cross_entropy': Array(0.11463193, dtype=float32)}
{'loss_inverse': Array(0.00046799, dtype=float32)}


  0%|          | 2320/1000000 [06:49<27:59:30,  9.90it/s]

{'loss': Array(0.14024802, dtype=float32), 'loss_cross_entropy': Array(0.13326614, dtype=float32)}
{'loss_inverse': Array(0.00038499, dtype=float32)}


  0%|          | 2330/1000000 [06:51<35:43:30,  7.76it/s]

{'loss': Array(0.12341794, dtype=float32), 'loss_cross_entropy': Array(0.11727146, dtype=float32)}
{'loss_inverse': Array(0.00037453, dtype=float32)}


  0%|          | 2340/1000000 [06:52<29:07:46,  9.51it/s]

{'loss': Array(0.12284001, dtype=float32), 'loss_cross_entropy': Array(0.11631757, dtype=float32)}
{'loss_inverse': Array(0.00037157, dtype=float32)}


  0%|          | 2350/1000000 [06:54<28:26:48,  9.74it/s]

{'loss': Array(0.1210342, dtype=float32), 'loss_cross_entropy': Array(0.11423898, dtype=float32)}
{'loss_inverse': Array(0.00042486, dtype=float32)}


  0%|          | 2360/1000000 [06:55<29:04:24,  9.53it/s]

{'loss': Array(0.12911168, dtype=float32), 'loss_cross_entropy': Array(0.12218546, dtype=float32)}
{'loss_inverse': Array(0.00041271, dtype=float32)}


  0%|          | 2370/1000000 [06:57<35:50:16,  7.73it/s]

{'loss': Array(0.1288216, dtype=float32), 'loss_cross_entropy': Array(0.12219252, dtype=float32)}
{'loss_inverse': Array(0.00039068, dtype=float32)}


  0%|          | 2380/1000000 [06:59<29:09:41,  9.50it/s]

{'loss': Array(0.11049705, dtype=float32), 'loss_cross_entropy': Array(0.10462242, dtype=float32)}
{'loss_inverse': Array(0.00039357, dtype=float32)}


  0%|          | 2390/1000000 [07:00<30:11:57,  9.18it/s]

{'loss': Array(0.13945363, dtype=float32), 'loss_cross_entropy': Array(0.13298117, dtype=float32)}
{'loss_inverse': Array(0.00064966, dtype=float32)}


  0%|          | 2400/1000000 [07:02<28:26:30,  9.74it/s]

{'loss': Array(0.12987387, dtype=float32), 'loss_cross_entropy': Array(0.1226619, dtype=float32)}
{'loss_inverse': Array(0.00053886, dtype=float32)}


  0%|          | 2410/1000000 [07:04<32:59:38,  8.40it/s]

{'loss': Array(0.13766631, dtype=float32), 'loss_cross_entropy': Array(0.13089287, dtype=float32)}
{'loss_inverse': Array(0.00056668, dtype=float32)}


  0%|          | 2420/1000000 [07:05<30:19:46,  9.14it/s]

{'loss': Array(0.12147024, dtype=float32), 'loss_cross_entropy': Array(0.11526904, dtype=float32)}
{'loss_inverse': Array(0.00066106, dtype=float32)}


  0%|          | 2430/1000000 [07:07<28:32:20,  9.71it/s]

{'loss': Array(0.12995611, dtype=float32), 'loss_cross_entropy': Array(0.12298223, dtype=float32)}
{'loss_inverse': Array(0.0006105, dtype=float32)}


  0%|          | 2440/1000000 [07:09<42:34:54,  6.51it/s]

{'loss': Array(0.13368116, dtype=float32), 'loss_cross_entropy': Array(0.12716001, dtype=float32)}
{'loss_inverse': Array(0.00073358, dtype=float32)}


  0%|          | 2450/1000000 [07:10<31:33:37,  8.78it/s]

{'loss': Array(0.1311674, dtype=float32), 'loss_cross_entropy': Array(0.12475156, dtype=float32)}
{'loss_inverse': Array(0.00056281, dtype=float32)}


  0%|          | 2460/1000000 [07:12<28:43:33,  9.65it/s]

{'loss': Array(0.14220032, dtype=float32), 'loss_cross_entropy': Array(0.13456976, dtype=float32)}
{'loss_inverse': Array(0.00038565, dtype=float32)}


  0%|          | 2470/1000000 [07:13<28:20:08,  9.78it/s]

{'loss': Array(0.12756254, dtype=float32), 'loss_cross_entropy': Array(0.12096324, dtype=float32)}
{'loss_inverse': Array(0.00060136, dtype=float32)}


  0%|          | 2480/1000000 [07:15<36:37:14,  7.57it/s]

{'loss': Array(0.1497829, dtype=float32), 'loss_cross_entropy': Array(0.14278431, dtype=float32)}
{'loss_inverse': Array(0.00044992, dtype=float32)}


  0%|          | 2490/1000000 [07:17<29:49:29,  9.29it/s]

{'loss': Array(0.1413069, dtype=float32), 'loss_cross_entropy': Array(0.1345957, dtype=float32)}
{'loss_inverse': Array(0.00044248, dtype=float32)}


  0%|          | 2500/1000000 [07:18<28:40:42,  9.66it/s]

{'loss': Array(0.1337333, dtype=float32), 'loss_cross_entropy': Array(0.12633456, dtype=float32)}
{'loss_inverse': Array(0.00042898, dtype=float32)}


  0%|          | 2510/1000000 [07:27<107:04:40,  2.59it/s]

{'loss': Array(0.128491, dtype=float32), 'loss_cross_entropy': Array(0.12224573, dtype=float32)}
{'loss_inverse': Array(0.00039617, dtype=float32)}


  0%|          | 2520/1000000 [07:28<41:22:41,  6.70it/s] 

{'loss': Array(0.14943115, dtype=float32), 'loss_cross_entropy': Array(0.14091876, dtype=float32)}
{'loss_inverse': Array(0.00039177, dtype=float32)}


  0%|          | 2530/1000000 [07:30<32:13:21,  8.60it/s]

{'loss': Array(0.12618141, dtype=float32), 'loss_cross_entropy': Array(0.11907449, dtype=float32)}
{'loss_inverse': Array(0.00040778, dtype=float32)}


  0%|          | 2540/1000000 [07:31<30:24:46,  9.11it/s]

{'loss': Array(0.15878074, dtype=float32), 'loss_cross_entropy': Array(0.15006994, dtype=float32)}
{'loss_inverse': Array(0.00039265, dtype=float32)}


  0%|          | 2550/1000000 [07:33<39:08:50,  7.08it/s]

{'loss': Array(0.12869202, dtype=float32), 'loss_cross_entropy': Array(0.12202179, dtype=float32)}
{'loss_inverse': Array(0.00048554, dtype=float32)}


  0%|          | 2560/1000000 [07:35<32:19:52,  8.57it/s]

{'loss': Array(0.14873177, dtype=float32), 'loss_cross_entropy': Array(0.14111711, dtype=float32)}
{'loss_inverse': Array(0.00043689, dtype=float32)}


  0%|          | 2570/1000000 [07:36<29:42:27,  9.33it/s]

{'loss': Array(0.1370731, dtype=float32), 'loss_cross_entropy': Array(0.13046525, dtype=float32)}
{'loss_inverse': Array(0.00046091, dtype=float32)}


  0%|          | 2580/1000000 [07:38<28:30:04,  9.72it/s]

{'loss': Array(0.12604956, dtype=float32), 'loss_cross_entropy': Array(0.11873028, dtype=float32)}
{'loss_inverse': Array(0.00044639, dtype=float32)}


  0%|          | 2590/1000000 [07:40<37:47:55,  7.33it/s]

{'loss': Array(0.12215527, dtype=float32), 'loss_cross_entropy': Array(0.11594605, dtype=float32)}
{'loss_inverse': Array(0.00044896, dtype=float32)}


  0%|          | 2600/1000000 [07:41<30:29:57,  9.08it/s]

{'loss': Array(0.12439405, dtype=float32), 'loss_cross_entropy': Array(0.11777158, dtype=float32)}
{'loss_inverse': Array(0.00048915, dtype=float32)}


  0%|          | 2610/1000000 [07:43<28:20:32,  9.78it/s]

{'loss': Array(0.13002524, dtype=float32), 'loss_cross_entropy': Array(0.12352046, dtype=float32)}
{'loss_inverse': Array(0.00050617, dtype=float32)}


  0%|          | 2620/1000000 [07:45<43:16:22,  6.40it/s]

{'loss': Array(0.13975857, dtype=float32), 'loss_cross_entropy': Array(0.13274167, dtype=float32)}
{'loss_inverse': Array(0.00045189, dtype=float32)}


  0%|          | 2630/1000000 [07:46<30:12:12,  9.17it/s]

{'loss': Array(0.13205756, dtype=float32), 'loss_cross_entropy': Array(0.1251448, dtype=float32)}
{'loss_inverse': Array(0.00054074, dtype=float32)}


  0%|          | 2640/1000000 [07:48<28:28:33,  9.73it/s]

{'loss': Array(0.15576072, dtype=float32), 'loss_cross_entropy': Array(0.14777029, dtype=float32)}
{'loss_inverse': Array(0.00077255, dtype=float32)}


  0%|          | 2650/1000000 [07:50<49:58:50,  5.54it/s]

{'loss': Array(0.12513645, dtype=float32), 'loss_cross_entropy': Array(0.11869609, dtype=float32)}
{'loss_inverse': Array(0.00076333, dtype=float32)}


  0%|          | 2660/1000000 [07:51<31:32:48,  8.78it/s]

{'loss': Array(0.15161668, dtype=float32), 'loss_cross_entropy': Array(0.14413083, dtype=float32)}
{'loss_inverse': Array(0.00092467, dtype=float32)}


  0%|          | 2670/1000000 [07:53<28:08:51,  9.84it/s]

{'loss': Array(0.12440505, dtype=float32), 'loss_cross_entropy': Array(0.11791661, dtype=float32)}
{'loss_inverse': Array(0.00086474, dtype=float32)}


  0%|          | 2680/1000000 [07:54<27:54:55,  9.92it/s]

{'loss': Array(0.15500869, dtype=float32), 'loss_cross_entropy': Array(0.14742407, dtype=float32)}
{'loss_inverse': Array(0.00136132, dtype=float32)}


  0%|          | 2690/1000000 [07:56<43:05:54,  6.43it/s]

{'loss': Array(0.12326508, dtype=float32), 'loss_cross_entropy': Array(0.11701488, dtype=float32)}
{'loss_inverse': Array(0.00183451, dtype=float32)}


  0%|          | 2700/1000000 [07:58<30:57:07,  8.95it/s]

{'loss': Array(0.12864226, dtype=float32), 'loss_cross_entropy': Array(0.12186902, dtype=float32)}
{'loss_inverse': Array(0.00203292, dtype=float32)}


  0%|          | 2710/1000000 [07:59<27:54:39,  9.93it/s]

{'loss': Array(0.12315788, dtype=float32), 'loss_cross_entropy': Array(0.11655904, dtype=float32)}
{'loss_inverse': Array(0.01856505, dtype=float32)}


  0%|          | 2720/1000000 [08:00<28:57:09,  9.57it/s]

{'loss': Array(0.12869674, dtype=float32), 'loss_cross_entropy': Array(0.12229245, dtype=float32)}
{'loss_inverse': Array(0.01439741, dtype=float32)}


  0%|          | 2730/1000000 [08:02<43:07:11,  6.42it/s]

{'loss': Array(0.12501702, dtype=float32), 'loss_cross_entropy': Array(0.11854146, dtype=float32)}
{'loss_inverse': Array(0.00970135, dtype=float32)}


  0%|          | 2740/1000000 [08:04<29:54:59,  9.26it/s]

{'loss': Array(0.13288096, dtype=float32), 'loss_cross_entropy': Array(0.12625833, dtype=float32)}
{'loss_inverse': Array(0.00892363, dtype=float32)}


  0%|          | 2750/1000000 [08:05<29:42:11,  9.33it/s]

{'loss': Array(0.14054646, dtype=float32), 'loss_cross_entropy': Array(0.1333559, dtype=float32)}
{'loss_inverse': Array(0.00363733, dtype=float32)}


  0%|          | 2760/1000000 [08:07<28:03:39,  9.87it/s]

{'loss': Array(0.12057555, dtype=float32), 'loss_cross_entropy': Array(0.11439315, dtype=float32)}
{'loss_inverse': Array(0.00164496, dtype=float32)}


  0%|          | 2770/1000000 [08:09<38:27:18,  7.20it/s]

{'loss': Array(0.12991808, dtype=float32), 'loss_cross_entropy': Array(0.12387109, dtype=float32)}
{'loss_inverse': Array(0.00103418, dtype=float32)}


  0%|          | 2779/1000000 [08:10<31:29:41,  8.80it/s]

{'loss': Array(0.12570058, dtype=float32), 'loss_cross_entropy': Array(0.11869276, dtype=float32)}
{'loss_inverse': Array(0.00106341, dtype=float32)}


  0%|          | 2789/1000000 [08:12<29:00:34,  9.55it/s]

{'loss': Array(0.1320614, dtype=float32), 'loss_cross_entropy': Array(0.12519221, dtype=float32)}
{'loss_inverse': Array(0.00066845, dtype=float32)}


  0%|          | 2799/1000000 [08:13<28:23:53,  9.75it/s]

{'loss': Array(0.12872009, dtype=float32), 'loss_cross_entropy': Array(0.12185766, dtype=float32)}
{'loss_inverse': Array(0.00058821, dtype=float32)}


  0%|          | 2809/1000000 [08:15<34:17:47,  8.08it/s]

{'loss': Array(0.13443606, dtype=float32), 'loss_cross_entropy': Array(0.12800165, dtype=float32)}
{'loss_inverse': Array(0.00041942, dtype=float32)}


  0%|          | 2819/1000000 [08:17<30:02:42,  9.22it/s]

{'loss': Array(0.12595221, dtype=float32), 'loss_cross_entropy': Array(0.11952555, dtype=float32)}
{'loss_inverse': Array(0.00034814, dtype=float32)}


  0%|          | 2829/1000000 [08:18<28:26:31,  9.74it/s]

{'loss': Array(0.11898929, dtype=float32), 'loss_cross_entropy': Array(0.1124712, dtype=float32)}
{'loss_inverse': Array(0.00033059, dtype=float32)}


  0%|          | 2839/1000000 [08:20<36:12:48,  7.65it/s]

{'loss': Array(0.13262954, dtype=float32), 'loss_cross_entropy': Array(0.12609534, dtype=float32)}
{'loss_inverse': Array(0.00030267, dtype=float32)}


  0%|          | 2849/1000000 [08:22<30:08:05,  9.19it/s]

{'loss': Array(0.14535424, dtype=float32), 'loss_cross_entropy': Array(0.13839136, dtype=float32)}
{'loss_inverse': Array(0.00029425, dtype=float32)}


  0%|          | 2859/1000000 [08:23<28:34:51,  9.69it/s]

{'loss': Array(0.12258434, dtype=float32), 'loss_cross_entropy': Array(0.11623126, dtype=float32)}
{'loss_inverse': Array(0.00033997, dtype=float32)}


  0%|          | 2869/1000000 [08:25<43:13:25,  6.41it/s]

{'loss': Array(0.14033455, dtype=float32), 'loss_cross_entropy': Array(0.13313738, dtype=float32)}
{'loss_inverse': Array(0.00032231, dtype=float32)}


  0%|          | 2879/1000000 [08:27<30:44:59,  9.01it/s]

{'loss': Array(0.12244135, dtype=float32), 'loss_cross_entropy': Array(0.11609509, dtype=float32)}
{'loss_inverse': Array(0.00026471, dtype=float32)}


  0%|          | 2889/1000000 [08:28<28:12:38,  9.82it/s]

{'loss': Array(0.1401925, dtype=float32), 'loss_cross_entropy': Array(0.13312632, dtype=float32)}
{'loss_inverse': Array(0.00037701, dtype=float32)}


  0%|          | 2899/1000000 [08:30<49:30:05,  5.60it/s]

{'loss': Array(0.13895588, dtype=float32), 'loss_cross_entropy': Array(0.13129933, dtype=float32)}
{'loss_inverse': Array(0.00025117, dtype=float32)}


  0%|          | 2909/1000000 [08:32<31:39:08,  8.75it/s]

{'loss': Array(0.12844609, dtype=float32), 'loss_cross_entropy': Array(0.12140435, dtype=float32)}
{'loss_inverse': Array(0.0002507, dtype=float32)}


  0%|          | 2919/1000000 [08:33<28:42:07,  9.65it/s]

{'loss': Array(0.12918797, dtype=float32), 'loss_cross_entropy': Array(0.12245529, dtype=float32)}
{'loss_inverse': Array(0.00025239, dtype=float32)}


  0%|          | 2929/1000000 [08:35<50:13:45,  5.51it/s]

{'loss': Array(0.1236155, dtype=float32), 'loss_cross_entropy': Array(0.1166492, dtype=float32)}
{'loss_inverse': Array(0.00022856, dtype=float32)}


  0%|          | 2939/1000000 [08:37<32:26:57,  8.54it/s]

{'loss': Array(0.14568824, dtype=float32), 'loss_cross_entropy': Array(0.13861375, dtype=float32)}
{'loss_inverse': Array(0.00021794, dtype=float32)}


  0%|          | 2949/1000000 [08:38<28:08:01,  9.84it/s]

{'loss': Array(0.13529027, dtype=float32), 'loss_cross_entropy': Array(0.12849464, dtype=float32)}
{'loss_inverse': Array(0.00021355, dtype=float32)}


  0%|          | 2959/1000000 [08:39<27:29:01, 10.08it/s]

{'loss': Array(0.12783791, dtype=float32), 'loss_cross_entropy': Array(0.12108994, dtype=float32)}
{'loss_inverse': Array(0.00018353, dtype=float32)}


  0%|          | 2969/1000000 [08:41<28:09:59,  9.83it/s]

{'loss': Array(0.1349601, dtype=float32), 'loss_cross_entropy': Array(0.12799802, dtype=float32)}
{'loss_inverse': Array(0.00023252, dtype=float32)}


  0%|          | 2979/1000000 [08:43<32:27:42,  8.53it/s]

{'loss': Array(0.11902814, dtype=float32), 'loss_cross_entropy': Array(0.11269279, dtype=float32)}
{'loss_inverse': Array(0.00016161, dtype=float32)}


  0%|          | 2989/1000000 [08:44<28:37:26,  9.68it/s]

{'loss': Array(0.13612461, dtype=float32), 'loss_cross_entropy': Array(0.1294552, dtype=float32)}
{'loss_inverse': Array(0.00020207, dtype=float32)}


  0%|          | 2999/1000000 [08:46<28:53:32,  9.59it/s]

{'loss': Array(0.14939557, dtype=float32), 'loss_cross_entropy': Array(0.14190654, dtype=float32)}
{'loss_inverse': Array(0.00016781, dtype=float32)}


  0%|          | 3009/1000000 [08:54<114:03:22,  2.43it/s]

{'loss': Array(0.1332548, dtype=float32), 'loss_cross_entropy': Array(0.12647817, dtype=float32)}
{'loss_inverse': Array(0.00018124, dtype=float32)}


  0%|          | 3019/1000000 [08:56<42:48:50,  6.47it/s] 

{'loss': Array(0.13337986, dtype=float32), 'loss_cross_entropy': Array(0.1270103, dtype=float32)}
{'loss_inverse': Array(0.00022497, dtype=float32)}


  0%|          | 3029/1000000 [08:57<31:18:24,  8.85it/s]

{'loss': Array(0.12485826, dtype=float32), 'loss_cross_entropy': Array(0.11848626, dtype=float32)}
{'loss_inverse': Array(0.0002684, dtype=float32)}


  0%|          | 3039/1000000 [08:59<28:23:44,  9.75it/s]

{'loss': Array(0.1273831, dtype=float32), 'loss_cross_entropy': Array(0.12060869, dtype=float32)}
{'loss_inverse': Array(0.00026095, dtype=float32)}


  0%|          | 3049/1000000 [09:01<36:08:46,  7.66it/s]

{'loss': Array(0.11174967, dtype=float32), 'loss_cross_entropy': Array(0.10553154, dtype=float32)}
{'loss_inverse': Array(0.00029477, dtype=float32)}


  0%|          | 3059/1000000 [09:02<28:56:44,  9.57it/s]

{'loss': Array(0.1200912, dtype=float32), 'loss_cross_entropy': Array(0.11369321, dtype=float32)}
{'loss_inverse': Array(0.00025761, dtype=float32)}


  0%|          | 3069/1000000 [09:04<27:54:22,  9.92it/s]

{'loss': Array(0.12446427, dtype=float32), 'loss_cross_entropy': Array(0.11828604, dtype=float32)}
{'loss_inverse': Array(0.00020115, dtype=float32)}


  0%|          | 3079/1000000 [09:06<43:15:21,  6.40it/s]

{'loss': Array(0.12089331, dtype=float32), 'loss_cross_entropy': Array(0.1145978, dtype=float32)}
{'loss_inverse': Array(0.00020535, dtype=float32)}


  0%|          | 3089/1000000 [09:07<30:05:06,  9.20it/s]

{'loss': Array(0.13756816, dtype=float32), 'loss_cross_entropy': Array(0.13061646, dtype=float32)}
{'loss_inverse': Array(0.00022165, dtype=float32)}


  0%|          | 3099/1000000 [09:08<28:37:17,  9.68it/s]

{'loss': Array(0.12244924, dtype=float32), 'loss_cross_entropy': Array(0.1162064, dtype=float32)}
{'loss_inverse': Array(0.00020801, dtype=float32)}


  0%|          | 3109/1000000 [09:10<30:08:58,  9.18it/s]

{'loss': Array(0.15237758, dtype=float32), 'loss_cross_entropy': Array(0.14463432, dtype=float32)}
{'loss_inverse': Array(0.00018299, dtype=float32)}


  0%|          | 3119/1000000 [09:12<42:44:13,  6.48it/s]

{'loss': Array(0.1301982, dtype=float32), 'loss_cross_entropy': Array(0.1230995, dtype=float32)}
{'loss_inverse': Array(0.00016846, dtype=float32)}


  0%|          | 3129/1000000 [09:13<30:44:58,  9.01it/s]

{'loss': Array(0.12406578, dtype=float32), 'loss_cross_entropy': Array(0.11725509, dtype=float32)}
{'loss_inverse': Array(0.00017189, dtype=float32)}


  0%|          | 3139/1000000 [09:15<30:02:52,  9.22it/s]

{'loss': Array(0.12141271, dtype=float32), 'loss_cross_entropy': Array(0.11545318, dtype=float32)}
{'loss_inverse': Array(0.00015081, dtype=float32)}


  0%|          | 3149/1000000 [09:16<28:11:33,  9.82it/s]

{'loss': Array(0.14601123, dtype=float32), 'loss_cross_entropy': Array(0.13848259, dtype=float32)}
{'loss_inverse': Array(0.00018458, dtype=float32)}


  0%|          | 3159/1000000 [09:18<35:36:35,  7.78it/s]

{'loss': Array(0.12408469, dtype=float32), 'loss_cross_entropy': Array(0.11739983, dtype=float32)}
{'loss_inverse': Array(0.00023365, dtype=float32)}


  0%|          | 3169/1000000 [09:20<30:11:46,  9.17it/s]

{'loss': Array(0.1261593, dtype=float32), 'loss_cross_entropy': Array(0.11956602, dtype=float32)}
{'loss_inverse': Array(0.00020536, dtype=float32)}


  0%|          | 3179/1000000 [09:21<29:09:59,  9.49it/s]

{'loss': Array(0.13508074, dtype=float32), 'loss_cross_entropy': Array(0.12799816, dtype=float32)}
{'loss_inverse': Array(0.00016332, dtype=float32)}


  0%|          | 3189/1000000 [09:23<48:57:52,  5.65it/s]

{'loss': Array(0.12645285, dtype=float32), 'loss_cross_entropy': Array(0.11967446, dtype=float32)}
{'loss_inverse': Array(0.0001742, dtype=float32)}


  0%|          | 3199/1000000 [09:25<31:19:07,  8.84it/s]

{'loss': Array(0.13963096, dtype=float32), 'loss_cross_entropy': Array(0.13295837, dtype=float32)}
{'loss_inverse': Array(0.00016626, dtype=float32)}


  0%|          | 3209/1000000 [09:26<28:32:18,  9.70it/s]

{'loss': Array(0.132389, dtype=float32), 'loss_cross_entropy': Array(0.12497689, dtype=float32)}
{'loss_inverse': Array(0.00015228, dtype=float32)}


  0%|          | 3219/1000000 [09:28<27:48:14,  9.96it/s]

{'loss': Array(0.1468857, dtype=float32), 'loss_cross_entropy': Array(0.13962422, dtype=float32)}
{'loss_inverse': Array(0.00015975, dtype=float32)}


  0%|          | 3229/1000000 [09:30<34:44:42,  7.97it/s]

{'loss': Array(0.12378891, dtype=float32), 'loss_cross_entropy': Array(0.11764701, dtype=float32)}
{'loss_inverse': Array(0.00014186, dtype=float32)}


  0%|          | 3239/1000000 [09:31<29:29:16,  9.39it/s]

{'loss': Array(0.11666628, dtype=float32), 'loss_cross_entropy': Array(0.11005262, dtype=float32)}
{'loss_inverse': Array(0.00013653, dtype=float32)}


  0%|          | 3249/1000000 [09:32<27:45:58,  9.97it/s]

{'loss': Array(0.12176128, dtype=float32), 'loss_cross_entropy': Array(0.11502364, dtype=float32)}
{'loss_inverse': Array(0.00012333, dtype=float32)}


  0%|          | 3259/1000000 [09:34<27:20:54, 10.12it/s]

{'loss': Array(0.14358462, dtype=float32), 'loss_cross_entropy': Array(0.13622652, dtype=float32)}
{'loss_inverse': Array(0.00017256, dtype=float32)}


  0%|          | 3269/1000000 [09:36<33:36:58,  8.24it/s]

{'loss': Array(0.12102754, dtype=float32), 'loss_cross_entropy': Array(0.11422013, dtype=float32)}
{'loss_inverse': Array(0.0001403, dtype=float32)}


  0%|          | 3279/1000000 [09:37<29:05:22,  9.52it/s]

{'loss': Array(0.12122793, dtype=float32), 'loss_cross_entropy': Array(0.11481573, dtype=float32)}
{'loss_inverse': Array(0.00015819, dtype=float32)}


  0%|          | 3289/1000000 [09:39<27:43:50,  9.98it/s]

{'loss': Array(0.14434934, dtype=float32), 'loss_cross_entropy': Array(0.13743137, dtype=float32)}
{'loss_inverse': Array(0.00016351, dtype=float32)}


  0%|          | 3299/1000000 [09:41<42:57:24,  6.45it/s]

{'loss': Array(0.11999323, dtype=float32), 'loss_cross_entropy': Array(0.11343348, dtype=float32)}
{'loss_inverse': Array(0.0001514, dtype=float32)}


  0%|          | 3309/1000000 [09:42<29:39:48,  9.33it/s]

{'loss': Array(0.12855954, dtype=float32), 'loss_cross_entropy': Array(0.12211596, dtype=float32)}
{'loss_inverse': Array(0.00015321, dtype=float32)}


  0%|          | 3319/1000000 [09:44<28:11:21,  9.82it/s]

{'loss': Array(0.14786983, dtype=float32), 'loss_cross_entropy': Array(0.14022373, dtype=float32)}
{'loss_inverse': Array(0.00017503, dtype=float32)}


  0%|          | 3329/1000000 [09:45<28:39:18,  9.66it/s]

{'loss': Array(0.12859558, dtype=float32), 'loss_cross_entropy': Array(0.122086, dtype=float32)}
{'loss_inverse': Array(0.00013703, dtype=float32)}


  0%|          | 3339/1000000 [09:47<32:49:00,  8.44it/s]

{'loss': Array(0.12506893, dtype=float32), 'loss_cross_entropy': Array(0.11888558, dtype=float32)}
{'loss_inverse': Array(0.00016078, dtype=float32)}


  0%|          | 3349/1000000 [09:49<28:52:17,  9.59it/s]

{'loss': Array(0.13147514, dtype=float32), 'loss_cross_entropy': Array(0.1247765, dtype=float32)}
{'loss_inverse': Array(0.00016085, dtype=float32)}


  0%|          | 3359/1000000 [09:50<29:58:59,  9.23it/s]

{'loss': Array(0.1358546, dtype=float32), 'loss_cross_entropy': Array(0.12915419, dtype=float32)}
{'loss_inverse': Array(0.00015502, dtype=float32)}


  0%|          | 3369/1000000 [09:52<38:59:19,  7.10it/s]

{'loss': Array(0.15094905, dtype=float32), 'loss_cross_entropy': Array(0.14307736, dtype=float32)}
{'loss_inverse': Array(0.0001509, dtype=float32)}


  0%|          | 3379/1000000 [09:54<29:59:04,  9.23it/s]

{'loss': Array(0.14679675, dtype=float32), 'loss_cross_entropy': Array(0.1391853, dtype=float32)}
{'loss_inverse': Array(0.00017482, dtype=float32)}


  0%|          | 3389/1000000 [09:55<29:44:30,  9.31it/s]

{'loss': Array(0.12791102, dtype=float32), 'loss_cross_entropy': Array(0.12096536, dtype=float32)}
{'loss_inverse': Array(0.00014502, dtype=float32)}


  0%|          | 3399/1000000 [09:56<28:04:04,  9.86it/s]

{'loss': Array(0.1276872, dtype=float32), 'loss_cross_entropy': Array(0.1205477, dtype=float32)}
{'loss_inverse': Array(0.00013421, dtype=float32)}


  0%|          | 3409/1000000 [09:58<35:06:20,  7.89it/s]

{'loss': Array(0.1464126, dtype=float32), 'loss_cross_entropy': Array(0.1390913, dtype=float32)}
{'loss_inverse': Array(0.0001527, dtype=float32)}


  0%|          | 3419/1000000 [10:00<31:04:52,  8.91it/s]

{'loss': Array(0.15069245, dtype=float32), 'loss_cross_entropy': Array(0.14279686, dtype=float32)}
{'loss_inverse': Array(0.00012759, dtype=float32)}


  0%|          | 3429/1000000 [10:01<28:41:28,  9.65it/s]

{'loss': Array(0.1337565, dtype=float32), 'loss_cross_entropy': Array(0.12689415, dtype=float32)}
{'loss_inverse': Array(0.00012575, dtype=float32)}


  0%|          | 3439/1000000 [10:03<28:09:54,  9.83it/s]

{'loss': Array(0.11836598, dtype=float32), 'loss_cross_entropy': Array(0.11198673, dtype=float32)}
{'loss_inverse': Array(0.00012806, dtype=float32)}


  0%|          | 3449/1000000 [10:05<40:24:25,  6.85it/s]

{'loss': Array(0.13611458, dtype=float32), 'loss_cross_entropy': Array(0.12883298, dtype=float32)}
{'loss_inverse': Array(0.00017842, dtype=float32)}


  0%|          | 3459/1000000 [10:06<30:28:27,  9.08it/s]

{'loss': Array(0.12140837, dtype=float32), 'loss_cross_entropy': Array(0.11466759, dtype=float32)}
{'loss_inverse': Array(0.00012465, dtype=float32)}


  0%|          | 3469/1000000 [10:08<27:35:53, 10.03it/s]

{'loss': Array(0.13188085, dtype=float32), 'loss_cross_entropy': Array(0.12499461, dtype=float32)}
{'loss_inverse': Array(0.00011289, dtype=float32)}


  0%|          | 3479/1000000 [10:09<27:36:34, 10.03it/s]

{'loss': Array(0.11120327, dtype=float32), 'loss_cross_entropy': Array(0.10504323, dtype=float32)}
{'loss_inverse': Array(0.00016287, dtype=float32)}


  0%|          | 3489/1000000 [10:11<34:26:14,  8.04it/s]

{'loss': Array(0.14486638, dtype=float32), 'loss_cross_entropy': Array(0.13748918, dtype=float32)}
{'loss_inverse': Array(0.00014059, dtype=float32)}


  0%|          | 3499/1000000 [10:13<27:52:46,  9.93it/s]

{'loss': Array(0.14230177, dtype=float32), 'loss_cross_entropy': Array(0.13512474, dtype=float32)}
{'loss_inverse': Array(0.00013818, dtype=float32)}


  0%|          | 3509/1000000 [10:21<93:03:55,  2.97it/s] 

{'loss': Array(0.14074303, dtype=float32), 'loss_cross_entropy': Array(0.13374984, dtype=float32)}
{'loss_inverse': Array(0.00012864, dtype=float32)}


  0%|          | 3519/1000000 [10:22<59:16:56,  4.67it/s] 

{'loss': Array(0.13199429, dtype=float32), 'loss_cross_entropy': Array(0.12486011, dtype=float32)}
{'loss_inverse': Array(0.00014305, dtype=float32)}


  0%|          | 3529/1000000 [10:24<32:33:48,  8.50it/s]

{'loss': Array(0.14028881, dtype=float32), 'loss_cross_entropy': Array(0.13244139, dtype=float32)}
{'loss_inverse': Array(0.00014988, dtype=float32)}


  0%|          | 3539/1000000 [10:25<28:55:06,  9.57it/s]

{'loss': Array(0.12328988, dtype=float32), 'loss_cross_entropy': Array(0.11661322, dtype=float32)}
{'loss_inverse': Array(0.00014077, dtype=float32)}


  0%|          | 3549/1000000 [10:27<27:42:52,  9.99it/s]

{'loss': Array(0.14224589, dtype=float32), 'loss_cross_entropy': Array(0.13532284, dtype=float32)}
{'loss_inverse': Array(0.00013097, dtype=float32)}


  0%|          | 3559/1000000 [10:29<37:03:31,  7.47it/s]

{'loss': Array(0.12357821, dtype=float32), 'loss_cross_entropy': Array(0.11722672, dtype=float32)}
{'loss_inverse': Array(0.00011676, dtype=float32)}


  0%|          | 3569/1000000 [10:30<29:28:28,  9.39it/s]

{'loss': Array(0.12126514, dtype=float32), 'loss_cross_entropy': Array(0.11462214, dtype=float32)}
{'loss_inverse': Array(0.00010806, dtype=float32)}


  0%|          | 3579/1000000 [10:32<27:38:11, 10.02it/s]

{'loss': Array(0.15573302, dtype=float32), 'loss_cross_entropy': Array(0.14821492, dtype=float32)}
{'loss_inverse': Array(0.0001167, dtype=float32)}


  0%|          | 3589/1000000 [10:34<41:59:43,  6.59it/s]

{'loss': Array(0.12888359, dtype=float32), 'loss_cross_entropy': Array(0.12223039, dtype=float32)}
{'loss_inverse': Array(0.00010607, dtype=float32)}


  0%|          | 3599/1000000 [10:35<31:17:11,  8.85it/s]

{'loss': Array(0.13511102, dtype=float32), 'loss_cross_entropy': Array(0.12840925, dtype=float32)}
{'loss_inverse': Array(0.00010516, dtype=float32)}


  0%|          | 3609/1000000 [10:37<28:10:26,  9.82it/s]

{'loss': Array(0.13621499, dtype=float32), 'loss_cross_entropy': Array(0.12885492, dtype=float32)}
{'loss_inverse': Array(0.00010821, dtype=float32)}


  0%|          | 3619/1000000 [10:38<27:14:16, 10.16it/s]

{'loss': Array(0.14401053, dtype=float32), 'loss_cross_entropy': Array(0.13675606, dtype=float32)}
{'loss_inverse': Array(0.00014133, dtype=float32)}


  0%|          | 3629/1000000 [10:40<33:45:22,  8.20it/s]

{'loss': Array(0.12910414, dtype=float32), 'loss_cross_entropy': Array(0.12203787, dtype=float32)}
{'loss_inverse': Array(0.00012168, dtype=float32)}


  0%|          | 3639/1000000 [10:41<28:35:52,  9.68it/s]

{'loss': Array(0.14164637, dtype=float32), 'loss_cross_entropy': Array(0.13401939, dtype=float32)}
{'loss_inverse': Array(0.00011857, dtype=float32)}


  0%|          | 3649/1000000 [10:43<27:34:23, 10.04it/s]

{'loss': Array(0.12887153, dtype=float32), 'loss_cross_entropy': Array(0.12239621, dtype=float32)}
{'loss_inverse': Array(0.00011097, dtype=float32)}


  0%|          | 3659/1000000 [10:44<26:57:23, 10.27it/s]

{'loss': Array(0.14423726, dtype=float32), 'loss_cross_entropy': Array(0.13687125, dtype=float32)}
{'loss_inverse': Array(0.00010484, dtype=float32)}


  0%|          | 3669/1000000 [10:46<32:01:10,  8.64it/s]

{'loss': Array(0.12126835, dtype=float32), 'loss_cross_entropy': Array(0.11450205, dtype=float32)}
{'loss_inverse': Array(0.00010114, dtype=float32)}


  0%|          | 3679/1000000 [10:47<27:48:38,  9.95it/s]

{'loss': Array(0.14322838, dtype=float32), 'loss_cross_entropy': Array(0.13613893, dtype=float32)}
{'loss_inverse': Array(9.501948e-05, dtype=float32)}


  0%|          | 3689/1000000 [10:49<27:16:22, 10.15it/s]

{'loss': Array(0.12734197, dtype=float32), 'loss_cross_entropy': Array(0.12112956, dtype=float32)}
{'loss_inverse': Array(8.255414e-05, dtype=float32)}


  0%|          | 3699/1000000 [10:51<47:38:15,  5.81it/s]

{'loss': Array(0.13153139, dtype=float32), 'loss_cross_entropy': Array(0.12444244, dtype=float32)}
{'loss_inverse': Array(0.00010978, dtype=float32)}


  0%|          | 3709/1000000 [10:52<30:08:42,  9.18it/s]

{'loss': Array(0.12507986, dtype=float32), 'loss_cross_entropy': Array(0.1184381, dtype=float32)}
{'loss_inverse': Array(0.00011605, dtype=float32)}


  0%|          | 3719/1000000 [10:54<27:34:11, 10.04it/s]

{'loss': Array(0.15288885, dtype=float32), 'loss_cross_entropy': Array(0.14550662, dtype=float32)}
{'loss_inverse': Array(0.00012656, dtype=float32)}


  0%|          | 3729/1000000 [10:55<28:22:53,  9.75it/s]

{'loss': Array(0.11897407, dtype=float32), 'loss_cross_entropy': Array(0.11233498, dtype=float32)}
{'loss_inverse': Array(0.00012389, dtype=float32)}


  0%|          | 3739/1000000 [10:57<37:17:11,  7.42it/s]

{'loss': Array(0.13607393, dtype=float32), 'loss_cross_entropy': Array(0.12895666, dtype=float32)}
{'loss_inverse': Array(0.00012983, dtype=float32)}


  0%|          | 3749/1000000 [10:58<28:46:38,  9.62it/s]

{'loss': Array(0.12609829, dtype=float32), 'loss_cross_entropy': Array(0.11922503, dtype=float32)}
{'loss_inverse': Array(0.00011983, dtype=float32)}


  0%|          | 3759/1000000 [11:00<29:40:45,  9.32it/s]

{'loss': Array(0.14195298, dtype=float32), 'loss_cross_entropy': Array(0.13448058, dtype=float32)}
{'loss_inverse': Array(0.00011132, dtype=float32)}


  0%|          | 3769/1000000 [11:02<48:25:38,  5.71it/s]

{'loss': Array(0.13738815, dtype=float32), 'loss_cross_entropy': Array(0.13054383, dtype=float32)}
{'loss_inverse': Array(0.00011869, dtype=float32)}


  0%|          | 3779/1000000 [11:03<30:47:57,  8.98it/s]

{'loss': Array(0.12740892, dtype=float32), 'loss_cross_entropy': Array(0.12011541, dtype=float32)}
{'loss_inverse': Array(9.2443945e-05, dtype=float32)}


  0%|          | 3789/1000000 [11:05<27:28:16, 10.07it/s]

{'loss': Array(0.12871179, dtype=float32), 'loss_cross_entropy': Array(0.12107771, dtype=float32)}
{'loss_inverse': Array(8.8819506e-05, dtype=float32)}


  0%|          | 3799/1000000 [11:06<27:49:41,  9.94it/s]

{'loss': Array(0.14331925, dtype=float32), 'loss_cross_entropy': Array(0.1355965, dtype=float32)}
{'loss_inverse': Array(0.00010432, dtype=float32)}


  0%|          | 3809/1000000 [11:08<34:19:46,  8.06it/s]

{'loss': Array(0.1315665, dtype=float32), 'loss_cross_entropy': Array(0.12420505, dtype=float32)}
{'loss_inverse': Array(0.00010457, dtype=float32)}


  0%|          | 3819/1000000 [11:10<29:46:21,  9.29it/s]

{'loss': Array(0.14589687, dtype=float32), 'loss_cross_entropy': Array(0.13838649, dtype=float32)}
{'loss_inverse': Array(8.863674e-05, dtype=float32)}


  0%|          | 3829/1000000 [11:11<28:16:27,  9.79it/s]

{'loss': Array(0.13968349, dtype=float32), 'loss_cross_entropy': Array(0.1322999, dtype=float32)}
{'loss_inverse': Array(8.7284134e-05, dtype=float32)}


  0%|          | 3839/1000000 [11:13<37:39:48,  7.35it/s]

{'loss': Array(0.13292441, dtype=float32), 'loss_cross_entropy': Array(0.12536035, dtype=float32)}
{'loss_inverse': Array(0.00010116, dtype=float32)}


  0%|          | 3849/1000000 [11:14<28:42:09,  9.64it/s]

{'loss': Array(0.12739597, dtype=float32), 'loss_cross_entropy': Array(0.12047118, dtype=float32)}
{'loss_inverse': Array(7.228582e-05, dtype=float32)}


  0%|          | 3859/1000000 [11:16<28:02:41,  9.87it/s]

{'loss': Array(0.12927279, dtype=float32), 'loss_cross_entropy': Array(0.12197914, dtype=float32)}
{'loss_inverse': Array(8.242114e-05, dtype=float32)}


  0%|          | 3869/1000000 [11:17<27:05:03, 10.22it/s]

{'loss': Array(0.12668559, dtype=float32), 'loss_cross_entropy': Array(0.11997081, dtype=float32)}
{'loss_inverse': Array(0.00012647, dtype=float32)}


  0%|          | 3879/1000000 [11:19<37:21:50,  7.41it/s]

{'loss': Array(0.15494776, dtype=float32), 'loss_cross_entropy': Array(0.14725363, dtype=float32)}
{'loss_inverse': Array(8.7399945e-05, dtype=float32)}


  0%|          | 3889/1000000 [11:21<29:25:38,  9.40it/s]

{'loss': Array(0.15641545, dtype=float32), 'loss_cross_entropy': Array(0.14854865, dtype=float32)}
{'loss_inverse': Array(0.00010564, dtype=float32)}


  0%|          | 3899/1000000 [11:22<27:29:58, 10.06it/s]

{'loss': Array(0.13858572, dtype=float32), 'loss_cross_entropy': Array(0.13169686, dtype=float32)}
{'loss_inverse': Array(0.00011637, dtype=float32)}


  0%|          | 3909/1000000 [11:24<27:16:07, 10.15it/s]

{'loss': Array(0.11016539, dtype=float32), 'loss_cross_entropy': Array(0.10457464, dtype=float32)}
{'loss_inverse': Array(0.00013516, dtype=float32)}


  0%|          | 3919/1000000 [11:25<34:19:40,  8.06it/s]

{'loss': Array(0.12016094, dtype=float32), 'loss_cross_entropy': Array(0.11348631, dtype=float32)}
{'loss_inverse': Array(0.00010947, dtype=float32)}


  0%|          | 3929/1000000 [11:27<28:40:09,  9.65it/s]

{'loss': Array(0.12483793, dtype=float32), 'loss_cross_entropy': Array(0.11879362, dtype=float32)}
{'loss_inverse': Array(9.93029e-05, dtype=float32)}


  0%|          | 3939/1000000 [11:28<27:57:54,  9.89it/s]

{'loss': Array(0.14686573, dtype=float32), 'loss_cross_entropy': Array(0.13908593, dtype=float32)}
{'loss_inverse': Array(0.00010917, dtype=float32)}


  0%|          | 3949/1000000 [11:30<42:51:32,  6.46it/s]

{'loss': Array(0.14127585, dtype=float32), 'loss_cross_entropy': Array(0.13429372, dtype=float32)}
{'loss_inverse': Array(9.088889e-05, dtype=float32)}


  0%|          | 3959/1000000 [11:32<30:13:38,  9.15it/s]

{'loss': Array(0.12656832, dtype=float32), 'loss_cross_entropy': Array(0.11979258, dtype=float32)}
{'loss_inverse': Array(9.960035e-05, dtype=float32)}


  0%|          | 3969/1000000 [11:33<27:51:28,  9.93it/s]

{'loss': Array(0.14526089, dtype=float32), 'loss_cross_entropy': Array(0.137286, dtype=float32)}
{'loss_inverse': Array(7.6573495e-05, dtype=float32)}


  0%|          | 3979/1000000 [11:35<27:35:33, 10.03it/s]

{'loss': Array(0.12725723, dtype=float32), 'loss_cross_entropy': Array(0.12090979, dtype=float32)}
{'loss_inverse': Array(8.946796e-05, dtype=float32)}


  0%|          | 3989/1000000 [11:37<32:18:21,  8.56it/s]

{'loss': Array(0.12557885, dtype=float32), 'loss_cross_entropy': Array(0.11847889, dtype=float32)}
{'loss_inverse': Array(7.471773e-05, dtype=float32)}


  0%|          | 3999/1000000 [11:38<28:30:34,  9.70it/s]

{'loss': Array(0.12679423, dtype=float32), 'loss_cross_entropy': Array(0.11984872, dtype=float32)}
{'loss_inverse': Array(6.699803e-05, dtype=float32)}


  0%|          | 4009/1000000 [11:46<97:06:59,  2.85it/s] 

{'loss': Array(0.12700017, dtype=float32), 'loss_cross_entropy': Array(0.11974978, dtype=float32)}
{'loss_inverse': Array(7.6716555e-05, dtype=float32)}


  0%|          | 4019/1000000 [11:48<38:16:49,  7.23it/s] 

{'loss': Array(0.13419203, dtype=float32), 'loss_cross_entropy': Array(0.12690882, dtype=float32)}
{'loss_inverse': Array(7.478652e-05, dtype=float32)}


  0%|          | 4029/1000000 [11:50<35:25:36,  7.81it/s]

{'loss': Array(0.1600358, dtype=float32), 'loss_cross_entropy': Array(0.15251951, dtype=float32)}
{'loss_inverse': Array(6.857243e-05, dtype=float32)}


  0%|          | 4039/1000000 [11:51<28:45:45,  9.62it/s]

{'loss': Array(0.11627492, dtype=float32), 'loss_cross_entropy': Array(0.10926729, dtype=float32)}
{'loss_inverse': Array(8.871582e-05, dtype=float32)}


  0%|          | 4049/1000000 [11:53<27:36:39, 10.02it/s]

{'loss': Array(0.13642198, dtype=float32), 'loss_cross_entropy': Array(0.12996249, dtype=float32)}
{'loss_inverse': Array(6.9388094e-05, dtype=float32)}


  0%|          | 4059/1000000 [11:54<47:35:18,  5.81it/s]

{'loss': Array(0.14190899, dtype=float32), 'loss_cross_entropy': Array(0.13412124, dtype=float32)}
{'loss_inverse': Array(0.00014245, dtype=float32)}


  0%|          | 4069/1000000 [11:56<30:57:03,  8.94it/s]

{'loss': Array(0.14382766, dtype=float32), 'loss_cross_entropy': Array(0.13656649, dtype=float32)}
{'loss_inverse': Array(0.00011115, dtype=float32)}


  0%|          | 4079/1000000 [11:57<27:32:16, 10.05it/s]

{'loss': Array(0.14834379, dtype=float32), 'loss_cross_entropy': Array(0.14081903, dtype=float32)}
{'loss_inverse': Array(0.0001189, dtype=float32)}


  0%|          | 4089/1000000 [11:59<27:27:15, 10.08it/s]

{'loss': Array(0.14246829, dtype=float32), 'loss_cross_entropy': Array(0.13493057, dtype=float32)}
{'loss_inverse': Array(0.00011035, dtype=float32)}


  0%|          | 4099/1000000 [12:01<38:00:35,  7.28it/s]

{'loss': Array(0.14197683, dtype=float32), 'loss_cross_entropy': Array(0.13455474, dtype=float32)}
{'loss_inverse': Array(8.673899e-05, dtype=float32)}


  0%|          | 4109/1000000 [12:02<28:31:19,  9.70it/s]

{'loss': Array(0.12931906, dtype=float32), 'loss_cross_entropy': Array(0.12259244, dtype=float32)}
{'loss_inverse': Array(6.682381e-05, dtype=float32)}


  0%|          | 4119/1000000 [12:04<27:19:42, 10.12it/s]

{'loss': Array(0.15646206, dtype=float32), 'loss_cross_entropy': Array(0.14868648, dtype=float32)}
{'loss_inverse': Array(6.3362524e-05, dtype=float32)}


  0%|          | 4129/1000000 [12:05<42:36:13,  6.49it/s]

{'loss': Array(0.12141369, dtype=float32), 'loss_cross_entropy': Array(0.11448355, dtype=float32)}
{'loss_inverse': Array(6.947628e-05, dtype=float32)}


  0%|          | 4139/1000000 [12:07<29:30:29,  9.37it/s]

{'loss': Array(0.11530311, dtype=float32), 'loss_cross_entropy': Array(0.10869894, dtype=float32)}
{'loss_inverse': Array(9.4042814e-05, dtype=float32)}


  0%|          | 4149/1000000 [12:08<27:26:25, 10.08it/s]

{'loss': Array(0.13371101, dtype=float32), 'loss_cross_entropy': Array(0.126598, dtype=float32)}
{'loss_inverse': Array(9.593021e-05, dtype=float32)}


  0%|          | 4159/1000000 [12:10<30:17:52,  9.13it/s]

{'loss': Array(0.13504907, dtype=float32), 'loss_cross_entropy': Array(0.12746242, dtype=float32)}
{'loss_inverse': Array(8.601477e-05, dtype=float32)}


  0%|          | 4169/1000000 [12:12<32:27:48,  8.52it/s]

{'loss': Array(0.12633093, dtype=float32), 'loss_cross_entropy': Array(0.11959245, dtype=float32)}
{'loss_inverse': Array(8.081114e-05, dtype=float32)}


  0%|          | 4179/1000000 [12:13<27:57:03,  9.90it/s]

{'loss': Array(0.13892291, dtype=float32), 'loss_cross_entropy': Array(0.13162659, dtype=float32)}
{'loss_inverse': Array(7.572231e-05, dtype=float32)}


  0%|          | 4189/1000000 [12:15<28:00:43,  9.87it/s]

{'loss': Array(0.14516538, dtype=float32), 'loss_cross_entropy': Array(0.13726526, dtype=float32)}
{'loss_inverse': Array(8.642447e-05, dtype=float32)}


  0%|          | 4199/1000000 [12:16<27:22:47, 10.10it/s]

{'loss': Array(0.11461641, dtype=float32), 'loss_cross_entropy': Array(0.10866486, dtype=float32)}
{'loss_inverse': Array(6.885407e-05, dtype=float32)}


  0%|          | 4209/1000000 [12:18<31:38:26,  8.74it/s]

{'loss': Array(0.15624855, dtype=float32), 'loss_cross_entropy': Array(0.14812331, dtype=float32)}
{'loss_inverse': Array(8.3477375e-05, dtype=float32)}


  0%|          | 4219/1000000 [12:19<28:08:12,  9.83it/s]

{'loss': Array(0.13528423, dtype=float32), 'loss_cross_entropy': Array(0.12853354, dtype=float32)}
{'loss_inverse': Array(6.568789e-05, dtype=float32)}


  0%|          | 4229/1000000 [12:21<27:59:21,  9.88it/s]

{'loss': Array(0.13779247, dtype=float32), 'loss_cross_entropy': Array(0.13081548, dtype=float32)}
{'loss_inverse': Array(6.594496e-05, dtype=float32)}


  0%|          | 4239/1000000 [12:23<47:51:03,  5.78it/s]

{'loss': Array(0.13136046, dtype=float32), 'loss_cross_entropy': Array(0.12434583, dtype=float32)}
{'loss_inverse': Array(7.182544e-05, dtype=float32)}


  0%|          | 4249/1000000 [12:24<31:11:34,  8.87it/s]

{'loss': Array(0.15588939, dtype=float32), 'loss_cross_entropy': Array(0.14805919, dtype=float32)}
{'loss_inverse': Array(7.264292e-05, dtype=float32)}


  0%|          | 4259/1000000 [12:26<28:31:37,  9.70it/s]

{'loss': Array(0.12990773, dtype=float32), 'loss_cross_entropy': Array(0.12275296, dtype=float32)}
{'loss_inverse': Array(8.513181e-05, dtype=float32)}


  0%|          | 4269/1000000 [12:27<27:25:07, 10.09it/s]

{'loss': Array(0.13013737, dtype=float32), 'loss_cross_entropy': Array(0.12345964, dtype=float32)}
{'loss_inverse': Array(7.994242e-05, dtype=float32)}


  0%|          | 4279/1000000 [12:29<37:31:56,  7.37it/s]

{'loss': Array(0.14912574, dtype=float32), 'loss_cross_entropy': Array(0.14083503, dtype=float32)}
{'loss_inverse': Array(8.543248e-05, dtype=float32)}


  0%|          | 4289/1000000 [12:31<29:27:45,  9.39it/s]

{'loss': Array(0.1342097, dtype=float32), 'loss_cross_entropy': Array(0.12719151, dtype=float32)}
{'loss_inverse': Array(6.713436e-05, dtype=float32)}


  0%|          | 4299/1000000 [12:32<27:41:38,  9.99it/s]

{'loss': Array(0.12365021, dtype=float32), 'loss_cross_entropy': Array(0.11750633, dtype=float32)}
{'loss_inverse': Array(7.247829e-05, dtype=float32)}


  0%|          | 4309/1000000 [12:34<48:11:17,  5.74it/s]

{'loss': Array(0.15098624, dtype=float32), 'loss_cross_entropy': Array(0.143467, dtype=float32)}
{'loss_inverse': Array(6.23365e-05, dtype=float32)}


  0%|          | 4319/1000000 [12:35<31:03:46,  8.90it/s]

{'loss': Array(0.14284317, dtype=float32), 'loss_cross_entropy': Array(0.13545679, dtype=float32)}
{'loss_inverse': Array(6.0548988e-05, dtype=float32)}


  0%|          | 4329/1000000 [12:37<27:33:04, 10.04it/s]

{'loss': Array(0.13808359, dtype=float32), 'loss_cross_entropy': Array(0.13077818, dtype=float32)}
{'loss_inverse': Array(7.063476e-05, dtype=float32)}


  0%|          | 4339/1000000 [12:38<27:04:52, 10.21it/s]

{'loss': Array(0.14486586, dtype=float32), 'loss_cross_entropy': Array(0.13704486, dtype=float32)}
{'loss_inverse': Array(5.3367614e-05, dtype=float32)}


  0%|          | 4349/1000000 [12:40<34:41:21,  7.97it/s]

{'loss': Array(0.12131067, dtype=float32), 'loss_cross_entropy': Array(0.11475622, dtype=float32)}
{'loss_inverse': Array(5.34912e-05, dtype=float32)}


  0%|          | 4359/1000000 [12:42<28:34:35,  9.68it/s]

{'loss': Array(0.12534192, dtype=float32), 'loss_cross_entropy': Array(0.11862627, dtype=float32)}
{'loss_inverse': Array(4.474213e-05, dtype=float32)}


  0%|          | 4369/1000000 [12:43<27:37:51, 10.01it/s]

{'loss': Array(0.1409338, dtype=float32), 'loss_cross_entropy': Array(0.1328923, dtype=float32)}
{'loss_inverse': Array(5.0177914e-05, dtype=float32)}


  0%|          | 4379/1000000 [12:45<38:44:34,  7.14it/s]

{'loss': Array(0.13585843, dtype=float32), 'loss_cross_entropy': Array(0.12846191, dtype=float32)}
{'loss_inverse': Array(4.57972e-05, dtype=float32)}


  0%|          | 4389/1000000 [12:47<29:25:09,  9.40it/s]

{'loss': Array(0.12213465, dtype=float32), 'loss_cross_entropy': Array(0.11531484, dtype=float32)}
{'loss_inverse': Array(4.7407433e-05, dtype=float32)}


  0%|          | 4399/1000000 [12:48<27:34:10, 10.03it/s]

{'loss': Array(0.13947496, dtype=float32), 'loss_cross_entropy': Array(0.13218537, dtype=float32)}
{'loss_inverse': Array(5.1497376e-05, dtype=float32)}


  0%|          | 4409/1000000 [12:49<26:45:08, 10.34it/s]

{'loss': Array(0.14907847, dtype=float32), 'loss_cross_entropy': Array(0.14150195, dtype=float32)}
{'loss_inverse': Array(5.9742855e-05, dtype=float32)}


  0%|          | 4419/1000000 [12:51<37:02:37,  7.47it/s]

{'loss': Array(0.13506177, dtype=float32), 'loss_cross_entropy': Array(0.12802762, dtype=float32)}
{'loss_inverse': Array(5.732231e-05, dtype=float32)}


  0%|          | 4429/1000000 [12:53<28:36:24,  9.67it/s]

{'loss': Array(0.13818587, dtype=float32), 'loss_cross_entropy': Array(0.13054888, dtype=float32)}
{'loss_inverse': Array(5.127563e-05, dtype=float32)}


  0%|          | 4439/1000000 [12:54<27:11:13, 10.17it/s]

{'loss': Array(0.16241565, dtype=float32), 'loss_cross_entropy': Array(0.15429492, dtype=float32)}
{'loss_inverse': Array(7.501517e-05, dtype=float32)}


  0%|          | 4449/1000000 [12:56<27:21:13, 10.11it/s]

{'loss': Array(0.15617083, dtype=float32), 'loss_cross_entropy': Array(0.1480423, dtype=float32)}
{'loss_inverse': Array(5.520847e-05, dtype=float32)}


  0%|          | 4459/1000000 [12:57<34:23:19,  8.04it/s]

{'loss': Array(0.14325081, dtype=float32), 'loss_cross_entropy': Array(0.13593163, dtype=float32)}
{'loss_inverse': Array(9.600493e-05, dtype=float32)}


  0%|          | 4469/1000000 [12:59<28:15:36,  9.79it/s]

{'loss': Array(0.1375288, dtype=float32), 'loss_cross_entropy': Array(0.13060312, dtype=float32)}
{'loss_inverse': Array(0.00013115, dtype=float32)}


  0%|          | 4479/1000000 [13:00<28:07:00,  9.84it/s]

{'loss': Array(0.12650336, dtype=float32), 'loss_cross_entropy': Array(0.11973581, dtype=float32)}
{'loss_inverse': Array(0.0001227, dtype=float32)}


  0%|          | 4489/1000000 [13:02<42:17:59,  6.54it/s]

{'loss': Array(0.13311502, dtype=float32), 'loss_cross_entropy': Array(0.12607309, dtype=float32)}
{'loss_inverse': Array(0.00013576, dtype=float32)}


  0%|          | 4499/1000000 [13:04<29:33:53,  9.35it/s]

{'loss': Array(0.13122925, dtype=float32), 'loss_cross_entropy': Array(0.12416691, dtype=float32)}
{'loss_inverse': Array(0.00020237, dtype=float32)}


  0%|          | 4509/1000000 [13:12<92:33:12,  2.99it/s] 

{'loss': Array(0.14235882, dtype=float32), 'loss_cross_entropy': Array(0.13434197, dtype=float32)}
{'loss_inverse': Array(0.00046151, dtype=float32)}


  0%|          | 4519/1000000 [13:13<37:54:30,  7.29it/s] 

{'loss': Array(0.13385373, dtype=float32), 'loss_cross_entropy': Array(0.12701373, dtype=float32)}
{'loss_inverse': Array(0.00043251, dtype=float32)}


  0%|          | 4529/1000000 [13:15<37:01:09,  7.47it/s]

{'loss': Array(0.1238456, dtype=float32), 'loss_cross_entropy': Array(0.11712305, dtype=float32)}
{'loss_inverse': Array(0.00044001, dtype=float32)}


  0%|          | 4539/1000000 [13:17<29:06:36,  9.50it/s]

{'loss': Array(0.15928064, dtype=float32), 'loss_cross_entropy': Array(0.1505357, dtype=float32)}
{'loss_inverse': Array(0.00042625, dtype=float32)}


  0%|          | 4549/1000000 [13:18<27:33:50, 10.03it/s]

{'loss': Array(0.15052795, dtype=float32), 'loss_cross_entropy': Array(0.14296281, dtype=float32)}
{'loss_inverse': Array(0.00061121, dtype=float32)}


  0%|          | 4559/1000000 [13:20<43:09:25,  6.41it/s]

{'loss': Array(0.1254925, dtype=float32), 'loss_cross_entropy': Array(0.11862985, dtype=float32)}
{'loss_inverse': Array(0.00061637, dtype=float32)}


  0%|          | 4569/1000000 [13:21<32:20:39,  8.55it/s]

{'loss': Array(0.12359805, dtype=float32), 'loss_cross_entropy': Array(0.11657732, dtype=float32)}
{'loss_inverse': Array(0.00113099, dtype=float32)}


  0%|          | 4579/1000000 [13:23<28:57:57,  9.55it/s]

{'loss': Array(0.1343802, dtype=float32), 'loss_cross_entropy': Array(0.12731944, dtype=float32)}
{'loss_inverse': Array(0.0358163, dtype=float32)}


  0%|          | 4589/1000000 [13:24<28:18:03,  9.77it/s]

{'loss': Array(0.13890271, dtype=float32), 'loss_cross_entropy': Array(0.13144536, dtype=float32)}
{'loss_inverse': Array(0.01031438, dtype=float32)}


  0%|          | 4599/1000000 [13:26<43:10:11,  6.40it/s]

{'loss': Array(0.14876747, dtype=float32), 'loss_cross_entropy': Array(0.14101918, dtype=float32)}
{'loss_inverse': Array(0.00450151, dtype=float32)}


  0%|          | 4609/1000000 [13:28<30:37:47,  9.03it/s]

{'loss': Array(0.13102919, dtype=float32), 'loss_cross_entropy': Array(0.12427229, dtype=float32)}
{'loss_inverse': Array(0.00308849, dtype=float32)}


  0%|          | 4619/1000000 [13:29<28:44:12,  9.62it/s]

{'loss': Array(0.15012042, dtype=float32), 'loss_cross_entropy': Array(0.1427694, dtype=float32)}
{'loss_inverse': Array(0.00174397, dtype=float32)}


  0%|          | 4629/1000000 [13:31<28:31:42,  9.69it/s]

{'loss': Array(0.12017132, dtype=float32), 'loss_cross_entropy': Array(0.1141718, dtype=float32)}
{'loss_inverse': Array(0.00092038, dtype=float32)}


  0%|          | 4639/1000000 [13:33<35:42:01,  7.74it/s]

{'loss': Array(0.12296255, dtype=float32), 'loss_cross_entropy': Array(0.1160579, dtype=float32)}
{'loss_inverse': Array(0.00068534, dtype=float32)}


  0%|          | 4649/1000000 [13:34<29:24:39,  9.40it/s]

{'loss': Array(0.12974475, dtype=float32), 'loss_cross_entropy': Array(0.12275845, dtype=float32)}
{'loss_inverse': Array(0.00064222, dtype=float32)}


  0%|          | 4659/1000000 [13:36<28:51:16,  9.58it/s]

{'loss': Array(0.11977589, dtype=float32), 'loss_cross_entropy': Array(0.11352461, dtype=float32)}
{'loss_inverse': Array(0.00054368, dtype=float32)}


  0%|          | 4669/1000000 [13:38<48:40:55,  5.68it/s]

{'loss': Array(0.145805, dtype=float32), 'loss_cross_entropy': Array(0.13799739, dtype=float32)}
{'loss_inverse': Array(0.00048904, dtype=float32)}


  0%|          | 4679/1000000 [13:39<33:03:11,  8.36it/s]

{'loss': Array(0.13455598, dtype=float32), 'loss_cross_entropy': Array(0.12758945, dtype=float32)}
{'loss_inverse': Array(0.00035378, dtype=float32)}


  0%|          | 4689/1000000 [13:41<29:21:21,  9.42it/s]

{'loss': Array(0.12499874, dtype=float32), 'loss_cross_entropy': Array(0.1182904, dtype=float32)}
{'loss_inverse': Array(0.00036489, dtype=float32)}


  0%|          | 4699/1000000 [13:42<27:42:06,  9.98it/s]

{'loss': Array(0.13364996, dtype=float32), 'loss_cross_entropy': Array(0.12627906, dtype=float32)}
{'loss_inverse': Array(0.00020302, dtype=float32)}


  0%|          | 4709/1000000 [13:44<35:14:56,  7.84it/s]

{'loss': Array(0.13133113, dtype=float32), 'loss_cross_entropy': Array(0.12391271, dtype=float32)}
{'loss_inverse': Array(0.00030841, dtype=float32)}


  0%|          | 4719/1000000 [13:46<30:21:30,  9.11it/s]

{'loss': Array(0.13009878, dtype=float32), 'loss_cross_entropy': Array(0.12358635, dtype=float32)}
{'loss_inverse': Array(0.00026465, dtype=float32)}


  0%|          | 4729/1000000 [13:47<28:34:00,  9.68it/s]

{'loss': Array(0.1401182, dtype=float32), 'loss_cross_entropy': Array(0.1326106, dtype=float32)}
{'loss_inverse': Array(0.00016211, dtype=float32)}


  0%|          | 4739/1000000 [13:49<27:50:21,  9.93it/s]

{'loss': Array(0.11868065, dtype=float32), 'loss_cross_entropy': Array(0.11206182, dtype=float32)}
{'loss_inverse': Array(0.00016938, dtype=float32)}


  0%|          | 4749/1000000 [13:51<33:20:32,  8.29it/s]

{'loss': Array(0.13185044, dtype=float32), 'loss_cross_entropy': Array(0.12450004, dtype=float32)}
{'loss_inverse': Array(0.00014821, dtype=float32)}


  0%|          | 4759/1000000 [13:52<28:52:40,  9.57it/s]

{'loss': Array(0.13055074, dtype=float32), 'loss_cross_entropy': Array(0.12387003, dtype=float32)}
{'loss_inverse': Array(0.00017338, dtype=float32)}


  0%|          | 4769/1000000 [13:54<28:38:20,  9.65it/s]

{'loss': Array(0.12582564, dtype=float32), 'loss_cross_entropy': Array(0.11943483, dtype=float32)}
{'loss_inverse': Array(0.00012112, dtype=float32)}


  0%|          | 4779/1000000 [13:56<44:27:57,  6.22it/s]

{'loss': Array(0.13424608, dtype=float32), 'loss_cross_entropy': Array(0.12693608, dtype=float32)}
{'loss_inverse': Array(9.788057e-05, dtype=float32)}


  0%|          | 4789/1000000 [13:57<31:00:25,  8.92it/s]

{'loss': Array(0.14684436, dtype=float32), 'loss_cross_entropy': Array(0.13846886, dtype=float32)}
{'loss_inverse': Array(0.00011651, dtype=float32)}


  0%|          | 4799/1000000 [13:59<29:13:50,  9.46it/s]

{'loss': Array(0.16349679, dtype=float32), 'loss_cross_entropy': Array(0.1555127, dtype=float32)}
{'loss_inverse': Array(0.00010532, dtype=float32)}


  0%|          | 4809/1000000 [14:00<30:04:25,  9.19it/s]

{'loss': Array(0.14420374, dtype=float32), 'loss_cross_entropy': Array(0.13640833, dtype=float32)}
{'loss_inverse': Array(8.808646e-05, dtype=float32)}


  0%|          | 4819/1000000 [14:02<33:37:22,  8.22it/s]

{'loss': Array(0.11673243, dtype=float32), 'loss_cross_entropy': Array(0.11063679, dtype=float32)}
{'loss_inverse': Array(0.00011762, dtype=float32)}


  0%|          | 4829/1000000 [14:04<29:12:08,  9.47it/s]

{'loss': Array(0.1429132, dtype=float32), 'loss_cross_entropy': Array(0.13491777, dtype=float32)}
{'loss_inverse': Array(0.00010014, dtype=float32)}


  0%|          | 4839/1000000 [14:05<29:04:12,  9.51it/s]

{'loss': Array(0.1414856, dtype=float32), 'loss_cross_entropy': Array(0.1338769, dtype=float32)}
{'loss_inverse': Array(0.00012839, dtype=float32)}


  0%|          | 4849/1000000 [14:07<38:05:49,  7.26it/s]

{'loss': Array(0.12022238, dtype=float32), 'loss_cross_entropy': Array(0.11383355, dtype=float32)}
{'loss_inverse': Array(0.00018121, dtype=float32)}


  0%|          | 4859/1000000 [14:08<29:56:54,  9.23it/s]

{'loss': Array(0.13382716, dtype=float32), 'loss_cross_entropy': Array(0.12627909, dtype=float32)}
{'loss_inverse': Array(9.223604e-05, dtype=float32)}


  0%|          | 4869/1000000 [14:10<30:09:49,  9.16it/s]

{'loss': Array(0.1549968, dtype=float32), 'loss_cross_entropy': Array(0.14629154, dtype=float32)}
{'loss_inverse': Array(0.00012747, dtype=float32)}


  0%|          | 4879/1000000 [14:11<28:34:57,  9.67it/s]

{'loss': Array(0.14695223, dtype=float32), 'loss_cross_entropy': Array(0.13898273, dtype=float32)}
{'loss_inverse': Array(0.00010205, dtype=float32)}


  0%|          | 4889/1000000 [14:13<34:47:49,  7.94it/s]

{'loss': Array(0.11672314, dtype=float32), 'loss_cross_entropy': Array(0.10994929, dtype=float32)}
{'loss_inverse': Array(0.0001101, dtype=float32)}


  0%|          | 4899/1000000 [14:15<31:41:25,  8.72it/s]

{'loss': Array(0.13572712, dtype=float32), 'loss_cross_entropy': Array(0.12864281, dtype=float32)}
{'loss_inverse': Array(7.7896846e-05, dtype=float32)}


  0%|          | 4909/1000000 [14:16<28:39:34,  9.64it/s]

{'loss': Array(0.13976994, dtype=float32), 'loss_cross_entropy': Array(0.13212049, dtype=float32)}
{'loss_inverse': Array(0.00011825, dtype=float32)}


  0%|          | 4919/1000000 [14:18<27:21:53, 10.10it/s]

{'loss': Array(0.14547746, dtype=float32), 'loss_cross_entropy': Array(0.1379116, dtype=float32)}
{'loss_inverse': Array(8.782269e-05, dtype=float32)}


  0%|          | 4929/1000000 [14:20<42:09:07,  6.56it/s]

{'loss': Array(0.1282189, dtype=float32), 'loss_cross_entropy': Array(0.12130902, dtype=float32)}
{'loss_inverse': Array(0.00012747, dtype=float32)}


  0%|          | 4939/1000000 [14:21<30:33:43,  9.04it/s]

{'loss': Array(0.14260478, dtype=float32), 'loss_cross_entropy': Array(0.13478407, dtype=float32)}
{'loss_inverse': Array(8.117794e-05, dtype=float32)}


  0%|          | 4949/1000000 [14:23<28:07:21,  9.83it/s]

{'loss': Array(0.14236042, dtype=float32), 'loss_cross_entropy': Array(0.13477331, dtype=float32)}
{'loss_inverse': Array(0.00012398, dtype=float32)}


  0%|          | 4959/1000000 [14:24<28:27:33,  9.71it/s]

{'loss': Array(0.13412192, dtype=float32), 'loss_cross_entropy': Array(0.12776475, dtype=float32)}
{'loss_inverse': Array(7.3982446e-05, dtype=float32)}


  0%|          | 4970/1000000 [14:26<31:02:25,  8.90it/s]

{'loss': Array(0.1505035, dtype=float32), 'loss_cross_entropy': Array(0.14302972, dtype=float32)}
{'loss_inverse': Array(7.9493926e-05, dtype=float32)}


  0%|          | 4980/1000000 [14:28<28:16:47,  9.77it/s]

{'loss': Array(0.13245931, dtype=float32), 'loss_cross_entropy': Array(0.12516934, dtype=float32)}
{'loss_inverse': Array(0.000106, dtype=float32)}


  0%|          | 4990/1000000 [14:29<28:38:09,  9.65it/s]

{'loss': Array(0.1368325, dtype=float32), 'loss_cross_entropy': Array(0.13005118, dtype=float32)}
{'loss_inverse': Array(9.5641764e-05, dtype=float32)}


  0%|          | 5000/1000000 [14:31<42:33:01,  6.50it/s]

{'loss': Array(0.13076158, dtype=float32), 'loss_cross_entropy': Array(0.1234115, dtype=float32)}
{'loss_inverse': Array(8.317594e-05, dtype=float32)}


  1%|          | 5010/1000000 [14:39<96:55:14,  2.85it/s] 

{'loss': Array(0.11789536, dtype=float32), 'loss_cross_entropy': Array(0.11165829, dtype=float32)}
{'loss_inverse': Array(7.101056e-05, dtype=float32)}


  1%|          | 5020/1000000 [14:41<39:23:31,  7.02it/s] 

{'loss': Array(0.12964725, dtype=float32), 'loss_cross_entropy': Array(0.12314355, dtype=float32)}
{'loss_inverse': Array(9.521626e-05, dtype=float32)}


  1%|          | 5030/1000000 [14:42<30:14:09,  9.14it/s]

{'loss': Array(0.13939779, dtype=float32), 'loss_cross_entropy': Array(0.13177834, dtype=float32)}
{'loss_inverse': Array(8.463365e-05, dtype=float32)}


  1%|          | 5040/1000000 [14:44<42:45:42,  6.46it/s]

{'loss': Array(0.12203052, dtype=float32), 'loss_cross_entropy': Array(0.11529046, dtype=float32)}
{'loss_inverse': Array(6.858193e-05, dtype=float32)}


  1%|          | 5050/1000000 [14:46<30:41:29,  9.00it/s]

{'loss': Array(0.13085322, dtype=float32), 'loss_cross_entropy': Array(0.12411853, dtype=float32)}
{'loss_inverse': Array(0.00012065, dtype=float32)}


  1%|          | 5060/1000000 [14:47<29:16:18,  9.44it/s]

{'loss': Array(0.13210218, dtype=float32), 'loss_cross_entropy': Array(0.12512457, dtype=float32)}
{'loss_inverse': Array(6.854012e-05, dtype=float32)}


  1%|          | 5070/1000000 [14:49<28:28:41,  9.70it/s]

{'loss': Array(0.15091895, dtype=float32), 'loss_cross_entropy': Array(0.14291194, dtype=float32)}
{'loss_inverse': Array(7.866238e-05, dtype=float32)}


  1%|          | 5080/1000000 [14:51<34:16:19,  8.06it/s]

{'loss': Array(0.15026459, dtype=float32), 'loss_cross_entropy': Array(0.14243606, dtype=float32)}
{'loss_inverse': Array(7.611827e-05, dtype=float32)}


  1%|          | 5090/1000000 [14:52<29:41:59,  9.31it/s]

{'loss': Array(0.11072214, dtype=float32), 'loss_cross_entropy': Array(0.10465974, dtype=float32)}
{'loss_inverse': Array(7.8570585e-05, dtype=float32)}


  1%|          | 5100/1000000 [14:54<28:51:56,  9.57it/s]

{'loss': Array(0.132928, dtype=float32), 'loss_cross_entropy': Array(0.1254993, dtype=float32)}
{'loss_inverse': Array(7.75925e-05, dtype=float32)}


  1%|          | 5110/1000000 [14:56<43:12:05,  6.40it/s]

{'loss': Array(0.11901404, dtype=float32), 'loss_cross_entropy': Array(0.11227205, dtype=float32)}
{'loss_inverse': Array(0.00010016, dtype=float32)}


  1%|          | 5120/1000000 [14:57<30:20:53,  9.11it/s]

{'loss': Array(0.12784714, dtype=float32), 'loss_cross_entropy': Array(0.12104284, dtype=float32)}
{'loss_inverse': Array(8.845864e-05, dtype=float32)}


  1%|          | 5130/1000000 [14:59<28:34:54,  9.67it/s]

{'loss': Array(0.13274339, dtype=float32), 'loss_cross_entropy': Array(0.12536605, dtype=float32)}
{'loss_inverse': Array(7.924445e-05, dtype=float32)}


  1%|          | 5140/1000000 [15:00<28:54:28,  9.56it/s]

{'loss': Array(0.13146941, dtype=float32), 'loss_cross_entropy': Array(0.1249453, dtype=float32)}
{'loss_inverse': Array(7.106358e-05, dtype=float32)}


  1%|          | 5150/1000000 [15:02<42:15:36,  6.54it/s]

{'loss': Array(0.12958606, dtype=float32), 'loss_cross_entropy': Array(0.12261178, dtype=float32)}
{'loss_inverse': Array(6.843116e-05, dtype=float32)}


  1%|          | 5160/1000000 [15:04<30:29:32,  9.06it/s]

{'loss': Array(0.12447287, dtype=float32), 'loss_cross_entropy': Array(0.11801939, dtype=float32)}
{'loss_inverse': Array(9.555672e-05, dtype=float32)}


  1%|          | 5170/1000000 [15:05<29:16:20,  9.44it/s]

{'loss': Array(0.13002615, dtype=float32), 'loss_cross_entropy': Array(0.12333923, dtype=float32)}
{'loss_inverse': Array(6.682856e-05, dtype=float32)}


  1%|          | 5180/1000000 [15:07<28:35:30,  9.66it/s]

{'loss': Array(0.13361943, dtype=float32), 'loss_cross_entropy': Array(0.12743635, dtype=float32)}
{'loss_inverse': Array(6.758012e-05, dtype=float32)}


  1%|          | 5190/1000000 [15:09<35:04:38,  7.88it/s]

{'loss': Array(0.12448581, dtype=float32), 'loss_cross_entropy': Array(0.11829566, dtype=float32)}
{'loss_inverse': Array(7.2023126e-05, dtype=float32)}


  1%|          | 5200/1000000 [15:10<30:15:54,  9.13it/s]

{'loss': Array(0.1138001, dtype=float32), 'loss_cross_entropy': Array(0.10733008, dtype=float32)}
{'loss_inverse': Array(5.5382647e-05, dtype=float32)}


  1%|          | 5210/1000000 [15:12<28:22:01,  9.74it/s]

{'loss': Array(0.1402453, dtype=float32), 'loss_cross_entropy': Array(0.13303927, dtype=float32)}
{'loss_inverse': Array(4.7861893e-05, dtype=float32)}


  1%|          | 5220/1000000 [15:14<42:26:54,  6.51it/s]

{'loss': Array(0.13370548, dtype=float32), 'loss_cross_entropy': Array(0.12642539, dtype=float32)}
{'loss_inverse': Array(4.9877457e-05, dtype=float32)}


  1%|          | 5230/1000000 [15:15<31:48:51,  8.69it/s]

{'loss': Array(0.12316649, dtype=float32), 'loss_cross_entropy': Array(0.11642623, dtype=float32)}
{'loss_inverse': Array(6.497058e-05, dtype=float32)}


  1%|          | 5240/1000000 [15:17<29:14:32,  9.45it/s]

{'loss': Array(0.13823682, dtype=float32), 'loss_cross_entropy': Array(0.13116047, dtype=float32)}
{'loss_inverse': Array(4.620693e-05, dtype=float32)}


  1%|          | 5250/1000000 [15:18<28:12:28,  9.80it/s]

{'loss': Array(0.15496486, dtype=float32), 'loss_cross_entropy': Array(0.1467852, dtype=float32)}
{'loss_inverse': Array(4.967913e-05, dtype=float32)}


  1%|          | 5260/1000000 [15:20<33:48:21,  8.17it/s]

{'loss': Array(0.1098357, dtype=float32), 'loss_cross_entropy': Array(0.10361272, dtype=float32)}
{'loss_inverse': Array(4.622573e-05, dtype=float32)}


  1%|          | 5270/1000000 [15:21<28:44:00,  9.62it/s]

{'loss': Array(0.12323407, dtype=float32), 'loss_cross_entropy': Array(0.1163216, dtype=float32)}
{'loss_inverse': Array(3.7123693e-05, dtype=float32)}


  1%|          | 5280/1000000 [15:23<28:07:33,  9.82it/s]

{'loss': Array(0.1402678, dtype=float32), 'loss_cross_entropy': Array(0.13320675, dtype=float32)}
{'loss_inverse': Array(6.187737e-05, dtype=float32)}


  1%|          | 5290/1000000 [15:25<49:53:44,  5.54it/s]

{'loss': Array(0.12043224, dtype=float32), 'loss_cross_entropy': Array(0.11395317, dtype=float32)}
{'loss_inverse': Array(5.0214032e-05, dtype=float32)}


  1%|          | 5300/1000000 [15:26<32:32:22,  8.49it/s]

{'loss': Array(0.16070823, dtype=float32), 'loss_cross_entropy': Array(0.15257059, dtype=float32)}
{'loss_inverse': Array(5.978464e-05, dtype=float32)}


  1%|          | 5310/1000000 [15:28<29:14:28,  9.45it/s]

{'loss': Array(0.13226211, dtype=float32), 'loss_cross_entropy': Array(0.12490722, dtype=float32)}
{'loss_inverse': Array(5.1926596e-05, dtype=float32)}


  1%|          | 5320/1000000 [15:29<28:07:51,  9.82it/s]

{'loss': Array(0.13533889, dtype=float32), 'loss_cross_entropy': Array(0.12806498, dtype=float32)}
{'loss_inverse': Array(4.8176324e-05, dtype=float32)}


  1%|          | 5330/1000000 [15:31<43:19:20,  6.38it/s]

{'loss': Array(0.13522922, dtype=float32), 'loss_cross_entropy': Array(0.12796077, dtype=float32)}
{'loss_inverse': Array(4.78799e-05, dtype=float32)}


  1%|          | 5340/1000000 [15:33<30:49:16,  8.96it/s]

{'loss': Array(0.1411686, dtype=float32), 'loss_cross_entropy': Array(0.1336123, dtype=float32)}
{'loss_inverse': Array(4.844159e-05, dtype=float32)}


  1%|          | 5350/1000000 [15:34<28:32:13,  9.68it/s]

{'loss': Array(0.14589652, dtype=float32), 'loss_cross_entropy': Array(0.13782378, dtype=float32)}
{'loss_inverse': Array(3.603119e-05, dtype=float32)}


  1%|          | 5360/1000000 [15:36<29:59:16,  9.21it/s]

{'loss': Array(0.14394937, dtype=float32), 'loss_cross_entropy': Array(0.13663, dtype=float32)}
{'loss_inverse': Array(4.1008312e-05, dtype=float32)}


  1%|          | 5370/1000000 [15:38<33:38:54,  8.21it/s]

{'loss': Array(0.12923972, dtype=float32), 'loss_cross_entropy': Array(0.12267376, dtype=float32)}
{'loss_inverse': Array(4.6988964e-05, dtype=float32)}


  1%|          | 5380/1000000 [15:39<29:23:29,  9.40it/s]

{'loss': Array(0.12086892, dtype=float32), 'loss_cross_entropy': Array(0.11435585, dtype=float32)}
{'loss_inverse': Array(4.6333425e-05, dtype=float32)}


  1%|          | 5390/1000000 [15:41<28:16:18,  9.77it/s]

{'loss': Array(0.15424405, dtype=float32), 'loss_cross_entropy': Array(0.14595057, dtype=float32)}
{'loss_inverse': Array(6.0292747e-05, dtype=float32)}


  1%|          | 5400/1000000 [15:43<38:15:01,  7.22it/s]

{'loss': Array(0.12907912, dtype=float32), 'loss_cross_entropy': Array(0.12205564, dtype=float32)}
{'loss_inverse': Array(4.6199413e-05, dtype=float32)}


  1%|          | 5410/1000000 [15:44<29:15:04,  9.44it/s]

{'loss': Array(0.12693566, dtype=float32), 'loss_cross_entropy': Array(0.12068392, dtype=float32)}
{'loss_inverse': Array(4.6153767e-05, dtype=float32)}


  1%|          | 5420/1000000 [15:46<29:34:59,  9.34it/s]

{'loss': Array(0.14654939, dtype=float32), 'loss_cross_entropy': Array(0.13906728, dtype=float32)}
{'loss_inverse': Array(4.30084e-05, dtype=float32)}


  1%|          | 5430/1000000 [15:47<28:27:59,  9.71it/s]

{'loss': Array(0.13539316, dtype=float32), 'loss_cross_entropy': Array(0.12902215, dtype=float32)}
{'loss_inverse': Array(4.0242026e-05, dtype=float32)}


  1%|          | 5440/1000000 [15:49<34:49:00,  7.93it/s]

{'loss': Array(0.1445325, dtype=float32), 'loss_cross_entropy': Array(0.13713776, dtype=float32)}
{'loss_inverse': Array(7.023805e-05, dtype=float32)}


  1%|          | 5450/1000000 [15:51<29:46:13,  9.28it/s]

{'loss': Array(0.14799333, dtype=float32), 'loss_cross_entropy': Array(0.14020461, dtype=float32)}
{'loss_inverse': Array(4.561668e-05, dtype=float32)}


  1%|          | 5460/1000000 [15:52<28:00:02,  9.87it/s]

{'loss': Array(0.14793274, dtype=float32), 'loss_cross_entropy': Array(0.14072166, dtype=float32)}
{'loss_inverse': Array(4.8330352e-05, dtype=float32)}


  1%|          | 5470/1000000 [15:54<28:15:18,  9.78it/s]

{'loss': Array(0.14424068, dtype=float32), 'loss_cross_entropy': Array(0.1368305, dtype=float32)}
{'loss_inverse': Array(5.738798e-05, dtype=float32)}


  1%|          | 5480/1000000 [15:56<35:50:13,  7.71it/s]

{'loss': Array(0.15476517, dtype=float32), 'loss_cross_entropy': Array(0.14683978, dtype=float32)}
{'loss_inverse': Array(9.064074e-05, dtype=float32)}


  1%|          | 5490/1000000 [15:57<31:26:31,  8.79it/s]

{'loss': Array(0.12706617, dtype=float32), 'loss_cross_entropy': Array(0.11986364, dtype=float32)}
{'loss_inverse': Array(8.1098246e-05, dtype=float32)}


  1%|          | 5500/1000000 [15:59<28:27:23,  9.71it/s]

{'loss': Array(0.15383436, dtype=float32), 'loss_cross_entropy': Array(0.14619274, dtype=float32)}
{'loss_inverse': Array(5.8901736e-05, dtype=float32)}


  1%|          | 5510/1000000 [16:07<94:50:11,  2.91it/s] 

{'loss': Array(0.14348564, dtype=float32), 'loss_cross_entropy': Array(0.13586037, dtype=float32)}
{'loss_inverse': Array(5.7368663e-05, dtype=float32)}


  1%|          | 5520/1000000 [16:09<44:09:10,  6.26it/s] 

{'loss': Array(0.14133461, dtype=float32), 'loss_cross_entropy': Array(0.13397385, dtype=float32)}
{'loss_inverse': Array(4.491712e-05, dtype=float32)}


  1%|          | 5530/1000000 [16:10<31:34:12,  8.75it/s]

{'loss': Array(0.12414088, dtype=float32), 'loss_cross_entropy': Array(0.11731931, dtype=float32)}
{'loss_inverse': Array(4.972467e-05, dtype=float32)}


  1%|          | 5540/1000000 [16:12<29:13:37,  9.45it/s]

{'loss': Array(0.13200684, dtype=float32), 'loss_cross_entropy': Array(0.12517393, dtype=float32)}
{'loss_inverse': Array(4.7226844e-05, dtype=float32)}


  1%|          | 5550/1000000 [16:14<38:05:47,  7.25it/s]

{'loss': Array(0.13416816, dtype=float32), 'loss_cross_entropy': Array(0.12760161, dtype=float32)}
{'loss_inverse': Array(4.3007047e-05, dtype=float32)}


  1%|          | 5560/1000000 [16:15<31:48:59,  8.68it/s]

{'loss': Array(0.14733265, dtype=float32), 'loss_cross_entropy': Array(0.13950911, dtype=float32)}
{'loss_inverse': Array(3.4094912e-05, dtype=float32)}


  1%|          | 5570/1000000 [16:17<30:29:14,  9.06it/s]

{'loss': Array(0.1527162, dtype=float32), 'loss_cross_entropy': Array(0.14529411, dtype=float32)}
{'loss_inverse': Array(6.958085e-05, dtype=float32)}


  1%|          | 5580/1000000 [16:19<42:32:56,  6.49it/s]

{'loss': Array(0.1640697, dtype=float32), 'loss_cross_entropy': Array(0.15620618, dtype=float32)}
{'loss_inverse': Array(3.7924627e-05, dtype=float32)}


  1%|          | 5590/1000000 [16:20<31:42:50,  8.71it/s]

{'loss': Array(0.12113614, dtype=float32), 'loss_cross_entropy': Array(0.11459314, dtype=float32)}
{'loss_inverse': Array(5.1768293e-05, dtype=float32)}


  1%|          | 5600/1000000 [16:22<28:27:50,  9.70it/s]

{'loss': Array(0.14122336, dtype=float32), 'loss_cross_entropy': Array(0.13368267, dtype=float32)}
{'loss_inverse': Array(5.7292636e-05, dtype=float32)}


  1%|          | 5610/1000000 [16:23<28:25:12,  9.72it/s]

{'loss': Array(0.1465638, dtype=float32), 'loss_cross_entropy': Array(0.13858469, dtype=float32)}
{'loss_inverse': Array(4.8646118e-05, dtype=float32)}


  1%|          | 5620/1000000 [16:25<34:19:36,  8.05it/s]

{'loss': Array(0.13067882, dtype=float32), 'loss_cross_entropy': Array(0.12375476, dtype=float32)}
{'loss_inverse': Array(4.8566384e-05, dtype=float32)}


  1%|          | 5630/1000000 [16:27<28:36:35,  9.65it/s]

{'loss': Array(0.13410906, dtype=float32), 'loss_cross_entropy': Array(0.12660427, dtype=float32)}
{'loss_inverse': Array(4.713658e-05, dtype=float32)}


  1%|          | 5640/1000000 [16:28<28:07:02,  9.82it/s]

{'loss': Array(0.11593091, dtype=float32), 'loss_cross_entropy': Array(0.10987695, dtype=float32)}
{'loss_inverse': Array(3.478261e-05, dtype=float32)}


  1%|          | 5650/1000000 [16:30<43:21:37,  6.37it/s]

{'loss': Array(0.15456167, dtype=float32), 'loss_cross_entropy': Array(0.14706233, dtype=float32)}
{'loss_inverse': Array(3.529701e-05, dtype=float32)}


  1%|          | 5660/1000000 [16:31<30:42:45,  8.99it/s]

{'loss': Array(0.13889816, dtype=float32), 'loss_cross_entropy': Array(0.13134788, dtype=float32)}
{'loss_inverse': Array(4.1561154e-05, dtype=float32)}


  1%|          | 5670/1000000 [16:33<28:23:51,  9.73it/s]

{'loss': Array(0.1410871, dtype=float32), 'loss_cross_entropy': Array(0.1338121, dtype=float32)}
{'loss_inverse': Array(3.2961434e-05, dtype=float32)}


  1%|          | 5680/1000000 [16:34<28:57:38,  9.54it/s]

{'loss': Array(0.1420555, dtype=float32), 'loss_cross_entropy': Array(0.13407609, dtype=float32)}
{'loss_inverse': Array(3.6917543e-05, dtype=float32)}


  1%|          | 5690/1000000 [16:36<28:34:07,  9.67it/s]

{'loss': Array(0.15080655, dtype=float32), 'loss_cross_entropy': Array(0.14250588, dtype=float32)}
{'loss_inverse': Array(3.3388493e-05, dtype=float32)}


  1%|          | 5700/1000000 [16:38<33:34:54,  8.22it/s]

{'loss': Array(0.13876992, dtype=float32), 'loss_cross_entropy': Array(0.13141297, dtype=float32)}
{'loss_inverse': Array(3.2699787e-05, dtype=float32)}


  1%|          | 5710/1000000 [16:39<28:43:02,  9.62it/s]

{'loss': Array(0.1381135, dtype=float32), 'loss_cross_entropy': Array(0.13062486, dtype=float32)}
{'loss_inverse': Array(3.7400572e-05, dtype=float32)}


  1%|          | 5720/1000000 [16:41<28:48:36,  9.59it/s]

{'loss': Array(0.12431498, dtype=float32), 'loss_cross_entropy': Array(0.11734574, dtype=float32)}
{'loss_inverse': Array(3.4281973e-05, dtype=float32)}


  1%|          | 5730/1000000 [16:42<28:21:35,  9.74it/s]

{'loss': Array(0.13176456, dtype=float32), 'loss_cross_entropy': Array(0.12468227, dtype=float32)}
{'loss_inverse': Array(3.0118394e-05, dtype=float32)}


  1%|          | 5740/1000000 [16:44<33:02:49,  8.36it/s]

{'loss': Array(0.13425435, dtype=float32), 'loss_cross_entropy': Array(0.12726143, dtype=float32)}
{'loss_inverse': Array(2.7706084e-05, dtype=float32)}


  1%|          | 5750/1000000 [16:46<29:27:01,  9.38it/s]

{'loss': Array(0.14175415, dtype=float32), 'loss_cross_entropy': Array(0.13441488, dtype=float32)}
{'loss_inverse': Array(3.2360644e-05, dtype=float32)}


  1%|          | 5760/1000000 [16:47<27:44:44,  9.95it/s]

{'loss': Array(0.160515, dtype=float32), 'loss_cross_entropy': Array(0.15252897, dtype=float32)}
{'loss_inverse': Array(2.997645e-05, dtype=float32)}


  1%|          | 5770/1000000 [16:49<37:33:09,  7.35it/s]

{'loss': Array(0.13750528, dtype=float32), 'loss_cross_entropy': Array(0.12994291, dtype=float32)}
{'loss_inverse': Array(4.6676698e-05, dtype=float32)}


  1%|          | 5780/1000000 [16:51<31:15:44,  8.83it/s]

{'loss': Array(0.1416748, dtype=float32), 'loss_cross_entropy': Array(0.13405712, dtype=float32)}
{'loss_inverse': Array(2.7981965e-05, dtype=float32)}


  1%|          | 5790/1000000 [16:52<28:04:13,  9.84it/s]

{'loss': Array(0.14083926, dtype=float32), 'loss_cross_entropy': Array(0.13332257, dtype=float32)}
{'loss_inverse': Array(2.9341778e-05, dtype=float32)}


  1%|          | 5800/1000000 [16:54<42:28:26,  6.50it/s]

{'loss': Array(0.14573179, dtype=float32), 'loss_cross_entropy': Array(0.1376234, dtype=float32)}
{'loss_inverse': Array(2.7967364e-05, dtype=float32)}


  1%|          | 5810/1000000 [16:56<31:03:42,  8.89it/s]

{'loss': Array(0.12930274, dtype=float32), 'loss_cross_entropy': Array(0.12228469, dtype=float32)}
{'loss_inverse': Array(3.0205341e-05, dtype=float32)}


  1%|          | 5820/1000000 [16:57<28:32:11,  9.68it/s]

{'loss': Array(0.13799602, dtype=float32), 'loss_cross_entropy': Array(0.13068743, dtype=float32)}
{'loss_inverse': Array(2.7648666e-05, dtype=float32)}


  1%|          | 5830/1000000 [16:59<28:38:58,  9.64it/s]

{'loss': Array(0.13113968, dtype=float32), 'loss_cross_entropy': Array(0.12421119, dtype=float32)}
{'loss_inverse': Array(2.7693879e-05, dtype=float32)}


  1%|          | 5840/1000000 [17:01<50:14:25,  5.50it/s]

{'loss': Array(0.13679367, dtype=float32), 'loss_cross_entropy': Array(0.1294663, dtype=float32)}
{'loss_inverse': Array(2.6834512e-05, dtype=float32)}


  1%|          | 5849/1000000 [17:02<31:56:56,  8.64it/s]

{'loss': Array(0.14831558, dtype=float32), 'loss_cross_entropy': Array(0.14056613, dtype=float32)}
{'loss_inverse': Array(3.099169e-05, dtype=float32)}


  1%|          | 5859/1000000 [17:04<28:44:34,  9.61it/s]

{'loss': Array(0.13836586, dtype=float32), 'loss_cross_entropy': Array(0.13136293, dtype=float32)}
{'loss_inverse': Array(2.3175662e-05, dtype=float32)}


  1%|          | 5869/1000000 [17:05<30:16:22,  9.12it/s]

{'loss': Array(0.14275216, dtype=float32), 'loss_cross_entropy': Array(0.13489823, dtype=float32)}
{'loss_inverse': Array(3.6200232e-05, dtype=float32)}


  1%|          | 5879/1000000 [17:07<43:13:07,  6.39it/s]

{'loss': Array(0.11633956, dtype=float32), 'loss_cross_entropy': Array(0.10971429, dtype=float32)}
{'loss_inverse': Array(3.300139e-05, dtype=float32)}


  1%|          | 5889/1000000 [17:09<31:15:16,  8.84it/s]

{'loss': Array(0.1610518, dtype=float32), 'loss_cross_entropy': Array(0.15281071, dtype=float32)}
{'loss_inverse': Array(2.8409751e-05, dtype=float32)}


  1%|          | 5899/1000000 [17:10<30:11:38,  9.15it/s]

{'loss': Array(0.13465677, dtype=float32), 'loss_cross_entropy': Array(0.1275693, dtype=float32)}
{'loss_inverse': Array(2.8146822e-05, dtype=float32)}


  1%|          | 5909/1000000 [17:12<28:38:15,  9.64it/s]

{'loss': Array(0.12500483, dtype=float32), 'loss_cross_entropy': Array(0.11818756, dtype=float32)}
{'loss_inverse': Array(2.6194117e-05, dtype=float32)}


  1%|          | 5920/1000000 [17:14<31:05:05,  8.88it/s]

{'loss': Array(0.13369198, dtype=float32), 'loss_cross_entropy': Array(0.12623934, dtype=float32)}
{'loss_inverse': Array(2.5018342e-05, dtype=float32)}


  1%|          | 5930/1000000 [17:15<29:30:43,  9.36it/s]

{'loss': Array(0.13617606, dtype=float32), 'loss_cross_entropy': Array(0.12895817, dtype=float32)}
{'loss_inverse': Array(6.105424e-05, dtype=float32)}


  1%|          | 5940/1000000 [17:17<30:11:31,  9.15it/s]

{'loss': Array(0.15182841, dtype=float32), 'loss_cross_entropy': Array(0.1443803, dtype=float32)}
{'loss_inverse': Array(8.255429e-05, dtype=float32)}


  1%|          | 5950/1000000 [17:19<38:49:04,  7.11it/s]

{'loss': Array(0.13700567, dtype=float32), 'loss_cross_entropy': Array(0.129531, dtype=float32)}
{'loss_inverse': Array(6.08832e-05, dtype=float32)}


  1%|          | 5960/1000000 [17:20<31:01:28,  8.90it/s]

{'loss': Array(0.14204247, dtype=float32), 'loss_cross_entropy': Array(0.13492224, dtype=float32)}
{'loss_inverse': Array(8.537668e-05, dtype=float32)}


  1%|          | 5970/1000000 [17:22<28:31:30,  9.68it/s]

{'loss': Array(0.14369707, dtype=float32), 'loss_cross_entropy': Array(0.13698372, dtype=float32)}
{'loss_inverse': Array(5.820102e-05, dtype=float32)}


  1%|          | 5980/1000000 [17:23<28:15:35,  9.77it/s]

{'loss': Array(0.14123736, dtype=float32), 'loss_cross_entropy': Array(0.13417573, dtype=float32)}
{'loss_inverse': Array(4.219475e-05, dtype=float32)}


  1%|          | 5990/1000000 [17:25<36:18:54,  7.60it/s]

{'loss': Array(0.14222687, dtype=float32), 'loss_cross_entropy': Array(0.13551874, dtype=float32)}
{'loss_inverse': Array(3.07118e-05, dtype=float32)}


  1%|          | 6000/1000000 [17:27<29:50:55,  9.25it/s]

{'loss': Array(0.15314107, dtype=float32), 'loss_cross_entropy': Array(0.14501159, dtype=float32)}
{'loss_inverse': Array(4.774142e-05, dtype=float32)}


  1%|          | 6010/1000000 [17:35<98:05:39,  2.81it/s] 

{'loss': Array(0.1401286, dtype=float32), 'loss_cross_entropy': Array(0.13313745, dtype=float32)}
{'loss_inverse': Array(2.9663608e-05, dtype=float32)}


  1%|          | 6020/1000000 [17:37<40:49:26,  6.76it/s] 

{'loss': Array(0.14294265, dtype=float32), 'loss_cross_entropy': Array(0.1354373, dtype=float32)}
{'loss_inverse': Array(3.0093128e-05, dtype=float32)}


  1%|          | 6030/1000000 [17:39<34:57:37,  7.90it/s]

{'loss': Array(0.13282506, dtype=float32), 'loss_cross_entropy': Array(0.12580146, dtype=float32)}
{'loss_inverse': Array(3.1905747e-05, dtype=float32)}


  1%|          | 6040/1000000 [17:40<30:56:53,  8.92it/s]

{'loss': Array(0.13935032, dtype=float32), 'loss_cross_entropy': Array(0.13167746, dtype=float32)}
{'loss_inverse': Array(3.419243e-05, dtype=float32)}


  1%|          | 6050/1000000 [17:41<28:43:29,  9.61it/s]

{'loss': Array(0.13630356, dtype=float32), 'loss_cross_entropy': Array(0.12886904, dtype=float32)}
{'loss_inverse': Array(5.0219765e-05, dtype=float32)}


  1%|          | 6060/1000000 [17:43<42:44:54,  6.46it/s]

{'loss': Array(0.13852526, dtype=float32), 'loss_cross_entropy': Array(0.13125281, dtype=float32)}
{'loss_inverse': Array(5.605229e-05, dtype=float32)}


  1%|          | 6070/1000000 [17:45<32:39:07,  8.46it/s]

{'loss': Array(0.12739952, dtype=float32), 'loss_cross_entropy': Array(0.12076364, dtype=float32)}
{'loss_inverse': Array(4.3110318e-05, dtype=float32)}


  1%|          | 6080/1000000 [17:46<29:22:10,  9.40it/s]

{'loss': Array(0.14593704, dtype=float32), 'loss_cross_entropy': Array(0.13863966, dtype=float32)}
{'loss_inverse': Array(2.8975846e-05, dtype=float32)}


  1%|          | 6090/1000000 [17:48<28:12:35,  9.79it/s]

{'loss': Array(0.11539447, dtype=float32), 'loss_cross_entropy': Array(0.10920115, dtype=float32)}
{'loss_inverse': Array(3.5332832e-05, dtype=float32)}


  1%|          | 6100/1000000 [17:50<34:36:13,  7.98it/s]

{'loss': Array(0.13008513, dtype=float32), 'loss_cross_entropy': Array(0.12313672, dtype=float32)}
{'loss_inverse': Array(4.7266974e-05, dtype=float32)}


  1%|          | 6110/1000000 [17:51<29:34:08,  9.34it/s]

{'loss': Array(0.13620044, dtype=float32), 'loss_cross_entropy': Array(0.1287133, dtype=float32)}
{'loss_inverse': Array(4.7906015e-05, dtype=float32)}


  1%|          | 6120/1000000 [17:53<27:44:53,  9.95it/s]

{'loss': Array(0.11445693, dtype=float32), 'loss_cross_entropy': Array(0.10835586, dtype=float32)}
{'loss_inverse': Array(0.00010779, dtype=float32)}


  1%|          | 6130/1000000 [17:55<44:15:15,  6.24it/s]

{'loss': Array(0.1285549, dtype=float32), 'loss_cross_entropy': Array(0.12177237, dtype=float32)}
{'loss_inverse': Array(0.00017477, dtype=float32)}


  1%|          | 6140/1000000 [17:56<30:57:05,  8.92it/s]

{'loss': Array(0.12618144, dtype=float32), 'loss_cross_entropy': Array(0.12011307, dtype=float32)}
{'loss_inverse': Array(0.0001632, dtype=float32)}


  1%|          | 6150/1000000 [17:58<28:40:00,  9.63it/s]

{'loss': Array(0.15411961, dtype=float32), 'loss_cross_entropy': Array(0.14646712, dtype=float32)}
{'loss_inverse': Array(0.00019742, dtype=float32)}


  1%|          | 6160/1000000 [17:59<28:08:27,  9.81it/s]

{'loss': Array(0.12656318, dtype=float32), 'loss_cross_entropy': Array(0.11972704, dtype=float32)}
{'loss_inverse': Array(0.00040565, dtype=float32)}


  1%|          | 6170/1000000 [18:01<33:45:32,  8.18it/s]

{'loss': Array(0.13523823, dtype=float32), 'loss_cross_entropy': Array(0.128327, dtype=float32)}
{'loss_inverse': Array(0.0235283, dtype=float32)}


  1%|          | 6180/1000000 [18:03<29:24:39,  9.39it/s]

{'loss': Array(0.13365924, dtype=float32), 'loss_cross_entropy': Array(0.12677024, dtype=float32)}
{'loss_inverse': Array(0.00443875, dtype=float32)}


  1%|          | 6190/1000000 [18:04<28:03:36,  9.84it/s]

{'loss': Array(0.13429426, dtype=float32), 'loss_cross_entropy': Array(0.1269645, dtype=float32)}
{'loss_inverse': Array(0.00460282, dtype=float32)}


  1%|          | 6200/1000000 [18:06<28:56:08,  9.54it/s]

{'loss': Array(0.11709856, dtype=float32), 'loss_cross_entropy': Array(0.11101791, dtype=float32)}
{'loss_inverse': Array(0.00298832, dtype=float32)}


  1%|          | 6210/1000000 [18:08<34:48:46,  7.93it/s]

{'loss': Array(0.1362258, dtype=float32), 'loss_cross_entropy': Array(0.12891755, dtype=float32)}
{'loss_inverse': Array(0.00168058, dtype=float32)}


  1%|          | 6220/1000000 [18:09<28:28:42,  9.69it/s]

{'loss': Array(0.1276826, dtype=float32), 'loss_cross_entropy': Array(0.12016888, dtype=float32)}
{'loss_inverse': Array(0.00107709, dtype=float32)}


  1%|          | 6230/1000000 [18:11<28:26:00,  9.71it/s]

{'loss': Array(0.12737538, dtype=float32), 'loss_cross_entropy': Array(0.12013374, dtype=float32)}
{'loss_inverse': Array(0.00061361, dtype=float32)}


  1%|          | 6240/1000000 [18:13<48:14:19,  5.72it/s]

{'loss': Array(0.15047656, dtype=float32), 'loss_cross_entropy': Array(0.1426705, dtype=float32)}
{'loss_inverse': Array(0.00069179, dtype=float32)}


  1%|          | 6250/1000000 [18:14<31:42:05,  8.71it/s]

{'loss': Array(0.1456841, dtype=float32), 'loss_cross_entropy': Array(0.13864274, dtype=float32)}
{'loss_inverse': Array(0.00050406, dtype=float32)}


  1%|          | 6260/1000000 [18:16<29:23:57,  9.39it/s]

{'loss': Array(0.14182435, dtype=float32), 'loss_cross_entropy': Array(0.13440947, dtype=float32)}
{'loss_inverse': Array(0.00050974, dtype=float32)}


  1%|          | 6270/1000000 [18:17<27:57:11,  9.87it/s]

{'loss': Array(0.12312416, dtype=float32), 'loss_cross_entropy': Array(0.11621092, dtype=float32)}
{'loss_inverse': Array(0.00046654, dtype=float32)}


  1%|          | 6280/1000000 [18:19<34:52:18,  7.92it/s]

{'loss': Array(0.13922058, dtype=float32), 'loss_cross_entropy': Array(0.13215575, dtype=float32)}
{'loss_inverse': Array(0.00040356, dtype=float32)}


  1%|          | 6290/1000000 [18:20<29:28:29,  9.36it/s]

{'loss': Array(0.13371901, dtype=float32), 'loss_cross_entropy': Array(0.12670667, dtype=float32)}
{'loss_inverse': Array(0.00049787, dtype=float32)}


  1%|          | 6300/1000000 [18:22<28:49:39,  9.58it/s]

{'loss': Array(0.14229052, dtype=float32), 'loss_cross_entropy': Array(0.13513824, dtype=float32)}
{'loss_inverse': Array(0.00053654, dtype=float32)}


  1%|          | 6310/1000000 [18:24<43:14:25,  6.38it/s]

{'loss': Array(0.14364187, dtype=float32), 'loss_cross_entropy': Array(0.1360654, dtype=float32)}
{'loss_inverse': Array(0.00028547, dtype=float32)}


  1%|          | 6320/1000000 [18:25<31:07:59,  8.87it/s]

{'loss': Array(0.14339218, dtype=float32), 'loss_cross_entropy': Array(0.1353227, dtype=float32)}
{'loss_inverse': Array(0.00019656, dtype=float32)}


  1%|          | 6330/1000000 [18:27<28:48:53,  9.58it/s]

{'loss': Array(0.16806835, dtype=float32), 'loss_cross_entropy': Array(0.15972926, dtype=float32)}
{'loss_inverse': Array(0.00030582, dtype=float32)}


  1%|          | 6340/1000000 [18:28<28:42:10,  9.62it/s]

{'loss': Array(0.12087791, dtype=float32), 'loss_cross_entropy': Array(0.11416636, dtype=float32)}
{'loss_inverse': Array(0.00020362, dtype=float32)}


  1%|          | 6350/1000000 [18:30<38:48:40,  7.11it/s]

{'loss': Array(0.12411793, dtype=float32), 'loss_cross_entropy': Array(0.11696649, dtype=float32)}
{'loss_inverse': Array(0.00018926, dtype=float32)}


  1%|          | 6360/1000000 [18:32<30:06:02,  9.17it/s]

{'loss': Array(0.13637297, dtype=float32), 'loss_cross_entropy': Array(0.12910135, dtype=float32)}
{'loss_inverse': Array(0.00023875, dtype=float32)}


  1%|          | 6370/1000000 [18:33<28:31:36,  9.68it/s]

{'loss': Array(0.13429163, dtype=float32), 'loss_cross_entropy': Array(0.12707916, dtype=float32)}
{'loss_inverse': Array(0.00017892, dtype=float32)}


  1%|          | 6380/1000000 [18:35<28:53:41,  9.55it/s]

{'loss': Array(0.14368445, dtype=float32), 'loss_cross_entropy': Array(0.13631979, dtype=float32)}
{'loss_inverse': Array(0.00030347, dtype=float32)}


  1%|          | 6390/1000000 [18:37<39:41:04,  6.95it/s]

{'loss': Array(0.12448356, dtype=float32), 'loss_cross_entropy': Array(0.11701442, dtype=float32)}
{'loss_inverse': Array(0.0001855, dtype=float32)}


  1%|          | 6400/1000000 [18:38<30:03:43,  9.18it/s]

{'loss': Array(0.14934047, dtype=float32), 'loss_cross_entropy': Array(0.1419565, dtype=float32)}
{'loss_inverse': Array(0.00018439, dtype=float32)}


  1%|          | 6410/1000000 [18:40<29:35:26,  9.33it/s]

{'loss': Array(0.14054321, dtype=float32), 'loss_cross_entropy': Array(0.13306797, dtype=float32)}
{'loss_inverse': Array(0.00022282, dtype=float32)}


  1%|          | 6420/1000000 [18:41<29:17:04,  9.42it/s]

{'loss': Array(0.14074147, dtype=float32), 'loss_cross_entropy': Array(0.13385785, dtype=float32)}
{'loss_inverse': Array(0.00018528, dtype=float32)}


  1%|          | 6430/1000000 [18:43<32:59:59,  8.36it/s]

{'loss': Array(0.14762738, dtype=float32), 'loss_cross_entropy': Array(0.1398146, dtype=float32)}
{'loss_inverse': Array(0.00021918, dtype=float32)}


  1%|          | 6440/1000000 [18:45<28:48:34,  9.58it/s]

{'loss': Array(0.15040132, dtype=float32), 'loss_cross_entropy': Array(0.14197792, dtype=float32)}
{'loss_inverse': Array(0.00022911, dtype=float32)}


  1%|          | 6450/1000000 [18:46<28:51:52,  9.56it/s]

{'loss': Array(0.14085941, dtype=float32), 'loss_cross_entropy': Array(0.13330431, dtype=float32)}
{'loss_inverse': Array(0.00013069, dtype=float32)}


  1%|          | 6460/1000000 [18:48<43:14:52,  6.38it/s]

{'loss': Array(0.12944107, dtype=float32), 'loss_cross_entropy': Array(0.12214911, dtype=float32)}
{'loss_inverse': Array(0.00024996, dtype=float32)}


  1%|          | 6470/1000000 [18:50<30:02:26,  9.19it/s]

{'loss': Array(0.15220757, dtype=float32), 'loss_cross_entropy': Array(0.14434902, dtype=float32)}
{'loss_inverse': Array(0.00016703, dtype=float32)}


  1%|          | 6480/1000000 [18:51<28:41:22,  9.62it/s]

{'loss': Array(0.13914756, dtype=float32), 'loss_cross_entropy': Array(0.13202906, dtype=float32)}
{'loss_inverse': Array(0.00020938, dtype=float32)}


  1%|          | 6490/1000000 [18:53<28:28:59,  9.69it/s]

{'loss': Array(0.1302936, dtype=float32), 'loss_cross_entropy': Array(0.12329169, dtype=float32)}
{'loss_inverse': Array(0.00034082, dtype=float32)}


  1%|          | 6500/1000000 [18:55<33:17:13,  8.29it/s]

{'loss': Array(0.12155759, dtype=float32), 'loss_cross_entropy': Array(0.11521073, dtype=float32)}
{'loss_inverse': Array(0.00030019, dtype=float32)}


  1%|          | 6510/1000000 [19:03<96:12:23,  2.87it/s] 

{'loss': Array(0.13124315, dtype=float32), 'loss_cross_entropy': Array(0.12356978, dtype=float32)}
{'loss_inverse': Array(0.00035228, dtype=float32)}


  1%|          | 6520/1000000 [19:04<39:26:54,  7.00it/s] 

{'loss': Array(0.15773426, dtype=float32), 'loss_cross_entropy': Array(0.14966968, dtype=float32)}
{'loss_inverse': Array(0.00053466, dtype=float32)}


  1%|          | 6530/1000000 [19:06<39:41:45,  6.95it/s]

{'loss': Array(0.14996436, dtype=float32), 'loss_cross_entropy': Array(0.14265119, dtype=float32)}
{'loss_inverse': Array(0.00035822, dtype=float32)}


  1%|          | 6540/1000000 [19:08<29:37:08,  9.32it/s]

{'loss': Array(0.13269593, dtype=float32), 'loss_cross_entropy': Array(0.12549138, dtype=float32)}
{'loss_inverse': Array(0.00025922, dtype=float32)}


  1%|          | 6550/1000000 [19:09<28:25:53,  9.71it/s]

{'loss': Array(0.15730649, dtype=float32), 'loss_cross_entropy': Array(0.14865522, dtype=float32)}
{'loss_inverse': Array(0.00025287, dtype=float32)}


  1%|          | 6560/1000000 [19:11<49:10:28,  5.61it/s]

{'loss': Array(0.14555667, dtype=float32), 'loss_cross_entropy': Array(0.13793166, dtype=float32)}
{'loss_inverse': Array(0.00016174, dtype=float32)}


  1%|          | 6570/1000000 [19:13<31:30:04,  8.76it/s]

{'loss': Array(0.1403728, dtype=float32), 'loss_cross_entropy': Array(0.13316976, dtype=float32)}
{'loss_inverse': Array(0.00020036, dtype=float32)}


  1%|          | 6580/1000000 [19:14<28:28:00,  9.69it/s]

{'loss': Array(0.13303386, dtype=float32), 'loss_cross_entropy': Array(0.1253865, dtype=float32)}
{'loss_inverse': Array(0.00026044, dtype=float32)}


  1%|          | 6590/1000000 [19:16<27:52:36,  9.90it/s]

{'loss': Array(0.14023693, dtype=float32), 'loss_cross_entropy': Array(0.1327623, dtype=float32)}
{'loss_inverse': Array(0.00027229, dtype=float32)}


  1%|          | 6600/1000000 [19:17<28:12:13,  9.78it/s]

{'loss': Array(0.12225275, dtype=float32), 'loss_cross_entropy': Array(0.11555976, dtype=float32)}
{'loss_inverse': Array(0.00024697, dtype=float32)}


  1%|          | 6610/1000000 [19:19<32:37:27,  8.46it/s]

{'loss': Array(0.11405756, dtype=float32), 'loss_cross_entropy': Array(0.1083283, dtype=float32)}
{'loss_inverse': Array(0.00012365, dtype=float32)}


  1%|          | 6620/1000000 [19:21<29:40:28,  9.30it/s]

{'loss': Array(0.14989237, dtype=float32), 'loss_cross_entropy': Array(0.14215095, dtype=float32)}
{'loss_inverse': Array(0.00014354, dtype=float32)}


  1%|          | 6630/1000000 [19:22<28:38:02,  9.64it/s]

{'loss': Array(0.14102508, dtype=float32), 'loss_cross_entropy': Array(0.13327599, dtype=float32)}
{'loss_inverse': Array(0.0002208, dtype=float32)}


  1%|          | 6640/1000000 [19:24<49:48:07,  5.54it/s]

{'loss': Array(0.14165668, dtype=float32), 'loss_cross_entropy': Array(0.1338429, dtype=float32)}
{'loss_inverse': Array(0.00029029, dtype=float32)}


  1%|          | 6649/1000000 [19:25<32:59:48,  8.36it/s]

{'loss': Array(0.13729748, dtype=float32), 'loss_cross_entropy': Array(0.13044582, dtype=float32)}
{'loss_inverse': Array(0.00041495, dtype=float32)}


  1%|          | 6659/1000000 [19:27<29:16:10,  9.43it/s]

{'loss': Array(0.15269202, dtype=float32), 'loss_cross_entropy': Array(0.14473307, dtype=float32)}
{'loss_inverse': Array(0.00070767, dtype=float32)}


  1%|          | 6669/1000000 [19:28<28:05:09,  9.82it/s]

{'loss': Array(0.14425395, dtype=float32), 'loss_cross_entropy': Array(0.13663125, dtype=float32)}
{'loss_inverse': Array(0.00138871, dtype=float32)}


  1%|          | 6679/1000000 [19:30<39:05:34,  7.06it/s]

{'loss': Array(0.15208791, dtype=float32), 'loss_cross_entropy': Array(0.1447228, dtype=float32)}
{'loss_inverse': Array(0.00117304, dtype=float32)}


  1%|          | 6689/1000000 [19:32<29:02:32,  9.50it/s]

{'loss': Array(0.13344395, dtype=float32), 'loss_cross_entropy': Array(0.12653518, dtype=float32)}
{'loss_inverse': Array(0.00084607, dtype=float32)}


  1%|          | 6699/1000000 [19:33<27:59:53,  9.85it/s]

{'loss': Array(0.15866469, dtype=float32), 'loss_cross_entropy': Array(0.15089965, dtype=float32)}
{'loss_inverse': Array(0.00095843, dtype=float32)}


  1%|          | 6709/1000000 [19:35<51:47:04,  5.33it/s]

{'loss': Array(0.16030154, dtype=float32), 'loss_cross_entropy': Array(0.15236165, dtype=float32)}
{'loss_inverse': Array(0.00081742, dtype=float32)}


  1%|          | 6719/1000000 [19:37<32:16:06,  8.55it/s]

{'loss': Array(0.1351457, dtype=float32), 'loss_cross_entropy': Array(0.127844, dtype=float32)}
{'loss_inverse': Array(0.00101184, dtype=float32)}


  1%|          | 6729/1000000 [19:38<28:57:37,  9.53it/s]

{'loss': Array(0.12750372, dtype=float32), 'loss_cross_entropy': Array(0.12098737, dtype=float32)}
{'loss_inverse': Array(0.00056007, dtype=float32)}


  1%|          | 6739/1000000 [19:40<30:57:05,  8.91it/s]

{'loss': Array(0.13505493, dtype=float32), 'loss_cross_entropy': Array(0.12803067, dtype=float32)}
{'loss_inverse': Array(0.00069077, dtype=float32)}


  1%|          | 6750/1000000 [19:42<30:51:46,  8.94it/s]

{'loss': Array(0.1460816, dtype=float32), 'loss_cross_entropy': Array(0.13837309, dtype=float32)}
{'loss_inverse': Array(0.00052156, dtype=float32)}


  1%|          | 6760/1000000 [19:43<28:47:06,  9.58it/s]

{'loss': Array(0.13626014, dtype=float32), 'loss_cross_entropy': Array(0.12876986, dtype=float32)}
{'loss_inverse': Array(0.00057106, dtype=float32)}


  1%|          | 6770/1000000 [19:45<30:51:23,  8.94it/s]

{'loss': Array(0.15066718, dtype=float32), 'loss_cross_entropy': Array(0.14287381, dtype=float32)}
{'loss_inverse': Array(0.0006502, dtype=float32)}


  1%|          | 6780/1000000 [19:47<36:38:59,  7.53it/s]

{'loss': Array(0.13356178, dtype=float32), 'loss_cross_entropy': Array(0.12603834, dtype=float32)}
{'loss_inverse': Array(0.0004647, dtype=float32)}


  1%|          | 6790/1000000 [19:48<29:17:42,  9.42it/s]

{'loss': Array(0.14401096, dtype=float32), 'loss_cross_entropy': Array(0.13647918, dtype=float32)}
{'loss_inverse': Array(0.00056974, dtype=float32)}


  1%|          | 6800/1000000 [19:50<30:06:02,  9.17it/s]

{'loss': Array(0.12545647, dtype=float32), 'loss_cross_entropy': Array(0.11872047, dtype=float32)}
{'loss_inverse': Array(0.00042743, dtype=float32)}


  1%|          | 6810/1000000 [19:51<28:12:13,  9.78it/s]

{'loss': Array(0.1441326, dtype=float32), 'loss_cross_entropy': Array(0.13744757, dtype=float32)}
{'loss_inverse': Array(0.00032657, dtype=float32)}


  1%|          | 6820/1000000 [19:53<38:38:02,  7.14it/s]

{'loss': Array(0.14038242, dtype=float32), 'loss_cross_entropy': Array(0.13349117, dtype=float32)}
{'loss_inverse': Array(0.00030673, dtype=float32)}


  1%|          | 6830/1000000 [19:55<31:48:34,  8.67it/s]

{'loss': Array(0.1452772, dtype=float32), 'loss_cross_entropy': Array(0.13761465, dtype=float32)}
{'loss_inverse': Array(0.00029117, dtype=float32)}


  1%|          | 6840/1000000 [19:56<28:47:42,  9.58it/s]

{'loss': Array(0.13064054, dtype=float32), 'loss_cross_entropy': Array(0.12377826, dtype=float32)}
{'loss_inverse': Array(0.00038102, dtype=float32)}


  1%|          | 6850/1000000 [19:58<28:28:08,  9.69it/s]

{'loss': Array(0.14122832, dtype=float32), 'loss_cross_entropy': Array(0.13359319, dtype=float32)}
{'loss_inverse': Array(0.00019396, dtype=float32)}


  1%|          | 6860/1000000 [20:00<35:26:28,  7.78it/s]

{'loss': Array(0.13020204, dtype=float32), 'loss_cross_entropy': Array(0.12293839, dtype=float32)}
{'loss_inverse': Array(0.00018661, dtype=float32)}


  1%|          | 6870/1000000 [20:01<29:44:44,  9.27it/s]

{'loss': Array(0.14348856, dtype=float32), 'loss_cross_entropy': Array(0.13633196, dtype=float32)}
{'loss_inverse': Array(0.00023061, dtype=float32)}


  1%|          | 6880/1000000 [20:03<28:00:11,  9.85it/s]

{'loss': Array(0.15774475, dtype=float32), 'loss_cross_entropy': Array(0.14941178, dtype=float32)}
{'loss_inverse': Array(0.00072885, dtype=float32)}


  1%|          | 6890/1000000 [20:05<44:57:43,  6.14it/s]

{'loss': Array(0.14687972, dtype=float32), 'loss_cross_entropy': Array(0.13917144, dtype=float32)}
{'loss_inverse': Array(0.00137582, dtype=float32)}


  1%|          | 6900/1000000 [20:06<31:27:09,  8.77it/s]

{'loss': Array(0.13949949, dtype=float32), 'loss_cross_entropy': Array(0.13192005, dtype=float32)}
{'loss_inverse': Array(0.00071312, dtype=float32)}


  1%|          | 6910/1000000 [20:08<28:04:35,  9.83it/s]

{'loss': Array(0.13236575, dtype=float32), 'loss_cross_entropy': Array(0.12546755, dtype=float32)}
{'loss_inverse': Array(0.00035662, dtype=float32)}


  1%|          | 6920/1000000 [20:09<28:33:10,  9.66it/s]

{'loss': Array(0.12229013, dtype=float32), 'loss_cross_entropy': Array(0.11590304, dtype=float32)}
{'loss_inverse': Array(0.00026696, dtype=float32)}


  1%|          | 6930/1000000 [20:11<33:32:13,  8.23it/s]

{'loss': Array(0.13318375, dtype=float32), 'loss_cross_entropy': Array(0.1263733, dtype=float32)}
{'loss_inverse': Array(0.00028103, dtype=float32)}


  1%|          | 6940/1000000 [20:13<29:42:23,  9.29it/s]

{'loss': Array(0.1399947, dtype=float32), 'loss_cross_entropy': Array(0.13269429, dtype=float32)}
{'loss_inverse': Array(0.000158, dtype=float32)}


  1%|          | 6950/1000000 [20:14<28:31:36,  9.67it/s]

{'loss': Array(0.15194875, dtype=float32), 'loss_cross_entropy': Array(0.14372289, dtype=float32)}
{'loss_inverse': Array(0.0003771, dtype=float32)}


  1%|          | 6960/1000000 [20:16<38:56:22,  7.08it/s]

{'loss': Array(0.14993794, dtype=float32), 'loss_cross_entropy': Array(0.14219724, dtype=float32)}
{'loss_inverse': Array(0.00024441, dtype=float32)}


  1%|          | 6970/1000000 [20:18<29:43:00,  9.28it/s]

{'loss': Array(0.14845389, dtype=float32), 'loss_cross_entropy': Array(0.14085928, dtype=float32)}
{'loss_inverse': Array(0.00015794, dtype=float32)}


  1%|          | 6980/1000000 [20:19<28:15:22,  9.76it/s]

{'loss': Array(0.14190792, dtype=float32), 'loss_cross_entropy': Array(0.1343651, dtype=float32)}
{'loss_inverse': Array(0.00022675, dtype=float32)}


  1%|          | 6990/1000000 [20:21<28:36:24,  9.64it/s]

{'loss': Array(0.12294278, dtype=float32), 'loss_cross_entropy': Array(0.11601267, dtype=float32)}
{'loss_inverse': Array(0.00010977, dtype=float32)}


  1%|          | 7000/1000000 [20:23<35:22:14,  7.80it/s]

{'loss': Array(0.14016384, dtype=float32), 'loss_cross_entropy': Array(0.13234606, dtype=float32)}
{'loss_inverse': Array(0.00039489, dtype=float32)}


  1%|          | 7010/1000000 [20:31<95:47:42,  2.88it/s] 

{'loss': Array(0.16515215, dtype=float32), 'loss_cross_entropy': Array(0.1564608, dtype=float32)}
{'loss_inverse': Array(0.0003905, dtype=float32)}


  1%|          | 7020/1000000 [20:32<39:10:43,  7.04it/s] 

{'loss': Array(0.12831216, dtype=float32), 'loss_cross_entropy': Array(0.12139302, dtype=float32)}
{'loss_inverse': Array(0.00028139, dtype=float32)}


  1%|          | 7030/1000000 [20:34<37:20:16,  7.39it/s]

{'loss': Array(0.15239818, dtype=float32), 'loss_cross_entropy': Array(0.14494923, dtype=float32)}
{'loss_inverse': Array(0.0002072, dtype=float32)}


  1%|          | 7040/1000000 [20:36<30:13:43,  9.12it/s]

{'loss': Array(0.14328712, dtype=float32), 'loss_cross_entropy': Array(0.13519643, dtype=float32)}
{'loss_inverse': Array(0.00014973, dtype=float32)}


  1%|          | 7050/1000000 [20:37<27:36:19,  9.99it/s]

{'loss': Array(0.14331661, dtype=float32), 'loss_cross_entropy': Array(0.13583697, dtype=float32)}
{'loss_inverse': Array(0.00032942, dtype=float32)}


  1%|          | 7060/1000000 [20:39<27:53:31,  9.89it/s]

{'loss': Array(0.13582055, dtype=float32), 'loss_cross_entropy': Array(0.12840664, dtype=float32)}
{'loss_inverse': Array(0.00018711, dtype=float32)}


  1%|          | 7070/1000000 [20:41<35:06:36,  7.86it/s]

{'loss': Array(0.13986996, dtype=float32), 'loss_cross_entropy': Array(0.13235871, dtype=float32)}
{'loss_inverse': Array(0.0001964, dtype=float32)}


  1%|          | 7080/1000000 [20:42<28:37:23,  9.64it/s]

{'loss': Array(0.13283825, dtype=float32), 'loss_cross_entropy': Array(0.12526853, dtype=float32)}
{'loss_inverse': Array(0.00021855, dtype=float32)}


  1%|          | 7090/1000000 [20:44<28:01:23,  9.84it/s]

{'loss': Array(0.13773933, dtype=float32), 'loss_cross_entropy': Array(0.13057472, dtype=float32)}
{'loss_inverse': Array(0.00015233, dtype=float32)}


  1%|          | 7100/1000000 [20:46<50:04:26,  5.51it/s]

{'loss': Array(0.14319988, dtype=float32), 'loss_cross_entropy': Array(0.13585706, dtype=float32)}
{'loss_inverse': Array(0.00012767, dtype=float32)}


  1%|          | 7109/1000000 [20:47<31:17:02,  8.82it/s]

{'loss': Array(0.13364844, dtype=float32), 'loss_cross_entropy': Array(0.12641549, dtype=float32)}
{'loss_inverse': Array(0.00015993, dtype=float32)}


  1%|          | 7119/1000000 [20:48<27:53:16,  9.89it/s]

{'loss': Array(0.12910195, dtype=float32), 'loss_cross_entropy': Array(0.1221021, dtype=float32)}
{'loss_inverse': Array(0.00014149, dtype=float32)}


  1%|          | 7129/1000000 [20:50<28:39:14,  9.63it/s]

{'loss': Array(0.15123793, dtype=float32), 'loss_cross_entropy': Array(0.14352143, dtype=float32)}
{'loss_inverse': Array(0.00022216, dtype=float32)}


  1%|          | 7139/1000000 [20:52<43:04:55,  6.40it/s]

{'loss': Array(0.13335651, dtype=float32), 'loss_cross_entropy': Array(0.12578602, dtype=float32)}
{'loss_inverse': Array(0.00014051, dtype=float32)}


  1%|          | 7149/1000000 [20:53<30:00:49,  9.19it/s]

{'loss': Array(0.1114156, dtype=float32), 'loss_cross_entropy': Array(0.10556855, dtype=float32)}
{'loss_inverse': Array(0.00020185, dtype=float32)}


  1%|          | 7159/1000000 [20:55<28:07:16,  9.81it/s]

{'loss': Array(0.14047506, dtype=float32), 'loss_cross_entropy': Array(0.1330068, dtype=float32)}
{'loss_inverse': Array(9.795721e-05, dtype=float32)}


  1%|          | 7169/1000000 [20:56<28:49:47,  9.57it/s]

{'loss': Array(0.14423494, dtype=float32), 'loss_cross_entropy': Array(0.1368759, dtype=float32)}
{'loss_inverse': Array(0.00014022, dtype=float32)}


  1%|          | 7179/1000000 [20:58<33:23:31,  8.26it/s]

{'loss': Array(0.12633137, dtype=float32), 'loss_cross_entropy': Array(0.11961666, dtype=float32)}
{'loss_inverse': Array(0.00014857, dtype=float32)}


  1%|          | 7189/1000000 [21:00<28:55:20,  9.54it/s]

{'loss': Array(0.11208317, dtype=float32), 'loss_cross_entropy': Array(0.10631203, dtype=float32)}
{'loss_inverse': Array(8.187228e-05, dtype=float32)}


  1%|          | 7199/1000000 [21:01<29:44:25,  9.27it/s]

{'loss': Array(0.1250747, dtype=float32), 'loss_cross_entropy': Array(0.11805616, dtype=float32)}
{'loss_inverse': Array(9.350471e-05, dtype=float32)}


  1%|          | 7209/1000000 [21:03<38:38:04,  7.14it/s]

{'loss': Array(0.12354636, dtype=float32), 'loss_cross_entropy': Array(0.11679853, dtype=float32)}
{'loss_inverse': Array(0.00012889, dtype=float32)}


  1%|          | 7219/1000000 [21:05<29:28:45,  9.35it/s]

{'loss': Array(0.12583369, dtype=float32), 'loss_cross_entropy': Array(0.11920192, dtype=float32)}
{'loss_inverse': Array(0.00011995, dtype=float32)}


  1%|          | 7229/1000000 [21:06<28:40:25,  9.62it/s]

{'loss': Array(0.1254506, dtype=float32), 'loss_cross_entropy': Array(0.11884066, dtype=float32)}
{'loss_inverse': Array(0.00040484, dtype=float32)}


  1%|          | 7239/1000000 [21:08<27:27:11, 10.05it/s]

{'loss': Array(0.14521997, dtype=float32), 'loss_cross_entropy': Array(0.13701645, dtype=float32)}
{'loss_inverse': Array(0.00059249, dtype=float32)}


  1%|          | 7249/1000000 [21:10<35:09:20,  7.84it/s]

{'loss': Array(0.13396576, dtype=float32), 'loss_cross_entropy': Array(0.1271952, dtype=float32)}
{'loss_inverse': Array(0.00048667, dtype=float32)}


  1%|          | 7259/1000000 [21:11<29:10:21,  9.45it/s]

{'loss': Array(0.1283702, dtype=float32), 'loss_cross_entropy': Array(0.12132871, dtype=float32)}
{'loss_inverse': Array(0.00023054, dtype=float32)}


  1%|          | 7269/1000000 [21:12<27:34:54, 10.00it/s]

{'loss': Array(0.13984641, dtype=float32), 'loss_cross_entropy': Array(0.13222562, dtype=float32)}
{'loss_inverse': Array(0.0002884, dtype=float32)}


  1%|          | 7279/1000000 [21:14<27:22:25, 10.07it/s]

{'loss': Array(0.13845062, dtype=float32), 'loss_cross_entropy': Array(0.1312699, dtype=float32)}
{'loss_inverse': Array(0.00022949, dtype=float32)}


  1%|          | 7289/1000000 [21:16<38:08:56,  7.23it/s]

{'loss': Array(0.13384756, dtype=float32), 'loss_cross_entropy': Array(0.12683043, dtype=float32)}
{'loss_inverse': Array(0.00029422, dtype=float32)}


  1%|          | 7299/1000000 [21:17<29:56:00,  9.21it/s]

{'loss': Array(0.12056043, dtype=float32), 'loss_cross_entropy': Array(0.11387914, dtype=float32)}
{'loss_inverse': Array(0.00012667, dtype=float32)}


  1%|          | 7309/1000000 [21:19<28:02:34,  9.83it/s]

{'loss': Array(0.12807845, dtype=float32), 'loss_cross_entropy': Array(0.12145305, dtype=float32)}
{'loss_inverse': Array(0.00038142, dtype=float32)}


  1%|          | 7319/1000000 [21:20<28:07:47,  9.80it/s]

{'loss': Array(0.1398845, dtype=float32), 'loss_cross_entropy': Array(0.13262282, dtype=float32)}
{'loss_inverse': Array(0.00042478, dtype=float32)}


  1%|          | 7329/1000000 [21:22<35:00:06,  7.88it/s]

{'loss': Array(0.13484089, dtype=float32), 'loss_cross_entropy': Array(0.12730826, dtype=float32)}
{'loss_inverse': Array(0.00021902, dtype=float32)}


  1%|          | 7339/1000000 [21:24<29:29:20,  9.35it/s]

{'loss': Array(0.1300993, dtype=float32), 'loss_cross_entropy': Array(0.12290007, dtype=float32)}
{'loss_inverse': Array(0.00053445, dtype=float32)}


  1%|          | 7349/1000000 [21:25<29:07:05,  9.47it/s]

{'loss': Array(0.13247561, dtype=float32), 'loss_cross_entropy': Array(0.12533501, dtype=float32)}
{'loss_inverse': Array(0.0005618, dtype=float32)}


  1%|          | 7359/1000000 [21:27<48:20:05,  5.70it/s]

{'loss': Array(0.14163229, dtype=float32), 'loss_cross_entropy': Array(0.13360585, dtype=float32)}
{'loss_inverse': Array(0.00068415, dtype=float32)}


  1%|          | 7369/1000000 [21:29<31:21:20,  8.79it/s]

{'loss': Array(0.12183374, dtype=float32), 'loss_cross_entropy': Array(0.11529165, dtype=float32)}
{'loss_inverse': Array(0.00074198, dtype=float32)}


  1%|          | 7379/1000000 [21:30<29:03:39,  9.49it/s]

{'loss': Array(0.12644167, dtype=float32), 'loss_cross_entropy': Array(0.11928158, dtype=float32)}
{'loss_inverse': Array(0.00071312, dtype=float32)}


  1%|          | 7389/1000000 [21:32<28:06:46,  9.81it/s]

{'loss': Array(0.1262203, dtype=float32), 'loss_cross_entropy': Array(0.11923188, dtype=float32)}
{'loss_inverse': Array(0.00046258, dtype=float32)}


  1%|          | 7399/1000000 [21:33<35:03:19,  7.87it/s]

{'loss': Array(0.12570336, dtype=float32), 'loss_cross_entropy': Array(0.1191465, dtype=float32)}
{'loss_inverse': Array(0.00043986, dtype=float32)}


  1%|          | 7409/1000000 [21:35<30:36:41,  9.01it/s]

{'loss': Array(0.12944196, dtype=float32), 'loss_cross_entropy': Array(0.12238733, dtype=float32)}
{'loss_inverse': Array(0.00095998, dtype=float32)}


  1%|          | 7419/1000000 [21:36<28:22:18,  9.72it/s]

{'loss': Array(0.1466333, dtype=float32), 'loss_cross_entropy': Array(0.13901249, dtype=float32)}
{'loss_inverse': Array(0.00089348, dtype=float32)}


  1%|          | 7429/1000000 [21:38<37:50:22,  7.29it/s]

{'loss': Array(0.14888251, dtype=float32), 'loss_cross_entropy': Array(0.14127226, dtype=float32)}
{'loss_inverse': Array(0.00057165, dtype=float32)}


  1%|          | 7439/1000000 [21:40<32:08:41,  8.58it/s]

{'loss': Array(0.11769795, dtype=float32), 'loss_cross_entropy': Array(0.11108096, dtype=float32)}
{'loss_inverse': Array(0.00060865, dtype=float32)}


  1%|          | 7449/1000000 [21:41<29:03:25,  9.49it/s]

{'loss': Array(0.13578483, dtype=float32), 'loss_cross_entropy': Array(0.12894537, dtype=float32)}
{'loss_inverse': Array(0.00053111, dtype=float32)}


  1%|          | 7459/1000000 [21:43<48:51:44,  5.64it/s]

{'loss': Array(0.14317822, dtype=float32), 'loss_cross_entropy': Array(0.1353337, dtype=float32)}
{'loss_inverse': Array(0.0005158, dtype=float32)}


  1%|          | 7469/1000000 [21:45<32:55:16,  8.37it/s]

{'loss': Array(0.14189728, dtype=float32), 'loss_cross_entropy': Array(0.1343153, dtype=float32)}
{'loss_inverse': Array(0.00060586, dtype=float32)}


  1%|          | 7479/1000000 [21:46<28:44:29,  9.59it/s]

{'loss': Array(0.1474414, dtype=float32), 'loss_cross_entropy': Array(0.13946381, dtype=float32)}
{'loss_inverse': Array(0.00023505, dtype=float32)}


  1%|          | 7489/1000000 [21:48<27:41:17,  9.96it/s]

{'loss': Array(0.13656305, dtype=float32), 'loss_cross_entropy': Array(0.12915914, dtype=float32)}
{'loss_inverse': Array(0.00020849, dtype=float32)}


  1%|          | 7499/1000000 [21:50<47:44:00,  5.78it/s]

{'loss': Array(0.14051609, dtype=float32), 'loss_cross_entropy': Array(0.1334678, dtype=float32)}
{'loss_inverse': Array(0.00038658, dtype=float32)}


  1%|          | 7509/1000000 [21:58<98:14:33,  2.81it/s] 

{'loss': Array(0.13797854, dtype=float32), 'loss_cross_entropy': Array(0.13083364, dtype=float32)}
{'loss_inverse': Array(0.00033583, dtype=float32)}


  1%|          | 7519/1000000 [21:59<39:34:34,  6.97it/s] 

{'loss': Array(0.13225661, dtype=float32), 'loss_cross_entropy': Array(0.12487879, dtype=float32)}
{'loss_inverse': Array(0.00022793, dtype=float32)}


  1%|          | 7529/1000000 [22:01<29:38:30,  9.30it/s]

{'loss': Array(0.13142464, dtype=float32), 'loss_cross_entropy': Array(0.12516017, dtype=float32)}
{'loss_inverse': Array(0.00038424, dtype=float32)}


  1%|          | 7539/1000000 [22:03<48:15:20,  5.71it/s]

{'loss': Array(0.12616228, dtype=float32), 'loss_cross_entropy': Array(0.11922711, dtype=float32)}
{'loss_inverse': Array(0.00015082, dtype=float32)}


  1%|          | 7549/1000000 [22:04<31:54:18,  8.64it/s]

{'loss': Array(0.11884979, dtype=float32), 'loss_cross_entropy': Array(0.11199655, dtype=float32)}
{'loss_inverse': Array(0.00029464, dtype=float32)}


  1%|          | 7559/1000000 [22:06<30:20:58,  9.08it/s]

{'loss': Array(0.12264591, dtype=float32), 'loss_cross_entropy': Array(0.11620543, dtype=float32)}
{'loss_inverse': Array(0.00023775, dtype=float32)}


  1%|          | 7569/1000000 [22:07<28:05:40,  9.81it/s]

{'loss': Array(0.12621498, dtype=float32), 'loss_cross_entropy': Array(0.1194033, dtype=float32)}
{'loss_inverse': Array(0.00021082, dtype=float32)}


  1%|          | 7579/1000000 [22:09<34:11:22,  8.06it/s]

{'loss': Array(0.14807492, dtype=float32), 'loss_cross_entropy': Array(0.14032678, dtype=float32)}
{'loss_inverse': Array(0.00040615, dtype=float32)}


  1%|          | 7589/1000000 [22:11<29:02:07,  9.49it/s]

{'loss': Array(0.12929098, dtype=float32), 'loss_cross_entropy': Array(0.12201262, dtype=float32)}
{'loss_inverse': Array(0.00016552, dtype=float32)}


  1%|          | 7599/1000000 [22:12<27:54:57,  9.87it/s]

{'loss': Array(0.14154026, dtype=float32), 'loss_cross_entropy': Array(0.13388796, dtype=float32)}
{'loss_inverse': Array(0.00026768, dtype=float32)}


  1%|          | 7609/1000000 [22:14<42:07:30,  6.54it/s]

{'loss': Array(0.13491666, dtype=float32), 'loss_cross_entropy': Array(0.12741584, dtype=float32)}
{'loss_inverse': Array(0.00022739, dtype=float32)}


  1%|          | 7619/1000000 [22:16<30:21:37,  9.08it/s]

{'loss': Array(0.11942679, dtype=float32), 'loss_cross_entropy': Array(0.11331494, dtype=float32)}
{'loss_inverse': Array(0.00014476, dtype=float32)}


  1%|          | 7629/1000000 [22:17<28:10:44,  9.78it/s]

{'loss': Array(0.12550505, dtype=float32), 'loss_cross_entropy': Array(0.11890008, dtype=float32)}
{'loss_inverse': Array(0.00023708, dtype=float32)}


  1%|          | 7639/1000000 [22:19<48:41:20,  5.66it/s]

{'loss': Array(0.13629337, dtype=float32), 'loss_cross_entropy': Array(0.12942408, dtype=float32)}
{'loss_inverse': Array(0.00013239, dtype=float32)}


  1%|          | 7649/1000000 [22:20<32:00:34,  8.61it/s]

{'loss': Array(0.15208633, dtype=float32), 'loss_cross_entropy': Array(0.1446837, dtype=float32)}
{'loss_inverse': Array(0.00043919, dtype=float32)}


  1%|          | 7659/1000000 [22:22<28:38:59,  9.62it/s]

{'loss': Array(0.11303984, dtype=float32), 'loss_cross_entropy': Array(0.10690346, dtype=float32)}
{'loss_inverse': Array(0.00067235, dtype=float32)}


  1%|          | 7669/1000000 [22:23<27:47:04,  9.92it/s]

{'loss': Array(0.14640008, dtype=float32), 'loss_cross_entropy': Array(0.13904195, dtype=float32)}
{'loss_inverse': Array(0.00061989, dtype=float32)}


  1%|          | 7679/1000000 [22:25<29:43:24,  9.27it/s]

{'loss': Array(0.13343072, dtype=float32), 'loss_cross_entropy': Array(0.1267011, dtype=float32)}
{'loss_inverse': Array(0.00100524, dtype=float32)}


  1%|          | 7689/1000000 [22:27<35:02:33,  7.87it/s]

{'loss': Array(0.1401445, dtype=float32), 'loss_cross_entropy': Array(0.13250108, dtype=float32)}
{'loss_inverse': Array(0.00124024, dtype=float32)}


  1%|          | 7699/1000000 [22:28<29:00:07,  9.50it/s]

{'loss': Array(0.15570997, dtype=float32), 'loss_cross_entropy': Array(0.14745212, dtype=float32)}
{'loss_inverse': Array(0.00155756, dtype=float32)}


  1%|          | 7709/1000000 [22:30<29:23:22,  9.38it/s]

{'loss': Array(0.13850947, dtype=float32), 'loss_cross_entropy': Array(0.13169365, dtype=float32)}
{'loss_inverse': Array(0.00186129, dtype=float32)}


  1%|          | 7719/1000000 [22:31<28:25:39,  9.70it/s]

{'loss': Array(0.12038648, dtype=float32), 'loss_cross_entropy': Array(0.11412863, dtype=float32)}
{'loss_inverse': Array(0.00160658, dtype=float32)}


  1%|          | 7729/1000000 [22:33<35:16:34,  7.81it/s]

{'loss': Array(0.14420396, dtype=float32), 'loss_cross_entropy': Array(0.13632278, dtype=float32)}
{'loss_inverse': Array(0.00078549, dtype=float32)}


  1%|          | 7739/1000000 [22:35<28:58:06,  9.51it/s]

{'loss': Array(0.14369802, dtype=float32), 'loss_cross_entropy': Array(0.13660513, dtype=float32)}
{'loss_inverse': Array(0.00067514, dtype=float32)}


  1%|          | 7749/1000000 [22:36<27:53:10,  9.88it/s]

{'loss': Array(0.14278981, dtype=float32), 'loss_cross_entropy': Array(0.13541186, dtype=float32)}
{'loss_inverse': Array(0.0003314, dtype=float32)}


  1%|          | 7759/1000000 [22:38<48:03:23,  5.74it/s]

{'loss': Array(0.14677754, dtype=float32), 'loss_cross_entropy': Array(0.13890456, dtype=float32)}
{'loss_inverse': Array(0.00040002, dtype=float32)}


  1%|          | 7769/1000000 [22:40<31:23:46,  8.78it/s]

{'loss': Array(0.13927798, dtype=float32), 'loss_cross_entropy': Array(0.13170907, dtype=float32)}
{'loss_inverse': Array(0.00037439, dtype=float32)}


  1%|          | 7779/1000000 [22:41<28:53:39,  9.54it/s]

{'loss': Array(0.12544279, dtype=float32), 'loss_cross_entropy': Array(0.11859661, dtype=float32)}
{'loss_inverse': Array(0.00012908, dtype=float32)}


  1%|          | 7789/1000000 [22:43<28:31:13,  9.66it/s]

{'loss': Array(0.1395043, dtype=float32), 'loss_cross_entropy': Array(0.13223974, dtype=float32)}
{'loss_inverse': Array(0.00025791, dtype=float32)}


  1%|          | 7799/1000000 [22:44<35:15:16,  7.82it/s]

{'loss': Array(0.12908392, dtype=float32), 'loss_cross_entropy': Array(0.12238995, dtype=float32)}
{'loss_inverse': Array(0.00022395, dtype=float32)}


  1%|          | 7809/1000000 [22:46<29:47:24,  9.25it/s]

{'loss': Array(0.12118369, dtype=float32), 'loss_cross_entropy': Array(0.11478342, dtype=float32)}
{'loss_inverse': Array(0.00019013, dtype=float32)}


  1%|          | 7819/1000000 [22:47<27:36:09,  9.98it/s]

{'loss': Array(0.12054376, dtype=float32), 'loss_cross_entropy': Array(0.11421062, dtype=float32)}
{'loss_inverse': Array(0.00031685, dtype=float32)}


  1%|          | 7829/1000000 [22:49<41:30:31,  6.64it/s]

{'loss': Array(0.14209409, dtype=float32), 'loss_cross_entropy': Array(0.1344742, dtype=float32)}
{'loss_inverse': Array(0.00029245, dtype=float32)}


  1%|          | 7839/1000000 [22:51<30:37:19,  9.00it/s]

{'loss': Array(0.1240503, dtype=float32), 'loss_cross_entropy': Array(0.11700326, dtype=float32)}
{'loss_inverse': Array(0.00031606, dtype=float32)}


  1%|          | 7849/1000000 [22:52<28:51:18,  9.55it/s]

{'loss': Array(0.12696543, dtype=float32), 'loss_cross_entropy': Array(0.12019525, dtype=float32)}
{'loss_inverse': Array(0.00034216, dtype=float32)}


  1%|          | 7859/1000000 [22:54<27:44:46,  9.93it/s]

{'loss': Array(0.11981811, dtype=float32), 'loss_cross_entropy': Array(0.11353835, dtype=float32)}
{'loss_inverse': Array(0.00045702, dtype=float32)}


  1%|          | 7869/1000000 [22:56<39:16:09,  7.02it/s]

{'loss': Array(0.1282324, dtype=float32), 'loss_cross_entropy': Array(0.12153566, dtype=float32)}
{'loss_inverse': Array(0.00018361, dtype=float32)}


  1%|          | 7879/1000000 [22:57<30:17:49,  9.10it/s]

{'loss': Array(0.12435192, dtype=float32), 'loss_cross_entropy': Array(0.11742671, dtype=float32)}
{'loss_inverse': Array(0.00071473, dtype=float32)}


  1%|          | 7889/1000000 [22:59<28:00:13,  9.84it/s]

{'loss': Array(0.13211565, dtype=float32), 'loss_cross_entropy': Array(0.12533039, dtype=float32)}
{'loss_inverse': Array(0.00109147, dtype=float32)}


  1%|          | 7899/1000000 [23:00<28:26:36,  9.69it/s]

{'loss': Array(0.15197626, dtype=float32), 'loss_cross_entropy': Array(0.1443625, dtype=float32)}
{'loss_inverse': Array(0.00110323, dtype=float32)}


  1%|          | 7909/1000000 [23:02<38:22:10,  7.18it/s]

{'loss': Array(0.12904415, dtype=float32), 'loss_cross_entropy': Array(0.1223492, dtype=float32)}
{'loss_inverse': Array(0.00090588, dtype=float32)}


  1%|          | 7919/1000000 [23:04<29:16:26,  9.41it/s]

{'loss': Array(0.15107736, dtype=float32), 'loss_cross_entropy': Array(0.14365546, dtype=float32)}
{'loss_inverse': Array(0.00063763, dtype=float32)}


  1%|          | 7929/1000000 [23:05<29:44:30,  9.27it/s]

{'loss': Array(0.1457589, dtype=float32), 'loss_cross_entropy': Array(0.1382288, dtype=float32)}
{'loss_inverse': Array(0.00050057, dtype=float32)}


  1%|          | 7939/1000000 [23:07<28:26:25,  9.69it/s]

{'loss': Array(0.13623503, dtype=float32), 'loss_cross_entropy': Array(0.12901026, dtype=float32)}
{'loss_inverse': Array(0.00035588, dtype=float32)}


  1%|          | 7949/1000000 [23:08<35:06:35,  7.85it/s]

{'loss': Array(0.13510394, dtype=float32), 'loss_cross_entropy': Array(0.1281582, dtype=float32)}
{'loss_inverse': Array(0.00040175, dtype=float32)}


  1%|          | 7959/1000000 [23:10<31:25:19,  8.77it/s]

{'loss': Array(0.139634, dtype=float32), 'loss_cross_entropy': Array(0.13213173, dtype=float32)}
{'loss_inverse': Array(0.00027325, dtype=float32)}


  1%|          | 7969/1000000 [23:11<28:35:23,  9.64it/s]

{'loss': Array(0.12519504, dtype=float32), 'loss_cross_entropy': Array(0.11886014, dtype=float32)}
{'loss_inverse': Array(0.00027133, dtype=float32)}


  1%|          | 7979/1000000 [23:13<42:04:08,  6.55it/s]

{'loss': Array(0.13775472, dtype=float32), 'loss_cross_entropy': Array(0.13053967, dtype=float32)}
{'loss_inverse': Array(0.00028682, dtype=float32)}


  1%|          | 7989/1000000 [23:15<31:43:28,  8.69it/s]

{'loss': Array(0.14749148, dtype=float32), 'loss_cross_entropy': Array(0.13959779, dtype=float32)}
{'loss_inverse': Array(0.00017441, dtype=float32)}


  1%|          | 7999/1000000 [23:16<28:54:13,  9.53it/s]

{'loss': Array(0.14372484, dtype=float32), 'loss_cross_entropy': Array(0.13638835, dtype=float32)}
{'loss_inverse': Array(0.00030308, dtype=float32)}


  1%|          | 8009/1000000 [23:25<94:40:14,  2.91it/s] 

{'loss': Array(0.12866983, dtype=float32), 'loss_cross_entropy': Array(0.12144418, dtype=float32)}
{'loss_inverse': Array(0.000435, dtype=float32)}


  1%|          | 8019/1000000 [23:27<44:30:19,  6.19it/s] 

{'loss': Array(0.10875201, dtype=float32), 'loss_cross_entropy': Array(0.10226616, dtype=float32)}
{'loss_inverse': Array(0.000256, dtype=float32)}


  1%|          | 8029/1000000 [23:28<30:13:09,  9.12it/s]

{'loss': Array(0.12684445, dtype=float32), 'loss_cross_entropy': Array(0.11962103, dtype=float32)}
{'loss_inverse': Array(0.00023504, dtype=float32)}


  1%|          | 8039/1000000 [23:29<28:06:33,  9.80it/s]

{'loss': Array(0.14094636, dtype=float32), 'loss_cross_entropy': Array(0.13316554, dtype=float32)}
{'loss_inverse': Array(0.00017184, dtype=float32)}


  1%|          | 8049/1000000 [23:31<27:50:13,  9.90it/s]

{'loss': Array(0.12605368, dtype=float32), 'loss_cross_entropy': Array(0.11951271, dtype=float32)}
{'loss_inverse': Array(0.00011767, dtype=float32)}


  1%|          | 8059/1000000 [23:33<39:08:03,  7.04it/s]

{'loss': Array(0.1322884, dtype=float32), 'loss_cross_entropy': Array(0.12497381, dtype=float32)}
{'loss_inverse': Array(0.00023054, dtype=float32)}


  1%|          | 8069/1000000 [23:34<29:17:41,  9.41it/s]

{'loss': Array(0.12668264, dtype=float32), 'loss_cross_entropy': Array(0.12003887, dtype=float32)}
{'loss_inverse': Array(0.00026355, dtype=float32)}


  1%|          | 8079/1000000 [23:36<28:47:08,  9.57it/s]

{'loss': Array(0.14236799, dtype=float32), 'loss_cross_entropy': Array(0.13507637, dtype=float32)}
{'loss_inverse': Array(0.00012772, dtype=float32)}


  1%|          | 8089/1000000 [23:37<28:28:37,  9.68it/s]

{'loss': Array(0.14050715, dtype=float32), 'loss_cross_entropy': Array(0.1329813, dtype=float32)}
{'loss_inverse': Array(0.00018202, dtype=float32)}


  1%|          | 8099/1000000 [23:39<34:48:16,  7.92it/s]

{'loss': Array(0.14468877, dtype=float32), 'loss_cross_entropy': Array(0.13699701, dtype=float32)}
{'loss_inverse': Array(0.00016611, dtype=float32)}


  1%|          | 8109/1000000 [23:41<30:08:12,  9.14it/s]

{'loss': Array(0.13013443, dtype=float32), 'loss_cross_entropy': Array(0.12346764, dtype=float32)}
{'loss_inverse': Array(0.00015048, dtype=float32)}


  1%|          | 8119/1000000 [23:42<28:20:01,  9.72it/s]

{'loss': Array(0.15291198, dtype=float32), 'loss_cross_entropy': Array(0.14587133, dtype=float32)}
{'loss_inverse': Array(0.00011506, dtype=float32)}


  1%|          | 8129/1000000 [23:44<41:53:21,  6.58it/s]

{'loss': Array(0.14605278, dtype=float32), 'loss_cross_entropy': Array(0.13882977, dtype=float32)}
{'loss_inverse': Array(8.304761e-05, dtype=float32)}


  1%|          | 8139/1000000 [23:46<30:23:00,  9.07it/s]

{'loss': Array(0.12974209, dtype=float32), 'loss_cross_entropy': Array(0.12284738, dtype=float32)}
{'loss_inverse': Array(0.00011571, dtype=float32)}


  1%|          | 8149/1000000 [23:47<28:11:11,  9.77it/s]

{'loss': Array(0.13663198, dtype=float32), 'loss_cross_entropy': Array(0.12972341, dtype=float32)}
{'loss_inverse': Array(9.093498e-05, dtype=float32)}


  1%|          | 8159/1000000 [23:49<27:24:03, 10.05it/s]

{'loss': Array(0.1434661, dtype=float32), 'loss_cross_entropy': Array(0.13598807, dtype=float32)}
{'loss_inverse': Array(0.00015766, dtype=float32)}


  1%|          | 8169/1000000 [23:51<33:26:02,  8.24it/s]

{'loss': Array(0.14756571, dtype=float32), 'loss_cross_entropy': Array(0.14039649, dtype=float32)}
{'loss_inverse': Array(0.00010847, dtype=float32)}


  1%|          | 8179/1000000 [23:52<28:26:12,  9.69it/s]

{'loss': Array(0.12901524, dtype=float32), 'loss_cross_entropy': Array(0.12224736, dtype=float32)}
{'loss_inverse': Array(0.00019323, dtype=float32)}


  1%|          | 8189/1000000 [23:54<27:37:31,  9.97it/s]

{'loss': Array(0.14009799, dtype=float32), 'loss_cross_entropy': Array(0.1330241, dtype=float32)}
{'loss_inverse': Array(0.00036511, dtype=float32)}


  1%|          | 8199/1000000 [23:56<34:57:34,  7.88it/s]

{'loss': Array(0.13673161, dtype=float32), 'loss_cross_entropy': Array(0.12993471, dtype=float32)}
{'loss_inverse': Array(0.00037312, dtype=float32)}


  1%|          | 8209/1000000 [23:57<28:27:20,  9.68it/s]

{'loss': Array(0.14111641, dtype=float32), 'loss_cross_entropy': Array(0.13364342, dtype=float32)}
{'loss_inverse': Array(0.00030799, dtype=float32)}


  1%|          | 8219/1000000 [23:58<27:55:33,  9.87it/s]

{'loss': Array(0.14566708, dtype=float32), 'loss_cross_entropy': Array(0.13829462, dtype=float32)}
{'loss_inverse': Array(0.00046334, dtype=float32)}


  1%|          | 8229/1000000 [24:00<43:44:37,  6.30it/s]

{'loss': Array(0.1432087, dtype=float32), 'loss_cross_entropy': Array(0.13580136, dtype=float32)}
{'loss_inverse': Array(0.00092645, dtype=float32)}


  1%|          | 8239/1000000 [24:02<30:14:07,  9.11it/s]

{'loss': Array(0.12610976, dtype=float32), 'loss_cross_entropy': Array(0.11889507, dtype=float32)}
{'loss_inverse': Array(0.0007295, dtype=float32)}


  1%|          | 8249/1000000 [24:03<28:41:36,  9.60it/s]

{'loss': Array(0.14760108, dtype=float32), 'loss_cross_entropy': Array(0.1396858, dtype=float32)}
{'loss_inverse': Array(0.00089921, dtype=float32)}


  1%|          | 8259/1000000 [24:05<44:00:13,  6.26it/s]

{'loss': Array(0.12813991, dtype=float32), 'loss_cross_entropy': Array(0.12129145, dtype=float32)}
{'loss_inverse': Array(0.00051142, dtype=float32)}


  1%|          | 8269/1000000 [24:07<31:00:15,  8.89it/s]

{'loss': Array(0.12183511, dtype=float32), 'loss_cross_entropy': Array(0.11549808, dtype=float32)}
{'loss_inverse': Array(0.00065314, dtype=float32)}


  1%|          | 8279/1000000 [24:08<28:48:00,  9.57it/s]

{'loss': Array(0.12995464, dtype=float32), 'loss_cross_entropy': Array(0.12329038, dtype=float32)}
{'loss_inverse': Array(0.00056067, dtype=float32)}


  1%|          | 8289/1000000 [24:10<30:00:38,  9.18it/s]

{'loss': Array(0.14468256, dtype=float32), 'loss_cross_entropy': Array(0.13653833, dtype=float32)}
{'loss_inverse': Array(0.00053629, dtype=float32)}


  1%|          | 8299/1000000 [24:12<33:05:31,  8.32it/s]

{'loss': Array(0.13708945, dtype=float32), 'loss_cross_entropy': Array(0.12939487, dtype=float32)}
{'loss_inverse': Array(0.0002493, dtype=float32)}


  1%|          | 8309/1000000 [24:13<29:01:13,  9.49it/s]

{'loss': Array(0.15386307, dtype=float32), 'loss_cross_entropy': Array(0.14554226, dtype=float32)}
{'loss_inverse': Array(0.00079824, dtype=float32)}


  1%|          | 8319/1000000 [24:15<29:42:30,  9.27it/s]

{'loss': Array(0.13740937, dtype=float32), 'loss_cross_entropy': Array(0.13016164, dtype=float32)}
{'loss_inverse': Array(0.00072577, dtype=float32)}


  1%|          | 8329/1000000 [24:17<49:21:41,  5.58it/s]

{'loss': Array(0.14392515, dtype=float32), 'loss_cross_entropy': Array(0.13649035, dtype=float32)}
{'loss_inverse': Array(0.00055016, dtype=float32)}


  1%|          | 8339/1000000 [24:18<31:31:57,  8.74it/s]

{'loss': Array(0.15101263, dtype=float32), 'loss_cross_entropy': Array(0.1429141, dtype=float32)}
{'loss_inverse': Array(0.00050492, dtype=float32)}


  1%|          | 8349/1000000 [24:20<29:07:12,  9.46it/s]

{'loss': Array(0.15219806, dtype=float32), 'loss_cross_entropy': Array(0.14430182, dtype=float32)}
{'loss_inverse': Array(0.00045188, dtype=float32)}


  1%|          | 8359/1000000 [24:21<29:19:50,  9.39it/s]

{'loss': Array(0.13389407, dtype=float32), 'loss_cross_entropy': Array(0.12719946, dtype=float32)}
{'loss_inverse': Array(0.00036188, dtype=float32)}


  1%|          | 8369/1000000 [24:23<28:02:34,  9.82it/s]

{'loss': Array(0.12608998, dtype=float32), 'loss_cross_entropy': Array(0.12006289, dtype=float32)}
{'loss_inverse': Array(0.00032385, dtype=float32)}


  1%|          | 8379/1000000 [24:25<34:52:52,  7.90it/s]

{'loss': Array(0.13663988, dtype=float32), 'loss_cross_entropy': Array(0.12891941, dtype=float32)}
{'loss_inverse': Array(0.00027256, dtype=float32)}


  1%|          | 8389/1000000 [24:26<29:45:03,  9.26it/s]

{'loss': Array(0.14394468, dtype=float32), 'loss_cross_entropy': Array(0.13650315, dtype=float32)}
{'loss_inverse': Array(0.00031251, dtype=float32)}


  1%|          | 8399/1000000 [24:28<28:03:09,  9.82it/s]

{'loss': Array(0.13486679, dtype=float32), 'loss_cross_entropy': Array(0.12714142, dtype=float32)}
{'loss_inverse': Array(0.00020668, dtype=float32)}


  1%|          | 8409/1000000 [24:29<27:34:37,  9.99it/s]

{'loss': Array(0.12007795, dtype=float32), 'loss_cross_entropy': Array(0.11314092, dtype=float32)}
{'loss_inverse': Array(0.00029467, dtype=float32)}


  1%|          | 8419/1000000 [24:31<34:51:07,  7.90it/s]

{'loss': Array(0.1322696, dtype=float32), 'loss_cross_entropy': Array(0.12508969, dtype=float32)}
{'loss_inverse': Array(0.00025683, dtype=float32)}


  1%|          | 8429/1000000 [24:33<29:09:07,  9.45it/s]

{'loss': Array(0.12720047, dtype=float32), 'loss_cross_entropy': Array(0.12023018, dtype=float32)}
{'loss_inverse': Array(0.00023543, dtype=float32)}


  1%|          | 8439/1000000 [24:34<28:11:54,  9.77it/s]

{'loss': Array(0.13709716, dtype=float32), 'loss_cross_entropy': Array(0.13047768, dtype=float32)}
{'loss_inverse': Array(0.0003341, dtype=float32)}


  1%|          | 8449/1000000 [24:36<43:24:05,  6.35it/s]

{'loss': Array(0.13558565, dtype=float32), 'loss_cross_entropy': Array(0.12860309, dtype=float32)}
{'loss_inverse': Array(0.00015828, dtype=float32)}


  1%|          | 8459/1000000 [24:38<30:12:39,  9.12it/s]

{'loss': Array(0.15081643, dtype=float32), 'loss_cross_entropy': Array(0.14323512, dtype=float32)}
{'loss_inverse': Array(0.00017557, dtype=float32)}


  1%|          | 8469/1000000 [24:39<28:10:42,  9.77it/s]

{'loss': Array(0.11985015, dtype=float32), 'loss_cross_entropy': Array(0.11358821, dtype=float32)}
{'loss_inverse': Array(0.00014326, dtype=float32)}


  1%|          | 8479/1000000 [24:41<43:01:54,  6.40it/s]

{'loss': Array(0.14383744, dtype=float32), 'loss_cross_entropy': Array(0.13610423, dtype=float32)}
{'loss_inverse': Array(6.871067e-05, dtype=float32)}


  1%|          | 8489/1000000 [24:43<29:57:07,  9.20it/s]

{'loss': Array(0.1490793, dtype=float32), 'loss_cross_entropy': Array(0.14198814, dtype=float32)}
{'loss_inverse': Array(0.00016536, dtype=float32)}


  1%|          | 8499/1000000 [24:44<28:12:30,  9.76it/s]

{'loss': Array(0.13069889, dtype=float32), 'loss_cross_entropy': Array(0.12360263, dtype=float32)}
{'loss_inverse': Array(9.7109325e-05, dtype=float32)}


  1%|          | 8509/1000000 [24:52<95:09:42,  2.89it/s] 

{'loss': Array(0.12232679, dtype=float32), 'loss_cross_entropy': Array(0.11532877, dtype=float32)}
{'loss_inverse': Array(7.200688e-05, dtype=float32)}


  1%|          | 8519/1000000 [24:54<59:57:47,  4.59it/s] 

{'loss': Array(0.13395482, dtype=float32), 'loss_cross_entropy': Array(0.1266088, dtype=float32)}
{'loss_inverse': Array(0.00016728, dtype=float32)}


  1%|          | 8529/1000000 [24:56<33:21:22,  8.26it/s]

{'loss': Array(0.12953922, dtype=float32), 'loss_cross_entropy': Array(0.12281962, dtype=float32)}
{'loss_inverse': Array(0.00012835, dtype=float32)}


  1%|          | 8539/1000000 [24:57<29:23:03,  9.37it/s]

{'loss': Array(0.13203757, dtype=float32), 'loss_cross_entropy': Array(0.12551181, dtype=float32)}
{'loss_inverse': Array(0.00013624, dtype=float32)}


  1%|          | 8549/1000000 [24:59<28:13:06,  9.76it/s]

{'loss': Array(0.13980213, dtype=float32), 'loss_cross_entropy': Array(0.13253023, dtype=float32)}
{'loss_inverse': Array(0.00024577, dtype=float32)}


  1%|          | 8559/1000000 [25:01<44:07:29,  6.24it/s]

{'loss': Array(0.13558342, dtype=float32), 'loss_cross_entropy': Array(0.1281618, dtype=float32)}
{'loss_inverse': Array(9.0499234e-05, dtype=float32)}


  1%|          | 8569/1000000 [25:02<30:48:14,  8.94it/s]

{'loss': Array(0.1357243, dtype=float32), 'loss_cross_entropy': Array(0.1287329, dtype=float32)}
{'loss_inverse': Array(0.00010674, dtype=float32)}


  1%|          | 8579/1000000 [25:04<28:48:58,  9.56it/s]

{'loss': Array(0.16860865, dtype=float32), 'loss_cross_entropy': Array(0.15963975, dtype=float32)}
{'loss_inverse': Array(0.00010234, dtype=float32)}


  1%|          | 8589/1000000 [25:05<28:27:13,  9.68it/s]

{'loss': Array(0.13757485, dtype=float32), 'loss_cross_entropy': Array(0.1301796, dtype=float32)}
{'loss_inverse': Array(0.00015634, dtype=float32)}


  1%|          | 8599/1000000 [25:07<34:31:50,  7.98it/s]

{'loss': Array(0.1335946, dtype=float32), 'loss_cross_entropy': Array(0.12699433, dtype=float32)}
{'loss_inverse': Array(0.00015419, dtype=float32)}


  1%|          | 8609/1000000 [25:09<28:37:17,  9.62it/s]

{'loss': Array(0.13438652, dtype=float32), 'loss_cross_entropy': Array(0.12684074, dtype=float32)}
{'loss_inverse': Array(9.785004e-05, dtype=float32)}


  1%|          | 8619/1000000 [25:10<29:09:42,  9.44it/s]

{'loss': Array(0.15558863, dtype=float32), 'loss_cross_entropy': Array(0.14740516, dtype=float32)}
{'loss_inverse': Array(0.00013889, dtype=float32)}


  1%|          | 8629/1000000 [25:12<42:15:17,  6.52it/s]

{'loss': Array(0.13073064, dtype=float32), 'loss_cross_entropy': Array(0.12431592, dtype=float32)}
{'loss_inverse': Array(0.00012496, dtype=float32)}


  1%|          | 8639/1000000 [25:13<30:00:14,  9.18it/s]

{'loss': Array(0.13348754, dtype=float32), 'loss_cross_entropy': Array(0.1258575, dtype=float32)}
{'loss_inverse': Array(6.4918626e-05, dtype=float32)}


  1%|          | 8649/1000000 [25:15<29:46:53,  9.25it/s]

{'loss': Array(0.13311587, dtype=float32), 'loss_cross_entropy': Array(0.12631567, dtype=float32)}
{'loss_inverse': Array(0.0001116, dtype=float32)}


  1%|          | 8659/1000000 [25:16<29:18:31,  9.40it/s]

{'loss': Array(0.12684305, dtype=float32), 'loss_cross_entropy': Array(0.1202606, dtype=float32)}
{'loss_inverse': Array(9.1443755e-05, dtype=float32)}


  1%|          | 8669/1000000 [25:18<38:30:20,  7.15it/s]

{'loss': Array(0.12092294, dtype=float32), 'loss_cross_entropy': Array(0.11433285, dtype=float32)}
{'loss_inverse': Array(7.343693e-05, dtype=float32)}


  1%|          | 8679/1000000 [25:20<31:48:06,  8.66it/s]

{'loss': Array(0.12426157, dtype=float32), 'loss_cross_entropy': Array(0.11771057, dtype=float32)}
{'loss_inverse': Array(0.00016662, dtype=float32)}


  1%|          | 8689/1000000 [25:21<29:08:52,  9.45it/s]

{'loss': Array(0.13048793, dtype=float32), 'loss_cross_entropy': Array(0.12360194, dtype=float32)}
{'loss_inverse': Array(0.00011472, dtype=float32)}


  1%|          | 8699/1000000 [25:23<27:32:36, 10.00it/s]

{'loss': Array(0.12676945, dtype=float32), 'loss_cross_entropy': Array(0.11973548, dtype=float32)}
{'loss_inverse': Array(0.00017688, dtype=float32)}


  1%|          | 8709/1000000 [25:25<40:01:08,  6.88it/s]

{'loss': Array(0.13430327, dtype=float32), 'loss_cross_entropy': Array(0.12699428, dtype=float32)}
{'loss_inverse': Array(7.162784e-05, dtype=float32)}


  1%|          | 8719/1000000 [25:26<30:17:10,  9.09it/s]

{'loss': Array(0.14905308, dtype=float32), 'loss_cross_entropy': Array(0.14167082, dtype=float32)}
{'loss_inverse': Array(5.6874018e-05, dtype=float32)}


  1%|          | 8729/1000000 [25:28<28:12:18,  9.76it/s]

{'loss': Array(0.13522755, dtype=float32), 'loss_cross_entropy': Array(0.128424, dtype=float32)}
{'loss_inverse': Array(3.3408567e-05, dtype=float32)}


  1%|          | 8739/1000000 [25:29<29:07:09,  9.46it/s]

{'loss': Array(0.12767395, dtype=float32), 'loss_cross_entropy': Array(0.12091004, dtype=float32)}
{'loss_inverse': Array(0.00019382, dtype=float32)}


  1%|          | 8749/1000000 [25:31<33:28:27,  8.23it/s]

{'loss': Array(0.13396993, dtype=float32), 'loss_cross_entropy': Array(0.12712573, dtype=float32)}
{'loss_inverse': Array(0.00023539, dtype=float32)}


  1%|          | 8759/1000000 [25:33<29:24:26,  9.36it/s]

{'loss': Array(0.14403276, dtype=float32), 'loss_cross_entropy': Array(0.13635746, dtype=float32)}
{'loss_inverse': Array(0.00014764, dtype=float32)}


  1%|          | 8769/1000000 [25:34<28:55:20,  9.52it/s]

{'loss': Array(0.12762123, dtype=float32), 'loss_cross_entropy': Array(0.12033953, dtype=float32)}
{'loss_inverse': Array(6.550069e-05, dtype=float32)}


  1%|          | 8779/1000000 [25:36<39:00:38,  7.06it/s]

{'loss': Array(0.12316275, dtype=float32), 'loss_cross_entropy': Array(0.11616013, dtype=float32)}
{'loss_inverse': Array(0.000141, dtype=float32)}


  1%|          | 8789/1000000 [25:38<30:38:03,  8.99it/s]

{'loss': Array(0.13997853, dtype=float32), 'loss_cross_entropy': Array(0.13281077, dtype=float32)}
{'loss_inverse': Array(7.422277e-05, dtype=float32)}


  1%|          | 8799/1000000 [25:39<28:38:13,  9.61it/s]

{'loss': Array(0.14977665, dtype=float32), 'loss_cross_entropy': Array(0.1418597, dtype=float32)}
{'loss_inverse': Array(0.000198, dtype=float32)}


  1%|          | 8809/1000000 [25:41<43:16:05,  6.36it/s]

{'loss': Array(0.14835669, dtype=float32), 'loss_cross_entropy': Array(0.14037006, dtype=float32)}
{'loss_inverse': Array(0.00022899, dtype=float32)}


  1%|          | 8819/1000000 [25:43<30:09:18,  9.13it/s]

{'loss': Array(0.13481627, dtype=float32), 'loss_cross_entropy': Array(0.12733905, dtype=float32)}
{'loss_inverse': Array(0.00023971, dtype=float32)}


  1%|          | 8829/1000000 [25:44<27:50:57,  9.89it/s]

{'loss': Array(0.13443919, dtype=float32), 'loss_cross_entropy': Array(0.1269677, dtype=float32)}
{'loss_inverse': Array(0.00037496, dtype=float32)}


  1%|          | 8839/1000000 [25:46<27:57:40,  9.85it/s]

{'loss': Array(0.12720768, dtype=float32), 'loss_cross_entropy': Array(0.11975342, dtype=float32)}
{'loss_inverse': Array(0.00036027, dtype=float32)}


  1%|          | 8849/1000000 [25:48<37:25:45,  7.36it/s]

{'loss': Array(0.12648009, dtype=float32), 'loss_cross_entropy': Array(0.11959036, dtype=float32)}
{'loss_inverse': Array(0.00096736, dtype=float32)}


  1%|          | 8859/1000000 [25:49<29:22:25,  9.37it/s]

{'loss': Array(0.13926123, dtype=float32), 'loss_cross_entropy': Array(0.13191038, dtype=float32)}
{'loss_inverse': Array(0.00044401, dtype=float32)}


  1%|          | 8869/1000000 [25:51<29:02:04,  9.48it/s]

{'loss': Array(0.14268357, dtype=float32), 'loss_cross_entropy': Array(0.13499372, dtype=float32)}
{'loss_inverse': Array(0.00054878, dtype=float32)}


  1%|          | 8879/1000000 [25:52<28:50:53,  9.54it/s]

{'loss': Array(0.15396409, dtype=float32), 'loss_cross_entropy': Array(0.14534311, dtype=float32)}
{'loss_inverse': Array(0.00044466, dtype=float32)}


  1%|          | 8889/1000000 [25:54<38:42:13,  7.11it/s]

{'loss': Array(0.133165, dtype=float32), 'loss_cross_entropy': Array(0.12592462, dtype=float32)}
{'loss_inverse': Array(0.00088231, dtype=float32)}


  1%|          | 8899/1000000 [25:56<30:25:02,  9.05it/s]

{'loss': Array(0.16047081, dtype=float32), 'loss_cross_entropy': Array(0.15234533, dtype=float32)}
{'loss_inverse': Array(0.00038588, dtype=float32)}


  1%|          | 8909/1000000 [25:57<28:05:12,  9.80it/s]

{'loss': Array(0.14273785, dtype=float32), 'loss_cross_entropy': Array(0.13574886, dtype=float32)}
{'loss_inverse': Array(0.00064067, dtype=float32)}


  1%|          | 8919/1000000 [25:59<27:22:58, 10.05it/s]

{'loss': Array(0.15103672, dtype=float32), 'loss_cross_entropy': Array(0.1435156, dtype=float32)}
{'loss_inverse': Array(0.00064315, dtype=float32)}


  1%|          | 8929/1000000 [26:00<35:05:15,  7.85it/s]

{'loss': Array(0.14309911, dtype=float32), 'loss_cross_entropy': Array(0.1355729, dtype=float32)}
{'loss_inverse': Array(0.00051425, dtype=float32)}


  1%|          | 8939/1000000 [26:02<28:56:27,  9.51it/s]

{'loss': Array(0.13323592, dtype=float32), 'loss_cross_entropy': Array(0.1266056, dtype=float32)}
{'loss_inverse': Array(0.00035521, dtype=float32)}


  1%|          | 8949/1000000 [26:03<28:05:18,  9.80it/s]

{'loss': Array(0.14458208, dtype=float32), 'loss_cross_entropy': Array(0.13721035, dtype=float32)}
{'loss_inverse': Array(0.00034163, dtype=float32)}


  1%|          | 8959/1000000 [26:05<44:17:01,  6.22it/s]

{'loss': Array(0.14366154, dtype=float32), 'loss_cross_entropy': Array(0.13595483, dtype=float32)}
{'loss_inverse': Array(0.0003943, dtype=float32)}


  1%|          | 8969/1000000 [26:07<30:41:09,  8.97it/s]

{'loss': Array(0.15080254, dtype=float32), 'loss_cross_entropy': Array(0.14276485, dtype=float32)}
{'loss_inverse': Array(0.00030346, dtype=float32)}


  1%|          | 8979/1000000 [26:08<28:35:37,  9.63it/s]

{'loss': Array(0.14088385, dtype=float32), 'loss_cross_entropy': Array(0.13356791, dtype=float32)}
{'loss_inverse': Array(0.00027854, dtype=float32)}


  1%|          | 8989/1000000 [26:10<30:11:27,  9.12it/s]

{'loss': Array(0.15894608, dtype=float32), 'loss_cross_entropy': Array(0.15102065, dtype=float32)}
{'loss_inverse': Array(0.00020589, dtype=float32)}


  1%|          | 8999/1000000 [26:12<32:57:24,  8.35it/s]

{'loss': Array(0.13389842, dtype=float32), 'loss_cross_entropy': Array(0.12687083, dtype=float32)}
{'loss_inverse': Array(0.00028187, dtype=float32)}


  1%|          | 9009/1000000 [26:20<96:52:08,  2.84it/s] 

{'loss': Array(0.13532789, dtype=float32), 'loss_cross_entropy': Array(0.12770545, dtype=float32)}
{'loss_inverse': Array(0.0004535, dtype=float32)}


  1%|          | 9019/1000000 [26:21<39:42:29,  6.93it/s] 

{'loss': Array(0.1506587, dtype=float32), 'loss_cross_entropy': Array(0.1426888, dtype=float32)}
{'loss_inverse': Array(0.00075339, dtype=float32)}


  1%|          | 9029/1000000 [26:23<43:33:55,  6.32it/s]

{'loss': Array(0.13501534, dtype=float32), 'loss_cross_entropy': Array(0.12793386, dtype=float32)}
{'loss_inverse': Array(0.00030094, dtype=float32)}


  1%|          | 9039/1000000 [26:25<32:34:09,  8.45it/s]

{'loss': Array(0.129819, dtype=float32), 'loss_cross_entropy': Array(0.12329543, dtype=float32)}
{'loss_inverse': Array(0.00063633, dtype=float32)}


  1%|          | 9049/1000000 [26:26<29:25:32,  9.35it/s]

{'loss': Array(0.1537226, dtype=float32), 'loss_cross_entropy': Array(0.14537738, dtype=float32)}
{'loss_inverse': Array(0.00060292, dtype=float32)}


  1%|          | 9059/1000000 [26:28<28:25:35,  9.68it/s]

{'loss': Array(0.13884173, dtype=float32), 'loss_cross_entropy': Array(0.13184422, dtype=float32)}
{'loss_inverse': Array(0.00053174, dtype=float32)}


  1%|          | 9069/1000000 [26:30<34:41:33,  7.93it/s]

{'loss': Array(0.12381462, dtype=float32), 'loss_cross_entropy': Array(0.11751758, dtype=float32)}
{'loss_inverse': Array(0.00030437, dtype=float32)}


  1%|          | 9079/1000000 [26:31<29:03:09,  9.47it/s]

{'loss': Array(0.12834209, dtype=float32), 'loss_cross_entropy': Array(0.12164646, dtype=float32)}
{'loss_inverse': Array(0.00026771, dtype=float32)}


  1%|          | 9089/1000000 [26:33<28:01:49,  9.82it/s]

{'loss': Array(0.11900254, dtype=float32), 'loss_cross_entropy': Array(0.11240289, dtype=float32)}
{'loss_inverse': Array(0.00036605, dtype=float32)}


  1%|          | 9099/1000000 [26:34<27:15:45, 10.10it/s]

{'loss': Array(0.12738314, dtype=float32), 'loss_cross_entropy': Array(0.1208204, dtype=float32)}
{'loss_inverse': Array(0.00018408, dtype=float32)}


  1%|          | 9109/1000000 [26:36<32:25:14,  8.49it/s]

{'loss': Array(0.11793847, dtype=float32), 'loss_cross_entropy': Array(0.11162752, dtype=float32)}
{'loss_inverse': Array(0.00028902, dtype=float32)}


  1%|          | 9119/1000000 [26:38<29:28:17,  9.34it/s]

{'loss': Array(0.12845919, dtype=float32), 'loss_cross_entropy': Array(0.12166554, dtype=float32)}
{'loss_inverse': Array(0.00024618, dtype=float32)}


  1%|          | 9129/1000000 [26:39<27:58:54,  9.84it/s]

{'loss': Array(0.14855355, dtype=float32), 'loss_cross_entropy': Array(0.14113136, dtype=float32)}
{'loss_inverse': Array(0.00018896, dtype=float32)}


  1%|          | 9139/1000000 [26:41<48:51:25,  5.63it/s]

{'loss': Array(0.11839762, dtype=float32), 'loss_cross_entropy': Array(0.11166157, dtype=float32)}
{'loss_inverse': Array(0.00019184, dtype=float32)}


  1%|          | 9149/1000000 [26:43<31:24:06,  8.77it/s]

{'loss': Array(0.13959329, dtype=float32), 'loss_cross_entropy': Array(0.13176398, dtype=float32)}
{'loss_inverse': Array(0.00048536, dtype=float32)}


  1%|          | 9159/1000000 [26:44<29:02:19,  9.48it/s]

{'loss': Array(0.12496294, dtype=float32), 'loss_cross_entropy': Array(0.11829004, dtype=float32)}
{'loss_inverse': Array(0.0003226, dtype=float32)}


  1%|          | 9169/1000000 [26:46<29:28:51,  9.34it/s]

{'loss': Array(0.13559197, dtype=float32), 'loss_cross_entropy': Array(0.12876502, dtype=float32)}
{'loss_inverse': Array(0.00028092, dtype=float32)}


  1%|          | 9179/1000000 [26:48<38:06:15,  7.22it/s]

{'loss': Array(0.14007913, dtype=float32), 'loss_cross_entropy': Array(0.13307263, dtype=float32)}
{'loss_inverse': Array(0.00053213, dtype=float32)}


  1%|          | 9189/1000000 [26:49<29:20:24,  9.38it/s]

{'loss': Array(0.15082023, dtype=float32), 'loss_cross_entropy': Array(0.14273094, dtype=float32)}
{'loss_inverse': Array(0.00057461, dtype=float32)}


  1%|          | 9199/1000000 [26:51<29:20:59,  9.38it/s]

{'loss': Array(0.12383121, dtype=float32), 'loss_cross_entropy': Array(0.11685379, dtype=float32)}
{'loss_inverse': Array(0.00070403, dtype=float32)}


  1%|          | 9209/1000000 [26:53<49:02:20,  5.61it/s]

{'loss': Array(0.1413826, dtype=float32), 'loss_cross_entropy': Array(0.13381492, dtype=float32)}
{'loss_inverse': Array(0.00049582, dtype=float32)}


  1%|          | 9219/1000000 [26:54<31:20:19,  8.78it/s]

{'loss': Array(0.12703885, dtype=float32), 'loss_cross_entropy': Array(0.12032175, dtype=float32)}
{'loss_inverse': Array(0.00038263, dtype=float32)}


  1%|          | 9229/1000000 [26:56<28:29:48,  9.66it/s]

{'loss': Array(0.13927674, dtype=float32), 'loss_cross_entropy': Array(0.13154909, dtype=float32)}
{'loss_inverse': Array(0.00044097, dtype=float32)}


  1%|          | 9239/1000000 [26:57<27:00:49, 10.19it/s]

{'loss': Array(0.1362276, dtype=float32), 'loss_cross_entropy': Array(0.12855403, dtype=float32)}
{'loss_inverse': Array(0.0002833, dtype=float32)}


  1%|          | 9249/1000000 [26:59<35:53:10,  7.67it/s]

{'loss': Array(0.13137256, dtype=float32), 'loss_cross_entropy': Array(0.12381507, dtype=float32)}
{'loss_inverse': Array(0.00020064, dtype=float32)}


  1%|          | 9259/1000000 [27:01<29:23:18,  9.36it/s]

{'loss': Array(0.12037744, dtype=float32), 'loss_cross_entropy': Array(0.11360043, dtype=float32)}
{'loss_inverse': Array(0.00013469, dtype=float32)}


  1%|          | 9269/1000000 [27:02<27:53:33,  9.87it/s]

{'loss': Array(0.14225402, dtype=float32), 'loss_cross_entropy': Array(0.13473359, dtype=float32)}
{'loss_inverse': Array(0.00038601, dtype=float32)}


  1%|          | 9279/1000000 [27:04<38:34:00,  7.14it/s]

{'loss': Array(0.12537313, dtype=float32), 'loss_cross_entropy': Array(0.1186101, dtype=float32)}
{'loss_inverse': Array(0.00026316, dtype=float32)}


  1%|          | 9289/1000000 [27:06<30:25:42,  9.04it/s]

{'loss': Array(0.14654163, dtype=float32), 'loss_cross_entropy': Array(0.13884136, dtype=float32)}
{'loss_inverse': Array(0.00045451, dtype=float32)}


  1%|          | 9299/1000000 [27:07<28:14:53,  9.74it/s]

{'loss': Array(0.1358879, dtype=float32), 'loss_cross_entropy': Array(0.1291403, dtype=float32)}
{'loss_inverse': Array(0.00021984, dtype=float32)}


  1%|          | 9309/1000000 [27:08<27:33:59,  9.98it/s]

{'loss': Array(0.10985482, dtype=float32), 'loss_cross_entropy': Array(0.10391185, dtype=float32)}
{'loss_inverse': Array(0.00016129, dtype=float32)}


  1%|          | 9319/1000000 [27:10<38:47:45,  7.09it/s]

{'loss': Array(0.12487083, dtype=float32), 'loss_cross_entropy': Array(0.1188907, dtype=float32)}
{'loss_inverse': Array(0.00032002, dtype=float32)}


  1%|          | 9329/1000000 [27:12<29:37:03,  9.29it/s]

{'loss': Array(0.12437053, dtype=float32), 'loss_cross_entropy': Array(0.11791792, dtype=float32)}
{'loss_inverse': Array(0.0003788, dtype=float32)}


  1%|          | 9339/1000000 [27:13<28:22:20,  9.70it/s]

{'loss': Array(0.13062258, dtype=float32), 'loss_cross_entropy': Array(0.12360918, dtype=float32)}
{'loss_inverse': Array(0.000414, dtype=float32)}


  1%|          | 9349/1000000 [27:15<29:04:51,  9.46it/s]

{'loss': Array(0.13851771, dtype=float32), 'loss_cross_entropy': Array(0.13115393, dtype=float32)}
{'loss_inverse': Array(0.00052184, dtype=float32)}


  1%|          | 9360/1000000 [27:17<30:25:15,  9.05it/s]

{'loss': Array(0.11842219, dtype=float32), 'loss_cross_entropy': Array(0.11187475, dtype=float32)}
{'loss_inverse': Array(0.00037605, dtype=float32)}


  1%|          | 9370/1000000 [27:18<27:34:38,  9.98it/s]

{'loss': Array(0.16467087, dtype=float32), 'loss_cross_entropy': Array(0.15660603, dtype=float32)}
{'loss_inverse': Array(0.00038125, dtype=float32)}


  1%|          | 9380/1000000 [27:20<27:56:21,  9.85it/s]

{'loss': Array(0.11937747, dtype=float32), 'loss_cross_entropy': Array(0.11305433, dtype=float32)}
{'loss_inverse': Array(0.00044636, dtype=float32)}


  1%|          | 9390/1000000 [27:22<41:59:58,  6.55it/s]

{'loss': Array(0.14058942, dtype=float32), 'loss_cross_entropy': Array(0.13333088, dtype=float32)}
{'loss_inverse': Array(0.00026026, dtype=float32)}


  1%|          | 9400/1000000 [27:23<29:47:06,  9.24it/s]

{'loss': Array(0.14306477, dtype=float32), 'loss_cross_entropy': Array(0.13683234, dtype=float32)}
{'loss_inverse': Array(0.00026303, dtype=float32)}


  1%|          | 9410/1000000 [27:25<28:22:50,  9.70it/s]

{'loss': Array(0.14471063, dtype=float32), 'loss_cross_entropy': Array(0.1370027, dtype=float32)}
{'loss_inverse': Array(0.0003829, dtype=float32)}


  1%|          | 9420/1000000 [27:26<29:33:00,  9.31it/s]

{'loss': Array(0.14514886, dtype=float32), 'loss_cross_entropy': Array(0.1375048, dtype=float32)}
{'loss_inverse': Array(0.00053251, dtype=float32)}


  1%|          | 9430/1000000 [27:28<32:19:17,  8.51it/s]

{'loss': Array(0.13679513, dtype=float32), 'loss_cross_entropy': Array(0.12970598, dtype=float32)}
{'loss_inverse': Array(0.0002848, dtype=float32)}


  1%|          | 9440/1000000 [27:30<29:34:21,  9.30it/s]

{'loss': Array(0.11529016, dtype=float32), 'loss_cross_entropy': Array(0.10925718, dtype=float32)}
{'loss_inverse': Array(0.00124317, dtype=float32)}


  1%|          | 9450/1000000 [27:31<28:34:01,  9.63it/s]

{'loss': Array(0.15911093, dtype=float32), 'loss_cross_entropy': Array(0.15068641, dtype=float32)}
{'loss_inverse': Array(0.00072267, dtype=float32)}


  1%|          | 9460/1000000 [27:33<37:45:36,  7.29it/s]

{'loss': Array(0.15221663, dtype=float32), 'loss_cross_entropy': Array(0.14506412, dtype=float32)}
{'loss_inverse': Array(0.00069399, dtype=float32)}


  1%|          | 9470/1000000 [27:34<30:10:00,  9.12it/s]

{'loss': Array(0.13858233, dtype=float32), 'loss_cross_entropy': Array(0.13152663, dtype=float32)}
{'loss_inverse': Array(0.00071897, dtype=float32)}


  1%|          | 9480/1000000 [27:36<28:22:29,  9.70it/s]

{'loss': Array(0.12495802, dtype=float32), 'loss_cross_entropy': Array(0.11791123, dtype=float32)}
{'loss_inverse': Array(0.0001961, dtype=float32)}


  1%|          | 9490/1000000 [27:37<28:47:49,  9.55it/s]

{'loss': Array(0.15138422, dtype=float32), 'loss_cross_entropy': Array(0.14298682, dtype=float32)}
{'loss_inverse': Array(0.00054943, dtype=float32)}


  1%|          | 9500/1000000 [27:39<34:59:31,  7.86it/s]

{'loss': Array(0.12759684, dtype=float32), 'loss_cross_entropy': Array(0.12068782, dtype=float32)}
{'loss_inverse': Array(0.00036246, dtype=float32)}


  1%|          | 9510/1000000 [27:47<94:49:08,  2.90it/s] 

{'loss': Array(0.1323077, dtype=float32), 'loss_cross_entropy': Array(0.12511246, dtype=float32)}
{'loss_inverse': Array(0.00048581, dtype=float32)}


  1%|          | 9520/1000000 [27:49<40:16:15,  6.83it/s] 

{'loss': Array(0.12913863, dtype=float32), 'loss_cross_entropy': Array(0.12179649, dtype=float32)}
{'loss_inverse': Array(0.0005557, dtype=float32)}


  1%|          | 9529/1000000 [27:51<38:04:29,  7.23it/s]

{'loss': Array(0.13893457, dtype=float32), 'loss_cross_entropy': Array(0.13175134, dtype=float32)}
{'loss_inverse': Array(0.0004489, dtype=float32)}


  1%|          | 9539/1000000 [27:52<29:39:57,  9.27it/s]

{'loss': Array(0.1382829, dtype=float32), 'loss_cross_entropy': Array(0.13069117, dtype=float32)}
{'loss_inverse': Array(0.00035232, dtype=float32)}


  1%|          | 9549/1000000 [27:54<28:45:10,  9.57it/s]

{'loss': Array(0.13088712, dtype=float32), 'loss_cross_entropy': Array(0.12390246, dtype=float32)}
{'loss_inverse': Array(0.00045022, dtype=float32)}


  1%|          | 9559/1000000 [27:55<28:43:32,  9.58it/s]

{'loss': Array(0.12924486, dtype=float32), 'loss_cross_entropy': Array(0.121864, dtype=float32)}
{'loss_inverse': Array(0.00022515, dtype=float32)}


  1%|          | 9569/1000000 [27:57<35:41:30,  7.71it/s]

{'loss': Array(0.16522217, dtype=float32), 'loss_cross_entropy': Array(0.15667926, dtype=float32)}
{'loss_inverse': Array(0.00041262, dtype=float32)}


  1%|          | 9579/1000000 [27:59<29:32:34,  9.31it/s]

{'loss': Array(0.10290178, dtype=float32), 'loss_cross_entropy': Array(0.09721806, dtype=float32)}
{'loss_inverse': Array(0.00041447, dtype=float32)}


  1%|          | 9589/1000000 [28:00<28:36:41,  9.62it/s]

{'loss': Array(0.1309266, dtype=float32), 'loss_cross_entropy': Array(0.12345928, dtype=float32)}
{'loss_inverse': Array(0.00020224, dtype=float32)}


  1%|          | 9599/1000000 [28:02<28:04:05,  9.80it/s]

{'loss': Array(0.1299287, dtype=float32), 'loss_cross_entropy': Array(0.12231769, dtype=float32)}
{'loss_inverse': Array(0.00020251, dtype=float32)}


  1%|          | 9609/1000000 [28:04<32:55:06,  8.36it/s]

{'loss': Array(0.14119576, dtype=float32), 'loss_cross_entropy': Array(0.13378933, dtype=float32)}
{'loss_inverse': Array(0.00012075, dtype=float32)}


  1%|          | 9619/1000000 [28:05<29:44:40,  9.25it/s]

{'loss': Array(0.13134827, dtype=float32), 'loss_cross_entropy': Array(0.12422667, dtype=float32)}
{'loss_inverse': Array(0.00031874, dtype=float32)}


  1%|          | 9629/1000000 [28:07<28:44:53,  9.57it/s]

{'loss': Array(0.14752547, dtype=float32), 'loss_cross_entropy': Array(0.14030853, dtype=float32)}
{'loss_inverse': Array(0.00045755, dtype=float32)}


  1%|          | 9639/1000000 [28:09<42:12:09,  6.52it/s]

{'loss': Array(0.1376309, dtype=float32), 'loss_cross_entropy': Array(0.13032179, dtype=float32)}
{'loss_inverse': Array(0.0002556, dtype=float32)}


  1%|          | 9650/1000000 [28:10<28:55:21,  9.51it/s]

{'loss': Array(0.11633921, dtype=float32), 'loss_cross_entropy': Array(0.11088645, dtype=float32)}
{'loss_inverse': Array(0.00022417, dtype=float32)}


  1%|          | 9660/1000000 [28:12<28:09:47,  9.77it/s]

{'loss': Array(0.13881272, dtype=float32), 'loss_cross_entropy': Array(0.13128799, dtype=float32)}
{'loss_inverse': Array(0.00061885, dtype=float32)}


  1%|          | 9670/1000000 [28:13<28:13:23,  9.75it/s]

{'loss': Array(0.13510716, dtype=float32), 'loss_cross_entropy': Array(0.12769304, dtype=float32)}
{'loss_inverse': Array(0.00033947, dtype=float32)}


  1%|          | 9680/1000000 [28:15<33:48:49,  8.14it/s]

{'loss': Array(0.1390581, dtype=float32), 'loss_cross_entropy': Array(0.13216008, dtype=float32)}
{'loss_inverse': Array(0.00033104, dtype=float32)}


  1%|          | 9690/1000000 [28:17<29:09:48,  9.43it/s]

{'loss': Array(0.13362733, dtype=float32), 'loss_cross_entropy': Array(0.12611505, dtype=float32)}
{'loss_inverse': Array(0.0002475, dtype=float32)}


  1%|          | 9700/1000000 [28:18<29:45:38,  9.24it/s]

{'loss': Array(0.13130069, dtype=float32), 'loss_cross_entropy': Array(0.12406089, dtype=float32)}
{'loss_inverse': Array(0.00026457, dtype=float32)}


  1%|          | 9710/1000000 [28:20<39:21:55,  6.99it/s]

{'loss': Array(0.14850318, dtype=float32), 'loss_cross_entropy': Array(0.1408873, dtype=float32)}
{'loss_inverse': Array(0.00045195, dtype=float32)}


  1%|          | 9720/1000000 [28:22<29:51:47,  9.21it/s]

{'loss': Array(0.14193495, dtype=float32), 'loss_cross_entropy': Array(0.1344178, dtype=float32)}
{'loss_inverse': Array(0.00045709, dtype=float32)}


  1%|          | 9730/1000000 [28:23<29:05:40,  9.45it/s]

{'loss': Array(0.12260741, dtype=float32), 'loss_cross_entropy': Array(0.11551934, dtype=float32)}
{'loss_inverse': Array(0.00038323, dtype=float32)}


  1%|          | 9740/1000000 [28:25<29:58:24,  9.18it/s]

{'loss': Array(0.13814786, dtype=float32), 'loss_cross_entropy': Array(0.13109127, dtype=float32)}
{'loss_inverse': Array(0.00090218, dtype=float32)}


  1%|          | 9749/1000000 [28:27<36:17:23,  7.58it/s]

{'loss': Array(0.13281746, dtype=float32), 'loss_cross_entropy': Array(0.12546025, dtype=float32)}
{'loss_inverse': Array(0.00109861, dtype=float32)}


  1%|          | 9759/1000000 [28:28<29:24:02,  9.36it/s]

{'loss': Array(0.12274901, dtype=float32), 'loss_cross_entropy': Array(0.1153629, dtype=float32)}
{'loss_inverse': Array(0.00085046, dtype=float32)}


  1%|          | 9769/1000000 [28:30<29:51:42,  9.21it/s]

{'loss': Array(0.14409354, dtype=float32), 'loss_cross_entropy': Array(0.13660106, dtype=float32)}
{'loss_inverse': Array(0.00068322, dtype=float32)}


  1%|          | 9779/1000000 [28:31<28:42:49,  9.58it/s]

{'loss': Array(0.13325527, dtype=float32), 'loss_cross_entropy': Array(0.12637019, dtype=float32)}
{'loss_inverse': Array(0.00042296, dtype=float32)}


  1%|          | 9789/1000000 [28:33<38:46:25,  7.09it/s]

{'loss': Array(0.14679573, dtype=float32), 'loss_cross_entropy': Array(0.13882445, dtype=float32)}
{'loss_inverse': Array(0.00033765, dtype=float32)}


  1%|          | 9799/1000000 [28:35<30:53:30,  8.90it/s]

{'loss': Array(0.12738861, dtype=float32), 'loss_cross_entropy': Array(0.12066348, dtype=float32)}
{'loss_inverse': Array(0.0002313, dtype=float32)}


  1%|          | 9809/1000000 [28:36<28:43:48,  9.57it/s]

{'loss': Array(0.14741509, dtype=float32), 'loss_cross_entropy': Array(0.14015956, dtype=float32)}
{'loss_inverse': Array(0.00040359, dtype=float32)}


  1%|          | 9819/1000000 [28:38<28:32:52,  9.63it/s]

{'loss': Array(0.14031081, dtype=float32), 'loss_cross_entropy': Array(0.13278167, dtype=float32)}
{'loss_inverse': Array(0.0001868, dtype=float32)}


  1%|          | 9829/1000000 [28:40<36:59:19,  7.44it/s]

{'loss': Array(0.11645993, dtype=float32), 'loss_cross_entropy': Array(0.1098175, dtype=float32)}
{'loss_inverse': Array(0.00017254, dtype=float32)}


  1%|          | 9839/1000000 [28:41<29:39:41,  9.27it/s]

{'loss': Array(0.11697095, dtype=float32), 'loss_cross_entropy': Array(0.11034937, dtype=float32)}
{'loss_inverse': Array(0.00025173, dtype=float32)}


  1%|          | 9849/1000000 [28:43<29:04:45,  9.46it/s]

{'loss': Array(0.13023, dtype=float32), 'loss_cross_entropy': Array(0.12323856, dtype=float32)}
{'loss_inverse': Array(0.00019267, dtype=float32)}


  1%|          | 9859/1000000 [28:45<51:54:52,  5.30it/s]

{'loss': Array(0.13576092, dtype=float32), 'loss_cross_entropy': Array(0.12825422, dtype=float32)}
{'loss_inverse': Array(0.00044803, dtype=float32)}


  1%|          | 9869/1000000 [28:46<32:26:43,  8.48it/s]

{'loss': Array(0.12308248, dtype=float32), 'loss_cross_entropy': Array(0.1164154, dtype=float32)}
{'loss_inverse': Array(0.00052239, dtype=float32)}


  1%|          | 9879/1000000 [28:48<28:22:57,  9.69it/s]

{'loss': Array(0.14383392, dtype=float32), 'loss_cross_entropy': Array(0.13609104, dtype=float32)}
{'loss_inverse': Array(0.00037098, dtype=float32)}


  1%|          | 9889/1000000 [28:49<28:28:01,  9.66it/s]

{'loss': Array(0.13167869, dtype=float32), 'loss_cross_entropy': Array(0.1247499, dtype=float32)}
{'loss_inverse': Array(0.0003707, dtype=float32)}


  1%|          | 9899/1000000 [28:51<36:01:04,  7.64it/s]

{'loss': Array(0.12970975, dtype=float32), 'loss_cross_entropy': Array(0.12305675, dtype=float32)}
{'loss_inverse': Array(0.00057423, dtype=float32)}


  1%|          | 9909/1000000 [28:53<28:59:39,  9.49it/s]

{'loss': Array(0.12982486, dtype=float32), 'loss_cross_entropy': Array(0.12264603, dtype=float32)}
{'loss_inverse': Array(0.00055591, dtype=float32)}


  1%|          | 9919/1000000 [28:54<27:53:46,  9.86it/s]

{'loss': Array(0.13171844, dtype=float32), 'loss_cross_entropy': Array(0.12416253, dtype=float32)}
{'loss_inverse': Array(0.00032695, dtype=float32)}


  1%|          | 9929/1000000 [28:56<38:31:11,  7.14it/s]

{'loss': Array(0.13198642, dtype=float32), 'loss_cross_entropy': Array(0.12528098, dtype=float32)}
{'loss_inverse': Array(0.00018705, dtype=float32)}


  1%|          | 9939/1000000 [28:58<29:41:16,  9.26it/s]

{'loss': Array(0.15191369, dtype=float32), 'loss_cross_entropy': Array(0.14443718, dtype=float32)}
{'loss_inverse': Array(0.0001547, dtype=float32)}


  1%|          | 9949/1000000 [28:59<28:40:51,  9.59it/s]

{'loss': Array(0.12461019, dtype=float32), 'loss_cross_entropy': Array(0.11827912, dtype=float32)}
{'loss_inverse': Array(0.00016951, dtype=float32)}


  1%|          | 9959/1000000 [29:01<50:02:30,  5.50it/s]

{'loss': Array(0.12833405, dtype=float32), 'loss_cross_entropy': Array(0.12123799, dtype=float32)}
{'loss_inverse': Array(0.00022161, dtype=float32)}


  1%|          | 9969/1000000 [29:03<32:09:55,  8.55it/s]

{'loss': Array(0.15876783, dtype=float32), 'loss_cross_entropy': Array(0.15048747, dtype=float32)}
{'loss_inverse': Array(0.00025812, dtype=float32)}


  1%|          | 9979/1000000 [29:04<28:46:11,  9.56it/s]

{'loss': Array(0.10543399, dtype=float32), 'loss_cross_entropy': Array(0.09947944, dtype=float32)}
{'loss_inverse': Array(0.00024999, dtype=float32)}


  1%|          | 9989/1000000 [29:06<28:43:45,  9.57it/s]

{'loss': Array(0.13817596, dtype=float32), 'loss_cross_entropy': Array(0.13111615, dtype=float32)}
{'loss_inverse': Array(0.00017009, dtype=float32)}


  1%|          | 9999/1000000 [29:08<48:32:58,  5.66it/s]

{'loss': Array(0.14718263, dtype=float32), 'loss_cross_entropy': Array(0.13925178, dtype=float32)}
{'loss_inverse': Array(0.00015365, dtype=float32)}


  1%|          | 10009/1000000 [29:16<100:36:43,  2.73it/s]

{'loss': Array(0.14417438, dtype=float32), 'loss_cross_entropy': Array(0.13647294, dtype=float32)}
{'loss_inverse': Array(0.00012998, dtype=float32)}


  1%|          | 10019/1000000 [29:18<40:39:44,  6.76it/s] 

{'loss': Array(0.12633567, dtype=float32), 'loss_cross_entropy': Array(0.11922842, dtype=float32)}
{'loss_inverse': Array(7.855202e-05, dtype=float32)}


  1%|          | 10029/1000000 [29:19<30:47:16,  8.93it/s]

{'loss': Array(0.13331895, dtype=float32), 'loss_cross_entropy': Array(0.1266906, dtype=float32)}
{'loss_inverse': Array(0.00021795, dtype=float32)}


  1%|          | 10039/1000000 [29:21<43:33:22,  6.31it/s]

{'loss': Array(0.13522933, dtype=float32), 'loss_cross_entropy': Array(0.12793358, dtype=float32)}
{'loss_inverse': Array(9.724439e-05, dtype=float32)}


  1%|          | 10049/1000000 [29:23<31:43:18,  8.67it/s]

{'loss': Array(0.15956815, dtype=float32), 'loss_cross_entropy': Array(0.15161335, dtype=float32)}
{'loss_inverse': Array(0.00010805, dtype=float32)}


  1%|          | 10059/1000000 [29:24<28:12:59,  9.75it/s]

{'loss': Array(0.13382553, dtype=float32), 'loss_cross_entropy': Array(0.12664747, dtype=float32)}
{'loss_inverse': Array(5.9892314e-05, dtype=float32)}


  1%|          | 10069/1000000 [29:26<28:21:15,  9.70it/s]

{'loss': Array(0.13369924, dtype=float32), 'loss_cross_entropy': Array(0.12645112, dtype=float32)}
{'loss_inverse': Array(5.3911826e-05, dtype=float32)}


  1%|          | 10079/1000000 [29:28<32:43:41,  8.40it/s]

{'loss': Array(0.13519, dtype=float32), 'loss_cross_entropy': Array(0.12792401, dtype=float32)}
{'loss_inverse': Array(5.952251e-05, dtype=float32)}


  1%|          | 10089/1000000 [29:29<28:50:48,  9.53it/s]

{'loss': Array(0.12838206, dtype=float32), 'loss_cross_entropy': Array(0.12139749, dtype=float32)}
{'loss_inverse': Array(7.9464975e-05, dtype=float32)}


  1%|          | 10099/1000000 [29:31<29:44:47,  9.24it/s]

{'loss': Array(0.12489146, dtype=float32), 'loss_cross_entropy': Array(0.11795413, dtype=float32)}
{'loss_inverse': Array(0.00013678, dtype=float32)}


  1%|          | 10109/1000000 [29:33<39:04:56,  7.04it/s]

{'loss': Array(0.13299847, dtype=float32), 'loss_cross_entropy': Array(0.12601845, dtype=float32)}
{'loss_inverse': Array(0.00010096, dtype=float32)}


  1%|          | 10119/1000000 [29:34<29:24:02,  9.35it/s]

{'loss': Array(0.12181465, dtype=float32), 'loss_cross_entropy': Array(0.11494537, dtype=float32)}
{'loss_inverse': Array(0.00013335, dtype=float32)}


  1%|          | 10129/1000000 [29:36<28:24:53,  9.68it/s]

{'loss': Array(0.13234012, dtype=float32), 'loss_cross_entropy': Array(0.12514828, dtype=float32)}
{'loss_inverse': Array(0.00017288, dtype=float32)}


  1%|          | 10139/1000000 [29:37<28:01:59,  9.81it/s]

{'loss': Array(0.12376263, dtype=float32), 'loss_cross_entropy': Array(0.1173998, dtype=float32)}
{'loss_inverse': Array(0.00019903, dtype=float32)}


  1%|          | 10149/1000000 [29:39<34:30:24,  7.97it/s]

{'loss': Array(0.13061152, dtype=float32), 'loss_cross_entropy': Array(0.12394928, dtype=float32)}
{'loss_inverse': Array(0.00042016, dtype=float32)}


  1%|          | 10159/1000000 [29:41<28:58:22,  9.49it/s]

{'loss': Array(0.11944778, dtype=float32), 'loss_cross_entropy': Array(0.11313342, dtype=float32)}
{'loss_inverse': Array(0.00032009, dtype=float32)}


  1%|          | 10169/1000000 [29:42<28:19:34,  9.71it/s]

{'loss': Array(0.12682347, dtype=float32), 'loss_cross_entropy': Array(0.12023901, dtype=float32)}
{'loss_inverse': Array(0.00059988, dtype=float32)}


  1%|          | 10179/1000000 [29:43<27:21:30, 10.05it/s]

{'loss': Array(0.12222105, dtype=float32), 'loss_cross_entropy': Array(0.11579802, dtype=float32)}
{'loss_inverse': Array(0.0003143, dtype=float32)}


  1%|          | 10189/1000000 [29:45<38:04:24,  7.22it/s]

{'loss': Array(0.14159022, dtype=float32), 'loss_cross_entropy': Array(0.13368027, dtype=float32)}
{'loss_inverse': Array(0.00052733, dtype=float32)}


  1%|          | 10199/1000000 [29:47<29:05:03,  9.45it/s]

{'loss': Array(0.12152483, dtype=float32), 'loss_cross_entropy': Array(0.11510421, dtype=float32)}
{'loss_inverse': Array(0.00050268, dtype=float32)}


  1%|          | 10209/1000000 [29:48<27:47:37,  9.89it/s]

{'loss': Array(0.13603692, dtype=float32), 'loss_cross_entropy': Array(0.12837544, dtype=float32)}
{'loss_inverse': Array(0.00026689, dtype=float32)}


  1%|          | 10219/1000000 [29:50<27:45:00,  9.91it/s]

{'loss': Array(0.13191313, dtype=float32), 'loss_cross_entropy': Array(0.12495637, dtype=float32)}
{'loss_inverse': Array(0.00049087, dtype=float32)}


  1%|          | 10229/1000000 [29:52<43:03:40,  6.38it/s]

{'loss': Array(0.12711947, dtype=float32), 'loss_cross_entropy': Array(0.12026685, dtype=float32)}
{'loss_inverse': Array(0.00026192, dtype=float32)}


  1%|          | 10239/1000000 [29:53<30:01:00,  9.16it/s]

{'loss': Array(0.15462978, dtype=float32), 'loss_cross_entropy': Array(0.14662842, dtype=float32)}
{'loss_inverse': Array(0.00039529, dtype=float32)}


  1%|          | 10249/1000000 [29:55<28:48:30,  9.54it/s]

{'loss': Array(0.14139101, dtype=float32), 'loss_cross_entropy': Array(0.13398409, dtype=float32)}
{'loss_inverse': Array(0.00032017, dtype=float32)}


  1%|          | 10259/1000000 [29:56<27:55:07,  9.85it/s]

{'loss': Array(0.14080034, dtype=float32), 'loss_cross_entropy': Array(0.13308394, dtype=float32)}
{'loss_inverse': Array(0.00051883, dtype=float32)}


  1%|          | 10269/1000000 [29:58<48:27:20,  5.67it/s]

{'loss': Array(0.13480611, dtype=float32), 'loss_cross_entropy': Array(0.12775493, dtype=float32)}
{'loss_inverse': Array(0.0004106, dtype=float32)}


  1%|          | 10279/1000000 [30:00<30:45:00,  8.94it/s]

{'loss': Array(0.13334003, dtype=float32), 'loss_cross_entropy': Array(0.12681665, dtype=float32)}
{'loss_inverse': Array(0.00035837, dtype=float32)}


  1%|          | 10289/1000000 [30:01<28:52:01,  9.52it/s]

{'loss': Array(0.13284348, dtype=float32), 'loss_cross_entropy': Array(0.12515984, dtype=float32)}
{'loss_inverse': Array(0.00033979, dtype=float32)}


  1%|          | 10299/1000000 [30:03<28:18:55,  9.71it/s]

{'loss': Array(0.13848297, dtype=float32), 'loss_cross_entropy': Array(0.13094215, dtype=float32)}
{'loss_inverse': Array(0.00034239, dtype=float32)}


  1%|          | 10309/1000000 [30:05<42:33:33,  6.46it/s]

{'loss': Array(0.13336116, dtype=float32), 'loss_cross_entropy': Array(0.12611286, dtype=float32)}
{'loss_inverse': Array(0.00022826, dtype=float32)}


  1%|          | 10319/1000000 [30:06<30:29:48,  9.01it/s]

{'loss': Array(0.14271152, dtype=float32), 'loss_cross_entropy': Array(0.13557313, dtype=float32)}
{'loss_inverse': Array(0.00020619, dtype=float32)}


  1%|          | 10329/1000000 [30:08<29:02:28,  9.47it/s]

{'loss': Array(0.14159973, dtype=float32), 'loss_cross_entropy': Array(0.13435248, dtype=float32)}
{'loss_inverse': Array(0.00019821, dtype=float32)}


  1%|          | 10339/1000000 [30:09<27:48:19,  9.89it/s]

{'loss': Array(0.11119344, dtype=float32), 'loss_cross_entropy': Array(0.10478523, dtype=float32)}
{'loss_inverse': Array(0.00014964, dtype=float32)}


  1%|          | 10349/1000000 [30:11<35:22:57,  7.77it/s]

{'loss': Array(0.12284308, dtype=float32), 'loss_cross_entropy': Array(0.11616743, dtype=float32)}
{'loss_inverse': Array(0.00024764, dtype=float32)}


  1%|          | 10359/1000000 [30:13<29:54:23,  9.19it/s]

{'loss': Array(0.13023989, dtype=float32), 'loss_cross_entropy': Array(0.12350845, dtype=float32)}
{'loss_inverse': Array(0.00015739, dtype=float32)}


  1%|          | 10369/1000000 [30:14<27:56:09,  9.84it/s]

{'loss': Array(0.13875698, dtype=float32), 'loss_cross_entropy': Array(0.13113461, dtype=float32)}
{'loss_inverse': Array(0.0002687, dtype=float32)}


  1%|          | 10379/1000000 [30:16<38:10:52,  7.20it/s]

{'loss': Array(0.13769393, dtype=float32), 'loss_cross_entropy': Array(0.13041379, dtype=float32)}
{'loss_inverse': Array(0.00054677, dtype=float32)}


  1%|          | 10389/1000000 [30:18<30:41:18,  8.96it/s]

{'loss': Array(0.11771911, dtype=float32), 'loss_cross_entropy': Array(0.11121408, dtype=float32)}
{'loss_inverse': Array(0.00044371, dtype=float32)}


  1%|          | 10399/1000000 [30:19<27:53:57,  9.85it/s]

{'loss': Array(0.13545793, dtype=float32), 'loss_cross_entropy': Array(0.12815456, dtype=float32)}
{'loss_inverse': Array(0.00019284, dtype=float32)}


  1%|          | 10409/1000000 [30:21<48:52:32,  5.62it/s]

{'loss': Array(0.13628918, dtype=float32), 'loss_cross_entropy': Array(0.12902783, dtype=float32)}
{'loss_inverse': Array(0.00043182, dtype=float32)}


  1%|          | 10419/1000000 [30:22<31:09:05,  8.82it/s]

{'loss': Array(0.11942279, dtype=float32), 'loss_cross_entropy': Array(0.11255584, dtype=float32)}
{'loss_inverse': Array(0.00020754, dtype=float32)}


  1%|          | 10429/1000000 [30:24<28:29:15,  9.65it/s]

{'loss': Array(0.14367925, dtype=float32), 'loss_cross_entropy': Array(0.1363894, dtype=float32)}
{'loss_inverse': Array(0.00062464, dtype=float32)}


  1%|          | 10439/1000000 [30:25<28:53:54,  9.51it/s]

{'loss': Array(0.13247178, dtype=float32), 'loss_cross_entropy': Array(0.12494697, dtype=float32)}
{'loss_inverse': Array(0.00027374, dtype=float32)}


  1%|          | 10449/1000000 [30:27<35:34:32,  7.73it/s]

{'loss': Array(0.1292984, dtype=float32), 'loss_cross_entropy': Array(0.12252569, dtype=float32)}
{'loss_inverse': Array(0.00027782, dtype=float32)}


  1%|          | 10459/1000000 [30:29<29:09:30,  9.43it/s]

{'loss': Array(0.13692115, dtype=float32), 'loss_cross_entropy': Array(0.12981224, dtype=float32)}
{'loss_inverse': Array(0.00041306, dtype=float32)}


  1%|          | 10469/1000000 [30:30<28:40:24,  9.59it/s]

{'loss': Array(0.12786488, dtype=float32), 'loss_cross_entropy': Array(0.12063923, dtype=float32)}
{'loss_inverse': Array(0.00053419, dtype=float32)}


  1%|          | 10479/1000000 [30:32<27:30:44,  9.99it/s]

{'loss': Array(0.14207016, dtype=float32), 'loss_cross_entropy': Array(0.13467221, dtype=float32)}
{'loss_inverse': Array(0.00039049, dtype=float32)}


  1%|          | 10489/1000000 [30:34<32:55:55,  8.35it/s]

{'loss': Array(0.14787, dtype=float32), 'loss_cross_entropy': Array(0.13969587, dtype=float32)}
{'loss_inverse': Array(0.0001264, dtype=float32)}


  1%|          | 10499/1000000 [30:35<29:14:19,  9.40it/s]

{'loss': Array(0.13089924, dtype=float32), 'loss_cross_entropy': Array(0.12381418, dtype=float32)}
{'loss_inverse': Array(0.00025605, dtype=float32)}


  1%|          | 10509/1000000 [30:44<94:50:06,  2.90it/s] 

{'loss': Array(0.14657374, dtype=float32), 'loss_cross_entropy': Array(0.13885301, dtype=float32)}
{'loss_inverse': Array(0.00025755, dtype=float32)}


  1%|          | 10519/1000000 [30:45<40:52:10,  6.73it/s] 

{'loss': Array(0.11730893, dtype=float32), 'loss_cross_entropy': Array(0.11084118, dtype=float32)}
{'loss_inverse': Array(0.00017626, dtype=float32)}


  1%|          | 10529/1000000 [30:47<39:58:51,  6.87it/s]

{'loss': Array(0.11188455, dtype=float32), 'loss_cross_entropy': Array(0.10592204, dtype=float32)}
{'loss_inverse': Array(0.00051945, dtype=float32)}


  1%|          | 10539/1000000 [30:48<30:18:52,  9.07it/s]

{'loss': Array(0.13781457, dtype=float32), 'loss_cross_entropy': Array(0.13055915, dtype=float32)}
{'loss_inverse': Array(0.00040026, dtype=float32)}


  1%|          | 10549/1000000 [30:50<29:38:51,  9.27it/s]

{'loss': Array(0.11771777, dtype=float32), 'loss_cross_entropy': Array(0.11133015, dtype=float32)}
{'loss_inverse': Array(0.00044442, dtype=float32)}


  1%|          | 10559/1000000 [30:51<28:02:13,  9.80it/s]

{'loss': Array(0.12059183, dtype=float32), 'loss_cross_entropy': Array(0.11375206, dtype=float32)}
{'loss_inverse': Array(0.00016539, dtype=float32)}


  1%|          | 10569/1000000 [30:53<32:10:53,  8.54it/s]

{'loss': Array(0.12467771, dtype=float32), 'loss_cross_entropy': Array(0.11781635, dtype=float32)}
{'loss_inverse': Array(0.00017802, dtype=float32)}


  1%|          | 10579/1000000 [30:55<30:58:38,  8.87it/s]

{'loss': Array(0.12467936, dtype=float32), 'loss_cross_entropy': Array(0.11824808, dtype=float32)}
{'loss_inverse': Array(0.00042783, dtype=float32)}


  1%|          | 10589/1000000 [30:56<27:55:57,  9.84it/s]

{'loss': Array(0.13885927, dtype=float32), 'loss_cross_entropy': Array(0.13188162, dtype=float32)}
{'loss_inverse': Array(0.00041423, dtype=float32)}


  1%|          | 10599/1000000 [30:58<42:15:01,  6.50it/s]

{'loss': Array(0.14164202, dtype=float32), 'loss_cross_entropy': Array(0.13425387, dtype=float32)}
{'loss_inverse': Array(0.00039986, dtype=float32)}


  1%|          | 10609/1000000 [31:00<31:06:51,  8.83it/s]

{'loss': Array(0.13988061, dtype=float32), 'loss_cross_entropy': Array(0.1327806, dtype=float32)}
{'loss_inverse': Array(0.00073461, dtype=float32)}


  1%|          | 10619/1000000 [31:01<28:51:06,  9.53it/s]

{'loss': Array(0.11826213, dtype=float32), 'loss_cross_entropy': Array(0.11185862, dtype=float32)}
{'loss_inverse': Array(0.00050771, dtype=float32)}


  1%|          | 10629/1000000 [31:03<43:42:25,  6.29it/s]

{'loss': Array(0.14476672, dtype=float32), 'loss_cross_entropy': Array(0.13708654, dtype=float32)}
{'loss_inverse': Array(0.00052062, dtype=float32)}


  1%|          | 10639/1000000 [31:05<31:35:10,  8.70it/s]

{'loss': Array(0.12315823, dtype=float32), 'loss_cross_entropy': Array(0.11663129, dtype=float32)}
{'loss_inverse': Array(0.00056954, dtype=float32)}


  1%|          | 10649/1000000 [31:06<28:51:39,  9.52it/s]

{'loss': Array(0.13576369, dtype=float32), 'loss_cross_entropy': Array(0.12783025, dtype=float32)}
{'loss_inverse': Array(0.00024624, dtype=float32)}


  1%|          | 10659/1000000 [31:08<27:48:46,  9.88it/s]

{'loss': Array(0.11197738, dtype=float32), 'loss_cross_entropy': Array(0.10541135, dtype=float32)}
{'loss_inverse': Array(0.00031491, dtype=float32)}


  1%|          | 10669/1000000 [31:10<49:13:01,  5.58it/s]

{'loss': Array(0.14904808, dtype=float32), 'loss_cross_entropy': Array(0.1415523, dtype=float32)}
{'loss_inverse': Array(0.00045728, dtype=float32)}


  1%|          | 10679/1000000 [31:11<32:17:14,  8.51it/s]

{'loss': Array(0.13465454, dtype=float32), 'loss_cross_entropy': Array(0.1279648, dtype=float32)}
{'loss_inverse': Array(0.00044388, dtype=float32)}


  1%|          | 10689/1000000 [31:13<28:43:51,  9.56it/s]

{'loss': Array(0.13480614, dtype=float32), 'loss_cross_entropy': Array(0.12809223, dtype=float32)}
{'loss_inverse': Array(0.00035568, dtype=float32)}


  1%|          | 10699/1000000 [31:14<27:36:16,  9.96it/s]

{'loss': Array(0.1314037, dtype=float32), 'loss_cross_entropy': Array(0.12406474, dtype=float32)}
{'loss_inverse': Array(0.00054063, dtype=float32)}


  1%|          | 10709/1000000 [31:16<42:22:45,  6.48it/s]

{'loss': Array(0.14717868, dtype=float32), 'loss_cross_entropy': Array(0.14004368, dtype=float32)}
{'loss_inverse': Array(0.00028363, dtype=float32)}


  1%|          | 10719/1000000 [31:18<31:04:25,  8.84it/s]

{'loss': Array(0.14535032, dtype=float32), 'loss_cross_entropy': Array(0.13782075, dtype=float32)}
{'loss_inverse': Array(0.00013949, dtype=float32)}


  1%|          | 10729/1000000 [31:19<28:11:26,  9.75it/s]

{'loss': Array(0.14117122, dtype=float32), 'loss_cross_entropy': Array(0.13371049, dtype=float32)}
{'loss_inverse': Array(0.00069431, dtype=float32)}


  1%|          | 10739/1000000 [31:21<27:53:57,  9.85it/s]

{'loss': Array(0.1459701, dtype=float32), 'loss_cross_entropy': Array(0.13830756, dtype=float32)}
{'loss_inverse': Array(0.00081331, dtype=float32)}


  1%|          | 10749/1000000 [31:23<34:33:30,  7.95it/s]

{'loss': Array(0.13694392, dtype=float32), 'loss_cross_entropy': Array(0.12900607, dtype=float32)}
{'loss_inverse': Array(0.00039243, dtype=float32)}


  1%|          | 10759/1000000 [31:24<28:20:06,  9.70it/s]

{'loss': Array(0.12819846, dtype=float32), 'loss_cross_entropy': Array(0.12157436, dtype=float32)}
{'loss_inverse': Array(0.00025745, dtype=float32)}


  1%|          | 10769/1000000 [31:25<28:29:01,  9.65it/s]

{'loss': Array(0.1264943, dtype=float32), 'loss_cross_entropy': Array(0.11928421, dtype=float32)}
{'loss_inverse': Array(0.00033243, dtype=float32)}


  1%|          | 10779/1000000 [31:27<42:22:02,  6.49it/s]

{'loss': Array(0.12002188, dtype=float32), 'loss_cross_entropy': Array(0.11354631, dtype=float32)}
{'loss_inverse': Array(0.00021193, dtype=float32)}


  1%|          | 10789/1000000 [31:29<29:41:22,  9.26it/s]

{'loss': Array(0.15054013, dtype=float32), 'loss_cross_entropy': Array(0.14258575, dtype=float32)}
{'loss_inverse': Array(0.00022259, dtype=float32)}


  1%|          | 10799/1000000 [31:30<28:55:00,  9.50it/s]

{'loss': Array(0.14574002, dtype=float32), 'loss_cross_entropy': Array(0.1380627, dtype=float32)}
{'loss_inverse': Array(0.00044028, dtype=float32)}


  1%|          | 10809/1000000 [31:32<29:00:59,  9.47it/s]

{'loss': Array(0.15154305, dtype=float32), 'loss_cross_entropy': Array(0.14369446, dtype=float32)}
{'loss_inverse': Array(0.00033925, dtype=float32)}


  1%|          | 10819/1000000 [31:34<37:58:30,  7.24it/s]

{'loss': Array(0.13355087, dtype=float32), 'loss_cross_entropy': Array(0.12676369, dtype=float32)}
{'loss_inverse': Array(0.00043122, dtype=float32)}


  1%|          | 10829/1000000 [31:35<30:09:09,  9.11it/s]

{'loss': Array(0.12758456, dtype=float32), 'loss_cross_entropy': Array(0.12032545, dtype=float32)}
{'loss_inverse': Array(0.00029058, dtype=float32)}


  1%|          | 10839/1000000 [31:37<28:20:32,  9.69it/s]

{'loss': Array(0.13325553, dtype=float32), 'loss_cross_entropy': Array(0.12634663, dtype=float32)}
{'loss_inverse': Array(0.00042983, dtype=float32)}


  1%|          | 10849/1000000 [31:38<26:43:22, 10.28it/s]

{'loss': Array(0.15693012, dtype=float32), 'loss_cross_entropy': Array(0.14954875, dtype=float32)}
{'loss_inverse': Array(0.00028636, dtype=float32)}


  1%|          | 10859/1000000 [31:40<38:35:52,  7.12it/s]

{'loss': Array(0.13560684, dtype=float32), 'loss_cross_entropy': Array(0.12848364, dtype=float32)}
{'loss_inverse': Array(0.00024137, dtype=float32)}


  1%|          | 10869/1000000 [31:42<29:35:15,  9.29it/s]

{'loss': Array(0.13619575, dtype=float32), 'loss_cross_entropy': Array(0.12852739, dtype=float32)}
{'loss_inverse': Array(0.00029053, dtype=float32)}


  1%|          | 10879/1000000 [31:43<27:56:25,  9.83it/s]

{'loss': Array(0.14965235, dtype=float32), 'loss_cross_entropy': Array(0.1417554, dtype=float32)}
{'loss_inverse': Array(0.00027557, dtype=float32)}


  1%|          | 10889/1000000 [31:45<28:28:22,  9.65it/s]

{'loss': Array(0.1365401, dtype=float32), 'loss_cross_entropy': Array(0.12959002, dtype=float32)}
{'loss_inverse': Array(0.00013976, dtype=float32)}


  1%|          | 10899/1000000 [31:47<33:42:45,  8.15it/s]

{'loss': Array(0.12245095, dtype=float32), 'loss_cross_entropy': Array(0.11585655, dtype=float32)}
{'loss_inverse': Array(0.00016392, dtype=float32)}


  1%|          | 10909/1000000 [31:48<28:37:27,  9.60it/s]

{'loss': Array(0.13535151, dtype=float32), 'loss_cross_entropy': Array(0.12769482, dtype=float32)}
{'loss_inverse': Array(0.00021001, dtype=float32)}


  1%|          | 10919/1000000 [31:50<28:17:06,  9.71it/s]

{'loss': Array(0.14924923, dtype=float32), 'loss_cross_entropy': Array(0.14184533, dtype=float32)}
{'loss_inverse': Array(0.00035573, dtype=float32)}


  1%|          | 10929/1000000 [31:52<38:21:48,  7.16it/s]

{'loss': Array(0.16198096, dtype=float32), 'loss_cross_entropy': Array(0.15406026, dtype=float32)}
{'loss_inverse': Array(0.00064819, dtype=float32)}


  1%|          | 10939/1000000 [31:53<29:25:47,  9.34it/s]

{'loss': Array(0.12981729, dtype=float32), 'loss_cross_entropy': Array(0.12239969, dtype=float32)}
{'loss_inverse': Array(0.00034715, dtype=float32)}


  1%|          | 10949/1000000 [31:54<27:57:29,  9.83it/s]

{'loss': Array(0.13986753, dtype=float32), 'loss_cross_entropy': Array(0.13231269, dtype=float32)}
{'loss_inverse': Array(0.00026222, dtype=float32)}


  1%|          | 10959/1000000 [31:56<42:52:18,  6.41it/s]

{'loss': Array(0.14989887, dtype=float32), 'loss_cross_entropy': Array(0.14151044, dtype=float32)}
{'loss_inverse': Array(0.0001666, dtype=float32)}


  1%|          | 10969/1000000 [31:58<30:28:13,  9.02it/s]

{'loss': Array(0.14170656, dtype=float32), 'loss_cross_entropy': Array(0.13354824, dtype=float32)}
{'loss_inverse': Array(0.00049743, dtype=float32)}


  1%|          | 10979/1000000 [31:59<28:08:12,  9.76it/s]

{'loss': Array(0.13049623, dtype=float32), 'loss_cross_entropy': Array(0.12322431, dtype=float32)}
{'loss_inverse': Array(0.00042463, dtype=float32)}


  1%|          | 10989/1000000 [32:01<28:22:43,  9.68it/s]

{'loss': Array(0.14447978, dtype=float32), 'loss_cross_entropy': Array(0.13680662, dtype=float32)}
{'loss_inverse': Array(0.00030849, dtype=float32)}


  1%|          | 10999/1000000 [32:03<38:24:43,  7.15it/s]

{'loss': Array(0.15165384, dtype=float32), 'loss_cross_entropy': Array(0.14374983, dtype=float32)}
{'loss_inverse': Array(0.00025699, dtype=float32)}


  1%|          | 11009/1000000 [32:11<95:36:01,  2.87it/s] 

{'loss': Array(0.1166917, dtype=float32), 'loss_cross_entropy': Array(0.1103872, dtype=float32)}
{'loss_inverse': Array(0.00011856, dtype=float32)}


  1%|          | 11019/1000000 [32:13<39:32:27,  6.95it/s] 

{'loss': Array(0.13682176, dtype=float32), 'loss_cross_entropy': Array(0.12969688, dtype=float32)}
{'loss_inverse': Array(0.00031301, dtype=float32)}


  1%|          | 11029/1000000 [32:14<29:29:49,  9.31it/s]

{'loss': Array(0.12274163, dtype=float32), 'loss_cross_entropy': Array(0.11575514, dtype=float32)}
{'loss_inverse': Array(0.00037947, dtype=float32)}


  1%|          | 11039/1000000 [32:15<28:42:25,  9.57it/s]

{'loss': Array(0.1251534, dtype=float32), 'loss_cross_entropy': Array(0.11829948, dtype=float32)}
{'loss_inverse': Array(0.00042456, dtype=float32)}


  1%|          | 11049/1000000 [32:17<32:50:44,  8.36it/s]

{'loss': Array(0.13253103, dtype=float32), 'loss_cross_entropy': Array(0.12530065, dtype=float32)}
{'loss_inverse': Array(0.00026453, dtype=float32)}


  1%|          | 11059/1000000 [32:19<28:05:41,  9.78it/s]

{'loss': Array(0.1319333, dtype=float32), 'loss_cross_entropy': Array(0.12499159, dtype=float32)}
{'loss_inverse': Array(0.00030086, dtype=float32)}


  1%|          | 11069/1000000 [32:20<28:00:59,  9.81it/s]

{'loss': Array(0.13675334, dtype=float32), 'loss_cross_entropy': Array(0.12930588, dtype=float32)}
{'loss_inverse': Array(0.00025041, dtype=float32)}


  1%|          | 11079/1000000 [32:22<27:13:33, 10.09it/s]

{'loss': Array(0.13776879, dtype=float32), 'loss_cross_entropy': Array(0.13050304, dtype=float32)}
{'loss_inverse': Array(0.00036971, dtype=float32)}


  1%|          | 11089/1000000 [32:24<32:31:30,  8.45it/s]

{'loss': Array(0.12445356, dtype=float32), 'loss_cross_entropy': Array(0.11812712, dtype=float32)}
{'loss_inverse': Array(0.00035625, dtype=float32)}


  1%|          | 11099/1000000 [32:25<29:01:54,  9.46it/s]

{'loss': Array(0.1409292, dtype=float32), 'loss_cross_entropy': Array(0.13314003, dtype=float32)}
{'loss_inverse': Array(0.00013497, dtype=float32)}


  1%|          | 11109/1000000 [32:27<27:47:29,  9.88it/s]

{'loss': Array(0.13580707, dtype=float32), 'loss_cross_entropy': Array(0.12867963, dtype=float32)}
{'loss_inverse': Array(7.962567e-05, dtype=float32)}


  1%|          | 11119/1000000 [32:29<38:40:11,  7.10it/s]

{'loss': Array(0.1437497, dtype=float32), 'loss_cross_entropy': Array(0.1362148, dtype=float32)}
{'loss_inverse': Array(0.00016617, dtype=float32)}


  1%|          | 11129/1000000 [32:30<31:10:04,  8.81it/s]

{'loss': Array(0.14382051, dtype=float32), 'loss_cross_entropy': Array(0.13616571, dtype=float32)}
{'loss_inverse': Array(0.0001819, dtype=float32)}


  1%|          | 11139/1000000 [32:32<28:46:51,  9.54it/s]

{'loss': Array(0.14039128, dtype=float32), 'loss_cross_entropy': Array(0.13299352, dtype=float32)}
{'loss_inverse': Array(0.00025728, dtype=float32)}


  1%|          | 11149/1000000 [32:34<43:27:08,  6.32it/s]

{'loss': Array(0.14090246, dtype=float32), 'loss_cross_entropy': Array(0.13333425, dtype=float32)}
{'loss_inverse': Array(0.00025409, dtype=float32)}


  1%|          | 11159/1000000 [32:35<31:30:47,  8.72it/s]

{'loss': Array(0.12586805, dtype=float32), 'loss_cross_entropy': Array(0.11881917, dtype=float32)}
{'loss_inverse': Array(0.00037224, dtype=float32)}


  1%|          | 11169/1000000 [32:37<28:44:35,  9.56it/s]

{'loss': Array(0.14123002, dtype=float32), 'loss_cross_entropy': Array(0.13421296, dtype=float32)}
{'loss_inverse': Array(0.00030791, dtype=float32)}


  1%|          | 11179/1000000 [32:38<28:00:30,  9.81it/s]

{'loss': Array(0.11954238, dtype=float32), 'loss_cross_entropy': Array(0.11331654, dtype=float32)}
{'loss_inverse': Array(0.00038125, dtype=float32)}


  1%|          | 11189/1000000 [32:40<33:30:30,  8.20it/s]

{'loss': Array(0.14517315, dtype=float32), 'loss_cross_entropy': Array(0.13748518, dtype=float32)}
{'loss_inverse': Array(0.00041477, dtype=float32)}


  1%|          | 11199/1000000 [32:42<28:27:34,  9.65it/s]

{'loss': Array(0.12735383, dtype=float32), 'loss_cross_entropy': Array(0.12041821, dtype=float32)}
{'loss_inverse': Array(0.00052713, dtype=float32)}


  1%|          | 11209/1000000 [32:43<28:27:32,  9.65it/s]

{'loss': Array(0.14919794, dtype=float32), 'loss_cross_entropy': Array(0.14074177, dtype=float32)}
{'loss_inverse': Array(0.00035385, dtype=float32)}


  1%|          | 11219/1000000 [32:45<51:01:04,  5.38it/s]

{'loss': Array(0.12408117, dtype=float32), 'loss_cross_entropy': Array(0.11713844, dtype=float32)}
{'loss_inverse': Array(0.00035306, dtype=float32)}


  1%|          | 11229/1000000 [32:47<32:29:18,  8.45it/s]

{'loss': Array(0.14468352, dtype=float32), 'loss_cross_entropy': Array(0.13673733, dtype=float32)}
{'loss_inverse': Array(0.00028, dtype=float32)}


  1%|          | 11239/1000000 [32:48<28:04:04,  9.79it/s]

{'loss': Array(0.13232365, dtype=float32), 'loss_cross_entropy': Array(0.12546796, dtype=float32)}
{'loss_inverse': Array(0.00040923, dtype=float32)}


  1%|          | 11249/1000000 [32:50<28:34:11,  9.61it/s]

{'loss': Array(0.14410053, dtype=float32), 'loss_cross_entropy': Array(0.13656078, dtype=float32)}
{'loss_inverse': Array(0.00012895, dtype=float32)}


  1%|          | 11259/1000000 [32:51<27:54:28,  9.84it/s]

{'loss': Array(0.11685409, dtype=float32), 'loss_cross_entropy': Array(0.10989442, dtype=float32)}
{'loss_inverse': Array(0.00024929, dtype=float32)}


  1%|          | 11269/1000000 [32:53<32:40:18,  8.41it/s]

{'loss': Array(0.13981822, dtype=float32), 'loss_cross_entropy': Array(0.1323336, dtype=float32)}
{'loss_inverse': Array(0.00025824, dtype=float32)}


  1%|          | 11279/1000000 [32:55<29:57:06,  9.17it/s]

{'loss': Array(0.12327302, dtype=float32), 'loss_cross_entropy': Array(0.11669427, dtype=float32)}
{'loss_inverse': Array(0.00023421, dtype=float32)}


  1%|          | 11289/1000000 [32:56<28:10:30,  9.75it/s]

{'loss': Array(0.13695326, dtype=float32), 'loss_cross_entropy': Array(0.1296871, dtype=float32)}
{'loss_inverse': Array(0.00048394, dtype=float32)}


  1%|          | 11299/1000000 [32:58<27:49:25,  9.87it/s]

{'loss': Array(0.13885657, dtype=float32), 'loss_cross_entropy': Array(0.1319447, dtype=float32)}
{'loss_inverse': Array(0.0001921, dtype=float32)}


  1%|          | 11301/1000000 [32:59<79:05:09,  3.47it/s]

In [37]:

def sampling_model(key, model, sample_eval, nb_step=100, config=None):
    """
    Function used to sampling a state from a list 
    """
    seq_len_future = config.seq_len - config.seq_len // 4 
    noise_future  = jax.random.dirichlet(key, jnp.ones(6) * 5., (config.batch_size, seq_len_future, 54))
    sample_eval["reward"] = jnp.linspace(start=-0.5, stop=0.5, num=config.batch_size)[:, None]

    for t_step in range(nb_step):
        t_step_array = jnp.ones((config.batch_size, 1, 1, 1)) * float(t_step / nb_step)
        sample_eval["context"] = jnp.concatenate([sample_eval["reward"], t_step_array[:, :, 0, 0]], axis=1)

        estimation_logits_past, estimation_logits_future = model(
            sample_eval["state_past"], noise_future, sample_eval["context"]
        )

        estimation_proba_future = jax.nn.softmax(estimation_logits_future, axis=-1)

        noise_future = noise_future + float(1. / nb_step) * 1./ (1. - t_step_array + 0.0001) * (estimation_proba_future - noise_future)

    return noise_future



In [18]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer_eval, buffer_list_eval = dataset.fast_gathering_data_diffusion(
    env,
    vmap_reset,
    vmap_step,
    int(128),
    config.len_seq,
    buffer_eval,
    buffer_list_eval,
    subkey,
)

sample = buffer_eval.sample(buffer_list_eval, subkey)
sample = reshape_diffusion_setup(sample, subkey)

In [19]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

sample = buffer.sample(buffer_list, subkey)
sample = reshape_diffusion_setup(sample, subkey)


result = sampling_model(key=config.jax_key, model=transformer, sample_eval=sample, nb_batch_explore=128, nb_step=100)
result

Array([[[[9.12160613e-06, 1.71264401e-05, 1.48795079e-05,
          9.99910355e-01, 2.99175736e-05, 1.85837271e-05],
         [1.93071319e-05, 9.99914408e-01, 2.00294890e-05,
          1.98439229e-05, 9.08272341e-06, 1.74135203e-05],
         [1.85810495e-05, 6.38415804e-06, 9.99915481e-01,
          2.45338306e-05, 1.44244405e-05, 2.06008554e-05],
         ...,
         [1.11915870e-05, 2.57499050e-05, 9.99916375e-01,
          1.56825408e-05, 1.51600689e-05, 1.58642652e-05],
         [9.99907255e-01, 1.58234034e-05, 1.93407759e-05,
          3.18640377e-05, 1.44680962e-05, 1.12638809e-05],
         [1.67636899e-05, 6.71518501e-06, 9.99912083e-01,
          2.26497650e-05, 2.02378724e-05, 2.14804895e-05]],

        [[1.03620114e-05, 1.93770975e-05, 9.99914050e-01,
          1.39374752e-05, 2.23403331e-05, 1.99424103e-05],
         [1.14582945e-05, 1.58930197e-05, 2.36062333e-05,
          9.99919713e-01, 1.25739025e-05, 1.67943072e-05],
         [7.86513556e-06, 2.15659384e-05, 2.7151

In [24]:
index_batch  = 64

jnp.argmax(sample["state_past"], axis=-1).reshape((128, 8, 6, 3, 3))[index_batch, -1, :, :, :]

Array([[[2, 0, 1],
        [0, 0, 5],
        [2, 4, 5]],

       [[5, 1, 4],
        [2, 1, 2],
        [1, 3, 3]],

       [[3, 4, 4],
        [1, 2, 1],
        [5, 5, 1]],

       [[5, 4, 3],
        [5, 3, 0],
        [0, 1, 0]],

       [[0, 3, 1],
        [2, 4, 3],
        [4, 4, 0]],

       [[4, 5, 2],
        [3, 5, 2],
        [3, 0, 2]]], dtype=int32)

In [25]:
jnp.argmax(result, axis=-1).reshape((128, 24, 6, 3, 3))[index_batch, 0, :, :, :]

Array([[[2, 0, 1],
        [0, 0, 5],
        [3, 1, 5]],

       [[4, 2, 3],
        [1, 1, 3],
        [5, 2, 1]],

       [[2, 4, 4],
        [5, 2, 1],
        [4, 5, 1]],

       [[5, 4, 3],
        [5, 3, 0],
        [0, 1, 0]],

       [[0, 3, 5],
        [2, 4, 4],
        [4, 4, 2]],

       [[1, 3, 0],
        [3, 5, 2],
        [3, 0, 2]]], dtype=int32)

In [22]:
jnp.argmax(result, axis=-1).reshape((128, 24, 6, 3, 3))[index_batch, 1, :, :, :]

Array([[[2, 5, 0],
        [5, 0, 3],
        [4, 2, 5]],

       [[3, 0, 2],
        [4, 1, 5],
        [4, 3, 4]],

       [[1, 0, 1],
        [3, 2, 1],
        [5, 2, 1]],

       [[2, 1, 5],
        [0, 3, 1],
        [4, 1, 0]],

       [[3, 4, 0],
        [4, 4, 0],
        [4, 3, 5]],

       [[1, 4, 3],
        [2, 5, 0],
        [3, 2, 0]]], dtype=int32)

In [None]:
sample = buffer_eval.sample(buffer_list_eval, subkey)
sample = reshape_sample(sample)

TrajectoryBufferSample(experience={'action': Array([[[1.32556781e-01, 7.96739519e-01, 5.36718592e-02, ...,
         3.91646661e-03, 4.48901858e-03, 9.91594553e-01],
        [3.49070907e-01, 4.57749265e-04, 4.38157976e-01, ...,
         7.23136306e-01, 1.23497941e-01, 1.53365776e-01],
        [6.12441264e-03, 2.50436477e-02, 1.35732419e-03, ...,
         3.82237613e-01, 5.98694921e-01, 1.90675538e-02],
        ...,
        [1.41329234e-04, 2.44877161e-03, 8.43136787e-01, ...,
         2.33344346e-01, 6.42170012e-01, 1.24485560e-01],
        [6.32655225e-04, 1.77795421e-02, 9.65278149e-01, ...,
         1.25269741e-02, 3.21629345e-01, 6.65843725e-01],
        [9.08881542e-04, 1.04175135e-01, 7.50824576e-04, ...,
         9.99683421e-03, 7.89827347e-01, 2.00175866e-01]],

       [[2.03237548e-01, 7.00179100e-01, 3.63819454e-05, ...,
         9.96583939e-01, 2.39940570e-03, 1.01662707e-03],
        [7.63220847e-01, 1.11325733e-01, 3.15520242e-02, ...,
         5.45369804e-01, 4.54322606e-0

In [None]:
# now we should (easily) train an inverse model







AttributeError: 'RubikTransformer' object has no attribute 'state_mapping'

In [14]:
# save buffer, buffer_list
state_weight = nnx.state(transformer)

In [15]:
state_weight

State({
  'action_mapping': {
    'bias': VariableState(
      type=Param,
      value=Array([ 4.04955857e-02,  2.82790326e-02, -7.85927773e-02,  7.52160996e-02,
             -6.11112965e-03,  6.96982583e-03, -1.17343664e-02, -1.74523471e-03,
              1.28632234e-02, -7.83682019e-02,  2.75444221e-02, -4.00350802e-02,
              1.79233290e-02,  8.38570073e-02,  2.03401130e-02,  4.92124483e-02,
              8.69528428e-02,  2.20998153e-02, -3.42875794e-02, -6.76687211e-02,
              2.17811018e-02,  8.36544111e-02, -3.08539756e-02, -8.56901798e-03,
             -6.66830465e-02,  1.15918748e-01,  5.94779989e-03,  1.72799546e-02,
             -1.16014622e-01, -6.75882176e-02, -6.16184436e-02, -5.52975051e-02,
              4.17982265e-02, -5.43787293e-02,  1.19193546e-01, -9.40112211e-03,
             -4.03175130e-02, -3.47817354e-02, -6.77642366e-03,  8.85512680e-02,
              4.83243428e-02,  6.59283325e-02, -5.58541063e-03,  3.46533172e-02,
             -6.63223118e-02

In [31]:
def generate_past_state_with_with_random_policy(key, vmap_reset, step_jit_env, config):
    """
    Generate past state with random policy

    Args:
        config: configuration object

    Returns:
        state_past: (batch_size, len_seq//4, 6, 3, 3)

    """

    key1, key2 = jax.random.split(config.jax_key)

    keys = jax.random.split(key1, config.batch_size)
    state, timestep = vmap_reset(keys)

    last_state = None
    past_state = []

    for i in range(config.len_seq // 4):

        # apply random policy and retrieve state
        action = jax.random.randint(
            key=config.jax_key,
            minval=env.action_spec.minimum,
            maxval=env.action_spec.maximum,
            shape=(config.batch_size, 3),
        )

        state, timestep  = step_jit_env(state, action)
        past_state.append(state.cube)

    # concat all the past state to get the shape (batch_size, len_seq//4, 6, 3, 3) from a list of state of size (batch_size, 6, 3, 3) by creating the 1 axis
    state_past = jnp.stack(past_state, axis=1)

    return state_past, state

step_jit_env = jax.vmap(jit_step)

state_past, state = generate_past_state_with_with_random_policy(key, vmap_reset, step_jit_env, config)

In [32]:

def apply_decision_diffuser_policy(key, state_past, decision_diffuser, inverse_rl_model, config):
    """
    1. Make a estimation of the targeted reward
    2. Generate futur state with those targeted reward
    3. Choose policy from that
    """
    sample_eval = {
        "state_past": jax.nn.one_hot(state_past, 6),
    }

    state_future = sampling_model(key, decision_diffuser, sample_eval, nb_step=100, config=config)

    # state_future is (batch_size, seq_len, dim_input_state / 6, 6)
    state_to_act = jnp.concatenate([state_past, state_future], axis=1)
    state_to_act_futur_t = state_to_act[:, (config.seq_len // 4 - 1):(-1), :, :]
    state_to_act_futur_td1 = state_to_act[:, (config.seq_len // 4):, :, :]

    # flatten the last 2 axis
    state_to_act_futur_t = state_to_act_futur_t.reshape(
        (state_to_act_futur_t.shape[0], state_to_act_futur_t.shape[1], -1)
    )

    state_to_act_futur_td1 = state_to_act_futur_td1.reshape(
        (state_to_act_futur_td1.shape[0], state_to_act_futur_td1.shape[1], -1)
    )

    # now use reverse RL to compute the action TODO later
    actions = inverse_rl_model(state_to_act_futur_t, state_to_act_futur_td1)

    return actions



(128, 9, 6, 3, 3)

In [None]:

def gather_data_with_policy(state, state_past, actions, buffer, buffer_list, config):
    """
    For loop with those policy and state

    log performance compare to target

    """


    for i in range(config.len_seq - config.len_seq // 4):
        actions_step = actions[:, i, :]
        actions_0 = jnp.argmax(actions_step[:, :6], axis=1)
        actions_1 = jnp.argmax(actions_step[:, 6:], axis=1)

        actions_full = jnp.concatenate([actions_0, jnp.zeros(config.batch_size, 1.), actions_1], axis=1)
    
        # step 
        state, timestep  = step_jit_env(state, actions_full)

    # TODO SAVE DATA into batch format for later training



def improve_training_loop():
    """
    Relaunch the training loop with those new data incorporated into the buffer
    """
    pass


# now we want to modify the code that generate proper data according to policy
def policy_generation_calibration(decision_diffuser, inverse_model, buffer, buffer_list, key, nb_path_value, nb_futur, config):
    """
    Full stuff here

    Online transformer setup

    1. We generate env setup 
    2. First random action in the different env
    3. Use decision_diffuser to choose the action to do from here
    4. Observe / apply policy  to retrieve data
    5. Add the data into the buffer
    6. Train model on those data

    Remember to log the performance data to compare with other run / algorithms

    """
    pass
