In [1]:
# chose the current file directory as the working directory
import os

os.chdir("/teamspace/studios/this_studio/rubikscubesolver")

In [2]:
from tqdm import tqdm

import wandb  # for logging
import time
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import flax.nnx as nnx

import optax

from rubiktransformer.model_diffusion_dt import RubikDTTransformer
import rubiktransformer.dataset as dataset
from rubiktransformer.trainer import reshape_sample

cuda_plugin_extension is not found.


In [3]:
@dataclass
class Config:
    """Configuration class"""

    jax_key: jnp.ndarray = jax.random.PRNGKey(46)
    rngs = nnx.Rngs(45)
    batch_size: int = 128
    lr_1: float = 4e-3
    lr_2: float = 4e-3
    nb_games: int = 128 * 100
    len_seq: int = 20
    nb_step: int = 1000000
    log_every_step: int = 10
    log_eval_every_step: int = 10
    log_policy_reward_every_step: int = 10
    add_data_every_step: int = 500


config = Config()

# init wandb config
user = "forbu14"
project = "RubikTransformer"
display_name = "experiment_" + time.strftime("%Y%m%d-%H%M%S")

wandb.init(entity=user, project=project, name=display_name)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mforbu14[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
transformer = RubikDTTransformer(rngs=config.rngs, causal=True)

scheduler = optax.linear_schedule(init_value=0.0, end_value=1.0, transition_steps=4000)

# init optimizer
optimizer_dd = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.lion(config.lr_1 / 100.0),
    # optax.adamw(config.lr_1/10.),
    optax.scale_by_schedule(scheduler),
)

optimizer_diffuser = nnx.Optimizer(transformer, optimizer_dd)

# metrics
metrics_train = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
    loss_reward=nnx.metrics.Average("loss_reward"),
    loss_cross_entropy=nnx.metrics.Average("loss_cross_entropy"),
)

metrics_eval = nnx.MultiMetric(
    loss_eval=nnx.metrics.Average("loss_eval"),
    loss_reward_eval=nnx.metrics.Average("loss_reward_eval"),
    loss_cross_entropy_eval=nnx.metrics.Average("loss_cross_entropy_eval"),
)

In [6]:
# gather data from the environment
# init models and optimizers
env, buffer = dataset.init_env_buffer(sample_batch_size=config.batch_size)
env, buffer_eval = dataset.init_env_buffer(sample_batch_size=config.batch_size)


nb_games = config.nb_games
len_seq = config.len_seq

state_first = jnp.zeros((6, 3, 3))
state_next = jnp.zeros((len_seq, 6, 3, 3))
action = jnp.zeros((len_seq, 3))
action_proba = jnp.zeros((len_seq, 9))

# transform state to int8 type
state_first = state_first.astype(jnp.int8)
state_next = state_next.astype(jnp.int8)

# action to int32 type
action = action.astype(jnp.int32)

reward = jnp.zeros((len_seq))

jit_step = jax.jit(env.step)

buffer_list = buffer.init(
    {
        "state_first": state_first,
        "action": action,
        "reward": reward,
        "state_next": state_next,
        "action_pred": action_proba,
    }
)

buffer_list_eval = buffer_eval.init(
    {
        "state_first": state_first,
        "action": action,
        "reward": reward,
        "state_next": state_next,
        "action_pred": action_proba,
    }
)

In [7]:
def step_fn(state, key):
    """
    Simple step function
    We choose a random action
    """

    action = jax.random.randint(
        key=key,
        minval=env.action_spec.minimum,
        maxval=env.action_spec.maximum,
        shape=(3,),
    )

    new_state, timestep = jit_step(state, action)
    timestep.extras["action"] = action

    return new_state, timestep


def run_n_steps(state, key, n):
    random_keys = jax.random.split(key, n)
    state, rollout = jax.lax.scan(step_fn, state, random_keys)

    return rollout


vmap_reset = jax.vmap(jax.jit(env.reset))
vmap_step = jax.vmap(run_n_steps, in_axes=(0, 0, None))

In [9]:
nnx.display(transformer)

In [10]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data_diffusion(
    env,
    vmap_reset,
    vmap_step,
    int(config.nb_games / 10),
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

In [12]:
sample = buffer.sample(buffer_list, subkey)
sample = reshape_sample(sample)

In [13]:
sample

TrajectoryBufferSample(experience={'action': Array([[[1.0342622e-03, 9.8763175e-02, 1.9616945e-01, ...,
         9.1598278e-01, 7.9124421e-02, 4.8927194e-03],
        [6.5340154e-04, 7.4344707e-01, 2.9543904e-03, ...,
         8.2277584e-01, 1.3074802e-01, 4.6476152e-02],
        [2.9181805e-01, 3.3277448e-02, 1.3126548e-01, ...,
         1.6328433e-03, 9.9831665e-01, 5.0603030e-05],
        ...,
        [3.3290795e-04, 2.8182656e-04, 1.0129113e-05, ...,
         1.2191311e-01, 4.8528105e-01, 3.9280584e-01],
        [7.0439512e-04, 9.9126726e-01, 8.9772133e-04, ...,
         2.1325910e-02, 7.3878825e-01, 2.3988584e-01],
        [1.1097225e-02, 2.3210780e-03, 8.3933562e-02, ...,
         3.7379863e-04, 9.6265030e-01, 3.6975913e-02]],

       [[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
         0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
        [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
         0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
        [0.0000000e+00, 0.00

In [12]:
def loss_fn_transformer_rf(model: RubikDTTransformer, batch):
    # rectified flow setup
    state_past, state_future = model(
        batch["state_past"], batch["state_future_noise"], batch["context"]
    )

    # reshape state_logits
    # from (batch_size, sequence_length, 324) => (batch_size, sequence_length -1, 54, 6)
    state_logits = state_future.reshape(
        (state_future.shape[0], state_future.shape[1], 54, 6)
    )

    loss_crossentropy = optax.softmax_cross_entropy_with_integer_labels(
        logits=state_logits, labels=batch["state_future"]
    ).mean()

    return loss, (loss_crossentropy)


@nnx.jit
def train_step_transformer(
    model: RubikDTTransformer,
    optimizer: nnx.Optimizer,
    metrics: nnx.MultiMetric,
    key: jax.random.PRNGKey,
    batch,
):
    """Train for a single step."""
    # TODO generate noise and time step for rectified flow setup
    time_step = jax.random.uniform(
        key, (batch["state_past"].shape[0], 1)
    )  # random value between 0 and 1

    # now contact the value to have the context for the rectified flow setup
    batch["context"] = jnp.concatenate([batch["reward_global"], time_step], axis=1)

    # now we generate the random noise for the rectified flow setup
    random_noise = jax.random.normal(key, batch["state_future"].shape)
    simplex_noise = jax.nn.softmax(random_noise, axis=-1)

    batch["state_future_noise"] = (
        time_step * simplex_noise + (1 - time_step) * batch["state_future"]
    )

    grad_fn = nnx.value_and_grad(loss_fn_transformer_rf, has_aux=True)
    (loss, (loss_crossentropy, loss_reward)), grads = grad_fn(model, batch)
    metrics.update(
        loss=loss, loss_reward=loss_reward, loss_cross_entropy=loss_crossentropy
    )
    optimizer.update(grads)

In [13]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data(
    env,
    vmap_reset,
    vmap_step_proba,
    config.nb_games * 10,
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

In [13]:
# transformer model calibration
for idx_step in tqdm(range(config.nb_step)):
    # training for world model
    key, subkey = jax.random.split(config.jax_key)
    config.jax_key = key

    if idx_step % config.add_data_every_step == 0:
        buffer, buffer_list = dataset.fast_gathering_data(
            env,
            vmap_reset,
            vmap_step_proba,
            int(config.nb_games / 10.0),
            config.len_seq,
            buffer,
            buffer_list,
            config.jax_key,
        )

    sample = buffer.sample(buffer_list, subkey)
    sample = reshape_sample(sample)

    # we update the policy
    train_step_transformer(
        transformer, optimizer_worldmodel, metrics_train, sample.experience
    )

    if idx_step % config.log_every_step == 0:
        metrics_train_result = metrics_train.compute()
        print(metrics_train_result)

        wandb.log(metrics_train_result, step=idx_step)
        metrics_train.reset()

    if idx_step % config.log_eval_every_step == 0:
        key, subkey = jax.random.split(config.jax_key)
        config.jax_key = key

        buffer_eval, buffer_list_eval = dataset.fast_gathering_data(
            env,
            vmap_reset,
            vmap_step_proba,
            int(128),
            config.len_seq,
            buffer_eval,
            buffer_list_eval,
            subkey,
        )

        sample = buffer_eval.sample(buffer_list_eval, subkey)
        sample = reshape_sample(sample)

        loss, (loss_crossentropy, loss_reward) = loss_fn_transformer(
            transformer, sample.experience
        )

        metrics_eval.update(
            loss_eval=loss,
            loss_reward_eval=loss_reward,
            loss_cross_entropy_eval=loss_crossentropy,
        )
        wandb.log(metrics_eval.compute(), step=idx_step)

        metrics_eval.reset()

  9%|▉         | 89249/1000000 [4:37:23<26:27:33,  9.56it/s]

{'loss': Array(0.13090365, dtype=float32), 'loss_reward': Array(0.00194933, dtype=float32), 'loss_cross_entropy': Array(0.12895429, dtype=float32)}


  9%|▉         | 89259/1000000 [4:37:24<25:18:38, 10.00it/s]

{'loss': Array(0.13089804, dtype=float32), 'loss_reward': Array(0.00191141, dtype=float32), 'loss_cross_entropy': Array(0.12898663, dtype=float32)}


  9%|▉         | 89269/1000000 [4:37:26<29:32:48,  8.56it/s]

{'loss': Array(0.12980014, dtype=float32), 'loss_reward': Array(0.00193842, dtype=float32), 'loss_cross_entropy': Array(0.12786174, dtype=float32)}


  9%|▉         | 89280/1000000 [4:37:28<26:16:22,  9.63it/s]

{'loss': Array(0.13595587, dtype=float32), 'loss_reward': Array(0.00196807, dtype=float32), 'loss_cross_entropy': Array(0.13398778, dtype=float32)}


  9%|▉         | 89290/1000000 [4:37:30<26:22:50,  9.59it/s]

{'loss': Array(0.1410868, dtype=float32), 'loss_reward': Array(0.00197345, dtype=float32), 'loss_cross_entropy': Array(0.13911338, dtype=float32)}


  9%|▉         | 89300/1000000 [4:37:32<47:10:27,  5.36it/s]


{'loss': Array(0.13178296, dtype=float32), 'loss_reward': Array(0.00202771, dtype=float32), 'loss_cross_entropy': Array(0.12975524, dtype=float32)}


KeyboardInterrupt: 

In [None]:
buffer_eval, buffer_list_eval = dataset.fast_gathering_data(
    env,
    vmap_reset,
    vmap_step_proba,
    int(128),
    config.len_seq,
    buffer_eval,
    buffer_list_eval,
    subkey,
)

TrajectoryBufferSample(experience={'action': Array([[[1.32556781e-01, 7.96739519e-01, 5.36718592e-02, ...,
         3.91646661e-03, 4.48901858e-03, 9.91594553e-01],
        [3.49070907e-01, 4.57749265e-04, 4.38157976e-01, ...,
         7.23136306e-01, 1.23497941e-01, 1.53365776e-01],
        [6.12441264e-03, 2.50436477e-02, 1.35732419e-03, ...,
         3.82237613e-01, 5.98694921e-01, 1.90675538e-02],
        ...,
        [1.41329234e-04, 2.44877161e-03, 8.43136787e-01, ...,
         2.33344346e-01, 6.42170012e-01, 1.24485560e-01],
        [6.32655225e-04, 1.77795421e-02, 9.65278149e-01, ...,
         1.25269741e-02, 3.21629345e-01, 6.65843725e-01],
        [9.08881542e-04, 1.04175135e-01, 7.50824576e-04, ...,
         9.99683421e-03, 7.89827347e-01, 2.00175866e-01]],

       [[2.03237548e-01, 7.00179100e-01, 3.63819454e-05, ...,
         9.96583939e-01, 2.39940570e-03, 1.01662707e-03],
        [7.63220847e-01, 1.11325733e-01, 3.15520242e-02, ...,
         5.45369804e-01, 4.54322606e-0

In [None]:
sample = buffer_eval.sample(buffer_list_eval, subkey)
sample = reshape_sample(sample)

In [None]:
loss, (loss_crossentropy, loss_reward) = loss_fn_transformer(
    transformer, sample.experience
)

AttributeError: 'RubikTransformer' object has no attribute 'state_mapping'

In [14]:
# save buffer, buffer_list
# in pickle
import pickle

state_weight = nnx.state(transformer)

In [15]:
state_weight

State({
  'action_mapping': {
    'bias': VariableState(
      type=Param,
      value=Array([ 4.04955857e-02,  2.82790326e-02, -7.85927773e-02,  7.52160996e-02,
             -6.11112965e-03,  6.96982583e-03, -1.17343664e-02, -1.74523471e-03,
              1.28632234e-02, -7.83682019e-02,  2.75444221e-02, -4.00350802e-02,
              1.79233290e-02,  8.38570073e-02,  2.03401130e-02,  4.92124483e-02,
              8.69528428e-02,  2.20998153e-02, -3.42875794e-02, -6.76687211e-02,
              2.17811018e-02,  8.36544111e-02, -3.08539756e-02, -8.56901798e-03,
             -6.66830465e-02,  1.15918748e-01,  5.94779989e-03,  1.72799546e-02,
             -1.16014622e-01, -6.75882176e-02, -6.16184436e-02, -5.52975051e-02,
              4.17982265e-02, -5.43787293e-02,  1.19193546e-01, -9.40112211e-03,
             -4.03175130e-02, -3.47817354e-02, -6.77642366e-03,  8.85512680e-02,
              4.83243428e-02,  6.59283325e-02, -5.58541063e-03,  3.46533172e-02,
             -6.63223118e-02

In [16]:
# save state into pickle
with open("state_probainput_vscale5.pickle", "wb") as handle:
    pickle.dump(state_weight, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [17]:
state_weight

State({
  'action_mapping': {
    'bias': VariableState(
      type=Param,
      value=Array([ 4.04955857e-02,  2.82790326e-02, -7.85927773e-02,  7.52160996e-02,
             -6.11112965e-03,  6.96982583e-03, -1.17343664e-02, -1.74523471e-03,
              1.28632234e-02, -7.83682019e-02,  2.75444221e-02, -4.00350802e-02,
              1.79233290e-02,  8.38570073e-02,  2.03401130e-02,  4.92124483e-02,
              8.69528428e-02,  2.20998153e-02, -3.42875794e-02, -6.76687211e-02,
              2.17811018e-02,  8.36544111e-02, -3.08539756e-02, -8.56901798e-03,
             -6.66830465e-02,  1.15918748e-01,  5.94779989e-03,  1.72799546e-02,
             -1.16014622e-01, -6.75882176e-02, -6.16184436e-02, -5.52975051e-02,
              4.17982265e-02, -5.43787293e-02,  1.19193546e-01, -9.40112211e-03,
             -4.03175130e-02, -3.47817354e-02, -6.77642366e-03,  8.85512680e-02,
              4.83243428e-02,  6.59283325e-02, -5.58541063e-03,  3.46533172e-02,
             -6.63223118e-02

In [19]:
sample.experience.keys()

dict_keys(['action', 'action_pred', 'reward', 'state_first', 'state_next'])

In [None]:
state_pred_transformer, reward = transformer(
    sample.experience["state_first"], sample.experience["action_pred"]
)

In [None]:
proba = jax.nn.softmax(state_pred_transformer[0, 1, :].reshape(6, 3, 3, 6))
state_prediction = jnp.argmax(proba, axis=3)
print(state_prediction)

[[[5 1 3]
  [0 0 1]
  [0 2 0]]

 [[1 3 1]
  [4 1 4]
  [4 2 5]]

 [[2 4 2]
  [5 2 0]
  [3 0 4]]

 [[5 2 2]
  [2 3 3]
  [1 1 3]]

 [[1 3 4]
  [5 4 0]
  [2 4 0]]

 [[3 5 4]
  [3 5 1]
  [0 5 5]]]


In [None]:
sample.experience["state_next"][0, 0, :].reshape(6, 3, 3)

Array([[[5, 1, 3],
        [0, 0, 1],
        [0, 2, 0]],

       [[1, 3, 1],
        [4, 1, 4],
        [4, 2, 5]],

       [[2, 4, 2],
        [5, 2, 0],
        [3, 0, 4]],

       [[5, 2, 2],
        [2, 3, 3],
        [1, 1, 3]],

       [[1, 3, 4],
        [5, 4, 0],
        [2, 4, 0]],

       [[3, 5, 4],
        [3, 5, 1],
        [0, 5, 5]]], dtype=int8)