# LLMs for producing taxonomies of research topics

## Classifying relationships with Chain of Thought
This code performs requests to Amazon Bedrock for building taxonomies of research topics.

## Setup

In [None]:
import pandas as pd
from collections import defaultdict
import requests, json, os, io, re, base64, random, time, csv, datetime
from IPython.display import display, HTML
from tqdm import tqdm


%load_ext autoreload
%autoreload 2
from bedrock import BedrockWrapper
from gpt import GPTWrapper

### Initialising Amazon Bedrock Wrapper & more

In [None]:
use = "" #bedrock or gpt

In [None]:
if use == "bedrock":
    wrapper = BedrockWrapper(model="YOUR_MODEL")
elif use == "gpt":
    wrapper = GPTWrapper(api_key="YOUR_API_KEY", model="YOUR_MODEL")

global conversation_history
conversation_history = ""
botname = "assistant"
username = "user"

results=[]
results_history=[]

GOLD_STANDARD_FILE = "../dataset/IEEE-Rel-1K.csv"#"TOY-40"#"GS_2650"
RESULTS_FOLDER     = "../results"
DYNAMIC = True # it changes the file where to save the results

### Prompt area

In [None]:
def generate_prompt(topic1, topic2): 
    
    prompt1_template = """
    
Classify the relationship between '[TOPIC-A]' and '[TOPIC-B]' by applying the following relationship definitions:
1. '[TOPIC-A]' is-broader-than '[TOPIC-B]' if '[TOPIC-A]' is a super-category of '[TOPIC-B]', that is '[TOPIC-B]' is a type, a branch, or a specialised aspect of '[TOPIC-A]' or that '[TOPIC-B]' is a tool or a methodology mostly used in the context of '[TOPIC-A]' (e.g., car is-broader-than wheel).
2. '[TOPIC-A]' is-narrower-than '[TOPIC-B]' if '[TOPIC-A]' is a sub-category of '[TOPIC-B]', that is '[TOPIC-A]' is a type, a branch, or a specialised aspect of '[TOPIC-B]' or that '[TOPIC-A]' is a tool or a methodology mostly used in the context of '[TOPIC-B]' (e.g., wheel is-narrower-than car).
3. '[TOPIC-A]' is-same-as-than '[TOPIC-B]' if '[TOPIC-A]' and '[TOPIC-B]' are synonymous terms denoting a very similar concept (e.g., 'beautiful' is-same-as-than 'attractive'), including when one is the plural form of the other (e.g., cat is-same-as-than cats).
4. '[TOPIC-A]' is-other-than '[TOPIC-B]' if '[TOPIC-A]' and '[TOPIC-B]' either have no direct relationship or share a different kind of relationship that does not fit into the other defined relationships.

Think step by step by following these sequential instructions:
1) Provide a precise definition for '[TOPIC-A]'.
2) Provide a precise definition for '[TOPIC-B]'.
3) Formulate a sentence that includes both '[TOPIC-A]' and '[TOPIC-B]'.
4) Discuss '[TOPIC-A]' and '[TOPIC-B]' usage and relationship (is-narrower-than, is-broader-than, is-same-as-than, or is-other-than).

    """

    prompt2_template = """
    
Given the previous discussion, determine which one of the following statements is correct:
1. '[TOPIC-A]' is-broader-than '[TOPIC-B]'
2. '[TOPIC-B]' is-narrower-than '[TOPIC-A]'
3. '[TOPIC-A]' is-narrower-than '[TOPIC-B]'
4. '[TOPIC-B]' is-broader-than '[TOPIC-A]'
5. '[TOPIC-A]' is-same-as-than '[TOPIC-B]'
6. '[TOPIC-A]' is-other-than '[TOPIC-B]'

Answer by only stating the number of the correct statement.

    """
   
    prompt1 = prompt1_template.replace("[TOPIC-A]",topic1).replace("[TOPIC-B]",topic2)
    prompt2 = prompt2_template.replace("[TOPIC-A]",topic1).replace("[TOPIC-B]",topic2)
    
    return prompt1, prompt2

### Routines for classification and parsing

In [None]:
def parser_for_mistral_amazon_bedrock(text:str, verbose:bool=False)->str:
    last = "6"
    
    text = text.strip()
    text = text[:text.rfind("Explanation:")] # remove the bit with explanation
    if text != None and text != '':
        splitted = text.splitlines()
        if verbose: print(splitted)
        for line in splitted:
            if len(line) > 0 and bool(re.search(r'\d', line)): # the last branch checks if it contains numbers
                last = line
                break            
    return last


def parser_for_cohere_amazon_bedrock(text:str, verbose:bool=False)->str:
    last = "6"
    
    text = text.strip()
    if text != None and text != '':
        splitted = text.splitlines()
        if verbose: print(splitted)
        for line in splitted:
            if len(line) > 0 and bool(re.search(r'\d', line)): # the last branch checks if it contains numbers
                try:
                    last = re.search('\d+.', line).group()  
                except:
                    last = line
                    print(line)
                break
    return last


def gpt(text:str, verbose:bool=False)->str:
    text = text.strip()
    if text == None or text == '':
        return "6"
    splitted = text.splitlines()
    if verbose: print(splitted)
    last = splitted[len(splitted)-1]
    return last
    


## Does the conversion of numbers in the actual relationship
def simple_parser(text:str)->str:
    if use == "bedrock":
        last = parser_for_mistral_amazon_bedrock(text)
    elif use == "gpt":
        last = gpt(text)


    numbers = re.findall(r'\d', last)
    last_number = numbers[-1] # I get the last number, should there be more than one!
    
    if "1" in last_number or  "2" in last_number : return "broader"
    if "3" in last_number or  "4" in last_number: return "narrower"
    if "5" in last_number or  "synonymous" in last_number: return "same-as"
    if "6" in last_number or  "different" in last_number : return "other"
    else : return "other"


In [None]:
def handle_message(user_message, wrapper, verbose = False):
    global conversation_history
    if conversation_history == "":
        new_user_message = f"{username}: {user_message}\n\n{botname}:"
    else: 
        new_user_message = f"{conversation_history}\n\n{username}: {user_message}\n\n{botname}:"


    response_text = wrapper.invoke_model(new_user_message, verbose = verbose, test = False)
    
    conversation_history = f"{conversation_history}\n\n{username}: {user_message}\n\n{botname}: {response_text}\n" # Update the conversation history with the user message and bot response

    with open(f'conv_history_{botname}_terminal.txt', "a") as f:
        f.write(f"{username}: {user_message}\n\n{botname}: {response_text}\n") # Append conversation to text file

    return response_text
        
        
    
# this is the core function
def classify(topic1, topic2, wrapper, max_num_continue=3, verbose=True, answer_min_size=200)  :
    prompt1, prompt2 = generate_prompt(topic1,topic2) # we create both prompts
    
    words_p1 = len(prompt1.split())
    r = handle_message(prompt1, wrapper, verbose = False) # submit prompt 1
   
    
    if verbose:
        print("response:" + r)
        print("response len:", len(r) )
         
            
    result = handle_message(prompt2, wrapper, verbose = False) # submit prompt 1
    if verbose: print(result)
    return simple_parser(result)

### Load dataset

In [None]:
dataset = pd.read_csv(f'{GOLD_STANDARD_FILE}.csv', encoding = "UTF-8", keep_default_na=False)
dataset = dataset[["subject", "object","original_label"]]
mapping = {"supertopic":"broader", "subtopic":"narrower", "same_as":"same-as", "not_related":"other"}
dataset["original_label"] = dataset["original_label"].apply(lambda x: mapping[x] if x in mapping else x)
print(f"Total number of rows: {len(dataset)}")
dataset.head()

### Creating CSV 
For hosting our results

In [None]:
###########################################
# create file name
current_time = datetime.datetime.now()
ttime = str(current_time).split(".")[0].replace(" ","_").replace(":","-") if DYNAMIC else "final"
sides = "DOUBLE-SIDED" if BOTH_ORDER else "SINGLE-SIDED"
num_prompts = "DOUBLE-PROMPT"
results_file_name= f'{RESULTS_FOLDER}/GPT4/{GOLD_STANDARD_FILE}_{ttime}.csv'
print(f"Results will be saved in {results_file_name}")

###########################################
# initialize csv

file=open(results_file_name, 'w', newline='')
writer = csv.writer(file)
writer.writerow(['subject', 'object', 'original_label', 'predicted_label', 'Predicted1', 'Predicted2', 'log', 'tokens_used'])
file.flush()

## Iterating over the relationships of the Gold standard

In [None]:
relationship_processed = 0
VERBOSE = False
START = 0
END = len(dataset)

with tqdm(total=END-START) as pbar:
    for idx, row in dataset.iterrows():
        if idx >= START and idx < END:
            topic1 = row["subject"]
            topic2 = row["object"]
            target = row["original_label"]
        
            result1 = classify(topic1, topic2, wrapper, verbose=False, max_num_continue=0)
            conversation_history1 = conversation_history
            conversation_history = ''
             
            result2 = classify(topic2, topic1, wrapper, verbose=False, max_num_continue=0)
            conversation_history2 = conversation_history
            conversation_history = conversation_history1 + conversation_history2

            if result1 == 'broader' and result2 == 'narrower':
                predicted = 'broader'
            elif result1 == 'narrower' and result2 == 'broader':
                predicted = 'narrower'
            elif (result1 == 'narrower' and result2 == 'narrower') or (result1 == 'broader' and result2 == 'broader'):
                if len(topic1) <= len(topic2):
                    predicted = 'broader'  # if contradiction broader is the shorter
                else:
                    predicted = 'narrower'
            else:  # different
                if result1 == 'same-as' or result2 == 'same-as':
                    predicted = 'same-as'  # if one of them is same-as, keep it
                elif result1 == 'broader' and result2 == 'other' or result1 == 'other' and result2 == 'narrower':
                    predicted = 'broader'  # direction wins over other
                elif result1 == 'narrower' and result2 == 'other' or result1 == 'other' and result2 == 'broader':
                    predicted = 'narrower'  # direction wins over other
                else:
                    predicted = result1

            results.append(predicted)
            results_history.append(conversation_history)

            if idx % 10 == 0:
                print(f"Computed {idx} iterations")

            if VERBOSE:
                if predicted == target:
                    print("Matched", end=" ")

            writer.writerow([topic1, topic2, target, predicted, result1, result2, conversation_history])
            file.flush()
            relationship_processed += 1

            # resetting conversation
            conversation_history = ""
            pbar.update(1)

display(HTML(f"""<a href="{results_file_name}">To see these results click here.</a>"""))