# Amazon Bedrock and Stability.ai Stable Image Ultra 1.0 Advertising Demo

In this notebook we will learn how to use the AWS SDK for Python (Boto3) to create Ad posters using [Amazon Bedrock](https://aws.amazon.com/bedrock/) and [Stability.ai](https://stability.ai/stable-image) Stable Image Ultra1.0 model, with the help of [Anthropic Claude 3](https://www.anthropic.com/claude).  It demonstrates how to produce a series of images using LLM refined prompts for a brand called “Young Generational Shoes” aka “YGS”, inturn ensuring brand consistency and message effectiveness. By combining the ideation capabilities of LLMs with advanced image generation, this workflow empowers marketing teams to produce high-quality, tailored visual assets that resonate with their target audience more efficiently than ever before. This innovative approach has the potential to transform the creative process in advertising, enabling agencies to be more agile, productive, and aligned with rapidly evolving market trends and consumer preferences. 


Technologies:

- **Amazon Bedrock**: Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models (FMs) from leading AI companies like AI21 Labs, Anthropic, Cohere, Meta, Mistral AI, Stability AI, and Amazon via a single API

- **Stable Diffusion 3 Models**: SD3 Ultra 1.0 is the latest image generation model that is tailored towards more photorealistic outputs with more detailed imagery and composition compared to previous SD models, offering enhanced image composition and face generation that results in stunning visuals and realistic aesthetics.

- **Anthropic Claude 3 Model**: Claude3 is a family of state-of-the-art large language models developed by Anthropic, offering 200k context window.


References:

- [Amazon Bedrock SD 3.0 Documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-diffusion-1-0-image-image.html)
- [Stability.ai API Documentation](https://platform.stability.ai/docs/api-reference#tag/v1generation/operation/imageToImage)


## Prerequisites


### Python Environment

Create a virtual Python environment and install the required packages.


In [None]:
%%sh
# Install Python requirements
python3 -m pip install -r requirements.txt -Uq

### Authenticate with Your AWS Credentials

Your method of authentication may vary depending on your environment.


In [None]:
# Authenticate with AWS using your credentials

import os

os.environ["AWS_ACCESS_KEY_ID"] = ""
os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_SESSION_TOKEN"] = ""

## Define functions
Define the text-to-image, image-to-image and other utility functions

In [None]:
import base64
import io
import json
import logging
import boto3
from PIL import Image
import time
from enum import Enum, unique

from botocore.exceptions import ClientError

GENERATED_IMAGES = "./generated_images"


In [None]:
# Amazon Bedrock Model ID used throughout this notebook
# Model IDs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns
MODEL_ID = "stability.stable-image-ultra-v1:0" 


In [None]:
directory = "./generated_images"
if not os.path.exists(directory):
    os.makedirs(directory)

### Define text to image function

In [None]:
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Shows how to generate an image with SD3 Ultra 1.0.
"""

class ImageError(Exception):
    """
    Custom exception for errors returned by SD3 Ultra 1.0.
    """

    def __init__(self, message):
        self.message = message


# Set up logging for notebook environment
logger = logging.getLogger(__name__)
if logger.hasHandlers():
    logger.handlers.clear()
handler = logging.StreamHandler()
logger.addHandler(handler)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.setLevel(logging.INFO)


def generate_image_from_text(model_id, body):
    """
    Generate an image using SD3 Ultra 1.0 on demand.
    Args:
        model_id (str): The model ID to use.
        body (str) : The request body to use.
    Returns:
        image_bytes (bytes): The image generated by the model.
    """

    logger.info("Generating image with SD3 Ultra 1.0 model %s", model_id)

    bedrock = boto3.client("bedrock-runtime", region_name="us-west-2")
    
    response = bedrock.invoke_model(modelId=model_id,body=body)
    response_body= json.loads(response["body"].read())
    image_data = base64.b64decode(response_body.get("images")[0])

    logger.info("Successfully generated image with the SD3 Ultra 1.0 model %s", model_id)
    return image_data

def text_to_image_request(
    model_id,
    positive_prompt,
    # negative_prompt,
    save_image_path=None,
    seed=1664300763
):
    """
    Args:
        model_id (str): The model ID to use.
        positive_prompt (str): The positive prompt to use.
    """
    
    # Build request body
    body = json.dumps(
        {
            "prompt": positive_prompt, 
            "mode" : "text-to-image"
        }
    )

    # Generate and save image
    try:
        image = generate_image_from_text(model_id=model_id, body=body)

        if  save_image_path:
            generated_image_path = save_image_path
        else:
            epoch_time = int(time.time())
            generated_image_path = f"{GENERATED_IMAGES}/image_{epoch_time}.jpg"
        
        logger.info(f"Generated image: {generated_image_path}")
        with open(generated_image_path, "wb") as file:
            file.write(image)

        print(f"The generated image has been saved to {generated_image_path}.")
        
    except ClientError as err:
        message = err.response["Error"]["Message"]
        logger.error("A client error occurred: %s", message)
    except ImageError as err:
        logger.error(err.message)

    else:
        logger.info(f"Finished generating image with SD3 Ultra 1.0 model {model_id}.")

In [None]:
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Shows how to generate an image from a reference image with SD3 Ultra 1.0 (on demand).
"""

class ImageToImageRequest:
    """
    Class for handling image to image request parameters.
    """

    def __init__(
        self,
        image_width,
        image_height,
        positive_prompt,
        # negative_prompt,
        init_image_mode="IMAGE_STRENGTH",
        image_strength=0.5,
        cfg_scale=7,
        clip_guidance_preset="SLOWER",
        sampler="K_DPMPP_2M",
        samples=1,
        seed=1,
        steps=30,
        style_preset="photographic",
        extras=None,
    ):
        self.image_width = image_width
        self.image_height = image_height
        self.positive_prompt = positive_prompt
        # self.negative_prompt = negative_prompt
        self.init_image_mode = init_image_mode
        self.image_strength = image_strength
        self.cfg_scale = cfg_scale
        self.clip_guidance_preset = clip_guidance_preset
        self.sampler = sampler
        self.samples = samples
        self.seed = seed
        self.steps = steps
        self.style_preset = style_preset
        self.extras = extras


@unique
class StylesPresets(Enum):
    """
    Enumerator for SD3 Ultra 1.0 style presets.
    """

    THREE_D_MODEL = "3d-model"
    ANALOG_FILM = "analog-film"
    ANIME = "anime"
    CINEMATIC = "cinematic"
    COMIC_BOOK = "comic-book"
    DIGITAL_ART = "digital-art"
    ENHANCE = "enhance"
    FANTASY_ART = "fantasy-art"
    ISOMETRIC = "isometric"
    LINE_ART = "line-art"
    LOW_POLY = "low-poly"
    MODELING_COMPOUND = "modeling-compound"
    NEON_PUNK = "neon-punk"
    ORIGAMI = "origami"
    PHOTOGRAPHIC = "photographic"
    PIXEL_ART = "pixel-art"
    TILE_TEXTURE = "tile-texture"


def generate_image_from_image(model_id, body):
    """
    Generate an image using SD 3.0 on demand.
    Args:
        model_id (str): The model ID to use.
        body (str) : The request body to use.
    Returns:
        image_bytes (bytes): The image generated by the model.
    """

    logger.info("Generating image with SD3 Ultra 1.0 model %s", model_id)

    bedrock = boto3.client(service_name="bedrock-runtime")

    accept = "application/json"
    content_type = "application/json"

    response = bedrock.invoke_model(
        body=body, modelId=model_id, accept=accept, contentType=content_type
    )
    response_body = json.loads(response.get("body").read())
    logger.info(f"Bedrock result: {response_body['result']}")

    base64_image = response_body.get("artifacts")[0].get("base64")
    base64_bytes = base64_image.encode("ascii")
    image_bytes = base64.b64decode(base64_bytes)

    finish_reason = response_body.get("artifacts")[0].get("finishReason")

    if finish_reason == "ERROR" or finish_reason == "CONTENT_FILTERED":
        raise ImageError(f"Image generation error. Error code is {finish_reason}")

    logger.info("Successfully generated image with the SD3 Ultra 1.0 model %s", model_id)

    return image_bytes


def image_to_image_request(
    imageToImageRequest,
    source_image,
    save_image_path=None,
    save_image_folder=None,
):
    """
    Args:
        imageToImageRequest (ImageToImageRequest): The image to image request to use.
        generated_images (str): The directory to save the generated images to.
        source_image (str): The source image to use.
    """

    # Read source image from file and encode as base64 strings
    image = Image.open(source_image)
    new_image = image.resize(
        (imageToImageRequest.image_width, imageToImageRequest.image_height)
    )

    new_image.save(f"{source_image[:-4]}_tmp.jpg")

    with open(f"{source_image[:-4]}_tmp.jpg", "rb") as image_file:
        init_image = base64.b64encode(image_file.read()).decode("utf8")

    # Build request body
    body = json.dumps(
        {
            "text_prompts": [
                {"text": imageToImageRequest.positive_prompt, "weight": 1}
                # {"text": imageToImageRequest.negative_prompt, "weight": -1},
            ],
            "init_image": init_image,
            "init_image_mode": imageToImageRequest.init_image_mode,
            "image_strength": imageToImageRequest.image_strength,
            "cfg_scale": imageToImageRequest.cfg_scale,
            "clip_guidance_preset": imageToImageRequest.clip_guidance_preset,
            "sampler": imageToImageRequest.sampler,
            "samples": imageToImageRequest.samples,
            "seed": imageToImageRequest.seed,
            "steps": imageToImageRequest.steps,
            "style_preset": imageToImageRequest.style_preset,
        }
    )

    try:
        logger.info(f"Source image: {source_image}")
        image_bytes = generate_image_from_image(model_id=MODEL_ID, body=body)
        image = Image.open(io.BytesIO(image_bytes))
        epoch_time = int(time.time())

        if save_image_path is not None:
            generated_image_path = save_image_path
        elif save_image_folder is not None:
            generated_image_path = f"{save_image_folder}/image_{epoch_time}_{imageToImageRequest.seed}_{imageToImageRequest.sampler}_{imageToImageRequest.image_strength}_{imageToImageRequest.cfg_scale}_{imageToImageRequest.steps}_{imageToImageRequest.style_preset}.jpg"
        else:
            generated_image_path = f"{GENERATED_IMAGES}/image_{epoch_time}_{imageToImageRequest.seed}_{imageToImageRequest.sampler}_{imageToImageRequest.image_strength}_{imageToImageRequest.cfg_scale}_{imageToImageRequest.steps}_{imageToImageRequest.style_preset}.jpg"


        logger.info(f"Generated image: {generated_image_path}")
        image.save(generated_image_path, format="JPEG", quality=95)

    except ClientError as err:
        message = err.response["Error"]["Message"]
        logger.error("A client error occurred: %s", message)
    except ImageError as err:
        logger.error(err.message)

    else:
        logger.info(f"Finished generating image with SD3 Ultra 1.0 model {MODEL_ID}.")

In [None]:
from PIL import Image
from IPython.display import display

def display_image(source_image_name, width=None, height=None):
    source_image = Image.open(source_image_name)
    if width and height:
        display(source_image.resize((width, height)))
    else:
        display(source_image)
    print(source_image_name)

### Define the Claude function

In [None]:
def invoke_claude(client, prompt, max_tokens_to_sample=2000, modelId="anthropic.claude-3-sonnet-20240229-v1:0", temperature=1, top_k=250, top_p=0.999, stop_sequences=[], retry=3):
    body_dict = {"messages": [
          {
            "role": "user",
            "content": [
              {
                "type": "text",
                "text": prompt
              }
            ]}],
            "max_tokens": max_tokens_to_sample,
            "temperature": temperature,
            "top_k": top_k,
            "top_p": top_p,
            "stop_sequences": stop_sequences+["\n\nHuman:"], 
            "anthropic_version": "bedrock-2023-05-31"}
    body = json.dumps(body_dict)

    request = {
      "modelId": modelId,
      "contentType": "application/json",
      "accept": "application/json",
      "body": body
    }
    
    for trial in range(retry):
        try:
            response = client.invoke_model(**request)
            response_body = json.loads(response.get('body').read())
            break
        except Exception as e:
            print(str(e))
            print("Bedrock request is throttled. Retry in a minute. (In production this should not happen.)")
            time.sleep(60)
    
    return response_body["content"][0]["text"]

## Let's start the Ad demo

### Generate the advertising concepts

In [None]:
prompt = """You are a seasoned veteran in the advertising industry with a wealth of experience in creating captivating and impactful campaigns. Your task is to generate five different creative advertising concepts for our new line of shoes under the brand "YGS".
Our product range includes running shoes, soccer shoes, and training shoes. Our target audience is the young generation, a demographic known for their energy, trendiness, and desire to express their individuality.
Each advertising concept should seamlessly incorporate the following elements:

1. The specific type of shoe (running, soccer, tennis, hiking or training) and its intended usage.
2. A vivid description of the colors and unique features that make our shoes stand out.
3. A compelling scenario that vividly illustrates when and where these shoes would be worn, capturing the essence of the active lifestyle our target audience embraces.

Your concepts should be fresh, engaging, and resonate with the youthful spirit of our target market. Creativity, originality, and a deep understanding of our audience's aspirations and passions should shine through in your advertising ideas.
Remember, the goal is to craft compelling narratives that not only showcase our product's features but also tap into the emotions and desires of the young generation, inspiring them to embrace our brand as an extension of their vibrant lifestyles.

The output format should follow below Json format:

[
    {
        "concept": "xxx",
        "Description": "xxx",
        "Scenario": "xxx"
    },
    {
        "concept": "xxx",
        "Description": "xxx",
        "Scenario": "xxx"
    }

    ...
]

"""
client = boto3.client(service_name="bedrock-runtime")
result = invoke_claude(client, prompt)
print(result)

### Parse the advertising concepts and generate prompts for Stable Image Ultra

In [None]:
import json

def parse_json_string_to_prompt(json_string):
    prompts = []
    internal_prompt = """You are an expert to use stable diffusion model to generate shoes ad posters. Please use the following content to generate the positive prompts for stable diffusion model:
    - "Concept": {Concept}
    - "Description": {Description}
    - "Scenario": {Scenario}
    
    Output format shoud be Json format as below:
    [
        {
            "positive_prompt": "xxx",
        }
    ]
    please add below sentence to the positive prompt:
    text \'YGS\' on the Shoes as a logo.

    """
    try:
        data = json.loads(json_string)
        for item in data:
            final_prompt = internal_prompt.replace("{Concept}", item['concept']).replace("{Description}", item['Description']).replace("{Scenario}", item['Scenario'])
            generated_prompt = invoke_claude(client, final_prompt)
            print(f"generated_prompt: {generated_prompt}")
            prompts.append(generated_prompt)
    except json.JSONDecodeError:
        print("Invalid JSON string")
    
    return prompts

json_string =result

prompts = parse_json_string_to_prompt(json_string)
print(prompts)


### Generate the Ad poster for the advertising concepts

In [None]:
for index, prompt_json in enumerate(prompts):
    prompt = json.loads(prompt_json)
    for item in prompt:
        if "positive_prompt" in item:
            POSITIVE_PROMPT =item["positive_prompt"]
            print(POSITIVE_PROMPT)

    SAVE_IMAGE_PATH = f"./generated_images/ad_poster_{index}.jpg"
    print(f"Attempting to save image to: {SAVE_IMAGE_PATH}")

    text_to_image_request(
        MODEL_ID,
        POSITIVE_PROMPT,
        SAVE_IMAGE_PATH
    )