In [1]:
# chose the current file directory as the working directory
import os
os.chdir("/teamspace/studios/this_studio/rubikscubesolver")

In [2]:
from tqdm import tqdm
import pickle

import wandb  # for logging
import time
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import flax.nnx as nnx

import optax

from rubiktransformer.model_diffusion_dt import RubikDTTransformer, InverseRLModel
import rubiktransformer.dataset as dataset
from rubiktransformer.trainer import reshape_sample

cuda_plugin_extension is not found.


In [3]:
@dataclass
class Config:
    """Configuration class"""

    jax_key: jnp.ndarray = jax.random.PRNGKey(49)
    rngs = nnx.Rngs(48)
    batch_size: int = 128
    lr_1: float = 4e-4
    lr_2: float = 4e-4
    nb_games: int = 128 * 100
    len_seq: int = 32
    nb_step: int = 1000000
    log_every_step: int = 10
    log_eval_every_step: int = 10
    log_policy_reward_every_step: int = 10
    add_data_every_step: int = 500

    save_model_every_step: int = 2000


config = Config()

# init wandb config
user = "forbu14"
project = "RubikTransformer"
display_name = "experiment_" + time.strftime("%Y%m%d-%H%M%S")

wandb.init(entity=user, project=project, name=display_name)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mforbu14[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
transformer = RubikDTTransformer(rngs=config.rngs, causal=True)

inverse_rl_model = InverseRLModel(
    dim_input_state = 6 * 6 * 3 * 3,
    dim_output_action = 6 + 3,
    dim_middle = 1024,
    nb_layers = 3,
    rngs=config.rngs,
)

scheduler = optax.linear_schedule(init_value=0.0, end_value=1.0, transition_steps=4000)

# init optimizer
optimizer_dd = optax.chain(
    optax.clip_by_global_norm(1.0),
    #optax.lion(config.lr_1 / 10.0),
    optax.adamw(config.lr_1),
    optax.scale_by_schedule(scheduler),
)

optimizer_rl_inverse = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(config.lr_2),
)

optimizer_diffuser = nnx.Optimizer(transformer, optimizer_dd)
optimizer_inverse = nnx.Optimizer(inverse_rl_model, optimizer_rl_inverse)

# metrics
metrics_train = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
    loss_cross_entropy=nnx.metrics.Average("loss_cross_entropy"),
)

metrics_eval = nnx.MultiMetric(
    loss_eval=nnx.metrics.Average("loss_eval"),

    loss_cross_entropy_eval=nnx.metrics.Average("loss_cross_entropy_eval"),
)

# metric for inverse model
metrics_inverse = nnx.MultiMetric(
    loss_inverse=nnx.metrics.Average("loss_inverse"),
)

In [5]:
# load weight from world model transformer:
import pickle

filename = "state_ddt_model_improved_v0.pickle"

with open(filename, "rb") as input_file:
    state = pickle.load(input_file)

nnx.update(transformer, state)

# load weight from world model transformer:
import pickle

filename = "state_inverse_rl_model_improved_v0.pickle"

with open(filename, "rb") as input_file:
    state = pickle.load(input_file)

nnx.update(inverse_rl_model, state)

In [6]:
# gather data from the environment
# init models and optimizers
env, buffer = dataset.init_env_buffer(sample_batch_size=config.batch_size)
env, buffer_eval = dataset.init_env_buffer(sample_batch_size=config.batch_size)


nb_games = config.nb_games
len_seq = config.len_seq

state_first = jnp.zeros((6, 3, 3))
state_next = jnp.zeros((len_seq, 6, 3, 3))
action = jnp.zeros((len_seq, 3))
action_proba = jnp.zeros((len_seq, 9))

# transform state to int8 type
state_first = state_first.astype(jnp.int8)
state_next = state_next.astype(jnp.int8)

# action to int32 type
action = action.astype(jnp.int32)

reward = jnp.zeros((1))

jit_step = jax.jit(env.step)

buffer_list = buffer.init(
    {
        "action": action,
        "reward": reward,
        "state_histo": state_next,
    }
)

buffer_list_eval = buffer_eval.init(
    {
        "action": action,
        "reward": reward,
        "state_histo": state_next,
    }
)

In [7]:
def step_fn(state, key):
    """
    Simple step function
    We choose a random action
    """

    action = jax.random.randint(
        key=key,
        minval=env.action_spec.minimum,
        maxval=env.action_spec.maximum,
        shape=(3,),
    )

    new_state, timestep = jit_step(state, action)
    timestep.extras["action"] = action

    return new_state, timestep


def run_n_steps(state, key, n):
    random_keys = jax.random.split(key, n)
    state, rollout = jax.lax.scan(step_fn, state, random_keys)

    return rollout


vmap_reset = jax.vmap(jax.jit(env.reset))
vmap_step = jax.vmap(run_n_steps, in_axes=(0, 0, None))

In [8]:
nnx.display(transformer)

In [9]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data_diffusion(
    env,
    vmap_reset,
    vmap_step,
    int(config.nb_games / 10),
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

In [10]:
sample = buffer.sample(buffer_list, subkey)

def reshape_diffusion_setup(sample, key=jax.random.PRNGKey(0)):
    sample.experience["state_histo"] = sample.experience["state_histo"].reshape(
        (sample.experience["state_histo"].shape[0], sample.experience["state_histo"].shape[1], 54)
    )

    # one hot encoding for state_histo
    sample.experience["state_histo"] = jax.nn.one_hot(
        sample.experience["state_histo"],
        num_classes=6,
        axis=-1,
    )

    # batch creation
    batch  = sample.experience
    len_seq = batch["state_histo"].shape[1]

    time_step = jax.random.uniform(
        key, (batch["state_histo"].shape[0], 1, 1, 1)
    ) # random value between 0 and 1

    batch['time_step'] = time_step

    # now contact the value to have the context for the rectified flow setup
    batch["context"] = jnp.concatenate([batch["reward"], time_step[:, :, 0, 0]], axis=1)

    batch["state_past"] = batch["state_histo"][:, :len_seq//4, :, :]
    batch["state_future"] = batch["state_histo"][:, len_seq//4:, :, :]

    # now we generate the random noise for the rectified flow setup
    simplex_noise = jax.random.dirichlet(key, jnp.ones(6), batch["state_future"].shape[:-1])

    batch["state_future_noise"] = (
        (1 - time_step) * simplex_noise + time_step * batch["state_future"]
    )

    batch["action_inverse"] = sample.experience["action"][:, 1:, :]

    # flatten the action_inverse to only have batch data
    batch["action_inverse"] = jnp.reshape(batch["action_inverse"], (batch["action_inverse"].shape[0] * batch["action_inverse"].shape[1], -1))

    # now we can one hot encode the action_inverse

    action_inverse_0 = jax.nn.one_hot(batch["action_inverse"][:, 0], num_classes=6, axis=-1)
    action_inverse_1 = jax.nn.one_hot(batch["action_inverse"][:, 2], num_classes=3, axis=-1)

    batch["action_inverse"] = jnp.concatenate([action_inverse_0, action_inverse_1], axis=1)

    state_histo_inverse_t = sample.experience["state_histo"][:, :-1, :, :]
    state_histo_inverse_td1 = sample.experience["state_histo"][:, 1:, :, :]

    batch["state_histo_inverse_t"] = state_histo_inverse_t
    batch["state_histo_inverse_td1"] = state_histo_inverse_td1

    # we flatten the two state_histo_inverse
    batch["state_histo_inverse_t"] = jnp.reshape(batch["state_histo_inverse_t"], (batch["state_histo_inverse_t"].shape[0] * batch["state_histo_inverse_t"].shape[1], -1))
    batch["state_histo_inverse_td1"] = jnp.reshape(batch["state_histo_inverse_td1"], (batch["state_histo_inverse_td1"].shape[0] * batch["state_histo_inverse_td1"].shape[1], -1))

    return batch


sample = reshape_diffusion_setup(sample)


In [11]:
def loss_fn_transformer_rf(model: RubikDTTransformer, batch):
    # rectified flow setup
    state_past, state_future = model(
        batch["state_past"], batch["state_future_noise"], batch["context"]
    )

    loss_crossentropy = optax.softmax_cross_entropy(
        logits=state_future, labels=batch["state_future"]
    ).mean(axis=[1, 2])

    weight = jnp.clip(1. / (1. - batch["time_step"][:, 0, 0, 0]), min=0.005, max=1.5)

    loss_cross_entropy_weight = loss_crossentropy * weight

    return loss_cross_entropy_weight.mean(), (loss_crossentropy.mean())


@nnx.jit
def train_step_transformer_rf(
    model: RubikDTTransformer,
    optimizer: nnx.Optimizer,
    metrics: nnx.MultiMetric,
    batch,
):
    """Train for a single step."""

    grad_fn = nnx.value_and_grad(loss_fn_transformer_rf, has_aux=True)
    (loss, (loss_crossentropy)), grads = grad_fn(model, batch)
    metrics.update(
        loss=loss, loss_cross_entropy=loss_crossentropy
    )
    optimizer.update(grads)

################# INVERSE RL ####################
def loss_fn_inverse_rl(model: InverseRLModel, batch):
    # rectified flow setup
    action = model(
        batch["state_histo_inverse_t"], batch["state_histo_inverse_td1"]
    )

    loss_crossentropy_0 = optax.softmax_cross_entropy(
        logits=action[:, :6], labels=batch["action_inverse"][:, :6]
    ).mean()

    loss_cross_entropy_1 = optax.softmax_cross_entropy(
        logits=action[:, 6:], labels=batch["action_inverse"][:, 6:]
    ).mean()

    return loss_crossentropy_0 + loss_cross_entropy_1


@nnx.jit
def train_step_inverse_rl(
    model: InverseRLModel,
    optimizer: nnx.Optimizer,
    metrics: nnx.MultiMetric,
    batch,
):
    """Train for a single step."""

    grad_fn = nnx.value_and_grad(loss_fn_inverse_rl)
    loss, grads = grad_fn(model, batch)
    metrics.update(
        loss_inverse=loss
    )
    optimizer.update(grads)

In [12]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data_diffusion(
    env,
    vmap_reset,
    vmap_step,
    config.nb_games * 1, # old is int(config.nb_games * 10.0),
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

In [None]:


# transformer model calibration
for idx_step in tqdm(range(config.nb_step)):
    # training for world model
    key, subkey = jax.random.split(config.jax_key)
    config.jax_key = key

    if idx_step % config.add_data_every_step == 0:
        buffer, buffer_list = dataset.fast_gathering_data_diffusion(
            env,
            vmap_reset,
            vmap_step,
            int(config.nb_games // 10),
            config.len_seq,
            buffer,
            buffer_list,
            config.jax_key,
        )

    sample = buffer.sample(buffer_list, subkey)
    sample = reshape_diffusion_setup(sample, subkey)

    # we update the policy
    train_step_transformer_rf(
        transformer, optimizer_diffuser, metrics_train, sample
    )

    # train the inverse model
    train_step_inverse_rl(
        inverse_rl_model, optimizer_inverse, metrics_inverse, sample
    )

    if idx_step % config.log_every_step == 0:
        metrics_train_result = metrics_train.compute()
        print(metrics_train_result)

        wandb.log(metrics_train_result, step=idx_step)
        metrics_train.reset()

        metrics_inverse_result = metrics_inverse.compute()
        print(metrics_inverse_result)

        wandb.log(metrics_inverse_result, step=idx_step)
        metrics_inverse.reset()

    if idx_step % config.log_eval_every_step == 0:
        key, subkey = jax.random.split(config.jax_key)
        config.jax_key = key

        buffer_eval, buffer_list_eval = dataset.fast_gathering_data_diffusion(
            env,
            vmap_reset,
            vmap_step,
            int(config.batch_size),
            config.len_seq,
            buffer_eval,
            buffer_list_eval,
            subkey,
        )

        sample = buffer_eval.sample(buffer_list_eval, subkey)
        sample = reshape_diffusion_setup(sample, subkey)

        loss, (loss_crossentropy) = loss_fn_transformer_rf(
            transformer, sample
        )

        metrics_eval.update(
            loss_eval=loss,
            loss_cross_entropy_eval=loss_crossentropy,
        )
        wandb.log(metrics_eval.compute(), step=idx_step)

        metrics_eval.reset()

    if idx_step % config.save_model_every_step == 0:

        state_weight = nnx.state(transformer)

        with open("state_ddt_model_improved.pickle", "wb") as handle:
            pickle.dump(state_weight, handle, protocol=pickle.HIGHEST_PROTOCOL)

        # save inverse model
        state_weight = nnx.state(inverse_rl_model)

        with open("state_inverse_rl_model_improved.pickle", "wb") as handle:
            pickle.dump(state_weight, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [16]:

def sampling_model(key, model, sample_eval, nb_step=100, config=None):
    """
    Function used to sampling a state from a list 
    """
    seq_len_future = config.len_seq - config.len_seq // 4 
    noise_future  = jax.random.dirichlet(key, jnp.ones(6) * 5., (config.batch_size, seq_len_future, 54))
    sample_eval["reward"] = jnp.linspace(start=-0.5, stop=0.5, num=config.batch_size)[:, None]

    for t_step in range(nb_step):
        t_step_array = jnp.ones((config.batch_size, 1, 1, 1)) * float(t_step / nb_step)
        sample_eval["context"] = jnp.concatenate([sample_eval["reward"], t_step_array[:, :, 0, 0]], axis=1)

        estimation_logits_past, estimation_logits_future = model(
            sample_eval["state_past"], noise_future, sample_eval["context"]
        )

        estimation_proba_future = jax.nn.softmax(estimation_logits_future, axis=-1)

        noise_future = noise_future + float(1. / nb_step) * 1./ (1. - t_step_array + 0.0001) * (estimation_proba_future - noise_future)

    return noise_future



In [17]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer_eval, buffer_list_eval = dataset.fast_gathering_data_diffusion(
    env,
    vmap_reset,
    vmap_step,
    int(config.batch_size),
    config.len_seq,
    buffer_eval,
    buffer_list_eval,
    subkey,
)

sample = buffer_eval.sample(buffer_list_eval, subkey)
sample = reshape_diffusion_setup(sample, subkey)

In [23]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

sample = buffer.sample(buffer_list, subkey)
sample = reshape_diffusion_setup(sample, subkey)


result = sampling_model(key=config.jax_key, model=transformer, sample_eval=sample, config=config, nb_step=100)
result

Array([[[[1.09494431e-05, 1.66741665e-05, 9.99895751e-01,
          2.22197268e-05, 3.50209884e-05, 1.94163295e-05],
         [2.11554579e-05, 2.44809780e-05, 2.29226425e-05,
          9.99903977e-01, 8.82399036e-06, 1.86653342e-05],
         [2.21990049e-05, 8.59028660e-06, 2.98996456e-05,
          9.99905646e-01, 1.01229525e-05, 2.35370826e-05],
         ...,
         [1.43166399e-05, 3.20903491e-05, 9.99898672e-01,
          1.93911837e-05, 1.82657968e-05, 1.72907021e-05],
         [9.99898255e-01, 1.77349430e-05, 2.03370582e-05,
          3.52105126e-05, 1.59306219e-05, 1.25847291e-05],
         [2.01943330e-05, 8.95850826e-06, 9.99900341e-01,
          2.50139274e-05, 2.45582778e-05, 2.09209975e-05]],

        [[1.19837932e-05, 1.96152832e-05, 9.99905944e-01,
          1.66372629e-05, 2.55447812e-05, 2.02928204e-05],
         [1.41649507e-05, 1.69913983e-05, 2.56728381e-05,
          9.99911547e-01, 1.36010349e-05, 1.80726638e-05],
         [1.13089336e-05, 2.37140339e-05, 3.1119

In [24]:
index_batch  = 64

jnp.argmax(sample["state_past"], axis=-1).reshape((128, 8, 6, 3, 3))[index_batch, -1, :, :, :]

Array([[[2, 0, 1],
        [0, 0, 5],
        [2, 4, 5]],

       [[5, 1, 4],
        [2, 1, 2],
        [1, 3, 3]],

       [[3, 4, 4],
        [1, 2, 1],
        [5, 5, 1]],

       [[5, 4, 3],
        [5, 3, 0],
        [0, 1, 0]],

       [[0, 3, 1],
        [2, 4, 3],
        [4, 4, 0]],

       [[4, 5, 2],
        [3, 5, 2],
        [3, 0, 2]]], dtype=int32)

In [25]:
jnp.argmax(result, axis=-1).reshape((128, 24, 6, 3, 3))[index_batch, 0, :, :, :]

Array([[[0, 0, 1],
        [0, 0, 5],
        [3, 4, 5]],

       [[2, 1, 4],
        [0, 1, 2],
        [2, 3, 3]],

       [[3, 4, 4],
        [1, 2, 1],
        [5, 5, 1]],

       [[5, 4, 3],
        [5, 3, 3],
        [0, 1, 4]],

       [[4, 2, 0],
        [4, 4, 3],
        [0, 3, 1]],

       [[5, 5, 2],
        [2, 5, 2],
        [1, 0, 2]]], dtype=int32)

In [26]:
jnp.argmax(result, axis=-1).reshape((128, 24, 6, 3, 3))[index_batch, 1, :, :, :]

Array([[[0, 0, 5],
        [0, 0, 5],
        [3, 4, 1]],

       [[2, 1, 4],
        [0, 1, 2],
        [2, 3, 4]],

       [[5, 4, 3],
        [1, 2, 1],
        [5, 5, 1]],

       [[2, 4, 3],
        [5, 3, 3],
        [0, 1, 4]],

       [[4, 2, 0],
        [4, 4, 3],
        [0, 3, 1]],

       [[5, 5, 3],
        [2, 5, 2],
        [1, 0, 2]]], dtype=int32)

In [None]:
sample = buffer_eval.sample(buffer_list_eval, subkey)
sample = reshape_sample(sample)

TrajectoryBufferSample(experience={'action': Array([[[1.32556781e-01, 7.96739519e-01, 5.36718592e-02, ...,
         3.91646661e-03, 4.48901858e-03, 9.91594553e-01],
        [3.49070907e-01, 4.57749265e-04, 4.38157976e-01, ...,
         7.23136306e-01, 1.23497941e-01, 1.53365776e-01],
        [6.12441264e-03, 2.50436477e-02, 1.35732419e-03, ...,
         3.82237613e-01, 5.98694921e-01, 1.90675538e-02],
        ...,
        [1.41329234e-04, 2.44877161e-03, 8.43136787e-01, ...,
         2.33344346e-01, 6.42170012e-01, 1.24485560e-01],
        [6.32655225e-04, 1.77795421e-02, 9.65278149e-01, ...,
         1.25269741e-02, 3.21629345e-01, 6.65843725e-01],
        [9.08881542e-04, 1.04175135e-01, 7.50824576e-04, ...,
         9.99683421e-03, 7.89827347e-01, 2.00175866e-01]],

       [[2.03237548e-01, 7.00179100e-01, 3.63819454e-05, ...,
         9.96583939e-01, 2.39940570e-03, 1.01662707e-03],
        [7.63220847e-01, 1.11325733e-01, 3.15520242e-02, ...,
         5.45369804e-01, 4.54322606e-0

In [40]:
def generate_past_state_with_with_random_policy(key, vmap_reset, step_jit_env, config):
    """
    Generate past state with random policy

    Args:
        config: configuration object

    Returns:
        state_past: (batch_size, len_seq//4, 6, 3, 3)

    """

    key1, key2 = jax.random.split(config.jax_key)

    keys = jax.random.split(key1, config.batch_size)
    state, timestep = vmap_reset(keys)

    last_state = None
    past_state = []

    actions_all = jax.random.randint(
        key=config.jax_key,
        minval=env.action_spec.minimum,
        maxval=env.action_spec.maximum,
        shape=(config.batch_size, config.len_seq // 4, 3),
    )

    for i in range(config.len_seq // 4):

        # apply random policy and retrieve state
        action = actions_all[:, i, :]

        state, timestep  = step_jit_env(state, action)
        past_state.append(state.cube)

    # concat all the past state to get the shape (batch_size, len_seq//4, 6, 3, 3) from a list of state of size (batch_size, 6, 3, 3) by creating the 1 axis
    state_past = jnp.stack(past_state, axis=1)

    return state_past, state, actions_all

step_jit_env = jax.vmap(jit_step)

state_past, state, actions_past = generate_past_state_with_with_random_policy(key, vmap_reset, step_jit_env, config)

In [31]:
state_past.shape

(128, 8, 6, 3, 3)

In [41]:

def apply_decision_diffuser_policy(key, state_past, decision_diffuser, inverse_rl_model, config):
    """
    1. Make a estimation of the targeted reward
    2. Generate futur state with those targeted reward
    3. Choose policy from that
    """
    sample_eval = {
        "state_past": jax.nn.one_hot(state_past, 6),
    }

    state_past = jnp.copy(state_past.reshape((state_past.shape[0], state_past.shape[1], -1)))
    state_past = jax.nn.one_hot(state_past, num_classes=6)

    state_future = sampling_model(key, decision_diffuser, sample_eval, nb_step=100, config=config)

    # state_future is (batch_size, seq_len, dim_input_state / 6, 6)
    state_to_act = jnp.concatenate([state_past, state_future], axis=1)
    state_to_act_futur_t = state_to_act[:, (config.len_seq // 4 - 1):(-1), :, :]
    state_to_act_futur_td1 = state_to_act[:, (config.len_seq // 4):, :, :]

    # flatten the last 2 axis
    state_to_act_futur_t = state_to_act_futur_t.reshape(
        (state_to_act_futur_t.shape[0], state_to_act_futur_t.shape[1], -1)
    )

    state_to_act_futur_td1 = state_to_act_futur_td1.reshape(
        (state_to_act_futur_td1.shape[0], state_to_act_futur_td1.shape[1], -1)
    )

    # now use reverse RL to compute the action TODO later
    actions = inverse_rl_model(state_to_act_futur_t, state_to_act_futur_td1)

    return actions

actions_futur = apply_decision_diffuser_policy(config.jax_key, state_past, transformer, inverse_rl_model, config)

(128, 24, 9)

In [56]:

from rubiktransformer.dataset import GOAL_OBSERVATION

def gather_data_with_policy(state, state_past, actions_past, actions_futur, buffer, buffer_list, config):
    """
    For loop with those policy and state

    log performance compare to target

    """
    state_futur_list = []

    for i in range(config.len_seq - config.len_seq // 4):
        actions_step = actions_futur[:, i, :]
        actions_0 = jnp.argmax(actions_step[:, :6], axis=1)
        actions_1 = jnp.argmax(actions_step[:, 6:], axis=1)

        actions_full = jnp.stack([actions_0, jnp.zeros(config.batch_size), actions_1], axis=1)
        
        # transform to int type
        actions_full = actions_full.astype(jnp.int32)
    
        # step 
        state, timestep  = step_jit_env(state, actions_full)

        state_futur_list.append(state.cube)

    # TODO SAVE DATA into batch format for later training
    actions_0_all_futur = jnp.argmax(actions_futur[:, :, :6], axis=-1)
    actions_1_all_futur = jnp.argmax(actions_futur[:, :, 6:], axis=-1)

    action_all_futur = jnp.stack([actions_0_all_futur, jnp.zeros((config.batch_size, actions_0_all_futur.shape[1])), actions_1_all_futur], axis=-1)

    action_all = jnp.concatenate([actions_past, action_all_futur], axis=1)
    action_all = action_all.astype(jnp.int32)

    state_futur = jnp.stack(state_futur_list, axis=1)

    state_all = jnp.concatenate([state_past, state_futur], axis=1)

    # compute reward 
    goal_observation = jnp.repeat(
        GOAL_OBSERVATION[None, None, :, :, :], config.batch_size, axis=0
    )
    goal_observation = jnp.repeat(goal_observation, config.len_seq, axis=1)
    reward = jnp.where(state_all != goal_observation, -1.0, 1.0)

    reward = reward.mean(axis=[2, 3, 4])
    reward = reward[:, -1] - reward[:, config.len_seq//4]

    # add data into the buffer
    pass

    for idx_batch in range(config.batch_size):
        buffer_list = buffer.add(
            buffer_list,
            {
                "action": action_all[idx_batch],
                "reward": reward[idx_batch],
                "state_histo": state_all[idx_batch],
            },
        )

    return buffer, buffer_list

buffer, buffer_list = gather_data_with_policy(state, state_past, actions_past, actions_futur, buffer, buffer_list, config)


In [None]:


def improve_training_loop(nb_iter=10000):
    """
    Relaunch the training loop with those new data incorporated into the buffer
    
    Full stuff here
    Online transformer setup

    1. We generate env setup 
    2. First random action in the different env
    3. Use decision_diffuser to choose the action to do from here
    4. Observe / apply policy  to retrieve data
    5. Add the data into the buffer
    6. Train model on those data

    Remember to log the performance data to compare with other run / algorithms
    """
    
    for _ in range(nb_iter):

        key, subkey = jax.random.split(config.jax_key)
        config.jax_key = key

        # first generate random state
        state_past, state, actions_past = generate_past_state_with_with_random_policy(key, vmap_reset, step_jit_env, config)
        
        # apply model to get some generation
        actions_futur = apply_decision_diffuser_policy(config.jax_key, state_past, transformer, inverse_rl_model, config)

        # update replay buffer dataset
        buffer, buffer_list = gather_data_with_policy(state, state_past, actions_past, actions_futur, buffer, buffer_list, config)

        # now we can do the training loop
        sample = buffer.sample(buffer_list, subkey)
        sample = reshape_diffusion_setup(sample, subkey)

        # we update the policy
        train_step_transformer_rf(
            transformer, optimizer_diffuser, metrics_train, sample
        )

        if idx_step % config.log_every_step == 0:
            metrics_train_result = metrics_train.compute()
            print(metrics_train_result)

            wandb.log(metrics_train_result, step=idx_step)
            metrics_train.reset()


