# Chain-of-Thought Prompting

This notebook explains chain-of-thought (CoT) prompting in the context of large language models (LLMs). By the end, you'll know how to benchmark any LLM on any reasoning task that might benefit from CoT prompting. Hopefully this will also improve your intuition on how to use LLMs---like ChatGPT---more effectively.

## What is Chain-of-Thought Prompting?

In 2022, a group of Google researchers released this [paper](https://arxiv.org/pdf/2201.11903)

![CoT paper](./images/cot_paper.png)

which showed that including a series of intermediate reasoning steps in the prompt significantly improves the ability of LLMs to perform complex reasoning. They decided to call this series of intermediate reasoning steps a "chain-of-thought (CoT)." A brief demonstration of their findings is shown below.

One example of a complex reasoning task outlined in the paper is "arithmetic reasoning," which are essentially just math word problems.

Suppose we had the following question:

<div align="center", style="font-size:20px;">

*The cafeteria had 23 apples. If they used 20 to make lunch and bought 6 more, how many apples do they have?*

</div>

By doing some basic math, we can figure out that the answer is 23 - 20 + 6 = 9.

Let's see if an LLM can solve it.

The code below takes our question and asks it to GPT-3 using the OpenAI API.

Unlike LLMs like ChatGPT, which you might be more familiar using, the version of GPT-3 used in the original CoT paper has not been fine-tuned for human dialogue. In other words, it's not a chatbot. Instead, it'll take your prompt and continuously regurgitate the probabilistically most likely next word until you tell it to stop using the `max_tokens` parameter.

Thus, the reason why we include a `Q:`, `A:`, example before asking our question of interest, is to tell that model that we want it to respond in the format `A: <answer>`.

In [None]:
from dotenv import load_dotenv
import os
from openai import OpenAI
import matplotlib.pyplot as plt
from plotting import *

load_dotenv()
# Assumes you have an OpenAI API key in your .env file
api_key = os.getenv('API_KEY')
client = OpenAI(
    api_key=api_key
)

no_cot_prompt = """
Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?
A: The answer is 11.
Q: The cafeteria had 23 apples. If they used 20 to make lunch and bought 6 more, how many apples do they have?
"""

In [None]:
# Run inference using GPT-3
response_no_cot = client.completions.create(
  model="davinci-002",
  prompt=no_cot_prompt,
  max_tokens=32,
  temperature=0
)
print(response_no_cot.choices[0].text.split('\n', 1)[0].strip())

Yikes... looks like GPT-3 is worse at math than your average elementary schooler. But all hope isn't lost yet. Let's try including a chain-of-thought in our example question.

In [None]:
cot_prompt = """
Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?
A: Roger started with 5 balls. 2 cans of 3 tennis balls each is 6 tennis balls. 5 + 6 = 11. The answer is 11.
Q: The cafeteria had 23 apples. If they used 20 to make lunch and bought 6 more, how many apples do they have?
"""

response_cot = client.completions.create(
  model="davinci-002",
  prompt=cot_prompt,
  max_tokens=128,
  temperature=0
)
print(response_cot.choices[0].text.split('\n', 1)[0].strip())

Would you look at that. Looks like GPT-3 might be good at math after all.

Let's try to make thing finding more rigorous.

## Evaluation

We replicate the results of the paper using three models. 
- GPT-3 Babbage (<1.3B parameters)
- GPT-3 Davinci (<175B parameters)
- Apple's OpenELM (3B parameters LLM that was unbenchmarked!)

We performed experiments on the following three datasets and we provide examples of what the each one looks like.

### Arithmetic Reasoning

##### Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?

##### A (No CoT): The answer is 5.

##### A (CoT): There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.

In [None]:
plot_gsm8k()

### Commonsense Reasoning

##### Q: Do hamsters provide food for any animals?

##### A (No CoT): Yes

##### A (CoT): Hamsters are prey animals. Prey are food for predators. Thus, hamsters provide food for some animals. So the answer is yes.

In [None]:
plot_stratqa()

### Symbolic Reasoning

##### Q: Take the last letters of the words in "Elon Musk" and concatenate them.

##### A (No CoT): The answer is nk.

##### A (CoT): The last letter of "Elon" is "n". The last letter of "Musk" is "k". Concatenating them is "nk". The answer is nk.

In [None]:
plot_in_domain()

Q: Take the last letters of the words in “GERALD MCKNIGHT KATRINA CARPENTER” and concatenate them.

A: The answer is DIAR

In [None]:
plot_out_domain()

## Conclusion

Based on these results, it seems like you can significantly improve an LLM's ability to solve reasoning tasks by simply being a bit more thoughtful about your prompt.

Furthemermore, one of the lasting impacts of the original CoT paper, is that the GSM-8K dataset is now a standard benchmark used by all tech companies whenever they launch a new LLM. For example, here's Meta's new Llama 3 they released a few weeks ago:

![Llama 3 Benchmarks](./images/llama_benchmarks.png)

And now you all know how to compute this exact benchmark.

Note: our working theory for why Apple decided to leave their OpenELM model unbenchmarked on GSM-8K is because it sucks on it.

Although CoT's ability to improve model performance on reasoning tasks has been undisputed since the original Google research paper came out, there have been many [papers](https://arxiv.org/abs/2305.04388) that question our ability to make and interpretability statements from CoT reasoning. However, we don't have enough time to address these limitations here. 