
### Import necessary libraries

In [99]:
import re
import requests
import csv
import pandas as pd
from tqdm import tqdm
from datetime import datetime
from IPython.display import display, HTML

##### Add KoboldAI API endpoint URL below

In [113]:
endpoint = ""

In [None]:
api_url = f"{endpoint}api/v1/generate"

##### Telegram Notifications for process completion

In [114]:
# uncomment the variables below to enable telegram notifications

# api_token = "" # insert telegram API token
# chat_id = "" # inset the chat id for receiving notifications

In [115]:
#enable the function below to start receiving notifications on telegram

# def notify(text='Cell execution completed.'):
#     requests.post('https://api.telegram.org/' + 'bot{}/sendMessage'.format(api_token), params=dict(chat_id=chat_id, text=text))

### Define the classification prompt

In [116]:
def generate_prompt(topic1, topic2):
    
    prompt_template = """
    
Classify the relationship between '[TOPIC-A]' and '[TOPIC-B]' by applying the following relationship definitions:
1. '[TOPIC-A]' is-broader-than '[TOPIC-B]' if '[TOPIC-A]' is a super-category of '[TOPIC-B]', that is '[TOPIC-B]' is a type, a branch, or a specialized aspect of '[TOPIC-A]' or that '[TOPIC-B]' is a tool or a methodology mostly used in the context of '[TOPIC-A]' (e.g., car is-broader-than wheel).
2. '[TOPIC-A]' is-narrower-than '[TOPIC-B]' if '[TOPIC-A]' is a sub-category of '[TOPIC-B]', that is '[TOPIC-A]' is a type, a branch, or a specialized aspect of '[TOPIC-B]' or that '[TOPIC-A]' is a tool or a methodology mostly used in the context of '[TOPIC-B]' (e.g., wheel is-narrower-than car).
3. '[TOPIC-A]' is-same-as-than '[TOPIC-B]' if '[TOPIC-A]' and '[TOPIC-B]' are synonymous terms denoting an identical concept (e.g., beautiful is-same-as-than attractive), including when one is the plural form of the other (e.g., cat is-same-as-than cats).
4. '[TOPIC-A]' is-other-than '[TOPIC-B]' if '[TOPIC-A]' and '[TOPIC-B]' either have no direct relationship or share a different kind of relationship that does not fit into the other defined relationships.

Given the previous definitions, determine which one of the following statements is correct:
1. '[TOPIC-A]' is-broader-than '[TOPIC-B]'
2. '[TOPIC-B]' is-narrower-than '[TOPIC-A]'
3. '[TOPIC-A]' is-narrower-than '[TOPIC-B]'
4. '[TOPIC-B]' is-broader-than '[TOPIC-A]'
5. '[TOPIC-A]' is-same-as-than '[TOPIC-B]'
6. '[TOPIC-A]' is-other-than '[TOPIC-B]'

    """

    final_prompt = prompt_template.replace("[TOPIC-A]", topic1).replace("[TOPIC-B]", topic2)

    return final_prompt

### Execute LLM

##### Use the cell below to establish parameters for LLM and then parse the resulting output.

In [117]:
def execute_llm(prompt, api_url, is_final_prompt=False):

    request_body = {
        "prompt": prompt,
        "use_story": False,
        "use_memory": False,
        "use_authors_note": False,
        "use_world_info": False,
        "max_context_length": 1600,
        "max_length": 254,
        "rep_pen": 1.0,
        "rep_pen_range": 2048,
        "rep_pen_slope": 0.7,
        "temperature": 0.1,
        "tfs": 1,
        "top_a": 0,
        "top_k": 100,
        "top_p": 1,
        "typical": 1,
        "sampler_order": [6, 0, 1, 3, 4, 2, 5],
        "singleline": False,
        "sampler_seed": 42,
        "sampler_full_determinism": True,
        "frmttriminc": False,
        "frmtrmblln": False
    }

 
    response = requests.post(api_url, json=request_body)

    
    if response.status_code != 200:
        return "other", ""

    response_data = response.json()
    response_text = response_data.get("results", [{}])[0].get("text", "")


    if not is_final_prompt:
        return response_text

    if not response_text.strip():
        return "other", ""


# Does the conversion of numbers in the actual relationship

    last_line = response_text.splitlines()[-1]
    if any(tag in last_line for tag in ["1", "2", "broader"]):
        return "broader", response_text
    elif any(tag in last_line for tag in ["3", "4", "narrower"]):
        return "narrower", response_text
    elif any(tag in last_line for tag in ["5", "same", "synonymous"]):
        return "same-as", response_text
    elif any(tag in last_line for tag in ["6", "other", "distinct"]):
        return "other", response_text
    else:
        return "other", response_text

##### The function ```classify_relationship``` is responsible for classifying Research Topics and returning the output obtained.

In [118]:
def classify_relationship(topic1, topic2, api_url):
    
    final_prompt = generate_prompt(topic1, topic2)

    final_output, buffer_output = execute_llm(final_prompt, api_url, is_final_prompt=True)

    return final_output, final_prompt, buffer_output

### Load Gold Standard

In [119]:
input_relations = f"../dataset/GS_1000.csv"

##### Use the cell below to parse ```Gold Standard``` and create a ```pandas``` dataframe

In [None]:
df = pd.read_csv(input_relations, encoding="utf-8")

In [None]:
# counts number of rows in dataframe

num_rows = len(df)
print(f"Total number of rows: {num_rows}")
df.head()

### Creating ```CSV```

In [None]:
# date and time format for dynamic file creation

current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

##### Use the cell below to store the output obtained as ```CSV``` in the folder [results](../results/)

In [None]:
predicted_relations = f"../results/GS_1000_{current_datetime}.csv"
print(f"Results will be saved in {predicted_relations}")

### Iterating over the relationships of the Gold standard

In [None]:
with open(predicted_relations, mode="w", newline="", encoding="utf-8") as output_file:

# defining the headers for CSV file

    fieldnames = (["subject", "object", "predicted_label", "original_label", "Predicted 1", "Predicted 2", "Final Prompt 1", "Final Prompt 2", "Buffer Output 1", "Buffer Output 2"])


    writer = csv.DictWriter(output_file, fieldnames=fieldnames)
    writer.writeheader()


    for index, row in df.iterrows():
        topic1 = row["subject"]
        topic2 = row["object"]
        target_label = row["original_label"]


        result1, prompt1, buffer_output1 = classify_relationship(topic1, topic2, api_url)
                
        result2, prompt2, buffer_output2 = classify_relationship(topic2, topic1, api_url)
        
        # empirical rules to mitigate the agreement/disagreement between the two branches of the two-way strategy

        if result1 == 'broader' and result2 == 'narrower' : 
            predicted = 'broader'
        elif result1 == 'narrower' and result2 == 'broader' : 
            predicted = 'narrower'
        elif ( result1 == 'narrower' and result2 == 'narrower')  or (result1 == 'broader' and result2 == 'broader') : 
            if len(topic1) <= len(topic2): predicted = 'broader'
            else : predicted = 'narrower'             
        else :
            if result1 == 'same-as' or result2 == 'same-as' :
                predicted = 'same-as'
            elif result1 == 'broader' and result2 == 'other' or  result1 == 'other' and result2 == 'narrower': 
                predicted = 'broader' 
            elif result1 == 'narrower' and result2 == 'other' or  result1 == 'other' and result2 == 'broader': 
                predicted = 'narrower'
            else :
                predicted = result1 

        # writing obtained outputs into the CSV
        
        writer.writerow({
            "subject": topic1,
            "object": topic2,
            "predicted_label": predicted or "",
            "original_label": target_label or "",
            "Predicted 1": result1 or "",
            "Predicted 2": result2 or "",
            "Final Prompt 1": prompt1 or "",
            "Final Prompt 2": prompt2 or "",
            "Buffer Output 1": buffer_output1 or "", 
            "Buffer Output 2": buffer_output2 or ""
        })


# notify(f"Predictions for {input_relations} written to {predicted_relations}") # uncomment the line to receive telegram notification on process completion
display(HTML(f"""<a href="{predicted_relations}">Click here to see the results.</a>"""))