In [1]:
import re
import time
from enum import Enum
from typing import Dict, List, Tuple

import kscope
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader

### Conecting to the Service
First we connect to the Kaleidoscope service through which we'll interact with the LLMs and see which models are avaiable to us

In [2]:
# Establish a client connection to the kscope service
client = kscope.Client(gateway_host="llm.cluster.local", gateway_port=3001)

Show all model instances that are currently active

In [3]:
client.model_instances

[{'id': '6d599738-d7b0-4277-83f8-1de47854f9f5',
  'name': 'falcon-40b',
  'state': 'ACTIVE'}]

To start, we obtain a handle to a model. In this example, let's use the Falcon-40B model.

In [4]:
model = client.load_model("falcon-40b")
# If this model is not actively running, it will get launched in the background.
# In this case, wait until it moves into an "ACTIVE" state before proceeding.
while model.state != "ACTIVE":
    time.sleep(1)

In [5]:
small_generation_config = {"max_tokens": 20, "top_k": 1, "top_p": 1.0, "temperature": 0.8}
moderate_generation_config = {"max_tokens": 128, "top_k": 1, "top_p": 1.0, "temperature": 0.8}

Let's ask the model a simple question to start. Note that the model generates until the `max_tokens` threshold is met. This means we would have to "parse" out the answer

In [6]:
generation = model.generate("What is the capital of Canada?", small_generation_config)
# Extract the text from the returned generation
print(generation.generation["sequences"][0])


Ottawa
What is the capital of Canada?
Ottawa
What is the capital of


## Loading the Dataset

We'll start by loading a sampling of 100 examples and parsing them into a test dataset.

In [7]:
class CoTDataset(Enum):
    GSM8K = "gsm8k"
    MULTI_ARITH = "multi_arith"


dataset_name = CoTDataset.MULTI_ARITH

dataset = load_dataset("gsm8k", "main") if dataset_name is CoTDataset.GSM8K else load_dataset("ChilleD/MultiArith")
# Setting the manual seed so that the shuffle is deterministic.
torch.manual_seed(1776)
# Loading with a batch size of one so we can process them 1 at a time.
dataloader = DataLoader(dataset["train"], shuffle=True, batch_size=1)

In [8]:
def process_gsmk8_answers(answer_str: str) -> float:
    processed_answer = answer_str.split("####")[-1].strip()
    return float(processed_answer.replace(",", ""))


def process_multi_arith_answers(answer_str: str) -> float:
    return float(answer_str.strip().replace(",", ""))


def process_dataset_point(datapoint: Dict[str, List[str]], dataset_name: CoTDataset) -> Tuple[str, float]:
    if dataset_name is CoTDataset.GSM8K:
        word_problem = datapoint["question"][0]
        answer = process_gsmk8_answers(datapoint["answer"][0])
        return word_problem, answer
    elif dataset_name is CoTDataset.MULTI_ARITH:
        word_problem = datapoint["question"][0].strip()
        answer = process_multi_arith_answers(datapoint["final_ans"][0])
        return word_problem, answer
    else:
        raise ValueError("Dataset not supported...")

Now we construct a dataset of 100 Word Problems and the associated answer.

In [9]:
total_examples = 100
data_iterator = iter(dataloader)
word_problems: List[str] = []
answers: List[float] = []
for i in range(total_examples):
    example = next(data_iterator)
    word_problem, answer = process_dataset_point(example, dataset_name)
    word_problems.append(word_problem)
    answers.append(answer)

print(f"Word Problem: {word_problems[0]}")
print(f"Answer: {answers[0]}")

Word Problem: Rachel was organizing her book case making sure each of the shelves had exactly 9 books on it. If she had 6 shelves of mystery books and 2 shelves of picture books, how many books did she have total?
Answer: 72.0


## Standard Zero-shot Prompting

First, let's measure the performance of a zero-shot prompt in solving the word problems in this task. 

**Note**: This is a two stage process. First, we need to have the model generate the correct response. Next, we need to parse that response to get a final answer and compare it to the label.

In [10]:
def create_prompt_from_template(word_problem: str) -> str:
    return f"Q: {word_problem}\nA: The answer is"

In [11]:
zero_shot_prompts = [create_prompt_from_template(word_problem) for word_problem in word_problems]
print(zero_shot_prompts[0])

Q: Rachel was organizing her book case making sure each of the shelves had exactly 9 books on it. If she had 6 shelves of mystery books and 2 shelves of picture books, how many books did she have total?
A: The answer is


In [12]:
def parse_answer_to_float(raw_answer: str) -> float:
    remove_leading_symbols = re.sub(r"^([^\d])*", "", raw_answer.strip())
    remove_commas = re.sub(r",\s*", "", remove_leading_symbols)
    return float(remove_commas.rstrip("."))


def parse_predicted_answer(full_answer: str) -> float:
    # Attempt to parse the answer string into a number
    answer_match = re.search(r"-?(\d[,.]*)+", full_answer)
    if not answer_match:
        print(f"Failed to match to number: {full_answer}")
        return 0.0
    else:
        split_answer = answer_match.group()
        try:
            return parse_answer_to_float(split_answer)
        except Exception:
            print(f"Failed to parse: {full_answer}\nMatched: {split_answer}")
            return 0.0

In [13]:
predicted_answers = []
for prompt_num, zero_shot_prompt in enumerate(zero_shot_prompts):
    generation_example = model.generate(zero_shot_prompt, generation_config=small_generation_config)
    full_answer = generation_example.generation["sequences"][0]
    predicted_answers.append(parse_predicted_answer(full_answer))

    if (prompt_num + 1) % 10 == 0:
        print(f"Processed {prompt_num + 1} prompts...")

Processed 10 prompts...
Processed 20 prompts...
Processed 30 prompts...
Processed 40 prompts...
Processed 50 prompts...
Processed 60 prompts...
Processed 70 prompts...
Processed 80 prompts...
Processed 90 prompts...
Processed 100 prompts...


Now let's measure the accuracy of the parsed predictions from the model compared with the true answers

In [14]:
correct = 0
for predicted_answer, true_answer in zip(predicted_answers, answers):
    if true_answer == predicted_answer:
        correct += 1
print(f"Zero-shot Prompt Accuracy: {correct/total_examples}")

Zero-shot Prompt Accuracy: 0.11


Clearly the model struggles to produce the correct answer for these problems. Let's see if we can improve the performance with zero-shot CoT!

## Zero-shot Chain-of-Thought Prompting

Now let's try performing some zero-shot CoT Prompting to see if we can get better performance. Remember that the zero-shot CoT prompt process has two stages. In the first stage, we ask the model to "think step by step" about how to solve the problem. In the second stage, we include that logic in the prompt and ask the model to provide a final answer.

**NOTE**: CoT Queries take a lot longer to run due to the significantly longer context involved. The 100 queries will take at least 50 minutes to complete

In [15]:
def construct_first_stage_prompt(word_problem: str) -> str:
    return f"Q: {word_problem}\nA: Let's think step by step."


def construct_second_stage_prompt(prompt: str, logic_generation: str) -> str:
    return f"{prompt}{logic_generation}\nTherefore, the final answer is"

Let's try out the two stage process to see what it looks like

In [16]:
logic_prompt = construct_first_stage_prompt(word_problems[0])
print(f"Logic Prompt:\n{logic_prompt}\n")

# First stage prompt to generate logic
logic_generation = model.generate(logic_prompt, generation_config=moderate_generation_config)
generated_logic = logic_generation.generation["sequences"][0]
print(f"Generated Logic:\n{generated_logic}")

answer_prompt = construct_second_stage_prompt(logic_prompt, generated_logic)
print(f"Answer Prompt:\n{answer_prompt}\n")

# Second stage prompt to generate answer
answer_generation = model.generate(answer_prompt, generation_config=small_generation_config)
generated_answer = answer_generation.generation["sequences"][0]
print(f"Generated Answer:\n{generated_answer}")

Logic Prompt:
Q: Rachel was organizing her book case making sure each of the shelves had exactly 9 books on it. If she had 6 shelves of mystery books and 2 shelves of picture books, how many books did she have total?
A: Let's think step by step.

Generated Logic:

1. How many books are on the mystery shelves?
6 shelves x 9 books per shelf = 54 mystery books
2. How many books are on the picture shelves?
2 shelves x 9 books per shelf = 18 picture books
3. How many books are there total?
54 mystery books + 18 picture books = 72 books total
4. How many books are on the mystery shelves?
6 shelves x 9 books per
Answer Prompt:
Q: Rachel was organizing her book case making sure each of the shelves had exactly 9 books on it. If she had 6 shelves of mystery books and 2 shelves of picture books, how many books did she have total?
A: Let's think step by step.
1. How many books are on the mystery shelves?
6 shelves x 9 books per shelf = 54 mystery books
2. How many books are on the picture shelves?

In [17]:
predicted_answers = []
for prompt_num, word_problem in enumerate(word_problems):
    logic_prompt = construct_first_stage_prompt(word_problem)
    logic_generation = model.generate(logic_prompt, generation_config=moderate_generation_config)
    generated_logic = logic_generation.generation["sequences"][0]

    answer_prompt = construct_second_stage_prompt(logic_prompt, generated_logic)
    answer_generation = model.generate(answer_prompt, generation_config=small_generation_config)
    full_answer = answer_generation.generation["sequences"][0]
    predicted_answers.append(parse_predicted_answer(full_answer))

    if (prompt_num + 1) % 10 == 0:
        print(f"Processed {prompt_num + 1} prompts...")

Processed 10 prompts...
Processed 20 prompts...
Processed 30 prompts...
Processed 40 prompts...
Processed 50 prompts...
Processed 60 prompts...
Processed 70 prompts...
Processed 80 prompts...
Processed 90 prompts...
Failed to match to number:  "He was a chicken."
Q: A man was driving his car when he saw a
Processed 100 prompts...


In [18]:
correct = 0
for predicted_answer, true_answer in zip(predicted_answers, answers):
    if true_answer == predicted_answer:
        correct += 1
print(f"Zero-shot CoT Prompt Accuracy: {correct/total_examples}")

Zero-shot CoT Prompt Accuracy: 0.32


The accuracy of the models ability to perform the task has improved significantly. However, it takes quite a lot longer for the model to respond due to the extra computation associated with logic generation and then processing for the second generation.