In [1]:
# chose the current file directory as the working directory
import os 
os.chdir("/teamspace/studios/this_studio/rubikscubesolver")


In [2]:
from tqdm import tqdm

import wandb  # for logging
import time
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import flax.nnx as nnx

import optax

from rubiktransformer.models import RubikTransformer, PolicyModel
import rubiktransformer.dataset as dataset
from rubiktransformer.trainer import train
from rubiktransformer.trainer import reshape_sample

cuda_plugin_extension is not found.


In [3]:
@dataclass
class Config:
    """Configuration class"""
    jax_key: jnp.ndarray = jax.random.PRNGKey(45)
    rngs = nnx.Rngs(44)
    batch_size: int = 128
    lr_1: float = 4e-3
    lr_2: float = 4e-3
    nb_games: int = 128 * 100
    len_seq: int = 5
    nb_step: int = 1000000
    log_every_step: int = 10
    log_eval_every_step: int = 10
    add_data_every_step: int = 500

config = Config()

# init wandb config
user = "forbu14"
project = "RubikTransformer"
display_name = "experiment_" + time.strftime("%Y%m%d-%H%M%S")

wandb.init(entity=user, project=project, name=display_name)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mforbu14[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111227162222753, max=1.0)…

In [4]:
policy = PolicyModel(rngs=config.rngs)
transformer = RubikTransformer(rngs=config.rngs, causal=False)

# init optimizer
optimizer_optaxworldmodel = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.lion(config.lr_1 / 100.),
    # optax.adamw(config.lr_1 / 10.),
    # optax.cosine_onecycle_schedule(1000, 3.)
)

optimizer_worldmodel = nnx.Optimizer(transformer, optimizer_optaxworldmodel)
optimizer_policy = nnx.Optimizer(policy, optax.adam(config.lr_2))

# metrics
metrics_train = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
    loss_reward=nnx.metrics.Average("loss_reward"),
    loss_cross_entropy=nnx.metrics.Average("loss_cross_entropy"),
)

metrics_eval = nnx.MultiMetric(
    loss_eval=nnx.metrics.Average("loss_eval"),
    loss_reward_eval=nnx.metrics.Average("loss_reward_eval"),
    loss_cross_entropy_eval=nnx.metrics.Average("loss_cross_entropy_eval"),
)


In [5]:
# gather data from the environment
# init models and optimizers
env, buffer = dataset.init_env_buffer(sample_batch_size=config.batch_size)
env, buffer_eval = dataset.init_env_buffer(sample_batch_size=config.batch_size)


nb_games = config.nb_games
len_seq = config.len_seq

state_first = jnp.zeros((6, 3, 3))
state_next = jnp.zeros((len_seq, 6, 3, 3))
action = jnp.zeros((len_seq, 3))

# transform state to int8 type
state_first = state_first.astype(jnp.int8)
state_next = state_next.astype(jnp.int8)

# action to int32 type
action = action.astype(jnp.int32)

reward = jnp.zeros((len_seq))

jit_step = jax.jit(env.step)

buffer_list = buffer.init(
    {
        "state_first": state_first,
        "action": action,
        "reward": reward,
        "state_next": state_next,
    }
)

buffer_list_eval = buffer_eval.init(
    {
        "state_first": state_first,
        "action": action,
        "reward": reward,
        "state_next": state_next,
    }
)

def step_fn(state, key):
    action = jax.random.randint(
        key=key,
        minval=env.action_spec.minimum,
        maxval=env.action_spec.maximum,
        shape=(3,),
    )

    new_state, timestep = jit_step(state, action)
    timestep.extras["action"] = action

    return new_state, timestep

def run_n_steps(state, key, n):
    random_keys = jax.random.split(key, n)
    state, rollout = jax.lax.scan(step_fn, state, random_keys)

    return rollout

vmap_reset = jax.vmap(jax.jit(env.reset))
vmap_step = jax.vmap(run_n_steps, in_axes=(0, 0, None))

In [6]:
nnx.display(transformer)

In [7]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data(
    env,
    vmap_reset,
    vmap_step,
    config.nb_games,
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

In [8]:

def loss_fn_transformer(model: RubikTransformer, batch):
    state_logits, reward_value = model(batch["state_first"], batch["action"])

    # reshape state_logits
    # from (batch_size, sequence_length, 324) => (batch_size, sequence_length -1, 54, 6)
    state_logits = state_logits[:, 1:, :]
    state_logits = state_logits.reshape(
        (state_logits.shape[0], state_logits.shape[1], 54, 6)
    )

    reward_value = reward_value[:, 1:]

    loss_crossentropy = optax.softmax_cross_entropy_with_integer_labels(
        logits=state_logits, labels=batch["state_next"]
    ).mean()

    loss_reward = jnp.square(reward_value - batch["reward"]).mean()

    loss = loss_crossentropy + loss_reward

    return loss, (loss_crossentropy, loss_reward)

@nnx.jit
def train_step_transformer(
    model: RubikTransformer, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch
):
    """Train for a single step."""
    grad_fn = nnx.value_and_grad(loss_fn_transformer, has_aux=True)
    (loss, (loss_crossentropy, loss_reward)), grads = grad_fn(model, batch)
    metrics.update(
        loss=loss, loss_reward=loss_reward, loss_cross_entropy=loss_crossentropy
    )
    optimizer.update(grads)


In [None]:
key, subkey = jax.random.split(config.jax_key)
config.jax_key = key

buffer, buffer_list = dataset.fast_gathering_data(
    env,
    vmap_reset,
    vmap_step,
    config.nb_games * 10,
    config.len_seq,
    buffer,
    buffer_list,
    subkey,
)

# transformer model calibration
for idx_step in tqdm(range(config.nb_step)):
    if idx_step % config.add_data_every_step == 0:
        buffer, buffer_list = dataset.fast_gathering_data(
            env,
            vmap_reset,
            vmap_step,
            int(config.nb_games / 10.),
            config.len_seq,
            buffer,
            buffer_list,
            config.jax_key,
        )

    # training for world model
    key, subkey = jax.random.split(config.jax_key)
    config.jax_key = key

    sample = buffer.sample(buffer_list, subkey)
    sample = reshape_sample(sample)

    # we update the policy
    train_step_transformer(
        transformer, optimizer_worldmodel, metrics_train, sample.experience
    )

    if idx_step % config.log_every_step == 0:
        metrics_train_result = metrics_train.compute()
        print(metrics_train_result)

        wandb.log(metrics_train_result, step=idx_step)
        metrics_train.reset()

    if idx_step % config.log_eval_every_step == 0:

        key, subkey = jax.random.split(config.jax_key)
        config.jax_key = key
        
        buffer_eval, buffer_list_eval = dataset.fast_gathering_data(
            env,
            vmap_reset,
            vmap_step,
            int(128),
            config.len_seq,
            buffer,
            buffer_list,
            config.jax_key,
        )

        sample = buffer_eval.sample(buffer_list_eval, subkey)
        sample = reshape_sample(sample)

        loss, (loss_crossentropy, loss_reward) = loss_fn_transformer(transformer, sample.experience)

        metrics_eval.update(loss_eval=loss, loss_reward_eval=loss_reward, loss_cross_entropy_eval=loss_crossentropy)
        wandb.log(metrics_eval.compute(), step=idx_step)

        metrics_eval.reset()

