## Deploy Text Embedding Model (GPT-J 6B FP-16)

#### Imports

In [None]:
from sagemaker.jumpstart.notebook_utils import list_jumpstart_models
from sagemaker.predictor import Predictor
from sagemaker import get_execution_role
from sagemaker.model import Model
from sagemaker import script_uris
from sagemaker import image_uris 
from sagemaker import model_uris
import sagemaker
import logging
import boto3
import time
import json

##### Setup logging

In [None]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies

In [None]:
logger.info(f'Using sagemaker=={sagemaker.__version__}')
logger.info(f'Using boto3=={boto3.__version__}')

#### Setup essentials 

##### List and filter all text embedding models available in JumpStart

In [None]:
models = list_jumpstart_models()
logger.info(f'Total number of models in SageMaker JumpStart hub = {len(models)}')

FILTER = 'task == textembedding'
txt2img_models = list_jumpstart_models(filter=FILTER)
txt2img_models

##### Setup config params

In [None]:
MODEL_ID = 'huggingface-textembedding-gpt-j-6b-fp16'  
MODEL_VERSION = '*'
INSTANCE_TYPE = 'ml.g5.2xlarge'
INSTANCE_COUNT = 1
IMAGE_SCOPE = 'inference'
MODEL_DATA_DOWNLOAD_TIMEOUT = 3600  # in seconds
CONTAINER_STARTUP_HEALTH_CHECK_TIMEOUT = 3600
CONTENT_TYPE = 'application/json'

# Set up roles and clients 
client = boto3.client('sagemaker-runtime')
ROLE = get_execution_role()
logger.info(f'Role => {ROLE}')

In [None]:
unix_time = int(time.time())
endpoint_name = f'{MODEL_ID}-{unix_time}'
logger.info(f'Endpoint name: {endpoint_name}')

#### Retrieve image and model URIs

In [None]:
deploy_image_uri = image_uris.retrieve(region=None, 
                                       framework=None, 
                                       image_scope=IMAGE_SCOPE, 
                                       model_id=MODEL_ID, 
                                       model_version=MODEL_VERSION, 
                                       instance_type=INSTANCE_TYPE)
logger.info(f'Deploy image URI => {deploy_image_uri}')

In [None]:
model_uri = model_uris.retrieve(model_id=MODEL_ID, 
                                model_version=MODEL_VERSION, 
                                model_scope=IMAGE_SCOPE)
logger.info(f'Model URI => {model_uri}')

In [None]:
env = {
    'SAGEMAKER_MODEL_SERVER_TIMEOUT': str(3600),
    'MODEL_CACHE_ROOT': '/opt/ml/model', 
    'SAGEMAKER_ENV': '1',
    'SAGEMAKER_SUBMIT_DIRECTORY': '/opt/ml/model/code/',
    'SAGEMAKER_PROGRAM': 'inference.py',
    'SAGEMAKER_MODEL_SERVER_WORKERS': '1', 
    'TS_DEFAULT_WORKERS_PER_MODEL': '1', 
}

#### Create SageMaker Model

In [None]:
model = Model(image_uri=deploy_image_uri, 
              model_data=model_uri, 
              role=ROLE, 
              predictor_cls=Predictor, 
              name=endpoint_name, 
              env=env)

#### Deploy text embedding model as SageMaker endpoint for real-time synchronous inference

In [None]:
%%time

_ = model.deploy(initial_instance_count=INSTANCE_COUNT, 
                 instance_type=INSTANCE_TYPE, 
                 endpoint_name=endpoint_name, 
                 model_data_download_timeout=MODEL_DATA_DOWNLOAD_TIMEOUT, 
                 container_startup_health_check_timeout=CONTAINER_STARTUP_HEALTH_CHECK_TIMEOUT)

### Test SageMaker endpoint for inference

In [None]:
# ENDPOINT_NAME = 'huggingface-textembedding-gpt-j-6b-fp16-1680825746'

In [None]:
query = 'what is the meaning of life according to an ant?'

In [None]:
payload = {'text_inputs': [query]}
payload = json.dumps(payload).encode('utf-8')

In [None]:
%%time

response = client.invoke_endpoint(EndpointName=endpoint_name,
                                  ContentType='application/json',
                                  Body=payload)
    

##### Parse model response to extract query embedding

In [None]:
body = json.loads(response['Body'].read())
embedding = body['embedding'][0]
embedding