## Stable Diffusion (Text to Image)

#### I. Imports 

In [None]:
from sagemaker.jumpstart.notebook_utils import list_jumpstart_models
from sagemaker.utils import name_from_base
from sagemaker.predictor import Predictor
from sagemaker import get_execution_role
from sagemaker.model import Model
from sagemaker import script_uris
from sagemaker import image_uris
from sagemaker import model_uris
import matplotlib.pyplot as plt
import numpy as np
import sagemaker
import datetime
import logging
import boto3
import json

##### Setup logging 

In [None]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies 

In [None]:
logger.info(f'[Using sagemaker version: {sagemaker.__version__}]')
logger.info(f'[Using boto3 version: {sagemaker.__version__}]')

#### II. Setup essentials 

In [None]:
ROLE = get_execution_role()
REGION = boto3.Session().region_name
session = sagemaker.Session()
logger.info(f'Region: {REGION}')

##### List all models from SageMaker JumpStart hub

In [None]:
models = list_jumpstart_models()
logger.info(f'Total number of models in SageMaker JumpStart hub = {len(models)}')

In [None]:
FILTER = 'task == txt2img'
txt2img_models = list_jumpstart_models(filter=FILTER)
txt2img_models

In [None]:
MODEL_ID = 'model-txt2img-stabilityai-stable-diffusion-v2-1-base'
MODEL_VERSION = '*'  # latest
SCOPE = 'inference'
INFERENCE_INSTANCE_TYPE = 'ml.g5.xlarge'# 'ml.p3.2xlarge'  # prefered

#### III. Retrieve inference artifacts 

In [None]:
inference_image_uri = image_uris.retrieve(region=REGION, 
                                          framework=None,
                                          model_id=MODEL_ID, 
                                          model_version=MODEL_VERSION, 
                                          image_scope=SCOPE, 
                                          instance_type=INFERENCE_INSTANCE_TYPE)
logger.info(f'Inference image URI: {inference_image_uri}')

In [None]:
inference_source_uri = script_uris.retrieve(model_id=MODEL_ID, 
                                            model_version=MODEL_VERSION, 
                                            script_scope=SCOPE)
logger.info(f'Inference source URI: {inference_source_uri}')

In [None]:
inference_model_uri = model_uris.retrieve(model_id=MODEL_ID, 
                                          model_version=MODEL_VERSION, 
                                          model_scope=SCOPE)
logger.info(f'Inference model URI: {inference_model_uri}')

In [None]:
# To increase the maximum response size from the endpoint
env = {'MMS_MAX_RESPONSE_SIZE': '20000000'}

In [None]:
current_time = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
endpoint_name = f'js-sd-{current_time}'
logger.info(f'Endpoint name = {endpoint_name}')

In [None]:
model = Model(image_uri=inference_image_uri,
              source_dir=inference_source_uri,
              model_data=inference_model_uri,
              entry_point='inference.py',  
              role=ROLE,
              predictor_cls=Predictor,
              name=endpoint_name,
              env=env)

#### IV. Deploy Stable Diffusion model as a SageMaker endpoint

Takes around 11 to 14 mins

In [None]:
%%time

_ = model.deploy(initial_instance_count=1, 
                 instance_type=INFERENCE_INSTANCE_TYPE, 
                 predictor_cls=Predictor, 
                 endpoint_name=endpoint_name)

#### V. Invoke the endpoint for inference 

In [None]:
ENDPOINT_NAME = 'js-sd-20230509001303' # update the endpoint name to the endpoint that has been deployed

In [None]:
client = boto3.client('sagemaker-runtime')

In [None]:
prompt = 'Cat in a space suit'

In [None]:
%%time

response = client.invoke_endpoint(EndpointName=endpoint_name, 
                                  Body=prompt, 
                                  ContentType='application/x-text')

In [None]:
response_body = json.loads(response['Body'].read().decode())
generated_image = response_body['generated_image']

In [None]:
plt.figure(figsize=(12, 12))
plt.imshow(np.array(generated_image))
plt.axis('off')
plt.title(prompt)
plt.show()

##### Invoke the model with generation parameters

In [None]:
prompt = """a cute magical flying cat, fantasy art drawn by disney concept artists, golden colour, high quality, highly detailed, elegant, sharp focus, concept art, character concepts, digital painting, mystery, adventure"""

* **`num_images_per_prompt`**: The num_images_per_prompt is the number of images that you want to generate for each prompt. It can be any positive integer value.
* **`num_inference_steps`**: The num_inference_steps is the number of steps that the image generator will take to create an image. It can be any positive integer value, but higher values may take longer time and consume more resources.
* **`guidance_scale`**: The guidance_scale is a parameter that controls how much the image generator will follow the prompt. It can be any positive decimal value, but higher values may result in more realistic images that match the prompt.

> **`num_inference_steps`** is the number of denoising steps that the image generator will take to create an image. Denoising steps are a process of removing noise from an image by applying a diffusion model that gradually refines the image quality. The more denoising steps you use, the higher quality image you can get, but it will also take longer time and consume more resources. The default value for num_inference_steps is `50`, which works well for most cases3. You can change this value according to your needs and preferences.

In [None]:
payload = {'prompt': prompt,
           'num_images_per_prompt': 1, 
           'num_inference_steps': 50, 
           'guidance_scale': 7.5
          }

In [None]:
payload = json.dumps(payload).encode('utf-8')

In [None]:
%%time

response = client.invoke_endpoint(EndpointName=endpoint_name, 
                                  Body=payload, 
                                  ContentType='application/x-text')

In [None]:
response_body = json.loads(response['Body'].read().decode())
generated_image = response_body['generated_image']

In [None]:
plt.figure(figsize=(12, 12))
plt.imshow(np.array(generated_image))
plt.axis('off')
plt.title(prompt)
plt.show()