In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# install from git
!if [ -e ./rubiks ]; then rm -rf ./rubiks; fi
!pip uninstall rubiks_rl -y --quiet
!git clone https://github.com/LongDangHoang/rubik_rl ./rubiks --quiet
!cd ./rubiks; pip install . --quiet; cd ..

# some packages
!pip install torchinfo python-dotenv wandb==0.15.0 protobuf==3.20.3 matplotlib --quiet

In [None]:
import os
import pandas as pd
import torch
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import random
import functools
import pythreejs as p3

from torchinfo import summary
from pathlib import Path
from tqdm import tqdm

from rubiks_rl.colors import Color
from rubiks_rl.rubik54 import Rubik54
from rubiks_rl.models import RLRubikModel 
from rubiks_rl.s3_utils import S3Syncer
from rubiks_rl.world import *
from rubiks_rl.logs import RubiksLogger

orig_seed = 314
torch.manual_seed(orig_seed)
torch.cuda.manual_seed_all(orig_seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# set environment
from dotenv import load_dotenv
ON_KAGGLE = False
if not load_dotenv():
    from kaggle_secrets import UserSecretsClient
    os.environ["WANDB_API_KEY"] = UserSecretsClient().get_secret("wandb_api")
    os.environ["AWS_ACCESS_KEY_ID"] = UserSecretsClient().get_secret("s3_aws_access_key")
    os.environ["AWS_SECRET_ACCESS_KEY"] = UserSecretsClient().get_secret("s3_aws_secret_access_key")
    ON_KAGGLE = True
    
import wandb

# Define hyper parameters
max_depth_scramble = 30
num_cubes = 1_024
num_epochs = 5000
num_blocks = 128
batch_size = 8192
accumulate_grad_batches = None
use_constant_lr = False
weight_decay = 0.01
refresh_every_n_epoch = 1
lr = 2e-4
use_rl_model = True

# run ids
log_wandb = True
use_existing_run = None
use_pretrain = None
init_new_wandb_run = True

# start a new wandb run to track this script
if log_wandb:
    wandb.login()

    if "run" not in globals():
        run = wandb.init(
            project="rubiks_rl",
            id=use_existing_run if (use_existing_run and not init_new_wandb_run) else None,
            resume="must" if (use_existing_run and not init_new_wandb_run) else None,
            config=dict(
                max_depth_scramble=max_depth_scramble,
                num_cubes=num_cubes,
                num_epochs=num_epochs,
                num_blocks=num_blocks,
                batch_size=batch_size,
                accumulate_grad_batches=accumulate_grad_batches,
                use_constant_lr=use_constant_lr,
                refresh_every_n_epoch=refresh_every_n_epoch,
                weight_decay=weight_decay,
                lr=lr,
                use_rl_model=use_rl_model
            )
        )
        assert run is not None

Load model

In [None]:
# setting up checkpoint 
save_ckpt_local_dir = (
    Path(f"./rubiks_rl/{run.id}/checkpoints") 
    if 'run' in globals() 
    else Path(f"./rubiks_rl/local_run/checkpoints")
)
save_ckpt_local_dir.mkdir(parents=True, exist_ok=True)

if use_existing_run:
    load_ckpt_local_dir = Path(f"./rubiks_rl/{use_existing_run}/checkpoints")
    load_ckpt_local_dir.mkdir(parents=True, exist_ok=True)
else:
    load_ckpt_local_dir = None

s3_syncer = S3Syncer(
    save_local_dir=save_ckpt_local_dir, 
    load_local_dir=load_ckpt_local_dir, 
    every_n_epochs=1
)

if use_existing_run:
    s3_syncer.download_files_from_s3()
    
if use_pretrain and not use_existing_run: # if resuming run, no need to pretrain
    pretrain_ckpt_path = (Path(".") / use_pretrain)
    pretrain_ckpt_local_dir = pretrain_ckpt_path.parent
    pretrain_ckpt_local_dir.mkdir(parents=True, exist_ok=True)
        
    s3_pretrain_sync = S3Syncer(save_local_dir=pretrain_ckpt_local_dir, load_local_dir=pretrain_ckpt_local_dir)
    if not pretrain_ckpt_path.exists():
        s3_pretrain_sync.download_filename(pretrain_ckpt_path.name) 
        assert pretrain_ckpt_path.exists()
    
    state_dict = torch.load(pretrain_ckpt_path)
    model = RLRubikModel()
    model.load_state_dict(state_dict["model_state_dict"])
elif use_existing_run:
    state_dict = torch.load(s3_syncer.load_local_dir / "last.ckpt")
    model = RLRubikModel()
    model.load_state_dict(state_dict["model_state_dict"])
else:
    model = RLRubikModel()
    
model = model.to(device)
summary(model, input_shape=(16, 54, 6))

## Policy iteration loop

- How far in the scramble do we want?
    - If we have the model learn on few steps, it may struggle as the number of steps increase? Best to have this as a parameter
- For each sample, find best action by evaluating on breadth-1
    - First, generate in a batch fashion the states to be evaluated
    - This will have shape (num_cubes * max_depth_scramble * 12, 54, 6)
- Then, evaluate the batch and add in the reward (-1 if state not solve, else 1)
- Retrieve the bootstrapped labels by taking the argmax over the reward
- Train the model using this output labels (num_cubes * max_depth_scramble) (by chunking over shape of 12)
    - We also apply weighting based on scramble distance to the average when taking the loss
- This forms one epoch. Should profile
- repeat until reach `num_epochs`

In [None]:
# set up optimiser and learning rate scheduler
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
lr_sched = optim.lr_scheduler.OneCycleLR(
    optimizer, 
    max_lr=lr*10, 
    epochs=num_epochs, 
    steps_per_epoch=int(np.ceil(num_cubes*max_depth_scramble/batch_size))
)

# apply states
if use_existing_run and ("state_dict" in globals()):
    optimizer.load_state_dict(state_dict["optimizer_state_dict"])
    lr_sched.load_state_dict(state_dict["scheduler_state_dict"])

In [None]:
# run training
pbar = tqdm(total=num_epochs, dynamic_ncols=True)
save_path = None
metrics_logger = RubiksLogger(prefix="train")

if log_wandb:
    wandb.watch(model)

# skip epochs if use_pretrain
start_training_epoch = 0
if use_existing_run and ("state_dict" in globals()):
    start_training_epoch = (lr_sched._step_count - 1) / int(np.ceil(num_cubes*max_depth_scramble/batch_size))

for epoch_n in range(num_epochs):
    if epoch_n < start_training_epoch:
        pbar.update(1)
        continue
        
    # get new data
    if epoch_n % refresh_every_n_epoch == 0 or ("best_cube_state_values" not in globals()):
        seed = epoch_n - (epoch_n % refresh_every_n_epoch)
        X = get_n_cubes_k_scrambles(num_cubes=num_cubes, max_depth_scramble=max_depth_scramble, seed=seed)["data"]
        weight = get_weights_by_scrambling_distance(num_cubes=num_cubes, max_depth_scramble=max_depth_scramble)
        # generate bootstrapped labels using bfs
        evaluate_fn = functools.partial(model_evaluate, model=model, batch_size=batch_size*32, device=device)
        best_action, best_cube_state_values = find_best_move_and_value_from(X, evaluate_fn)
        
        # sanity check
        # for speed randomly sample 100 cubes in num_cubes
        for i in random.sample(range(num_cubes), k=100):
            turn = CUBE.turn_mat[best_action[i]]
            assert np.all(X[i][turn] == CUBE.get_solved_state()), f"Model stops being able to solve scramble distance 1 at epoch {epoch_n}"
        
        if log_wandb:
            solved_cube_state_value = best_cube_state_values[(X == CUBE.get_solved_state()).all(axis=1).all(axis=1)][0]
            avg_target_state_values = [solved_cube_state_value] + list(best_cube_state_values[:num_cubes*10].reshape((10, num_cubes)).mean(axis=1))
            for d in range(11):
                wandb.log({f"target_state_value_d={d}": avg_target_state_values[d]})
        
        print(f"Finish generating training data using model at epoch {epoch_n}")

        # shuffle ids, but remove ids for solved state as we won't see the state at inference time
        ids = np.arange(len(X))[~(X == np.expand_dims(CUBE.get_solved_state(), 0)).all(axis=(1, 2))]
        generator = np.random.default_rng(seed=seed)
        ids = generator.permutation(ids)

    # train the model
    model.train()
    metrics_logger.refresh()
    for batch_start_idx in range(0, len(X), batch_size):
        batch_ids = ids[batch_start_idx:batch_start_idx+batch_size]
        optimizer.zero_grad()
        batch = torch.as_tensor(X[batch_ids], device=device, dtype=torch.float32)
        v, p = model(batch)
        p_lab = torch.as_tensor(best_action[batch_ids], device=device, dtype=torch.int64)
        v_lab = torch.as_tensor(best_cube_state_values[batch_ids], device=device, dtype=torch.float32).unsqueeze(1)
        loss_weight = torch.as_tensor(weight[batch_ids], device=device, dtype=torch.float32)
        
        loss_v = F.mse_loss(v, v_lab, reduction="none").reshape((-1,))
        loss_p = F.cross_entropy(p, p_lab, reduction="none").reshape((-1,))
        total_loss = loss_v + loss_p
        
        weighted_loss = (total_loss * loss_weight).mean()
        weighted_loss.backward()
        optimizer.step()
        lr_sched.step()
        
        if log_wandb:
            wandb.log({'lr': lr_sched.get_last_lr()[0]})
        
        # add to metric logger
        metrics_logger.update({
            "total_loss": total_loss.detach().cpu().numpy(),
            "policy_loss": loss_p.detach().cpu().numpy(),
            "state_value_loss": loss_v.detach().cpu().numpy(),
            "weight": loss_weight.detach().cpu().numpy(),
        })
        
    
    if log_wandb:
        wandb.log({
            "train_epoch_avg_weighted_loss": metrics_logger.avg_weighted_loss,
            "train_epoch_avg_weighted_state_value_loss": metrics_logger.avg_weighted_state_value_loss,
            "train_epoch_avg_weighted_policy_loss": metrics_logger.avg_weighted_policy_loss,
        })

    # save to checkpoint
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': lr_sched.state_dict(),
    }
    torch.save(checkpoint, s3_syncer.save_local_dir / 'last.ckpt')
    if epoch_n > 0 and epoch_n % 10 == 0:
        for file in s3_syncer.save_local_dir.iterdir():
            if file.name != "last.ckpt":
                os.remove(file)
        torch.save(checkpoint, s3_syncer.save_local_dir / f'epoch={epoch_n}.ckpt')
        
        if log_wandb:
            s3_syncer.upload_files_to_s3()
            print(f"Pushed model at epoch={epoch_n} to s3")
        
    # update pbar
    pbar.update(1)
    if log_wandb:
        wandb.log({'epoch': epoch_n})