# Stability.AI A/B Testing Notebook

## Usage Instructions

1. Complete the input fields below as appropriate to specify output and config file locations.
2. Specify your experiment in yaml
  - See the `%%writefile` cell for an example. You can either modify the values in that cell to write a new `test_config.yaml` to parameterize your experiments, or upload an appropriate yaml file.
  - **`defaults`**: settings that will be used across test cases.
  - **`combinatorial_parameters`**: permuted to generate randomized settings which will be shared across a given sample of images shown to the user.
  - **`differentiators`**: specifies the experiment names and what settings are specific to each experimental test case.
    - Parameterize your experiment with two or more test cases. The current example uses three.
    - Each test case (each top level entry below `differentiators`) will be assigned its own api client, so you can use settings like `engine` or `grpc_host` as differentiating attributes
3. Running the "Load a random sample" cell will:
  1. log the results from the previous sample
  2. generate a visualization of the recorded experiment outcomes
  3. load a new random set of images to compare.
    - The ordering in which the respective test cases are displayed is randomized each time the cell is executed (i.e. with each new set of images)
    - Push the button below an image to pick it as your favorite. 
    - Click again to deselect if you cahnge your mind
    - The notebook does not currently constrain the user to only select one option, but that's how we recommend you use it. 
    - When you're satisfied with your selection, execute the cell again to log your feedback and generate a new set of images.

In [None]:
%%capture

########################
# install dependencies #
########################

try:
    import stability_sdk
except ImportError:
    # to do: requirements file
    !pip install stability-sdk
    !pip install omegaconf panel loguru

###########
# imports #
###########

# python stdlib
from collections import Counter
import copy
import importlib.metadata
import io
from itertools import product
import json
import os
from pathlib import Path
import random
import time
import warnings

# google colab stdlib
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from scipy.stats import beta

# external deps
from loguru import logger
from omegaconf import OmegaConf
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation

########################
# user workspace setup #
########################

notebook_user_id = 'digthatdata' # @param {type: 'string'}
if not notebook_user_id:
    raise ValueError("\n> New fone, who dis?\nPlease identify yourself.")

explog_fname = "outcomes.log"

proj_root_str = '${active_project}'
mount_gdrive = True # @param {type:'boolean'}
if mount_gdrive:
    from google.colab import drive
    drive.mount('/content/drive')
    proj_root_str = '/content/drive/MyDrive/AI/StabilityAbTesting/${active_project}'


project_name = 'abtest1' # @param {type:'string'}
if not project_name.strip():
    project_name = str(time.time())

experiments_configfile_name = 'test_config.yaml' # @param {type:'string'}

if not Path(experiments_configfile_name).exists():
  warnings.warn(
      f"Experiment config file {experiments_configfile_name} "
      "not detected. You may need to upload the experiment config "
      "to this workspace. Alternatively, you can create one by running "
      "the `%%writefile` cell below"
  )

# @markdown if gdrive is mounted and `local_config_priority` is selected, the notebook will check for a
# @markdown local experiment config file. If it finds one, it will load the 
# @markdown experiment from the local file and copy the config to the project folder on gdrive
# @markdown -- overwriting the remote experiment config if it exists.

# @markdown if `local_config_priority` is not selected, the notebook will only
# @markdown look for an experiment config in the project folder (which will be on the gdrive if that's been mounted).


# @markdown The intention here is to facilitate modifying experiment parameters from the notebook's `%%writefile` cell mid-experiment

local_config_priority = True # @param {type:'boolean'}

#################################
# Config for notebook workspace #
#################################

workspace_cfg = OmegaConf.create({
    'active_project':project_name,
    'project_root':proj_root_str,
    'gdrive_mounted':mount_gdrive,
    'notebook_user_id':notebook_user_id,
    'exp_cfg_fname': experiments_configfile_name,
    'explog_fname':explog_fname,
    'local_config_priority':local_config_priority,
})

with open('config.yaml','w') as fp:
    OmegaConf.save(config=workspace_cfg, f=fp.name)

#############################
# misc setup                #
# - global variables (yuck) #
# - function definitions    #
#############################

exp_cfg_fpath = Path(workspace_cfg.project_root) / workspace_cfg.exp_cfg_fname

# to do -> mirror this to the notebook workspace
running_score = Counter()

# this stuff could come from the local config, e.g. to increment
# sample numebr without needing a prefix
SAMPLE_IDX = 0
RANDOM_PREFIX = str(time.time())

# would be nice if there was a way to query the API for its version
SDK_VERSION = importlib.metadata.version("stability_sdk") # if installing from local repo, use git commit hash instead (or as suffix?)

import pandas as pd

def barchart(running_score):
    df = pd.Series(running_score)
    df.sort_index().plot.barh() # why doesn't sort_index() do what I need it to?
    plt.show()

def posterior_plot(running_score, alpha=0.95):
    barchart(running_score)
    # try:
    #     posterior_plot_binary(running_score, alpha)
    # except NotImplementedError:
    #     barchart(running_score)


def posterior_plot_binary(running_score, alpha=0.95):
    """
    Plots a MAP estimate for a binomial probability.
    I.e. counts are used as the parameters for a beta
    PDF, and a credible interval about the median is
    estimated. Really we want the interval to be around
    the mode, but this was easier to code and it's close 
    enough.
    
    alpha: significance level of CI
    """
    if len(running_score) > 2:
      raise NotImplementedError

    fig, ax = plt.subplots(1, 1)

    a,b = list(running_score.values())
    # jeffrey's prior
    a+=0.5
    b+=0.5
    x = np.linspace(
        #beta.ppf(0.01, a, b),
        #beta.ppf(0.99, a, b),
        0,1, 
        100)
    ax.plot(x, beta.pdf(x, a=a, b=b),
          'r-', label='beta pdf')
    ax.set_xlim(0,1)
    ax.get_yaxis().set_visible(False)
    plt.title(f"MAP estimate for likelihood that {list(running_score.keys())[0]} is preferred")

    mu = a / (a+b)
    print(f"mu: {mu}")
    median = beta.median(a=a,b=b)
    print(f"median: {median}")
    plt.vlines(x=mu, ymin=0, ymax=beta.pdf(mu, a, b), linestyles='dashed', color='blue')
    mode = mu
    if (a>1) and (b>1):
      mode = (a-1) / (a+b-2)  #\frac{\alpha-1}{\alpha+\beta-2}\! fo
      print(f"mode: {mode}")
    plt.vlines(x=mode, ymin=0, ymax=beta.pdf(mode, a, b), linestyles='dashed')
    # fuck it. this is an interval around the median instead of the mode, but good enough.
    lwr, upr = beta.interval(a=a,b=b, alpha=alpha)
    plt.vlines(x=lwr, ymin=0, ymax=beta.pdf(lwr, a, b), linestyles='dashed')
    plt.vlines(x=upr, ymin=0, ymax=beta.pdf(upr, a, b), linestyles='dashed')
    xs2 = np.linspace(lwr, upr, 100)
    plt.fill_between(xs2, beta.pdf(xs2, a=a, b=b), color='r', alpha=0.4)

    plt.show()

In [None]:
%%writefile test_config.yaml

### settings that will be used across test cases.
defaults:
  grpc_host: grpc.stability.ai:443
  # If API key not provided in test_config.yaml, user prompted with getpass
  key:

### randomly permuted to produce settings shared across test cases for a generated set of images
combinatorial_parameters:
  prompt:
    - mom's spaghetti, knees weak, arm's sweaty. but for real, mom's spaghetti is delicious
    - prompt with an optional middle part. {middle} this is the end of the prompt.
  cfg_scale:
    - 7
    - 9
    - 12
    - 15
  steps:
    - 40
    - 50
    - 60

##  specifies the experiment names and what settings are specific to each experimental test case.
# not a fan of this name. maybe call this section "experiments"?
differentiators:
  test_case_A:
    engine: stable-diffusion-512-v2-0
    prompt_chunks:
      middle: ''
  test_case_B:
    engine: stable-diffusion-512-v2-1
    prompt_chunks:
      middle: this is the optional middle of the prompt. it only goes with test_case_B.
  test_case_C:
    engine: stable-diffusion-v1-5
    # If using prompt chunks, all test_cases need at least a prompt_chunks dict with the same keys and 
    # empty strings as values. If empty strings aren't specified, you'll get "None" as the filler chunk.
    prompt_chunks:
      middle: ''
      # Don't do this, results in `middle:"None"`
      # middle:


In [None]:
# @markdown ## Load Experiments

from omegaconf import OmegaConf
import getpass
from stability_sdk import client

import panel as pn
pn.extension()


exp_cfg_fpath_out = Path(workspace_cfg.project_root) / workspace_cfg.exp_cfg_fname
exp_cfg_fpath = exp_cfg_fpath_out
if workspace_cfg.local_config_priority:
    exp_cfg_fpath = Path(workspace_cfg.exp_cfg_fname)
if exp_cfg_fpath.exists():
  cfg = OmegaConf.load(exp_cfg_fpath)
elif exp_cfg_fpath_out.exists():
  cfg = OmegaConf.load(exp_cfg_fpath_out)
else:
  raise RuntimeError(
      f"Experiment config file {workspace_cfg.exp_cfg_fname} not found."
      "Make sure you've saved or uploaded the file, "
      "and that it's named correctly."
)

Path(workspace_cfg.project_root).mkdir(parents=True, exist_ok=True)
with exp_cfg_fpath_out.open('w') as fp:
  OmegaConf.save(config=cfg, f=fp)


########################
# propogate invariants #
########################

test_case_names = list(cfg.differentiators.keys())
invariant_attr_names = list(cfg.defaults.keys())
for test_case in cfg.differentiators:
    for param in cfg.defaults:
        if param not in cfg.differentiators[test_case]:
            cfg.differentiators[test_case][param] = cfg.defaults[param]


#####################################
# request from user if not provided #
#####################################

required_attributes = [
    'grpc_host',
    #'api_key'
    'key',
]

for test_case in cfg.differentiators:
    for attr in required_attributes:
      if not cfg.differentiators[test_case].get(attr):
        cfg.differentiators[test_case][attr] = getpass.getpass(f"[{test_case}] {attr}: ")

##########################################
# Build a client for each differentiator #
##########################################

# Doing this because it's likely the differentiators are different engines or grpc endpoints
clients = {}
for test_case in cfg.differentiators:
  kargs = {}
  for arg in ['host','key','engine']:
    if arg in cfg.differentiators[test_case]:
      kargs[arg] = cfg.differentiators[test_case][arg]
  clients[test_case] = client.StabilityInference(**kargs)

#######################################
# precompute combinations and shuffle #
#######################################

gen = product(*cfg.combinatorial_parameters.values())
experiments = list(gen)
random.shuffle(experiments)

items = [] # initialize this null so we can log outcomes at the top of the loop


for test_case in cfg.differentiators:
  running_score[test_case]+=0



In [None]:
# @markdown # Load a random sample to score preference

SAMPLE_IDX += 1

##########################
# Log experiment outcome #
##########################

save_images = False # @param {type:'boolean'}
save_favorite_only = False # @param {type:'boolean'}


# to do: make this not a closure.
def log_items(items):
  #for (img, test_case, kwargs_gen, is_preference) in items: # to do: dictify
  recs = []
  for item in items:
    # assign image a filename
    img_fname = f"{RANDOM_PREFIX}_{SAMPLE_IDX}_{item['test_case']}.png"
    #rec = copy.deepcopy(item)
    rec = item
    img_fpath = Path(workspace_cfg.project_root) / img_fname
    # save image
    img = rec.pop('img')
    save_im = False
    if save_images or save_favorite_only:
      save_im = True
    if save_favorite_only and not rec['is_preference']:
      save_im = False
    if save_im:
      print(img_fpath)
      rec['img_fpath'] = str(img_fpath)
      img.save(img_fpath)
    # update outcome
    rec['is_preference'] = rec['button'].value
    if rec['is_preference']:
      running_score[rec['test_case']] += 1
    rec.pop('button')
    # log outcome
    recs.append(rec)
  outfile = Path(workspace_cfg.project_root) / explog_fname
  #with open(outfile, 'a') as f:
  with outfile.open('a') as f:
    json.dump(recs, f)
    f.write('\n')
  logger.debug(running_score)
    
if items:
  try:
    log_items(items)
    posterior_plot(running_score)
  except KeyError:
    # fuck it
    pass


SEED = random.randrange(0, 4294967295)

blind_test = False # @param {type: "boolean"}

def item_to_ux(
    item
    ):
  img = item['img']
  test_case = item['test_case']
  kwargs = item['kwargs']

  output = [f"# {test_case}"]
  if blind_test:
    output = []
  output += [img]
  if not blind_test:
    output += [f"{kwargs}"]
  else:
    output += [f"{kwargs_exp}"]
  toggle = pn.widgets.Toggle(name='Favorite', button_type='success')
  output.append(toggle)
  item['button'] = toggle
  item['is_preference'] = toggle.value
  return pn.Column(*output)


non_generation_arguments = ['grpc_host', 'engine', 'key']

rec = random.choice(experiments)

keys = cfg.combinatorial_parameters.keys()
kwargs_exp = dict(zip(keys, rec))
kwargs_exp['seed'] = SEED

items = []
for test_case, api in clients.items():
  logger.debug(f"requesting image for {test_case}")
  kwargs_test = copy.deepcopy(kwargs_exp)
  kwargs_diff = cfg.differentiators[test_case]
  kwargs_test.update(kwargs_diff)
  kwargs_test.pop('key')
  kwargs_gen = copy.deepcopy(kwargs_test)
  for key in non_generation_arguments:
    if key in kwargs_gen:
      kwargs_gen.pop(key)
  #########
  # handle prompt_chunks
  chunks = kwargs_gen.pop('prompt_chunks', {})
  if '{' in kwargs_gen['prompt']:    
    kwargs_gen['prompt'] = kwargs_gen['prompt'].format(**chunks)
  #########

  answers = api.generate(**kwargs_gen)
  for resp in answers:
    for artifact in resp.artifacts:
        if artifact.finish_reason == generation.FILTER:
            warnings.warn(
                "Your request activated the API's safety filters and could not be processed."
                "Please modify the prompt and try again.")
        if artifact.type == generation.ARTIFACT_IMAGE:
            img = Image.open(io.BytesIO(artifact.binary))
            img = img.resize([512, 512])
  items.append({
      'img':img,
      'test_case':test_case,
      'kwargs':kwargs_gen,
      'is_preference':False,
      # additional metadata
      'SDK_VERSION':SDK_VERSION,
      'timestamp':time.time(),
      'user_id': workspace_cfg.notebook_user_id,
      'project_name':workspace_cfg.active_project,
      })

random.shuffle(items)
pn.Row(*[item_to_ux(it) for it in items])