# Tables and Figures
This notebook contains the code we used to generate some of the tables and figures in the paper.

## Load Packages

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import glob
import json
import textwrap

import requests
from PIL import Image
from io import BytesIO
sns.set('talk')

# Quantitative Results

### Steerability of Image Generation Models

In [31]:
df = pd.read_csv('data/summary_metrics.csv')

In [None]:
df_sorted = df.sort_values('rating_avg')
df_sorted = df_sorted[df_sorted['model'] != 'Average']  # Remove average if present
x = np.arange(len(df_sorted))
width = 0.35

# Create figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))

# Left panel - Ratings
rating_bars = ax1.bar(x, df_sorted['rating_avg'], 
                     yerr=df_sorted['rating_sem'],
                     capsize=5,
                     label='All attempts')
ax1.set_ylabel('Similarity to goal image')
ax1.set_xticks(x)
ax1.axhline(10, color='darkred', linestyle='--', linewidth=1, label='Perfect similarity')
ax1.set_ylim(0, 11)
ax1.set_xticklabels(df_sorted['model'], rotation=45, ha='right')
ax1.legend(loc='upper right')
# Right panel - Attempts
color_palette = sns.color_palette()
first_bars = ax2.bar(x - width/2, df_sorted['first_attempt_avg'], width, 
                    label='First attempt', 
                    yerr=df_sorted['first_attempt_sem'],
                    capsize=5, color=color_palette[1])
last_bars = ax2.bar(x + width/2, df_sorted['last_attempt_avg'], width, 
                    label='Last attempt',
                    yerr=df_sorted['last_attempt_sem'],
                    capsize=5, color=color_palette[2])
# ax2.set_ylabel('Similarity to goal image\n (human rating)')
ax2.set_xticks(x)
ax2.set_xticklabels(df_sorted['model'], rotation=45, ha='right')

# Add a dashed line at 10 for perfect similarity in right panel
ax2.axhline(10, color='darkred', linestyle='--', linewidth=1, label='Perfect similarity')
ax2.set_ylim(0, 11)
ax2.legend()

plt.tight_layout()

plt.savefig('all_attempts_barchart.pdf', bbox_inches='tight', dpi=300)


In [None]:
df[['model', 'last_over_first_avg', 'last_over_first_sem', 'first_matches_goal_avg', 'first_matches_goal_sem', 'last_matches_goal_avg', 'last_matches_goal_sem']]

### Text Steering vs Image Steering vs. Image Steering with RL

In [33]:
unique_attempts= [1, 2, 3, 4, 5]
means_text_steering= [0.0, 0.002495847021540006, 0.01192834104100863, 0.02294766716659069, 0.02898228106399377]
sems_text_steering= [0.0, 0.014831217294403013, 0.015463105048206157, 0.014722723334135723, 0.01510288011453466]
means_random_stepsize= [0.0, 0.02797472513980747, 0.04062674078882111, 0.04724131867011882, 0.052637251082414425]
sems_random_stepsize= [0.0, 0.004340646216735713, 0.004960735710954918, 0.005284981183033871, 0.005362842282494922]
means_new_rl= [0.0, 0.04049745921430917, 0.05952271690656399, 0.06805097260351839, 0.0715540506202599]
sems_new_rl= [0.0, 0.0057135848847513545, 0.006633413178525755, 0.00707276899252904, 0.007134218214304489]

In [None]:
sns.set("talk")
plt.figure(figsize=(8.6, 6))
plt.errorbar(unique_attempts, means_text_steering, yerr=sems_text_steering, fmt='o-', capsize=5, label='Text Steering')
plt.errorbar(unique_attempts, means_random_stepsize, yerr=sems_random_stepsize, fmt='o-', capsize=5, label='Image Steering (random)')
plt.errorbar(unique_attempts, means_new_rl, yerr=sems_new_rl, fmt='o-', capsize=5, label='Image Steering (RLHS)')
plt.xlabel('Attempt Number', fontsize=23)
plt.ylabel('Improvement over first attempt', fontsize=23)
plt.xticks(unique_attempts)
# Set label sizes larger
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.grid(True)
plt.legend(loc='upper left')
# legend = plt.legend(bbox_to_anchor=(0.32, 1.55), loc='upper center', ncol=1)
plt.tight_layout()
plt.savefig('improvement_by_attempt.pdf', bbox_inches='tight', dpi=300)

### Image Steering with RL Results

In [None]:
X_stay = np.array([[0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.04642136, 0.05038638, 0.04917789, 0.05267864, 0.05179059,
        0.05574815, 0.05335869, 0.05451705, 0.0551046 , 0.05405061],
       [0.08927953, 0.08747603, 0.08821033, 0.0898371 , 0.08980018,
        0.08727004, 0.09088007, 0.08836148, 0.089135  , 0.09181658],
       [0.09804241, 0.09788405, 0.09963065, 0.10032585, 0.09884076,
        0.09839644, 0.0972433 , 0.09653404, 0.09663419, 0.09819333]])
X_not_stay = np.array([[0.09855232, 0.10071277, 0.10350775, 0.10599847, 0.10902616,
        0.11065273, 0.11191053, 0.11209451, 0.11139376, 0.11168785],
       [0.13162178, 0.13280695, 0.13582534, 0.13732035, 0.13839655,
        0.1380477 , 0.13738207, 0.13916574, 0.13703005, 0.13656748],
       [0.12548979, 0.12919237, 0.12751702, 0.13099927, 0.133716  ,
        0.12982938, 0.12979796, 0.13273502, 0.13228459, 0.13002534],
       [0.12670392, 0.12695364, 0.12751615, 0.1273911 , 0.12708219,
        0.12593613, 0.12495567, 0.12699805, 0.12644036, 0.12650806]])
counts_stay = np.array([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0],
       [1961, 2199, 2200, 2500, 2362, 2729, 2415, 2587, 2794, 2674],
       [3780, 3661, 3862, 3996, 3843, 3470, 4354, 3746, 3771, 4259],
       [4351, 4581, 5105, 5562, 4800, 4546, 4349, 4139, 4262, 4717]])
counts_not_stay = np.array([[4250, 4761, 5531, 6158, 7524, 8539, 9355, 9755, 8894, 8969],
       [4046, 4233, 4856, 5180, 5126, 5093, 5017, 5588, 5052, 5124],
       [2978, 3386, 2923, 3545, 4047, 3417, 3472, 3811, 3869, 3546],
       [2819, 2797, 2752, 2743, 2780, 2612, 2702, 2732, 2678, 2709]])
counts_stay_se = 0.3 / np.sqrt(counts_stay)
counts_not_stay_se = 0.3 / np.sqrt(counts_not_stay)


In [None]:
# Create x-axis values
sns.set_context('talk')
x = np.linspace(0.1, 1, 10)

# Create figure with 4 subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

# Colors for consistency
color_palette = sns.color_palette()
stay_color = color_palette[0]
not_stay_color = color_palette[1]

# Create each subplot
for i in range(4):
    # Plot not_stay data
    if i != 0:
        axes[i].errorbar(x, X_not_stay[i], #yerr=counts_not_stay_se[i], 
                        color=not_stay_color, label='Previous switch', 
                        capsize=5, marker='o')
        
        # Plot stay data
        axes[i].errorbar(x, X_stay[i], #yerr=counts_stay_se[i], 
                        color=stay_color, label='Previous stay', 
                        capsize=5, marker='o')
    else:
        axes[i].errorbar(x, X_not_stay[i], #yerr=counts_not_stay_se[i], 
                        color=sns.color_palette()[2], label='First turn', 
                        capsize=5, marker='o')
    
    axes[i].set_xlabel('Proposal step size')
    axes[i].set_ylabel('Value')
    axes[i].legend()
    axes[i].grid(True, alpha=0.3)
    axes[i].set_title(f'Iteration {i+1}')

plt.tight_layout()
plt.savefig('rl_results.pdf', bbox_inches='tight', dpi=300)

# Qualitative Results: Image Generation Attempts

## Text Steering

In [27]:
attempts = pd.read_csv('data/steering_all_seeds.csv')

In [51]:
def plot_multiple_attempts(df, indices, figsize=(24, 6)):
    # Create a figure with num_rows rows, each containing 6 images
    num_rows = len(indices)
    fig, axes = plt.subplots(num_rows, 6, figsize=(figsize[0], figsize[1] * num_rows))
    
    # If there's only one row, wrap axes in a list to make it 2D
    if num_rows == 1:
        axes = axes.reshape(1, -1)

    # Add a vertical line to separate goal from attempts
    divider_ax = fig.add_axes([0.16, 0.0, 0.001, 0.99])  # Adjusted height to go from bottom to top
    divider_ax.axvline(x=0, color='black', alpha=0.75)
    divider_ax.axis('off')
    
    
    # Loop through each row
    for row_idx in range(num_rows):
        start_idx = indices[row_idx]
        # Get data for this row
        # row_data = random_df.iloc[row_idx]  # Start from the first row
        # row_data = random_df.iloc[-(row_idx + 1)] # Start from the last row
        row_data = df.iloc[start_idx]
        
        # Add the goal image
        response = requests.get(row_data['goal_image'])
        goal_img = Image.open(BytesIO(response.content))
        goal_prompt = row_data['goal_prompt']
        axes[row_idx, 0].imshow(goal_img)
        axes[row_idx, 0].axis('off')
        axes[row_idx, 0].set_title('Goal Image', pad=10, fontsize=20)
        
        # Loop through each attempt and display the image
        for col_idx in range(5):
            idx = start_idx + col_idx
            attempt = df.iloc[idx]
            # Fetch the image from the URL
            response = requests.get(attempt['generated_image'])
            img = Image.open(BytesIO(response.content))
            
            # Display the image
            axes[row_idx, col_idx+1].imshow(img)
            axes[row_idx, col_idx+1].axis('off')
            axes[row_idx, col_idx+1].set_title(f'Attempt {attempt["attempt"]}', 
                                                    fontsize=20, # 18
                                                    fontstyle='italic')
            
            # Add prompt text below the image
            # prompt_text = textwrap.fill('"' + attempt['prompt'] + '"', width=38, break_long_words=False)
            prompt_text = textwrap.fill('"' + attempt['prompt'] + '"', width=35, break_long_words=True)
            axes[row_idx, col_idx+1].text(-0.0, -0.05, prompt_text.strip(),
                                              ha='left', va='top',
                                              transform=axes[row_idx, col_idx+1].transAxes,
                                              fontsize=20)#,fontstyle='italic') # 14\
    plt.tight_layout()
    # Add more bottom margin to accommodate the text
    if num_rows != 1:
        plt.subplots_adjust(bottom=0.3/num_rows, wspace=0.1, hspace=0.4)
    
    return fig, axes

In [None]:
indices = [1170]
fig, axes = plot_multiple_attempts(attempts, indices, figsize=(30, 6.75))

In [None]:
indices = [2515]
fig, axes = plot_multiple_attempts(attempts, indices, figsize=(30, 6.75))

In [None]:
indices = [2515, 305, 735, 2260]
fig, axes = plot_multiple_attempts(attempts, indices, figsize=(30, 7.25))

## Image Steering

In [10]:
def plot_multiple_attempts_image_steering(all_attempts, indices, figsize=(24, 6)):
    # Create a figure with num_rows rows, each containing 6 images
    num_rows = len(indices)
    fig, axes = plt.subplots(num_rows, 6, figsize=(figsize[0], figsize[1] * num_rows))
    
    if num_rows == 1:
        axes = axes.reshape(1, -1)

    # Add a vertical line to separate goal from attempts
    divider_ax = fig.add_axes([0.167, 0.0, 0.001, 0.99])
    divider_ax.axvline(x=0, color='black', alpha=0.75)
    divider_ax.axis('off')
    

    random_attempts = [all_attempts[i] for i in indices]
    
    # Loop through each row
    for row_idx in range(num_rows):
        row_data = random_attempts[row_idx]
        
        # Add the goal image
        response = requests.get(row_data['goal_url'])
        goal_img = Image.open(BytesIO(response.content))
        goal_prompt = row_data['goal_prompt']
        axes[row_idx, 0].imshow(goal_img)
        axes[row_idx, 0].axis('off')
        axes[row_idx, 0].set_title(f'Goal Image', pad=10, fontsize=20)
        
        # Get the list of attempts
        attempts = row_data['rounds']

        image_0_url = row_data['rounds'][0]['generated_images'][0]
        response = requests.get(image_0_url)
        img = Image.open(BytesIO(response.content))
        axes[row_idx, 1].imshow(img)
        axes[row_idx, 1].axis('off')
        axes[row_idx, 1].set_title(f'Attempt 1', fontsize=20)  

        prompt_text = textwrap.fill(row_data['user_prompt'], width=35, break_long_words=False)
        axes[row_idx, 1].text(0., -0.05, prompt_text, 
                  ha='left', va='top', 
                  transform=axes[row_idx, 1].transAxes,
                  fontsize=20)#, fontstyle='italic')

        # Loop through each attempt and display the image
        for attempt_idx, attempt in enumerate(attempts):
            # Fetch the image from the URL
            response = requests.get(attempt['selected_image_url'])
            img = Image.open(BytesIO(response.content))
            
            # Display the image
            axes[row_idx, attempt_idx + 2].imshow(img)
            axes[row_idx, attempt_idx + 2].axis('off')
            axes[row_idx, attempt_idx + 2].set_title(f'Attempt {1+attempt["iteration"]}', 
                                                    fontsize=20)
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.3/num_rows, wspace=0.1, hspace=0.4)
    
    return fig, axes

In [None]:
from ast import literal_eval
def parse(x):
    try:
        data = literal_eval(x)
        return data
    except (ValueError, SyntaxError) as e:
        print(f"Error evaluating string: {e}")

df = pd.read_csv('data/img_steering.csv')
df['all_attempts'] = df['all_attempts'].apply(parse)
all_attempts = df['all_attempts'].tolist()
# Only include rows where all_attempts has 2 items
all_attempts = [item for sublist in all_attempts for item in sublist]
len(all_attempts)

df = pd.read_csv('data/img_steering_rl.csv')
df['all_attempts'] = df['all_attempts'].apply(parse)
all_attempts_2 = df['all_attempts'].tolist()
# Only include rows where all_attempts has 2 items
all_attempts += [item for sublist in all_attempts_2 for item in sublist]
len(all_attempts)


In [None]:
indices = [-14, -18, -58, -1]
fig, axes = plot_multiple_attempts_image_steering(all_attempts, indices, figsize=(30, 7))


### Image Steering on Tiles

In [None]:
df = pd.read_csv('data/img_steering_rl_tiles.csv')
df = df[pd.to_datetime(df['StartDate'], errors='coerce').notnull()]
df = df[~pd.isnull(df['ID'])]
df.shape
df['all_attempts'] = df['all_attempts'].apply(parse_json_safely)
# Only include rows where all_attempts is not None
df = df[df['all_attempts'].notna()]
df = df[df['all_attempts'].apply(lambda x: len(x) == 2)]
all_attempts = df['all_attempts'].tolist()
# Only include rows where all_attempts has 2 items
all_attempts = [item for sublist in all_attempts for item in sublist]
len(all_attempts)


In [None]:
indices = [68, 87, 2, 19]#, 22]
fig, axes = plot_multiple_attempts_image_steering(all_attempts, indices, figsize=(30, 7.6))


# Blind Steering

In [23]:
data = pd.read_csv("data/steering_all_seeds.csv")
llm_baseline4 = pd.read_csv('data/blind_steering_4.csv')
llm_baseline10 = pd.read_csv('data/blind_steering_10.csv')
llm_baseline20 = pd.read_csv('data/blind_steering_20.csv')
llm_baseline7 = pd.read_csv('data/blind_steering_7.csv')

In [None]:
# Get the max score for each goal image
max_scores = data.groupby('goal_image')['dreamsim'].max()
first_scores = data.groupby('goal_image')['dreamsim'].first()
np.mean((max_scores - first_scores) / first_scores)

In [25]:
def get_improvement_stats(baseline_df):
  # For each URL, it it's in data and baseline_df, include it in the analysis
  common_urls = set(data['goal_image']).intersection(set(baseline_df['goal_url']))
  human_improvements = []
  llm_improvements = []
  for common_url in common_urls:
      # First, get human improvement
      relevant_data = data[data['goal_image'] == common_url]
      original_score = relevant_data[relevant_data['attempt'] == 1]['dreamsim'].values[0].item()
      human_improvement = relevant_data['dreamsim'].max() - original_score
      # Now, get LLM improvement
      relevant_llm = baseline_df[baseline_df['goal_url'] == common_url]
      llm_improvement = max(relevant_llm['score'].max() - original_score, 0.0) # Need to do 0 because sometimes the first prompt is the best and we don't store it
      # print(human_improvement, llm_improvement)
      human_improvements.append(human_improvement)
      llm_improvements.append(llm_improvement)

  human_improvements = np.array(human_improvements)
  llm_improvements = np.array(llm_improvements)
  mean_ratio = np.mean(llm_improvements) / np.mean(human_improvements)
  # Estimate se by bootstrapping
  num_bootstraps = 1000
  bootstrap_ratios = []
  for _ in range(num_bootstraps):
    bootstrap_inds = np.random.randint(0, len(human_improvements), size=len(human_improvements))
    bootstrap_human = human_improvements[bootstrap_inds]
    bootstrap_llm = llm_improvements[bootstrap_inds]
    bootstrap_ratios.append(np.mean(bootstrap_llm) / np.mean(bootstrap_human))
  se = np.std(bootstrap_ratios)
  return mean_ratio, se

improvement_means = [0]
improvement_ses = [0]
for baseline in [llm_baseline4, llm_baseline7, llm_baseline10, llm_baseline20]:
    mean_ratio, se = get_improvement_stats(baseline)
    improvement_means.append(mean_ratio)
    improvement_ses.append(se)

In [None]:
xs = [0, 4, 7, 10, 20]
ys = improvement_means

# plt.figure(figsize=(8, 4))
plt.figure(figsize=(7, 4))
plt.plot(xs, ys, 'b-')
# plt.scatter(xs[1], ys[1], color='red', s=100)
# Make the second point red but all others blue
for i in range(len(xs)):
    # if i != 1:
        plt.errorbar(xs[i], ys[i], yerr=improvement_ses[i], color=sns.color_palette()[0], fmt='o', capsize=5)
    # else:
        # plt.errorbar(xs[i], ys[i], yerr=improvement_ses[i], color='gray', fmt='o', capsize=5)
        
# plt.errorbar(xs, ys, yerr=improvement_ses, fmt='o', capsize=5)
plt.axhline(y=1, color='k', linestyle='--', label='Human performance')
# plt.axvline(x=4, color='red', linestyle='--', label='Humans have 4 attempts to reprompt')
plt.xticks(xs)
# Make the 4 tick red
# plt.gca().get_xticklabels()[1].set_color(sns.color_palette()[2])
# plt.gca().get_xticklabels()[1].set_color('gray')
plt.ylim([-0.05, 1.05])
plt.legend()
plt.xlabel('Number of blind prompt rewrites')
# plt.xlabel('Number of blind attempts')# (Note: humans have 4 attempts to reprompt)')
# plt.text(10., -0.3, "(Humans have 4 attempts to reprompt)", ha='center', va='top', fontsize=12, color=sns.color_palette()[2])
plt.text(10., -0.305, "(Humans have 4 attempts to reprompt)", ha='center', va='top', fontsize=15, color='dimgray')
# plt.ylabel("Fraction of human\n improvement attained by LLM")
plt.ylabel("Fraction of human\n performance attained")
plt.title("How close is human steering to blind steering?")
plt.savefig('human_improvement_vs_llm_improvement.pdf', dpi=300, bbox_inches='tight')


# The Steerability vs. Producibility Frontier

In [None]:
num_seeds = [1, 2, 3, 1e9]
steering_scores = [0.716, 0.689, 0.680, 0.652]
producibility_scores = [0.684, 0.689, 0.697, 0.742]
procudibility_errs = [0.01, 0.01, 0.01, 0.01]
steering_errs = [0.0129, 0.0127, 0.0121, 0.010]


plt.figure(figsize=(7, 4))
plt.scatter(steering_scores, producibility_scores, s=100)

plt.plot(steering_scores, producibility_scores, 'b--')
plt.xlim(0.645, 0.725)
plt.ylim(0.6726, 0.7526)
plt.xlabel('Steerability score')
plt.ylabel('Producibility score')
plt.text(0.6645, 0.7425, 'Default model', ha='center', va='center', fontsize=15, color='black')
plt.text(0.715, 0.693, 'Constrained\n to 1 seed', ha='center', va='center', fontsize=15, color='black')
plt.title("Tradeoff between steerability and producibility")
plt.savefig('frontier.pdf', dpi=35, bbox_inches='tight')

