# <a id='toc1_'></a>[Using Bria 1.4 with AWS JumpStart](#toc0_)

This sample notebook shows you how to deploy BRIA v1.4 Safe for Commercial Use Model as an endpoint on Amazon SageMaker.

> **Note**: This is a reference notebook and it cannot run unless you make changes suggested in the notebook.
 
<a id='toc1_1_'></a>[Prerequisites](#toc0_)

1. **Note**: Open this notebook from an Amazon SageMaker Notebook Instance or Amazon SageMaker Studio.

1. Ensure that IAM role used has **AmazonSageMakerFullAccess**

1. To deploy the ML model successfully using the steps in this notebook, ensure that either:
    1. Your IAM role has the following three permissions and you have authority to make AWS Marketplace subscriptions in the AWS account used: 
        1. **aws-marketplace:ViewSubscriptions**
        1. **aws-marketplace:Unsubscribe**
        1. **aws-marketplace:Subscribe**  
    2. Or your AWS account has a subscription to [the Bria 1.4 - PLACEHOLDER MARKETPLACE LINK](https://aws.amazon.com/marketplace/pp/) . If so, skip step: [Subscribe to the model package](#1.-Subscribe-to-the-model-package)


## <a id='toc1_3_'></a>[Usage instructions](#toc0_)
You can run this notebook one cell at a time (By using Shift+Enter for running a cell).


   
- [1. Subscribe to the Bria Model Package](https://aws.amazon.com/marketplace/pp/prodview-pn2xoztixtsbc)
- [2: Create an endpoint and perform real-time inference](#toc2_)
  - [A: Text to image](#toc4_1_)
- [3: Install Bria agent to share back atribution](#toc5_)
- [4: Delete the endpoint](#toc6_)

# <a id='toc2_'></a>[2: Create an endpoint and perform real-time inference](#toc0_)

In [None]:
import base64
import json
from io import BytesIO

import boto3
import sagemaker
from PIL import Image
from sagemaker import ModelPackage, get_execution_role
from sagemaker.utils import name_from_base

Once you have subscribed to Bria marketplace product, get the model package ARN:


In [None]:
# PLACEHOLDER
package_arn = "arn:aws:sagemaker:us-east-1:865070037744:model-package/bria-v1-4-1-56b14d06134839da9044c693f8822318"

endpoint_name = name_from_base("bria-1-4-jumpstart")

region = boto3.Session().region_name

# Note: the below line will only work in a Sagemaker environment such as a Studio Notebook
# If you're running this code locally, substitute with a role ARN that has SagemakerFullAccess IAM policy attached
role_arn = get_execution_role()

sagemaker_session = sagemaker.Session()

Create a deployable `ModelPackage`. For Bria 1.4 use one of the following instances types: `ml.g5.xlarge`, `ml.p4d.24xlarge`. Specify it as `instance_type` below.


In [None]:
model = ModelPackage(
    role=role_arn, model_package_arn=package_arn, sagemaker_session=sagemaker_session
)

# Deploy the ModelPackage. This will take 5-10 minutes to run
instance_type = "ml.g5.xlarge"  # valid instance types for this model are ml.g5.xlarge and p4d.24xlarge
deployed_model = model.deploy(
    initial_instance_count=1, instance_type=instance_type, endpoint_name=endpoint_name
)

If you have already deployed your model, you can also access it via your chosen `endpoint_name` and `sagemaker_session`:


In [None]:
deployed_model = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sagemaker_session,
)

Now you can invoke the model and it will return an image.

In [None]:
input = {
    "prompt": "A towering redwood tree in a forest, during twilight",
    "width": 512,
    "height": 512,
    "steps": 50,
    "seed": 42,
    "negative_prompt": "blue sky, people",
}
output = deployed_model.predict(
    data=json.dumps(input),
    initial_args={"Accept": "application/json", "ContentType": "application/json"},
).decode("utf-8")
output = json.loads(output)

Output images are included in the response's `artifacts` as base64 encoded strings. Below is a helper function for accessing decoding these images:

In [None]:
def image_decode(model_response) -> None:
    """
    Decodes and displays an image from model output

    Args:
        model_response (dict): The response object from the model.

    Returns:
        None
    """
    image = model_response["artifacts"][0]["image_base64"]
    image_data = base64.b64decode(image)
    image = Image.open(BytesIO(image_data))
    display(image)


image_decode(output)

# <a id='toc5_'></a>[3: Install attribution agent](#toc0_)

Follow instruction [here](https://github.com/Bria-AI/agent) to deploy our attribution agent.

Once ready the same API used to call j.s endpoint directlly will be avaiable via lambda that was installed by the agent stack

In [None]:
import boto3


# Set up the AWS Lambda client
lambda_client = boto3.client('lambda', region_name='your_region')

# Specify the Lambda function name
function_name = 'your_lambda_function_name'

# Input payload for the Lambda function (if needed)
payload = {
    "prompt": "A towering redwood tree in a forest, during twilight",
    "width": 512,
    "height": 512,
    "steps": 50,
    "seed": 42,
    "negative_prompt": "blue sky, people",
}

# Make the request to the Lambda function
response = lambda_client.invoke(
    FunctionName=function_name,
    InvocationType='RequestResponse',
    Payload=json.dumps(payload),
)

output = json.load(response['Payload'])

image_decode(output)

# <a id='toc6_'></a>[4: Delete the endpoint](#toc0_)

When you've finished working, you can delete the endpoint to release the EC2 instance associated with it, and stop billing.

Get your list of Sagemaker endpoints using the AWS Sagemaker SDK like this:

In [45]:
sm_client = boto3.client("sagemaker")
endpoints = sm_client.list_endpoints()["Endpoints"]
for endpoint in endpoints:
    print(endpoint["EndpointName"])

# Delete an endpoint

In [44]:
deployed_model.sagemaker_session.delete_endpoint(endpoint_name)
# Rerun the above cell to confirm that its gone.