# Chain-of-Thought Prompting

This notebook explains chain-of-thought (CoT) prompting in the context of large language models (LLMs). By the end, you'll know how to benchmark any LLM on any reasoning task that might benefit from CoT prompting. Hopefully this will also improve your intuition on how to use LLMs---like ChatGPT---more effectively.

## What is Chain-of-Thought Prompting?

In 2022, a group of Google researchers released this [paper](https://arxiv.org/pdf/2201.11903)

![CoT paper](./images/cot_paper.png)

which showed that including a series of intermediate reasoning steps in the prompt significantly improves the ability of LLMs to perform complex reasoning. They decided to call this series of intermediate reasoning steps a "chain-of-thought (CoT)." A brief demonstration of their findings is shown below.

One example of a complex reasoning task outlined in the paper is "arithmetic reasoning," which are essentially just math word problems.

Suppose we had the following question:

<div align="center", style="font-size:20px;">

*The cafeteria had 23 apples. If they used 20 to make lunch and bought 6 more, how many apples do they have?*

</div>

By doing some basic math, we can figure out that the answer is 23 - 20 + 6 = 9.

Let's see if an LLM can solve it.

The code below takes our question and asks it to GPT-3 using the OpenAI API.

Unlike LLMs like ChatGPT, which you might be more familiar using, the version of GPT-3 used in the original CoT paper has not been fine-tuned for human dialogue. In other words, it's not a chatbot. Instead, it'll take your prompt and continuously regurgitate the probabilistically most likely next word until you tell it to stop using the `max_tokens` parameter.

Thus, the reason why we include a `Q:`, `A:`, example before asking our question of interest, is to tell that model that we want it to respond in the format `A: <answer>`.

In [None]:
from dotenv import load_dotenv
import os
from openai import OpenAI
import matplotlib.pyplot as plt

load_dotenv()
# Assumes you have an OpenAI API key in your .env file
api_key = os.getenv('API_KEY')
client = OpenAI(
    api_key=api_key
)

no_cot_prompt = """
Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?
A: The answer is 11.
Q: The cafeteria had 23 apples. If they used 20 to make lunch and bought 6 more, how many apples do they have?
"""

In [None]:
# Run inference using GPT-3
response_no_cot = client.completions.create(
  model="davinci-002",
  prompt=no_cot_prompt,
  max_tokens=32,
  temperature=0
)
print(response_no_cot.choices[0].text.split('\n', 1)[0].strip())

Yikes... looks like GPT-3 is worse at math than your average elementary schooler. But all hope isn't lost yet. Let's try including a chain-of-thought in our example question.

In [None]:
cot_prompt = """
Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?
A: Roger started with 5 balls. 2 cans of 3 tennis balls each is 6 tennis balls. 5 + 6 = 11. The answer is 11.
Q: The cafeteria had 23 apples. If they used 20 to make lunch and bought 6 more, how many apples do they have?
"""

response_cot = client.completions.create(
  model="davinci-002",
  prompt=cot_prompt,
  max_tokens=128,
  temperature=0
)
print(response_cot.choices[0].text.split('\n', 1)[0].strip())

Would you look at that. Looks like GPT-3 might be good at math after all.

Let's try to make thing finding more rigorous.

## Evaluation

We replicate the results of the paper using three models. 
- GPT-3 Babbage (<1.3B parameters)
- GPT-3 Davinci (<175B parameters)
- Apple's OpenELM (3B parameters LLM that was unbenchmarked!)

We performed experiments on the following three datasets. We provide examples of what the each one looks like.
- GSM8K
    - Q: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
    - A: 72
- StrategyQA
    - Q: Are more people today related to Genghis Khan than Julius Caesar?
    - A: Yes
- Last Letter Concatenation
    - In domain
        - Q: Take the last letters of the words in “SHEILA PUCKETT” and concatenate them.
        - AT
    - Out of domain
        - Q: Take the last letters of the words in “CRAIG GUY STACY SANTANA” and concatenate them.
        - A: GYYA

#### Plotting code

In [None]:
def plot_gsm8k():
    # Data
    sizes = [1, 2, 3]
    no_cot_accuracies = [3.5, 2.79, 12.6]
    cot_accuracies = [3.4, 4.79, 36.2]

    # Plot
    plt.figure(figsize=(10, 6))

    # Plot data points
    plt.scatter(sizes, no_cot_accuracies, color='blue', label='No COT')
    plt.scatter(sizes, cot_accuracies, color='red', label='COT')

    # Plot trend lines
    plt.plot(sizes, no_cot_accuracies, color='blue', linestyle='-')
    plt.plot(sizes, cot_accuracies, color='red', linestyle='-')

    # Ticks
    plt.xticks([1, 2, 3], labels=[str("1"), str("3"), str("175")], )


    # Labels and title
    plt.xlabel('Model Scale (# parameters in billions)')
    plt.ylabel('Solve Rate (%)')
    plt.title('Model Accuracy vs Size on GSM8K')
    plt.legend()

    plt.show()

def plot_stratqa():
    # Data
    sizes = [1, 2, 3]
    no_cot_accuracies = [48.8, 56.5, 59.4]
    cot_accuracies = [56.5, 54.8, 70.1]

    # Plot
    plt.figure(figsize=(10, 6))

    # Plot data points
    plt.scatter(sizes, no_cot_accuracies, color='blue', label='No COT')
    plt.scatter(sizes, cot_accuracies, color='red', label='COT')

    # Plot trend lines
    plt.plot(sizes, no_cot_accuracies, color='blue', linestyle='-')
    plt.plot(sizes, cot_accuracies, color='red', linestyle='-')

    # Ticks
    plt.xticks([1, 2, 3], labels=[str("1"), str("3"), str("175")], )


    # Labels and title
    plt.xlabel('Model Scale (# parameters in billions)')
    plt.ylabel('Solve Rate (%)')
    plt.title('Model Accuracy vs Size on StrategyQA')
    plt.legend()

    plt.show()

def plot_in_domain():
    # Data
    sizes = [1, 2, 3]
    no_cot_accuracies = [0.6, 2.4, 1.0]
    cot_accuracies = [3.2, 6.2, 93.2]

    # Plot
    plt.figure(figsize=(10, 6))

    # Plot data points
    plt.scatter(sizes, no_cot_accuracies, color='blue', label='No COT')
    plt.scatter(sizes, cot_accuracies, color='red', label='COT')

    # Plot trend lines
    plt.plot(sizes, no_cot_accuracies, color='blue', linestyle='-')
    plt.plot(sizes, cot_accuracies, color='red', linestyle='-')

    # Ticks
    plt.xticks([1, 2, 3], labels=[str("1"), str("3"), str("175")], )


    # Labels and title
    plt.xlabel('Model Scale (# parameters in billions)')
    plt.ylabel('Solve Rate (%)')
    plt.title('Model Accuracy vs Size on Last Letter Concatenation In Domain')
    plt.legend()

    plt.show()

def plot_out_domain():
    # Data
    sizes = [1, 2, 3]
    no_cot_accuracies = [0.0, 0.0, 0]
    cot_accuracies = [0.0, 0.2, 30.2]

    # Plot
    plt.figure(figsize=(10, 6))

    # Plot data points
    plt.scatter(sizes, no_cot_accuracies, color='blue', label='No COT')
    plt.scatter(sizes, cot_accuracies, color='red', label='COT')

    # Plot trend lines
    plt.plot(sizes, no_cot_accuracies, color='blue', linestyle='-')
    plt.plot(sizes, cot_accuracies, color='red', linestyle='-')

    # Ticks
    plt.xticks([1, 2, 3], labels=[str("1"), str("3"), str("175")], )


    # Labels and title
    plt.xlabel('Model Scale (# parameters in billions)')
    plt.ylabel('Solve Rate (%)')
    plt.title('Model Accuracy vs Size on Last Letter Concatenation Out of Domain')
    plt.legend()

    plt.show()

### Arithmetic Reasoning

In [None]:
plot_gsm8k()

### Commonsense Reasoning

In [None]:
plot_stratqa()

## Symbolic Reasoning

In [None]:
plot_in_domain()

In [None]:
plot_out_domain()