# Elyza Japanese Llama2 TGI SageMaker Finetuning

This is a sample code to finetune `elyza/ELYZA-japanese-Llama-2-7b-instruct` with LoRA and deploy to text-generation-inference (TGI) on SageMaker.

In [None]:
%pip install sagemaker pip boto3 botocore --upgrade  --quiet

In [None]:
import sagemaker, boto3, json
from sagemaker import get_execution_role
from sagemaker.huggingface import HuggingFace, HuggingFaceModel, get_huggingface_llm_image_uri

role = get_execution_role()
region = boto3.Session().region_name
sess = sagemaker.Session()
bucket = sess.default_bucket()

sagemaker.__version__

## Upload Data

Fine Tuning 用の日本語データをフォルダに配置してアップロードする。

ここでは例として [Databricks Dolly 15k](https://github.com/databrickslabs/dolly/tree/master/data) データセットを日本語に翻訳したものを利用します。(License: [Creative Commons Attribution-ShareAlike 3.0 Unported License](https://creativecommons.org/licenses/by-sa/3.0/legalcode))

In [None]:
!curl -L https://huggingface.co/datasets/kunishou/databricks-dolly-15k-ja/resolve/main/databricks-dolly-15k-ja.json --create-dirs -o ./data/databricks-dolly-15k-ja.json

In [None]:
!head ./data/databricks-dolly-15k-ja.json

In [None]:
import pandas as pd
df = pd.read_json("./data/databricks-dolly-15k-ja.json")
df = df[:3000]
df.to_json("./data/databricks-dolly-15k-ja-filtered.json", orient="records", force_ascii=False)

In [None]:
input_train = sess.upload_data(
    path="./data/databricks-dolly-15k-ja-filtered.json",
    key_prefix="Dolly"
)
input_train

## Fine-tuning

In [None]:
base_job_name="Elyza"
hyperparameters={
    'base_model':'elyza/ELYZA-japanese-Llama-2-7b-instruct',
    # 'load_in_8bit': True,
    # 'load_in_4bit': True,
    'data_path': '/opt/ml/input/data/train/databricks-dolly-15k-ja-filtered.json',
    'save_merged': True,
    'num_epochs': 1, # default 3
    'cutoff_len': 512,
    'group_by_length': False,
    'output_dir': '/opt/ml/model',
    # 'resume_from_checkpoint': '/opt/ml/checkpoints',
    'lora_target_modules': '[q_proj,v_proj,fc_in,fc_out]',
    'lora_r': 16,
    'batch_size': 8,
    'micro_batch_size': 8,
    'prompt_template_name': 'llama2',
    ## wandb setting
    # 'wandb_project': 'rinna',
    # 'wandb_run_name': "rinna",
    # 'wandb_watch': "gradients",  # options: false | gradients | all
    # 'wandb_log_model': "false",  # options: false | true
}
environment = {
    'WANDB_API_KEY': '<API KEY>'
}

In [None]:
huggingface_estimator = HuggingFace(
    base_job_name=base_job_name,
    role=role,
    entry_point='finetune.py',
    source_dir='./scripts/code',
    instance_type='ml.g5.2xlarge',
    instance_count=1,
    volume_size=200,
    transformers_version='4.26',
    pytorch_version='1.13',
    py_version='py39',
    use_spot_instances=True,
    max_wait=86400,
    hyperparameters=hyperparameters,
    environment=environment,
    metric_definitions=[{'Name': 'eval_loss', 'Regex': "'eval_loss': (\d\.\d+)"},
                        {'Name': 'train_loss', 'Regex': "'loss': (\d\.\d+)"}],
    # checkpoint_s3_uri=f"s3://{bucket}/{base_job_name}/checkpoint/",
)
huggingface_estimator.fit({'train': input_train})

## Deploy Model

In [None]:
# Get Model Artifact Location

import boto3
import sagemaker

def get_latest_training_job_artifact(base_job_name):
    sagemaker_client = boto3.client('sagemaker')
    response = sagemaker_client.list_training_jobs(NameContains=base_job_name, SortBy='CreationTime', SortOrder='Descending')
    training_job_arn = response['TrainingJobSummaries'][0]['TrainingJobArn']
    training_job_description = sagemaker_client.describe_training_job(TrainingJobName=training_job_arn.split('/')[-1])
    return training_job_description['ModelArtifacts']['S3ModelArtifacts']

try:
    model_data = huggingface_estimator.model_data
except:
    # Retrieve artifact url when kernel is restarted
    model_data = get_latest_training_job_artifact('Elyza')

print(model_data)

## Deploy Model

TGI でのデプロイ

In [None]:
hf_model_id = "/opt/ml/model" # LoRA を解凍したディレクトリを指定。AutoPeftModelForCausalLM により adapter_config.json で指定したモデルをロードし LoRA を適用する。 (TGI v1.0.1 より利用可能)
number_of_gpus = 1 # number of gpus to use for inference and tensor parallelism
health_check_timeout = 300 # Increase the timeout for the health check to 5 minutes for downloading the model
instance_type = "ml.g5.2xlarge" # instance type to use for deployment

In [None]:
llm_image = get_huggingface_llm_image_uri(
    "huggingface",
    version="0.9.3"
)
endpoint_name = sagemaker.utils.name_from_base("elyza-7b-lora")
llm_model = HuggingFaceModel(
    role=role,
    image_uri=llm_image,
    model_data=model_data,
    env={
        'HF_MODEL_ID': hf_model_id,
        'MODEL_CACHE_ROOT': "/opt/ml/model",
        'SM_NUM_GPUS': str(number_of_gpus),
        'DTYPE': 'bfloat16',
        'MAX_INPUT_LENGTH': "2048",  # Max length of input text
        'MAX_TOTAL_TOKENS': "4096",  # Max length of the generation (including input text)
        'MAX_BATCH_TOTAL_TOKENS': "8192",
    }
)
llm = llm_model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    container_startup_health_check_timeout=health_check_timeout,
    endpoint_name=endpoint_name,
)

## Run Inference

In [None]:
import json
import boto3
import logging
import io

boto3.set_stream_logger("",logging.INFO)
smr = boto3.client('sagemaker-runtime')

endpoint_name = llm.endpoint_name


class LineIterator:
    """
    A helper class for parsing the byte stream input from TGI container. 
    
    The output of the model will be in the following format:
    ```
    b'data:{"token": {"text": " a"}}\n\n'
    b'data:{"token": {"text": " challenging"}}\n\n'
    b'data:{"token": {"text": " problem"
    b'}}'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. It will also save any pending 
    lines that doe not end with a '\n' to make sure truncations are concatinated
    """
    
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord('\n'):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print('Unknown event type:' + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])


            
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
stop_token = '</s>'

def inference(text, system="あなたは誠実で優秀な日本人のアシスタントです。"):
    prompt = "{bos_token}{b_inst} {system}{prompt} {e_inst} ".format(
        bos_token="<s>",
        b_inst=B_INST,
        system=f"{B_SYS}{system}{E_SYS}",
        prompt=text,
        e_inst=E_INST,
    )
    body = {
        "inputs":prompt,
        "parameters":{
            "max_new_tokens": 512,
            "return_full_text": False,
            "do_sample": True,
            "temperature": 0.3,
            "stop": [stop_token]
        },
    }
    response = smr.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='application/json',
        Accept='application/json',
        Body=json.dumps(body)
    )
    print(json.loads(response['Body'].read())[0]['generated_text'])


def inference_stream(text, system="あなたは誠実で優秀な日本人のアシスタントです。"):
    prompt = "{bos_token}{b_inst} {system}{prompt} {e_inst} ".format(
        bos_token="<s>",
        b_inst=B_INST,
        system=f"{B_SYS}{system}{E_SYS}",
        prompt=text,
        e_inst=E_INST,
    )
    body = {
        "inputs":prompt,
        "parameters":{
            "max_new_tokens": 512,
            "return_full_text": False,
            "do_sample": True,
            "temperature": 0.3,
            "stop": [stop_token]
        },
        "stream": True
    }
    resp = smr.invoke_endpoint_with_response_stream(EndpointName=endpoint_name, Body=json.dumps(body), ContentType='application/json')
    # print(resp)
    event_stream = resp['Body']
    start_json = b'{'
    for line in LineIterator(event_stream):
        # print(line)
        if line != b'' and start_json in line:
            data = json.loads(line[line.find(start_json):].decode('utf-8'))
            if not stop_token in data['token']['text']:
                print(data['token']['text'],end='')

In [None]:
inference_stream("AWSとはなんですか？一言で要約してください")

## Delete Endpoint

In [None]:
llm.delete_model()
llm.delete_endpoint()