In [None]:
#@title Install Stability SDK with fine-tuning support
import getpass
import io
import logging
import os
import shutil
import sys
import time
from pathlib import Path
from zipfile import ZipFile

if os.path.exists("../src/stability_sdk"):
    sys.path.append("../src") # use local SDK src
else:
    path = Path('stability-sdk')
    if path.exists():
        shutil.rmtree(path)
        !pip uninstall -y stability-sdk
    !git clone -b "PLATFORM-339" --recurse-submodules https://github.com/Stability-AI/stability-sdk
    !pip install ./stability-sdk

In [None]:
#@title Connect to the Stability API
from stability_sdk.api import Context as ApiContext
from stability_sdk.finetune import *

# @markdown To get your API key visit https://dreamstudio.ai/account
STABILITY_HOST = "grpc-staging.stability.ai:443" #@param {type:"string"}
STABILITY_KEY = getpass.getpass('Enter your API Key')

engine_id = "stable-diffusion-xl-beta-v2-2-2" #@param ["stable-diffusion-xl-beta-v2-2-2"] {type:"string"}

# Create API context to query user info and generate images
api_context = ApiContext(STABILITY_HOST, STABILITY_KEY, generate_engine_id=engine_id+"-ft")
(balance, pfp) = api_context.get_user_info()
print(f"Logged in org:{api_context._user_organization_id} with balance:{balance}")

# Create a fine-tuning context for model training
ft_context = Context(STABILITY_HOST, STABILITY_KEY)

# Redirect logs to print statements so we can see them in the notebook
class PrintHandler(logging.Handler):
    def emit(self, record):
        print(self.format(record))
logging.getLogger().addHandler(PrintHandler())
logging.getLogger().setLevel(logging.INFO)


In [None]:
# List fine-tuned models for this user / organization
models = list_models(ft_context, org_id=api_context._user_organization_id)
print(f"Found {len(models)} models")
for model in models:
    print(f"  Model {model.id} {model.name} {model.status}")

In [None]:
#@title Specify folder of images or upload zip file
training_dir = "./train" #@param {type:"string"}

if not os.path.exists(training_dir):
    try:
        from google.colab import files

        upload_res = files.upload()
        training_dir = list(upload_res.keys())[0]
        print(f"Received {training_dir}")
        if not training_dir.endswith(".zip"):
            raise ValueError("Uploaded file must be a zip file")

        zf = ZipFile(io.BytesIO(upload_res[training_dir]), "r")
        training_dir = Path(training_dir).stem
        print(f"Extracting to {training_dir}")
        zf.extractall(training_dir)

    except ImportError:
        pass

print(f"Using training images from: {training_dir}")

In [None]:
#@title Perform fine-tuning
model_name = "cat-ft-01" #@param {type:"string"}
training_mode = "object" #@param ["face", "style", "object"] {type:"string"}
object_prompt = "cat" #@param {type:"string"}

# Gather training images
images = []
for filename in os.listdir(training_dir):
    if os.path.splitext(filename)[1].lower() in ['.png', '.jpg', '.jpeg']:
        images.append(os.path.join(training_dir, filename))

# Create the fine-tune model
params = FineTuneParameters(
    name=model_name,
    mode=FineTuneMode(training_mode),
    object_prompt=object_prompt,
    engine_id=engine_id,
)
model = create_model(ft_context, params, images)
print(f"Model {model_name} created.")
print(model)

In [None]:
# Check on training status
start_time = time.time()
while model.status != FineTuneStatus.COMPLETED and model.status != FineTuneStatus.FAILED:
    model = get_model(ft_context, model.id)
    elapsed = time.time() - start_time
    print(f"Model {model.name} ({model.id}) status: {model.status} for {elapsed:.2f} sec")
    time.sleep(5)

In [None]:
# If fine-tuning fails for some reason, you can resubmit the model
if model.status == FineTuneStatus.FAILED:
    print(f"Training failed, resubmitting")
    model = resubmit_model(ft_context, model.id)

In [None]:
# Generate an image using the fine-tuned model
results = api_context.generate(
    prompts=[f"Illustration of <{model.id}> as a wizard"],
    weights=[1],
    finetune_models=[model.id],
    finetune_weights=[0.7]
)
image = results[generation.ARTIFACT_IMAGE][0]
image

In [None]:
# Models can be updated to change settings before a resubmit or after training to rename
update_model(ft_context, model.id, name="cat-ft-01-renamed")

In [None]:
# Delete the model when it's no longer needed
delete_model(ft_context, model.id)