# SageMaker JumpStart Foundation Models - OpenChatKit GPT-NeoXT-Chat-Base-20B Chatbot

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)

---

---
Welcome to Amazon [SageMaker JumpStart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html)! You can use SageMaker JumpStart to solve many Machine Learning tasks through one-click in SageMaker Studio, or through [SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/overview.html#use-prebuilt-models-with-sagemaker-jumpstart).


In this demo notebook, we demonstrate how to use the SageMaker Python SDK to deploy the [GPT-NeoXT-Chat-Base-20B](https://huggingface.co/togethercomputer/GPT-NeoXT-Chat-Base-20B) model and query the model within an [OpenChatKit](https://github.com/togethercomputer/OpenChatKit) interactive shell. This demonstration provides an open-source Foundation Model chatbot for use within your application.

---

1. [Set up](#1.-Set-Up)
2. [Select a pre-trained model](#2.-Select-a-pre-trained-model)
3. [Retrieve artifacts & deploy an endpoint](#3.-Retrieve-Artifacts-&-Deploy-an-Endpoint)
4. [Query endpoint and parse response](#4.-Query-endpoint-and-parse-response)
5. [Use an OpenChatKit shell to interact with your deployed endpoint](#5.-Use-an-OpenChatKit-shell-to-interact-with-your-deployed-endpoint)
6. [Clean up the endpoint](#6.-Clean-up-the-endpoint)

Note: This notebook was tested on ml.t3.medium instance in Amazon SageMaker Studio with Python 3 (Data Science) kernel and in Amazon SageMaker Notebook instance with conda_python3 kernel.

### 1. Set up
---
Before executing the notebook, there are some initial steps required for set up.

---

In [None]:
%pip install --upgrade sagemaker --quiet

#### Permissions

---
To host on Amazon SageMaker, we need to set up and authenticate the use of AWS services. Here, we use the execution role associated with the current notebook as the AWS account role with SageMaker access. 

---

In [None]:
import sagemaker, boto3, json
from sagemaker.session import Session

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()

### 2. Select a pre-trained model
***
You can continue with the default model, or can choose a different model from the dropdown generated upon running the next cell. A complete list of SageMaker pre-trained models can also be accessed at [SageMaker pre-trained Models](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html#).
***

In [None]:
model_id, model_version = "huggingface-textgeneration2-gpt-neoxt-chat-base-20b-fp16", "*"

### 3. Retrieve Artifacts & Deploy an Endpoint

***

Using SageMaker, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset. We start by retrieving the `instance_type`, `image_uri`, and `model_uri` for the pre-trained model. To host the pre-trained model, we create an instance of [`sagemaker.model.Model`](https://sagemaker.readthedocs.io/en/stable/api/inference/model.html) and deploy it. This may take a few minutes.

***

In [None]:
from sagemaker import image_uris, model_uris, instance_types
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer


endpoint_name = name_from_base(f"jumpstart-example-{model_id}")

# Retrieve the inference instance type for the specified model.
instance_type = instance_types.retrieve_default(
    model_id=model_id, model_version=model_version, scope="inference"
)

# Retrieve the inference docker container uri.
image_uri = image_uris.retrieve(
    region=None,
    framework=None,
    image_scope="inference",
    model_id=model_id,
    model_version=model_version,
    instance_type=instance_type,
)

# Retrieve the model uri.
model_uri = model_uris.retrieve(
    model_id=model_id, model_version=model_version, model_scope="inference"
)

# Create the SageMaker model instance. The inference script is prepacked with the model artifact.
model = Model(
    image_uri=image_uri,
    model_data=model_uri,
    role=aws_role,
    predictor_cls=Predictor,
    name=endpoint_name,
)

# For regions without the default g5 instance, p3 will be default which needs a larger EBS volume.
volume_size = 256 if "p3" in instance_type else None

# Set the serializer/deserializer used to run inference through the sagemaker API.
serializer = JSONSerializer()
deserializer = JSONDeserializer()

# Deploy the Model.
predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    predictor_cls=Predictor,
    endpoint_name=endpoint_name,
    volume_size=volume_size,
    serializer=serializer,
    deserializer=deserializer,
)

### 4. Query endpoint and parse response

***
This model also supports many advanced parameters while performing inference. They include:

* **max_length:** Model generates text until the output length (which includes the input context length) reaches `max_length`. If specified, it must be a positive integer.
* **max_time:** The maximum amount of time you allow the computation to run for in seconds. Generation will still finish the current pass after allocated time has been passed. This setting can help to generate a response prior to endpoint invocation response time out errors.
* **num_return_sequences:** Number of output sequences returned. If specified, it must be a positive integer.
* **num_beams:** Number of beams used in the greedy search. If specified, it must be integer greater than or equal to `num_return_sequences`.
* **no_repeat_ngram_size:** Model ensures that a sequence of words of `no_repeat_ngram_size` is not repeated in the output sequence. If specified, it must be a positive integer greater than 1.
* **temperature:** Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If `temperature` -> 0, it results in greedy decoding. If specified, it must be a positive float.
* **early_stopping:** If True, text generation is finished when all beam hypotheses reach the end of sentence token. If specified, it must be boolean.
* **do_sample:** If True, sample the next word as per the likelihood. If specified, it must be boolean.
* **top_k:** In each step of text generation, sample from only the `top_k` most likely words. If specified, it must be a positive integer.
* **top_p:** In each step of text generation, sample from the smallest possible set of words with cumulative probability `top_p`. If specified, it must be a float between 0 and 1.
* **seed:** Fix the randomized state for reproducibility. If specified, it must be an integer.

We may specify any subset of the parameters mentioned above while invoking an endpoint. Next, we show an example of how to invoke endpoint with these arguments

***

In [None]:
payload = {
    "text_inputs": "<human>: Tell me the steps to make a pizza\n<bot>:",
    "max_length": 500,
    "max_time": 50,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": True,
    "stopping_criteria": ["<human>"],
}
response = predictor.predict(payload)
print(response[0][0]["generated_text"])

***
Here, we have provided the payload argument `"stopping_criteria": ["<human>"]`, which has resulted in the model response ending with the generation of the word sequence `"<human>"`. The SageMaker JumpStart model script will accept any list of strings as desired stop words, convert this list to a valid [`stopping_criteria` keyword argument](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationMixin.generate.stopping_criteria) to the transformers generate API, and terminate text generation when the output sequence contains any specified stop words. This is useful for two reasons: first, inference time is reduced because the endpoint does not continue to generate undesired text beyond the stop words, and, second, this prevents the OpenChatKit model from hallucinating additional human and bot responses until other stop criteria are met.
***

### 5. Use an OpenChatKit shell to interact with your deployed endpoint

***
OpenChatKit provides a command line shell to interact with the chatbot. In this notebook, we show how to utilize this shell with your deployed endpoint
***

In [None]:
%%bash
git clone --branch v0.16 --depth 1 https://github.com/togethercomputer/OpenChatKit.git
touch OpenChatKit/__init__.py
touch OpenChatKit/inference/__init__.py

***
In the following code block, we provide a bare-bones simplification of the inference scripts in this OpenChatKit repository that can interact with our deployed SageMaker endpoint. There are two main components to this:
1. A shell interpreter (`JumpStartOpenChatKitShell`) that allows for iterative inference invocations of the model endpoint, and
2. A conversation object (`Conversation`) that stores previous human/chatbot interactions locally within the interactive shell and appropriately formats past conversations for future inference context.

The `Conversation` object is imported as-is from the OpenChatKit repository. You can view the implementation of this object in the output of the following cell.
***

In [None]:
!pygmentize ./OpenChatKit/inference/conversation.py

***
While the `Conversation` object above is imported from OpenChatKit, the custom shell interpreter is explicitly defined in the following cell, simplifying the OpenChatKit implementation. We encourage you to explore the previously cloned repository to see how more in-depth features, such as token streaming, moderation models, and retrieval augmented generation may be used within this context. The context of this notebook focuses on demonstrating a minimal viable chatbot with a SageMaker JumpStart endpoint; you can add complexity as needed from here.
***

In [None]:
import cmd
import json
from typing import List, Optional

from OpenChatKit.inference.conversation import Conversation


class JumpStartOpenChatKitShell(cmd.Cmd):
    intro = (
        "Welcome to the OpenChatKit chatbot shell, modified to use a SageMaker JumpStart endpoint! Type /help or /? to "
        "list commands. For example, type /quit to exit shell.\n"
    )
    prompt = ">>> "
    human_id = "<human>"
    bot_id = "<bot>"

    def __init__(self, predictor: Predictor, cmd_queue: Optional[List[str]] = None, **kwargs):
        super().__init__()
        self.predictor = predictor
        self.payload_kwargs = kwargs
        self.payload_kwargs["stopping_criteria"] = [self.human_id]
        if cmd_queue is not None:
            self.cmdqueue = cmd_queue

    def preloop(self):
        self.conversation = Conversation(self.human_id, self.bot_id)

    def precmd(self, line):
        command = line[1:] if line.startswith("/") else "say " + line
        return command

    def do_say(self, arg):
        self.conversation.push_human_turn(arg)
        prompt = self.conversation.get_raw_prompt()
        payload = {"text_inputs": prompt, **self.payload_kwargs}
        response = self.predictor.predict(payload)
        output = response[0][0]["generated_text"][len(prompt) :]
        self.conversation.push_model_response(output)
        print(self.conversation.get_last_turn())

    def do_reset(self, arg):
        self.conversation = Conversation(self.human_id, self.bot_id)

    def do_hyperparameters(self, arg):
        print(f"Hyperparameters: {self.payload_kwargs}\n")

    def do_quit(self, arg):
        return True

***
We can now launch this shell as a command loop. This will repeatedly issue a prompt, accept input, parse the input command, and dispatch actions. Because the resulting shell may be utilized in an infinite loop, this notebook provides a default command queue (`cmdqueue`) as a queued list of input lines; when the last command in the queue, `/quit`, is executed, the shell will terminate. To dynamically interact with this chatbot, please remove the `cmdqueue`.

***

In [None]:
cmd_queue = [
    "Hello!",
    "Make a markdown table of national parks with their location and date established.",
    "/hyperparameters",
    "/quit",
]
JumpStartOpenChatKitShell(
    predictor=predictor,
    cmd_queue=cmd_queue,
    max_new_tokens=128,
    do_sample=True,
    temperature=0.6,
    top_k=40,
).cmdloop()

***
And that's it! Just a quick reminder: you can comment out the `cmd_queue` in the above cell to have an interactive dialog with the chatbot.
***

### 6. Clean up the endpoint

In [None]:
predictor.delete_model()
predictor.delete_endpoint()

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)

![This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/introduction_to_amazon_algorithms|jumpstart-foundation-models|text-generation-chatbot-openchatkit.ipynb)
