# SageMaker JumpStart Foundation Models - OpenChatKit GPT-NeoXT-Chat-Base-20B Chatbot

---
Welcome to Amazon [SageMaker JumpStart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html)! You can use SageMaker JumpStart to solve many Machine Learning tasks through one-click in SageMaker Studio, or through [SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/overview.html#use-prebuilt-models-with-sagemaker-jumpstart).


In this demo notebook, we demonstrate how to use the SageMaker Python SDK to deploy the [GPT-NeoXT-Chat-Base-20B](https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B) model and query the model within an [OpenChatKit](https://github.com/togethercomputer/OpenChatKit) interactive shell. This demonstration provides an open-source Foundation Model chatbot for use within your application.

TODO: introduce and discuss benefits of OpenChatKit
---

1. [Set Up](#1.-Set-Up)
2. [Select a model](#2.-Select-a-model)
3. [Retrieve Artifacts & Deploy an Endpoint](#3.-Retrieve-Artifacts-&-Deploy-an-Endpoint)
4. [Query endpoint and parse response](#4.-Query-endpoint-and-parse-response)
5. [Advanced features: How to use various parameters to control the generated text](#5.-Advanced-features:-How-to-use-various-advanced-parameters-to-control-the-generated-text)
6. [Advanced features: How to use prompts engineering to solve different tasks](#6.-Advacned-features:-How-to-use-prompts-engineering-to-solve-different-tasks)
5. [Clean up the endpoint](#5.-Clean-up-the-endpoint)

Note: This notebook was tested on ml.t3.medium instance in Amazon SageMaker Studio with Python 3 (Data Science) kernel and in Amazon SageMaker Notebook instance with conda_python3 kernel.

### 1. Set Up

---
Before executing the notebook, there are some initial steps required for set up. This notebook requires ipywidgets.

---

In [None]:
!pip install ipywidgets==7.0.0 --quiet
!pip install --upgrade sagemaker --quiet

#### Permissions and environment variables

---
To host on Amazon SageMaker, we need to set up and authenticate the use of AWS services. Here, we use the execution role associated with the current notebook as the AWS account role with SageMaker access. 

---

In [None]:
import sagemaker, boto3, json
from sagemaker.session import Session

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

## 2. Select a pre-trained model
***
You can continue with the default model, or can choose a different model from the dropdown generated upon running the next cell. A complete list of SageMaker pre-trained models can also be accessed at [SageMaker pre-trained Models](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html#).
***

In [None]:
model_id, model_version = "huggingface-textgeneration2-gpt-neoxt-chat-base-20b-fp16", "*"

### 3. Retrieve Artifacts & Deploy an Endpoint

***

Using SageMaker, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset. We start by retrieving the `instance_type`, `image_uri`, and `model_uri` for the pre-trained model. To host the pre-trained model, we create an instance of [`sagemaker.model.Model`](https://sagemaker.readthedocs.io/en/stable/api/inference/model.html) and deploy it. This may take a few minutes.

***

In [None]:
from sagemaker import image_uris, model_uris, script_uris, instance_types
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base


endpoint_name = name_from_base(f"jumpstart-example-{model_id}")

# Retrieve the inference instance type for the specified model.
# instance_type = instance_types.retrieve_default(
#     model_id=model_id, model_version=model_version, scope="inference"
# )
instance_type = "ml.g5.24xlarge"

# Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.
# image_uri = image_uris.retrieve(
#     region=None,
#     framework=None,  # automatically inferred from model_id
#     image_scope="inference",
#     model_id=model_id,
#     model_version=model_version,
#     instance_type=instance_type,
# )
image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.3-cu117"

# Retrieve the model uri.
# model_uri = model_uris.retrieve(
#     model_id=model_id, model_version=model_version, model_scope="inference"
# )
model_uri = "jumpstart-1p/textgeneration2/infer-huggingface-textgeneration2-huggingface-textgeneration2-gpt-neoxt-chat-base-20b-fp16-20230419-1418-repack.tar.gz"

# Create the SageMaker model instance
# We already repack the inference script and model artifacts, so the `source_dir` argument to Model is not required.
model = Model(
    image_uri=image_uri,
    model_data=model_uri,
    role=aws_role,
    predictor_cls=Predictor,
    name=endpoint_name,
)

# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,
# for being able to run inference through the sagemaker API.
model_predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    predictor_cls=Predictor,
    endpoint_name=endpoint_name,
)

### 4. Query endpoint and parse response

***
This model also supports many advanced parameters while performing inference. They include:

* **max_length:** Model generates text until the output length (which includes the input context length) reaches `max_length`. If specified, it must be a positive integer.
* **max_time:** The maximum amount of time you allow the computation to run for in seconds. Generation will still finish the current pass after allocated time has been passed. This setting can help to generate a response prior to endpoint invocation response time out errors.
* **num_return_sequences:** Number of output sequences returned. If specified, it must be a positive integer.
* **num_beams:** Number of beams used in the greedy search. If specified, it must be integer greater than or equal to `num_return_sequences`.
* **no_repeat_ngram_size:** Model ensures that a sequence of words of `no_repeat_ngram_size` is not repeated in the output sequence. If specified, it must be a positive integer greater than 1.
* **temperature:** Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If `temperature` -> 0, it results in greedy decoding. If specified, it must be a positive float.
* **early_stopping:** If True, text generation is finished when all beam hypotheses reach the end of sentence token. If specified, it must be boolean.
* **do_sample:** If True, sample the next word as per the likelihood. If specified, it must be boolean.
* **top_k:** In each step of text generation, sample from only the `top_k` most likely words. If specified, it must be a positive integer.
* **top_p:** In each step of text generation, sample from the smallest possible set of words with cumulative probability `top_p`. If specified, it must be a float between 0 and 1.
* **seed:** Fix the randomized state for reproducibility. If specified, it must be an integer.

We may specify any subset of the parameters mentioned above while invoking an endpoint. Next, we show an example of how to invoke endpoint with these arguments

***

In [None]:
payload = {
    "text_inputs": "<human>: Tell me the steps to make a pizza\n<bot>:",
    "max_length": 50,
    "max_time": 50,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": True,
}


def query_endpoint_with_json_payload(encoded_json, endpoint_name):
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType="application/json", Body=encoded_json
    )
    return response


query_response = query_endpoint_with_json_payload(
    json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
)


def parse_response(query_response):
    model_predictions = json.loads(query_response["Body"].read())
    generated_text = model_predictions[0][0]["generated_text"]
    return generated_text


generated_texts = parse_response(query_response)
print(generated_texts)

### 5. Use an OpenChatKit shell to interact with your deployed endpoint

***
OpenChatKit provides a command line shell to interact with the chatbot. Here, we show how to utilize this shell with your deployed endpoint
***

In [None]:
%%bash
git clone --branch v0.16 --depth 1 https://github.com/togethercomputer/OpenChatKit.git
touch OpenChatKit/__init__.py
touch OpenChatKit/inference/__init__.py

***
Here, we provide a bare-bones simplification of the inference scripts in this OpenChatKit repository that can interact with our deployed SageMaker endpoint. There are three main components to this:
1. A model object (`JumpStartChatModel`) as a light wrapper around our endpoint query and parsing scripts,
2. A shell interpreter (`JumpStartOpenChatKitShell`) that allows for iterative inference invocations, and
3. A conversation object (`Conversation`) that stores previous human/chatbot interactions within the interactive shell.

The `Conversation` object is used as-is from the OpenChatKit repository. The model and shell objects, however, are explicitly defined in the following cell since they have some notable simplifications from the OpenChatKit implementation. We encourage you to explore the previously cloned repository to see how more in-depth features, such as token streaming, moderation models, and retrieval augmented generation may be used within this context.
***

In [None]:
import cmd
import json
from typing import List, Optional

from OpenChatKit.inference.conversation import Conversation


class JumpStartChatModel:
    human_id = "<human>"
    bot_id = "<bot>"

    def __init__(self, endpoint_name: str):
        self.endpoint_name = endpoint_name

    def do_inference(self, prompt, **kwargs):
        payload = {"text_inputs": prompt, **kwargs}
        payload_json = json.dumps(payload).encode("utf-8")
        query_response = query_endpoint_with_json_payload(payload_json, endpoint_name=endpoint_name)
        generated_text = parse_response(query_response)
        return generated_text[len(prompt):]  # remove the context from the output


class JumpStartOpenChatKitShell(cmd.Cmd):
    intro = (
        "Welcome to the OpenChatKit chatbot shell, modified to use a SageMaker JumpStart endpoint! Type /help or /? to "
        "list commands. For example, type /quit to exit shell.\n"
    )
    prompt = ">>> "
    
    def __init__(self, endpoint_name: str, cmd_queue: Optional[List[str]] = None, **kwargs):
        super().__init__()
        self._endpoint_name = endpoint_name
        self._payload_kwargs = kwargs
        if cmd_queue is not None:
            self.cmdqueue = cmd_queue

    def preloop(self):
        self._model = JumpStartChatModel(self._endpoint_name)
        self._convo = Conversation(self._model.human_id, self._model.bot_id)

    def precmd(self, line):
        command = line[1:] if line.startswith('/') else 'say ' + line
        return command

    def do_say(self, arg):
        self._convo.push_human_turn(arg)
        output = self._model.do_inference(self._convo.get_raw_prompt(), **self._payload_kwargs)
        self._convo.push_model_response(output)
        print(self._convo.get_last_turn())

    def do_reset(self, arg):
        self._convo = Conversation(self._model.human_id, self._model.bot_id)

    def do_hyperparameters(self, arg):
        print(f"Hyperparameters: {self._payload_kwargs}\n")

    def do_quit(self, arg):
        return True

***
We can now launch this shell as a command loop. This will repeatedly issue a prompt, accept input, parse the input command, and dispatch actions. Because the resulting shell may be utilized in an infinite loop, this notebook provides a default command queue (`cmdqueue`) as a queued list of input lines; because the last input is the command `/quit`, the shell will exit upon exhaustion of the queue. To dynamically interact with this chatbot, please remove the `cmdqueue`.

***

In [None]:
cmd_queue = [
    "Hello!",
    "Make a markdown table of national parks with their location and date established.",
    "/quit",
]
JumpStartOpenChatKitShell(
    endpoint_name=endpoint_name,
    cmd_queue=cmd_queue,
    max_new_tokens=128,
    do_sample=True,
    temperature=0.6,
    top_k=40,
).cmdloop()

### 7. Clean up the endpoint

In [None]:
# Delete the SageMaker endpoint
model_predictor.delete_model()
model_predictor.delete_endpoint()