In [None]:
import sys
import time
import logging
import sagemaker, boto3, json
from sagemaker.model import Model
from sagemaker.session import Session
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base
from sagemaker import image_uris, model_uris, script_uris, hyperparameters

In [None]:
# CONSTANTS
# Here I'm taking small model and instance type to avoid costs
# For production, of course XXL model with strong instance will perform better.
MODEL_ID = "huggingface-text2text-flan-t5-base"
INSTANCE_TYPE = 'ml.g4dn.2xlarge'
INSTANCE_COUNT = 1
ENDPOINT_NAME = 'LLM_QA_FLAN_T5_BASE'

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()

aws_region = boto3.Session().region_name
sess = sagemaker.Session()

logger.info(f'aws_role={aws_role}')
logger.info(f'aws_region={aws_region}')

In [None]:
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None, 
    image_scope="inference",
    model_id=MODEL_ID,
    model_version='*',
    instance_type=INSTANCE_TYPE)

model_uri = model_uris.retrieve(
    model_id=MODEL_ID, 
    model_version='*', 
    model_scope="inference"
)

logger.info(f"deploy_image_uri={deploy_image_uri}, model_uri={model_uri}")

In [None]:
model_inference = Model(
    image_uri=deploy_image_uri,
    model_data=model_uri,
    role=aws_role,
    predictor_cls=Predictor,
    name=ENDPOINT_NAME,
    env={"TS_DEFAULT_WORKERS_PER_MODEL": "1"},
)

model_predictor_inference = model_inference.deploy(
    initial_instance_count=INSTANCE_COUNT,
    instance_type=INSTANCE_TYPE,
    predictor_cls=Predictor,
    endpoint_name=ENDPOINT_NAME,
)