In [1]:
# chose the current file directory as the working directory
import os 
os.chdir("/teamspace/studios/this_studio/rubikscubesolver")


In [2]:
from tqdm import tqdm

import wandb  # for logging
import time
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import flax.nnx as nnx

import optax

from rubiktransformer.models import RubikTransformer, PolicyModel
import rubiktransformer.dataset as dataset
from rubiktransformer.trainer import train
from rubiktransformer.trainer import reshape_sample

cuda_plugin_extension is not found.


In [3]:
@dataclass
class Config:
    """Configuration class"""
    jax_key: jnp.ndarray = jax.random.PRNGKey(45)
    rngs = nnx.Rngs(44)
    batch_size: int = 1024
    lr_1: float = 4e-3
    lr_2: float = 4e-3
    nb_games: int = 1024 * 100
    len_seq: int = 12
    nb_step: int = 1000000
    log_every_step: int = 10
    log_eval_every_step: int = 10
    log_policy_reward_every_step: int = 10
    add_data_every_step: int = 500
    log_true_model_reward_every_step: int = 50

config = Config()

# init wandb config
user = "forbu14"
project = "RubikTransformer"
display_name = "experiment_" + time.strftime("%Y%m%d-%H%M%S")

wandb.init(entity=user, project=project, name=display_name)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mforbu14[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
env, buffer = dataset.init_env_buffer(sample_batch_size=config.batch_size)

policy = PolicyModel(rngs=config.rngs, d_model=1024, temp=5.)
transformer = RubikTransformer(rngs=config.rngs, causal=True)

optimizer_policy = optax.chain(
    #optax.clip_by_global_norm(1.0),
    optax.adamw(config.lr_1 / 100.),
)


optimizer_policy = nnx.Optimizer(policy, optimizer_policy)

# metrics
metrics_train = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
    loss_reward=nnx.metrics.Average("loss_reward"),
    loss_cross_entropy=nnx.metrics.Average("loss_cross_entropy"),
)

metrics_eval = nnx.MultiMetric(
    loss_eval=nnx.metrics.Average("loss_eval"),
    loss_reward_eval=nnx.metrics.Average("loss_reward_eval"),
    loss_cross_entropy_eval=nnx.metrics.Average("loss_cross_entropy_eval"),
)


metrics_policy = nnx.MultiMetric(
    sum_reward_policy=nnx.metrics.Average("sum_reward_policy"),
)


In [5]:
# load weight from world model transformer:
import pickle

filename = "state_probainput_vscale5.pickle"

with open(filename, "rb") as input_file:
    state = pickle.load(input_file)

nnx.update(transformer, state)

In [6]:
state

State({
  'action_mapping': {
    'bias': VariableState(
      type=Param,
      value=Array([ 4.04955857e-02,  2.82790326e-02, -7.85927773e-02,  7.52160996e-02,
             -6.11112965e-03,  6.96982583e-03, -1.17343664e-02, -1.74523471e-03,
              1.28632234e-02, -7.83682019e-02,  2.75444221e-02, -4.00350802e-02,
              1.79233290e-02,  8.38570073e-02,  2.03401130e-02,  4.92124483e-02,
              8.69528428e-02,  2.20998153e-02, -3.42875794e-02, -6.76687211e-02,
              2.17811018e-02,  8.36544111e-02, -3.08539756e-02, -8.56901798e-03,
             -6.66830465e-02,  1.15918748e-01,  5.94779989e-03,  1.72799546e-02,
             -1.16014622e-01, -6.75882176e-02, -6.16184436e-02, -5.52975051e-02,
              4.17982265e-02, -5.43787293e-02,  1.19193546e-01, -9.40112211e-03,
             -4.03175130e-02, -3.47817354e-02, -6.77642366e-03,  8.85512680e-02,
              4.83243428e-02,  6.59283325e-02, -5.58541063e-03,  3.46533172e-02,
             -6.63223118e-02

In [7]:


nb_games = config.nb_games
len_seq = config.len_seq

vmap_reset = jax.vmap(jax.jit(env.reset))


In [8]:
def gather_data_policy(
    model_policy: PolicyModel,
    model_worldmodel: RubikTransformer,
    env,
    vmap_reset,
    batch_size,
    len_seq,
    key,):
    keys = jax.random.split(key, batch_size)
    state, timestep = vmap_reset(keys)

    one_hot = jax.nn.one_hot(state.cube, 6)
    state_first_policy = jnp.reshape(
        one_hot, (batch_size, 1, -1)
    )

    state_pred = jnp.copy(state_first_policy)
    action_list = None

    state_pred_list = []
    uniform0_list = []
    uniform1_list = []

    # Collect a batch of rollouts
    for i in range(len_seq):
        keys = jax.random.split(key, batch_size)
        key_uniform = jax.random.split(keys[0], 2)
        key = keys[1]
        
        # generate random values 
        # random_uniform0, random_uniform1
        # should be of size (batch_size, 6) and (batch_size, 3) 
        uniform0 = jax.random.uniform(key_uniform[0], (batch_size, 1, 6))
        uniform1 = jax.random.uniform(key_uniform[1], (batch_size, 1, 3))

        # apply the policy
        action_result = model_policy(state_pred, uniform0, uniform1)

        if action_list is None:
            action_list = action_result
        else:
            action_list = jnp.concatenate((action_list, action_result), axis=1)

        # save data into a list
        state_pred_list.append(state_pred)
        uniform0_list.append(uniform0)
        uniform1_list.append(uniform1)

        # now we can apply the world model to sample next state
        state_logits, reward = model_worldmodel(state_pred, action_list)

        # reshape then argmax
        state_logits = state_logits.reshape(
            (state_logits.shape[0], state_logits.shape[1], 54, 6)
        )

        state_pred = jnp.argmax(state_logits, axis=3)

        # onehot
        state_pred = jax.nn.one_hot(state_pred, 6)

        # shape to flatten
        state_pred = state_pred.reshape((state_pred.shape[0], state_pred.shape[1], -1))

        # take the last state
        state_pred = state_pred[:, -1, :]

        # add a dimension on axis 1
        state_pred = jnp.expand_dims(state_pred, axis=1)

    # here we create the dataset in a proper format
    state_pred_histo = jnp.concatenate(state_pred_list, axis=1)
    uniform0_histo = jnp.concatenate(uniform0_list, axis=1)
    uniform1_histo = jnp.concatenate(uniform1_list, axis=1)

    return state_pred_histo, uniform0_histo, uniform1_histo, action_list


key = jax.random.PRNGKey(48)

state_pred_histo, uniform0_histo, uniform1_histo, action_list = gather_data_policy(
    policy,
    transformer,
    env,
    vmap_reset,
    config.batch_size,
    config.len_seq,
    key,)


In [9]:
from rubiktransformer.dataset import *

# compute reward from true environment
vmap_reset = jax.vmap(jax.jit(env.reset))
jit_step = jax.vmap(jax.jit(env.step))

GOAL_OBSERVATION = jnp.zeros((6, 3, 3))
for i in range(6):
    GOAL_OBSERVATION = GOAL_OBSERVATION.at[i, :, :].set(i)


def compute_reward_custom(observation):
    """
    Here we compute the reward for a given observation
    the observation here is a 6x3x3 array with value between 0 and 5
    that define the observation of the rubik cube
    We want to check the distance between the observation and the goal
    the goal g is of size 6x3x3 with g[i, :, :] = i
    """
    if observation.shape == (6, 3, 3):
        return jnp.where(observation != GOAL_OBSERVATION, -1.0, 1.0).mean()
    elif len(observation.shape) == 4:
        # we repeat the goal_observation to match the shape of the observation
        goal_observation = jnp.repeat(
            GOAL_OBSERVATION[None, :, :, :], observation.shape[0], axis=0
        )
        return jnp.where(observation != goal_observation, -1.0, 1.0).mean(axis=[1, 2, 3])



def compute_reward_policy(key, batch_size, len_seq, with_reward_func=None):

    # create keys for the vmap
    keys = jax.random.split(key, batch_size)
    keys_step = jax.random.split(keys[0], 2)

    # one hot encoding

    state, timestep = vmap_reset(keys)


    reward_sum = 0

    for i in range(len_seq):
        # apply the policy

        one_hot = jax.nn.one_hot(state.cube, 6)
        state_first_policy = jnp.reshape(
            one_hot, (batch_size, -1)
        )

        state_pred = jnp.copy(state_first_policy)


        action_prob = policy(state_pred, None, None)

        action_proba_0 = action_prob[:, :6]
        action_proba_1 = action_prob[:, 6:9]

        action0 = jax.random.categorical(keys_step[0], action_proba_0)
        action1 = jax.random.categorical(keys_step[1], action_proba_1)

        action = jnp.stack((action0, jnp.zeros(batch_size), action1), axis=-1)

        # transform into int8
        action = jnp.int8(action)

        # now we can apply the true world model to sample next state
        state, timestep = jit_step(state, action)

        reward = compute_reward_custom(state.cube)

        if with_reward_func is None:
            reward_sum += reward.mean()
        else:
            reward_sum += with_reward_func(reward).mean()

    return reward_sum / len_seq




In [10]:
state_pred_histo, uniform0_histo, uniform1_histo, action_list = gather_data_policy(
    policy,
    transformer,
    env,
    vmap_reset,
    config.batch_size,
    config.len_seq,
    key,)

In [11]:
nnx.display(transformer)

In [12]:

def reward_hacking(reward):
    """
    reward is an array of value of shape (batch_size, len_seq, 1) with value between -1 and 1
    we want to apply to every element the funciton
    f(x) = 0.1 * jnp.exp(4 * x)
    """

    return 0.1 * jnp.exp(4. * reward)

def loss_fn_transformer_policy(model_policy: PolicyModel, model: RubikTransformer, batch):
    action_plan = model_policy(batch["states"], batch["uniform0"], batch["uniform1"])

    states_next, reward_value = model(batch["state_first"], action_plan) 

    # modify the reward learning dynamics (end goal is very important)
    reward_value = reward_hacking(reward_value)

    loss_reward = - (reward_value[:, 1:, :]).sum(axis=1).mean()

    loss = loss_reward

    return loss, (loss_reward)

@nnx.jit
def train_step_transformer_policy(
    model_policy: PolicyModel, model: RubikTransformer, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch
):
    """Train for a single step."""
    grad_fn = nnx.value_and_grad(loss_fn_transformer_policy, has_aux=True)
    (loss, (loss_reward)), grads = grad_fn(model_policy, model, batch)
    metrics.update(
        sum_reward_policy=loss
    )
    optimizer.update(grads)


In [13]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

# transformer model calibration
for idx_step in tqdm(range(15000)):
    # gather data from policy :
    key, subkey = jax.random.split(config.jax_key)
    config.jax_key = key
    
    state_pred_histo, uniform0_histo, uniform1_histo, action_list = gather_data_policy(
        policy,
        transformer,
        env,
        vmap_reset,
        config.batch_size,
        config.len_seq,
        config.jax_key)

    batch = {
        "states": state_pred_histo,
        "uniform0": uniform0_histo,
        "uniform1": uniform1_histo,
        "state_first": state_pred_histo[:, 0, :],
    }

    batch["state_first"] = jnp.expand_dims(batch["state_first"], axis=1)

    train_step_transformer_policy(
        policy,
        transformer,
        optimizer_policy,
        metrics_policy,
        batch
    )

    if idx_step % config.log_policy_reward_every_step == 0:
        result_metrics = metrics_policy.compute()

        wandb.log(result_metrics, step=idx_step)

        metrics_policy.reset()

    if idx_step % config.log_true_model_reward_every_step == 0:
        # now we can log the reward for true world model
        key, subkey = jax.random.split(config.jax_key)
        config.jax_key = key

        reward_true_model = compute_reward_policy(key, config.batch_size, config.len_seq + 10, with_reward_func=reward_hacking)

        wandb.log({"reward_true_model": reward_true_model}, step=idx_step)



  0%|          | 0/15000 [00:00<?, ?it/s]

 31%|███       | 4618/15000 [1:31:42<3:25:54,  1.19s/it]

In [41]:
compute_reward_policy(key, config.batch_size, config.len_seq, with_reward_func=reward_hacking)

Array(1.0793765e-10, dtype=float32)

In [14]:
transformer.train()

action_plan = policy(batch["states"], batch["uniform0"], batch["uniform1"])
states_next, reward_value = transformer(batch["state_first"], action_plan) 

In [15]:
print(action_plan[0, 0, :])

[4.3580148e-01 4.5662463e-01 1.9980942e-04 1.0158302e-01 2.5370296e-03
 3.2539819e-03 1.1360981e-02 8.6222529e-01 1.2641373e-01]


In [21]:
reward_value[4, 1:, :]

Array([[-0.343951  ],
       [-0.16684626],
       [-0.31941935],
       [-0.35973513],
       [-0.5181925 ],
       [-0.43773872],
       [-0.36427772],
       [-0.364562  ],
       [-0.36198646],
       [-0.35279733],
       [-0.3900961 ],
       [-0.3427981 ]], dtype=float32)

In [19]:
action_plan[0, 0]

Array([1.5827626e-09, 3.5747697e-05, 3.9419392e-06, 9.9820602e-07,
       9.9995935e-01, 2.2563261e-10, 7.2900742e-01, 2.7099249e-01,
       5.1224045e-08], dtype=float32)

In [45]:
jnp.where(reward_hacking(reward_value[:, 1:, :]) == reward_hacking(reward_value[:, 1:, :]).max())

(Array([555], dtype=int32), Array([1], dtype=int32), Array([0], dtype=int32))

In [46]:
reward_hacking(reward_value[555, 1:, :])

Array([[0.06450409],
       [0.17188816],
       [0.04082082],
       [0.04273823],
       [0.01966814],
       [0.02531038],
       [0.03303149],
       [0.01994882],
       [0.01997759],
       [0.03321796],
       [0.02282799],
       [0.01893796]], dtype=float32)

In [76]:
init_result = jnp.argmax(batch["state_first"][41, 0, :].reshape(54, 6), axis=1).reshape(6, 3, 3)

reward = jnp.where(init_result != dataset.GOAL_OBSERVATION, -1.0, 1.0)


reward_hacking(reward.mean())

Array(0.01960022, dtype=float32)

In [50]:
reward_hacking(reward.mean())

Array(0.02273007, dtype=float32)

In [47]:
jnp.argmax(batch["state_first"][555, 0, :].reshape(54, 6), axis=1).reshape((6, 3, 3))


Array([[[5, 2, 0],
        [0, 0, 2],
        [4, 1, 3]],

       [[1, 4, 4],
        [1, 1, 3],
        [3, 1, 4]],

       [[0, 0, 1],
        [5, 2, 0],
        [0, 1, 3]],

       [[2, 5, 2],
        [3, 3, 2],
        [5, 3, 0]],

       [[1, 4, 5],
        [3, 4, 5],
        [3, 4, 5]],

       [[4, 2, 1],
        [5, 5, 0],
        [2, 4, 2]]], dtype=int32)

In [49]:
jnp.argmax(states_next[555, 2, :].reshape(54, 6), axis=1).reshape((6, 3, 3))

Array([[[0, 0, 5],
        [5, 0, 2],
        [0, 2, 0]],

       [[1, 0, 1],
        [4, 1, 1],
        [1, 1, 3]],

       [[2, 5, 2],
        [2, 2, 0],
        [4, 1, 3]],

       [[1, 4, 3],
        [3, 3, 2],
        [5, 3, 0]],

       [[4, 3, 4],
        [3, 4, 1],
        [3, 4, 4]],

       [[5, 5, 5],
        [5, 5, 0],
        [2, 4, 2]]], dtype=int32)

In [142]:
reward_hacking(-0.4444)

Array(0.01690434, dtype=float32, weak_type=True)

In [106]:
jax.nn.softmax(states_next[0, 1, :].reshape((54, 6)))[1, :]

Array([3.9461483e-08, 1.9702059e-06, 1.0005269e-04, 9.9987459e-01,
       2.8885726e-12, 2.3303615e-05], dtype=float32)

In [85]:
transformer.transformer

List(
  0=TransformerBlock(
    causal=True,
    dropout=Dropout(rate=0.05, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(
      default=RngStream(
        count=RngCount(
          tag='default',
          value=Array(786031, dtype=uint32)
        ),
        key=RngKey(
          tag='default',
          value=Array((), dtype=key<fry>) overlaying:
          [ 0 45]
        )
      )
    )),
    feedforward=FeedForward(
      linear1=Linear(
        bias=Param(
          value=Array(shape=(1024,), dtype=float32)
        ),
        bias_init=<function zeros at 0x7f7ef8f0b7f0>,
        dot_general=<function dot_general at 0x7f7ef9447910>,
        dtype=None,
        in_features=512,
        kernel=Param(
          value=Array(shape=(512, 1024), dtype=float32)
        ),
        kernel_init=<function variance_scaling.<locals>.init at 0x7f7ef874c040>,
        out_features=1024,
        param_dtype=<class 'jax.numpy.float32'>,
        precision=None,
        us