In [None]:
import botocore
import sagemaker, boto3, json
from sagemaker import get_execution_role
import os


aws_role = get_execution_role()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

# If uploading to a different folder, change this variable.
local_training_dataset_folder = "training_images"
if not os.path.exists(local_training_dataset_folder):
    os.mkdir(local_training_dataset_folder)

In [None]:
instance_prompt = "A photo of cedar apple rust"

In [None]:
# Instance prompt is fed into the training script via dataset_info.json present in the training folder. Here, we write that file.
import os
import json

with open(os.path.join(local_training_dataset_folder, "dataset_info.json"), "w") as f:
    f.write(json.dumps({"instance_prompt": instance_prompt}))

In [None]:
account_id = boto3.client("sts").get_caller_identity().get("Account")

training_bucket = f"stable-diffusion-jumpstart-{aws_region}-{account_id}"

In [None]:
assets_bucket = f"jumpstart-cache-prod-{aws_region}"


s3 = boto3.client("s3")
s3.download_file(
    f"jumpstart-cache-prod-{aws_region}",
    "ai_services_assets/custom_labels/cl_jumpstart_ic_notebook_utils.py",
    "utils.py",
)


from utils import create_bucket_if_not_exists

create_bucket_if_not_exists(training_bucket)

In [None]:
train_s3_path = f"s3://{training_bucket}/custom_cedar_apple_rust_stable_diffusion_dataset/"

!aws s3 cp --recursive $local_training_dataset_folder $train_s3_path

In [None]:
from sagemaker import image_uris, model_uris, script_uris

train_model_id, train_model_version, train_scope = (
    "model-txt2img-stabilityai-stable-diffusion-v2-1-base",
    "*",
    "training",
)

# Tested with ml.g4dn.2xlarge (16GB GPU memory) and ml.g5.2xlarge (24GB GPU memory) instances. Other instances may work as well.
# If ml.g5.2xlarge instance type is available, please change the following instance type to speed up training.
training_instance_type = "ml.g4dn.2xlarge"

# Retrieve the docker image
train_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    model_id=train_model_id,
    model_version=train_model_version,
    image_scope=train_scope,
    instance_type=training_instance_type,
)

# Retrieve the training script. This contains all the necessary files including data processing, model training etc.
train_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope=train_scope
)
# Retrieve the pre-trained model tarball to further fine-tune
train_model_uri = model_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, model_scope=train_scope
)

In [None]:
output_bucket = sess.default_bucket()

In [None]:
output_prefix = "jumpstart-example-sd-training"

s3_output_location = f"s3://{output_bucket}/{output_prefix}/output"

In [None]:
from sagemaker import hyperparameters

# Retrieve the default hyper-parameters for fine-tuning the model
hyperparameters = hyperparameters.retrieve_default(
    model_id=train_model_id, model_version=train_model_version
)

# [Optional] Override default hyperparameters with custom values. This controls the duration of the training and the quality of the output.
# If max_steps is too small, training will be fast but the the model will not be able to generate custom images for your usecase.
# If max_steps is too large, training will be very slow.
hyperparameters["max_steps"] = "200"
print(hyperparameters)

In [None]:
%time
from sagemaker.estimator import Estimator
from sagemaker.utils import name_from_base
from sagemaker.tuner import HyperparameterTuner

training_job_name = name_from_base(f"jumpstart-example-{train_model_id}-transfer-learning")

# Create SageMaker Estimator instance
sd_estimator = Estimator(
    role=aws_role,
    image_uri=train_image_uri,
    source_dir=train_source_uri,
    model_uri=train_model_uri,
    entry_point="transfer_learning.py",  # Entry-point file in source_dir and present in train_source_uri.
    instance_count=1,
    instance_type=training_instance_type,
    max_run=360000,
    hyperparameters=hyperparameters,
    output_path=s3_output_location,
    base_job_name=training_job_name,
)

# Launch a SageMaker Training job by passing s3 path of the training data
sd_estimator.fit({"training": train_s3_path}, logs=True)

In [None]:
%time

inference_instance_type = "ml.g4dn.2xlarge"

# Retrieve the inference docker container uri
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=train_model_id,
    model_version=train_model_version,
    instance_type=inference_instance_type,
)
# Retrieve the inference script uri. This includes scripts for model loading, inference handling etc.
deploy_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope="inference"
)

endpoint_name = name_from_base(f"jumpstart-example-FT-{train_model_id}-")

# Use the estimator from the previous step to deploy to a SageMaker endpoint
finetuned_predictor = sd_estimator.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    entry_point="inference.py",  # entry point file in source_dir and present in deploy_source_uri
    image_uri=deploy_image_uri,
    source_dir=deploy_source_uri,
    endpoint_name=endpoint_name,
)

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def query(model_predictor, text):
    """Query the model predictor."""

    encoded_text = json.dumps(text).encode("utf-8")

    query_response = model_predictor.predict(
        encoded_text,
        {
            "ContentType": "application/x-text",
            "Accept": "application/json",
        },
    )
    return query_response


def parse_response(query_response):
    """Parse response and return generated image and the prompt"""

    response_dict = json.loads(query_response)
    return response_dict["generated_image"], response_dict["prompt"]


def display_img_and_prompt(img, prmpt):
    """Display hallucinated image."""
    plt.figure(figsize=(12, 12))
    plt.imshow(np.array(img))
    plt.axis("off")
    plt.title(prmpt)
    plt.show()

In [None]:
all_prompts = [
    "A photo of a cedar apple rust",
    "A photo of a cedar apple rust",
    "A photo of a cedar apple rust",
    "A photo of a cedar apple rust",
    "A photo of a cedar apple rust",
    "A photo of a cedar apple rust",
    "A photo of a cedar apple rust"
]
for prompt in all_prompts:
    query_response = query(finetuned_predictor, prompt)
    img, _ = parse_response(query_response)
    display_img_and_prompt(img, prompt)