# Chain SageMaker Endpoints and Wiki

# Setup the Environment

Install the AI21 labs pypi package, langchain, and wikipedia for augmented generation

In [None]:
!pip install langchain>=0.0.123 --quiet
!pip install "ai21[SM]" --quiet
!pip install boto3 --quiet
!pip install nest_asyncio --quiet
!pip install --upgrade langchain --quiet
!pip -q install wikipedia --quiet

In [4]:
import os
from typing import Dict
import json
from typing import Optional, List, Mapping, Any, Dict
import boto3
import sagemaker

import ai21

from langchain.docstore.document import Document
from langchain.document_loaders import S3DirectoryLoader, WebBaseLoader
from langchain.llms.base import LLM
from langchain.chains.question_answering import load_qa_chain
from langchain import PromptTemplate
from langchain.agents import initialize_agent, Tool
from langchain.tools import BaseTool
from langchain.utilities import WikipediaAPIWrapper
import nest_asyncio

# Add Wiki to Langchain

In [38]:
wikipedia = WikipediaAPIWrapper()

In [None]:
wikipedia.run('when is SVB founded')

In [40]:
wikipedia_tool = Tool(
    name='wikipedia',
    func= wikipedia.run,
    description="Useful for when you need to look up a topic, country or person on wikipedia"
)


In [5]:
nest_asyncio.apply()

# Deploy LLM Endpoint to SageMaker

We will use the Jurassic Jumbo Instruct model from the AWS marketplace becasue it is the best instructor model available on AWS at the time of creating this demo.

## Setup SageMaker Variables

Create the SageMaker Session and retrieve the ARN for Jurassic.?

In [180]:
model_package_map = {
    "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/j2-jumbo-instruct-v1-0-20-8b2be365d1883a15b7d78da7217cdeab",
}
region = boto3.Session().region_name
model_package_arn = model_package_map[region]
role = sagemaker.get_execution_role()
sagemaker_session = sagemaker.Session()
runtime_sm_client = boto3.client("runtime.sagemaker")

## Deploy Jurassic to Endpoint

Deploy the Jurassic model to a SageMaker endpoint. **Please note this instance is LARGE - use caution and make sure to clean up**

In [None]:
model_name = "j2-jumbo-instruct"
content_type = "application/json"
real_time_inference_instance_type = (
    "ml.p4d.24xlarge"
)

# create a deployable model from the model package.
model = sagemaker.ModelPackage(
    role=role, model_package_arn=model_package_arn, sagemaker_session=sagemaker_session
)

# Deploy the model
predictor = model.deploy(
    1, real_time_inference_instance_type, endpoint_name=model_name, 
    model_data_download_timeout=3600,
    container_startup_health_check_timeout=600,
)

# Create a Custom LLM Handler for SageMaker Model

We can now use the `LLM` base from LangChain to create a wrapper around the Jurassic model

In [7]:
class SageMakerLLM(LLM):
    
    @property
    def _llm_type(self) -> str:
        return "jurassic-jumbo-instruct"
    
    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        response = ai21.Completion.execute(
            sm_endpoint="j2-jumbo-instruct",
            prompt=prompt,
            maxTokens=500,
            temperature=0,
            numResults=1,
            stopSequences=stop,
        )
        return response['completions'][0]['data']['text']

In [26]:
llm = SageMakerLLM()
llm("How many US states are there?.\n")

'There are 50 US states.'

# Answer a Question via a `chain`

In [29]:
query = """When was SVB founded?
"""

In [44]:
docs = [
    Document(
        page_content=wikipedia.run(query),
    )
]

In [45]:
prompt_template = """Use the following pieces of context to answer the question at the end.

{context}

Question: {question}
Answer:"""
PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)

chain = load_qa_chain(
    llm=SageMakerLLM(),
    prompt=PROMPT
)

chain({"input_documents": [docs[0]], "question": query}, return_only_outputs=True)

{'output_text': ' 1983'}

# Cleanup

In [None]:
j2 = sagemaker.predictor.Predictor('j2-jumbo-instruct"')
j2.delete_model()
j2.delete_endpoint(delete_endpoint_config=True)