## Train LLama V2 on Trainium using Amazon SageMaker

In [None]:
#retrive the docker image URL stored in step 1
%store -r docker_image 

use_fsx = False # set this to true and check other fsx parameters to use FSxL for the job
use_checkpoint = True # set this to True if you ran Notebook 4 and have checkpoint created.

In [None]:
!pip install -U sagemaker boto3

In [None]:
import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:  
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")


Nemo Megatron uses Hydra based configuration and Yaml config files. In order to support this we will use SageMaker hyperparameters, which will be passed as arguments to the entrypoint script. In the entry point script we will use hydra compose API to read in the passed hyperparameters and override it in config. 

Note: Please change the trainer.num_nodes parameters accordingly to the number of instances in the estimator. 

In [None]:
hyperparameters = {}
NUM_NODES = 2
# Trainer parameters
hyperparameters["trainer.devices"]=32 
hyperparameters["trainer.num_nodes"]=NUM_NODES # Change this to the number of nodes in the training job indicated by instance_count parameter
hyperparameters["trainer.max_epochs"]="null"
hyperparameters["trainer.max_steps"]=500
hyperparameters["trainer.val_check_interval"]=0.99
hyperparameters["trainer.log_every_n_steps"]=1
hyperparameters["trainer.limit_val_batches"]=1
hyperparameters["trainer.limit_test_batches"]=1
hyperparameters["trainer.accumulate_grad_batches"]=1
hyperparameters["trainer.precision"]=32

#Model Parameters for 7b configuration

hyperparameters["model.micro_batch_size"]=1
hyperparameters["model.global_batch_size"]=256
hyperparameters["model.tensor_model_parallel_size"]=8
hyperparameters["model.pipeline_model_parallel_size"]=1
hyperparameters["model.max_position_embeddings"]=4096
hyperparameters["model.encoder_seq_length"]=4096
hyperparameters["model.hidden_size"]=4096
hyperparameters["model.num_layers"]=32
hyperparameters["model.num_attention_heads"]=32
hyperparameters["model.init_method_std"]=0.021
hyperparameters["model.hidden_dropout"]=0
hyperparameters["model.layernorm_epsilon"]=1e-5

hyperparameters["model.data.num_workers"]=1
hyperparameters["model.data.seq_length"]=4096
#hyperparameters["model.data.splits_string"]="\'980,10,10\'"
hyperparameters["model.optim.name"]="adamw"
hyperparameters["model.optim.lr"]=3.0e-4
hyperparameters["model.optim.betas"]="[0.9,0.95]"
hyperparameters["model.optim.weight_decay"]=0.1
hyperparameters["model.optim.sched.name"]="CosineAnnealing"
hyperparameters["model.optim.sched.warmup_steps"]=10
hyperparameters["model.optim.sched.constant_steps"]=0
hyperparameters["model.optim.sched.min_lr"]=3.0e-5
hyperparameters["model.optim.capturable"]=True
hyperparameters["model.sequence_parallel"]=True
hyperparameters["model.activations_checkpoint_granularity"]="full"
hyperparameters["model.activations_checkpoint_method"]="uniform"
hyperparameters["model.activations_checkpoint_num_layers"]=1
hyperparameters["model.save_xser"]=True

#experiment manager
hyperparameters["exp_manager.create_tensorboard_logger"]=False
hyperparameters["exp_manager.resume_if_exists"]=False
hyperparameters["exp_manager.resume_ignore_no_checkpoint"]=False
hyperparameters["exp_manager.create_checkpoint_callback"]=True

hyperparameters["exp_manager.checkpoint_callback_params.train_time_interval"]=36000
hyperparameters["exp_manager.checkpoint_callback_params.save_last"]=True
hyperparameters["model.use_cpu_initialization"]=True


In [None]:
# Retrive the FSX details from Store Magic 

if use_fsx:
    #retrive fsx details
    %store -r fsx_id
    %store -r sec_group
    %store -r private_subnet_id
    %store -r fsx_mount
    %store -r fsx_file_system_path
else:
    use_fsx = False

In [None]:
# setup fsx config for data channels
from sagemaker.inputs import FileSystemInput
if use_fsx:
    FS_ID = fsx_id # FSX ID
    FS_BASE_PATH = "/" + fsx_mount + "/" + fsx_file_system_path # Path in the filesystem that needs to be mounted
    SUBNET_ID = private_subnet_id # Subnet to launch SM jobs in
    SEC_GRP = [sec_group]

    fsx_train_input = FileSystemInput(
        file_system_id=FS_ID,
        file_system_type='FSxLustre',
        directory_path=FS_BASE_PATH + "/nemo_llama",
        file_system_access_mode="rw"
    )
    hyperparameters["model.tokenizer.type"]='/opt/ml/input/data/train/llama7b-hf'
    hyperparameters["model.data.data_prefix"]="[1.0,/opt/ml/input/data/train/wiki_text_document]"
    if use_checkpoint:
        hyperparameters["model.resume_from_checkpoint"] = "/opt/ml/input/data/train/llama7b_weights/mp_rank_07/model_optim_rng.ckpt"
        hyperparameters["model.load_xser"] = True
    hyperparameters["exp_manager.explicit_log_dir"]="/opt/ml/input/data/train/logs"
    cache_dir = "/opt/ml/input/data/train/neuron_cache"
    data_channels = {"train": fsx_train_input}

else:
    checkpoint_s3_uri = "s3://" + sagemaker_session_bucket + "/nemo_llama_experiment"
    # we will use the sagemaker s3 checkpoints mechanism since we need read/write access to the paths.
    hyperparameters["model.tokenizer.type"]='/opt/ml/checkpoints/llama7b-hf'
    hyperparameters["model.data.data_prefix"]="[1.0,/opt/ml/checkpoints/wiki_text_document]"
    if use_checkpoint:
        hyperparameters["model.resume_from_checkpoint"] = "/opt/ml/checkpoints/llama7b_weights/mp_rank_07/model_optim_rng.ckpt"
        hyperparameters["model.load_xser"] = True
    hyperparameters["exp_manager.explicit_log_dir"]="/opt/ml/model"
    checkpoint_dir = '/opt/ml/checkpoints'
    cache_dir = "/opt/ml/checkpoints/neuron_cache"

## Launch Training Job

We will launch the training job using Trn1.32xlarge Instance. 

**_NOTE:_**  When using S3 as data source, initially run the training job with 1 node for few training steps and later stop and increase the number of nodes. We need to do this as the nemo dataset loader creates and stores index files in the checkpoint path when we run the training. This happens on node with rank 0 process and other nodes will read after its done. Since we have checkpoints in S3 , you will get an file not found error when other processes in different node try to access the index files. This is because the syncronization of files from local disk to S3 will not be completed in time. 

In [None]:
import time

from sagemaker.pytorch import PyTorch


# define Training Job Name 
job_name = f'llama-neuron-nemo-{time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())}'

env = {}

env['FI_PROVIDER'] = 'efa'
env['NCCL_PROTO'] = 'simple'
env['FI_EFA_USE_DEVICE_RDMA'] = '1'
env['RDMAV_FORK_SAFE'] = '1'
env['FI_EFA_FORK_SAFE'] = '1'
env['NCCL_SOCKET_IFNAME'] = 'ens'
env['XLA_USE_BF16']='1'
env['NCCL_SOCKET_IFNAME'] = '^lo,docker'
env['NEURON_CC_FLAGS'] = "--cache_dir=" + cache_dir

# estimator 
pt_estimator = PyTorch(
    entry_point='train.py',
    source_dir='./scripts',
    instance_type="ml.trn1.32xlarge",
    image_uri=docker_image,
    instance_count=NUM_NODES,
    hyperparameters=hyperparameters,
    role=role,
    job_name=job_name,
    environment=env,
    disable_output_compression=True,
    checkpoint_s3_uri=checkpoint_s3_uri if not use_fsx else None,
    checkpoint_local_path=checkpoint_dir if not use_fsx else None,
    subnets = [SUBNET_ID] if use_fsx else None, # Give SageMaker Training Jobs access to FSx resources in your Amazon VPC
    keep_alive_period_in_seconds=600,
    distribution={"torch_distributed": {"enabled": True}} # enable torchrun 
)

In [None]:
if use_fsx:
    pt_estimator.fit(data_channels)
else:
    pt_estimator.fit()

In [None]:
sess.update_training_job(pt_estimator.latest_training_job.job_name, resource_config={"KeepAlivePeriodInSeconds":0})