In [106]:
# chose the current file directory as the working directory
import os 
os.chdir("/teamspace/studios/this_studio/rubikscubesolver")


In [107]:
from tqdm import tqdm

import wandb  # for logging
import time
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import flax.nnx as nnx

import optax

from rubiktransformer.models import RubikTransformer, PolicyModel
import rubiktransformer.dataset as dataset
from rubiktransformer.trainer import train
from rubiktransformer.trainer import reshape_sample

In [108]:
@dataclass
class Config:
    """Configuration class"""
    jax_key: jnp.ndarray = jax.random.PRNGKey(46)
    rngs = nnx.Rngs(45)
    batch_size: int = 128
    lr_1: float = 4e-3
    lr_2: float = 4e-3
    nb_games: int = 128 * 100
    len_seq: int = 15
    nb_step: int = 1000000
    log_every_step: int = 10
    log_eval_every_step: int = 10
    log_policy_reward_every_step: int = 10
    add_data_every_step: int = 500

config = Config()

# init wandb config
user = "forbu14"
project = "RubikTransformer"
display_name = "experiment_" + time.strftime("%Y%m%d-%H%M%S")

wandb.init(entity=user, project=project, name=display_name)


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,█▄▄▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
loss_cross_entropy,█▅▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
loss_cross_entropy_eval,█▅▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
loss_eval,█▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
loss_reward,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_reward_eval,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,1.14226
loss_cross_entropy,1.13461
loss_cross_entropy_eval,1.1382
loss_eval,1.14528
loss_reward,0.00764
loss_reward_eval,0.00708


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111218802220214, max=1.0)…

In [109]:
policy = PolicyModel(rngs=config.rngs)
transformer = RubikTransformer(rngs=config.rngs, causal=True)

scheduler = optax.linear_schedule(init_value=0., end_value=1., transition_steps=4000)

# init optimizer
optimizer_optaxworldmodel = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.lion(config.lr_1 / 100.),
   # optax.adamw(config.lr_1/10.),
    optax.scale_by_schedule(scheduler),
)

optimizer_worldmodel = nnx.Optimizer(transformer, optimizer_optaxworldmodel)
optimizer_policy = nnx.Optimizer(policy, optax.adam(config.lr_2))

# metrics
metrics_train = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
    loss_reward=nnx.metrics.Average("loss_reward"),
    loss_cross_entropy=nnx.metrics.Average("loss_cross_entropy"),
)

metrics_eval = nnx.MultiMetric(
    loss_eval=nnx.metrics.Average("loss_eval"),
    loss_reward_eval=nnx.metrics.Average("loss_reward_eval"),
    loss_cross_entropy_eval=nnx.metrics.Average("loss_cross_entropy_eval"),
)


In [110]:
# gather data from the environment
# init models and optimizers
env, buffer = dataset.init_env_buffer(sample_batch_size=config.batch_size)
env, buffer_eval = dataset.init_env_buffer(sample_batch_size=config.batch_size)


nb_games = config.nb_games
len_seq = config.len_seq

state_first = jnp.zeros((6, 3, 3))
state_next = jnp.zeros((len_seq, 6, 3, 3))
action = jnp.zeros((len_seq, 3))

# transform state to int8 type
state_first = state_first.astype(jnp.int8)
state_next = state_next.astype(jnp.int8)

# action to int32 type
action = action.astype(jnp.int32)

reward = jnp.zeros((len_seq))

jit_step = jax.jit(env.step)

buffer_list = buffer.init(
    {
        "state_first": state_first,
        "action": action,
        "reward": reward,
        "state_next": state_next,
    }
)

buffer_list_eval = buffer_eval.init(
    {
        "state_first": state_first,
        "action": action,
        "reward": reward,
        "state_next": state_next,
    }
)

def step_fn(state, key):
    action = jax.random.randint(
        key=key,
        minval=env.action_spec.minimum,
        maxval=env.action_spec.maximum,
        shape=(3,),
    )

    new_state, timestep = jit_step(state, action)
    timestep.extras["action"] = action

    return new_state, timestep

def run_n_steps(state, key, n):
    random_keys = jax.random.split(key, n)
    state, rollout = jax.lax.scan(step_fn, state, random_keys)

    return rollout

vmap_reset = jax.vmap(jax.jit(env.reset))
vmap_step = jax.vmap(run_n_steps, in_axes=(0, 0, None))

In [111]:
nnx.display(transformer)

In [112]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data(
    env,
    vmap_reset,
    vmap_step,
    config.nb_games,
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

In [113]:

def loss_fn_transformer(model: RubikTransformer, batch):
    state_logits, reward_value = model(batch["state_first"], batch["action"])

    # reshape state_logits
    # from (batch_size, sequence_length, 324) => (batch_size, sequence_length -1, 54, 6)
    state_logits = state_logits[:, 1:, :]
    state_logits = state_logits.reshape(
        (state_logits.shape[0], state_logits.shape[1], 54, 6)
    )

    reward_value = reward_value[:, 1:]

    loss_crossentropy = optax.softmax_cross_entropy_with_integer_labels(
        logits=state_logits, labels=batch["state_next"]
    ).mean()

    loss_reward = jnp.square(reward_value - batch["reward"]).mean()

    loss = loss_crossentropy + loss_reward

    return loss, (loss_crossentropy, loss_reward)

@nnx.jit
def train_step_transformer(
    model: RubikTransformer, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch
):
    """Train for a single step."""
    grad_fn = nnx.value_and_grad(loss_fn_transformer, has_aux=True)
    (loss, (loss_crossentropy, loss_reward)), grads = grad_fn(model, batch)
    metrics.update(
        loss=loss, loss_reward=loss_reward, loss_cross_entropy=loss_crossentropy
    )
    optimizer.update(grads)


In [114]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data(
    env,
    vmap_reset,
    vmap_step,
    config.nb_games * 10,
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

# transformer model calibration
for idx_step in tqdm(range(config.nb_step)):

    # training for world model
    key, subkey = jax.random.split(config.jax_key)
    config.jax_key = key

    if idx_step % config.add_data_every_step == 0:
        buffer, buffer_list = dataset.fast_gathering_data(
            env,
            vmap_reset,
            vmap_step,
            int(config.nb_games / 10.),
            config.len_seq,
            buffer,
            buffer_list,
            config.jax_key,
        )

    sample = buffer.sample(buffer_list, subkey)
    sample = reshape_sample(sample)

    # we update the policy
    train_step_transformer(
        transformer, optimizer_worldmodel, metrics_train, sample.experience
    )

    if idx_step % config.log_every_step == 0:
        metrics_train_result = metrics_train.compute()
        print(metrics_train_result)

        wandb.log(metrics_train_result, step=idx_step)
        metrics_train.reset()

    if idx_step % config.log_eval_every_step == 0:

        key, subkey = jax.random.split(config.jax_key)
        config.jax_key = key
        
        buffer_eval, buffer_list_eval = dataset.fast_gathering_data(
            env,
            vmap_reset,
            vmap_step,
            int(128),
            config.len_seq,
            buffer,
            buffer_list,
            subkey,
        )

        sample = buffer_eval.sample(buffer_list_eval, subkey)
        sample = reshape_sample(sample)

        loss, (loss_crossentropy, loss_reward) = loss_fn_transformer(transformer, sample.experience)

        metrics_eval.update(loss_eval=loss, loss_reward_eval=loss_reward, loss_cross_entropy_eval=loss_crossentropy)
        wandb.log(metrics_eval.compute(), step=idx_step)

        metrics_eval.reset()



  4%|▎         | 37254/1000000 [1:45:37<43:52:08,  6.10it/s]

  4%|▎         | 37258/1000000 [1:45:38<29:24:26,  9.09it/s]

{'loss': Array(0.7982044, dtype=float32), 'loss_reward': Array(0.00522177, dtype=float32), 'loss_cross_entropy': Array(0.7929827, dtype=float32)}


  4%|▎         | 37269/1000000 [1:45:39<26:16:40, 10.18it/s]

{'loss': Array(0.78816444, dtype=float32), 'loss_reward': Array(0.00506678, dtype=float32), 'loss_cross_entropy': Array(0.78309757, dtype=float32)}


  4%|▎         | 37279/1000000 [1:45:41<36:17:53,  7.37it/s]

{'loss': Array(0.78788036, dtype=float32), 'loss_reward': Array(0.00485121, dtype=float32), 'loss_cross_entropy': Array(0.78302914, dtype=float32)}


  4%|▎         | 37290/1000000 [1:45:43<45:29:36,  5.88it/s]

{'loss': Array(0.78979343, dtype=float32), 'loss_reward': Array(0.00509722, dtype=float32), 'loss_cross_entropy': Array(0.7846963, dtype=float32)}





KeyboardInterrupt: 

In [53]:
sample = buffer_eval.sample(buffer_list_eval, subkey)


Array([[[[[4, 3, 2],
          [2, 0, 4],
          [4, 4, 5]],

         [[3, 5, 3],
          [2, 1, 1],
          [5, 3, 3]],

         [[2, 1, 0],
          [5, 2, 4],
          [4, 3, 3]],

         [[1, 4, 1],
          [0, 3, 1],
          [2, 3, 1]],

         [[0, 5, 0],
          [0, 4, 1],
          [5, 0, 2]],

         [[1, 2, 5],
          [2, 5, 5],
          [4, 0, 0]]],


        [[[4, 3, 3],
          [2, 0, 1],
          [4, 4, 3]],

         [[3, 5, 5],
          [2, 1, 5],
          [5, 3, 0]],

         [[4, 5, 2],
          [3, 2, 1],
          [3, 4, 0]],

         [[5, 4, 1],
          [4, 3, 1],
          [2, 3, 1]],

         [[0, 5, 0],
          [0, 4, 1],
          [5, 0, 2]],

         [[1, 2, 2],
          [2, 5, 0],
          [4, 0, 1]]],


        [[[3, 3, 3],
          [2, 0, 1],
          [5, 4, 3]],

         [[1, 5, 5],
          [2, 1, 5],
          [4, 3, 0]],

         [[4, 5, 2],
          [3, 2, 1],
          [3, 4, 0]],

         [[5, 4, 4],


In [121]:
# save buffer, buffer_list
# in pickle 
import pickle

state_weight = nnx.state(transformer)

In [122]:
# save state into pickle
with open('statev5.pickle', 'wb') as handle:
    pickle.dump(state_weight, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [28]:
state_weight

State({
  'action_mapping': {
    'bias': VariableState(
      type=Param,
      value=Array([ 1.02062002e-02,  8.40116176e-04,  4.30170354e-03, -2.32206844e-03,
             -3.86462361e-03, -3.70370690e-03, -8.79005529e-03,  8.39715451e-03,
              7.43274577e-03, -5.95704373e-03,  5.41053806e-03, -4.66044992e-03,
             -5.95248491e-03,  1.94322341e-03,  5.30058471e-03, -8.62675734e-06,
              1.46222953e-03, -4.82380530e-03,  6.11663517e-03, -4.79279645e-03,
             -9.01755528e-04,  1.05155641e-02,  3.69963143e-03, -4.41583851e-03,
             -1.33020673e-02, -1.03268318e-03, -2.12690933e-03, -4.83895745e-03,
             -6.36680424e-03,  8.21062829e-03,  1.37141149e-03, -7.68418843e-03,
              1.33816362e-03,  1.21833570e-03, -6.97358511e-04, -5.94444433e-03,
              1.58518285e-03,  1.02888222e-03, -1.85307261e-04,  6.62476011e-03,
             -7.09831342e-03, -5.51314838e-03,  3.09163728e-03, -3.30520235e-03,
             -8.18780065e-03