In [1]:
# chose the current file directory as the working directory
import os 
os.chdir("/teamspace/studios/this_studio/rubikscubesolver")


In [2]:
from tqdm import tqdm

import wandb  # for logging
import time
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import flax.nnx as nnx

import optax

from rubiktransformer.models import RubikTransformer, PolicyModel
import rubiktransformer.dataset as dataset
from rubiktransformer.trainer import train
from rubiktransformer.trainer import reshape_sample

cuda_plugin_extension is not found.


In [3]:
@dataclass
class Config:
    """Configuration class"""
    jax_key: jnp.ndarray = jax.random.PRNGKey(45)
    rngs = nnx.Rngs(44)
    batch_size: int = 128
    lr_1: float = 4e-3
    lr_2: float = 4e-3
    nb_games: int = 128 * 100
    len_seq: int = 5
    nb_step: int = 1000000
    log_every_step: int = 10
    log_eval_every_step: int = 10
    log_policy_reward_every_step: int = 10
    add_data_every_step: int = 500

config = Config()

# init wandb config
user = "forbu14"
project = "RubikTransformer"
display_name = "experiment_" + time.strftime("%Y%m%d-%H%M%S")

wandb.init(entity=user, project=project, name=display_name)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mforbu14[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [27]:
policy = PolicyModel(rngs=config.rngs)
transformer = RubikTransformer(rngs=config.rngs, causal=False)

# init optimizer
optimizer_optaxworldmodel = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.lion(config.lr_1 / 100.),
)

optimizer_policy = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.lion(config.lr_1 / 100.),
)

optimizer_worldmodel = nnx.Optimizer(transformer, optimizer_optaxworldmodel)
optimizer_policy = nnx.Optimizer(policy, optimizer_policy)

# metrics
metrics_train = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
    loss_reward=nnx.metrics.Average("loss_reward"),
    loss_cross_entropy=nnx.metrics.Average("loss_cross_entropy"),
)

metrics_eval = nnx.MultiMetric(
    loss_eval=nnx.metrics.Average("loss_eval"),
    loss_reward_eval=nnx.metrics.Average("loss_reward_eval"),
    loss_cross_entropy_eval=nnx.metrics.Average("loss_cross_entropy_eval"),
)


metrics_policy = nnx.MultiMetric(
    sum_reward_policy=nnx.metrics.Average("sum_reward_policy"),
)


In [5]:


nb_games = config.nb_games
len_seq = config.len_seq

vmap_reset = jax.vmap(jax.jit(env.reset))


In [21]:
def gather_data_policy(
    model_policy: PolicyModel,
    model_worldmodel: RubikTransformer,
    env,
    vmap_reset,
    batch_size,
    len_seq,
    key,):
    keys = jax.random.split(key, batch_size)
    state, timestep = vmap_reset(keys)

    one_hot = jax.nn.one_hot(state.cube, 6)
    state_first_policy = jnp.reshape(
        one_hot, (batch_size, 1, -1)
    )

    state_pred = jnp.copy(state_first_policy)
    action_list = None

    state_pred_list = []
    action_pred_histo = []
    uniform0_list = []
    uniform1_list = []

    # Collect a batch of rollouts
    for i in range(len_seq):
        keys = jax.random.split(key, batch_size)
        key_uniform = jax.random.split(keys[0], 2)
        key = keys[1]
        
        # generate random values 
        # random_uniform0, random_uniform1
        # should be of size (batch_size, 6) and (batch_size, 3) 
        uniform0 = jax.random.uniform(key_uniform[0], (batch_size, 1, 6))
        uniform1 = jax.random.uniform(key_uniform[1], (batch_size, 1, 3))

        # apply the policy
        action_result = model_policy(state_pred, uniform0, uniform1)

        if action_list is None:
            action_list = action_result
        else:
            action_list = jnp.concatenate((action_list, action_result), axis=1)

        # save data into a list
        state_pred_list.append(state_pred)
        uniform0_list.append(uniform0)
        uniform1_list.append(uniform1)

        # now we can apply the world model to sample next state
        state_next, reward = model_worldmodel(state_pred, action_list)

        # convert state_next 
        state_pred = state_next[:, -1, :]

        # add a dimension on axis 1
        state_pred = jnp.expand_dims(state_pred, axis=1)

    # here we create the dataset in a proper format
    state_pred_histo = jnp.concatenate(state_pred_list, axis=1)
    uniform0_histo = jnp.concatenate(uniform0_list, axis=1)
    uniform1_histo = jnp.concatenate(uniform1_list, axis=1)

    return state_pred_histo, uniform0_histo, uniform1_histo, action_list


key = jax.random.PRNGKey(48)

state_pred_histo, uniform0_histo, uniform1_histo, action_list = gather_data_policy(
    policy,
    transformer,
    env,
    vmap_reset,
    config.batch_size,
    config.len_seq,
    key,)


In [24]:
action_list.shape

(128, 5, 9)

In [9]:
nnx.display(transformer)

In [33]:

def loss_fn_transformer_policy(model_policy: PolicyModel, model: RubikTransformer, batch):
    action_plan = model_policy(batch["states"], batch["uniform0"], batch["uniform1"])

    states_next, reward_value = model(batch["state_first"], action_plan) 

    loss_reward = - (reward_value).sum(axis=1).mean()

    loss =  loss_reward

    return loss, (loss_reward)

@nnx.jit
def train_step_transformer_policy(
    model_policy: PolicyModel, model: RubikTransformer, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch
):
    """Train for a single step."""
    grad_fn = nnx.value_and_grad(loss_fn_transformer_policy, has_aux=True)
    (loss, (loss_reward)), grads = grad_fn(model_policy, model, batch)
    metrics.update(
        sum_reward_policy=loss
    )
    optimizer.update(grads)


In [34]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

# transformer model calibration
for idx_step in tqdm(range(config.nb_step)):
    # gather data from policy :
    key, subkey = jax.random.split(config.jax_key)
    config.jax_key = key
    
    state_pred_histo, uniform0_histo, uniform1_histo, action_list = gather_data_policy(
        policy,
        transformer,
        env,
        vmap_reset,
        config.batch_size,
        config.len_seq,
        config.jax_key)

    batch = {
        "states": state_pred_histo,
        "uniform0": uniform0_histo,
        "uniform1": uniform1_histo,
        "state_first": state_pred_histo[:, 0, :],
    }

    batch["state_first"] = jnp.expand_dims(batch["state_first"], axis=1)

    train_step_transformer_policy(
        policy,
        transformer,
        optimizer_policy,
        metrics_policy,
        batch
    )

    if idx_step % config.log_policy_reward_every_step == 0:
        result_metrics = metrics_policy.compute()

        wandb.log(result_metrics, step=idx_step)

        metrics_policy.reset()



  0%|          | 78/1000000 [02:33<496:33:35,  1.79s/it]