In [1]:
import json
from rich.pretty import pprint

In [2]:
with open("../../data/raw_cot_chatgpt_reasons.json", "r") as f:
    cot = json.load(f)

# Check the typical

In [3]:
# from transformers import AutoTokenizer
# import seaborn as sns

# tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
# lengths = [len(tokenizer(reason)["input_ids"]) for reason in list(cot.values())]

In [4]:
# sns.histplot(lengths, bins=20, color='skyblue', edgecolor='black', kde=False)

# Check the opening sentence

In [5]:
reasons_by_prefixes = {
    "The statement is a": {},
    "The statement claims that": {},
    "The statement \"": {},
    "1. The statement is: \"": {},
    "1. The statement: \"": {},
    "1. The statement says \"": {},
    "To determine if the statement": {}, # Maybe ok, but can be too long
    # "Step 1": {}, # This is fine
    # "The evidence provided": {}, # This is fine
    # "The evidence states": {}, # This is fine
}
idx_to_remove = []

for idx, (statement_id, reason) in enumerate(cot.items()):
    for reason_prefix in list(reasons_by_prefixes.keys()):
        if reason.startswith(reason_prefix):
            reasons_by_prefixes[reason_prefix][statement_id] = reason
            idx_to_remove += [idx]
            break

for prefix, reasons in reasons_by_prefixes.items():
    print(f"{prefix}: {len(reasons.keys())}")

print(f"Unclassified: {len(list(cot.values()))-len(idx_to_remove)}")
for idx, (statement_id, reason) in enumerate(cot.items()):
    if idx not in idx_to_remove:
        pprint(f"{statement_id}, {reason}")

The statement is a: 332
The statement claims that: 132
The statement ": 107
1. The statement is: ": 7
1. The statement: ": 1
1. The statement says ": 1
To determine if the statement: 78
Unclassified: 1034


In [6]:
reasons_by_prefixes["other"] = {statement_id: reason for idx, (statement_id, reason) in enumerate(list(cot.items())) if idx not in idx_to_remove}

In [7]:
# from matplotlib import pyplot as plt

# for reason_type, reasons in reasons_by_prefixes.items():
#     print(reason_type)
#     lengths = [len(tokenizer(reason)["input_ids"]) for reason in list(reasons.values())]

#     # Create a new figure and axis for each reason_type
#     fig, ax = plt.subplots()
    
#     # Plot the histogram for the current reason_type
#     sns.histplot(lengths, bins=20, color='skyblue', edgecolor='black', kde=False, ax=ax)
    
#     # Customize the plot
#     ax.set_title(f"Histogram for {len(reasons)} statements with '{reason_type}' prefix")
#     ax.set_xlabel("Length of input_ids")
#     ax.set_ylabel("Frequency")

#     # Show the plot for the current reason_type
#     plt.show()

# Postprocess ChatGPT-generated CoT reasons

## Prefix: "The statement is a"

In [8]:
postprocessed_reasons_by_prefixes = {prefix: {} for prefix in list(reasons_by_prefixes.keys())}

In [9]:
prefix = "The statement is a"

for statement_id in list(reasons_by_prefixes[prefix].keys()):
    postprocessed_statement = reasons_by_prefixes[prefix][statement_id].replace("is a contradiction because it ", "")
    
    # Entailment
    postprocessed_statement = postprocessed_statement.replace("The statement is an entailment because it is supported by the evidence provided.", "").strip()
    postprocessed_statement = postprocessed_statement.replace("The statement is an entailment because it accurately reflects the information provided in the evidence. ", "").strip()
    postprocessed_statement = postprocessed_statement.replace("The statement is an entailment because ", "").strip()
    postprocessed_statement = postprocessed_statement.replace("The statement is an entailment.", "").strip()
    
    # Contradiction
    postprocessed_statement = postprocessed_statement.replace("The statement is a contradiction because ", "").strip()

    postprocessed_reasons_by_prefixes[prefix][statement_id] = postprocessed_statement

pprint(postprocessed_reasons_by_prefixes[prefix])

## Prefix: The statement claims that

In [10]:
prefix = "The statement claims that"

for statement_id in list(reasons_by_prefixes[prefix].keys()):
    postprocessed_statement = ".".join(reasons_by_prefixes[prefix][statement_id].split(". ")[1:]).strip()
    postprocessed_statement = postprocessed_statement.replace("However, based on the evidence provided, this statement is a contradiction.", "").strip()
    postprocessed_statement = postprocessed_statement.replace("based on the evidence provided, this statement is contradicted.", "").strip()
    if postprocessed_statement.startswith("However, "):
        postprocessed_statement = postprocessed_statement[8:]
    if postprocessed_statement.startswith("Contradiction: "):
        postprocessed_statement = postprocessed_statement[14:]
    if postprocessed_statement.startswith("To determine if this statement is a contradiction or an entailment,") or postprocessed_statement.startswith("To determine if this statement is an entailment or a contradiction,"):
        postprocessed_statement = ".".join(postprocessed_statement.split(".")[1:]).strip()
    if postprocessed_statement.startswith("Contradiction occurs because"):
        postprocessed_statement = ".".join(postprocessed_statement.split(".")[1:]).strip()
    postprocessed_statement = postprocessed_statement.replace("based on the evidence provided,", "from the evidence,").strip()
    postprocessed_statement = postprocessed_statement.replace("according to the evidence provided,", "from the evidence,").strip()
    postprocessed_statement = postprocessed_statement.replace("evidence provided", "evidence").strip()

    postprocessed_reasons_by_prefixes[prefix][statement_id] = postprocessed_statement.strip()

# pprint(postprocessed_reasons_by_prefixes[prefix])

## Prefix: The statement "

In [11]:
prefix = "The statement \""

for statement_id in list(reasons_by_prefixes[prefix].keys()):
    # postprocessed_statement = reasons_by_prefixes[prefix][statement_id]
    quote_start = reasons_by_prefixes[prefix][statement_id].find('"')
    quote_end = reasons_by_prefixes[prefix][statement_id][quote_start+1:].find('"') + quote_start + 1
    if quote_start != -1 and quote_end != -1:
        postprocessed_statement = reasons_by_prefixes[prefix][statement_id][:quote_start-1] + reasons_by_prefixes[prefix][statement_id][quote_end+1:]
    postprocessed_statement = postprocessed_statement.replace("The statement is an entailment based on the evidence provided. ", "").strip()
    postprocessed_statement = postprocessed_statement.replace("The statement is an entailment to the evidence provided. ", "").strip()
    postprocessed_statement = postprocessed_statement.replace("The statement is an entailment to the presented evidence. ", "").strip()
    postprocessed_statement = postprocessed_statement.replace("The statement is an entailment. ", "").strip()
    
    postprocessed_statement = postprocessed_statement.replace("The statement is a contradiction based on the evidence provided. ", "").strip()
    postprocessed_statement = postprocessed_statement.replace("The statement is a contradiction to the evidence provided. ", "").strip()
    postprocessed_statement = postprocessed_statement.replace("The statement is a contradiction to the presented evidence. ", "").strip()
    postprocessed_statement = postprocessed_statement.replace("The statement contradicts the evidence provided because ", "").strip()
    postprocessed_statement = postprocessed_statement.replace("The statement contradicts the evidence because ", "").strip()
    postprocessed_statement = postprocessed_statement.replace("The statement is a contradiction. ", "").strip()
    postprocessed_statement = postprocessed_statement.replace("The statement contradicts the evidence provided. ", "").strip()
    
    postprocessed_reasons_by_prefixes[prefix][statement_id] = postprocessed_statement.strip()

# print(reasons_by_prefixes[prefix]["17a821f8-5e68-4bf7-ac01-3f96ddfc5187"])
pprint(postprocessed_reasons_by_prefixes[prefix])

## Prefix: 1. The statement is: "

In [12]:
prefix = "1. The statement is: \""

for statement_id in list(reasons_by_prefixes[prefix].keys()):
    # postprocessed_statement = reasons_by_prefixes[prefix][statement_id]
    quote_start = reasons_by_prefixes[prefix][statement_id].find('"')
    quote_end = reasons_by_prefixes[prefix][statement_id][quote_start+1:].find('"') + quote_start + 1
    if quote_start != -1 and quote_end != -1:
        postprocessed_statement = reasons_by_prefixes[prefix][statement_id][:quote_start-1] + reasons_by_prefixes[prefix][statement_id][quote_end+1:]
    postprocessed_statement = postprocessed_statement.replace("1. The statement is:", "").strip()
    if postprocessed_statement.startswith("2. "):
        postprocessed_statement = postprocessed_statement[2:]
    postprocessed_statement = postprocessed_statement.replace("\n3. ", "\n").strip()
    postprocessed_statement = postprocessed_statement.replace("\n4. ", "\n").strip()
    postprocessed_statement = postprocessed_statement.replace("\n5. ", "\n").strip()
    postprocessed_statement = postprocessed_statement.replace("\n6. ", "\n").strip()
    postprocessed_statement = postprocessed_statement.replace("\n7. ", "\n").strip()
    
    postprocessed_reasons_by_prefixes[prefix][statement_id] = postprocessed_statement.strip()

# print(reasons_by_prefixes[prefix]["17a821f8-5e68-4bf7-ac01-3f96ddfc5187"])
pprint(postprocessed_reasons_by_prefixes[prefix])

## Prefix: 1. The statement: "

In [13]:
prefix = "1. The statement: \""
pprint(reasons_by_prefixes[prefix])

for statement_id in list(reasons_by_prefixes[prefix].keys()):
    quote_start = reasons_by_prefixes[prefix][statement_id].find('"')
    quote_end = reasons_by_prefixes[prefix][statement_id][quote_start+1:].find('"') + quote_start + 1
    if quote_start != -1 and quote_end != -1:
        postprocessed_statement = reasons_by_prefixes[prefix][statement_id][:quote_start-1] + reasons_by_prefixes[prefix][statement_id][quote_end+1:]
    postprocessed_statement = postprocessed_statement.replace("1. The statement:", "").strip()
    if postprocessed_statement.startswith("2. "):
        postprocessed_statement = postprocessed_statement[2:]
    postprocessed_statement = postprocessed_statement.replace("\n3. ", "\n").strip()
    postprocessed_statement = postprocessed_statement.replace("\n4. ", "\n").strip()
    postprocessed_statement = postprocessed_statement.replace("\n5. ", "\n").strip()
    postprocessed_statement = postprocessed_statement.replace("\n6. ", "\n").strip()
    postprocessed_statement = postprocessed_statement.replace("\n7. ", "\n").strip()
    
    postprocessed_reasons_by_prefixes[prefix][statement_id] = postprocessed_statement.strip()

# print(reasons_by_prefixes[prefix]["17a821f8-5e68-4bf7-ac01-3f96ddfc5187"])
pprint(postprocessed_reasons_by_prefixes[prefix])

## Prefix: 1. The statement says "

In [14]:
prefix = "1. The statement says \""
pprint(reasons_by_prefixes[prefix])

for statement_id in list(reasons_by_prefixes[prefix].keys()):
    # postprocessed_statement = reasons_by_prefixes[prefix][statement_id]
    quote_start = reasons_by_prefixes[prefix][statement_id].find('"')
    quote_end = reasons_by_prefixes[prefix][statement_id][quote_start+1:].find('"') + quote_start + 1
    if quote_start != -1 and quote_end != -1:
        postprocessed_statement = reasons_by_prefixes[prefix][statement_id][:quote_start-1] + reasons_by_prefixes[prefix][statement_id][quote_end+1:]
    postprocessed_statement = postprocessed_statement.replace("1. The statement says", "").strip()
    if postprocessed_statement.startswith("2. "):
        postprocessed_statement = postprocessed_statement[2:]
    postprocessed_statement = postprocessed_statement.replace("\n3. ", "\n").strip()
    postprocessed_statement = postprocessed_statement.replace("\n4. ", "\n").strip()
    postprocessed_statement = postprocessed_statement.replace("\n5. ", "\n").strip()
    postprocessed_statement = postprocessed_statement.replace("\n6. ", "\n").strip()
    postprocessed_statement = postprocessed_statement.replace("\n7. ", "\n").strip()
    
    postprocessed_reasons_by_prefixes[prefix][statement_id] = postprocessed_statement.strip()

# print(reasons_by_prefixes[prefix]["17a821f8-5e68-4bf7-ac01-3f96ddfc5187"])
pprint(postprocessed_reasons_by_prefixes[prefix])

## Prefix: To determine if the statement

In [15]:
prefix = "To determine if the statement"

for statement_id in list(reasons_by_prefixes[prefix].keys()):
    postprocessed_statement = reasons_by_prefixes[prefix][statement_id]
    if postprocessed_statement.startswith("To determine if the statement is a contradiction or an entailment,") or postprocessed_statement.startswith("To determine if the statement is an entailment or a contradiction,"):
        postprocessed_statement = ".".join(postprocessed_statement.split(".")[1:]).strip()
    # postprocessed_statement = postprocessed_statement.replace("The statement is an entailment because it is supported by the evidence provided.", "").strip()

    postprocessed_reasons_by_prefixes[prefix][statement_id] = postprocessed_statement

pprint(postprocessed_reasons_by_prefixes[prefix])

## Other type of CoT reasons

In [16]:
postprocessed_reasons_by_prefixes["other"] = reasons_by_prefixes["other"]

In [17]:
import re

prefix = "other"
# pprint(reasons_by_prefixes[prefix])

for statement_id in list(reasons_by_prefixes[prefix].keys()):
    postprocessed_statement = reasons_by_prefixes[prefix][statement_id]
    if postprocessed_statement.startswith("Step 1: "):
        pattern = re.compile(r'Step \d+:', re.DOTALL)
        postprocessed_statement = re.sub(pattern, '', postprocessed_statement)
    
    postprocessed_reasons_by_prefixes[prefix][statement_id] = postprocessed_statement.strip()

# print(reasons_by_prefixes[prefix]["8fee5ce4-3e46-4731-842e-a5b1df451c7d"])
pprint(postprocessed_reasons_by_prefixes[prefix])

In [18]:
# postprocessed_reasons = [reason for reasons in list(postprocessed_reasons_by_prefixes.values()) for reason in list(reasons.values())]
# tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
# lengths = [len(tokenizer(reason)["input_ids"]) for reason in postprocessed_reasons]
# sns.histplot(lengths, bins=100, color='skyblue', edgecolor='black', kde=False)

## If "entailment" or "contradiction" is not present in the last sentence, add "The answer is entailment/contradiction"

In [22]:
answer_template = "The answer is "

with open("../../data/train.json", "r") as f:
    train_data = json.load(f)

postprocessed_reasons = {statement_id: reason for reasons in list(postprocessed_reasons_by_prefixes.values()) for statement_id, reason in reasons.items()}

statement_ids_to_remove = []
for statement_id, postprocessed_reason in postprocessed_reasons.items():
    last_sentence = ".".join(postprocessed_reason.split(".")[-2:])
    if statement_id in train_data:
        if "entailment" in last_sentence or "contradiction" in last_sentence:
            continue
        else:
            true_answer = train_data[statement_id]["Label"]
            postprocessed_reasons[statement_id] = "\n".join([postprocessed_reasons[statement_id], f"{answer_template}{true_answer}"])
    else:
        statement_ids_to_remove += [statement_id]

for statement_id in statement_ids_to_remove:
    del postprocessed_reasons[statement_id]

pprint(postprocessed_reasons)

In [26]:
for statement_id in list(train_data.keys()):
    train_data[statement_id]["CoT_label"] = postprocessed_reasons[statement_id]

with open("../../data/train_cot.json", "w") as f:
    json.dump(train_data, f, indent=4)