In [1]:
#@title Install Stability SDK with fine-tuning support
import getpass
import io
import logging
import os
import shutil
import sys
import time
from IPython.display import clear_output
from pathlib import Path
from zipfile import ZipFile

if os.path.exists("../src/stability_sdk"):
    sys.path.append("../src") # use local SDK src
else:
    path = Path('stability-sdk')
    if path.exists():
        shutil.rmtree(path)
        !pip uninstall -y stability-sdk
    !git clone -b "PLATFORM-339" --recurse-submodules https://github.com/Stability-AI/stability-sdk
    !pip install ./stability-sdk

In [None]:
#@title Connect to the Stability API
from stability_sdk.api import Context, generation
from stability_sdk.finetune import (
    create_model, delete_model, get_model, list_models, resubmit_model, update_model,
    FineTuneMode, FineTuneParameters, FineTuneStatus
)

# @markdown To get your API key visit https://dreamstudio.ai/account
STABILITY_HOST = "grpc-staging.stability.ai:443" #@param {type:"string"}
STABILITY_KEY = getpass.getpass('Enter your API Key')

engine_id = "stable-diffusion-xl-1024-v0-9" #@param ["stable-diffusion-xl-1024-v0-9", "stable-diffusion-xl-1024-v1-0"] {type:"string"}

# Create API context to query user info and generate images
context = Context(STABILITY_HOST, STABILITY_KEY, generate_engine_id=engine_id)
(balance, pfp) = context.get_user_info()
print(f"Logged in org:{context._user_organization_id} with balance:{balance}")

# Redirect logs to print statements so we can see them in the notebook
class PrintHandler(logging.Handler):
    def emit(self, record):
        print(self.format(record))
logging.getLogger().addHandler(PrintHandler())
logging.getLogger().setLevel(logging.INFO)

In [None]:
# List fine-tuned models for this user / organization
models = list_models(context, org_id=context._user_organization_id)
print(f"Found {len(models)} models")
for model in models:
    print(f"  Model {model.id} {model.name} {model.status}")

In [None]:
#@title Specify folder of images or upload zip file
training_dir = "./train" #@param {type:"string"}

if not os.path.exists(training_dir):
    try:
        from google.colab import files

        upload_res = files.upload()
        training_dir = list(upload_res.keys())[0]
        print(f"Received {training_dir}")
        if not training_dir.endswith(".zip"):
            raise ValueError("Uploaded file must be a zip file")

        zf = ZipFile(io.BytesIO(upload_res[training_dir]), "r")
        training_dir = Path(training_dir).stem
        print(f"Extracting to {training_dir}")
        zf.extractall(training_dir)

    except ImportError:
        pass

print(f"Using training images from: {training_dir}")

In [None]:
#@title Perform fine-tuning
model_name = "cat-ft-01" #@param {type:"string"}
training_mode = "object" #@param ["face", "style", "object"] {type:"string"}
object_prompt = "cat" #@param {type:"string"}

# Gather training images
images = []
for filename in os.listdir(training_dir):
    if os.path.splitext(filename)[1].lower() in ['.png', '.jpg', '.jpeg']:
        images.append(os.path.join(training_dir, filename))

# Create the fine-tune model
params = FineTuneParameters(
    name=model_name,
    mode=FineTuneMode(training_mode),
    object_prompt=object_prompt,
    engine_id=engine_id,
)
model = create_model(context, params, images)
print(f"Model {model_name} created.")
print(model)

In [6]:
# Check on training status
start_time = time.time()
while model.status != FineTuneStatus.COMPLETED and model.status != FineTuneStatus.FAILED:
    model = get_model(context, model.id)
    elapsed = time.time() - start_time
    clear_output(wait=True)
    print(f"Model {model.name} ({model.id}) status: {model.status} for {elapsed:.0f} seconds")
    time.sleep(5)

clear_output(wait=True)
status_message = "completed" if model.status == FineTuneStatus.COMPLETED else "failed"
print(f"Model {model.name} ({model.id}) {status_message} after {elapsed:.0f} seconds")

In [13]:
# If fine-tuning fails for some reason, you can resubmit the model
if model.status == FineTuneStatus.FAILED:
    print("Training failed, resubmitting")
    model = resubmit_model(context, model.id)

In [None]:
# Generate an image using the fine-tuned model
results = context.generate(
    prompts=[f"Illustration of <{model.id}:0.7> as a wizard"],
    weights=[1],
    width=1024,
    height=1024,
    seed=42,
    sampler=generation.SAMPLER_DDIM,
    preset="photographic",
)
image = results[generation.ARTIFACT_IMAGE][0]
display(image)

In [None]:
# Models can be updated to change settings before a resubmit or after training to rename
update_model(context, model.id, name="cat-ft-01-renamed")

In [None]:
# Delete the model when it's no longer needed
delete_model(context, model.id)

In [None]:
#@title Example using StabilityInference class
import warnings
from stability_sdk.client import StabilityInference
from PIL import Image

si = StabilityInference(STABILITY_HOST, STABILITY_KEY, engine=engine_id)
results = si.generate(
    f"Illustration of <{model.id}:0.7> as a wizard",
    width=1024, 
    height=1024, 
    seed=42,
    sampler=generation.SAMPLER_DDIM,
    style_preset="photographic"
)
for resp in results:
    for artifact in resp.artifacts:
        if artifact.finish_reason == generation.FILTER:
            warnings.warn(
                "Your request activated the API's safety filters and could not be processed."
                "Please modify the prompt and try again.")
        if artifact.type == generation.ARTIFACT_IMAGE:
            display(Image.open(io.BytesIO(artifact.binary)))