# 1. Retrieval Augmented In-context Learning

## 1.1 Contextual Embedding

In [1]:
# Step 0. Prepare the environment
!pip install InstructorEmbedding sentence-transformers datasets scikit-learn

Collecting InstructorEmbedding
  Downloading InstructorEmbedding-1.0.1-py2.py3-none-any.whl.metadata (20 kB)
Collecting sentence-transformers
  Downloading sentence_transformers-2.7.0-py3-none-any.whl.metadata (11 kB)
Downloading InstructorEmbedding-1.0.1-py2.py3-none-any.whl (19 kB)
Downloading sentence_transformers-2.7.0-py3-none-any.whl (171 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m171.5/171.5 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: InstructorEmbedding, sentence-transformers
Successfully installed InstructorEmbedding-1.0.1 sentence-transformers-2.7.0


In [2]:
!mkdir -p data/classification
!wget -O data/classification/train.txt https://raw.githubusercontent.com/ranpox/comp3361-spring2024/main/assignments/A1/data/classification/train.txt
!wget -O data/classification/dev.txt https://raw.githubusercontent.com/ranpox/comp3361-spring2024/main/assignments/A1/data/classification/dev.txt
!wget -O data/classification/test-blind.txt https://raw.githubusercontent.com/ranpox/comp3361-spring2024/main/assignments/A1/data/classification/test-blind.txt

--2024-04-29 19:43:36--  https://raw.githubusercontent.com/ranpox/comp3361-spring2024/main/assignments/A1/data/classification/train.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 738844 (722K) [text/plain]
Saving to: 'data/classification/train.txt'


2024-04-29 19:43:36 (13.3 MB/s) - 'data/classification/train.txt' saved [738844/738844]

--2024-04-29 19:43:37--  https://raw.githubusercontent.com/ranpox/comp3361-spring2024/main/assignments/A1/data/classification/dev.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 94400 (92

In [3]:
# Step 1. Declare the model & Example usage
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("hkunlp/instructor-base")
embeddings = model.encode(
    [
        "Dynamical Scalar Degree of Freedom in Horava-Lifshitz Gravity",
        "Comparison of Atmospheric Neutrino Flux Calculations at Low Energies",
        "Fermion Bags in the Massive Gross-Neveu Model",
        "QCD corrections to Associated t-tbar-H production at the Tevatron",
    ],
    prompt="Represent the Medicine sentence for clustering: ",
    show_progress_bar=True,
)

print(embeddings.shape)

modules.json:   0%|          | 0.00/461 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/66.2k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/439M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


tokenizer_config.json:   0%|          | 0.00/2.43k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/270 [00:00<?, ?B/s]

2_Dense/config.json:   0%|          | 0.00/115 [00:00<?, ?B/s]

2_Dense/pytorch_model.bin:   0%|          | 0.00/2.36M [00:00<?, ?B/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

(4, 768)


## 1.2 Retrieve Relevant Examples

In [7]:
import re
import json
from datasets import load_dataset
from sentence_transformers import SentenceTransformer

def load_train_data():
    """Loads the GSM8k train dataset.

    Returns:
        A list of dictionaries containing questions, cot answers, and short answers.
    """
    ds = load_dataset("gsm8k", "main", split="train")
    examples = [{"question": example["question"], "answer": example["answer"]} for example in ds]
    for example in examples:
        example["short_answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"].split("####")[1].strip())
        example["cot_answer"] = re.sub(r"\<\<.*?\>\>", "", example["answer"].split("####")[0].strip()) \
            + " So the answer is " + example["short_answer"] + "."
    return examples

def load_test_data():
    """Loads the first 30 examples of the GSM8k test dataset.

    Returns:
        A list of dictionaries containing questions and answers.
    """
    ds = load_dataset("gsm8k", "main", split="test")
    examples = [{"question": example["question"], "answer": example["answer"].split("####")[1].strip()} for example in ds]
    for example in examples:
        example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
    return examples[:30]

In [8]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
def retrieve_examples(GSM8K_train, GSM8K_test):
    """Retrieve top-20 in-context examples from GSM8K training set for each testing examples.

    Returns:
        A dictionary mapping testing questions to a list of top-20 training examples.
    """
    # Encode all the questions in training set
    GSM8K_train = load_train_data()
    GSM8K_test = load_test_data()

    train_questions = [example['question'] for example in GSM8K_train]
    test_questions = [example['question'] for example in GSM8K_test]
    train_question_embeddings = model.encode(train_questions, prompt="Represent the question for retrieval: ", show_progress_bar=True)
    test_question_embeddings = model.encode(test_questions, prompt="Represent the question for retrieval: ", show_progress_bar=True)

    similarity_scores = cosine_similarity(test_question_embeddings, train_question_embeddings) # (30, 7473)
    top_20_indices = np.argsort(-similarity_scores, axis=1)[:, :20] # (30, 20)

    # print(f'Visualise similarity score: {np.take_along_axis(similarity_scores, top_20_indices,axis=1)}')

    #encoded_train_questions = [example['Encoded Question'] for example in GSM8K_train]
    # Loop over the 30 test questions
    for test_idx, test_question in enumerate(GSM8K_test):
        # Add the top-20 training examples to the dictionary of the test question
        for train_idx, example in enumerate(top_20_indices[test_idx]):
            test_question[f'Example {train_idx}'] = GSM8K_train[example]
    # Return the dictionary
    return GSM8K_test

In [9]:
GSM8K_train = load_train_data()
GSM8K_test = load_test_data()
RETRIVED_EXAMPLES = retrieve_examples(GSM8K_train, GSM8K_test)
with open("retrieved_examples.json", "w") as fout:
    json.dump(RETRIVED_EXAMPLES, fout)

Downloading readme:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

Downloading data: 100%|██████████| 2.31M/2.31M [00:00<00:00, 29.8MB/s]
Downloading data: 100%|██████████| 419k/419k [00:00<00:00, 6.15MB/s]


Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

Batches:   0%|          | 0/234 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Note: The following retrieval augmented generation does not require a GPU. Please consider saving and downloading the examples you retrieve from the left file tab so that you will not be hindered by Colab GPU restrictions.

In [10]:
import re
import json
from datasets import load_dataset
from sentence_transformers import SentenceTransformer

#with open("retrieved_examples.json", "w") as fout:
#    json.dump(RETRIVED_EXAMPLES, fout)
with open("retrieved_examples.json", "r") as fin:
    RETRIVED_EXAMPLES = json.load(fin)

## 1.3 Generation with Huggingface Inference API

We will use LLM by querying huggingface inference api so we do not need GPU for the following code. Please generate HF_TOKEN at [hf.co/settings/tokens](hf.co/settings/tokens) and set as environment varible.

In [11]:
!pip install backoff evaluate

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl.metadata (9.4 kB)
Collecting responses<0.19 (from evaluate)
  Downloading responses-0.18.0-py3-none-any.whl.metadata (29 kB)
Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading responses-0.18.0-py3-none-any.whl (38 kB)
Installing collected packages: responses, evaluate
Successfully installed evaluate-0.4.1 responses-0.18.0


In [12]:
from abc import ABC, abstractmethod
from typing import List, Dict, Any
import os
import json
import backoff
import evaluate
import re
import time
import requests


os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_TOKEN"] =   #Your HF Token

from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

class LLM(object):
    def __init__(self, model_name="mistralai/Mistral-7B-v0.1"):
        self.model_name = model_name
        self.api_url = f"https://api-inference.huggingface.co/models/{model_name}"
        self.headers = {"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}

    @backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=60)
    def generate(self, prompts: List[str], **kwargs) -> List[str]:
        outputs = []

        def query(payload):
            response = requests.post(self.api_url, headers=self.headers, json=payload)
            if response.status_code != 200:
                raise ValueError(f"Request failed with code {response.status_code}, {response.text}")
            return response.json()

        for prompt in prompts:
            data = query(
                {
                    "inputs": prompt,
                    "parameters": { "max_new_tokens": 256, "stop": ["Question:"]},
                }
            )

            outputs.append(data[0]['generated_text'])

        return outputs

2024-04-29 19:45:12.477250: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-29 19:45:12.477396: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-29 19:45:12.596422: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [13]:
llm = LLM("codellama/CodeLlama-7b-hf")
#llm = LLM("mistralai/Mistral-7B-v0.1")
print(llm.generate(["Explain the importance of low latency LLMs", "What is the capital of France?"]))

['Explain the importance of low latency LLMs\n\n### Explain the importance of low latency LLMs\n\n-   Latency is the time it takes for a packet to travel from one point to another.\n\n-   Latency is the time it takes for a packet to travel from one point to another.\n\n-   Latency is the time it takes for a packet to travel from one point to another.\n\n-   Latency is the time it takes for a packet to travel from one point to another.\n\n-   Latency is the time it takes for a packet to travel from one point to another.\n\n-   Latency is the time it takes for a packet to travel from one point to another.\n\n-   Latency is the time it takes for a packet to travel from one point to another.\n\n-   Latency is the time it takes for a packet to travel from one point to another.\n\n-   Latency is the time it takes for a packet to travel from one point to another.\n\n-   Latency is the time it takes for a packet to travel from one point to another.\n\n-   Latency is the time it', 'What is the 

Please adapt your GSM8KCoTEvaluator for this API-based LLM. And report the performance of 8-shot chain-of-thought prompting on first 30 examples of GSM8K.

In [14]:
class Evaluator(ABC):
    def __init__(self, llm):
        self.llm = llm

    @abstractmethod
    def load_data(self):
        pass

    @abstractmethod
    def build_prompts(self):
        pass

    @abstractmethod
    def postprocess_output(self, output: str) -> str:
        pass

    def generate_completions(self, prompts: List[str], batch_size=2) -> List[str]:
        return self.llm.generate(prompts)

    def evaluate(self, batch_size=4, n_shot=8, save_dir="outputs"):
        dataset = self.load_data()
        prompts = self.build_prompts(dataset, n_shot)
        outputs = self.generate_completions(prompts, batch_size=batch_size)

        predictions = []
        for i, (example, prompt, output) in enumerate(zip(dataset, prompts, outputs)):
            prediction = {
                "task_id": example.get("task_id", f"task_{i}"),
                "prompt": prompt,
                "completion": output,
                "prediction": self.postprocess_output(output[len(prompt):]),
            }
            #print(f"task_id: {prediction['task_id']}")
            #print(f"  prompt: \n{prediction['prompt']}")
            #print(f"  completion: \n{prediction['completion']}")
            #print(f"  prediction: \n{prediction['prediction']}")
            predictions.append(prediction)

        # Save predictions to file
        os.makedirs(save_dir, exist_ok=True)
        prediction_save_path = os.path.join(save_dir, f"{type(self).__name__}_predictions.jsonl")
        with open(prediction_save_path, "w") as fout:
            for pred in predictions:
                fout.write(json.dumps(pred) + "\n")

        # Calculate metrics and print results
        results = self.calculate_metrics(predictions, dataset)
        print(f"Results for {type(self).__name__}: {results}")

    @abstractmethod
    def calculate_metrics(self):
        pass

GSM_EXAMPLARS = [
    {
        "question": "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?",
        "cot_answer": "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. So the answer is 6.",
        "pot_answer": "def solution():\n    \"\"\"There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\"\"\"\n    trees_initial = 15\n    trees_after = 21\n    trees_added = trees_after - trees_initial\n    result = trees_added\n    return result",
        "short_answer": "6"
    },
    {
        "question": "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?",
        "cot_answer": "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. So the answer is 5.",
        "pot_answer": "def solution():\n    \"\"\"If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\"\"\"\n    cars_initial = 3\n    cars_arrived = 2\n    total_cars = cars_initial + cars_arrived\n    result = total_cars\n    return result",
        "short_answer": "5"
    },
    {
        "question": "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?",
        "cot_answer": "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. So the answer is 39.",
        "pot_answer": "def solution():\n    \"\"\"Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\"\"\"\n    leah_chocolates = 32\n    sister_chocolates = 42\n    total_chocolates = leah_chocolates + sister_chocolates\n    chocolates_eaten = 35\n    chocolates_left = total_chocolates - chocolates_eaten\n    result = chocolates_left\n    return result",
        "short_answer": "39"
    },
    {
        "question": "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?",
        "cot_answer": "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. So the answer is 8.",
        "pot_answer": "def solution():\n    \"\"\"Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\"\"\"\n    jason_lollipops_initial = 20\n    jason_lollipops_after = 12\n    denny_lollipops = jason_lollipops_initial - jason_lollipops_after\n    result = denny_lollipops\n    return result",
        "short_answer": "8"
    },
    {
        "question": "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?",
        "cot_answer": "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. So the answer is 9.",
        "pot_answer": "def solution():\n    \"\"\"Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\"\"\"\n    toys_initial = 5\n    mom_toys = 2\n    dad_toys = 2\n    total_received = mom_toys + dad_toys\n    total_toys = toys_initial + total_received\n    result = total_toys\n    return result",
        "short_answer": "9"
    },
    {
        "question": "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?",
        "cot_answer": "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. So the answer is 29.",
        "pot_answer": "def solution():\n    \"\"\"Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\"\"\"\n    toys_initial = 5\n    mom_toys = 2\n    dad_toys = 2\n    total_received = mom_toys + dad_toys\n    total_toys = toys_initial + total_received\n    result = total_toys\n    return result",
        "short_answer": "29"
    },
    {
        "question": "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?",
        "cot_answer": "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. So the answer is 33.",
        "pot_answer": "def solution():\n    \"\"\"Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\"\"\"\n    golf_balls_initial = 58\n    golf_balls_lost_tuesday = 23\n    golf_balls_lost_wednesday = 2\n    golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday\n    result = golf_balls_left\n    return result",
        "short_answer": "33"
    },
    {
        "question": "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?",
        "cot_answer": "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. So the answer is 8.",
        "pot_answer": "def solution():\n    \"\"\"Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\"\"\"\n    money_initial = 23\n    bagels = 5\n    bagel_cost = 3\n    money_spent = bagels * bagel_cost\n    money_left = money_initial - money_spent\n    result = money_left\n    return result",
        "short_answer": "8"
    }
]

In [15]:
class GSM8KEvaluator(Evaluator):
    def load_data(self):
        ds = load_dataset("gsm8k", "main", split="test")
        examples = [{"question": example["question"], "answer": example["answer"].split("####")[1].strip()} for example in ds]
        for example in examples:
            example["answer"] = re.sub(r"(\d),(\d)", r"\1\2", example["answer"])
        return examples[:30]

    def build_prompts(self, dataset, n_shot=8, demos=GSM_EXAMPLARS):
        pass

    def postprocess_output(self, output: str) -> str:
        output = output.split('the answer is ')[-1] if 'the answer is ' in output else output
        output = output.split('=')[-1] if '=' in output else output
        output = re.sub(r"(\d),(\d)", r"\1\2", output) if len(re.sub(r"(\d),(\d)", r"\1\2", output)) != 0 else output
        output = re.findall(r'\d+', output)[-1] if len(re.findall(r'\d+', output)) != 0 else output
        output = output.split('.')[0] if (len(output.split('.')) == 2) else output
        return output

    def calculate_metrics(self, predictions, dataset):
        score = 0
        for i in range(len(predictions)):
            #print(f'''Prediction:{predictions[i]["prediction"]}
            #          Answer:{dataset[i]["answer"]}
            #          Match: {predictions[i]["prediction"] == dataset[i]["answer"]}''')
            if predictions[i]["prediction"] == dataset[i]["answer"]:
                score += 1
            else:
                continue
        score /= len(predictions)
        return score

In [16]:
class GSM8KCoTEvaluator(GSM8KEvaluator):
    def build_prompts(self, dataset, n_shot=8, demos=GSM_EXAMPLARS):
        final_prompt = "Answer the following questions."
        for i in range(n_shot):
            question = demos[i]['question']
            answer = demos[i]['cot_answer']
            prompt = f"\nQuestion: {question}\nAnswer: {answer}"
            final_prompt = final_prompt + prompt
        prompts = [f"{final_prompt}\nQuestion: {test['question']}\nAnswer: " for test in dataset]
        return prompts

In [17]:
cot_evaluator = GSM8KCoTEvaluator(llm)
cot_evaluator.evaluate(n_shot=8)

Results for GSM8KCoTEvaluator: 0.16666666666666666


## 1.4 Impact of Quantity on Few-shot Prompting

In [18]:
cot_evaluator = GSM8KCoTEvaluator(llm)
cot_evaluator.evaluate(n_shot=1)
cot_evaluator.evaluate(n_shot=2)
cot_evaluator.evaluate(n_shot=4)
#cot_evaluator.evaluate(n_shot=8)

Results for GSM8KCoTEvaluator: 0.1
Results for GSM8KCoTEvaluator: 0.1
Results for GSM8KCoTEvaluator: 0.1


## 1.5 Retrieval Augmented Few-shot Prompting

In [19]:
class GSM8KRetrievalICLEvaluator(GSM8KEvaluator):
    def build_prompts(self, dataset, n_shot=8, demos=RETRIVED_EXAMPLES):
        """Build prompts with RETRIVED_EXAMPLES we generated in 1.2.
        """
        prompts = []
        final_prompt = "Answer the following questions."
        for test in demos:
            final_prompt = "Answer the following questions."
            for i in range(n_shot):
                question = test[f'Example {i}']['question']
                answer = test[f'Example {i}']['cot_answer']
                prompt = f"\nQuestion: {question}\nAnswer: {answer}"
                final_prompt = final_prompt + prompt
            final_prompt = f"{final_prompt}\nQuestion: {test['question']}\nAnswer: "
            #print(f'{demos.index(test)}th prompt: \n   {final_prompt}')
            prompts.append(final_prompt)
        return prompts

In [20]:
retrieval_icl_evaluator = GSM8KRetrievalICLEvaluator(llm)
retrieval_icl_evaluator.evaluate(n_shot=1)
retrieval_icl_evaluator.evaluate(n_shot=2)
retrieval_icl_evaluator.evaluate(n_shot=4)
retrieval_icl_evaluator.evaluate(n_shot=8)

Results for GSM8KRetrievalICLEvaluator: 0.03333333333333333
Results for GSM8KRetrievalICLEvaluator: 0.03333333333333333
Results for GSM8KRetrievalICLEvaluator: 0.1
Results for GSM8KRetrievalICLEvaluator: 0.13333333333333333


In [21]:
# Rearrange the examples in ascending order of relevance.
class GSM8KRetrievalICLEvaluator_ascending(GSM8KEvaluator):
    def build_prompts(self, dataset, n_shot=8, demos=RETRIVED_EXAMPLES):
        """Build prompts with RETRIVED_EXAMPLES we generated in 1.2.
        """
        prompts = []
        final_prompt = "Answer the following questions."
        for test in demos:
            final_prompt = "Answer the following questions."
            for i in range(n_shot-1, -1, -1):
                question = test[f'Example {i}']['question']
                answer = test[f'Example {i}']['cot_answer']
                prompt = f"\nQuestion: {question}\nAnswer: {answer}"
                final_prompt = final_prompt + prompt
            final_prompt = f"{final_prompt}\nQuestion: {test['question']}\nAnswer: "
            #print(final_prompt)
            prompts.append(final_prompt)
        return prompts

In [23]:
retrieval_icl_ascending_evaluator = GSM8KRetrievalICLEvaluator_ascending(llm)
retrieval_icl_ascending_evaluator.evaluate(n_shot=1)
retrieval_icl_ascending_evaluator.evaluate(n_shot=2)
retrieval_icl_ascending_evaluator.evaluate(n_shot=4)
retrieval_icl_ascending_evaluator.evaluate(n_shot=8)

Results for GSM8KRetrievalICLEvaluator_ascending: 0.03333333333333333
Results for GSM8KRetrievalICLEvaluator_ascending: 0.1
Results for GSM8KRetrievalICLEvaluator_ascending: 0.1
Results for GSM8KRetrievalICLEvaluator_ascending: 0.1
