## Stable Diffusion (Text to Image) - Finetune with cat images (limited data)

#### Imports 

In [None]:
from sagemaker.jumpstart.notebook_utils import list_jumpstart_models
from sagemaker.utils import name_from_base
from sagemaker.estimator import Estimator
from sagemaker import get_execution_role
from sagemaker import hyperparameters
from sagemaker import script_uris
from sagemaker import image_uris
from sagemaker import model_uris
import matplotlib.pyplot as plt
import numpy as np
import sagemaker
import datetime
import logging
import boto3
import json

##### Setup logging 

In [None]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies 

In [None]:
logger.info(f'[Using SageMaker version: {sagemaker.__version__}]')

#### I. Setup essentials 

In [None]:
ROLE = get_execution_role()
REGION = boto3.Session().region_name
session = sagemaker.Session()
logger.info(f'Region: {REGION}')

##### List all models from SageMaker JumpStart hub

In [None]:
models = list_jumpstart_models()
len(models)

In [None]:
FILTER = 'task == txt2img'
txt2img_models = list_jumpstart_models(filter=FILTER)
txt2img_models

In [None]:
MODEL_ID = 'model-txt2img-stabilityai-stable-diffusion-v2-1-base'
MODEL_VERSION = '*'  # latest
SCOPE = 'training'
TRAIN_INSTANCE_TYPE = 'ml.g4dn.2xlarge'

#### II. Retrieve training artifacts 

In [None]:
train_image_uri = image_uris.retrieve(region=REGION, 
                                      framework=None,
                                      model_id=MODEL_ID, 
                                      model_version=MODEL_VERSION, 
                                      image_scope=SCOPE, 
                                      instance_type=TRAIN_INSTANCE_TYPE)
logger.info(f'Training image URI: {train_image_uri}')

In [None]:
train_source_uri = script_uris.retrieve(model_id=MODEL_ID, 
                                        model_version=MODEL_VERSION, 
                                        script_scope=SCOPE)
logger.info(f'Training source URI: {train_source_uri}')

In [None]:
train_model_uri = model_uris.retrieve(model_id=MODEL_ID, 
                                      model_version=MODEL_VERSION, 
                                      model_scope=SCOPE)
logger.info(f'Training model URI: {train_model_uri}')

##### Setup data locations 

In [None]:
DEFAULT_BUCKET = session.default_bucket()
logger.info(f'Default bucket = {DEFAULT_BUCKET}')

In [None]:
TRAIN_DATA_INPUT_PREFIX = 'js-input/cats/'

In [None]:
TRAIN_DATA_INPUT_S3_PATH = f's3://{DEFAULT_BUCKET}/{TRAIN_DATA_INPUT_PREFIX}'
logger.info(f'Training data input S3 location => {TRAIN_DATA_INPUT_S3_PATH}')

#### III. Copy dataset for fine-tuning from local to S3

In [None]:
!rm -rf ./data/.ipynb_checkpoints/ 

In [None]:
!aws s3 cp ./data {TRAIN_DATA_INPUT_S3_PATH} --recursive

In [None]:
TRAIN_DATA_OUTPUT_PREFIX = 'js-output'
TRAIN_DATA_OUTPUT_S3_PATH = f's3://{DEFAULT_BUCKET}/{TRAIN_DATA_OUTPUT_PREFIX}'
logger.info(f'Training output S3 location => {TRAIN_DATA_OUTPUT_S3_PATH}')

#### IV. Access and update default hyperparams

In [None]:
hyperparams = hyperparameters.retrieve_default(model_id=MODEL_ID, 
                                                   model_version=MODEL_VERSION)
hyperparams 

You can also override these hyperparams 

In [None]:
hyperparams['max_steps'] = '400'
hyperparams['seed'] = '123'
hyperparams

#### V. Finetune Stable Diffusion model

In [None]:
model_prefix = name_from_base(f'js-{MODEL_ID}-')
training_job_name = f'{model_prefix}-finetuning'
logger.info(f'Train job name => {training_job_name}')

In [None]:
MAX_RUN = 360000

In [None]:
estimator = Estimator(role=ROLE, 
                      image_uri=train_image_uri, 
                      source_dir=train_source_uri, 
                      model_uri=train_model_uri, 
                      entry_point='transfer_learning.py', 
                      instance_count=1, 
                      instance_type=TRAIN_INSTANCE_TYPE, 
                      max_run=MAX_RUN, 
                      hyperparameters=hyperparams, 
                      output_path=TRAIN_DATA_OUTPUT_S3_PATH, 
                      base_job_name=training_job_name)

In [None]:
%%time

estimator.fit({'training': TRAIN_DATA_INPUT_S3_PATH}, logs=False)

## Download to g4dn.2xlarge Studio Application

In [None]:
!aws s3 cp {TRAIN_DATA_OUTPUT_S3_PATH}/{estimator.hyperparameters()["sagemaker_job_name"]}/output/model.tar.gz .
!mkdir model
!tar -zxvf model.tar.gz -C model
!rm model.tar.gz

In [None]:
!pip install ipywidgets==7.0.0 diffusers transformers --quiet

In [None]:
text = "riobugger cat in superman suit"
text = "riobugger cat portrait, a renaissance painting"
text = "pencil sketch of riobugger cat"
text = "riobugger cat in cartoon animation"
text = "riobugger cat on the beach"
text = "a photo of a riobugger cat"

In [None]:
from diffusers import StableDiffusionPipeline
import torch

text = "pencil sketch of riobugger cat"

pipe = StableDiffusionPipeline.from_pretrained(
    "./model/", 
    revision="fp16", 
    torch_dtype=torch.float16, 
    use_auth_token=True).to("cuda")
image = pipe(text).images[0]

image

#### VI. Deploy fine-tuned model as a SageMaker endpoint

##### Retrieve artifacts for inference 

In [None]:
SCOPE = 'inference'
INFERENCE_INSTANCE_TYPE = 'ml.g4dn.2xlarge'  # prefered

In [None]:
deploy_image_uri = image_uris.retrieve(region=REGION, 
                                       framework=None,  
                                       image_scope=SCOPE, 
                                       model_id=MODEL_ID, 
                                       model_version=MODEL_VERSION, 
                                       instance_type=INFERENCE_INSTANCE_TYPE)
logger.info(f'Deploy image URI => {deploy_image_uri}')

In [None]:
deploy_source_uri = script_uris.retrieve(model_id=MODEL_ID, 
                                         model_version=MODEL_VERSION, 
                                         script_scope=SCOPE)
logger.info(f'Deploy source URI => {deploy_source_uri}')

In [None]:
current_time = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
endpoint_name = f'js-ep-{current_time}'
endpoint_name

In [None]:
%%time

predictor = estimator.deploy(initial_instance_count=1, 
                             instance_type=INFERENCE_INSTANCE_TYPE, 
                             entry_point='inference.py', 
                             image_uri=deploy_image_uri, 
                             source_dir=deploy_source_uri, 
                             endpoint_name=endpoint_name)

#### VII. Invoke the endpoint for inference using Predictor 

In [None]:
prompt = 'riobugger cat in a space suit'

In [None]:
prompt = json.dumps(prompt).encode('utf-8')

In [None]:
MIME_INFO = {'ContentType': 'application/x-text', 
             'Accept': 'application/json'}

In [None]:
%%time

response = predictor.predict(prompt, MIME_INFO)

In [None]:
response = json.loads(response)
img = response['generated_image']
prompt = response['prompt']

In [None]:
plt.figure(figsize=(12, 12))
plt.imshow(np.array(img))
plt.axis('off')
plt.title(prompt)
plt.show()

#### VIII. Invoke the endpoint for inference using SageMaker run-time client (Alternative)

In [None]:
prompt = 'riobugger cat animated and dressed as a police officer'

In [None]:
client = boto3.client('sagemaker-runtime')

In [None]:
%%time

response = client.invoke_endpoint(EndpointName=endpoint_name, 
                                  Body=prompt, 
                                  ContentType='application/x-text')

In [None]:
response_body = json.loads(response['Body'].read().decode())
generated_image = response_body['generated_image']

In [None]:
plt.figure(figsize=(12, 12))
plt.imshow(np.array(generated_image))
plt.axis('off')
plt.title(prompt)
plt.show()

In [None]:
# Delete endpoint to save costs
predictor.delete_endpoint()