In [None]:
import botocore
import sagemaker, boto3, json
from sagemaker import get_execution_role
import os
from sagemaker.workflow.pipeline_context import PipelineSession

aws_role = get_execution_role()
aws_region = boto3.Session().region_name
session = PipelineSession()

In [None]:
%time
import json
import uuid
from time import strftime, gmtime
from sagemaker import image_uris, script_uris, model_uris, Model


train_model_id, train_model_version, train_scope = (
    "model-txt2img-stabilityai-stable-diffusion-v2-1-base",
    "*",
    "training",
)

inference_instance_type = "ml.g4dn.2xlarge"

# Retrieve the inference docker container uri
inference_image_uri = image_uris.retrieve(
    region=aws_region,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=train_model_id,
    model_version=train_model_version,
    instance_type=inference_instance_type,
)

# Retrieve the inference script uri. This includes scripts for model loading, inference handling etc.
deploy_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope="inference"
)

# model = Model(
#   image_uri=inference_image_uri,
#   model_data="s3://sagemaker-eu-central-1-562760952310/jumpstart-example-sd-training/model/pipelines-klavescck34y-sd-transfer-learning-QAEuYiT8xt/output/model.tar.gz",
#   sagemaker_session=session,
#   role=aws_role
# )


from sagemaker.pytorch import PyTorchModel

model = PyTorchModel(model_data="s3://sagemaker-eu-central-1-562760952310/jumpstart-example-sd-training/model/pipelines-klavescck34y-sd-transfer-learning-QAEuYiT8xt/output/model.tar.gz", 
                     image_uri=inference_image_uri,
                     entry_point="inference.py",  # entry point file in source_dir and present in deploy_source_uri
                     source_dir=deploy_source_uri,
                     role=aws_role)


predictor = model.deploy(initial_instance_count=1, instance_type=inference_instance_type)

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def query(model_predictor, text):
    """Query the model predictor."""

    encoded_text = json.dumps(text).encode("utf-8")

    query_response = model_predictor.predict(
        encoded_text,
        {
            "ContentType": "application/x-text",
            "Accept": "application/json",
        },
    )
    return query_response


def parse_response(query_response):
    """Parse response and return generated image and the prompt"""

    response_dict = json.loads(query_response)
    return response_dict["generated_image"], response_dict["prompt"]


def display_img_and_prompt(img, prmpt):
    """Display hallucinated image."""
    plt.figure(figsize=(12, 12))
    plt.imshow(np.array(img))
    plt.axis("off")
    plt.title(prmpt)
    plt.show()

In [None]:
all_prompts = [
    "A photo of a cedar apple rust",
    "A photo of a cedar apple rust",
    "A photo of a cedar apple rust",
    "A photo of a cedar apple rust",
    "A photo of a cedar apple rust",
    "A photo of a cedar apple rust",
    "A photo of a cedar apple rust"
]
for prompt in all_prompts:
    query_response = query(predictor, prompt)
    img, _ = parse_response(query_response)
    display_img_and_prompt(img, prompt)