diff --git a/notebooks/getting_started.ipynb b/notebooks/getting_started.ipynb index 5807cd8..0c38f11 100644 --- a/notebooks/getting_started.ipynb +++ b/notebooks/getting_started.ipynb @@ -44,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -64,9 +64,18 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\tzehl\\Documents\\programming\\promptolution\\.venv\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "from promptolution.helpers import run_experiment\n", "from promptolution.config import Config" @@ -81,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -90,25 +99,28 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "config = Config(\n", - " task_name=\"agnews\",\n", - " ds_path=\"../data_sets/cls/agnews/\",\n", - " n_steps=3,\n", - " optimizer=\"evopromptga\",\n", + " task_name=\"subj\",\n", + " ds_path=\"../data_sets/cls/subj/\",\n", + " n_steps=8,\n", + " optimizer=\"evopromptde\",\n", " meta_llm=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n", " evaluation_llm=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n", " downstream_llm=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n", " api_token=token,\n", + " prepend_exemplars=True,\n", + " exemplar_selector=\"random_search\",\n", + " n_exemplars=3,\n", ")" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -117,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -147,104 +159,74 @@ " \n", " \n", " \n", - " 0\n", - " Classify the news story into one of the follow...\n", - " 0.95\n", - " \n", - " \n", - " 7\n", - " You will be required to classify a news articl...\n", - " 0.90\n", - " \n", - " \n", - " 11\n", - " Classify the topic of the following news as \"W...\n", - " 0.90\n", + " 1\n", + " evaluate each sentence as either objective or ...\n", + " 0.80\n", " \n", " \n", - " 12\n", - " Classify news articles into categories (World,...\n", - " 0.90\n", + " 8\n", + " As a linguist, analyze a statement from a movi...\n", + " 0.80\n", " \n", " \n", - " 4\n", - " Classify the given news article into one of th...\n", - " 0.85\n", + " 3\n", + " identify whether the given sentence was expres...\n", + " 0.65\n", " \n", " \n", - " 6\n", - " Your job is to determine whether a news articl...\n", - " 0.85\n", + " 5\n", + " Analyze the textual content of a given stateme...\n", + " 0.65\n", " \n", " \n", - " 13\n", - " Categorize the provided news article according...\n", - " 0.85\n", + " 9\n", + " determine the classification of each sentence ...\n", + " 0.60\n", " \n", " \n", - " 1\n", - " Categorize the news article into one of four c...\n", - " 0.80\n", + " 0\n", + " evaluate each statement as either subjective o...\n", + " 0.50\n", " \n", " \n", " 2\n", - " Your responsibility is to accurately categoriz...\n", - " 0.80\n", + " Classify the sentence according to its subject...\n", + " 0.40\n", " \n", " \n", - " 3\n", - " Identify the primary theme of a news article a...\n", - " 0.80\n", - " \n", - " \n", - " 8\n", - " In this task, you are given a news article. Yo...\n", - " 0.80\n", - " \n", - " \n", - " 14\n", - " Accurately categorize news articles into World...\n", - " 0.80\n", - " \n", - " \n", - " 5\n", - " Accurately categorize the provided news articl...\n", - " 0.75\n", + " 6\n", + " As a classifier, interpret phrases in movie re...\n", + " 0.35\n", " \n", " \n", - " 9\n", - " Determine the theme of the news item. Choose f...\n", - " 0.75\n", + " 7\n", + " and\\n\\nshae is about to return to bed when she...\n", + " 0.35\n", " \n", " \n", - " 10\n", - " Determine the primary theme of the news articl...\n", - " 0.70\n", + " 4\n", + " Analyze reviews and label them as subjective o...\n", + " 0.30\n", " \n", " \n", "\n", "" ], "text/plain": [ - " prompt score\n", - "0 Classify the news story into one of the follow... 0.95\n", - "7 You will be required to classify a news articl... 0.90\n", - "11 Classify the topic of the following news as \"W... 0.90\n", - "12 Classify news articles into categories (World,... 0.90\n", - "4 Classify the given news article into one of th... 0.85\n", - "6 Your job is to determine whether a news articl... 0.85\n", - "13 Categorize the provided news article according... 0.85\n", - "1 Categorize the news article into one of four c... 0.80\n", - "2 Your responsibility is to accurately categoriz... 0.80\n", - "3 Identify the primary theme of a news article a... 0.80\n", - "8 In this task, you are given a news article. Yo... 0.80\n", - "14 Accurately categorize news articles into World... 0.80\n", - "5 Accurately categorize the provided news articl... 0.75\n", - "9 Determine the theme of the news item. Choose f... 0.75\n", - "10 Determine the primary theme of the news articl... 0.70" + " prompt score\n", + "1 evaluate each sentence as either objective or ... 0.80\n", + "8 As a linguist, analyze a statement from a movi... 0.80\n", + "3 identify whether the given sentence was expres... 0.65\n", + "5 Analyze the textual content of a given stateme... 0.65\n", + "9 determine the classification of each sentence ... 0.60\n", + "0 evaluate each statement as either subjective o... 0.50\n", + "2 Classify the sentence according to its subject... 0.40\n", + "6 As a classifier, interpret phrases in movie re... 0.35\n", + "7 and\\n\\nshae is about to return to bed when she... 0.35\n", + "4 Analyze reviews and label them as subjective o... 0.30" ] }, - "execution_count": 11, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } diff --git a/poetry.lock b/poetry.lock index c3c7763..1938d9c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -1376,6 +1376,17 @@ files = [ {file = "jiter-0.5.0.tar.gz", hash = "sha256:1d916ba875bcab5c5f7d927df998c4cb694d27dceddf3392e58beaf10563368a"}, ] +[[package]] +name = "joblib" +version = "1.4.2" +description = "Lightweight pipelining with Python functions" +optional = false +python-versions = ">=3.8" +files = [ + {file = "joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6"}, + {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, +] + [[package]] name = "jsonpatch" version = "1.33" @@ -3361,6 +3372,106 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"] torch = ["safetensors[numpy]", "torch (>=1.10)"] +[[package]] +name = "scikit-learn" +version = "1.5.2" +description = "A set of python modules for machine learning and data mining" +optional = false +python-versions = ">=3.9" +files = [ + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:299406827fb9a4f862626d0fe6c122f5f87f8910b86fe5daa4c32dcd742139b6"}, + {file = "scikit_learn-1.5.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2d4cad1119c77930b235579ad0dc25e65c917e756fe80cab96aa3b9428bd3fb0"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c412ccc2ad9bf3755915e3908e677b367ebc8d010acbb3f182814524f2e5540"}, + {file = "scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a686885a4b3818d9e62904d91b57fa757fc2bed3e465c8b177be652f4dd37c8"}, + {file = "scikit_learn-1.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:c15b1ca23d7c5f33cc2cb0a0d6aaacf893792271cddff0edbd6a40e8319bc113"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03b6158efa3faaf1feea3faa884c840ebd61b6484167c711548fce208ea09445"}, + {file = "scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1ff45e26928d3b4eb767a8f14a9a6efbf1cbff7c05d1fb0f95f211a89fd4f5de"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f763897fe92d0e903aa4847b0aec0e68cadfff77e8a0687cabd946c89d17e675"}, + {file = "scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8b0ccd4a902836493e026c03256e8b206656f91fbcc4fde28c57a5b752561f1"}, + {file = "scikit_learn-1.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:6c16d84a0d45e4894832b3c4d0bf73050939e21b99b01b6fd59cbb0cf39163b6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f932a02c3f4956dfb981391ab24bda1dbd90fe3d628e4b42caef3e041c67707a"}, + {file = "scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3b923d119d65b7bd555c73be5423bf06c0105678ce7e1f558cb4b40b0a5502b1"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, + {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, + {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, + {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, + {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, + {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca64b3089a6d9b9363cd3546f8978229dcbb737aceb2c12144ee3f70f95684b7"}, + {file = "scikit_learn-1.5.2-cp39-cp39-win_amd64.whl", hash = "sha256:3bed4909ba187aca80580fe2ef370d9180dcf18e621a27c4cf2ef10d279a7efe"}, + {file = "scikit_learn-1.5.2.tar.gz", hash = "sha256:b4237ed7b3fdd0a4882792e68ef2545d5baa50aca3bb45aa7df468138ad8f94d"}, +] + +[package.dependencies] +joblib = ">=1.2.0" +numpy = ">=1.19.5" +scipy = ">=1.6.0" +threadpoolctl = ">=3.1.0" + +[package.extras] +benchmark = ["matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "pandas (>=1.1.5)"] +build = ["cython (>=3.0.10)", "meson-python (>=0.16.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)"] +docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.3.4)", "memory_profiler (>=0.57.0)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pydata-sphinx-theme (>=0.15.3)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)", "sphinx (>=7.3.7)", "sphinx-copybutton (>=0.5.2)", "sphinx-design (>=0.5.0)", "sphinx-design (>=0.6.0)", "sphinx-gallery (>=0.16.0)", "sphinx-prompt (>=1.4.0)", "sphinx-remove-toctrees (>=1.0.0.post1)", "sphinxcontrib-sass (>=0.3.4)", "sphinxext-opengraph (>=0.9.1)"] +examples = ["matplotlib (>=3.3.4)", "pandas (>=1.1.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.17.2)", "seaborn (>=0.9.0)"] +install = ["joblib (>=1.2.0)", "numpy (>=1.19.5)", "scipy (>=1.6.0)", "threadpoolctl (>=3.1.0)"] +maintenance = ["conda-lock (==2.5.6)"] +tests = ["black (>=24.3.0)", "matplotlib (>=3.3.4)", "mypy (>=1.9)", "numpydoc (>=1.2.0)", "pandas (>=1.1.5)", "polars (>=0.20.30)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pyarrow (>=12.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.2.1)", "scikit-image (>=0.17.2)"] + +[[package]] +name = "scipy" +version = "1.14.1" +description = "Fundamental algorithms for scientific computing in Python" +optional = false +python-versions = ">=3.10" +files = [ + {file = "scipy-1.14.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0"}, + {file = "scipy-1.14.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d"}, + {file = "scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69"}, + {file = "scipy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad"}, + {file = "scipy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617"}, + {file = "scipy-1.14.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37"}, + {file = "scipy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2"}, + {file = "scipy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2"}, + {file = "scipy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5"}, + {file = "scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310"}, + {file = "scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066"}, + {file = "scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1"}, + {file = "scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73"}, + {file = "scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d"}, + {file = "scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e"}, + {file = "scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06"}, + {file = "scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84"}, + {file = "scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417"}, +] + +[package.dependencies] +numpy = ">=1.23.5,<2.3" + +[package.extras] +dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"] +doc = ["jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.13.1)", "jupytext", "matplotlib (>=3.5)", "myst-nb", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<=7.3.7)", "sphinx-design (>=0.4.0)"] +test = ["Cython", "array-api-strict (>=2.0)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja", "pooch", "pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] + [[package]] name = "seaborn" version = "0.13.2" @@ -3525,6 +3636,17 @@ files = [ doc = ["reno", "sphinx"] test = ["pytest", "tornado (>=4.5)", "typeguard"] +[[package]] +name = "threadpoolctl" +version = "3.5.0" +description = "threadpoolctl" +optional = false +python-versions = ">=3.8" +files = [ + {file = "threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467"}, + {file = "threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107"}, +] + [[package]] name = "tiktoken" version = "0.7.0" @@ -4050,4 +4172,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "7f2a00d58b72f3b7cec0991808ffc354c5f10f87e1c78bd9ed4d5369932d243d" +content-hash = "6c9aacc81e214e934481f8764b4ecf4db4366f0860952bf045649e2b405f83a5" diff --git a/promptolution/config.py b/promptolution/config.py index 10000f5..dac2d9a 100644 --- a/promptolution/config.py +++ b/promptolution/config.py @@ -2,7 +2,7 @@ import configparser from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional @dataclass @@ -13,13 +13,13 @@ class Config: either from a config file or from keyword arguments. Attributes: - task_name (str): Name of the task. - ds_path (str): Path to the dataset. - n_steps (int): Number of optimization steps. - optimizer (str): Name of the optimizer to use. - meta_llm (str): Name of the meta language model. - downstream_llm (str): Name of the downstream language model. - evaluation_llm (str): Name of the evaluation language model. + task_name (str): Name of the task. Should not be None if used. + ds_path (str): Path to the dataset. Should not be None if used. + n_steps (int): Number of optimization steps. Should not be None if used. + optimizer (str): Name of the optimizer to use. Should not be None if used. + meta_llm (str): Name of the meta language model. Should not be None if used. + downstream_llm (str): Name of the downstream language model. Should not be None if used. + evaluation_llm (str): Name of the evaluation language model. Should not be None if used. init_pop_size (int): Initial population size. Defaults to 10. logging_dir (str): Directory for logging. Defaults to "logs/run.csv". experiment_name (str): Name of the experiment. Defaults to "experiment". @@ -27,10 +27,20 @@ class Config: donor_random (bool): Whether to use random donor prompts for EvoPromptDE. Defaults to False. random_seed (int): Random seed for reproducibility. Defaults to 42. selection_mode (str): Selection mode for EvoPromptGA. Defaults to "random". - meta_bs (int): Batch size for local meta LLM. Defaults to None. - downstream_bs (int): Batch size for local downstream LLM. Defaults to None. - api_token (str): API token for different APIs, as implemented in LLM classes. Defaults to None. - meta_prompt (str): Prompt template for the meta LLM. Defaults to None. + meta_bs (int): Batch size for local meta LLM. Should not be None if llm is run locally. Defaults to None. + downstream_bs (int): Batch size for local downstream LLM. + Should not be None if llm is run locally Defaults to None. + api_token (str): API token for different APIs, as implemented in LLM classes. + Should not be None if APILLM is used. Defaults to None. + meta_prompt (str): Prompt template for the meta LLM. + If None is set, default meta_prompts from template.py will be used. Defaults to None. + prepend_exemplars (bool): rather to do exemplar search and prepend few-shot examples. Defaults to False. + n_exemplars (int): how many exemplars to prepend. Only used if prepend_exemplars is True. Defaults to 5. + exemplar_selector (str): which exemplar selector to use. Should not be None if preped_exemplars is True. + Defaults to None. + n_ds_samples_to_meta (int): how many examples to show of the ds to show to meta-llm + (not applicable to every optimizer) + n_eval_samples (int): how many examples to show to evaluation llm for evaluation. """ task_name: str = None @@ -46,11 +56,16 @@ class Config: include_task_desc: bool = True donor_random: bool = False random_seed: int = 42 - selection_mode: Optional[str] = "random" + selection_mode: Optional[Literal["random", "wheel", "tour"]] = "random" meta_bs: Optional[int] = None downstream_bs: Optional[int] = None api_token: Optional[str] = None meta_prompt: Optional[str] = None + prepend_exemplars: Optional[bool] = False + n_exemplars: Optional[int] = 5 + exemplar_selector: Optional[str] = None + n_ds_samples_to_meta: Optional[int] = 2 + n_eval_samples: Optional[int] = 20 def __post_init__(self): """Validate the configuration after initialization.""" diff --git a/promptolution/exemplar_selectors/__init__.py b/promptolution/exemplar_selectors/__init__.py new file mode 100644 index 0000000..f234373 --- /dev/null +++ b/promptolution/exemplar_selectors/__init__.py @@ -0,0 +1,33 @@ +"""Module for exemplar selectors.""" + +from typing import Literal + +from promptolution.exemplar_selectors.random_search_selector import RandomSearchSelector +from promptolution.exemplar_selectors.random_selector import RandomSelector +from promptolution.predictors.base_predictor import BasePredictor +from promptolution.tasks.base_task import BaseTask + +SELECTOR_MAP = { + "random": RandomSelector, + "random_search": RandomSearchSelector, +} + + +def get_exemplar_selector(name: Literal["random", "random_search"], task: BaseTask, predictor: BasePredictor): + """Factory function to get an exemplar selector based on the given name. + + Args: + name (str): The name of the exemplar selector to instantiate. + task (BaseTask): The task object to be passed to the selector. + predictor (BasePredictor): The predictor object to be passed to the selector. + + Returns: + BaseExemplarSelector: An instance of the requested exemplar selector. + + Raises: + ValueError: If the requested selector name is not found. + """ + if name not in SELECTOR_MAP: + raise ValueError(f"Exemplar selector '{name}' not found. Available selectors: {list(SELECTOR_MAP.keys())}") + + return SELECTOR_MAP[name](task, predictor) diff --git a/promptolution/exemplar_selectors/base_exemplar_selector.py b/promptolution/exemplar_selectors/base_exemplar_selector.py new file mode 100644 index 0000000..dd96e7b --- /dev/null +++ b/promptolution/exemplar_selectors/base_exemplar_selector.py @@ -0,0 +1,41 @@ +"""Base class for exemplar selectors.""" + +from abc import ABC, abstractmethod +from typing import Any, List, Tuple + +from promptolution.predictors.base_predictor import BasePredictor +from promptolution.tasks.base_task import BaseTask + + +class BaseExemplarSelector(ABC): + """An abstract base class for exemplar selectors. + + This class defines the basic interface and common functionality + that all exemplar selectors should implement. + """ + + def __init__(self, task: BaseTask, predictor: BasePredictor): + """Initialize the BaseExemplarSelector. + + Args: + task (BaseTask): An object representing the task to be performed. + predictor (BasePredictor): An object capable of making predictions based on prompts. + """ + self.task = task + self.predictor = predictor + + @abstractmethod + def select_exemplars(self, prompt: str, n_examples: int = 5) -> str: + """Select exemplars based on the given prompt. + + Args: + prompt (str): The input prompt to base the exemplar selection on. + n_examples (int, optional): The number of exemplars to select. Defaults to 5. + + Returns: + str: A new prompt that includes the original prompt and the selected exemplars. + + Raises: + NotImplementedError: This method should be implemented by subclasses. + """ + raise NotImplementedError("This method should be implemented by subclasses.") diff --git a/promptolution/exemplar_selectors/random_search_selector.py b/promptolution/exemplar_selectors/random_search_selector.py new file mode 100644 index 0000000..005fef8 --- /dev/null +++ b/promptolution/exemplar_selectors/random_search_selector.py @@ -0,0 +1,39 @@ +"""Random search exemplar selector.""" + +from promptolution.exemplar_selectors.base_exemplar_selector import BaseExemplarSelector + + +class RandomSearchSelector(BaseExemplarSelector): + """A selector that uses random search to find the best set of exemplars. + + This class implements a strategy that generates multiple sets of random examples, + evaluates their performance, and selects the best performing set. + """ + + def select_exemplars(self, prompt, n_examples: int = 5, n_trials: int = 5): + """Select exemplars using a random search strategy. + + This method generates multiple sets of random examples, evaluates their performance + when combined with the original prompt, and returns the best performing set. + + Args: + prompt (str): The input prompt to base the exemplar selection on. + n_examples (int, optional): The number of exemplars to select in each trial. Defaults to 5. + n_trials (int, optional): The number of random trials to perform. Defaults to 5. + + Returns: + str: The best performing prompt, which includes the original prompt and the selected exemplars. + """ + best_score = 0 + best_prompt = prompt + + for _ in range(n_trials): + _, seq = self.task.evaluate(prompt, self.predictor, n_samples=n_examples, subsample=True, return_seq=True) + prompt_with_examples = "\n\n".join([prompt] + seq) + "\n\n" + # evaluate prompts as few shot prompt + score = self.task.evaluate(prompt_with_examples, self.predictor, subsample=True) + if score > best_score: + best_score = score + best_prompt = prompt_with_examples + + return best_prompt diff --git a/promptolution/exemplar_selectors/random_selector.py b/promptolution/exemplar_selectors/random_selector.py new file mode 100644 index 0000000..5fe01ae --- /dev/null +++ b/promptolution/exemplar_selectors/random_selector.py @@ -0,0 +1,46 @@ +"""Random exemplar selector.""" + +from promptolution.exemplar_selectors.base_exemplar_selector import BaseExemplarSelector +from promptolution.predictors.base_predictor import BasePredictor +from promptolution.tasks.base_task import BaseTask + + +class RandomSelector(BaseExemplarSelector): + """A selector that randomly selects correct exemplars. + + This class implements a strategy that generates random examples and selects + those that are evaluated as correct until the desired number of exemplars is reached. + """ + + def __init__(self, task: BaseTask, predictor: BasePredictor, desired_score: int = 1): + """Initialize the RandomSelector. + + Args: + task (BaseTask): An object representing the task to be performed. + predictor (BasePredictor): An object capable of making predictions based on prompts. + desired_score (int, optional): The desired score for the exemplars. Defaults to 1. + """ + super().__init__(task, predictor) + self.desired_score = desired_score + + def select_exemplars(self, prompt, n_examples: int = 5): + """Select exemplars using a random selection strategy. + + This method generates random examples and selects those that are evaluated as correct + (score == self.desired_score) until the desired number of exemplars is reached. + + Args: + prompt (str): The input prompt to base the exemplar selection on. + n_examples (int, optional): The number of exemplars to select. Defaults to 5. + + Returns: + str: A new prompt that includes the original prompt and the selected exemplars. + """ + examples = [] + while len(examples) < n_examples: + score, seq = self.task.evaluate(prompt, self.predictor, n_samples=1, return_seq=True) + if score == self.desired_score: + examples.append(seq[0]) + prompt = "\n\n".join([prompt] + examples) + "\n\n" + + return prompt diff --git a/promptolution/helpers.py b/promptolution/helpers.py index 11f942d..9d776a9 100644 --- a/promptolution/helpers.py +++ b/promptolution/helpers.py @@ -6,6 +6,7 @@ import pandas as pd from promptolution.config import Config +from promptolution.exemplar_selectors import get_exemplar_selector from promptolution.llms import get_llm from promptolution.optimizers import get_optimizer from promptolution.predictors import Classificator @@ -50,10 +51,15 @@ def run_optimization(config: Config): initial_prompts=init_pop, task=task, predictor=predictor, + n_eval_samples=config.n_eval_samples, ) prompts = optimizer.optimize(n_steps=config.n_steps) + if config.prepend_exemplars: + selector = get_exemplar_selector(config.exemplar_selector, task, predictor) + prompts = [selector.select_exemplars(p, n_examples=config.n_exemplars) for p in prompts] + return prompts @@ -72,7 +78,7 @@ def run_evaluation(config: Config, prompts: List[str]): llm = get_llm(config.evaluation_llm, token=config.api_token) predictor = Classificator(llm, classes=task.classes) - scores = task.evaluate(prompts, predictor) + scores = task.evaluate(prompts, predictor, subsample=True, n_samples=config.n_eval_samples) df = pd.DataFrame(dict(prompt=prompts, score=scores)) df = df.sort_values("score", ascending=False) diff --git a/promptolution/llms/api_llm.py b/promptolution/llms/api_llm.py index df3410e..1c34709 100644 --- a/promptolution/llms/api_llm.py +++ b/promptolution/llms/api_llm.py @@ -72,14 +72,11 @@ def __init__(self, model_id: str, token: str = None): ValueError: If an unknown model identifier is provided. """ if "claude" in model_id: - ANTHROPIC_API_KEY = open("anthropictoken.txt", "r").read() if token is None else token - self.model = ChatAnthropic(model=model_id, api_key=ANTHROPIC_API_KEY) + self.model = ChatAnthropic(model=model_id, api_key=token) elif "gpt" in model_id: - OPENAI_API_KEY = open("openaitoken.txt", "r").read() if token is None else token - self.model = ChatOpenAI(model=model_id, api_key=OPENAI_API_KEY) + self.model = ChatOpenAI(model=model_id, api_key=token) else: - DEEPINFRA_API_KEY = open("deepinfratoken.txt", "r").read() if token is None else token - self.model = ChatDeepInfra(model_name=model_id, deepinfra_api_token=DEEPINFRA_API_KEY) + self.model = ChatDeepInfra(model_name=model_id, deepinfra_api_token=token) def get_response(self, prompts: List[str]) -> List[str]: """Get responses for a list of prompts in a synchronous manner. diff --git a/promptolution/optimizers/__init__.py b/promptolution/optimizers/__init__.py index 11d2956..ae4ed93 100644 --- a/promptolution/optimizers/__init__.py +++ b/promptolution/optimizers/__init__.py @@ -66,7 +66,7 @@ def get_optimizer( if config.optimizer == "opro": prompt_template = OPRO_TEMPLATE prompt_template = config.meta_prompt if config.meta_prompt else prompt_template - n_samples = kwargs.get("n_samples", config.n_samples if config is not None else None) + n_samples = kwargs.get("n_samples", config.n_ds_samples_to_meta if config is not None else None) return Opro(prompt_template=prompt_template, n_samples=n_samples, *args, **kwargs) raise ValueError(f"Unknown optimizer: {config.optimizer}") diff --git a/promptolution/optimizers/base_optimizer.py b/promptolution/optimizers/base_optimizer.py index effc329..2cac685 100644 --- a/promptolution/optimizers/base_optimizer.py +++ b/promptolution/optimizers/base_optimizer.py @@ -26,12 +26,20 @@ class BaseOptimizer(ABC): predictor (optional): Predictor for prompt evaluation. Defaults to None. """ - def __init__(self, initial_prompts: list[str], task: BaseTask, callbacks: list[Callable] = [], predictor=None): + def __init__( + self, + initial_prompts: list[str], + task: BaseTask, + callbacks: list[Callable] = [], + predictor=None, + n_eval_samples=20, + ): """Initialize the BaseOptimizer.""" self.prompts = initial_prompts self.task = task self.callbacks = callbacks self.predictor = predictor + self.n_eval_samples = n_eval_samples @abstractmethod def optimize(self, n_steps: int) -> List[str]: diff --git a/promptolution/optimizers/evoprompt_de.py b/promptolution/optimizers/evoprompt_de.py index 7772a03..17d74b3 100644 --- a/promptolution/optimizers/evoprompt_de.py +++ b/promptolution/optimizers/evoprompt_de.py @@ -51,7 +51,7 @@ def optimize(self, n_steps: int) -> List[str]: Returns: List[str]: The optimized list of prompts after all steps. """ - self.scores = self.task.evaluate(self.prompts, self.predictor) + self.scores = self.task.evaluate(self.prompts, self.predictor, subsample=True, n_samples=self.n_eval_samples) self.prompts = [prompt for _, prompt in sorted(zip(self.scores, self.prompts), reverse=True)] self.scores = sorted(self.scores, reverse=True) @@ -80,7 +80,9 @@ def optimize(self, n_steps: int) -> List[str]: child_prompts = self.meta_llm.get_response(meta_prompts) child_prompts = [prompt.split("")[-1].split("")[0].strip() for prompt in child_prompts] - child_scores = self.task.evaluate(child_prompts, self.predictor) + child_scores = self.task.evaluate( + child_prompts, self.predictor, subsample=True, n_samples=self.n_eval_samples + ) for i in range(len(self.prompts)): if child_scores[i] > self.scores[i]: diff --git a/promptolution/optimizers/evoprompt_ga.py b/promptolution/optimizers/evoprompt_ga.py index 2393ef5..2ec789b 100644 --- a/promptolution/optimizers/evoprompt_ga.py +++ b/promptolution/optimizers/evoprompt_ga.py @@ -56,7 +56,9 @@ def optimize(self, n_steps: int) -> List[str]: List[str]: The optimized list of prompts after all steps. """ # get scores from task - self.scores = self.task.evaluate(self.prompts, self.predictor).tolist() + self.scores = self.task.evaluate( + self.prompts, self.predictor, subsample=True, n_samples=self.n_eval_samples + ).tolist() # sort prompts by score self.prompts = [prompt for _, prompt in sorted(zip(self.scores, self.prompts), reverse=True)] self.scores = sorted(self.scores, reverse=True) @@ -64,7 +66,12 @@ def optimize(self, n_steps: int) -> List[str]: for _ in range(n_steps): new_prompts = self._crossover(self.prompts, self.scores) prompts = self.prompts + new_prompts - scores = self.scores + self.task.evaluate(new_prompts, self.predictor).tolist() + scores = ( + self.scores + + self.task.evaluate( + new_prompts, self.predictor, subsample=True, n_samples=self.n_eval_samples + ).tolist() + ) # sort scores and prompts self.prompts = [prompt for _, prompt in sorted(zip(scores, prompts), reverse=True)][: len(self.prompts)] diff --git a/promptolution/optimizers/opro.py b/promptolution/optimizers/opro.py index ddf504e..b2fa645 100644 --- a/promptolution/optimizers/opro.py +++ b/promptolution/optimizers/opro.py @@ -38,7 +38,9 @@ def __init__(self, meta_llm: BaseLLM, n_samples: int = 2, prompt_template: str = super().__init__(**args) self.meta_prompt = self.meta_prompt.replace("", self.task.description) - self.scores = [self.task.evaluate(p, self.predictor) for p in self.prompts] + self.scores = [ + self.task.evaluate(p, self.predictor, subsample=True, n_samples=self.n_eval_samples) for p in self.prompts + ] def _sample_examples(self): """Sample examples from the task dataset with their label. @@ -78,7 +80,7 @@ def optimize(self, n_steps: int) -> List[str]: prompt = self.meta_llm.get_response([meta_prompt])[0] prompt = prompt.split("")[-1].split("")[0].strip() - score = self.task.evaluate(prompt, self.predictor) + score = self.task.evaluate(prompt, self.predictor, subsample=True, n_samples=self.n_eval_samples) self.prompts.append(prompt) self.scores.append(score) diff --git a/promptolution/predictors/base_predictor.py b/promptolution/predictors/base_predictor.py index 941ee9a..eea7f74 100644 --- a/promptolution/predictors/base_predictor.py +++ b/promptolution/predictors/base_predictor.py @@ -1,10 +1,12 @@ """Base module for predictors.""" from abc import abstractmethod -from typing import List +from typing import List, Tuple import numpy as np +from promptolution.llms.base_llm import BaseLLM + class BasePredictor: """Abstract base class for predictors in the promptolution library. @@ -12,37 +14,30 @@ class BasePredictor: This class defines the interface that all concrete predictor implementations should follow. Attributes: - model_id (str): Identifier for the model used by the predictor. - classes (List[str]): List of possible class labels for classification tasks. + llm: The language model used for generating predictions. + Methods: predict: An abstract method that should be implemented by subclasses to make predictions based on prompts and input data. """ - def __init__(self, model_id, classes, *args, **kwargs): + def __init__(self, llm: BaseLLM): """Initialize the BasePredictor. Args: - model_id (str): Identifier for the model to use. - classes (List[str]): List of possible class labels. - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. + llm: The language model to use for predictions. + classes (List[str]): The list of valid class labels. """ - self.model_id = model_id - self.classes = classes + self.llm = llm - @abstractmethod - def predict( - self, - prompts: List[str], - xs: np.ndarray, - ) -> np.ndarray: + def predict(self, prompts: List[str], xs: np.ndarray, return_seq: bool = False) -> np.ndarray: """Abstract method to make predictions based on prompts and input data. Args: prompts (List[str]): List of prompts to use for prediction. xs (np.ndarray): Array of input data. + return_seq (bool, optional): whether to return the generating sequence Returns: np.ndarray: Array of predictions. @@ -50,6 +45,24 @@ def predict( Raises: NotImplementedError: If not implemented by a subclass. """ + if isinstance(prompts, str): + prompts = [prompts] + + outputs = self.llm.get_response([prompt + "\n" + x for prompt in prompts for x in xs]) + preds = self._extract_preds(outputs, (len(prompts), len(xs))) + + if return_seq: + return preds, [i + "\n" + o for i, o in zip(xs, outputs)] + + return preds + + def _extract_preds(self, preds: List[str], shape: Tuple[int, int]) -> np.ndarray: + """Extract class labels from the predictions, based on the list of valid class labels. + + Args: + preds: The raw predictions from the language model. + shape: The shape of the output array: (n_prompts, n_samples). + """ raise NotImplementedError diff --git a/promptolution/predictors/classificator.py b/promptolution/predictors/classificator.py index 7cf6fe9..f33bfc6 100644 --- a/promptolution/predictors/classificator.py +++ b/promptolution/predictors/classificator.py @@ -1,6 +1,6 @@ """Module for classification predictors.""" -from typing import List +from typing import List, Tuple import numpy as np @@ -11,7 +11,10 @@ class Classificator(BasePredictor): """A predictor class for classification tasks using language models. This class takes a language model and a list of classes, and provides a method - to predict classes for given prompts and input data. + to predict classes for given prompts and input data. The class labels are extracted + by matching the words in the prediction with the list of valid class labels. + The first occurrence of a valid class label in the prediction is used as the predicted class. + If no valid class label is found, the first class label in the list is used as the default prediction. Attributes: llm: The language model used for generating predictions. @@ -28,39 +31,19 @@ def __init__(self, llm, classes, *args, **kwargs): llm: The language model to use for predictions. classes (List[str]): The list of valid class labels. """ - self.llm = llm + super().__init__(llm) self.classes = classes - def predict( - self, - prompts: List[str], - xs: np.ndarray, - ) -> np.ndarray: - """Predict classes for given prompts and input data. - - This method generates predictions using the language model and then - extracts the predicted class from the model's output. + def _extract_preds(self, preds: List[str], shape: Tuple[int, int]) -> np.ndarray: + """Extract class labels from the predictions, based on the list of valid class labels. Args: - prompts (List[str]): The list of prompts to use for prediction. - xs (np.ndarray): The input data array. - - Returns: - np.ndarray: A 2D array of predicted classes, with shape (len(prompts), len(xs)). - - Note: - The method concatenates each prompt with each input data point, - passes it to the language model, and then extracts the first word - in the response that matches a class in self.classes. + preds: The raw predictions from the language model. + shape: The shape of the output array: (n_prompts, n_samples). """ - if isinstance(prompts, str): - prompts = [prompts] - - preds = self.llm.get_response([prompt + "\n" + x for prompt in prompts for x in xs]) - response = [] for pred in preds: - predicted_class = "" + predicted_class = self.classes[0] # use first class as default pred for word in pred.split(" "): word = "".join([c for c in word if c.isalnum()]) if word in self.classes: @@ -69,5 +52,5 @@ def predict( response.append(predicted_class) - response = np.array(response).reshape(len(prompts), len(xs)) + response = np.array(response).reshape(*shape) return response diff --git a/promptolution/tasks/classification_tasks.py b/promptolution/tasks/classification_tasks.py index 9da7a17..f37deec 100644 --- a/promptolution/tasks/classification_tasks.py +++ b/promptolution/tasks/classification_tasks.py @@ -2,9 +2,10 @@ import json from pathlib import Path -from typing import Dict, List, Literal, Optional +from typing import Callable, Dict, List, Literal, Optional import numpy as np +from sklearn.metrics import accuracy_score from promptolution.predictors.base_predictor import BasePredictor from promptolution.tasks.base_task import BaseTask @@ -25,8 +26,9 @@ class ClassificationTask(BaseTask): xs (Optional[np.ndarray]): Input data for the task. ys (Optional[np.ndarray]): Ground truth labels for the task. classes (Optional[List]): List of possible class labels. - split (Literal["dev", "test"]): Dataset split to use. seed (int): Random seed for reproducibility. + split (Literal["dev", "test"]): Dataset split to use. + metric (Callable): Metric to use as an evaluation score for the prompts. Inherits from: BaseTask: The base class for tasks in the promptolution library. @@ -38,6 +40,7 @@ def __init__( task_id: str = "Classification Task", seed: int = 42, split: Literal["dev", "test"] = "dev", + metric: Callable = accuracy_score, ): """Initialize the ClassificationTask. @@ -46,6 +49,7 @@ def __init__( dataset_path (str): Path to the dataset description JSON file. seed (int, optional): Random seed for reproducibility. Defaults to 42. split (Literal["dev", "test"], optional): Dataset split to use. Defaults to "dev". + metric (Callable): Metric to use as an evaluation score for the prompts. Defaults to sklearn's accuracy. """ self.task_id: str = task_id self.path: Path = dataset_path @@ -56,6 +60,7 @@ def __init__( self.ys: Optional[np.ndarray] = None self.classes: Optional[List] = None self.split: Literal["dev", "test"] = split + self.metric = metric self._parse_task() self.reset_seed(seed) @@ -95,7 +100,12 @@ def _parse_task(self): self.ys = np.array(ys) def evaluate( - self, prompts: List[str], predictor: BasePredictor, n_samples: int = 20, subsample: bool = True + self, + prompts: List[str], + predictor: BasePredictor, + n_samples: int = 20, + subsample: bool = False, + return_seq: bool = False, ) -> np.ndarray: """Evaluate a set of prompts using a given predictor. @@ -103,7 +113,9 @@ def evaluate( prompts (List[str]): List of prompts to evaluate. predictor (BasePredictor): Predictor to use for evaluation. n_samples (int, optional): Number of samples to use if subsampling. Defaults to 20. - subsample (bool, optional): Whether to use subsampling. Defaults to True. + subsample (bool, optional): Whether to use subsampling. + If set to true, samples a different subset per call. Defaults to False. + return_seq (bool, optional): whether to return the generating sequence Returns: np.ndarray: Array of accuracy scores for each prompt. @@ -120,10 +132,17 @@ def evaluate( ys_subsample = self.ys[indices] # Make predictions on the subsample - preds = predictor.predict(prompts, xs_subsample) + preds = predictor.predict(prompts, xs_subsample, return_seq=return_seq) + + if return_seq: + preds, seqs = preds + + scores = np.array([self.metric(ys_subsample, pred) for pred in preds]) + + if return_seq: + return scores, seqs - # Calculate accuracy: number of correct predictions / total number of predictions per prompt - return np.mean(preds == ys_subsample, axis=1) + return scores def reset_seed(self, seed: int = None): """Reset the random seed.""" diff --git a/pyproject.toml b/pyproject.toml index 659f202..8635296 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "promptolution" -version = "0.2.0" +version = "1.0.0" description = "" authors = ["Tom Zehle, Moritz Schlager, Timo Heiß"] readme = "README.md" @@ -14,6 +14,7 @@ langchain-core = "^0.2.29" langchain-community = "^0.2.12" pandas = "^2.2.2" tqdm = "^4.66.5" +scikit-learn = "^1.5.2" [tool.poetry.group.dev.dependencies] matplotlib = "^3.9.2"