# Causation (Bootstrap: Reprompting)

In this notebook, you can bootstrap CoT examples using ChatGPT.<br>

*Please note: This is a partial implementation of the Reprompting paper (see below), where the second portion is skipped for the moment as the metric for convergence is unclear and require a bit of further testing. However, the first bootstrapping process is implemented here.*

👼 At the end of the notebook, you will be able to download a `.toml` file which you can use with the `causation-cotsc` notebook.

👼 This notebook alleviates the human labour required for annotating CoT examples. This means you can quickly generate the reasoning and examine them. However, class definitions and a labelled dataset is still required.


## OpenAI Privacy Policy
This notebook uses OpenAI's API, meaning that your data will be sent to the OpenAI servers.

For concerns about how your data will be handled, please read through the Privacy Policy [here](https://openai.com/policies/api-data-usage-policies).

## Paper: Reprompting (expand to read)
Reprompting: Automated Chain-of-Thought Prompt Inference Through Gibbs Sampling: https://arxiv.org/abs/2305.09993

In [None]:
from IPython.display import IFrame
display(IFrame(src='https://arxiv.org/pdf/2305.09993.pdf', width=1600, height=700))

## 1. Enter your OpenAI Key

In [None]:
from causation.utils import openai_apikey_input
openai_apikey_input()

## 2. Upload your instruction toml file

In [None]:
from causation.utils import fileuploader

finput, instrs = fileuploader('.toml')
finput

In [None]:
assert instrs.get('data'), "Did you upload your CoT definitions?"

## 3. Build prompt for bootstrapping (No user interaction required)

In [None]:
import toml
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from langchain.prompts import PromptTemplate, SystemMessagePromptTemplate
from langchain.prompts import PromptTemplate, HumanMessagePromptTemplate
from langchain.prompts import ChatPromptTemplate

# response schema
response_schemas = [
    ResponseSchema(name="reason", description="The reasoning for the classification in a logical step by step manner."),
    ResponseSchema(name="answer", description="The classification"),
]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)

# load instructions from toml
path = instrs.get('data')
data = toml.load(path)

# system messages
prefix = data.pop('PREFIX', None)
prefix = prefix.get('instruction') if prefix is not None else None

instructions = list()
classes = list(data.keys())
for clazz in classes:
    instruction = data.get(clazz).get('instruction')
    instruction = f"<class>\n{clazz}: {instruction}</class>"
    instructions.append(instruction)

sys_template = prefix + "\n\n" + f"""
The following are {len(classes)} classes with a description of each. 
These are XML delimited with <class> tags in the format: <class> Class: Description </class>.
Please classify each 'query' as one of the {len(classes)} classes.\n\n""" + '\n'.join(instructions) + "\n\n"

sys_prompt = PromptTemplate(template=sys_template, input_variables=[])
bootstrap_sys = SystemMessagePromptTemplate(prompt=sys_prompt)


human_template = r"""
{format_instructions}

Let's think logically step by step and make sure you provide your best answer. In your reasoning say why it is NOT the other classes.
Query: {query}"""

human_prompt = PromptTemplate(
    template=human_template,
    input_variables=['query'],
    partial_variables={'format_instructions': output_parser.get_format_instructions()},
)
bootstrap_human = HumanMessagePromptTemplate(prompt=human_prompt)

bootstrap_prompt = ChatPromptTemplate.from_messages([bootstrap_sys, bootstrap_human])

"Bootstrap prompt created. Please continue."

## 4. Define Bootstrapping (No user interaction required)

In [None]:
from langchain.chat_models.openai import ChatOpenAI

chat_model = ChatOpenAI(model='gpt-3.5-turbo', temperature=0.0)
f"Your model name is: {chat_model.model_name}"

In [None]:
# postprocessing
import re, json
from dataclasses import dataclass

@dataclass(frozen=True)
class ReasonAnswerPair:
    r: str
    a: str

def parse(output: str) -> ReasonAnswerPair:
    try:
        if output.startswith('{') and output.strip().endswith('}'):
            json_string = output
        else:
            json_string = re.search(r'```json\n(.*?)```', output, re.DOTALL).group(1)
        j = json.loads(json_string)
        return ReasonAnswerPair(r=j.get('reason'), a=j.get('answer'))
    except:
        return ReasonAnswerPair(r=output, a='')
    
"Postprocessing parser created. Please continue."

In [None]:
r""" Bootstrapping

Use zero-shot GPT to obtain a bunch of CoTRecipes.
Input: bunch of queries and associated answer.
"""
import pandas as pd
from dataclasses import dataclass

@dataclass
class CoTRecipe:   
    query: str     # referred to as 'x' in the paper.
    reasoning: str    # referred to as 'z' in the paper.   -> a joint distribution of this variable is what we're trying to obtain.
    answer: str    # referred to as 'y' in the paper.
    
@dataclass(frozen=True)
class QueryAnswerPair:
    q: str
    a: str
    
def dataset_to_qapairs(df: pd.DataFrame) -> list[QueryAnswerPair]:
    qa_pairs = list()
    for row in df.itertuples():
        qa_pair = QueryAnswerPair(q=row.sentence, a=row.clazz)
        qa_pairs.append(qa_pair)
    return qa_pairs

def bootstrap(pairs: list[QueryAnswerPair]) -> tuple[list[CoTRecipe], list[CoTRecipe]]:
    batch_answers = [p.a for p in pairs]
    batch_queries = [p.q for p in pairs]
    batch_messages = [bootstrap_prompt.format_prompt(query=q).to_messages() for q in batch_queries]
    assert len(batch_messages) == len(batch_answers) == len(batch_queries), "Mismatched number of queries and answers."
    
    # 1. generate the 'z' i.e. the reasoning via zeroshot.
    print("+ Making calls to OpenAI. Please wait...")
    results = chat_model.generate(batch_messages)
    print("+ Done.")
    
    success: list[CoTRecipe] = list()
    failed: tuple[list[CoTRecipe], str] = list()
    for idx, (generation, query, true_answer) in enumerate(zip(results.generations, batch_queries, batch_answers)):
        if len(generation) > 1: print(f"+ > 1 generation? {idx=}")
        rapair = parse(generation[0].text)
        recipe = CoTRecipe(query=query, reasoning=rapair.r, answer=rapair.a)
        if rapair.a == true_answer:
            success.append(recipe)
        else:
            failed.append((recipe, true_answer))
    return success, failed 

"Bootstrap logic defined. Please continue."

## 5. Upload labelled dataset

Please upload your excel dataset with the 'clazz' and 'sentence' columns.
The 'clazz' column are your labels, these will be used to check against GPT generated answers.

In [None]:
finput, dataset = fileuploader(ext='.xlsx')
finput

In [None]:
assert dataset.get('data'), "Did you upload your dataset?"

df = pd.read_excel(dataset.get('data'))
assert "clazz" in df.columns, "Missing 'clazz' column in dataset."
assert "sentence" in df.columns, "Missing 'sentence' column in dataset."

qa_pairs = dataset_to_qapairs(df)
f"Converted to {len(qa_pairs)} question answer pairs. Run the next cell and run the bootstrap process."

## 6. Bootstrap: Ask GPT to generate the CoT exemplars.

In [None]:
recipes, failed = bootstrap(qa_pairs)
f"Number of successes: {len(recipes)} fails: {len(failed)}"

In [None]:
# save recipes as .toml
import toml
import panel as pn
from pathlib import Path
data = dict()

prev_data = toml.load(instrs.get('data'))
prefix = prev_data.pop('PREFIX', None)
if prefix is not None:
    data['PREFIX'] = {'instruction': prefix.get('instruction') }

for recipe in recipes:
    if recipe.answer not in data.keys():
        instruction = prev_data.get(recipe.answer, dict()).get('instruction', '')
        data[recipe.answer] = {'examples': [{'query': recipe.query, 'steps': recipe.reasoning}], 'instruction': instruction}
    else:
        data.get(recipe.answer).get('examples').append({'query': recipe.query, 'steps': recipe.reasoning})

path = Path('bootstrapped.toml')
        
with open(path, 'w') as h:
    toml.dump(data, h)
    
pn.widgets.FileDownload(file=str(path), filename=path.name)

In [None]:
# [Optional] Quick check of bootstrapped.toml that was generated.
!cat bootstrapped.toml