In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from shobu_rl import Shobu_RL
import numpy as np
import torch
import os

In [None]:
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "cpu"
)
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
torch.autograd.set_detect_anomaly(True)

In [None]:
from models import Shobu_PPO
model = Shobu_PPO()
model.to(device)

In [None]:
# critic, actor, and backbone
critic_params = list(model.critic.parameters())
backbone_params = list(model.backbone.parameters())
actor_params = [
    p for p in model.parameters() 
    if (not any(p is cp for cp in critic_params)) and (not any(p is bp for bp in backbone_params))
]  # All other params (policy heads)

optimizer = torch.optim.AdamW([
    {'params': actor_params, 'lr': 3e-4},
    {'params': backbone_params, 'lr': 3e-4},
    {'params': critic_params, 'lr': 1e-4}
], amsgrad=True, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.5)

In [None]:
shobu_rl = Shobu_RL(model)

In [None]:
shobu_rl.train(optimizer, scheduler, sparse=False)