In [1]:
#@title Install Stability SDK with fine-tuning support
import os
import shutil
import sys
from pathlib import Path

if os.path.exists("../src/stability_sdk"):
    sys.path.append("../src") # use local SDK src
else:
    path = Path('stability-sdk')
    if path.exists():
        shutil.rmtree(path)
        !pip uninstall -y stability-sdk
    !git clone -b "PLATFORM-339" --recurse-submodules https://github.com/Stability-AI/stability-sdk
    !pip install ./stability-sdk

In [2]:
#@title Connect to the Stability API
import getpass
from stability_sdk.api import Context as ApiContext
from stability_sdk.finetune import *

# @markdown To get your API key visit https://dreamstudio.ai/account
STABILITY_HOST = "grpc-staging.stability.ai:443" #@param {type:"string"}
STABILITY_KEY = getpass.getpass('Enter your API Key')

api_context = ApiContext(STABILITY_HOST, STABILITY_KEY)
(balance, pfp) = api_context.get_user_info()
print(f"Logged in org:{api_context._user_organization_id} with balance:{balance}")

# Create a fine-tuning context
ft_context = Context(STABILITY_HOST, STABILITY_KEY)

In [None]:
# List fine tuned models for this user / organization
models = list_models(ft_context, org_id=api_context._user_organization_id)
print(f"Found {len(models)} models")
for model in models:
    print(f"  Model {model.id} {model.name} {model.status}")

In [None]:
#@title Perform fine-tuning
training_image_path = "./train" #@param {type:"string"}
model_name = "cat-ft-01" #@param {type:"string"}
training_mode = "object" #@param ["none", "face", "style", "object"] {type:"string"}
object_name = "cat" #@param {type:"string"}
engine_id = "stable-diffusion-512-v2-1" #@param {type:"string"}

# Gather training images
images = []
for filename in os.listdir(training_image_path):
    if os.path.splitext(filename)[1].lower() in ['.png', '.jpg', '.jpeg']:
        images.append(os.path.join(training_image_path, filename))

# Create the fine-tune model
params = FineTuneParameters(
    name=model_name,
    mode=FineTuneMode(training_mode),
    object_name=object_name,
    engine_id=engine_id,
)
model = create_model(ft_context, params, images)
print(model)

In [None]:
# Check on training status
model = get_model(ft_context, model.id)
print(f"Model {model.id} {model.name} {model.status}")

In [None]:
# If fine-tuning fails for some reason, you can resubmit the model
if model.status == FineTuneStatus.FAILED:
    print(f"Training failed, resubmitting")
    model = resubmit_model(ft_context, model.id)

In [None]:
# Generate an image using the fine-tuned model
results = api_context.generate(["a cute fluffy cat"], [1.0], finetune_model="cat-ft-01")
image = results[generation.ARTIFACT_IMAGE][0]
image

In [None]:
# Models can be updated to change settings before a resubmit or after training to rename
update_model(ft_context, model.id, name="cat-ft-01-renamed")

In [None]:
# Delete the model when it's no longer needed
delete_model(ft_context, model.id)