In [97]:
# chose the current file directory as the working directory
import os 
os.chdir("/teamspace/studios/this_studio/rubikscubesolver")


In [98]:
from tqdm import tqdm

import wandb  # for logging
import time
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import flax.nnx as nnx

import optax

from rubiktransformer.models import RubikTransformer, PolicyModel
import rubiktransformer.dataset as dataset
from rubiktransformer.trainer import train
from rubiktransformer.trainer import reshape_sample

In [99]:
@dataclass
class Config:
    """Configuration class"""
    jax_key: jnp.ndarray = jax.random.PRNGKey(45)
    rngs = nnx.Rngs(44)
    batch_size: int = 128
    lr_1: float = 4e-3
    lr_2: float = 4e-3
    nb_games: int = 128 * 100
    len_seq: int = 10
    nb_step: int = 1000000
    log_every_step: int = 10
    log_eval_every_step: int = 10
    log_policy_reward_every_step: int = 10
    add_data_every_step: int = 500

config = Config()

# init wandb config
user = "forbu14"
project = "RubikTransformer"
display_name = "experiment_" + time.strftime("%Y%m%d-%H%M%S")

wandb.init(entity=user, project=project, name=display_name)


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112274100014474, max=1.0…

In [100]:
env, buffer = dataset.init_env_buffer(sample_batch_size=config.batch_size)

policy = PolicyModel(rngs=config.rngs)
transformer = RubikTransformer(rngs=config.rngs, causal=True)

optimizer_policy = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.lion(config.lr_1 / 100.),
)


optimizer_policy = nnx.Optimizer(policy, optimizer_policy)

# metrics
metrics_train = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
    loss_reward=nnx.metrics.Average("loss_reward"),
    loss_cross_entropy=nnx.metrics.Average("loss_cross_entropy"),
)

metrics_eval = nnx.MultiMetric(
    loss_eval=nnx.metrics.Average("loss_eval"),
    loss_reward_eval=nnx.metrics.Average("loss_reward_eval"),
    loss_cross_entropy_eval=nnx.metrics.Average("loss_cross_entropy_eval"),
)


metrics_policy = nnx.MultiMetric(
    sum_reward_policy=nnx.metrics.Average("sum_reward_policy"),
)


In [101]:
# load weight from world model transformer:
import pickle

filename = "statev5.pickle"

with open(filename, "rb") as input_file:
    state = pickle.load(input_file)

nnx.update(transformer, state)

In [102]:


nb_games = config.nb_games
len_seq = config.len_seq

vmap_reset = jax.vmap(jax.jit(env.reset))


In [103]:
def gather_data_policy(
    model_policy: PolicyModel,
    model_worldmodel: RubikTransformer,
    env,
    vmap_reset,
    batch_size,
    len_seq,
    key,):
    keys = jax.random.split(key, batch_size)
    state, timestep = vmap_reset(keys)

    one_hot = jax.nn.one_hot(state.cube, 6)
    state_first_policy = jnp.reshape(
        one_hot, (batch_size, 1, -1)
    )

    state_pred = jnp.copy(state_first_policy)
    action_list = None

    state_pred_list = []
    uniform0_list = []
    uniform1_list = []

    # Collect a batch of rollouts
    for i in range(len_seq):
        keys = jax.random.split(key, batch_size)
        key_uniform = jax.random.split(keys[0], 2)
        key = keys[1]
        
        # generate random values 
        # random_uniform0, random_uniform1
        # should be of size (batch_size, 6) and (batch_size, 3) 
        uniform0 = jax.random.uniform(key_uniform[0], (batch_size, 1, 6))
        uniform1 = jax.random.uniform(key_uniform[1], (batch_size, 1, 3))

        # apply the policy
        action_result = model_policy(state_pred, uniform0, uniform1)

        if action_list is None:
            action_list = action_result
        else:
            action_list = jnp.concatenate((action_list, action_result), axis=1)

        # save data into a list
        state_pred_list.append(state_pred)
        uniform0_list.append(uniform0)
        uniform1_list.append(uniform1)

        # now we can apply the world model to sample next state
        state_logits, reward = model_worldmodel(state_pred, action_list)

        # reshape then argmax
        state_logits = state_logits.reshape(
            (state_logits.shape[0], state_logits.shape[1], 54, 6)
        )

        state_pred = jnp.argmax(state_logits, axis=3)

        # onehot
        state_pred = jax.nn.one_hot(state_pred, 6)

        # shape to flatten
        state_pred = state_pred.reshape((state_pred.shape[0], state_pred.shape[1], -1))

        # take the last state
        state_pred = state_pred[:, -1, :]

        # add a dimension on axis 1
        state_pred = jnp.expand_dims(state_pred, axis=1)

    # here we create the dataset in a proper format
    state_pred_histo = jnp.concatenate(state_pred_list, axis=1)
    uniform0_histo = jnp.concatenate(uniform0_list, axis=1)
    uniform1_histo = jnp.concatenate(uniform1_list, axis=1)

    return state_pred_histo, uniform0_histo, uniform1_histo, action_list


key = jax.random.PRNGKey(48)

state_pred_histo, uniform0_histo, uniform1_histo, action_list = gather_data_policy(
    policy,
    transformer,
    env,
    vmap_reset,
    config.batch_size,
    config.len_seq,
    key,)


In [104]:
state_pred_histo, uniform0_histo, uniform1_histo, action_list = gather_data_policy(
    policy,
    transformer,
    env,
    vmap_reset,
    config.batch_size,
    config.len_seq,
    key,)

In [105]:
nnx.display(transformer)

In [106]:

def reward_hacking(reward):
    """
    reward is an array of value of shape (batch_size, len_seq, 1) with value between -1 and 1
    we want to apply to every element the funciton
    f(x) = 0.1 * jnp.exp(4 * x)
    """

    return 0.1 * jnp.exp(4. * reward)

def loss_fn_transformer_policy(model_policy: PolicyModel, model: RubikTransformer, batch):
    action_plan = model_policy(batch["states"], batch["uniform0"], batch["uniform1"])

    states_next, reward_value = model(batch["state_first"], action_plan) 

    # modify the reward learning dynamics (end goal is very important)
    reward_value = reward_hacking(reward_value)

    loss_reward = - (reward_value[:, 1:, :]).sum(axis=1).mean()

    loss =  loss_reward

    return loss, (loss_reward)

@nnx.jit
def train_step_transformer_policy(
    model_policy: PolicyModel, model: RubikTransformer, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch
):
    """Train for a single step."""
    grad_fn = nnx.value_and_grad(loss_fn_transformer_policy, has_aux=True)
    (loss, (loss_reward)), grads = grad_fn(model_policy, model, batch)
    metrics.update(
        sum_reward_policy=loss
    )
    optimizer.update(grads)


In [107]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

# transformer model calibration
for idx_step in tqdm(range(config.nb_step)):
    # gather data from policy :
    key, subkey = jax.random.split(config.jax_key)
    config.jax_key = key
    
    state_pred_histo, uniform0_histo, uniform1_histo, action_list = gather_data_policy(
        policy,
        transformer,
        env,
        vmap_reset,
        config.batch_size,
        config.len_seq,
        config.jax_key)

    batch = {
        "states": state_pred_histo,
        "uniform0": uniform0_histo,
        "uniform1": uniform1_histo,
        "state_first": state_pred_histo[:, 0, :],
    }

    batch["state_first"] = jnp.expand_dims(batch["state_first"], axis=1)

    train_step_transformer_policy(
        policy,
        transformer,
        optimizer_policy,
        metrics_policy,
        batch
    )

    if idx_step % config.log_policy_reward_every_step == 0:
        result_metrics = metrics_policy.compute()

        wandb.log(result_metrics, step=idx_step)

        metrics_policy.reset()



  0%|          | 80/1000000 [01:34<302:43:38,  1.09s/it]

In [86]:
transformer.eval()

action_plan = policy(batch["states"], batch["uniform0"], batch["uniform1"])
states_next, reward_value = transformer(batch["state_first"], action_plan) 

In [87]:
action_plan[0, 0]

Array([0., 0., 0., 0., 0., 1., 0., 1., 0.], dtype=float32)

In [88]:
reward_value[0, 1:, :]

Array([[-0.30294454],
       [-0.2782483 ],
       [-0.27852708],
       [-0.27690354],
       [-0.27779618],
       [-0.27849773],
       [-0.28144747],
       [-0.28677082],
       [-0.28572685],
       [-0.28956443]], dtype=float32)

In [89]:
jnp.argmax(batch["state_first"][0, 0, :].reshape(54, 6), axis=1).reshape((6, 3, 3))


Array([[[4, 3, 3],
        [4, 0, 2],
        [0, 1, 2]],

       [[3, 4, 0],
        [2, 1, 3],
        [5, 2, 1]],

       [[1, 0, 2],
        [0, 2, 1],
        [2, 4, 3]],

       [[5, 5, 5],
        [0, 3, 2],
        [0, 5, 1]],

       [[1, 3, 2],
        [3, 4, 5],
        [4, 0, 4]],

       [[3, 1, 5],
        [4, 5, 5],
        [0, 1, 4]]], dtype=int32)

In [90]:
jnp.argmax(states_next[0, 1, :].reshape(54, 6), axis=1).reshape((6, 3, 3))

Array([[[4, 3, 1],
        [4, 5, 1],
        [4, 1, 2]],

       [[3, 0, 5],
        [3, 2, 3],
        [5, 2, 5]],

       [[1, 5, 4],
        [3, 5, 3],
        [2, 4, 3]],

       [[1, 0, 2],
        [1, 5, 4],
        [0, 5, 4]],

       [[3, 3, 0],
        [0, 0, 0],
        [4, 0, 4]],

       [[4, 1, 5],
        [4, 3, 5],
        [0, 1, 4]]], dtype=int32)

In [43]:
jax.nn.softmax(states_next[0, 1, :].reshape((54, 6)))[0, :]

Array([0.1177654 , 0.07482656, 0.28022024, 0.10275396, 0.34884286,
       0.07559103], dtype=float32)

In [85]:
transformer.transformer

List(
  0=TransformerBlock(
    causal=True,
    dropout=Dropout(rate=0.05, broadcast_dims=(), deterministic=False, rng_collection='dropout', rngs=Rngs(
      default=RngStream(
        count=RngCount(
          tag='default',
          value=Array(786031, dtype=uint32)
        ),
        key=RngKey(
          tag='default',
          value=Array((), dtype=key<fry>) overlaying:
          [ 0 45]
        )
      )
    )),
    feedforward=FeedForward(
      linear1=Linear(
        bias=Param(
          value=Array(shape=(1024,), dtype=float32)
        ),
        bias_init=<function zeros at 0x7f7ef8f0b7f0>,
        dot_general=<function dot_general at 0x7f7ef9447910>,
        dtype=None,
        in_features=512,
        kernel=Param(
          value=Array(shape=(512, 1024), dtype=float32)
        ),
        kernel_init=<function variance_scaling.<locals>.init at 0x7f7ef874c040>,
        out_features=1024,
        param_dtype=<class 'jax.numpy.float32'>,
        precision=None,
        us