### Llama 2 7b Supervised Fine Tuning (SFT) using Amazon SageMaker.

In [None]:
! pip install -U sagemaker boto3

In [None]:
import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")


In [None]:
hyperparameters = {}
SM_TRAIN_DIR = "/opt/ml/input/data" 

hyperparameters["model_name_or_path"] = "meta-llama/Llama-2-7b-hf"
hyperparameters["output_dir"] =  "/opt/ml/model"
hyperparameters["data_path"] = "Dahoas/rm-static Dahoas/full-hh-rlhf Dahoas/synthetic-instruct-gptj-pairwise"
hyperparameters["data_split"] = "2,4,4"
hyperparameters["per_device_train_batch_size"] = 8
hyperparameters["per_device_eval_batch_size"] = 8
hyperparameters["max_seq_len"] = 1024
hyperparameters["num_train_epochs"] = 1
hyperparameters["learning_rate"] = 9.65e-6
hyperparameters["weight_decay"] = 0.
hyperparameters["gradient_accumulation_steps"] = 1
hyperparameters["lr_scheduler_type"] = "cosine"
hyperparameters["num_warmup_steps"] = 0
hyperparameters["seed"] = 1234
hyperparameters["gradient_checkpointing"] = ""
hyperparameters["zero_stage"] = 3
hyperparameters["deepspeed"] = ""
hyperparameters["access_token"] = "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"

In [None]:
env = {}
env['FI_PROVIDER'] = 'efa'
env['NCCL_PROTO'] = 'simple'
env['FI_EFA_USE_DEVICE_RDMA'] = '1'
env['RDMAV_FORK_SAFE'] = '1'

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch

In [None]:
base_job_name = "llama7b-sft-dschat"
estimator = PyTorch(
    base_job_name=base_job_name,
    source_dir="./scripts",
    entry_point="sft/main.py",
    role=role,
    framework_version="1.13.1",
    py_version="py39",
    instance_count=1,
    instance_type="ml.p4de.24xlarge",
    hyperparameters=hyperparameters,
    disable_profiler=True,
    environment=env,
    distribution={"torch_distributed": {"enabled": True}},
    keep_alive_period_in_seconds=600, 
    disable_output_compression=True
)

In [44]:
estimator.fit()

Model Parameters: 6.607 B, Latency: 3.14s, TFLOPs: 17.76, Samples/sec: 2.55, Time/seq 0.39s, Batch Size: 8, Sequence Length: 1024
Model Parameters: 6.607 B, Latency: 3.14s, TFLOPs: 17.78, Samples/sec: 2.55, Time/seq 0.39s, Batch Size: 8, Sequence Length: 1024
Model Parameters: 6.607 B, Latency: 3.15s, TFLOPs: 17.71, Samples/sec: 2.54, Time/seq 0.39s, Batch Size: 8, Sequence Length: 1024
Model Parameters: 6.607 B, Latency: 3.14s, TFLOPs: 17.76, Samples/sec: 2.55, Time/seq 0.39s, Batch Size: 8, Sequence Length: 1024
Model Parameters: 6.607 B, Latency: 3.13s, TFLOPs: 17.79, Samples/sec: 2.55, Time/seq 0.39s, Batch Size: 8, Sequence Length: 1024
Model Parameters: 6.607 B, Latency: 3.14s, TFLOPs: 17.75, Samples/sec: 2.55, Time/seq 0.39s, Batch Size: 8, Sequence Length: 1024
Model Parameters: 6.607 B, Latency: 3.15s, TFLOPs: 17.71, Samples/sec: 2.54, Time/seq 0.39s, Batch Size: 8, Sequence Length: 1024
Model Parameters: 6.607 B, Latency: 3.14s, TFLOPs: 17.79, Samples/sec: 2.55, Time/seq 0.39

In [45]:
sess.update_training_job(estimator.latest_training_job.job_name, resource_config={"KeepAlivePeriodInSeconds":0})

INFO:sagemaker:Updating training job with name llama7b-sft-dschat-2023-10-04-05-48-33-742
