In [1]:
# chose the current file directory as the working directory
import os
os.chdir("/teamspace/studios/this_studio/rubikscubesolver")

In [2]:
from tqdm import tqdm
import pickle

import wandb  # for logging
import time
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import flax.nnx as nnx

import rubiktransformer.dataset as dataset
from rubiktransformer.trainer import reshape_sample

from rubiktransformer.trainer_online import init_model_optimizer, init_buffer, train_step_transformer_rf, training_loop
from rubiktransformer.online_training_utils import run_n_steps, reshape_diffusion_setup

cuda_plugin_extension is not found.


In [3]:
@dataclass
class Config:
    """Configuration class"""

    jax_key: jnp.ndarray = jax.random.PRNGKey(49)
    rngs = nnx.Rngs(48)
    batch_size: int = 128
    lr_1: float = 4e-4
    lr_2: float = 4e-4
    nb_games: int = 128 * 100
    len_seq: int = 32
    nb_step: int = 1000000
    max_length_buffer: int = 1024 * 10
    log_every_step: int = 10
    log_eval_every_step: int = 10
    log_policy_reward_every_step: int = 10
    add_data_every_step: int = 500

    save_model_every_step: int = 2000


config = Config()

# init wandb config
user = "forbu14"
project = "RubikTransformer"
display_name = "experiment_" + time.strftime("%Y%m%d-%H%M%S")

wandb.init(entity=user, project=project, name=display_name)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mforbu14[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:

(
    optimizer_diffuser,
    optimizer_inverse,
    metrics_train,
    metrics_eval,
    metrics_inverse,
    transformer,
    inverse_rl_model,
) = init_model_optimizer(config)

env, buffer, buffer_eval, buffer_list, buffer_list_eval, jit_step = init_buffer(
    config
)

vmap_reset = jax.vmap(jax.jit(env.reset))
vmap_step = jax.vmap(run_n_steps, in_axes=(0, 0, None))

##### TRAINING #####
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key


In [5]:

buffer, buffer_list = dataset.fast_gathering_data_diffusion(
    env,
    vmap_reset,
    vmap_step,
    config.nb_games * 1,  # old is int(config.nb_games * 10.0),
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)


In [6]:
# load weight from world model transformer:
import pickle

filename = "state_ddt_model_improved_v2.pickle"

with open(filename, "rb") as input_file:
    state = pickle.load(input_file)

nnx.update(transformer, state)

# load weight from world model transformer:
import pickle

filename = "state_inverse_rl_model_improved_v2.pickle"

with open(filename, "rb") as input_file:
    state = pickle.load(input_file)

nnx.update(inverse_rl_model, state)

In [7]:
sample = buffer.sample(buffer_list, subkey)
sample = reshape_diffusion_setup(sample)


In [8]:
sample.keys()

dict_keys(['action', 'reward', 'state_histo', 'time_step', 'context', 'state_past', 'state_future', 'state_future_noise', 'action_inverse', 'state_histo_inverse_t', 'state_histo_inverse_td1'])

In [9]:
sample["action_inverse"]

Array([[0., 0., 1., ..., 1., 0., 0.],
       [0., 1., 0., ..., 1., 0., 0.],
       [0., 0., 1., ..., 1., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 1., ..., 1., 0., 0.],
       [0., 1., 0., ..., 1., 0., 0.]], dtype=float32)

In [10]:
sample["state_histo_inverse_t"]

Array([[0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       ...,
       [1., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [11]:
sample["state_histo_inverse_td1"]

Array([[0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [12]:
inverse_rl_model(sample["state_histo_inverse_t"], sample["state_histo_inverse_td1"])[:, :6]

Array([[-58.14122  , -12.908836 , 107.829445 , -26.21095  , -87.36493  ,
        -19.743734 ],
       [-17.515263 ,  80.81346  , -28.577995 , -49.040886 , -30.555523 ,
        -17.010876 ],
       [-30.758236 , -30.516747 ,  76.49304  , -14.1797085, -48.68231  ,
        -10.990468 ],
       ...,
       [-23.25964  , -62.139297 , -24.637434 ,  67.35371  , -12.765332 ,
        -13.466367 ],
       [-33.95061  , -17.303339 ,  73.70761  , -25.77107  , -45.953106 ,
        -11.735286 ],
       [-15.208434 ,  78.16751  , -26.444832 , -57.60352  , -29.599823 ,
        -18.02725  ]], dtype=float32)

In [29]:

def sampling_model(key, model, sample_eval, nb_step=100, config=None, target_reward=0.5):
    """
    Function used to sampling a state from a list 
    """
    seq_len_future = config.len_seq - config.len_seq // 4 
    noise_future  = jax.random.dirichlet(key, jnp.ones(6) * 5., (config.batch_size, seq_len_future, 54))
    sample_eval["reward"] = jnp.linspace(start=-0.1 + target_reward, stop=0.1 + target_reward, num=config.batch_size)[:, None]

    for t_step in range(nb_step):
        t_step_array = jnp.ones((config.batch_size, 1, 1, 1)) * float(t_step / nb_step)
        sample_eval["context"] = jnp.concatenate([sample_eval["reward"], t_step_array[:, :, 0, 0]], axis=1)

        estimation_logits_past, estimation_logits_future = model(
            sample_eval["state_past"], noise_future, sample_eval["context"]
        )

        estimation_proba_future = jax.nn.softmax(estimation_logits_future, axis=-1)

        noise_future = noise_future + float(1. / nb_step) * 1./ (1. - t_step_array + 0.0001) * (estimation_proba_future - noise_future)

    return noise_future



In [30]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer_eval, buffer_list_eval = dataset.fast_gathering_data_diffusion(
    env,
    vmap_reset,
    vmap_step,
    int(config.batch_size),
    config.len_seq,
    buffer_eval,
    buffer_list_eval,
    subkey,
)

sample = buffer_eval.sample(buffer_list_eval, subkey)
sample = reshape_diffusion_setup(sample, subkey)

In [31]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

sample = buffer.sample(buffer_list, subkey)
sample = reshape_diffusion_setup(sample, subkey)


result = sampling_model(key=config.jax_key, model=transformer, sample_eval=sample, config=config, nb_step=100)
result

Array([[[[9.99899507e-01, 1.74097950e-05, 1.12622511e-05,
          2.68227886e-05, 1.88681297e-05, 2.61575915e-05],
         [9.99913096e-01, 2.75000930e-05, 7.84196891e-06,
          1.92100415e-05, 1.67725375e-05, 1.56341121e-05],
         [1.78341288e-05, 1.04047358e-05, 1.42740319e-05,
          1.43313082e-05, 2.53261533e-05, 9.99917805e-01],
         ...,
         [1.94137683e-05, 1.77902402e-05, 1.11241825e-05,
          9.99910533e-01, 2.41212547e-05, 1.69395935e-05],
         [2.44495459e-05, 9.99926388e-01, 8.33871309e-06,
          8.53434904e-06, 8.80047446e-06, 2.35179905e-05],
         [1.31488778e-05, 7.83847645e-06, 9.99915063e-01,
          2.54756305e-05, 1.83007214e-05, 2.02064402e-05]],

        [[9.99909103e-01, 1.27787935e-05, 2.79305968e-05,
          2.04902608e-05, 9.05554043e-06, 2.06260011e-05],
         [3.59257683e-05, 1.37771713e-05, 1.21578341e-05,
          9.99911845e-01, 1.76803442e-05, 8.60699220e-06],
         [2.35249754e-05, 9.04168701e-06, 1.8083

In [32]:
index_batch  = 64

jnp.argmax(sample["state_past"], axis=-1).reshape((128, 8, 6, 3, 3))[index_batch, -1, :, :, :]

Array([[[4, 0, 2],
        [1, 0, 5],
        [3, 2, 2]],

       [[4, 3, 0],
        [3, 1, 1],
        [5, 5, 4]],

       [[1, 3, 3],
        [0, 2, 5],
        [3, 0, 5]],

       [[0, 3, 1],
        [4, 3, 1],
        [1, 4, 2]],

       [[0, 2, 5],
        [5, 4, 4],
        [3, 4, 1]],

       [[4, 2, 0],
        [1, 5, 2],
        [5, 0, 2]]], dtype=int32)

In [33]:
jnp.argmax(result, axis=-1).reshape((128, 24, 6, 3, 3))[index_batch, 0, :, :, :]

Array([[[2, 5, 2],
        [0, 0, 2],
        [4, 1, 3]],

       [[0, 2, 5],
        [3, 1, 1],
        [5, 5, 4]],

       [[4, 3, 0],
        [0, 2, 5],
        [3, 0, 5]],

       [[1, 3, 3],
        [4, 3, 1],
        [1, 4, 2]],

       [[0, 3, 1],
        [5, 4, 4],
        [3, 4, 1]],

       [[4, 2, 0],
        [1, 5, 2],
        [5, 0, 2]]], dtype=int32)

In [34]:
jnp.argmax(result, axis=-1).reshape((128, 24, 6, 3, 3))[index_batch, 1, :, :, :]

Array([[[3, 5, 0],
        [0, 0, 2],
        [4, 1, 3]],

       [[0, 2, 5],
        [3, 1, 1],
        [5, 5, 4]],

       [[4, 3, 2],
        [0, 2, 5],
        [3, 0, 2]],

       [[3, 1, 2],
        [3, 3, 4],
        [1, 4, 1]],

       [[5, 3, 1],
        [0, 4, 4],
        [2, 4, 1]],

       [[4, 2, 0],
        [1, 5, 2],
        [5, 5, 0]]], dtype=int32)

In [35]:
sample = buffer_eval.sample(buffer_list_eval, subkey)
sample = reshape_sample(sample)

KeyError: 'action_pred'

In [None]:
def generate_past_state_with_with_random_policy(key, vmap_reset, step_jit_env, config):
    """
    Generate past state with random policy

    Args:
        config: configuration object

    Returns:
        state_past: (batch_size, len_seq//4, 6, 3, 3)

    """

    key1, key2 = jax.random.split(config.jax_key)

    keys = jax.random.split(key1, config.batch_size)
    state, timestep = vmap_reset(keys)

    last_state = None
    past_state = []

    actions_all = jax.random.randint(
        key=config.jax_key,
        minval=env.action_spec.minimum,
        maxval=env.action_spec.maximum,
        shape=(config.batch_size, config.len_seq // 4, 3),
    )

    for i in range(config.len_seq // 4):

        # apply random policy and retrieve state
        action = actions_all[:, i, :]

        state, timestep  = step_jit_env(state, action)
        past_state.append(state.cube)

    # concat all the past state to get the shape (batch_size, len_seq//4, 6, 3, 3) from a list of state of size (batch_size, 6, 3, 3) by creating the 1 axis
    state_past = jnp.stack(past_state, axis=1)

    return state_past, state, actions_all

step_jit_env = jax.vmap(jit_step)

state_past, state, actions_past = generate_past_state_with_with_random_policy(key, vmap_reset, step_jit_env, config)

In [36]:
state_past.shape

(128, 8, 6, 3, 3)

In [38]:

def apply_decision_diffuser_policy(key, state_past, decision_diffuser, inverse_rl_model, config, target_reward=0.5):
    """
    1. Make a estimation of the targeted reward
    2. Generate futur state with those targeted reward
    3. Choose policy from that
    """
    sample_eval = {
        "state_past": jax.nn.one_hot(state_past, 6),
    }

    state_past = jnp.copy(state_past.reshape((state_past.shape[0], state_past.shape[1], -1)))
    state_past = jax.nn.one_hot(state_past, num_classes=6)

    state_future = sampling_model(key, decision_diffuser, sample_eval, nb_step=100, config=config, target_reward=target_reward)

    # state_future is (batch_size, seq_len, dim_input_state / 6, 6)
    state_to_act = jnp.concatenate([state_past, state_future], axis=1)
    state_to_act_futur_t = state_to_act[:, (config.len_seq // 4 - 1):(-1), :, :]
    state_to_act_futur_td1 = state_to_act[:, (config.len_seq // 4):, :, :]

    # flatten the last 2 axis
    state_to_act_futur_t = state_to_act_futur_t.reshape(
        (state_to_act_futur_t.shape[0], state_to_act_futur_t.shape[1], -1)
    )

    state_to_act_futur_td1 = state_to_act_futur_td1.reshape(
        (state_to_act_futur_td1.shape[0], state_to_act_futur_td1.shape[1], -1)
    )

    # now use reverse RL to compute the action TODO later
    actions = inverse_rl_model(state_to_act_futur_t, state_to_act_futur_td1)

    return actions

actions_futur = apply_decision_diffuser_policy(config.jax_key, state_past, transformer, inverse_rl_model, config)

In [39]:

from rubiktransformer.dataset import GOAL_OBSERVATION

def gather_data_with_policy(state, state_past, actions_past, actions_futur, buffer, buffer_list, config):
    """
    For loop with those policy and state

    log performance compare to target

    """
    state_futur_list = []

    for i in range(config.len_seq - config.len_seq // 4):
        actions_step = actions_futur[:, i, :]
        actions_0 = jnp.argmax(actions_step[:, :6], axis=1)
        actions_1 = jnp.argmax(actions_step[:, 6:], axis=1)

        actions_full = jnp.stack([actions_0, jnp.zeros(config.batch_size), actions_1], axis=1)
        
        # transform to int type
        actions_full = actions_full.astype(jnp.int32)
    
        # step 
        state, timestep  = step_jit_env(state, actions_full)

        state_futur_list.append(state.cube)

    # TODO SAVE DATA into batch format for later training
    actions_0_all_futur = jnp.argmax(actions_futur[:, :, :6], axis=-1)
    actions_1_all_futur = jnp.argmax(actions_futur[:, :, 6:], axis=-1)

    action_all_futur = jnp.stack([actions_0_all_futur, jnp.zeros((config.batch_size, actions_0_all_futur.shape[1])), actions_1_all_futur], axis=-1)

    action_all = jnp.concatenate([actions_past, action_all_futur], axis=1)
    action_all = action_all.astype(jnp.int32)

    state_futur = jnp.stack(state_futur_list, axis=1)

    state_all = jnp.concatenate([state_past, state_futur], axis=1)

    # compute reward 
    goal_observation = jnp.repeat(
        GOAL_OBSERVATION[None, None, :, :, :], config.batch_size, axis=0
    )
    goal_observation = jnp.repeat(goal_observation, config.len_seq, axis=1)
    reward = jnp.where(state_all != goal_observation, -1.0, 1.0)

    reward = reward.mean(axis=[2, 3, 4])
    reward = reward[:, -1] - reward[:, config.len_seq//4]


    for idx_batch in range(config.batch_size):
        buffer_list = buffer.add(
            buffer_list,
            {
                "action": action_all[idx_batch],
                "reward": reward[idx_batch],
                "state_histo": state_all[idx_batch],
            },
        )

    return buffer, buffer_list, reward

buffer, buffer_list, reward_real = gather_data_with_policy(state, state_past, actions_past, actions_futur, buffer, buffer_list, config)


In [41]:
def reward_hacking(reward):
    """
    reward is an array of value of shape (batch_size, len_seq, 1) with value between -1 and 1
    we want to apply to every element the funciton
    f(x) = 0.1 * jnp.exp(4 * x)
    """

    return 0.1 * jnp.exp(4.0 * reward)

def improve_training_loop(buffer, buffer_list, nb_iter=10000):
    """
    Relaunch the training loop with those new data incorporated into the buffer
    
    Full stuff here
    Online transformer setup

    1. We generate env setup 
    2. First random action in the different env
    3. Use decision_diffuser to choose the action to do from here
    4. Observe / apply policy  to retrieve data
    5. Add the data into the buffer
    6. Train model on those data

    Remember to log the performance data to compare with other run / algorithms
    """
    target_reward = 0.5
    
    for idx_step in range(nb_iter):

        print("begin iter")

        key, subkey = jax.random.split(config.jax_key)
        config.jax_key = key

        print("generate high reward value")
        # first generate random state
        state_past, state, actions_past = generate_past_state_with_with_random_policy(key, vmap_reset, step_jit_env, config)
        
        print("apply the strategy")
        # apply model to get some generation
        actions_futur = apply_decision_diffuser_policy(config.jax_key, state_past, transformer, inverse_rl_model, config)

        print("improve data buffer")
        # update replay buffer dataset
        buffer, buffer_list, reward_mean = gather_data_with_policy(state, state_past, actions_past, actions_futur, buffer, buffer_list, config)

        diff_target = reward_mean - target_reward
        target_reward = 1./2. * (target_reward + reward_mean.max())
        print("new target : ", target_reward)

        wandb.log({"reward_normalized" : reward_hacking(reward_mean).mean(), "target_reward_new" : target_reward, "diff_target_reward": diff_target.mean()})

        # now we can do the training loop
        sample = buffer.sample(buffer_list, subkey)
        sample = reshape_diffusion_setup(sample, subkey)

        print("trainign")

        # we update the policy
        train_step_transformer_rf(
            transformer, optimizer_diffuser, metrics_train, sample
        )

        if idx_step % config.log_every_step == 0:
            metrics_train_result = metrics_train.compute()
            print(metrics_train_result)

            wandb.log(metrics_train_result, step=idx_step)
            metrics_train.reset()

improve_training_loop(buffer, buffer_list, nb_iter=10000)


begin iter
generate high reward value
apply the strategy
improve data buffer
new target :  0.41666666
trainign
{'loss': Array(0.10118264, dtype=float32), 'loss_cross_entropy': Array(0.09603005, dtype=float32)}
begin iter
generate high reward value
apply the strategy




improve data buffer
new target :  0.375
trainign
begin iter
generate high reward value
apply the strategy
improve data buffer
new target :  0.3912037
trainign
begin iter
generate high reward value
apply the strategy
