In [1]:
# chose the current file directory as the working directory
import os 
os.chdir("/teamspace/studios/this_studio/rubikscubesolver")


In [2]:
from tqdm import tqdm

import wandb  # for logging
import time
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import flax.nnx as nnx

import optax

from rubiktransformer.models import RubikTransformer, PolicyModel
import rubiktransformer.dataset as dataset
from rubiktransformer.trainer import train
from rubiktransformer.trainer import reshape_sample

cuda_plugin_extension is not found.


In [3]:
@dataclass
class Config:
    """Configuration class"""
    jax_key: jnp.ndarray = jax.random.PRNGKey(46)
    rngs = nnx.Rngs(45)
    batch_size: int = 128
    lr_1: float = 4e-3
    lr_2: float = 4e-3
    nb_games: int = 128 * 100
    len_seq: int = 15
    nb_step: int = 1000000
    log_every_step: int = 10
    log_eval_every_step: int = 10
    log_policy_reward_every_step: int = 10
    add_data_every_step: int = 500

config = Config()

# init wandb config
user = "forbu14"
project = "RubikTransformer"
display_name = "experiment_" + time.strftime("%Y%m%d-%H%M%S")

wandb.init(entity=user, project=project, name=display_name)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mforbu14[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
policy = PolicyModel(rngs=config.rngs)
transformer = RubikTransformer(rngs=config.rngs, causal=True)

scheduler = optax.linear_schedule(init_value=0., end_value=1., transition_steps=4000)

# init optimizer
optimizer_optaxworldmodel = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.lion(config.lr_1 / 100.),
   # optax.adamw(config.lr_1/10.),
    optax.scale_by_schedule(scheduler),
)

optimizer_worldmodel = nnx.Optimizer(transformer, optimizer_optaxworldmodel)
optimizer_policy = nnx.Optimizer(policy, optax.adam(config.lr_2))

# metrics
metrics_train = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
    loss_reward=nnx.metrics.Average("loss_reward"),
    loss_cross_entropy=nnx.metrics.Average("loss_cross_entropy"),
)

metrics_eval = nnx.MultiMetric(
    loss_eval=nnx.metrics.Average("loss_eval"),
    loss_reward_eval=nnx.metrics.Average("loss_reward_eval"),
    loss_cross_entropy_eval=nnx.metrics.Average("loss_cross_entropy_eval"),
)


In [11]:
# gather data from the environment
# init models and optimizers
env, buffer = dataset.init_env_buffer(sample_batch_size=config.batch_size)
env, buffer_eval = dataset.init_env_buffer(sample_batch_size=config.batch_size)


nb_games = config.nb_games
len_seq = config.len_seq

state_first = jnp.zeros((6, 3, 3))
state_next = jnp.zeros((len_seq, 6, 3, 3))
action = jnp.zeros((len_seq, 3))
action_proba = jnp.zeros((len_seq, 9))

# transform state to int8 type
state_first = state_first.astype(jnp.int8)
state_next = state_next.astype(jnp.int8)

# action to int32 type
action = action.astype(jnp.int32)

reward = jnp.zeros((len_seq))

jit_step = jax.jit(env.step)

buffer_list = buffer.init(
    {
        "state_first": state_first,
        "action": action,
        "reward": reward,
        "state_next": state_next,
        "action_pred": action_proba,
    }
)

buffer_list_eval = buffer_eval.init(
    {
        "state_first": state_first,
        "action": action,
        "reward": reward,
        "state_next": state_next,
        "action_pred": action_proba,
    }
)


In [None]:

def step_fn(state, key):
    """
    Simple step function
    We choose a random action
    """

    action = jax.random.randint(
        key=key,
        minval=env.action_spec.minimum,
        maxval=env.action_spec.maximum,
        shape=(3,),
    )

    new_state, timestep = jit_step(state, action)
    timestep.extras["action"] = action
    timestep.extras["action_pred"] = jnp.zeros((9,))

    return new_state, timestep

def run_n_steps(state, key, n):
    random_keys = jax.random.split(key, n)
    state, rollout = jax.lax.scan(step_fn, state, random_keys)

    return rollout

vmap_reset = jax.vmap(jax.jit(env.reset))
vmap_step = jax.vmap(run_n_steps, in_axes=(0, 0, None))

In [101]:
scale_factor = 10

def step_fn_proba_setup(state, data):
    """
    Simple step function
    We choose a random action
    """

    key1, key2 = jax.random.split(data)

    # we choose a probability distribution over the action
    action_proba_0 = jax.random.normal(key1, shape=(6,)) * scale_factor
    action_proba_1 = jax.random.normal(key2, shape=(3,)) * scale_factor

    # softmax over the action
    action_proba_0 = jax.nn.softmax(action_proba_0)
    action_proba_1 = jax.nn.softmax(action_proba_1)

    # sample the action from the probability distribution
    action_proba_0_value = jax.random.categorical(key1, action_proba_0)
    action_proba_1_value = jax.random.categorical(key2, action_proba_1)

    action = jnp.array([action_proba_0_value, jnp.array(0), action_proba_1_value])

    # One hot encoding
    action_proba_0 = jax.nn.one_hot(action_proba_0_value, 6)
    action_proba_1 = jax.nn.one_hot(action_proba_1_value, 3)

    # concat 
    action_proba = jnp.concatenate([action_proba_0, action_proba_1])

    new_state, timestep = jit_step(state, action)
    timestep.extras["action"] = action
    timestep.extras["action_pred"] = action_proba

    return new_state, timestep

def run_n_steps_proba(state, key, n):
    random_keys = jax.random.split(key, n)
    state, rollout = jax.lax.scan(step_fn_proba_setup, state, random_keys)

    return rollout

vmap_reset = jax.vmap(jax.jit(env.reset))
vmap_step_proba = jax.vmap(run_n_steps_proba, in_axes=(0, 0, None))

In [102]:
state, timeStep = jax.jit(env.reset)(key)

In [None]:
key = jax.random.PRNGKey(0)
run_n_steps_proba(state, key, 10)

In [22]:
nnx.display(transformer)

jaxlib.xla_extension.ArrayImpl

In [12]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data(
    env,
    vmap_reset,
    vmap_step,
    int(config.nb_games / 10),
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

In [28]:
buffer, buffer_list = dataset.fast_gathering_data(
    env,
    vmap_reset,
    vmap_step_proba,
    int(config.nb_games / 100),
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

TypeError: Only scalar arrays can be converted to Python scalars; got arr.ndim=1

In [18]:
(jax.random.uniform(key) < 0.05).item()

False

In [24]:

def loss_fn_transformer(model: RubikTransformer, batch):
    state_logits, reward_value = model(batch["state_first"], batch["action"])

    # reshape state_logits
    # from (batch_size, sequence_length, 324) => (batch_size, sequence_length -1, 54, 6)
    state_logits = state_logits[:, 1:, :]
    state_logits = state_logits.reshape(
        (state_logits.shape[0], state_logits.shape[1], 54, 6)
    )

    reward_value = reward_value[:, 1:]

    loss_crossentropy = optax.softmax_cross_entropy_with_integer_labels(
        logits=state_logits, labels=batch["state_next"]
    ).mean()

    loss_reward = jnp.square(reward_value - batch["reward"]).mean()

    loss = loss_crossentropy + loss_reward

    return loss, (loss_crossentropy, loss_reward)

@nnx.jit
def train_step_transformer(
    model: RubikTransformer, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch
):
    """Train for a single step."""
    grad_fn = nnx.value_and_grad(loss_fn_transformer, has_aux=True)
    (loss, (loss_crossentropy, loss_reward)), grads = grad_fn(model, batch)
    metrics.update(
        loss=loss, loss_reward=loss_reward, loss_cross_entropy=loss_crossentropy
    )
    optimizer.update(grads)


In [None]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data(
    env,
    vmap_reset,
    vmap_step,
    config.nb_games * 10,
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

# transformer model calibration
for idx_step in tqdm(range(config.nb_step)):

    # training for world model
    key, subkey = jax.random.split(config.jax_key)
    config.jax_key = key

    if idx_step % config.add_data_every_step == 0:
        buffer, buffer_list = dataset.fast_gathering_data(
            env,
            vmap_reset,
            vmap_step,
            int(config.nb_games / 10.),
            config.len_seq,
            buffer,
            buffer_list,
            config.jax_key,
        )

    sample = buffer.sample(buffer_list, subkey)
    sample = reshape_sample(sample)

    # we update the policy
    train_step_transformer(
        transformer, optimizer_worldmodel, metrics_train, sample.experience
    )

    if idx_step % config.log_every_step == 0:
        metrics_train_result = metrics_train.compute()
        print(metrics_train_result)

        wandb.log(metrics_train_result, step=idx_step)
        metrics_train.reset()

    if idx_step % config.log_eval_every_step == 0:

        key, subkey = jax.random.split(config.jax_key)
        config.jax_key = key
        
        buffer_eval, buffer_list_eval = dataset.fast_gathering_data(
            env,
            vmap_reset,
            vmap_step,
            int(128),
            config.len_seq,
            buffer,
            buffer_list,
            subkey,
        )

        sample = buffer_eval.sample(buffer_list_eval, subkey)
        sample = reshape_sample(sample)

        loss, (loss_crossentropy, loss_reward) = loss_fn_transformer(transformer, sample.experience)

        metrics_eval.update(loss_eval=loss, loss_reward_eval=loss_reward, loss_cross_entropy_eval=loss_crossentropy)
        wandb.log(metrics_eval.compute(), step=idx_step)

        metrics_eval.reset()



In [15]:
sample = buffer_eval.sample(buffer_list_eval, subkey)
sample = reshape_sample(sample)


In [16]:
loss, (loss_crossentropy, loss_reward) = loss_fn_transformer(transformer, sample.experience)


AttributeError: 'RubikTransformer' object has no attribute 'state_mapping'

In [26]:
# save buffer, buffer_list
# in pickle 
import pickle

state_weight = nnx.state(transformer)

In [27]:
state_weight

State({
  'action_mapping': {
    'bias': VariableState(
      type=Param,
      value=Array([ 2.98494305e-02,  7.37091200e-03,  1.95764122e-03, -7.42887263e-04,
              4.47213883e-03, -1.97569397e-03, -1.05965240e-02,  2.91437353e-03,
             -2.21973169e-04, -9.10157617e-03,  2.28590495e-03, -1.90868392e-03,
              1.00215327e-03, -1.14400042e-02, -3.59990627e-05, -5.88784646e-03,
             -6.00204896e-03, -1.50419520e-02,  6.35023054e-04,  2.19244440e-03,
              9.63746384e-03,  6.40006363e-03,  9.55769233e-03, -1.84854679e-02,
             -1.29062552e-02,  7.26002501e-03,  4.44780141e-02, -6.89176784e-04,
             -1.76944733e-02,  5.26020071e-03,  1.60558335e-02, -1.72243211e-02,
              1.93193310e-03, -4.42172680e-03, -8.73702206e-03, -2.39823805e-03,
              1.03886090e-02,  3.90930893e-03, -1.02841277e-02,  9.13410354e-03,
              1.48430187e-03,  3.35732801e-03, -2.83761718e-03, -9.74433613e-04,
             -1.45434914e-02

In [28]:
# save state into pickle
with open('statev6.pickle', 'wb') as handle:
    pickle.dump(state_weight, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [12]:
state_weight

State({})

In [13]:
nnx.state(transformer)

State({})

In [14]:

transformer

RubikTransformer()