In [1]:
# chose the current file directory as the working directory
import os
os.chdir("/teamspace/studios/this_studio/rubikscubesolver")

In [2]:
from tqdm import tqdm

import wandb  # for logging
import time
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import flax.nnx as nnx

import optax

from rubiktransformer.model_diffusion_dt import RubikDTTransformer
import rubiktransformer.dataset as dataset
from rubiktransformer.trainer import reshape_sample

cuda_plugin_extension is not found.


In [3]:
@dataclass
class Config:
    """Configuration class"""

    jax_key: jnp.ndarray = jax.random.PRNGKey(49)
    rngs = nnx.Rngs(48)
    batch_size: int = 128
    lr_1: float = 4e-4
    lr_2: float = 4e-4
    nb_games: int = 128 * 100
    len_seq: int = 32
    nb_step: int = 1000000
    log_every_step: int = 10
    log_eval_every_step: int = 10
    log_policy_reward_every_step: int = 10
    add_data_every_step: int = 500

    save_model_every_step: int = 2000


config = Config()

# init wandb config
user = "forbu14"
project = "RubikTransformer"
display_name = "experiment_" + time.strftime("%Y%m%d-%H%M%S")

wandb.init(entity=user, project=project, name=display_name)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mforbu14[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
transformer = RubikDTTransformer(rngs=config.rngs, causal=True)

scheduler = optax.linear_schedule(init_value=0.0, end_value=1.0, transition_steps=4000)

# init optimizer
optimizer_dd = optax.chain(
    optax.clip_by_global_norm(1.0),
    #optax.lion(config.lr_1 / 10.0),
    optax.adamw(config.lr_1),
    optax.scale_by_schedule(scheduler),
)

optimizer_diffuser = nnx.Optimizer(transformer, optimizer_dd)

# metrics
metrics_train = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
    loss_cross_entropy=nnx.metrics.Average("loss_cross_entropy"),
)

metrics_eval = nnx.MultiMetric(
    loss_eval=nnx.metrics.Average("loss_eval"),

    loss_cross_entropy_eval=nnx.metrics.Average("loss_cross_entropy_eval"),
)

In [5]:
import pickle

filename = "state_ddt_model.pickle"

with open(filename, "rb") as input_file:
    state = pickle.load(input_file)

nnx.update(transformer, state)

In [6]:
# gather data from the environment
# init models and optimizers
env, buffer = dataset.init_env_buffer(sample_batch_size=config.batch_size)
env, buffer_eval = dataset.init_env_buffer(sample_batch_size=config.batch_size)


nb_games = config.nb_games
len_seq = config.len_seq

state_first = jnp.zeros((6, 3, 3))
state_next = jnp.zeros((len_seq, 6, 3, 3))
action = jnp.zeros((len_seq, 3))
action_proba = jnp.zeros((len_seq, 9))

# transform state to int8 type
state_first = state_first.astype(jnp.int8)
state_next = state_next.astype(jnp.int8)

# action to int32 type
action = action.astype(jnp.int32)

reward = jnp.zeros((1))

jit_step = jax.jit(env.step)

buffer_list = buffer.init(
    {
        "action": action,
        "reward": reward,
        "state_histo": state_next,
    }
)

buffer_list_eval = buffer_eval.init(
    {
        "action": action,
        "reward": reward,
        "state_histo": state_next,
    }
)

In [7]:
def step_fn(state, key):
    """
    Simple step function
    We choose a random action
    """

    action = jax.random.randint(
        key=key,
        minval=env.action_spec.minimum,
        maxval=env.action_spec.maximum,
        shape=(3,),
    )

    new_state, timestep = jit_step(state, action)
    timestep.extras["action"] = action

    return new_state, timestep


def run_n_steps(state, key, n):
    random_keys = jax.random.split(key, n)
    state, rollout = jax.lax.scan(step_fn, state, random_keys)

    return rollout


vmap_reset = jax.vmap(jax.jit(env.reset))
vmap_step = jax.vmap(run_n_steps, in_axes=(0, 0, None))

In [8]:
nnx.display(transformer)

In [9]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data_diffusion(
    env,
    vmap_reset,
    vmap_step,
    int(config.nb_games / 10),
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

In [13]:
sample = buffer.sample(buffer_list, subkey)

def reshape_diffusion_setup(sample, key=jax.random.PRNGKey(0)):
    sample.experience["state_histo"] = sample.experience["state_histo"].reshape(
        (sample.experience["state_histo"].shape[0], sample.experience["state_histo"].shape[1], 54)
    )

    # one hot encoding for state_histo
    sample.experience["state_histo"] = jax.nn.one_hot(
        sample.experience["state_histo"],
        num_classes=6,
        axis=-1,
    )

    # batch creation
    batch  = sample.experience
    len_seq = batch["state_histo"].shape[1]

    time_step = jax.random.uniform(
        key, (batch["state_histo"].shape[0], 1, 1, 1)
    ) # random value between 0 and 1

    batch['time_step'] = time_step

    # now contact the value to have the context for the rectified flow setup
    batch["context"] = jnp.concatenate([batch["reward"], time_step[:, :, 0, 0]], axis=1)

    batch["state_past"] = batch["state_histo"][:, :len_seq//4, :, :]
    batch["state_future"] = batch["state_histo"][:, len_seq//4:, :, :]

    # now we generate the random noise for the rectified flow setup
    simplex_noise = jax.random.dirichlet(key, jnp.ones(6), batch["state_future"].shape[:-1])

    batch["state_future_noise"] = (
        (1 - time_step) * simplex_noise + time_step * batch["state_future"]
    )
    
    return batch


sample = reshape_diffusion_setup(sample)


In [11]:
def loss_fn_transformer_rf(model: RubikDTTransformer, batch):
    # rectified flow setup
    state_past, state_future = model(
        batch["state_past"], batch["state_future_noise"], batch["context"]
    )

    loss_crossentropy = optax.softmax_cross_entropy(
        logits=state_future, labels=batch["state_future"]
    ).mean(axis=[1, 2])

    weight = jnp.clip(1. / (1. - batch["time_step"][:, 0, 0, 0]), min=0.005, max=1.5)

    loss_cross_entropy_weight = loss_crossentropy * weight

    return loss_cross_entropy_weight.mean(), (loss_crossentropy.mean())


@nnx.jit
def train_step_transformer_rf(
    model: RubikDTTransformer,
    optimizer: nnx.Optimizer,
    metrics: nnx.MultiMetric,
    batch,
):
    """Train for a single step."""

    grad_fn = nnx.value_and_grad(loss_fn_transformer_rf, has_aux=True)
    (loss, (loss_crossentropy)), grads = grad_fn(model, batch)
    metrics.update(
        loss=loss, loss_cross_entropy=loss_crossentropy
    )
    optimizer.update(grads)

In [10]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data_diffusion(
    env,
    vmap_reset,
    vmap_step,
    config.nb_games, # old is int(config.nb_games * 10.0),
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

In [None]:
import pickle

# transformer model calibration
for idx_step in tqdm(range(config.nb_step)):
    # training for world model
    key, subkey = jax.random.split(config.jax_key)
    config.jax_key = key

    if idx_step % config.add_data_every_step == 0:
        buffer, buffer_list = dataset.fast_gathering_data_diffusion(
            env,
            vmap_reset,
            vmap_step,
            int(config.nb_games // 10),
            config.len_seq,
            buffer,
            buffer_list,
            config.jax_key,
        )

    sample = buffer.sample(buffer_list, subkey)
    sample = reshape_diffusion_setup(sample, subkey)

    # we update the policy
    train_step_transformer_rf(
        transformer, optimizer_diffuser, metrics_train, sample
    )

    if idx_step % config.log_every_step == 0:
        metrics_train_result = metrics_train.compute()
        print(metrics_train_result)

        wandb.log(metrics_train_result, step=idx_step)
        metrics_train.reset()

    if idx_step % config.log_eval_every_step == 0:
        key, subkey = jax.random.split(config.jax_key)
        config.jax_key = key

        buffer_eval, buffer_list_eval = dataset.fast_gathering_data_diffusion(
            env,
            vmap_reset,
            vmap_step,
            int(128),
            config.len_seq,
            buffer_eval,
            buffer_list_eval,
            subkey,
        )

        sample = buffer_eval.sample(buffer_list_eval, subkey)
        sample = reshape_diffusion_setup(sample, subkey)

        loss, (loss_crossentropy) = loss_fn_transformer_rf(
            transformer, sample
        )

        metrics_eval.update(
            loss_eval=loss,
            loss_cross_entropy_eval=loss_crossentropy,
        )
        wandb.log(metrics_eval.compute(), step=idx_step)

        metrics_eval.reset()

    if idx_step % config.save_model_every_step == 0:

        state_weight = nnx.state(transformer)

        with open("state_ddt_model.pickle", "wb") as handle:
            pickle.dump(state_weight, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [11]:

def sampling_model(key, model, sample_eval, nb_batch_explore=128, nb_step=100):
    """
    Function used to sampling a state from a list 
    """
    seq_len_future = sample_eval["state_future"].shape[1]
    noise_future  = jax.random.dirichlet(key, jnp.ones(6) * 5., (nb_batch_explore, seq_len_future, 54))
    sample_eval["reward"] = jnp.linspace(start=-0.5, stop=0.5, num=nb_batch_explore)[:, None]

    for t_step in range(nb_step):
        t_step_array = jnp.ones((nb_batch_explore, 1, 1, 1)) * float(t_step / nb_step)
        sample_eval["context"] = jnp.concatenate([sample_eval["reward"], t_step_array[:, :, 0, 0]], axis=1)

        estimation_logits_past, estimation_logits_future = model(
            sample_eval["state_past"], noise_future, sample_eval["context"]
        )


        estimation_proba_future = jax.nn.softmax(estimation_logits_future, axis=-1)

        noise_future = noise_future + float(1. / nb_step) * 1./ (1. - t_step_array + 0.0001) * (estimation_proba_future - noise_future)

    return noise_future



In [18]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer_eval, buffer_list_eval = dataset.fast_gathering_data_diffusion(
    env,
    vmap_reset,
    vmap_step,
    int(128),
    config.len_seq,
    buffer_eval,
    buffer_list_eval,
    subkey,
)

sample = buffer_eval.sample(buffer_list_eval, subkey)
sample = reshape_diffusion_setup(sample, subkey)

In [19]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

sample = buffer.sample(buffer_list, subkey)
sample = reshape_diffusion_setup(sample, subkey)


result = sampling_model(key=config.jax_key, model=transformer, sample_eval=sample, nb_batch_explore=128, nb_step=100)
result

Array([[[[9.12160613e-06, 1.71264401e-05, 1.48795079e-05,
          9.99910355e-01, 2.99175736e-05, 1.85837271e-05],
         [1.93071319e-05, 9.99914408e-01, 2.00294890e-05,
          1.98439229e-05, 9.08272341e-06, 1.74135203e-05],
         [1.85810495e-05, 6.38415804e-06, 9.99915481e-01,
          2.45338306e-05, 1.44244405e-05, 2.06008554e-05],
         ...,
         [1.11915870e-05, 2.57499050e-05, 9.99916375e-01,
          1.56825408e-05, 1.51600689e-05, 1.58642652e-05],
         [9.99907255e-01, 1.58234034e-05, 1.93407759e-05,
          3.18640377e-05, 1.44680962e-05, 1.12638809e-05],
         [1.67636899e-05, 6.71518501e-06, 9.99912083e-01,
          2.26497650e-05, 2.02378724e-05, 2.14804895e-05]],

        [[1.03620114e-05, 1.93770975e-05, 9.99914050e-01,
          1.39374752e-05, 2.23403331e-05, 1.99424103e-05],
         [1.14582945e-05, 1.58930197e-05, 2.36062333e-05,
          9.99919713e-01, 1.25739025e-05, 1.67943072e-05],
         [7.86513556e-06, 2.15659384e-05, 2.7151

In [24]:
index_batch  = 64

jnp.argmax(sample["state_past"], axis=-1).reshape((128, 8, 6, 3, 3))[index_batch, -1, :, :, :]

Array([[[2, 0, 1],
        [0, 0, 5],
        [2, 4, 5]],

       [[5, 1, 4],
        [2, 1, 2],
        [1, 3, 3]],

       [[3, 4, 4],
        [1, 2, 1],
        [5, 5, 1]],

       [[5, 4, 3],
        [5, 3, 0],
        [0, 1, 0]],

       [[0, 3, 1],
        [2, 4, 3],
        [4, 4, 0]],

       [[4, 5, 2],
        [3, 5, 2],
        [3, 0, 2]]], dtype=int32)

In [25]:
jnp.argmax(result, axis=-1).reshape((128, 24, 6, 3, 3))[index_batch, 0, :, :, :]

Array([[[2, 0, 1],
        [0, 0, 5],
        [3, 1, 5]],

       [[4, 2, 3],
        [1, 1, 3],
        [5, 2, 1]],

       [[2, 4, 4],
        [5, 2, 1],
        [4, 5, 1]],

       [[5, 4, 3],
        [5, 3, 0],
        [0, 1, 0]],

       [[0, 3, 5],
        [2, 4, 4],
        [4, 4, 2]],

       [[1, 3, 0],
        [3, 5, 2],
        [3, 0, 2]]], dtype=int32)

In [22]:
jnp.argmax(result, axis=-1).reshape((128, 24, 6, 3, 3))[index_batch, 1, :, :, :]

Array([[[2, 5, 0],
        [5, 0, 3],
        [4, 2, 5]],

       [[3, 0, 2],
        [4, 1, 5],
        [4, 3, 4]],

       [[1, 0, 1],
        [3, 2, 1],
        [5, 2, 1]],

       [[2, 1, 5],
        [0, 3, 1],
        [4, 1, 0]],

       [[3, 4, 0],
        [4, 4, 0],
        [4, 3, 5]],

       [[1, 4, 3],
        [2, 5, 0],
        [3, 2, 0]]], dtype=int32)

In [None]:
sample = buffer_eval.sample(buffer_list_eval, subkey)
sample = reshape_sample(sample)

TrajectoryBufferSample(experience={'action': Array([[[1.32556781e-01, 7.96739519e-01, 5.36718592e-02, ...,
         3.91646661e-03, 4.48901858e-03, 9.91594553e-01],
        [3.49070907e-01, 4.57749265e-04, 4.38157976e-01, ...,
         7.23136306e-01, 1.23497941e-01, 1.53365776e-01],
        [6.12441264e-03, 2.50436477e-02, 1.35732419e-03, ...,
         3.82237613e-01, 5.98694921e-01, 1.90675538e-02],
        ...,
        [1.41329234e-04, 2.44877161e-03, 8.43136787e-01, ...,
         2.33344346e-01, 6.42170012e-01, 1.24485560e-01],
        [6.32655225e-04, 1.77795421e-02, 9.65278149e-01, ...,
         1.25269741e-02, 3.21629345e-01, 6.65843725e-01],
        [9.08881542e-04, 1.04175135e-01, 7.50824576e-04, ...,
         9.99683421e-03, 7.89827347e-01, 2.00175866e-01]],

       [[2.03237548e-01, 7.00179100e-01, 3.63819454e-05, ...,
         9.96583939e-01, 2.39940570e-03, 1.01662707e-03],
        [7.63220847e-01, 1.11325733e-01, 3.15520242e-02, ...,
         5.45369804e-01, 4.54322606e-0

AttributeError: 'RubikTransformer' object has no attribute 'state_mapping'

In [14]:
# save buffer, buffer_list
# in pickle
import pickle

state_weight = nnx.state(transformer)

In [15]:
state_weight

State({
  'action_mapping': {
    'bias': VariableState(
      type=Param,
      value=Array([ 4.04955857e-02,  2.82790326e-02, -7.85927773e-02,  7.52160996e-02,
             -6.11112965e-03,  6.96982583e-03, -1.17343664e-02, -1.74523471e-03,
              1.28632234e-02, -7.83682019e-02,  2.75444221e-02, -4.00350802e-02,
              1.79233290e-02,  8.38570073e-02,  2.03401130e-02,  4.92124483e-02,
              8.69528428e-02,  2.20998153e-02, -3.42875794e-02, -6.76687211e-02,
              2.17811018e-02,  8.36544111e-02, -3.08539756e-02, -8.56901798e-03,
             -6.66830465e-02,  1.15918748e-01,  5.94779989e-03,  1.72799546e-02,
             -1.16014622e-01, -6.75882176e-02, -6.16184436e-02, -5.52975051e-02,
              4.17982265e-02, -5.43787293e-02,  1.19193546e-01, -9.40112211e-03,
             -4.03175130e-02, -3.47817354e-02, -6.77642366e-03,  8.85512680e-02,
              4.83243428e-02,  6.59283325e-02, -5.58541063e-03,  3.46533172e-02,
             -6.63223118e-02

In [None]:
# now we want to modify the code that generate proper data according to policy
def policy_generation_calibration(decision_diffuser, inverse_model, buffer, buffer_list):
    """
    Online transformer setup

    1. We generate env setup 
    2. First random action in the different env
    3. Use decision_diffuser to choose the action to do from here
    4. Observe / apply policy  to retrieve data
    5. Add the data into the buffer
    6. Train model on those data

    Remember to log the performance data to compare with other run / algorithms

    """
    def generate_past_state_with_with_random_policy():
        pass

    def apply_decision_diffuser_policy():
        pass

    def gather_data_with_policy():
        pass

    def improve_training_loop():
        pass

    
    # we sample the buffer to get a proper information setup
    pass