In [None]:
import sagemaker, boto3
from sagemaker.model import Model
from sagemaker.session import Session
from sagemaker.predictor import Predictor
from sagemaker import image_uris, model_uris

In [None]:
# Get endpoint name from cloudformation output
cf_client = boto3.client('cloudformation')
stackname = 'LLMStack'

response = cf_client.describe_stacks(StackName=stackname)
outputs = response["Stacks"][0]["Outputs"]

cf_outputs = {}
for i in outputs:
    cf_outputs[i['OutputKey']] = i['OutputValue']


In [None]:
endpoint_config = {
    'generative': {
        'model_id': 'huggingface-text2text-flan-t5-xl',
        'instance_type': 'ml.g5.2xlarge',
        'instance_count': 1,
        'endpoint_name': cf_outputs["SageMakerEndpointGenerative"],
        'env': {"TS_DEFAULT_WORKERS_PER_MODEL": "1"},
        'predictor': Predictor
    },
    'embeddings': {
        'model_id': 'huggingface-textembedding-gpt-j-6b',
        'instance_type': 'ml.g5.4xlarge',
        'instance_count': 1,
        'endpoint_name': cf_outputs["SageMakerEndpointEmbeddings"],
        'env': {"TS_DEFAULT_WORKERS_PER_MODEL": "1"},
        'predictor': Predictor
    }
}

In [None]:
sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()

aws_region = boto3.Session().region_name
sess = sagemaker.Session()

print(f'aws_role={aws_role}')
print(f'aws_region={aws_region}')

In [None]:
def deploy_model(name: str, config: dict):
    assert name in endpoint_config.keys(), 'Provide proper name of the model.'

    model_id = config[name]['model_id']
    instance_type = config[name]['instance_type']
    instance_count = config[name]['instance_count']
    endpoint_name = config[name]['endpoint_name']
    env = config[name]['env']
    predictor = config[name]['predictor']

    deploy_image_uri = image_uris.retrieve(
        region=None,
        framework=None, 
        image_scope="inference",
        model_id=model_id,
        model_version='*',
        instance_type=instance_type)

    model_uri = model_uris.retrieve(
        model_id=model_id,
        model_version='*',
        model_scope="inference"
    )

    print(f'deploy_image_uri: {deploy_image_uri} \n')
    print(f'model_uri: {model_uri}')
    
    model_inference = Model(
        image_uri=deploy_image_uri,
        model_data=model_uri,
        role=aws_role,
        predictor_cls=predictor,
        name=endpoint_name,
        env=env,
    )

    model_predictor_inference = model_inference.deploy(
        initial_instance_count=instance_count,
        instance_type=instance_type,
        predictor_cls=predictor,
        endpoint_name=endpoint_name,
    )
    
    print(f'Deployed model with endpoint: {endpoint_name}')

In [None]:
deploy_model('embeddings', endpoint_config)

In [None]:
deploy_model('generative', endpoint_config)