In [1]:
import os
import jax
import wandb
import socket
import logging
import warnings
import argparse
import gymnasium as gym
from pprint import pprint
from datetime import datetime
from tensorboardX import SummaryWriter
from tqdm import TqdmExperimentalWarning

In [2]:
jax.config.update("jax_enable_x64", True)
os.environ["WANDB_START_METHOD"] = "thread"
warnings.filterwarnings(
    "ignore", category=TqdmExperimentalWarning
)  # Remove experimental warning

In [3]:
from powr.utils import *
from powr.wrappers import *
from powr.powr import POWR
from powr.kernels import dirac_kernel, gaussian_kernel, gaussian_kernel_diag

In [4]:
logging.basicConfig(level=logging.WARNING)
logging.getLogger('jax').setLevel(logging.WARNING)
logging.getLogger('tensorboardX').setLevel(logging.WARNING)

In [5]:
args = argparse.Namespace(
    env="MountainCar-v0",
    group=None,
    project=None,
    la=1e-6,
    eta=0.1,
    gamma=0.99,
    sigma=0.2,
    q_mem=0,
    delete_Q_memory=False,
    early_stopping=None,
    warmup_episodes=1,
    epochs=8,
    train_episodes=1,
    parallel_envs=3,
    subsamples=1_000,
    iter_pmd=1,
    eval_episodes=1,
    save_gif_every=None,
    save_checkpoint_every=20,
    eval_every=1,
    seed=0,
    checkpoint=None,
    device="gpu",
    notes=None,
    tags=[],
    offline=True,
)
args.algo = "powr"

In [6]:
# ** Wandb Settings **
# Resume Wandb run if checkpoint is provided

checkpoint = args.checkpoint
if checkpoint is not None:
    checkpoint_data = load_checkpoint(checkpoint)
    project = args.project

    # Load saved `args`, `total_timesteps`, and `wandb_run_id`
    args = argparse.Namespace(**checkpoint_data["args"])
    total_timesteps = checkpoint_data["total_timesteps"]
    starting_epoch = checkpoint_data["epoch"]
    wandb_run_id = checkpoint_data["wandb_run_id"]
    print("Resuming WandB run: ", wandb_run_id)
    # Resume Wandb run with saved run ID
    wandb.init(
        project=project,
        id=wandb_run_id,  # Use saved Wandb run ID to resume the run
        save_code=True,
        sync_tensorboard=True,
        monitor_gym=True,
        resume="must",
        mode=("online" if not args.offline else "disabled"),
    )

    run_path = f"{checkpoint}/"
else:
    pprint(vars(args))
    random_string = get_random_string(5)
    current_date = datetime.today().strftime("%Y_%m_%d_%H_%M_%S")
    run_path = (
        "runs/"
        + str(args.env)
        + "/"
        + args.algo
        + "/"
        + get_run_name(args, current_date)
        + "_"
        + random_string
        + "/"
    )
    create_dirs(run_path)
    save_config(vars(args), run_path)

    # Initialize wandb
    wandb.init(
        config=vars(args),
        project=("powr" if args.project is None else args.project),
        group=(f"{args.env}/{args.algo}" if args.group is None else args.group),
        name=str(current_date)
        + "_"
        + str(args.env)
        + "_"
        + args.algo
        + "_eta="
        + str(args.eta)
        + "_la="
        + str(args.la)
        + "_train_eps="
        + str(args.train_episodes)
        + "_pmd_iters="
        + str(args.iter_pmd)
        + "_earlystop="
        + str(args.early_stopping)
        + "_seed"
        + str(args.seed)
        + "_"
        + random_string,
        save_code=True,
        sync_tensorboard=True,
        tags=args.tags,
        monitor_gym=True,
        notes=args.notes,
        mode=("online" if not args.offline else "disabled"),
    )
    starting_epoch = 0
    total_timesteps = 0

{'algo': 'powr',
 'checkpoint': None,
 'delete_Q_memory': False,
 'device': 'gpu',
 'early_stopping': None,
 'env': 'MountainCar-v0',
 'epochs': 8,
 'eta': 0.1,
 'eval_episodes': 1,
 'eval_every': 1,
 'gamma': 0.99,
 'group': None,
 'iter_pmd': 1,
 'la': 1e-06,
 'notes': None,
 'offline': True,
 'parallel_envs': 3,
 'project': None,
 'q_mem': 0,
 'save_checkpoint_every': 20,
 'save_gif_every': None,
 'seed': 0,
 'sigma': 0.2,
 'subsamples': 1000,
 'tags': [],
 'train_episodes': 1,
 'warmup_episodes': 1}


AttributeError: 'Namespace' object has no attribute 'total_timesteps'

In [11]:
# ** Device Settings **
device_setting = args.device
if device_setting == "gpu":
    device = jax.devices("gpu")[0]
    jax.config.update("jax_default_device", device)  # Update the default device to GPU

    print(f"Currently running on \033[92mGPU {RESET}")
elif device_setting == "cpu":
    
    try:
        os.environ["JAX_PLATFORMS"] = "cpu"
        device = jax.devices("cpu")[0]  
        jax.config.update("jax_default_device", device)  # Update the default device to CPU
    except:
        os.environ["JAX_PLATFORMS"] = "cpu"
        jax.config.update("jax_default_device", jax.devices("cpu")[0])

    print(f"Currently running on \033[92mCPU {RESET}")
else:
    raise ValueError(f"Unknown device setting {device_setting}, please use <cpu> or <gpu>")


# ** Logging Settings **
# Create tensorboard writer
writer = SummaryWriter(f"{run_path}")
writer.add_text(
    "hyperparameters",
    "|param|value|\n|-|-|\n%s"
    % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

# Create log file
log_file = open(os.path.join((run_path), "log_file.txt"), "a", encoding="utf-8")

# ** Hyperparameters Settings **
subsamples = args.subsamples
la = args.la
eta = args.eta
gamma = args.gamma
q_memories = args.q_mem

parallel_envs = args.parallel_envs
warmup_episodes = args.warmup_episodes
assert warmup_episodes > 0, "Number of warmup episodes must be greater than 0"
if warmup_episodes % parallel_envs != 0:

    warnings.warn(
            f"Number of warmup episodes {warmup_episodes} not divisible by parallel environments {parallel_envs}, considering {(warmup_episodes // parallel_envs + 1)*parallel_envs} warmup episodes",
            UserWarning,
        )        
    warmup_episodes = warmup_episodes//parallel_envs + 1
else:
    warmup_episodes = warmup_episodes//parallel_envs

epochs = args.epochs
train_episodes = args.train_episodes
if train_episodes % parallel_envs != 0:

    warnings.warn(
            f"Number of training episodes {train_episodes} not divisible by parallel environments {parallel_envs}, considering {(train_episodes // parallel_envs + 1)*parallel_envs} training episodes",
            UserWarning,
        )        
    train_episodes = train_episodes//parallel_envs + 1
else:
    train_episodes = train_episodes//parallel_envs

iter_pmd = args.iter_pmd
eval_episodes = args.eval_episodes
if eval_episodes % parallel_envs != 0:

    warnings.warn(
            f"Number of evaluation episodes {eval_episodes} not divisible by parallel environments {parallel_envs}, considering {(eval_episodes // parallel_envs + 1)*parallel_envs} evaluation episodes",
            UserWarning,
        )        
    eval_episodes = eval_episodes//parallel_envs + 1
else:
    eval_episodes = eval_episodes//parallel_envs

assert args.early_stopping is None or args.early_stopping > 0, "Number of early stopping episodes must be greater than 0"
early_stopping = args.early_stopping/parallel_envs if args.early_stopping is not None else None

save_gif_every = args.save_gif_every
eval_every = args.eval_every
save_checkpoint_every = args.save_checkpoint_every  
delete_Q_memory = args.delete_Q_memory

# ** Environment Settings **
env, kernel = parse_env(args.env, parallel_envs, args.sigma)

# ** Kernel Settings **
def to_be_jit_kernel(X, Y):
    return kernel(X, Y)

jit_kernel = jax.jit(to_be_jit_kernel)
v_jit_kernel = jax.vmap(jit_kernel) # TODO Not used

# ** Seed Settings**
set_seed(args.seed)

# ** POWR Initialization **
powr = POWR(
        env, 
        env, 
        args,
        eta=eta, 
        la=la, 
        gamma=gamma, 
        kernel=jit_kernel,
        subsamples=subsamples,
        q_memories=q_memories,
        delete_Q_memory=delete_Q_memory,
        early_stopping=early_stopping,
        tensorboard_writer=writer,
        starting_logging_epoch=starting_epoch,
        starting_logging_timestep=total_timesteps,
        run_path=run_path,
        seed=args.seed,
        checkpoint=checkpoint,
        device=device_setting,
        offline=args.offline,
    
)

# ** Training **
print(f"\033[1m\033[94mTraining the policy{RESET}")
powr.train( 
    epochs=epochs,
    warmup_episodes = warmup_episodes,
    train_episodes = train_episodes,
    eval_episodes = eval_episodes,
    iterations_pmd= iter_pmd,
    eval_every=eval_every,
    save_gif_every=save_gif_every,
    save_checkpoint_every=save_checkpoint_every,
    args_to_save=args,
) 

# ** Testing **
print(f"\033[1m\033[94mTesting the policy{RESET}")
n_test_episodes = 10
mean_reward = powr.evaluate(n_test_episodes)

print(f"Policy mean reward over {n_test_episodes} episodes: {mean_reward}")

usage: ipykernel_launcher.py [-h] [--env ENV] [--group GROUP]
                             [--project PROJECT] [--la LA] [--eta ETA]
                             [--gamma GAMMA] [--sigma SIGMA] [--q-mem Q_MEM]
                             [--delete-Q-memory]
                             [--early-stopping EARLY_STOPPING]
                             [--warmup-episodes WARMUP_EPISODES]
                             [--epochs EPOCHS]
                             [--train-episodes TRAIN_EPISODES]
                             [--parallel-envs PARALLEL_ENVS]
                             [--subsamples SUBSAMPLES] [--iter-pmd ITER_PMD]
                             [--eval-episodes EVAL_EPISODES]
                             [--save-gif-every SAVE_GIF_EVERY]
                             [--save-checkpoint-every SAVE_CHECKPOINT_EVERY]
                             [--eval-every EVAL_EVERY] [--seed SEED]
                             [--checkpoint CHECKPOINT] [--device DEVICE]
                      

SystemExit: 2