# Code for Bidirectional Chain-of-Thought (CoT, two-way) Prompting

## Import necessary libraries

In [None]:
import re
import os
import requests
import csv
import pandas as pd
from tqdm import tqdm
from datetime import datetime

##### Add KoboldAI API endpoint URL below

In [None]:
endpoint = ""
api_url = f"{endpoint}api/v1/generate"

##### Telegram Notifications for process completion

In [None]:
# uncomment the variables below to enable telegram notifications

# api_token = "" # insert telegram API token
# chat_id = "" # inset the chat id for receiving notifications

In [None]:
#enable the function below to start receiving notifications on telegram

# def notify(text='Cell execution completed.'):
#     requests.post('https://api.telegram.org/' + 'bot{}/sendMessage'.format(api_token), params=dict(chat_id=chat_id, text=text))

## Define Prompts

In [None]:
def generate_prompt(topic1,topic2): 
    
    prompt1_template = """
    
Classify the relationship between '[TOPIC-A]' and '[TOPIC-B]' by applying the following relationship definitions:
1. '[TOPIC-A]' is-broader-than '[TOPIC-B]' if '[TOPIC-A]' is a super-category of '[TOPIC-B]', that is '[TOPIC-B]' is a type, a branch, or a specialised aspect of '[TOPIC-A]' or that '[TOPIC-B]' is a tool or a methodology mostly used in the context of '[TOPIC-A]' (e.g., car is-broader-than wheel).
2. '[TOPIC-A]' is-narrower-than '[TOPIC-B]' if '[TOPIC-A]' is a sub-category of '[TOPIC-B]', that is '[TOPIC-A]' is a type, a branch, or a specialised aspect of '[TOPIC-B]' or that '[TOPIC-A]' is a tool or a methodology mostly used in the context of '[TOPIC-B]' (e.g., wheel is-narrower-than car).
3. '[TOPIC-A]' is-same-as-than '[TOPIC-B]' if '[TOPIC-A]' and '[TOPIC-B]' are synonymous terms denoting a very similar concept (e.g., 'beautiful' is-same-as-than 'attractive'), including when one is the plural form of the other (e.g., cat is-same-as-than cats).
4. '[TOPIC-A]' is-other-than '[TOPIC-B]' if '[TOPIC-A]' and '[TOPIC-B]' either have no direct relationship or share a different kind of relationship that does not fit into the other defined relationships.

Think step by step by following these sequential instructions:
1) Provide a precise definition for '[TOPIC-A]'.
2) Provide a precise definition for '[TOPIC-B]'.
3) Formulate a sentence that includes both '[TOPIC-A]' and '[TOPIC-B]'.
4) Discuss '[TOPIC-A]' and '[TOPIC-B]' usage and relationship (is-narrower-than, is-broader-than, is-same-as-than, or is-other-than).

    """

    prompt2_template = """
    
Given the previous discussion, determine which one of the following statements is correct:
1. '[TOPIC-A]' is-broader-than '[TOPIC-B]'
2. '[TOPIC-B]' is-narrower-than '[TOPIC-A]'
3. '[TOPIC-A]' is-narrower-than '[TOPIC-B]'
4. '[TOPIC-B]' is-broader-than '[TOPIC-A]'
5. '[TOPIC-A]' is-same-as-than '[TOPIC-B]'
6. '[TOPIC-A]' is-other-than '[TOPIC-B]'

Answer by only stating the number of the correct statement.

    """
   
    prompt1 = prompt1_template.replace("[TOPIC-A]",topic1).replace("[TOPIC-B]",topic2)
    prompt2 = prompt2_template.replace("[TOPIC-A]",topic1).replace("[TOPIC-B]",topic2)
    
    return prompt1, prompt2

##### Function (`simple_parser`) to parse and classify the relationship between research topics pairs based on the LLM-generated response

In [None]:
def simple_parser(text) :
    print('simple_parser -' + str(text) +'-')
    if text == None or text == '':
        return "other"
    if text.strip() == '' :
        print('Empty answer, setting to "other"')
        return "other"

    last = text.splitlines()[-1]
    if any(tag in last for tag in ["1", "2", "broader"]):
        return "broader"
    elif any(tag in last for tag in ["3", "4", "narrower"]):
        return "narrower"
    elif any(tag in last for tag in ["5", "same", "synonymous"]):
        return "same-as"
    elif any(tag in last for tag in ["6", "other", "distinct"]):
        return "other"
    else:
        return "other"

In [None]:
def split_text(text):
    parts = re.split(r'\n[a-zA-Z]', text)
    return parts

## Execute LLM

##### Define global variables for conversation history and user/bot settings

In [None]:
username = "user"
botname = "assistant"
num_lines_to_keep = 20
global conversation_history

##### Use the cell below to establish KoboldAI parameters for the model

In [None]:
def get_prompt(conversation_history, username, text): # For KoboldAI Generation
    return {
        "prompt": conversation_history + f"{username}: {text}\n\n{botname}:",
        "use_story": False,
        "use_memory": False,
        "use_authors_note": False,
        "use_world_info": False,
        "max_context_length": 1600,
        "max_length": 254,
        "rep_pen": 1.0,
        "rep_pen_range": 2048,
        "rep_pen_slope": 0.7,
        "temperature": 0.1,
        "tfs": 1,
        "top_a": 0,
        "top_k": 100,
        "top_p": 1,
        "typical": 1,
        "sampler_order": [6, 0, 1, 3, 4, 2, 5],
        "singleline": False,
        "sampler_seed": 42,   #set the seed
        "sampler_full_determinism": True,     #set it so the seed determines generation content
        "frmttriminc": False,
        "frmtrmblln": False
    }

##### Function (`handle_message`) to generate a prompt for KoboldAI based on conversation history and user input

In [None]:
def handle_message(user_message):
    global conversation_history
    prompt = get_prompt(conversation_history, username, user_message)  # Generate a prompt using the conversation history and user message

    response = requests.post(api_url, json=prompt, timeout=2500)  # Send the prompt to KoboldAI and get the response

    if response.status_code == 200:
        results = response.json()['results']
        text = results[0]['text']  # Parse the response and get the generated text
        response_text = split_text(text)[0]
        response_text = response_text.replace("  ", " ")
        conversation_history += f"{username}: {user_message}\n{botname}: {response_text}\n"  # Update the conversation history with the user message and bot response
        with open(f'conv_history_{botname}_terminal.txt', "a") as f:
            f.write(f"{username}: {user_message}\n{botname}: {response_text}\n")  # Append conversation to text file

        response_text = response_text.replace("\n", "")
        return response_text


def continue_():
    global conversation_history
    prompt = get_prompt(conversation_history, "", "")  # Generate a prompt using the conversation history and user message
    prompt['prompt'] = conversation_history
    response = requests.post(api_url, json=prompt)  # Send the prompt to KoboldAI and get the response
    if response.status_code == 200:
        results = response.json()['results']
        text = results[0]['text']  # Parse the response and get the generated text
        response_text = split_text(text)[0]
        response_text = response_text.replace("  ", " ")
        conversation_history += f"{response_text}\n"  # Update the conversation history with the user message and bot response
        with open(f'conv_history_{botname}_terminal.txt', "a") as f:
            f.write(f"{response_text}\n")  # Append conversation to text file

        response_text = response_text.replace("\n", "")

        return response_text

##### Function (`classify`) to classify the relationship between research topics pairs using LLM-generated responses

In [None]:
def classify(topic1, topic2, max_num_continue=3, verbose=True, answer_min_size=200):
    prompt1, prompt2 = generate_prompt(topic1, topic2)
    words_p1 = len(prompt1.split())
    r = handle_message(prompt1)  # Submit prompt 1
    if verbose:
        print("Response length:", len(r))

    # Continue generating response until the desired answer length is reached or the maximum number of continuation attempts is reached
    for _ in range(max_num_continue):
        if verbose:
            print('Words in response:', len(conversation_history.split()) - words_p1)
        if len(conversation_history.split()) - words_p1 > answer_min_size:
            continue
        r = continue_()  # Keep generating response after newline character
        if verbose:
            print("Response length:", len(r))
        if r == 0:
            break  # Break if continuation is not possible

    result = handle_message(prompt2)  # Submit prompt 2
    if verbose:
        print(result)
    return simple_parser(result)


In [None]:
with open(f'conv_history_{botname}_terminal.txt', 'a+') as file:
    file.seek(0)
    chathistory = file.read()
conversation_history = chathistory

# Provide a hint to the user
print(f"Loaded conversation history from `conv_history_{botname}_terminal.txt` file.")

## Data Preparation

In [None]:
topics1 = []
topics2 = []
target_labels = []

##### Verify and preprocess the dataset for classification of relationship between research topics

In [None]:
def verify_dataset(dataset):
    global topics1
    global topics2
    global target_labels
    topics1 = list(dataset['subject'])
    topics1 = list(filter(None, topics1))  # remove empty entries
    topics2 = list(dataset['object'])
    topics2 = list(filter(None, topics2))  # remove empty entries
    target_labels = list(dataset['simplified_label'])
    target_labels = list(filter(None, target_labels))  # remove empty entries

    assert len(topics1) == len(topics2) and len(topics1) == len(target_labels)

    dataset_properties = {
        'Number of topics': len(topics1),
        'Topics 1': topics1,
        'Topics 2': topics2,
        'Target labels': target_labels
    }

    return dataset_properties


dataset = pd.read_csv('../datasets/PEM-Rel-8k/PEM-Rel-8k-Test.csv', encoding="UTF-8", keep_default_na=False)
dataset_properties = verify_dataset(dataset)

# Print dataset properties
print("Dataset verification successful!")
for property_name, property_value in dataset_properties.items():
    if property_name == 'Target labels':
        label_counts = pd.Series(property_value).value_counts()
        print("Label frequencies:")
        print(label_counts)
    else:
        print(property_name + ": " + str(property_value))

## Save the Output

In [None]:
# Generate a timestamp for the file name
current_time = datetime.datetime.now()
timestamp = current_time.strftime("%Y-%m-%d_%H-%M-%S")

# Define the directory and file name for the generated file
results_directory = "../results"
file_prefix = "PEM-Rel-8K"
file_extension = ".csv"
file_name = f"{file_prefix}_{timestamp}{file_extension}"
file_path = os.path.normpath(os.path.join(results_directory, file_name))

print(f"File is saved at: {file_path}")

In [None]:
with open(file_path, 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['subject', 'object', 'simplified_label', 'predicted_label', 'Predicted 1', 'Predicted 2', 'log'])

In [None]:
results = []
results_history = []
n_cat = {}
n = 0

## Model Inference

In [None]:
for topic_frequency, topic1 in tqdm(enumerate(topics1), total=len(topics1), desc="Processing data"):
    topic2 = topics2[topic_frequency]
    target = target_labels[topic_frequency]
    n += 1
    n_cat[target] = n_cat.get(target, 0) + 1

    result1 = classify(topic1, topic2, verbose=False, max_num_continue=0)

    conversation_history1 = conversation_history
    conversation_history = ''
    
    
    result2 = classify(topic2, topic1, verbose=False, max_num_continue=0)
    conversation_history2 = conversation_history
    conversation_history = conversation_history1 + conversation_history2


    if result1 == 'broader' and result2 == 'narrower':
        predicted = 'broader'
    elif result1 == 'narrower' and result2 == 'broader':
        predicted = 'narrower'
    elif (result1 == 'narrower' and result2 == 'narrower') or (result1 == 'broader' and result2 == 'broader'):
        if len(topic1) <= len(topic2):
            predicted = 'broader'
        else:
            predicted = 'narrower'
    else:
        if result1 == 'same-as' or result2 == 'same-as':
            predicted = 'same-as'
        elif result1 == 'broader' and result2 == 'other' or result1 == 'other' and result2 == 'narrower':
            predicted = 'broader'
        elif result1 == 'narrower' and result2 == 'other' or result1 == 'other' and result2 == 'broader':
            predicted = 'narrower'
        else:
            predicted = result1


    results.append(predicted)
    results_history.append(conversation_history)

    conversation_history = ""
    writer.writerow([topics1[topic_frequency], topics2[topic_frequency], target_labels[topic_frequency], results[topic_frequency], result1, result2, results_history[topic_frequency]])
    file.flush()

    if n > 100000:
        break
    
# notify("Process Completed.")    #enable to receive notification on telegram when process is done
print("Process Completed.")