## RLHF llama2 7b using SM

In [None]:
! pip install -U sagemaker boto3


In [None]:
import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")


In [None]:
hyperparameters = {}
SM_TRAIN_DIR = "/opt/ml/input/data" 

hyperparameters["actor_model_name_or_path"] = "/opt/ml/input/data/sft"
hyperparameters["critic_model_name_or_path"] = "/opt/ml/input/data/reward"
hyperparameters["output_dir"] =  "/opt/ml/model"
hyperparameters["data_path"] = "Dahoas/rm-static"
hyperparameters["data_split"] = "2,4,4"
hyperparameters["num_padding_at_beginning"] = 1
hyperparameters["per_device_generation_batch_size"] = 1
hyperparameters["per_device_training_batch_size"] = 1
#hyperparameters["num_padding_at_beginning"] = 0
hyperparameters["generation_batches"] = 1
hyperparameters["ppo_epochs"] = 1
hyperparameters["max_answer_seq_len"] = 256
hyperparameters["max_prompt_seq_len"] = 256
hyperparameters["actor_learning_rate"] = 9.65e-6
hyperparameters["critic_learning_rate"] = 5e-6
hyperparameters["actor_weight_decay"] = 0.1
hyperparameters["critic_weight_decay"] = 0.1
hyperparameters["num_train_epochs"] = 1
hyperparameters["lr_scheduler_type"] = "cosine"
hyperparameters["gradient_accumulation_steps"] = 1
hyperparameters["actor_gradient_checkpointing"] = ""

hyperparameters["critic_gradient_checkpointing"] = ""
hyperparameters["offload_reference_model"] = ""
hyperparameters["disable_actor_dropout"] = ""
hyperparameters["deepspeed"] = ""

hyperparameters["actor_zero_stage"] = 3
hyperparameters["critic_zero_stage"] = 3
hyperparameters["enable_hybrid_engine"] = ""

#hyperparameters["enable_mixed_precision_lora"] = ""
#hyperparameters["actor_lora_dim"] = 64
#hyperparameters["critic_lora_dim"] = 64
#hyperparameters["critic_lora_module_name"] = "layers."
#hyperparameters["actor_lora_module_name"] = "layers."

#hyperparameters["dtype"] = "bf16"

hyperparameters["output_dir"] = "/opt/ml/model"


hyperparameters["access_token"] = "hf_XXXXX" # Replace the access token

In [None]:
env = {}
env['FI_PROVIDER'] = 'efa'
env['NCCL_PROTO'] = 'simple'
env['FI_EFA_USE_DEVICE_RDMA'] = '1'
env['RDMAV_FORK_SAFE'] = '1'

In [None]:
from sagemaker.pytorch import PyTorch

In [None]:
hyperparameters

In [None]:
base_job_name = "llama7b-rlhf-dschat"
estimator = PyTorch(
    base_job_name=base_job_name,
    source_dir="./scripts",
    entry_point="rlhf/main.py",
    role=role,
    framework_version="1.13.1",
    py_version="py39",
    instance_count=1,
    instance_type="ml.p4d.24xlarge",
    hyperparameters=hyperparameters,
    disable_profiler=True,
    environment=env,
    distribution={"torch_distributed": {"enabled": True}},
    keep_alive_period_in_seconds=600, 
    disable_output_compression=True
)

In [None]:
actor_model_path = "<path to model from step 1 sft>"
critic_model_path = "<path to model from step 2 reward>"

In [None]:
estimator.fit({"sft":actor_model_path,"reward":critic_model_path})

In [None]:
sess.update_training_job(estimator.latest_training_job.job_name, resource_config={"KeepAlivePeriodInSeconds":0})