# Stability API Standard Feature Demo

This notebook showcases a few key features available through our API.
You will need to obtain keys, and may need to be whitelisted for some features

* Stability SDXL Keys are available here: https://platform.stability.ai/account/keys

*For a complete reference of the Stability API, please visit https://platform.stability.ai/docs/api-reference* <br>
Please note that a REST API and gRPC API are available.

In [None]:
#@title Install Dependencies
import requests
import shutil
import getpass
import os
import base64
from google.colab import files
from PIL import Image

In [None]:
#@title Load in Sample Images
#Feel free to replace these with your own images
url_mappings = {"dog_with_armor": "https://i.imgur.com/4nnSP8q.png",
                "dog_with_armor_inpaint": "https://i.imgur.com/eu44gJe.png",
                "dog_with_armor_inpaint_just_armor": "https://i.imgur.com/Mw6QU6P.png",
                "dog_outpaint_example": "https://i.imgur.com/yv9RxjQ.png",
                "outpaint_mask_1024_1024": "https://i.imgur.com/L1lqrXm.png"
                }
for name in url_mappings:
  response = requests.get(url_mappings[name], stream=True)
  with open(f'/content/{name}.png', 'wb') as out_file:
    response.raw.decode_content = True
    shutil.copyfileobj(response.raw, out_file)
  del response

In [None]:
#@title Stability API Key
# You will be prompted to enter your api keys after running this code
# You can view your API key here: https://next.platform.stability.ai/account/keys
api_key = getpass.getpass('Enter your API Key')

In [None]:
#@title Text To Image Example
url = "https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/text-to-image"

body = {
  "steps": 30,
  "width": 1024,
  "height": 1024,
  "seed": 0,
  "cfg_scale": 5,
  "samples": 1,
  "text_prompts": [
    {
      "text": "A painting of a cat wearing armor, intricate filigree, cinematic masterpiece digital art",
      "weight": 1
    },
    {
      "text": "blurry, bad",
      "weight": -1
    }
  ],
}

headers = {
  "Accept": "application/json",
  "Content-Type": "application/json",
  "Authorization": f"Bearer {api_key}",
}

response = requests.post(
  url,
  headers=headers,
  json=body,
)

if response.status_code != 200:
    raise Exception("Non-200 response: " + str(response.text))

data = response.json()

# make sure the out directory exists
if not os.path.exists("./out"):
    os.makedirs("./out")

for i, image in enumerate(data["artifacts"]):
    with open(f'./out/txt2img_{image["seed"]}.png', "wb") as f:
        f.write(base64.b64decode(image["base64"]))
        files.download(f.name)

In [None]:
#@title Inpainting Example
response = requests.post(
    "https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/image-to-image/masking",
    headers={
        "Accept": "application/json",
        "Authorization": f"Bearer {api_key}"
    },
    files={
        #replace init image and mask image with your image and mask
        "init_image": open("/content/dog_with_armor.png", "rb"),
        "mask_image": open("/content/dog_with_armor_inpaint_just_armor.png", "rb")
    },
    data={
    "mask_source": "MASK_IMAGE_BLACK",
		"steps": 40,
		"seed": 0,
		"cfg_scale": 5,
		"samples": 1,
		"text_prompts[0][text]": 'Dog Armor made of chocolate',
		"text_prompts[0][weight]": 1,
		"text_prompts[1][text]": 'blurry, bad',
		"text_prompts[1][weight]": -1,
    }
)

if response.status_code != 200:
    raise Exception("Non-200 response: " + str(response.text))

data = response.json()

# make sure the out directory exists
if not os.path.exists("./out"):
    os.makedirs("./out")

for i, image in enumerate(data["artifacts"]):
    with open(f'./out/img2img_{image["seed"]}.png', "wb") as f:
        f.write(base64.b64decode(image["base64"]))
        files.download(f.name)

In [None]:
#@title Inpainting - Change Background Example
response = requests.post(
    "https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/image-to-image/masking",
    headers={
        "Accept": "application/json",
        "Authorization": f"Bearer {api_key}"
    },
    files={
        "init_image": open("/content/dog_with_armor.png", "rb"),
        "mask_image": open("/content/dog_with_armor_inpaint.png", "rb")
    },
    data={
    # Flipping to white will make it inpaint but remove background, even though dog is masked black
    "mask_source": "MASK_IMAGE_WHITE",
		"steps": 40,
		"seed": 0,
		"cfg_scale": 5,
		"samples": 1,
		"text_prompts[0][text]": 'Medieval castle',
		"text_prompts[0][weight]": 1,
		"text_prompts[1][text]": 'blurry, bad',
		"text_prompts[1][weight]": -1,
    }
)

if response.status_code != 200:
    raise Exception("Non-200 response: " + str(response.text))

data = response.json()

# make sure the out directory exists
if not os.path.exists("./out"):
    os.makedirs("./out")

for i, image in enumerate(data["artifacts"]):
    with open(f'./out/img2img_{image["seed"]}.png', "wb") as f:
        f.write(base64.b64decode(image["base64"]))
        files.download(f.name)

In [None]:
#@title Outpainting Example

# Init image has to be the same size as mask image
# Paste the smaller init image onto the mask
initial_init_image = Image.open("/content/dog_outpaint_example.png")
# The mask is already blurred, which will improve coherence
mask = Image.open("/content/outpaint_mask_1024_1024.png")
mask.paste(initial_init_image)
mask.save('/content/dog_outpaint_init_image.png', quality=95)


response = requests.post(
    "https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/image-to-image/masking",
    headers={
        "Accept": "application/json",
        "Authorization": f"Bearer {api_key}"
    },
    files={
        "init_image": open("/content/dog_outpaint_init_image.png", "rb"),
        "mask_image": open("/content/outpaint_mask_1024_1024.png", "rb")
    },
    data={
    "mask_source": "MASK_IMAGE_BLACK",
		"steps": 40,
		"seed": 0,
		"cfg_scale": 5,
		"samples": 1,
		"text_prompts[0][text]": 'Medieval castle',
		"text_prompts[0][weight]": 1,
		"text_prompts[1][text]": 'blurry, bad',
		"text_prompts[1][weight]": -1,
    }
)

if response.status_code != 200:
    raise Exception("Non-200 response: " + str(response.text))

data = response.json()

# make sure the out directory exists
if not os.path.exists("./out"):
    os.makedirs("./out")

for i, image in enumerate(data["artifacts"]):
    with open(f'./out/img2img_{image["seed"]}.png', "wb") as f:
        f.write(base64.b64decode(image["base64"]))
        files.download(f.name)

In [None]:
#@title Image Upscaling Example
response = requests.post(
    "https://api.stability.ai/v1/generation/esrgan-v1-x2plus/image-to-image/upscale",
    headers={
        "Accept": "application/json",
        "Authorization": f"Bearer {api_key}"
    },
    files={
        "image": open("/content/dog_with_armor.png", "rb")
    },
    data={
        "width": 2048
    }
)

if response.status_code != 200:
    raise Exception("Non-200 response: " + str(response.text))

data = response.json()

# make sure the out directory exists
if not os.path.exists("./out"):
    os.makedirs("./out")

for i, image in enumerate(data["artifacts"]):
    with open(f'./out/img2img_{image["seed"]}.png', "wb") as f:
        f.write(base64.b64decode(image["base64"]))
        files.download(f.name)

# Stability SDXL Enterprise API Demo

Stability provides enterprise-grade features for customers that require faster speeds and dedicated managed services with support. These nodes can have significantly faster speeds based on the Stability Supercomputer, and may include prototype / preview models. <br>
The below section will leverage a demo node that is prepared at request. If you would like to try an enterprise node, please reach out to the Stability team.

In [None]:
# This demo notebook is designed to help illustrate the latency on prototype nodes with early access models
# This REST implementation will hit the special demo node. Please note that this is a demo only and results in production should exceed these speeds.
# Make sure to enable batch downloads from your browser if you want to see the images that will be downloaded locally
# OUTPUT: You will see Average and Total time for image generation. Note that in the first call there may be a warm-up time of up to 2 seconds, and tha colab will add an additional ~1.5 seconds

import base64
import requests
import os
import time
from google.colab import files


def make_request(index):
  #replace the <node> with the name of the node and module provided to you
    url = "https://test.api.stability.ai/v1/generation/<node>/<module>"
#Steps: Increasing can improve quality, and increase latency
    body = {
      "steps": 22,
      "width": 1024,
      "height": 1024,
      "seed": 0,
      "cfg_scale": 6,
      "samples": 1,
      "text_prompts": [
        {
          "text": "octane render of a barabaric software engineer",
          "weight": 1
        },
        {
          "text": "blurry, bad",
          "weight": -1
        }
      ],
    }

    headers = {
      "Accept": "application/json",
      "Content-Type": "application/json",
      #insert your Key
      "Authorization": "Bearer <key>",
    }

    response = requests.post(
      url,
      headers=headers,
      json=body,
    )

    if response.status_code != 200:
        raise Exception("Non-200 response: " + str(response.text))

    data = response.json()

    if not os.path.exists("./out"):
        os.makedirs("./out")

    for i, image in enumerate(data["artifacts"]):
        with open(f'./out/txt2img_{image["seed"]}_{index}.png', "wb") as f:
            f.write(base64.b64decode(image["base64"]))

            #Please comment the below line to execute pure benchmarking without downloading the images
            files.download(f.name)

total_time = 0

#Adjust num to change the number of images to get as batch
num = 2
for i in range(num):
    print(i)
    start = time.time()
    make_request(i)
    end = time.time()
    total_time += (end - start)

print("Average: ", total_time/num)
print("Total_Time: ", total_time)
print("Num Iterations: ", num)

# SDXL Finetuning REST API Demo

Stability is offering a private beta of its fine-tuning service to select customers. <br>

Note that this is a **developer beta** - bugs and quality issues with the generated fine-tunes may occur. Please reach out to Stability if this is the case - and share what you've made as well!

The code below hits the Stability REST API.  This REST API contract is rather solid, so it's unlikely to see large changes before the production release of fine-tuning.

Known issues:

* Style fine-tunes may result in overfitting - if this is the case, uncomment the `# weight=1.0` field of `DiffusionFineTune` in the diffusion section and provide a value between -1 and 1. You may need to go as low as 0.2 or 0.1.
* We will be exposing test parameters soon - please reach out with examples of datasets that produce overfitting or errors if you have them.

In [None]:
#@title Stability API key
import getpass

#@markdown Execute this step and paste your API key in the box that appears. <br/> Visit https://platform.stability.ai/account/keys to get your API key! <br/> <br/> <em>Note: If you are not on the fine-tuning whitelist you will receive an error during training.</em>

API_KEY = getpass.getpass('Paste your Stability API Key here and press Enter: ')

API_HOST = "https://preview-api.stability.ai"

ENGINE_ID = "stable-diffusion-xl-1024-v1-0"

In [None]:
#@title Initialize the REST API wrapper
import io
import logging
import requests
import os
import shutil
import sys
import time
import json
import base64
from enum import Enum
from dataclasses import dataclass, is_dataclass, field, asdict
from typing import List, Optional, Any
from IPython.display import clear_output
from pathlib import Path
from PIL import Image
from zipfile import ZipFile


class Printable:
    """ Helper class for printing a class to the console. """

    @staticmethod
    def to_json(obj: Any) -> Any:
        if isinstance(obj, Enum):
            return obj.value
        if is_dataclass(obj):
            return asdict(obj)

        return obj

    def __str__(self):
        return f"{self.__class__.__name__}: {json.dumps(self, default=self.to_json, indent=4)}"


class ToDict:
    """ Helper class to simplify converting dataclasses to dicts. """

    def to_dict(self):
        return {k: v for k, v in asdict(self).items() if v is not None}


@dataclass
class FineTune(Printable):
    id: str
    user_id: str
    name: str
    mode: str
    engine_id: str
    training_set_id: str
    status: str
    failure_reason: Optional[str] = field(default=None)
    duration_seconds: Optional[int] = field(default=None)
    object_prompt: Optional[str] = field(default=None)


@dataclass
class DiffusionFineTune(Printable, ToDict):
    id: str
    token: str
    weight: Optional[float] = field(default=None)


@dataclass
class TextPrompt(Printable, ToDict):
    text: str
    weight: Optional[float] = field(default=None)


class Sampler(Enum):
    DDIM = "DDIM"
    DDPM = "DDPM"
    K_DPMPP_2M = "K_DPMPP_2M"
    K_DPMPP_2S_ANCESTRAL = "K_DPMPP_2S_ANCESTRAL"
    K_DPM_2 = "K_DPM_2"
    K_DPM_2_ANCESTRAL = "K_DPM_2_ANCESTRAL"
    K_EULER = "K_EULER"
    K_EULER_ANCESTRAL = "K_EULER_ANCESTRAL"
    K_HEUN = "K_HEUN"
    K_LMS = "K_LMS"

    @staticmethod
    def from_string(val) -> Enum or None:
        for sampler in Sampler:
            if sampler.value == val:
                return sampler
        raise Exception(f"Unknown Sampler: {val}")


@dataclass
class TextToImageParams(Printable):
    fine_tunes: List[DiffusionFineTune]
    text_prompts: List[TextPrompt]
    samples: int
    sampler: Sampler
    engine_id: str
    steps: int
    seed: Optional[int] = field(default=0)
    cfg_value: Optional[int] = field(default=7)
    width: Optional[int] = field(default=1024)
    height: Optional[int] = field(default=1024)


@dataclass
class DiffusionResult:
    base64: str
    seed: int
    finish_reason: str

    def __str__(self):
        return f"DiffusionResult(base64='too long to print', seed='{self.seed}', finish_reason='{self.finish_reason}')"

    def __repr__(self):
        return self.__str__()


@dataclass
class TrainingSetBase(Printable):
    id: str
    name: str


@dataclass
class TrainingSetImage(Printable):
    id: str


@dataclass
class TrainingSet(TrainingSetBase):
    images: List[TrainingSetImage]


class FineTuningRESTWrapper:
    """
    Helper class to simplify interacting with the fine-tuning service via
    Stability's REST API.

    While this class can be copied to your local environment, it is not likely
    robust enough for your needs and does not support all of the features that
    the REST API offers.
    """

    def __init__(self, api_key: str, api_host: str):
        self.api_key = api_key
        self.api_host = api_host

    def create_fine_tune(self,
                         name: str,
                         images: List[str],
                         engine_id: str,
                         mode: str,
                         object_prompt: Optional[str] = None) -> FineTune:
        print(f"Creating {mode} fine-tune called '{name}' using {len(images)} images...")

        payload = {"name": name, "engine_id": engine_id, "mode": mode}
        if object_prompt is not None:
            payload["object_prompt"] = object_prompt

        # Create a training set
        training_set_id = self.create_training_set(name=name)
        payload["training_set_id"] = training_set_id
        print(f"\tCreated training set {training_set_id}")

        # Add images to the training set
        for image in images:
            print(f"\t\tAdding {os.path.basename(image)}")
            self.add_image_to_training_set(
                training_set_id=training_set_id,
                image=image
            )

        # Create the fine-tune
        print(f"\tCreating a fine-tune from the training set")
        response = requests.post(
            f"{self.api_host}/v1/fine-tunes",
            json=payload,
            headers={
                "Authorization": self.api_key,
                "Content-Type": "application/json"
            }
        )
        raise_on_non200(response)
        print(f"\tCreated fine-tune {response.json()['id']}")

        print(f"Success")
        return FineTune(**response.json())

    def get_fine_tune(self, fine_tune_id: str) -> FineTune:
        response = requests.get(
            f"{self.api_host}/v1/fine-tunes/{fine_tune_id}",
            headers={"Authorization": self.api_key}
        )

        raise_on_non200(response)

        return FineTune(**response.json())

    def list_fine_tunes(self) -> List[FineTune]:
        response = requests.get(
            f"{self.api_host}/v1/fine-tunes",
            headers={"Authorization": self.api_key}
        )

        raise_on_non200(response)

        return [FineTune(**ft) for ft in response.json()]

    def rename_fine_tune(self, fine_tune_id: str, name: str) -> FineTune:
        response = requests.patch(
            f"{self.api_host}/v1/fine-tunes/{fine_tune_id}",
            json={"operation": "RENAME", "name": name},
            headers={
                "Authorization": self.api_key,
                "Content-Type": "application/json"
            }
        )

        raise_on_non200(response)

        return FineTune(**response.json())

    def retrain_fine_tune(self, fine_tune_id: str) -> FineTune:
        response = requests.patch(
            f"{self.api_host}/v1/fine-tunes/{fine_tune_id}",
            json={"operation": "RETRAIN"},
            headers={
                "Authorization": self.api_key,
                "Content-Type": "application/json"
            }
        )

        raise_on_non200(response)

        return FineTune(**response.json())

    def delete_fine_tune(self, fine_tune: FineTune):
        # Delete the underlying training set
        self.delete_training_set(fine_tune.training_set_id)

        # Delete the fine-tune
        response = requests.delete(
            f"{self.api_host}/v1/fine-tunes/{fine_tune.id}",
            headers={"Authorization": self.api_key}
        )

        raise_on_non200(response)

    def create_training_set(self, name: str) -> str:
        response = requests.post(
            f"{self.api_host}/v1/training-sets",
            json={"name": name},
            headers={
                "Authorization": self.api_key,
                "Content-Type": "application/json"
            }
        )

        raise_on_non200(response)

        return response.json().get('id')

    def get_training_set(self, training_set_id: str) -> TrainingSet:
        response = requests.get(
            f"{self.api_host}/v1/training-sets/{training_set_id}",
            headers={"Authorization": self.api_key}
        )

        raise_on_non200(response)

        return TrainingSet(**response.json())

    def list_training_sets(self) -> List[TrainingSetBase]:
        response = requests.get(
            f"{self.api_host}/v1/training-sets",
            headers={"Authorization": self.api_key}
        )

        raise_on_non200(response)

        return [TrainingSetBase(**tsb) for tsb in response.json()]

    def add_image_to_training_set(self, training_set_id: str, image: str) -> str:
        with open(image, 'rb') as image_file:
            response = requests.post(
                f"{self.api_host}/v1/training-sets/{training_set_id}/images",
                headers={"Authorization": self.api_key},
                files={'image': image_file}
            )

        raise_on_non200(response)

        return response.json().get('id')

    def remove_image_from_training_set(self, training_set_id: str, image_id: str) -> None:
        response = requests.delete(
            f"{self.api_host}/v1/training-sets/{training_set_id}/images/{image_id}",
            headers={"Authorization": self.api_key}
        )

        raise_on_non200(response)

    def delete_training_set(self, training_set_id: str) -> None:
        response = requests.delete(
            f"{self.api_host}/v1/training-sets/{training_set_id}",
            headers={"Authorization": self.api_key}
        )

        raise_on_non200(response)

    def text_to_image(self, params: TextToImageParams) -> List[DiffusionResult]:
        payload = {
            "fine_tunes": [ft.to_dict() for ft in params.fine_tunes],
            "text_prompts": [tp.to_dict() for tp in params.text_prompts],
            "samples": params.samples,
            "sampler": params.sampler.value,
            "steps": params.steps,
            "seed": params.seed,
            "width": params.width,
            "height": params.height,
            "cfg_value": params.cfg_value,
        }

        response = requests.post(
            f"{self.api_host}/v1/generation/{params.engine_id}/text-to-image",
            json=payload,
            headers={
                "Authorization": self.api_key,
                "Accept": "application/json",
            }
        )

        raise_on_non200(response)

        return [
            DiffusionResult(base64=item["base64"], seed=item["seed"], finish_reason=item["finishReason"])
            for item in response.json()["artifacts"]
        ]


def raise_on_non200(response):
    if 200 <= response.status_code < 300:
        return
    raise Exception(f"Status code {response.status_code}: {json.dumps(response.json(), indent=4)}")


# Redirect logs to print statements so we can see them in the notebook
class PrintHandler(logging.Handler):
    def emit(self, record):
        print(self.format(record))
logging.getLogger().addHandler(PrintHandler())
logging.getLogger().setLevel(logging.INFO)

# Initialize the fine-tune service
rest_api = FineTuningRESTWrapper(API_KEY, API_HOST)

In [None]:
#@title List your existing fine-tunes

fine_tunes = rest_api.list_fine_tunes()
print(f"Found {len(fine_tunes)} models")
for fine_tune in fine_tunes:
    print(f"  Model {fine_tune.id} {fine_tune.status:<9} {fine_tune.name}")

## Add Training Images

For training, we need a dataset of images in a `.zip` file.

<em>Please only upload images that you have the permission to use.</em>


### Image Dimensions

- Images **cannot** have any side less than 328px
- Images **cannot** be larger than 10MB

There is no upper-bound for what we'll accept for an image's dimensions, but any side above 1024px will be scaled down to 1024px, while preserving aspect ratio. For example:
- `3024x4032` will be scaled down to `768x1024`
- `1118x1118` will be scaled down to `1024x1024`


### Image Quantity

- Datasets **cannot** have fewer than 3 images
- Datasets **cannot** have more than 64 images

A larger dataset often tends to result in a more accurate fine-tune, but will also take longer to train.

While each mode can accept up to 64 images, we have a few suggestions for a starter dataset based on the mode you are using:
*   `FACE`: 6 or more images.
*   `OBJECT`: 6 - 10 images.
*   `STYLE`: 20 - 30 images.

In [None]:
#@title Upload ZIP file of images
training_dir = "./train"
Path(training_dir).mkdir(exist_ok=True)
try:
    from google.colab import files

    upload_res = files.upload()
    extracted_dir = list(upload_res.keys())[0]
    print(f"Received {extracted_dir}")
    if not extracted_dir.endswith(".zip"):
        raise ValueError("Uploaded file must be a zip file")

    zf = ZipFile(io.BytesIO(upload_res[extracted_dir]), "r")
    extracted_dir = Path(extracted_dir).stem
    print(f"Extracting to {extracted_dir}")
    zf.extractall(extracted_dir)

    for root, dirs, files in os.walk(extracted_dir):
        for file in files:
            source_path = os.path.join(root, file)
            target_path = os.path.join(training_dir, file)

            # Ignore Mac-specific files
            if 'MACOSX' in source_path or 'DS' in source_path:
              continue

            # Move the file to the target directory
            print('Copying', source_path, '==>', target_path)
            shutil.move(source_path, target_path)


except ImportError:
    pass

print(f"Using training images from: {training_dir}")

## Train a Fine-Tune

Now we're ready to train our fine-tune. Use the parameters below to configure the name and the kind of fine-tune

Please note that the training duration will vary based on:
- The number of images in your dataset
- The `training_mode` used
- The `engine_id` that is being fine-tuned on

The following are some rough estimates for the training duration for each mode based on our recommended dataset sizes:

* `FACE`: 4 - 5 minutes.
* `OBJECT`: 5 - 10 minutes.
* `STYLE`: 20 - 30 minutes.

In [None]:
#@title Begin Training
fine_tune_name = "my dog spot" #@param {type:"string"}
#@markdown > Requirements: <ul><li>Must be unique (only across your account, not globally)</li> <li>Must be between 3 and 64 characters (inclusive)</li> <li>Must only contain letters, numbers, spaces, or hyphens</li></ul>
training_mode = "OBJECT" #@param ["FACE", "STYLE", "OBJECT"] {type:"string"}
#@markdown > Determines the kind of fine-tune you're creating: <ul><li><code>FACE</code> - a fine-tune on faces; expects pictures containing a face; automatically crops and centers on the face detected in the input photos.</li> <li> <code>OBJECT</code> - a fine-tune on a particular object (e.g. a bottle); segments out the object using the `object_prompt` below</li> <li><code>STYLE</code> - a fine-tune on a particular style (e.g. satellite photos of earth); crops the images and filters for image quality.</li></ul>
object_prompt = "dog" #@param {type:"string"}
#@markdown > Used for segmenting out your subject when the `training_mode` is `OBJECT`.  (i.e. if you want to fine tune on a cat, put `cat` - for a bottle of liquor, use `bottle`. In general, it's best to use the most general word you can to describe your object.)

# Gather training images
images = []
for filename in os.listdir(training_dir):
    if os.path.splitext(filename)[1].lower() in ['.png', '.jpg', '.jpeg', '.heic']:
        images.append(os.path.join(training_dir, filename))

# Create the fine-tune
fine_tune = rest_api.create_fine_tune(
    name=fine_tune_name,
    images=images,
    mode=training_mode,
    object_prompt=object_prompt if training_mode == "OBJECT" else None,
    engine_id=ENGINE_ID,
)

print()
print(fine_tune)

In [None]:
#@title Wait For Training to Finish
start_time = time.time()
while fine_tune.status != "COMPLETED" and fine_tune.status != "FAILED":
    fine_tune = rest_api.get_fine_tune(fine_tune.id)
    elapsed = time.time() - start_time
    clear_output(wait=True)
    print(f"Training '{fine_tune.name}' ({fine_tune.id}) status: {fine_tune.status} for {elapsed:.0f} seconds")
    time.sleep(10)

clear_output(wait=True)
status_message = "completed" if fine_tune.status == "COMPLETED" else "failed"
print(f"Training '{fine_tune.name}' ({fine_tune.id}) {status_message} after {elapsed:.0f} seconds")

In [None]:
#@title (Optional) Retrain if Training Failed
if fine_tune.status == "FAILED":
    print(f"Training failed, due to {fine_tune.failure_reason}. Retraining...")
    fine_tune = rest_api.retrain_fine_tune(fine_tune.id)

## Use your Fine-Tune

Time to diffuse!  The example below uses a single fine-tune, but using multiple fine-tunes is where this process really shines.  While this Colab doesn't directly support diffusing with multiple fine-tunes, you can still try it out by commenting out the

In [None]:
#@title <font color="#FFFFFF">Generate Images

prompt_token="$my-dog" #@param {type:"string"}
#@markdown > This token is an alias for your fine-tune, allowing you to reference your fine-tune directly in your prompt. Each fine-tune you want to diffuse with must provide a unique alias. <br/><br/> For example, if your token was `$my-dog` you might use a prompt like: `a picture of $my-dog` or `$my-dog chasing a rabbit`. <br/><br/> If you have more than one fine-tune you can combine them!  Given some fine-tune of film noir images you could use a prompt like `$my-dog in the style of $film-noir`.
prompt="a photo of $my-dog"  #@param {type:"string"}
#@markdown > The prompt to diffuse with.  Must contain the `prompt_token` at least once.
dimensions="1024x1024" #@param ['1024x1024', '1152x896', '1216x832', '1344x768', '1536x640', '640x1536', '768x1344', '832x1216', '896x1152']
#@markdown > The dimensions of the image to generate (width x height).
samples=4 #@param {type:"slider", min:1, max:10, step:1}
#@markdown > The number of images to generate. Requesting a large number of images may negatively response time.
steps=32 #@param {type:"slider", min:30, max:60, step:1}
#@markdown > The number of iterations or stages a diffusion model goes through in the process of generating an image from a given text prompt. Lower steps will generate more quickly, but if steps are lowered too much, image quality will suffer. Images with higher steps take longer to generate, but often give more detailed results.
cfg_value=7 #@param {type:"slider", min:0, max:35, step:1}
#@markdown > CFG (Classifier Free Guidance) scale determines how strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt).
seed=0  #@param {type:"number"}
#@markdown > The noise seed to use during diffusion.  Using `0` means a random seed will be generated for each image.  If you provide a non-zero value, images will be far less random.

params = TextToImageParams(
    fine_tunes=[
        DiffusionFineTune(
            id=fine_tune.id,
            token=prompt_token,
            # Uncomment the following to provide a weight for the fine-tune
            # weight=1.0
        ),

        # Uncomment the following to use multiple fine-tunes at once
        # DiffusionFineTune(
        #     id="",
        #     token="",
        #     # weight=1.0
        # ),
    ],
    text_prompts=[
        TextPrompt(
            text=prompt,
            # weight=1.0
        ),
    ],
    engine_id=ENGINE_ID,
    samples=samples,
    steps=steps,
    seed=0,
    cfg_value=cfg_value,
    width=int(dimensions.split("x")[0]),
    height=int(dimensions.split("x")[1]),
    sampler=Sampler.K_DPMPP_2S_ANCESTRAL
)

start_time = time.time()
images = rest_api.text_to_image(params)

elapsed = time.time() - start_time
print(f"Diffusion completed in {elapsed:.0f} seconds!")
print(f"{len(images)} result{'s' if len(images) > 1 else ''} will be displayed below momentarily (depending on the speed of Colab).\n")

for image in images:
  display(Image.open(io.BytesIO(base64.b64decode(image.base64))))

In [None]:
#@title (Optional) Download Images
from google.colab import files

if not os.path.exists("./out"):
    os.makedirs("./out")

for index, image in enumerate(images):
   with open(f'./out/txt2img_{image.seed}_{index}.png', "wb") as f:
      f.write(base64.b64decode(image.base64))
      files.download(f.name)

In [None]:
#@title (Optional) Rename Fine-Tune

name = "" #@param {type:"string"}
rest_api.rename_fine_tune(fine_tune.id, name=name)