# SageMaker + Astra DB, integration example

Use an LLM and an Embedding model from Amazon SageMaker and a Vector Store from [Astra DB](https://astra.datastax.com) to run a simple RAG-based application.

In this notebook, you will:
- either deploy Embedding model and LLM, or connect to existing ones in SageMaker, and see them in action;
- Connect with Astra DB and create a Vector Store in it;
- populate it with example "pretend entomology" information;
- run an AI-powered entomology assistant to help identification of field insect observations.

> Note: this notebook is designed to run within Amazon SageMaker Studio. See [this page](https://awesome-astra.github.io/docs/pages/aiml/aws/aws-sagemaker/) for more information and references.

## General setup

_Note: you may see some dependency-resolution error in the output from `pip` here. Do not pay too much attention: the rest of this notebook will work just fine._

In [6]:
!pip install --upgrade pip
!pip install --quiet \
    "sagemaker==2.193.0" \
    "langchain==0.0.317" \
    "cassio>=0.1.3" \
    "datasets==2.14.5"

[0m

In [7]:
from typing import Dict, List, Optional, Any
import json


import boto3
import cassio

from datasets import load_dataset

from sagemaker.session import Session
from sagemaker import image_uris, model_uris
from sagemaker.predictor import Predictor
from sagemaker.model import Model
from sagemaker.utils import name_from_base
from sagemaker.base_serializers import JSONSerializer
from sagemaker.base_deserializers import JSONDeserializer

from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.llms import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.vectorstores import Cassandra

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml


In [8]:
boto3_sm_client = boto3.client('runtime.sagemaker')
region_name = boto3.Session().region_name

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml


#### Prepare a function that specializes the default SageMaker "Predictor": it will be later supplied when creating the `Model` objects.

In some cases one can pass a `Model` out of the box, but for our models you want to specify to use JSON serialization/deserialization when interacting with the model endpoints.

In [9]:
def my_json_predictor(*pargs, **kwargs):
    return Predictor(*pargs, **kwargs,
                     serializer=JSONSerializer(),
                     deserializer=JSONDeserializer(),
    )

## Embedding model, setup

Here you can choose between a model already deployed in the UI and a programmatic deploy throug the SageMaker SDK.

In [10]:
emb_endpoint_supplied = False

emb_endpoint_name = input("Enter the *embedding model* endpoint name if already deployed (leave empty if deploying with SDK):").strip()

if emb_endpoint_name == "":
    print(f"\n{'*' * 101}")
    print("*** INFO: the embedding model will be deployed programmatically, as no endpoint name was provided. **")
    print("***       Re-run this cell and supply the endpoint name if this is incorrect.                      **")
    print(f"{'*' * 101}")
else:
    emb_endpoint_supplied = True

Enter the *embedding model* endpoint name if already deployed (leave empty if deploying with SDK): jumpstart-dft-yo-hf-textembedding-gpt-j-6b


The following cells will go through the steps required for programmatic deployment of a JumpStart model through the SageMaker SDK.

Note that they will do nothing else than print a message, instead, if the embedding model endpoint has been given already.

In [11]:
if not emb_endpoint_supplied:
    emb_model_id = "huggingface-textembedding-gpt-j-6b"
    # Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html for the model IDs
    emb_endpoint_name = name_from_base(emb_model_id)
    print(f"[INFO] Embedding endpoint name = '{emb_endpoint_name}'")
    emb_instance_type = "ml.g5.24xlarge"
    emb_model_version = "*"
    emb_model_env = {}

    emb_deploy_image_uri = image_uris.retrieve(
        region=None,
        framework=None,
        image_scope="inference",
        model_id=emb_model_id,
        model_version=emb_model_version,
        instance_type=emb_instance_type,
    )
else:
    print("(nothing to do in this case)")

(nothing to do in this case)


In [12]:
if not emb_endpoint_supplied:
    emb_model_uri = model_uris.retrieve(
        model_id=emb_model_id,
        model_version=emb_model_version,
        model_scope="inference",
    )
    emb_model_inference = Model(
        image_uri=emb_deploy_image_uri,
        model_data=emb_model_uri,
        role=aws_role,
        predictor_cls=my_json_predictor,
        name=emb_endpoint_name,
        env=emb_model_env,
    )
else:
    print("(nothing to do in this case)")

(nothing to do in this case)


#### This is the actual deploy step.

> _Note: this cell may take even **ten minutes** to complete. You may check the SageMaker Studio 'endpoints' tab while this is running._

In [13]:
if not emb_endpoint_supplied:
    print("*** About to start the embedding model deploy ...\n")
    emb_predictor = emb_model_inference.deploy(
        initial_instance_count=1,
        instance_type=emb_instance_type,
        predictor_cls=my_json_predictor,
        endpoint_name=emb_endpoint_name,
    )
    print("\n*** Embedding model deploy completed.")
else:
    print("(nothing to do in this case)")

(nothing to do in this case)


## Embedding model, LangChain setup

To be able to work with the shape of the input and output specific to _this_ embedding model, we need to create and supply a suitable `EmbeddingsContentHandler` when instantiating the LangChain abstraction for the SageMaker embedding:

In [14]:
class SageMakerGPTJ6BContentHandler(EmbeddingsContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:
        input_encoded = json.dumps({
            "text_inputs": inputs,
            **model_kwargs,
        }).encode("utf-8")
        return input_encoded

    def transform_output(self, output: bytes) -> List[List[float]]:
        """
        `output` is actually a botocore.response.StreamingBody object in our case
        """
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json["embedding"]


emb_content_handler = SageMakerGPTJ6BContentHandler()

embeddings = SagemakerEndpointEmbeddings(
    endpoint_name=emb_endpoint_name,
    region_name=region_name,
    content_handler=emb_content_handler,
)

### Embedding model, test invocation through LangChain

As a simple test, we check that the model returns vectors normalized to having unit norm:

In [16]:
vector1 = embeddings.embed_query("Hello, SageMaker")
vectors = embeddings.embed_documents(["Can you embed multiple sentences at once?", "Sure, you can."])

print(f"Norm of 'vector1': {sum(x*x for x in vector1):.4f}")

print("Norms of 'vectors'")
for i, v in enumerate(vectors):
    print(f"    [{i}] norm = {sum(x*x for x in v):.4f}")

Norm of 'vector1': 1.0000
Norms of 'vectors'
    [0] norm = 1.0000
    [1] norm = 1.0000


## LLM, setup

Here you can choose between a model already deployed in the UI and a programmatic deploy throug the SageMaker SDK.

In [17]:
llm_endpoint_supplied = False

llm_endpoint_name = input("Enter the *LLM* endpoint name if already deployed (leave empty if deploying with SDK):").strip()

if llm_endpoint_name == "":
    print(f"\n{'*' * 89}")
    print("*** INFO: the LLM will be deployed programmatically, as no endpoint name was provided. **")
    print("***       Re-run this cell and supply the endpoint name if this is incorrect.          **")
    print(f"{'*' * 89}")
else:
    llm_endpoint_supplied = True

Enter the *LLM* endpoint name if already deployed (leave empty if deploying with SDK): jumpstart-dft-my2-meta-textgeneration-llama-2-70b-f


The following cells work similarly to the embedding model deployment seen earlier:

In [18]:
if not llm_endpoint_supplied:
    llm_model_id = "meta-textgeneration-llama-2-70b-f"
    # Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html for the model IDs
    llm_endpoint_name = name_from_base(llm_model_id)
    print(f"[INFO] LLM endpoint name = '{llm_endpoint_name}'")
    llm_instance_type = "ml.g5.48xlarge"
    llm_model_version = "*"
    llm_model_env = {}
    # llm_model_env = {"SAGEMAKER_MODEL_SERVER_WORKERS": "1", "TS_DEFAULT_WORKERS_PER_MODEL": "1"} # TODO: check if relevant

    llm_deploy_image_uri = image_uris.retrieve(
        region=None,
        framework=None,
        image_scope="inference",
        model_id=llm_model_id,
        model_version=llm_model_version,
        instance_type=llm_instance_type,
    )
else:
    print("(nothing to do in this case)")

(nothing to do in this case)


In [19]:
if not llm_endpoint_supplied:
    llm_model_uri = model_uris.retrieve(
        model_id=llm_model_id,
        model_version=llm_model_version,
        model_scope="inference",
    )

    llm_model_inference = Model(
        image_uri=llm_deploy_image_uri,
        model_data=llm_model_uri,
        role=aws_role,
        predictor_cls=my_json_predictor,
        name=llm_endpoint_name,
        env=llm_model_env,
    )
else:
    print("(nothing to do in this case)")

(nothing to do in this case)


#### This is the actual deploy step.

> _Note: this cell may take even **twenty minutes or so** to complete. You may check the SageMaker Studio 'endpoints' tab while this is running._

In [20]:
if not llm_endpoint_supplied:
    print("*** About to start the LLM deploy ...\n")
    llm_predictor = llm_model_inference.deploy(
        initial_instance_count=1,
        instance_type=llm_instance_type,
        predictor_cls=my_json_predictor,
        endpoint_name=llm_endpoint_name,
    )
    print("\n*** LLM deploy completed.")
else:
    print("(nothing to do in this case)")

(nothing to do in this case)


## LLM, LangChain setup

Similarly as what was done for the embedding model, we need to provide a "Content Handler" tailored to the specific signature of this LLM.

While we are at it, we allow for custom "system role" instructions to be coded in the actual payload to the LLM:

In [21]:
DEFAULT_SYSTEM_ROLE_INSTRUCTIONS = "You are a helpful chat assistant."

class Llama2_70BChatContentHandler(LLMContentHandler):

    content_type = "application/json"
    accepts = "application/json"
    system_role_instructions: str

    def __init__(self, *pargs, system_role_instructions: str = DEFAULT_SYSTEM_ROLE_INSTRUCTIONS, **kwargs):
        self.system_role_instructions = system_role_instructions
        LLMContentHandler.__init__(self, *pargs, **kwargs)
    
    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_dict = {
            "inputs": [
                [
                    {"role": "system", "content": self.system_role_instructions},
                    {"role": "user", "content": prompt}
                ]
            ],
            "parameters": model_kwargs,
        }
        return json.dumps(input_dict).encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        return response_json[0]["generation"]["content"].strip()


llm_content_handler = Llama2_70BChatContentHandler(
    system_role_instructions=(
        "You are a helpful biology chat assistant; your task is to fulfill the incoming requests "
        "truthfully and with scientific accuracy, also cutting the clutter and pleasantries."
    ),
)
    
llm = SagemakerEndpoint(
    endpoint_name=llm_endpoint_name,
    region_name=region_name,
    model_kwargs= {"max_new_tokens": 1024, "top_p": 0.4, "temperature": 0.8},
    content_handler=llm_content_handler,
    # model-specific (Llama requires acceptance of the EULA)
    endpoint_kwargs={
        'CustomAttributes': 'accept_eula=true',
    },
)

_A note about the `endpoint_kwargs` parameter._

As mentioned earlier, for this model each LLM call must carry a special header to signal acceptance of the EULA. This is accomplished,
at the LangChain level, by passing this parameter when creating the `SagemakerEndpoint` instance. For reference, you can check how this paramter
is used within the LangChain code ([check the code](https://github.com/langchain-ai/langchain/blob/7db6aabf65e70811e40ee6f2e1ba8e0425ba81c9/libs/langchain/langchain/llms/sagemaker_endpoint.py#L359C23-L359C39)).
Essentially the EULA acceptance flag is passed down to the underlying `boto3` library, whose `invoke_endpoint` method accepts the `CustomAttributes` parameter
([check the docs](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker-runtime/client/invoke_endpoint.html#invoke-endpoint)).

### LLM, test invocation through LangChain

In [22]:
print(llm("Would a wolf eat a mouse, or would the mouse eat the wolf? Keep the answer as short as possible."))

Wolf eats mouse.


## Vector store on Astra DB

In this section, first initialize `cassio` globally so that it establishes a connection to your database (your connection secrets must be provided).

In [23]:
ASTRA_DB_ID = input("Enter your Astra DB ID ('0123abcd-'):")
ASTRA_DB_APPLICATION_TOKEN = getpass("Enter your Astra DB Token ('AstraCS:...'):")
ASTRA_DB_KEYSPACE = input("Enter your keyspace name (optional, default keyspace used if not provided):")

In [24]:
cassio.init(
    token=ASTRA_DB_APPLICATION_TOKEN,
    database_id=ASTRA_DB_ID,
    keyspace=ASTRA_DB_KEYSPACE if ASTRA_DB_KEYSPACE else None,
)

Now a vector store is created, ready for use:

In [25]:
astra_v_store = Cassandra(
    session=None,   # <-- meaning: fall back to globally-set "cassio.init()" connection
    keyspace=None,  # <-- meaning: fall back to globally-set "cassio.init()" connection
    table_name="sagemaker_demo_v_store",
    embedding=embeddings,
)

A small example dataset is loaded through HuggingFace. You can print a sample item to get an idea of its structure.

In [26]:
sample_dataset = load_dataset("datastax/entomology")["train"]

def _shorten(dct): return {k: v if len(v) < 40 else v[:40]+"..." for k, v in dct.items()}

print(f"Loaded {len(sample_dataset)} entries")
print("Example entry:")
print("\n".join(
    f"    {l}" for l in json.dumps(_shorten(sample_dataset[19]), indent=4).split("\n")
))

Loaded 32 entries
Example entry:
    {
        "id": "psilodermis_spectrus",
        "name": "Psilodermis spectrus",
        "order": "Coleoptera",
        "description": "Beetle with transparent wings and a spec..."
    }


The dataset is prepared for insertion in the vector store:

_(Note: Care is taken of calculating IDs deterministically to avoid accidental creation of duplicates in case the `add_texts` cell is run repeatedly.)_

In [27]:
texts = [entry["description"] for entry in sample_dataset]
metadatas = [
    {
        "name": entry["name"],
        "order": entry["order"],
    }
    for entry in sample_dataset
]
ids = [entry["name"].lower().replace(" ", "_") for entry in sample_dataset]

print(f"Example from `texts`:\n    \"{texts[19][:40]}...\"")
print(f"Example from `metadatas`:\n    {metadatas[19]}")
print(f"Example from `ids`:\n    \"{ids[19]}\"")

Example from `texts`:
    "Beetle with transparent wings and a spec..."
Example from `metadatas`:
    {'name': 'Psilodermis spectrus', 'order': 'Coleoptera'}
Example from `ids`:
    "psilodermis_spectrus"


This is where the writes take place (and the embedding vectors are calculated for each item in `texts`):

In [28]:
inserted_ids = astra_v_store.add_texts(texts=texts, metadatas=metadatas, ids=ids)

print(f"Inserted: {', '.join(inserted_ids)[:80]}... ({len(inserted_ids)} items)")

Inserted: glitterus_aurorae, luminastrum_nocturnalis, prismaticus_geminus, cryptocicada_my... (32 items)


## Set up the full pipeline

### Retrieval part

Package the search part of the flow in a handy function:

In [29]:
def find_similar_entries(description, k=3, order=None):
    if order:
        md = {"order": order}
    else:
        md = {}
    documents = astra_v_store.similarity_search(description, k=k, filter=md)
    return documents

In [30]:
print(find_similar_entries("Long wings with brown spots, flies erratically, thin legs", k=2, order="Odonata"))

[Document(page_content='Bright blue body, wings edged with a golden hue. Often seen hovering above ponds and lakes during springtime. Characteristic double-wing beat.', metadata={'name': 'Hexaplectra azurea', 'order': 'Odonata'}), Document(page_content='Ruby-red dragonfly with black-tipped wings. Fast flyer, and can often be seen darting between flowers on sunny days.', metadata={'name': 'Rubroptera rosetta', 'order': 'Odonata'})]


### Generation part

In [31]:
PROMPT_TEMPLATE = """You are an expert entomologist tasked with helping specimen identification on the field.
You are given relevant excerpts from an invertebrate textbook along with my field observation.
Your task is to compare my observation with the textbook excerpts and come to an identification, explaining why you came to that conclusion and giving the degree of certainity.
Only use the information provided in the user observation to come to your conclusion!
Be sure to provide, in your verdict, the species' Order together with the full Latin name.

USER OBSERVATION: {description}

TEXTBOOK CANDIDATE MATCHES:
{candidates}

YOUR EXPLAINED IDENTIFICATION:"""

The above is the prompt that will be used in the full RAG pipeline. Another handy tool is a utility function to turn the returned items from the vector store into a single-string description, ready for insertion in the prompt template:

In [32]:
def describe_candidates(matches):
    return "\n".join([
        f"Candidate species {i+1}: '{doc.metadata['name']}' (order: {doc.metadata['order']})\nDescription: {doc.page_content}\n"
        for i, doc in enumerate(matches)
    ])

In [33]:
print(describe_candidates(find_similar_entries("Long wings with brown spots, flies erratically, thin legs", k=2, order="Odonata")))

Candidate species 1: 'Hexaplectra azurea' (order: Odonata)
Description: Bright blue body, wings edged with a golden hue. Often seen hovering above ponds and lakes during springtime. Characteristic double-wing beat.

Candidate species 2: 'Rubroptera rosetta' (order: Odonata)
Description: Ruby-red dragonfly with black-tipped wings. Fast flyer, and can often be seen darting between flowers on sunny days.



This is the main function implementing the complete RAG pipeline: search, construction of the prompt, and invocation of the LLM to get the answer.

In [34]:
def identify_and_suggest(description, order=None):
    matches = find_similar_entries(description, k=3, order=order)
    candidates_text = describe_candidates(matches)
    prompt = PROMPT_TEMPLATE.format(
        description=description,
        candidates=candidates_text,
    )
    return llm(prompt)

### Putting it all to test

In [35]:
print(identify_and_suggest("A large butterfly with elongated wing tips and a yellow spot in the middle of each wing."))

Based on your observation of a large butterfly with elongated wing tips and a yellow spot in the middle of each wing, I would identify the species as 'Diamantis glittoris' (order: Lepidoptera).

The description of 'Diamantis glittoris' matches your observation in several ways. Firstly, the butterfly is described as having sparkling wings that shine like diamonds in sunlight, which aligns with your observation of a yellow spot in the middle of each wing. Additionally, the species is known to have a slender body and short, clubbed antennae, which is consistent with your description of a large butterfly with elongated wing tips.

Furthermore, 'Diamantis glittoris' is known to inhabit flowery gardens and coastal habitats, which could explain why you spotted the butterfly in the field.

My degree of certainty for this identification is moderate, as there could be other species that match your observation. However, based on the information provided, 'Diamantis glittoris' seems to be the most

In [40]:
print(identify_and_suggest("I found an elongated brown but with small wings, dark elitra and sturdy antennae in a meadow."))

Based on your observation of an elongated brown body with small wings, dark elytra, and sturdy antennae, I would identify the insect you found as a member of the order Coleoptera, specifically the species Psilodermis spectrus.

The description of Psilodermis spectrus matches your observation in several ways. Firstly, the beetle has transparent wings, which fits with your observation of small wings. Secondly, the body of the beetle is slender and colorless, which aligns with your description of an elongated brown body. Additionally, the species is known to have faint, shifting colors on its back, which could explain the dark elytra you observed. Finally, the fact that the insect was found in a meadow suggests that it may have been attracted to the vegetation, which is consistent with the species' known habitat of haunted forests and eerie swamps.

While the other two candidate species, Anthroptila punctatus and Zephyrella albis, do share some similarities with your observation, they do 

In [37]:
print(identify_and_suggest("What looked like a leaf was in fact moving! It startled me greatly. But I'm not sure it's an insect, I did not see antennae. What was it?"))

Dear User,

Thank you for your observation. Based on your description of a moving, leaf-like entity without antennae, I would identify the organism you encountered as Neonymphalis radiatus, a species of butterfly in the order Lepidoptera.

The absence of antennae is not unusual for a butterfly, as many species have reduced or absent antennae. The description of the wings as "neon-bright, radiant patterns" matches the characteristic appearance of Neonymphalis radiatus, which has vibrant, iridescent markings on its wings. The robust body and glowing tips on the antennae also align with the description of this species.

While Cryptoclytra mirabilis and Psilodermis spectrus are both intriguing species, they do not match the details of your observation as closely. Cryptoclytra mirabilis has a more elongated body and long, slender antennae, and is typically found in historic buildings and libraries, whereas Psilodermis spectrus has a slender, colorless body and is found in haunted forests an

### The "final app":

The loop below is a simple "app" to repeatedly interact with the entomology assistant:

- Try it with simple observations such as _I found a strange bug in the library, whose appearance was that of an old piece of paper. What was it?_
- Enter an empty input to end the cell.

In [41]:
while True:
    observation = input("\n=============================\nEnter your field observation: ").strip()
    if observation:
        print("-----------------------------")
        result = identify_and_suggest(observation)
        print(f"Result ==> {result}")
    else:
        print("(no input)")
        break
        
print("\n========\nGoodbye.")


Enter your field observation:  I found a strange bug in the library, whose appearance was that of an old piece of paper. What was it?


-----------------------------
Result ==> Based on your observation of the bug's appearance, which resembles an old piece of paper, and its habitat in a library, I would identify it as Cryptoclytra mirabilis (order: Hemiptera).

The description of Cryptoclytra mirabilis matches your observation, with its wings resembling ancient parchment scrolls and its elongated body. The fact that it inhabits old libraries and historic buildings also fits with your finding the bug in a library.

While Phantasma mirus (order: Phasmida) is a good candidate due to its ability to mimic small, dry twigs, its habitat in deciduous forests does not match the library environment. Psilodermis spectrus (order: Coleoptera) is also a good candidate due to its ghostly appearance, but its habitat in haunted forests and eerie swamps does not match the library setting.

Therefore, with a high degree of certainty, I identify the bug you found in the library as Cryptoclytra mirabilis (order: Hemiptera).



Enter your field observation:  


(no input)

Goodbye.


## Appendix: non-LangChain model tests

The code below is not part of the main LangChain-based application, but shows how you can use the SageMaker endpoints at lower abstraction layers than LangChain, namely by calling directly the boto3 or the SageMaker SDK primitives. Note that in the latter case, if you have deployed the model in the SageMaker UI, you will have to construct a `Predictor` object manually.

_These non-LangChain idioms are important in themselves, as they open the way to a richer set of possibilities for integrating Astra DB with Amazon SageMaker._

### Embedding model, test invocation through boto3

In [42]:
encoded_body = json.dumps(
    {
        "text_inputs": [
            "Can you invoke a SageMaker embedding model from boto3 directly?",
            "Wait and see..."
        ]
    }
).encode("utf-8")

response = boto3_sm_client.invoke_endpoint(
    EndpointName=emb_endpoint_name,
    Body=encoded_body,
    ContentType='application/json',
    Accept='application/json',
)

response_body = response['Body']
read_body = response_body.read()
response_json = json.loads(read_body.decode())

# This is a list 2 lists, each made of 4096 floats:
embedding_vectors = response_json['embedding']

print(f"Returned {len(embedding_vectors)} embedding vectors.")
print(f"Each is made of {len(embedding_vectors[0])} float values.")
print(f"  The first one starts with: {str(embedding_vectors[0])[:80]}...")

Returned 2 embedding vectors.
Each is made of 4096 float values.
  The first one starts with: [0.016896691173315048, -1.7813106751418673e-05, -0.007678704336285591, 0.0056925...


### Embedding model, test invocation through SageMaker SDK

In [43]:
if emb_endpoint_supplied:
    emb_predictor = Predictor(
        emb_endpoint_name,
        serializer=JSONSerializer(),
        deserializer=JSONDeserializer(),
    )
else:
    # `emb_predictor` was already created as part of the deploy-from-code procedure
    pass

response_json = emb_predictor.predict(
    {"text_inputs": [
            "Can you show me how to use the SageMaker SDK directly for embeddings?",
            "Let me look at the docs..."
        ]
    }
)

# This is a list 2 lists, each made of 4096 floats:
embedding_vectors = response_json["embedding"]

print(f"Returned {len(embedding_vectors)} embedding vectors.")
print(f"Each is made of {len(embedding_vectors[0])} float values.")
print(f"  The first one starts with: {str(embedding_vectors[0])[:80]}...")

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
Returned 2 embedding vectors.
Each is made of 4096 float values.
  The first one starts with: [0.010399903170764446, -0.014787523075938225, 0.005106857977807522, -0.008509762...


### LLM, test invocation through boto3

For this particular model, the exact shape of the input data can be found in [this blog post](https://aws.amazon.com/blogs/machine-learning/llama-2-foundation-models-from-meta-are-now-available-in-amazon-sagemaker-jumpstart/).

In [44]:
encoded_body = json.dumps({
    "inputs":
        [
            [
                {
                    "role": "system",
                    "content": "Always answer with scientific accuracy",
                },
                {
                    "role": "user",
                    "content": "How many legs do spiders have?",
                },
            ],
        ],
       "parameters": {
           "max_new_tokens": 256,
           "top_p": 0.9,
           "temperature": 0.6
       },
}).encode("utf-8")

response = boto3_sm_client.invoke_endpoint(
    EndpointName=llm_endpoint_name,
    Body=encoded_body,
    ContentType='application/json',
    Accept='application/json',
    # This is required for each invocation of this model:
    CustomAttributes='accept_eula=true',
)
response_body = response['Body']
read_body = response_body.read()
response_json = json.loads(read_body.decode())

print(f"Full response:\n")
print(json.dumps(response_json, indent=4))

Full response:

[
    {
        "generation": {
            "role": "assistant",
            "content": " Spiders have eight legs."
        }
    }
]


### LLM, test invocation through SageMaker SDK

Note how the EULA acceptance is passed in this case ([reference](https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html)).

In [45]:
if llm_endpoint_supplied:
    llm_predictor = Predictor(
        llm_endpoint_name,
        serializer=JSONSerializer(),
        deserializer=JSONDeserializer(),
    )
else:
    # `llm_predictor` was already created as part of the deploy-from-code procedure
    pass


response_json = llm_predictor.predict(
    {
        "inputs":
            [
                [
                    {
                        "role": "system",
                        "content": "Always answer with scientific accuracy",
                    },
                    {
                        "role": "user",
                        "content": "How many up quarks are in a proton?",
                    },
                ],
            ],
           "parameters": {
               "max_new_tokens": 256,
               "top_p": 0.9,
               "temperature": 0.6
           },
    },
    custom_attributes = 'accept_eula=true',
)

print(f"Full response:\n")
print(json.dumps(response_json, indent=4))

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
Full response:

[
    {
        "generation": {
            "role": "assistant",
            "content": " A proton is composed of three quarks: two up quarks and one down quark. Therefore, there are two up quarks in a proton."
        }
    }
]
