# Text to Image generation on SageMaker

In this notebook, you will learn how you can fine-tune an existing Stable Diffusion model on SageMaker and deploy it for inference.

## 0. Setup

In [None]:
import multiprocessing as mp
import torch
import subprocess
import os

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker.huggingface import HuggingFace
from sagemaker import get_execution_role

This notebook is purely educational for showing how to fine-tune latent-stable-diffusion on Amazon SageMaker. Neither the images produced or code represent Amazon or its views in any way shape or form. To properly leverage this codebase, read the corresponding licenses from [CompVis](https://huggingface.co/spaces/CompVis/stable-diffusion-license) (the model) and [Conceptual Captions](https://huggingface.co/datasets/conceptual_captions) (from Google, but you will use HF)

This demo requires a g5.12xlarge or more powerful instance.

Model weights were provided by CompVis/stable-diffusion-v1-4. You can find the licensing, README and more [here](https://huggingface.co/CompVis/stable-diffusion-v1-4). To download the weights, you will need to have a huggingface account, accept the terms on the aforementioned link, then generate your user authenticated token. These steps are beyond the scope of this Notebook. Please note that the finetune.py script has been slightly modified from a PR request [here](https://github.com/huggingface/diffusers/pull/356)

You will install some libraries so that you can use stable-diffusion locally.

In [None]:
!pip install -r ./src/requirements.txt -q

## 1. Download Model and Data
Now you will download the model first. You can modify the following cell if you want an example with your own Token. 

Otherwise, you can download the data from here.

In [None]:
bucket = "INSERT BUCKET NAME"
path = "conceptual_captions"
s3_train_channel = f"s3://{bucket}/{path}"
image = '0.jpg'
image_file = f"./dta/{image}"

Or if you would like to use the original dataset from huggingface, you can download the parquet file using the following code, and then download the images independently

In [None]:
!mkdir -p dta
!aws s3 cp {s3_train_channel}/{image} ./dta/
!aws s3 cp {s3_train_channel}/dataset.parquet ./dta/

In [None]:
import pandas as pd

df = pd.read_parquet('./dta/dataset.parquet')

In [None]:
df.head(n=3)

In [None]:
caption = df['caption'][0]
from PIL import Image

print (caption)

Image.open(image_file)

In [None]:
!ls src

Additionally, the data you will be using comes from mscoco. However, you can also download from [here](https://huggingface.co/datasets/ChristophSchuhmann/MS_COCO_2017_URL_TEXT) which uses the dataset from [here](https://academictorrents.com/details/74dec1dd21ae4994dfd9069f9cb0443eb960c962). Then use this [link](https://github.com/rom1504/img2dataset) to quickly fill in the datasets files. For the purpose of this notebook you can download a few samples using the cell below.

# 2. Training
You will use distributed training, to do so you need to leverage any existing GPU's. The first cell will evaluate to see how many gpus are on the current system.

In [None]:
#Add Profiler

In [None]:
from sagemaker.debugger import ProfilerConfig, FrameworkProfile

profiler_config = ProfilerConfig(
    system_monitor_interval_millis=500, framework_profile_params=FrameworkProfile(num_steps=10)
)

In [None]:
local = False
output = None
def get_processes_per_host(instance_type):
    if instance_type == 'ml.g5.12xlarge':
        processes_per_host = 4
    elif 'local' in instance_type:
        from torch import cuda
        processes_per_host = cuda.device_count()
        local = True
    else:
        print ('Please look up the number of GPUs per node from the EC2 page here: https://aws.amazon.com/ec2/instance-types/ ')
    
    return processes_per_host


instance_type = 'ml.g5.12xlarge'

processes_per_host = get_processes_per_host(instance_type)
    

The following cell will enable you to build an estimator for training locally, and fit on the local dataset you previously built.

In [None]:
%%capture output
# If you want to train locally you will need to run the following 
if local :
    !./process.sh

In [None]:
import os
from sagemaker.huggingface import HuggingFace
from sagemaker import get_execution_role
from sagemaker.local import LocalSession
from sagemaker import Session
import boto3


est = HuggingFace(
    entry_point='finetune.py',
    source_dir='src',
    image_uri='763104351884.dkr.ecr.us-east-1.amazonaws.com' + 
     '/pytorch-training:1.12.1-gpu-py38-cu113-ubuntu20.04-sagemaker',
    sagemaker_session=Session() if 'local' not in instance_type else LocalSession(boto_session=LocalSession().boto_session),
    role=get_execution_role(),
    instance_type=instance_type,
    keep_alive_time_in_seconds = 28800,
    # output_path= can define s3 output here,
    py_version='py38',
    base_job_name='stable-diffusion',
    instance_count=1,
    # all opt/ml paths point to SageMaker training 
    hyperparameters={
        'pretrained_model_name_or_path':'/opt/ml/input/data/training/sd-base-model',
        'dataset_name':'/opt/ml/input/data/training/dataset.parquet',
        'caption_column':'caption',
        'image_column':'sm_key',
        'resolution':256,
        'mixed_precision':'fp16',
        'train_batch_size':2,
        'learning_rate': '1e-10',
        'max_train_steps':100,
        'num_train_epochs':1,
        'output_dir':'/opt/ml/model/sd-output-final',   
    },    
    distribution={"mpi":{"enabled":True,"processes_per_host":processes_per_host}},
    profiler_config=profiler_config
)

In [None]:
#Please note training can take upwards of 25 minutes (13 minutes for saving the model). 

In [None]:
est.fit(inputs={'training':s3_train_channel},wait=True)

The "Aborting on container exit" line may hang for up to 20 minutes due to the size of the model being compressed, saved, and uploaded.

In [None]:
print(est.model_data) #In case you have to restart kernel.

## 3. Inference
Prior to doing inference you will need to extand an existing Deep Learning Container. Feel free to look at Dockerfile-Inf under the src directory for more details on this file. Otherwise, this following cell will build a local container for use in this notebook.

Prior to deploying you will need to build your extended image.

In [None]:
%%capture output
#Process and push_to_ecr may take some time to complete
if local and (output is None):
    !./process.sh
!./src/push_to_ecr.sh
with open('output.txt','r') as f:
    image_uri = f.read()

Define your Model for deployment (This can be skipped due to the previous train job).

In [None]:
from sagemaker.huggingface import HuggingFaceModel
from sagemaker.local import LocalSession
from sagemaker.session import Session
import os

est=HuggingFaceModel(role=get_execution_role(),
                     py_version='py38',
                     model_data=est.model_data,
                     image_uri=image_uri.strip(),
                     sagemaker_session=LocalSession() if 'local' in instance_type else Session(),
                     model_server_workers= 1
)

Deploy your model for inference!

In [None]:
pred = est.deploy(instance_type=instance_type,
                  initial_instance_count=1)

Provide prompts for training. The first text argument is based on this current dataset.

In [None]:
prompts = [caption,'A photo of an astronaut riding a horse on mars', 
           'A dragonfruit wearing karate belt in the snow.', 
           'Teddy bear swimming at the Olympics 400m Butter-fly event.',
           'A cute sloth holding a small glowing treasure chest.']

For more parameters feel free to explore [here](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion), just add 'parameters':{'key':'value'} to the input dict.

In [None]:
outputs = [pred.predict({'inputs':prompt}) \
           for prompt in prompts]

In [None]:
outputs = [output['images'][0] for output in outputs]

In [None]:
def process_result(out):
    from PIL import Image
    from io import BytesIO
    import base64
    return Image.open(BytesIO(base64.b64decode(out)))

In [None]:
images = [[process_result(output),prompt] for output,prompt in zip(outputs,prompts)]

In [None]:
#Visualize the results from the inference

In [None]:
import matplotlib.pyplot as plt

for i in range(len(images)):
    plt.figure()
    plt.title(images[i][1])
    plt.imshow(images[i][0])

In [None]:
# clean up your endpoint
pred.delete_endpoint()