diff --git a/notebooks/getting_started.ipynb b/notebooks/getting_started.ipynb
index 5807cd8..0c38f11 100644
--- a/notebooks/getting_started.ipynb
+++ b/notebooks/getting_started.ipynb
@@ -44,7 +44,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -64,9 +64,18 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 1,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "c:\\Users\\tzehl\\Documents\\programming\\promptolution\\.venv\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ }
+ ],
"source": [
"from promptolution.helpers import run_experiment\n",
"from promptolution.config import Config"
@@ -81,7 +90,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -90,25 +99,28 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"config = Config(\n",
- " task_name=\"agnews\",\n",
- " ds_path=\"../data_sets/cls/agnews/\",\n",
- " n_steps=3,\n",
- " optimizer=\"evopromptga\",\n",
+ " task_name=\"subj\",\n",
+ " ds_path=\"../data_sets/cls/subj/\",\n",
+ " n_steps=8,\n",
+ " optimizer=\"evopromptde\",\n",
" meta_llm=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
" evaluation_llm=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
" downstream_llm=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
" api_token=token,\n",
+ " prepend_exemplars=True,\n",
+ " exemplar_selector=\"random_search\",\n",
+ " n_exemplars=3,\n",
")"
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
@@ -117,7 +129,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -147,104 +159,74 @@
" \n",
"
\n",
" \n",
- " | 0 | \n",
- " Classify the news story into one of the follow... | \n",
- " 0.95 | \n",
- "
\n",
- " \n",
- " | 7 | \n",
- " You will be required to classify a news articl... | \n",
- " 0.90 | \n",
- "
\n",
- " \n",
- " | 11 | \n",
- " Classify the topic of the following news as \"W... | \n",
- " 0.90 | \n",
+ " 1 | \n",
+ " evaluate each sentence as either objective or ... | \n",
+ " 0.80 | \n",
"
\n",
" \n",
- " | 12 | \n",
- " Classify news articles into categories (World,... | \n",
- " 0.90 | \n",
+ " 8 | \n",
+ " As a linguist, analyze a statement from a movi... | \n",
+ " 0.80 | \n",
"
\n",
" \n",
- " | 4 | \n",
- " Classify the given news article into one of th... | \n",
- " 0.85 | \n",
+ " 3 | \n",
+ " identify whether the given sentence was expres... | \n",
+ " 0.65 | \n",
"
\n",
" \n",
- " | 6 | \n",
- " Your job is to determine whether a news articl... | \n",
- " 0.85 | \n",
+ " 5 | \n",
+ " Analyze the textual content of a given stateme... | \n",
+ " 0.65 | \n",
"
\n",
" \n",
- " | 13 | \n",
- " Categorize the provided news article according... | \n",
- " 0.85 | \n",
+ " 9 | \n",
+ " determine the classification of each sentence ... | \n",
+ " 0.60 | \n",
"
\n",
" \n",
- " | 1 | \n",
- " Categorize the news article into one of four c... | \n",
- " 0.80 | \n",
+ " 0 | \n",
+ " evaluate each statement as either subjective o... | \n",
+ " 0.50 | \n",
"
\n",
" \n",
" | 2 | \n",
- " Your responsibility is to accurately categoriz... | \n",
- " 0.80 | \n",
+ " Classify the sentence according to its subject... | \n",
+ " 0.40 | \n",
"
\n",
" \n",
- " | 3 | \n",
- " Identify the primary theme of a news article a... | \n",
- " 0.80 | \n",
- "
\n",
- " \n",
- " | 8 | \n",
- " In this task, you are given a news article. Yo... | \n",
- " 0.80 | \n",
- "
\n",
- " \n",
- " | 14 | \n",
- " Accurately categorize news articles into World... | \n",
- " 0.80 | \n",
- "
\n",
- " \n",
- " | 5 | \n",
- " Accurately categorize the provided news articl... | \n",
- " 0.75 | \n",
+ " 6 | \n",
+ " As a classifier, interpret phrases in movie re... | \n",
+ " 0.35 | \n",
"
\n",
" \n",
- " | 9 | \n",
- " Determine the theme of the news item. Choose f... | \n",
- " 0.75 | \n",
+ " 7 | \n",
+ " and\\n\\nshae is about to return to bed when she... | \n",
+ " 0.35 | \n",
"
\n",
" \n",
- " | 10 | \n",
- " Determine the primary theme of the news articl... | \n",
- " 0.70 | \n",
+ " 4 | \n",
+ " Analyze reviews and label them as subjective o... | \n",
+ " 0.30 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " prompt score\n",
- "0 Classify the news story into one of the follow... 0.95\n",
- "7 You will be required to classify a news articl... 0.90\n",
- "11 Classify the topic of the following news as \"W... 0.90\n",
- "12 Classify news articles into categories (World,... 0.90\n",
- "4 Classify the given news article into one of th... 0.85\n",
- "6 Your job is to determine whether a news articl... 0.85\n",
- "13 Categorize the provided news article according... 0.85\n",
- "1 Categorize the news article into one of four c... 0.80\n",
- "2 Your responsibility is to accurately categoriz... 0.80\n",
- "3 Identify the primary theme of a news article a... 0.80\n",
- "8 In this task, you are given a news article. Yo... 0.80\n",
- "14 Accurately categorize news articles into World... 0.80\n",
- "5 Accurately categorize the provided news articl... 0.75\n",
- "9 Determine the theme of the news item. Choose f... 0.75\n",
- "10 Determine the primary theme of the news articl... 0.70"
+ " prompt score\n",
+ "1 evaluate each sentence as either objective or ... 0.80\n",
+ "8 As a linguist, analyze a statement from a movi... 0.80\n",
+ "3 identify whether the given sentence was expres... 0.65\n",
+ "5 Analyze the textual content of a given stateme... 0.65\n",
+ "9 determine the classification of each sentence ... 0.60\n",
+ "0 evaluate each statement as either subjective o... 0.50\n",
+ "2 Classify the sentence according to its subject... 0.40\n",
+ "6 As a classifier, interpret phrases in movie re... 0.35\n",
+ "7 and\\n\\nshae is about to return to bed when she... 0.35\n",
+ "4 Analyze reviews and label them as subjective o... 0.30"
]
},
- "execution_count": 11,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
diff --git a/poetry.lock b/poetry.lock
index c3c7763..1938d9c 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@@ -1376,6 +1376,17 @@ files = [
{file = "jiter-0.5.0.tar.gz", hash = "sha256:1d916ba875bcab5c5f7d927df998c4cb694d27dceddf3392e58beaf10563368a"},
]
+[[package]]
+name = "joblib"
+version = "1.4.2"
+description = "Lightweight pipelining with Python functions"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"},
+ {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"},
+]
+
[[package]]
name = "jsonpatch"
version = "1.33"
@@ -3361,6 +3372,106 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"]
testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"]
torch = ["safetensors[numpy]", "torch (>=1.10)"]
+[[package]]
+name = "scikit-learn"
+version = "1.5.2"
+description = "A set of python modules for machine learning and data mining"
+optional = false
+python-versions = ">=3.9"
+files = [
+ {file = "scikit_learn-1.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:299406827fb9a4f862626d0fe6c122f5f87f8910b86fe5daa4c32dcd742139b6"},
+ {file = "scikit_learn-1.5.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2d4cad1119c77930b235579ad0dc25e65c917e756fe80cab96aa3b9428bd3fb0"},
+ {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c412ccc2ad9bf3755915e3908e677b367ebc8d010acbb3f182814524f2e5540"},
+ {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a686885a4b3818d9e62904d91b57fa757fc2bed3e465c8b177be652f4dd37c8"},
+ {file = "scikit_learn-1.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:c15b1ca23d7c5f33cc2cb0a0d6aaacf893792271cddff0edbd6a40e8319bc113"},
+ {file = "scikit_learn-1.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03b6158efa3faaf1feea3faa884c840ebd61b6484167c711548fce208ea09445"},
+ {file = "scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1ff45e26928d3b4eb767a8f14a9a6efbf1cbff7c05d1fb0f95f211a89fd4f5de"},
+ {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f763897fe92d0e903aa4847b0aec0e68cadfff77e8a0687cabd946c89d17e675"},
+ {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8b0ccd4a902836493e026c03256e8b206656f91fbcc4fde28c57a5b752561f1"},
+ {file = "scikit_learn-1.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:6c16d84a0d45e4894832b3c4d0bf73050939e21b99b01b6fd59cbb0cf39163b6"},
+ {file = "scikit_learn-1.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f932a02c3f4956dfb981391ab24bda1dbd90fe3d628e4b42caef3e041c67707a"},
+ {file = "scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3b923d119d65b7bd555c73be5423bf06c0105678ce7e1f558cb4b40b0a5502b1"},
+ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"},
+ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"},
+ {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"},
+ {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"},
+ {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"},
+ {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"},
+ {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"},
+ {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"},
+ {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"},
+ {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"},
+ {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"},
+ {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca64b3089a6d9b9363cd3546f8978229dcbb737aceb2c12144ee3f70f95684b7"},
+ {file = "scikit_learn-1.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:3bed4909ba187aca80580fe2ef370d9180dcf18e621a27c4cf2ef10d279a7efe"},
+ {file = "scikit_learn-1.5.2.tar.gz", hash = "sha256:b4237ed7b3fdd0a4882792e68ef2545d5baa50aca3bb45aa7df468138ad8f94d"},
+]
+
+[package.dependencies]
+joblib = ">=1.2.0"
+numpy = ">=1.19.5"
+scipy = ">=1.6.0"
+threadpoolctl = ">=3.1.0"
+
+[package.extras]
+benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"]
+build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"]
+docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"]
+examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"]
+install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"]
+maintenance = ["conda-lock (==2.5.6)"]
+tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"]
+
+[[package]]
+name = "scipy"
+version = "1.14.1"
+description = "Fundamental algorithms for scientific computing in Python"
+optional = false
+python-versions = ">=3.10"
+files = [
+ {file = "scipy-1.14.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389"},
+ {file = "scipy-1.14.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3"},
+ {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0"},
+ {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3"},
+ {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d"},
+ {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69"},
+ {file = "scipy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad"},
+ {file = "scipy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5"},
+ {file = "scipy-1.14.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675"},
+ {file = "scipy-1.14.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2"},
+ {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617"},
+ {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8"},
+ {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37"},
+ {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2"},
+ {file = "scipy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2"},
+ {file = "scipy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94"},
+ {file = "scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d"},
+ {file = "scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07"},
+ {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5"},
+ {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc"},
+ {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310"},
+ {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066"},
+ {file = "scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1"},
+ {file = "scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f"},
+ {file = "scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79"},
+ {file = "scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e"},
+ {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73"},
+ {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e"},
+ {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d"},
+ {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e"},
+ {file = "scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06"},
+ {file = "scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84"},
+ {file = "scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417"},
+]
+
+[package.dependencies]
+numpy = ">=1.23.5,<2.3"
+
+[package.extras]
+dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"]
+doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"]
+test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"]
+
[[package]]
name = "seaborn"
version = "0.13.2"
@@ -3525,6 +3636,17 @@ files = [
doc = ["reno", "sphinx"]
test = ["pytest", "tornado (>=4.5)", "typeguard"]
+[[package]]
+name = "threadpoolctl"
+version = "3.5.0"
+description = "threadpoolctl"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"},
+ {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"},
+]
+
[[package]]
name = "tiktoken"
version = "0.7.0"
@@ -4050,4 +4172,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
-content-hash = "7f2a00d58b72f3b7cec0991808ffc354c5f10f87e1c78bd9ed4d5369932d243d"
+content-hash = "6c9aacc81e214e934481f8764b4ecf4db4366f0860952bf045649e2b405f83a5"
diff --git a/promptolution/config.py b/promptolution/config.py
index 10000f5..dac2d9a 100644
--- a/promptolution/config.py
+++ b/promptolution/config.py
@@ -2,7 +2,7 @@
import configparser
from dataclasses import dataclass
from pathlib import Path
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Literal, Optional
@dataclass
@@ -13,13 +13,13 @@ class Config:
either from a config file or from keyword arguments.
Attributes:
- task_name (str): Name of the task.
- ds_path (str): Path to the dataset.
- n_steps (int): Number of optimization steps.
- optimizer (str): Name of the optimizer to use.
- meta_llm (str): Name of the meta language model.
- downstream_llm (str): Name of the downstream language model.
- evaluation_llm (str): Name of the evaluation language model.
+ task_name (str): Name of the task. Should not be None if used.
+ ds_path (str): Path to the dataset. Should not be None if used.
+ n_steps (int): Number of optimization steps. Should not be None if used.
+ optimizer (str): Name of the optimizer to use. Should not be None if used.
+ meta_llm (str): Name of the meta language model. Should not be None if used.
+ downstream_llm (str): Name of the downstream language model. Should not be None if used.
+ evaluation_llm (str): Name of the evaluation language model. Should not be None if used.
init_pop_size (int): Initial population size. Defaults to 10.
logging_dir (str): Directory for logging. Defaults to "logs/run.csv".
experiment_name (str): Name of the experiment. Defaults to "experiment".
@@ -27,10 +27,20 @@ class Config:
donor_random (bool): Whether to use random donor prompts for EvoPromptDE. Defaults to False.
random_seed (int): Random seed for reproducibility. Defaults to 42.
selection_mode (str): Selection mode for EvoPromptGA. Defaults to "random".
- meta_bs (int): Batch size for local meta LLM. Defaults to None.
- downstream_bs (int): Batch size for local downstream LLM. Defaults to None.
- api_token (str): API token for different APIs, as implemented in LLM classes. Defaults to None.
- meta_prompt (str): Prompt template for the meta LLM. Defaults to None.
+ meta_bs (int): Batch size for local meta LLM. Should not be None if llm is run locally. Defaults to None.
+ downstream_bs (int): Batch size for local downstream LLM.
+ Should not be None if llm is run locally Defaults to None.
+ api_token (str): API token for different APIs, as implemented in LLM classes.
+ Should not be None if APILLM is used. Defaults to None.
+ meta_prompt (str): Prompt template for the meta LLM.
+ If None is set, default meta_prompts from template.py will be used. Defaults to None.
+ prepend_exemplars (bool): rather to do exemplar search and prepend few-shot examples. Defaults to False.
+ n_exemplars (int): how many exemplars to prepend. Only used if prepend_exemplars is True. Defaults to 5.
+ exemplar_selector (str): which exemplar selector to use. Should not be None if preped_exemplars is True.
+ Defaults to None.
+ n_ds_samples_to_meta (int): how many examples to show of the ds to show to meta-llm
+ (not applicable to every optimizer)
+ n_eval_samples (int): how many examples to show to evaluation llm for evaluation.
"""
task_name: str = None
@@ -46,11 +56,16 @@ class Config:
include_task_desc: bool = True
donor_random: bool = False
random_seed: int = 42
- selection_mode: Optional[str] = "random"
+ selection_mode: Optional[Literal["random", "wheel", "tour"]] = "random"
meta_bs: Optional[int] = None
downstream_bs: Optional[int] = None
api_token: Optional[str] = None
meta_prompt: Optional[str] = None
+ prepend_exemplars: Optional[bool] = False
+ n_exemplars: Optional[int] = 5
+ exemplar_selector: Optional[str] = None
+ n_ds_samples_to_meta: Optional[int] = 2
+ n_eval_samples: Optional[int] = 20
def __post_init__(self):
"""Validate the configuration after initialization."""
diff --git a/promptolution/exemplar_selectors/__init__.py b/promptolution/exemplar_selectors/__init__.py
new file mode 100644
index 0000000..f234373
--- /dev/null
+++ b/promptolution/exemplar_selectors/__init__.py
@@ -0,0 +1,33 @@
+"""Module for exemplar selectors."""
+
+from typing import Literal
+
+from promptolution.exemplar_selectors.random_search_selector import RandomSearchSelector
+from promptolution.exemplar_selectors.random_selector import RandomSelector
+from promptolution.predictors.base_predictor import BasePredictor
+from promptolution.tasks.base_task import BaseTask
+
+SELECTOR_MAP = {
+ "random": RandomSelector,
+ "random_search": RandomSearchSelector,
+}
+
+
+def get_exemplar_selector(name: Literal["random", "random_search"], task: BaseTask, predictor: BasePredictor):
+ """Factory function to get an exemplar selector based on the given name.
+
+ Args:
+ name (str): The name of the exemplar selector to instantiate.
+ task (BaseTask): The task object to be passed to the selector.
+ predictor (BasePredictor): The predictor object to be passed to the selector.
+
+ Returns:
+ BaseExemplarSelector: An instance of the requested exemplar selector.
+
+ Raises:
+ ValueError: If the requested selector name is not found.
+ """
+ if name not in SELECTOR_MAP:
+ raise ValueError(f"Exemplar selector '{name}' not found. Available selectors: {list(SELECTOR_MAP.keys())}")
+
+ return SELECTOR_MAP[name](task, predictor)
diff --git a/promptolution/exemplar_selectors/base_exemplar_selector.py b/promptolution/exemplar_selectors/base_exemplar_selector.py
new file mode 100644
index 0000000..dd96e7b
--- /dev/null
+++ b/promptolution/exemplar_selectors/base_exemplar_selector.py
@@ -0,0 +1,41 @@
+"""Base class for exemplar selectors."""
+
+from abc import ABC, abstractmethod
+from typing import Any, List, Tuple
+
+from promptolution.predictors.base_predictor import BasePredictor
+from promptolution.tasks.base_task import BaseTask
+
+
+class BaseExemplarSelector(ABC):
+ """An abstract base class for exemplar selectors.
+
+ This class defines the basic interface and common functionality
+ that all exemplar selectors should implement.
+ """
+
+ def __init__(self, task: BaseTask, predictor: BasePredictor):
+ """Initialize the BaseExemplarSelector.
+
+ Args:
+ task (BaseTask): An object representing the task to be performed.
+ predictor (BasePredictor): An object capable of making predictions based on prompts.
+ """
+ self.task = task
+ self.predictor = predictor
+
+ @abstractmethod
+ def select_exemplars(self, prompt: str, n_examples: int = 5) -> str:
+ """Select exemplars based on the given prompt.
+
+ Args:
+ prompt (str): The input prompt to base the exemplar selection on.
+ n_examples (int, optional): The number of exemplars to select. Defaults to 5.
+
+ Returns:
+ str: A new prompt that includes the original prompt and the selected exemplars.
+
+ Raises:
+ NotImplementedError: This method should be implemented by subclasses.
+ """
+ raise NotImplementedError("This method should be implemented by subclasses.")
diff --git a/promptolution/exemplar_selectors/random_search_selector.py b/promptolution/exemplar_selectors/random_search_selector.py
new file mode 100644
index 0000000..005fef8
--- /dev/null
+++ b/promptolution/exemplar_selectors/random_search_selector.py
@@ -0,0 +1,39 @@
+"""Random search exemplar selector."""
+
+from promptolution.exemplar_selectors.base_exemplar_selector import BaseExemplarSelector
+
+
+class RandomSearchSelector(BaseExemplarSelector):
+ """A selector that uses random search to find the best set of exemplars.
+
+ This class implements a strategy that generates multiple sets of random examples,
+ evaluates their performance, and selects the best performing set.
+ """
+
+ def select_exemplars(self, prompt, n_examples: int = 5, n_trials: int = 5):
+ """Select exemplars using a random search strategy.
+
+ This method generates multiple sets of random examples, evaluates their performance
+ when combined with the original prompt, and returns the best performing set.
+
+ Args:
+ prompt (str): The input prompt to base the exemplar selection on.
+ n_examples (int, optional): The number of exemplars to select in each trial. Defaults to 5.
+ n_trials (int, optional): The number of random trials to perform. Defaults to 5.
+
+ Returns:
+ str: The best performing prompt, which includes the original prompt and the selected exemplars.
+ """
+ best_score = 0
+ best_prompt = prompt
+
+ for _ in range(n_trials):
+ _, seq = self.task.evaluate(prompt, self.predictor, n_samples=n_examples, subsample=True, return_seq=True)
+ prompt_with_examples = "\n\n".join([prompt] + seq) + "\n\n"
+ # evaluate prompts as few shot prompt
+ score = self.task.evaluate(prompt_with_examples, self.predictor, subsample=True)
+ if score > best_score:
+ best_score = score
+ best_prompt = prompt_with_examples
+
+ return best_prompt
diff --git a/promptolution/exemplar_selectors/random_selector.py b/promptolution/exemplar_selectors/random_selector.py
new file mode 100644
index 0000000..5fe01ae
--- /dev/null
+++ b/promptolution/exemplar_selectors/random_selector.py
@@ -0,0 +1,46 @@
+"""Random exemplar selector."""
+
+from promptolution.exemplar_selectors.base_exemplar_selector import BaseExemplarSelector
+from promptolution.predictors.base_predictor import BasePredictor
+from promptolution.tasks.base_task import BaseTask
+
+
+class RandomSelector(BaseExemplarSelector):
+ """A selector that randomly selects correct exemplars.
+
+ This class implements a strategy that generates random examples and selects
+ those that are evaluated as correct until the desired number of exemplars is reached.
+ """
+
+ def __init__(self, task: BaseTask, predictor: BasePredictor, desired_score: int = 1):
+ """Initialize the RandomSelector.
+
+ Args:
+ task (BaseTask): An object representing the task to be performed.
+ predictor (BasePredictor): An object capable of making predictions based on prompts.
+ desired_score (int, optional): The desired score for the exemplars. Defaults to 1.
+ """
+ super().__init__(task, predictor)
+ self.desired_score = desired_score
+
+ def select_exemplars(self, prompt, n_examples: int = 5):
+ """Select exemplars using a random selection strategy.
+
+ This method generates random examples and selects those that are evaluated as correct
+ (score == self.desired_score) until the desired number of exemplars is reached.
+
+ Args:
+ prompt (str): The input prompt to base the exemplar selection on.
+ n_examples (int, optional): The number of exemplars to select. Defaults to 5.
+
+ Returns:
+ str: A new prompt that includes the original prompt and the selected exemplars.
+ """
+ examples = []
+ while len(examples) < n_examples:
+ score, seq = self.task.evaluate(prompt, self.predictor, n_samples=1, return_seq=True)
+ if score == self.desired_score:
+ examples.append(seq[0])
+ prompt = "\n\n".join([prompt] + examples) + "\n\n"
+
+ return prompt
diff --git a/promptolution/helpers.py b/promptolution/helpers.py
index 11f942d..9d776a9 100644
--- a/promptolution/helpers.py
+++ b/promptolution/helpers.py
@@ -6,6 +6,7 @@
import pandas as pd
from promptolution.config import Config
+from promptolution.exemplar_selectors import get_exemplar_selector
from promptolution.llms import get_llm
from promptolution.optimizers import get_optimizer
from promptolution.predictors import Classificator
@@ -50,10 +51,15 @@ def run_optimization(config: Config):
initial_prompts=init_pop,
task=task,
predictor=predictor,
+ n_eval_samples=config.n_eval_samples,
)
prompts = optimizer.optimize(n_steps=config.n_steps)
+ if config.prepend_exemplars:
+ selector = get_exemplar_selector(config.exemplar_selector, task, predictor)
+ prompts = [selector.select_exemplars(p, n_examples=config.n_exemplars) for p in prompts]
+
return prompts
@@ -72,7 +78,7 @@ def run_evaluation(config: Config, prompts: List[str]):
llm = get_llm(config.evaluation_llm, token=config.api_token)
predictor = Classificator(llm, classes=task.classes)
- scores = task.evaluate(prompts, predictor)
+ scores = task.evaluate(prompts, predictor, subsample=True, n_samples=config.n_eval_samples)
df = pd.DataFrame(dict(prompt=prompts, score=scores))
df = df.sort_values("score", ascending=False)
diff --git a/promptolution/llms/api_llm.py b/promptolution/llms/api_llm.py
index df3410e..1c34709 100644
--- a/promptolution/llms/api_llm.py
+++ b/promptolution/llms/api_llm.py
@@ -72,14 +72,11 @@ def __init__(self, model_id: str, token: str = None):
ValueError: If an unknown model identifier is provided.
"""
if "claude" in model_id:
- ANTHROPIC_API_KEY = open("anthropictoken.txt", "r").read() if token is None else token
- self.model = ChatAnthropic(model=model_id, api_key=ANTHROPIC_API_KEY)
+ self.model = ChatAnthropic(model=model_id, api_key=token)
elif "gpt" in model_id:
- OPENAI_API_KEY = open("openaitoken.txt", "r").read() if token is None else token
- self.model = ChatOpenAI(model=model_id, api_key=OPENAI_API_KEY)
+ self.model = ChatOpenAI(model=model_id, api_key=token)
else:
- DEEPINFRA_API_KEY = open("deepinfratoken.txt", "r").read() if token is None else token
- self.model = ChatDeepInfra(model_name=model_id, deepinfra_api_token=DEEPINFRA_API_KEY)
+ self.model = ChatDeepInfra(model_name=model_id, deepinfra_api_token=token)
def get_response(self, prompts: List[str]) -> List[str]:
"""Get responses for a list of prompts in a synchronous manner.
diff --git a/promptolution/optimizers/__init__.py b/promptolution/optimizers/__init__.py
index 11d2956..ae4ed93 100644
--- a/promptolution/optimizers/__init__.py
+++ b/promptolution/optimizers/__init__.py
@@ -66,7 +66,7 @@ def get_optimizer(
if config.optimizer == "opro":
prompt_template = OPRO_TEMPLATE
prompt_template = config.meta_prompt if config.meta_prompt else prompt_template
- n_samples = kwargs.get("n_samples", config.n_samples if config is not None else None)
+ n_samples = kwargs.get("n_samples", config.n_ds_samples_to_meta if config is not None else None)
return Opro(prompt_template=prompt_template, n_samples=n_samples, *args, **kwargs)
raise ValueError(f"Unknown optimizer: {config.optimizer}")
diff --git a/promptolution/optimizers/base_optimizer.py b/promptolution/optimizers/base_optimizer.py
index effc329..2cac685 100644
--- a/promptolution/optimizers/base_optimizer.py
+++ b/promptolution/optimizers/base_optimizer.py
@@ -26,12 +26,20 @@ class BaseOptimizer(ABC):
predictor (optional): Predictor for prompt evaluation. Defaults to None.
"""
- def __init__(self, initial_prompts: list[str], task: BaseTask, callbacks: list[Callable] = [], predictor=None):
+ def __init__(
+ self,
+ initial_prompts: list[str],
+ task: BaseTask,
+ callbacks: list[Callable] = [],
+ predictor=None,
+ n_eval_samples=20,
+ ):
"""Initialize the BaseOptimizer."""
self.prompts = initial_prompts
self.task = task
self.callbacks = callbacks
self.predictor = predictor
+ self.n_eval_samples = n_eval_samples
@abstractmethod
def optimize(self, n_steps: int) -> List[str]:
diff --git a/promptolution/optimizers/evoprompt_de.py b/promptolution/optimizers/evoprompt_de.py
index 7772a03..17d74b3 100644
--- a/promptolution/optimizers/evoprompt_de.py
+++ b/promptolution/optimizers/evoprompt_de.py
@@ -51,7 +51,7 @@ def optimize(self, n_steps: int) -> List[str]:
Returns:
List[str]: The optimized list of prompts after all steps.
"""
- self.scores = self.task.evaluate(self.prompts, self.predictor)
+ self.scores = self.task.evaluate(self.prompts, self.predictor, subsample=True, n_samples=self.n_eval_samples)
self.prompts = [prompt for _, prompt in sorted(zip(self.scores, self.prompts), reverse=True)]
self.scores = sorted(self.scores, reverse=True)
@@ -80,7 +80,9 @@ def optimize(self, n_steps: int) -> List[str]:
child_prompts = self.meta_llm.get_response(meta_prompts)
child_prompts = [prompt.split("")[-1].split("")[0].strip() for prompt in child_prompts]
- child_scores = self.task.evaluate(child_prompts, self.predictor)
+ child_scores = self.task.evaluate(
+ child_prompts, self.predictor, subsample=True, n_samples=self.n_eval_samples
+ )
for i in range(len(self.prompts)):
if child_scores[i] > self.scores[i]:
diff --git a/promptolution/optimizers/evoprompt_ga.py b/promptolution/optimizers/evoprompt_ga.py
index 2393ef5..2ec789b 100644
--- a/promptolution/optimizers/evoprompt_ga.py
+++ b/promptolution/optimizers/evoprompt_ga.py
@@ -56,7 +56,9 @@ def optimize(self, n_steps: int) -> List[str]:
List[str]: The optimized list of prompts after all steps.
"""
# get scores from task
- self.scores = self.task.evaluate(self.prompts, self.predictor).tolist()
+ self.scores = self.task.evaluate(
+ self.prompts, self.predictor, subsample=True, n_samples=self.n_eval_samples
+ ).tolist()
# sort prompts by score
self.prompts = [prompt for _, prompt in sorted(zip(self.scores, self.prompts), reverse=True)]
self.scores = sorted(self.scores, reverse=True)
@@ -64,7 +66,12 @@ def optimize(self, n_steps: int) -> List[str]:
for _ in range(n_steps):
new_prompts = self._crossover(self.prompts, self.scores)
prompts = self.prompts + new_prompts
- scores = self.scores + self.task.evaluate(new_prompts, self.predictor).tolist()
+ scores = (
+ self.scores
+ + self.task.evaluate(
+ new_prompts, self.predictor, subsample=True, n_samples=self.n_eval_samples
+ ).tolist()
+ )
# sort scores and prompts
self.prompts = [prompt for _, prompt in sorted(zip(scores, prompts), reverse=True)][: len(self.prompts)]
diff --git a/promptolution/optimizers/opro.py b/promptolution/optimizers/opro.py
index ddf504e..b2fa645 100644
--- a/promptolution/optimizers/opro.py
+++ b/promptolution/optimizers/opro.py
@@ -38,7 +38,9 @@ def __init__(self, meta_llm: BaseLLM, n_samples: int = 2, prompt_template: str =
super().__init__(**args)
self.meta_prompt = self.meta_prompt.replace("", self.task.description)
- self.scores = [self.task.evaluate(p, self.predictor) for p in self.prompts]
+ self.scores = [
+ self.task.evaluate(p, self.predictor, subsample=True, n_samples=self.n_eval_samples) for p in self.prompts
+ ]
def _sample_examples(self):
"""Sample examples from the task dataset with their label.
@@ -78,7 +80,7 @@ def optimize(self, n_steps: int) -> List[str]:
prompt = self.meta_llm.get_response([meta_prompt])[0]
prompt = prompt.split("")[-1].split("")[0].strip()
- score = self.task.evaluate(prompt, self.predictor)
+ score = self.task.evaluate(prompt, self.predictor, subsample=True, n_samples=self.n_eval_samples)
self.prompts.append(prompt)
self.scores.append(score)
diff --git a/promptolution/predictors/base_predictor.py b/promptolution/predictors/base_predictor.py
index 941ee9a..eea7f74 100644
--- a/promptolution/predictors/base_predictor.py
+++ b/promptolution/predictors/base_predictor.py
@@ -1,10 +1,12 @@
"""Base module for predictors."""
from abc import abstractmethod
-from typing import List
+from typing import List, Tuple
import numpy as np
+from promptolution.llms.base_llm import BaseLLM
+
class BasePredictor:
"""Abstract base class for predictors in the promptolution library.
@@ -12,37 +14,30 @@ class BasePredictor:
This class defines the interface that all concrete predictor implementations should follow.
Attributes:
- model_id (str): Identifier for the model used by the predictor.
- classes (List[str]): List of possible class labels for classification tasks.
+ llm: The language model used for generating predictions.
+
Methods:
predict: An abstract method that should be implemented by subclasses
to make predictions based on prompts and input data.
"""
- def __init__(self, model_id, classes, *args, **kwargs):
+ def __init__(self, llm: BaseLLM):
"""Initialize the BasePredictor.
Args:
- model_id (str): Identifier for the model to use.
- classes (List[str]): List of possible class labels.
- *args: Variable length argument list.
- **kwargs: Arbitrary keyword arguments.
+ llm: The language model to use for predictions.
+ classes (List[str]): The list of valid class labels.
"""
- self.model_id = model_id
- self.classes = classes
+ self.llm = llm
- @abstractmethod
- def predict(
- self,
- prompts: List[str],
- xs: np.ndarray,
- ) -> np.ndarray:
+ def predict(self, prompts: List[str], xs: np.ndarray, return_seq: bool = False) -> np.ndarray:
"""Abstract method to make predictions based on prompts and input data.
Args:
prompts (List[str]): List of prompts to use for prediction.
xs (np.ndarray): Array of input data.
+ return_seq (bool, optional): whether to return the generating sequence
Returns:
np.ndarray: Array of predictions.
@@ -50,6 +45,24 @@ def predict(
Raises:
NotImplementedError: If not implemented by a subclass.
"""
+ if isinstance(prompts, str):
+ prompts = [prompts]
+
+ outputs = self.llm.get_response([prompt + "\n" + x for prompt in prompts for x in xs])
+ preds = self._extract_preds(outputs, (len(prompts), len(xs)))
+
+ if return_seq:
+ return preds, [i + "\n" + o for i, o in zip(xs, outputs)]
+
+ return preds
+
+ def _extract_preds(self, preds: List[str], shape: Tuple[int, int]) -> np.ndarray:
+ """Extract class labels from the predictions, based on the list of valid class labels.
+
+ Args:
+ preds: The raw predictions from the language model.
+ shape: The shape of the output array: (n_prompts, n_samples).
+ """
raise NotImplementedError
diff --git a/promptolution/predictors/classificator.py b/promptolution/predictors/classificator.py
index 7cf6fe9..f33bfc6 100644
--- a/promptolution/predictors/classificator.py
+++ b/promptolution/predictors/classificator.py
@@ -1,6 +1,6 @@
"""Module for classification predictors."""
-from typing import List
+from typing import List, Tuple
import numpy as np
@@ -11,7 +11,10 @@ class Classificator(BasePredictor):
"""A predictor class for classification tasks using language models.
This class takes a language model and a list of classes, and provides a method
- to predict classes for given prompts and input data.
+ to predict classes for given prompts and input data. The class labels are extracted
+ by matching the words in the prediction with the list of valid class labels.
+ The first occurrence of a valid class label in the prediction is used as the predicted class.
+ If no valid class label is found, the first class label in the list is used as the default prediction.
Attributes:
llm: The language model used for generating predictions.
@@ -28,39 +31,19 @@ def __init__(self, llm, classes, *args, **kwargs):
llm: The language model to use for predictions.
classes (List[str]): The list of valid class labels.
"""
- self.llm = llm
+ super().__init__(llm)
self.classes = classes
- def predict(
- self,
- prompts: List[str],
- xs: np.ndarray,
- ) -> np.ndarray:
- """Predict classes for given prompts and input data.
-
- This method generates predictions using the language model and then
- extracts the predicted class from the model's output.
+ def _extract_preds(self, preds: List[str], shape: Tuple[int, int]) -> np.ndarray:
+ """Extract class labels from the predictions, based on the list of valid class labels.
Args:
- prompts (List[str]): The list of prompts to use for prediction.
- xs (np.ndarray): The input data array.
-
- Returns:
- np.ndarray: A 2D array of predicted classes, with shape (len(prompts), len(xs)).
-
- Note:
- The method concatenates each prompt with each input data point,
- passes it to the language model, and then extracts the first word
- in the response that matches a class in self.classes.
+ preds: The raw predictions from the language model.
+ shape: The shape of the output array: (n_prompts, n_samples).
"""
- if isinstance(prompts, str):
- prompts = [prompts]
-
- preds = self.llm.get_response([prompt + "\n" + x for prompt in prompts for x in xs])
-
response = []
for pred in preds:
- predicted_class = ""
+ predicted_class = self.classes[0] # use first class as default pred
for word in pred.split(" "):
word = "".join([c for c in word if c.isalnum()])
if word in self.classes:
@@ -69,5 +52,5 @@ def predict(
response.append(predicted_class)
- response = np.array(response).reshape(len(prompts), len(xs))
+ response = np.array(response).reshape(*shape)
return response
diff --git a/promptolution/tasks/classification_tasks.py b/promptolution/tasks/classification_tasks.py
index 9da7a17..f37deec 100644
--- a/promptolution/tasks/classification_tasks.py
+++ b/promptolution/tasks/classification_tasks.py
@@ -2,9 +2,10 @@
import json
from pathlib import Path
-from typing import Dict, List, Literal, Optional
+from typing import Callable, Dict, List, Literal, Optional
import numpy as np
+from sklearn.metrics import accuracy_score
from promptolution.predictors.base_predictor import BasePredictor
from promptolution.tasks.base_task import BaseTask
@@ -25,8 +26,9 @@ class ClassificationTask(BaseTask):
xs (Optional[np.ndarray]): Input data for the task.
ys (Optional[np.ndarray]): Ground truth labels for the task.
classes (Optional[List]): List of possible class labels.
- split (Literal["dev", "test"]): Dataset split to use.
seed (int): Random seed for reproducibility.
+ split (Literal["dev", "test"]): Dataset split to use.
+ metric (Callable): Metric to use as an evaluation score for the prompts.
Inherits from:
BaseTask: The base class for tasks in the promptolution library.
@@ -38,6 +40,7 @@ def __init__(
task_id: str = "Classification Task",
seed: int = 42,
split: Literal["dev", "test"] = "dev",
+ metric: Callable = accuracy_score,
):
"""Initialize the ClassificationTask.
@@ -46,6 +49,7 @@ def __init__(
dataset_path (str): Path to the dataset description JSON file.
seed (int, optional): Random seed for reproducibility. Defaults to 42.
split (Literal["dev", "test"], optional): Dataset split to use. Defaults to "dev".
+ metric (Callable): Metric to use as an evaluation score for the prompts. Defaults to sklearn's accuracy.
"""
self.task_id: str = task_id
self.path: Path = dataset_path
@@ -56,6 +60,7 @@ def __init__(
self.ys: Optional[np.ndarray] = None
self.classes: Optional[List] = None
self.split: Literal["dev", "test"] = split
+ self.metric = metric
self._parse_task()
self.reset_seed(seed)
@@ -95,7 +100,12 @@ def _parse_task(self):
self.ys = np.array(ys)
def evaluate(
- self, prompts: List[str], predictor: BasePredictor, n_samples: int = 20, subsample: bool = True
+ self,
+ prompts: List[str],
+ predictor: BasePredictor,
+ n_samples: int = 20,
+ subsample: bool = False,
+ return_seq: bool = False,
) -> np.ndarray:
"""Evaluate a set of prompts using a given predictor.
@@ -103,7 +113,9 @@ def evaluate(
prompts (List[str]): List of prompts to evaluate.
predictor (BasePredictor): Predictor to use for evaluation.
n_samples (int, optional): Number of samples to use if subsampling. Defaults to 20.
- subsample (bool, optional): Whether to use subsampling. Defaults to True.
+ subsample (bool, optional): Whether to use subsampling.
+ If set to true, samples a different subset per call. Defaults to False.
+ return_seq (bool, optional): whether to return the generating sequence
Returns:
np.ndarray: Array of accuracy scores for each prompt.
@@ -120,10 +132,17 @@ def evaluate(
ys_subsample = self.ys[indices]
# Make predictions on the subsample
- preds = predictor.predict(prompts, xs_subsample)
+ preds = predictor.predict(prompts, xs_subsample, return_seq=return_seq)
+
+ if return_seq:
+ preds, seqs = preds
+
+ scores = np.array([self.metric(ys_subsample, pred) for pred in preds])
+
+ if return_seq:
+ return scores, seqs
- # Calculate accuracy: number of correct predictions / total number of predictions per prompt
- return np.mean(preds == ys_subsample, axis=1)
+ return scores
def reset_seed(self, seed: int = None):
"""Reset the random seed."""
diff --git a/pyproject.toml b/pyproject.toml
index 659f202..8635296 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "promptolution"
-version = "0.2.0"
+version = "1.0.0"
description = ""
authors = ["Tom Zehle, Moritz Schlager, Timo Heiß"]
readme = "README.md"
@@ -14,6 +14,7 @@ langchain-core = "^0.2.29"
langchain-community = "^0.2.12"
pandas = "^2.2.2"
tqdm = "^4.66.5"
+scikit-learn = "^1.5.2"
[tool.poetry.group.dev.dependencies]
matplotlib = "^3.9.2"