# Optimizing Prompts with **Automatic Prompt Engineer** (APE)

This notebook demonstrates how to use Automatic Prompt Engineer (APE) (arxiv link) to optimize prompts for text generation. In its simplest form, APE takes as input a dataset (a list of inputs and a list of outputs), a prompt template, and optimizes this prompt template so that it generates the outputs given the inputs.

APE accomplishes this in two steps. First, it uses a language model to generate a set of candidate prompts. Then, it uses a prompt evaluation function to evaluate the quality of each candidate prompt. Finally, it returns the prompt with the highest evaluation score.

In [1]:
# First, let's define a simple dataset consisting of words and their antonyms.
words = ["sane", "direct", "informally", "unpopular", "subtractive", "nonresidential",
    "inexact", "uptown", "incomparable", "powerful", "gaseous", "evenly", "formality",
    "deliberately", "off"]
antonyms = ["insane", "indirect", "formally", "popular", "additive", "residential",
    "exact", "downtown", "comparable", "powerless", "solid", "unevenly", "informality",
    "accidentally", "on"]

In [2]:
# Now, we need to define the format of the prompt that we are using.

eval_template = \
"""Instruction: [PROMPT]
Input: [INPUT]
Output: [OUTPUT]"""

In [3]:
# Now, let's use APE to find prompts that generate antonyms for each word.
from automatic_prompt_engineer import ape

result, demo_fn = ape.simple_ape(
    dataset=(words, antonyms),
    eval_template=eval_template,
    eval_model="facebook/opt-1.3b",
    prompt_gen_model="facebook/opt-1.3b",
    prompt_gen_mode="opt",
    prompt_gen_batch_size=100,
    eval_batch_size=250
)

Config before update: {'generation': {'num_subsamples': 5, 'num_demos': 5, 'num_prompts_per_subsample': 10, 'model': {'name': 'OPT', 'batch_size': 100, 'gpt_config': {'model': 'facebook/opt-1.3b', 'temperature': 0.9, 'max_tokens': 50, 'top_p': 0.9, 'frequency_penalty': 0.0, 'presence_penalty': 0.0}}}, 'evaluation': {'method': 'bandits', 'rounds': 20, 'num_prompts_per_round': 0.334, 'bandit_method': 'ucb', 'bandit_config': {'c': 1.0}, 'base_eval_method': 'likelihood', 'base_eval_config': {'num_samples': 5, 'num_few_shot': 5, 'model': {'name': 'GPT_forward', 'batch_size': 250, 'gpt_config': {'model': 'facebook/opt-1.3b', 'temperature': 0.7, 'max_tokens': 200, 'top_p': 1.0, 'frequency_penalty': 0.0, 'presence_penalty': 0.0}}}}, 'demo': {'model': {'name': 'GPT_forward', 'batch_size': 500, 'gpt_config': {'model': 'text-davinci-002', 'temperature': 0.7, 'max_tokens': 200, 'top_p': 1.0, 'frequency_penalty': 0.0, 'presence_penalty': 0.0}}}}
Config after update: {'generation': {'num_subsamples'

  0%|                                                                            | 0/1 [00:00<?, ?it/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/653 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/2.63G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/685 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/441 [00:00<?, ?B/s]

<generator object OPT.__generate_text.<locals>.<genexpr> at 0x7f009b422900>


100%|███████████████████████████████████████████████████████████████████| 1/1 [02:16<00:00, 136.34s/it]


Model returned 10 prompts. Deduplicating...
Model: {'name': 'GPT_forward', 'batch_size': 500, 'gpt_config': {'model': 'text-davinci-002', 'temperature': 0.7, 'max_tokens': 200, 'top_p': 1.0, 'frequency_penalty': 0.0, 'presence_penalty': 0.0}}
Deduplicated to 10 prompts.
Evaluating prompts...
Model: {'name': 'OPT', 'batch_size': 500, 'gpt_config': {'model': 'facebook/opt-1.3b', 'temperature': 0.7, 'max_tokens': 200, 'top_p': 1.0, 'frequency_penalty': 0.0, 'presence_penalty': 0.0}}


Evaluating prompts:   5%|██▎                                            | 1/20 [00:04<01:34,  4.95s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  10%|████▋                                          | 2/20 [00:08<01:15,  4.19s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  15%|███████                                        | 3/20 [00:11<01:03,  3.73s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  20%|█████████▍                                     | 4/20 [00:14<00:54,  3.41s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  25%|███████████▊                                   | 5/20 [00:17<00:48,  3.20s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  30%|██████████████                                 | 6/20 [00:20<00:43,  3.09s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  35%|████████████████▍                              | 7/20 [00:23<00:38,  2.98s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  40%|██████████████████▊                            | 8/20 [00:25<00:34,  2.91s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  45%|█████████████████████▏                         | 9/20 [00:28<00:31,  2.86s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  50%|███████████████████████                       | 10/20 [00:31<00:28,  2.84s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  55%|█████████████████████████▎                    | 11/20 [00:34<00:25,  2.82s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  60%|███████████████████████████▌                  | 12/20 [00:37<00:22,  2.80s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  65%|█████████████████████████████▉                | 13/20 [00:39<00:19,  2.80s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  70%|████████████████████████████████▏             | 14/20 [00:42<00:16,  2.79s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  75%|██████████████████████████████████▌           | 15/20 [00:45<00:13,  2.78s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  80%|████████████████████████████████████▊         | 16/20 [00:48<00:11,  2.77s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  85%|███████████████████████████████████████       | 17/20 [00:50<00:08,  2.78s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  90%|█████████████████████████████████████████▍    | 18/20 [00:53<00:05,  2.78s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts:  95%|███████████████████████████████████████████▋  | 19/20 [00:56<00:02,  2.78s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519

Evaluating prompts: 100%|██████████████████████████████████████████████| 20/20 [00:59<00:00,  2.96s/it]

input_logprobs type: <class 'list'> 
input_logprobs: [-6.04153321821069, -8.191540450506887, -3.9753765498344413, -3.2273789393350234, -5.871518315137005, -6.755135425012705, -2.7258399203428296, -4.072195061215885, -3.6079127414925876, -5.194091638441127, -2.99706993364682, -8.879172181282495, -0.05822003787611586, -1.348471323926764, -4.384973980261906, -5.849977565183112, -6.322347394548752, -2.4541524874662555, -1.0974799098658143, -9.538201230607848, -6.369417576842661, -0.06670313484693761, -3.915289697789944, -0.6542572982419781, -0.6425408778366558, -0.28847222862232497, -3.1319832759320634, -1.2148583829416988, -15.854948853423245, -4.9897377849924025, -3.226749553024832, -0.016500058646735923, -4.300266892913412, -0.17910350627494168, -0.1986627889894784, -4.540434683269391, -0.3087609129811979, -9.295749330374015, -7.282380144695635, -0.013664668518190917, -0.9239931798471402, -0.024573341067275, -0.004056200331300653, -3.209387099123652, -0.12097881092313394, -0.04535571519




In [4]:
# Let's see the results.
print(result)

score: prompt
----------------
-2.34: I gave a friend an instruction. Based on the instruction they produced the following input-output pairs:

Input: deliberately
Output: accidentally

Input: gaseous
Output: solid

Input: powerful
Output: powerless

Input: incomparable
Output: comparable

Input: uptown
Output: downtown

The instruction was to give the output of the first input pair to the input of the second input pair and to receive the output of the second input pair as a result. I then gave them two more instructions:

Input: gaseous
Output: solid

-2.34: I gave a friend an instruction. Based on the instruction they produced the following input-output pairs:

Input: deliberately
Output: accidentally

Input: gaseous
Output: solid

Input: powerful
Output: powerless

Input: incomparable
Output: comparable

Input: uptown
Output: downtown

The instruction was to use the correct input-output pairs in the following order:

Input: gaseous
Output: solid

Input: powerful
Output: powerless

I

Let's compare with a prompt written by a human:

"*Write an antonym to the following word.*"

In [5]:
from automatic_prompt_engineer import ape

manual_prompt = "Write an antonym to the following word."

human_result = ape.simple_eval(
    dataset=(words, antonyms),
    eval_template=eval_template,
    prompts=[manual_prompt],
    eval_model="facebook/opt-1.3b"
)

input_logprobs type: <class 'list'> 
input_logprobs: [-6.041535137894718, -8.191533373468227, -3.975373888354679, -3.227389056005671, -7.111651321954099, -2.9326585625111723, -7.920237050846291, -6.883951970460335, -0.13750898484275476, -1.479471824272397, -1.2375481058671203, -3.588140262831139, -2.2355612430772616, -1.9828868777346877, -1.0564682116750672, -8.687120660898874, -0.40303626924352925, -13.672876027515763, -1.5074674948461584, -0.8112016482235674, -0.892408129314914, -0.009294831798180303, -6.457041334535585, -6.041535137894718, -8.191533373468227, -3.975373888354679, -3.227389056005671, -7.111651321954099, -2.9326585625111723, -7.920237050846291, -6.883951970460335, -0.13750898484275476, -1.479471824272397, -1.2375481058671203, -3.588140262831139, -2.2355612430772616, -1.9828868777346877, -1.0564682116750672, -8.687120660898874, -0.40303626924352925, -11.379039731911057, -3.192938345953957, -2.651258216316079, -2.688038108661779, -0.011933916143942203, -5.048263183460186

In [6]:
print(human_result)

log(p): prompt
----------------
-3.90: Write an antonym to the following word.

