In [137]:
import sys
import os
import argparse
import os
import json
from datasets import load_dataset
import multiprocessing as mp
from tqdm import tqdm
from functools import partial
from openai import OpenAI
import networkx as nx
import re


input_file = os.path.join("rmanluo", "RoG-webqsp")
output_dir = os.path.join("datasets/AlignData", "RoG-webqsp")

# print("Save results to: ", output_dir)
if os.path.exists(output_dir) == False:
    os.makedirs(output_dir)

# Load dataset, select first 10 examples
dataset = load_dataset(input_file, split="train").select(range(10))

In [139]:
def get_entity_edges_with_neighbors(entity: str, graph: nx.Graph) -> list:
    '''
    given an entity, find all edges and neighbors
    '''
    neighbors = []
    edges = []

    if graph.has_node(entity):
        for neighbor in graph.neighbors(entity):
            neighbors.append(neighbor)
            edges.append(graph[entity][neighbor]['relation'])

    return edges, neighbors

In [140]:
def query_api(prompt):
    client = OpenAI(api_key="sk-xxx")
    response = (
        client.chat.completions.create(
            model="gpt-3.5-turbo-0125",
            messages=[{"role": "user", "content": prompt}],
            temperature=0,
            # max_tokens=500,
        )
        .choices[0]
        .message.content
    )
    # print("PROMPT: ", prompt)
    # print("=" * 50)
    # print("RECEIVED RESPONSE: ", response)
    # outputs_file = open(fpath, "w")
    # outputs_file.write(json.dumps({
    #     'prompt': prompt,
    #     'response': response
    # }))
    return {"prompt": prompt, "response": response}

In [141]:
def process_str(s):
    processed = []

    for item in s:
        parts = item.split(" -> ")
        for part in parts:
            if not processed or (processed and processed[-1] != part):
                processed.append(part)

    return ' -> '.join(processed)

In [142]:
def build_graph(graph: list) -> nx.Graph:
    G = nx.Graph()
    for triplet in graph:
        h, r, t = triplet
        G.add_edge(h, t, relation=r.strip())
    return G

In [144]:
for data in tqdm(dataset):
    id = data['id']
    question = data['question']
    graph = build_graph(data['graph'])
    starting_entity = data['q_entity'][0]

    # start MCQ reasoning
    reasoning_path = []
    flag = True
    while flag == True and len(reasoning_path) < 10:
        path_candidates, neighbors = get_entity_edges_with_neighbors(
            starting_entity, graph
        )

        options = []
        options.append(
            "0: EOS -> The final entity of current reasoning steps can directly answers the query. End of Selection."
        )
        i = 1
        for p, n in zip(path_candidates, neighbors):
            options.append(f"{i}: {starting_entity} -> {p} -> {n}")
            i += 1

        options_str = "\n".join(options)

        if len(reasoning_path) > 0:
            reasoning_path_str = process_str(reasoning_path)
            prompt = f"""
                User query: {question} \n
                To proceed, you must identify the most relevant reasoning path based on the current reasoning steps: {reasoning_path_str} \n
                Please review the following options and select the most appropriate reasoning path for the query, also including the corresponding entity where applicable: \n
                {options_str} \n
                After evaluating the options, please provide only the index of the selected reasoning path. If the final entity from the current reasoning steps directly answers the query, respond with option 0: EOS, End of Selection.
            """
        else:
            prompt = f"""
                User query: {question} \n
                To proceed, the starting entity is {starting_entity}. \n
                Please review the following options and select the most appropriate reasoning path for the query, also including the corresponding entity where applicable: \n
                {options_str} \n
                After evaluating the options, please provide only the index of the selected reasoning path. If the final entity from the current reasoning steps directly answers the query, respond with option 0: EOS, End of Selection.
            """

        try:
            response = query_api(prompt)['response'].strip()
        except Exception as e:
            print(e)
            print(f"Failed to get response for query: {question}")
            break

        if "EOS" in response:
            # print(f"END of SELECTION: {process_str(reasoning_path)}")
            flag = False
        else:
            index = int(re.findall(r"[-+]?\d*\.\d+|\d+", response)[0]) - 1
            # print(f"RESPONSE: {response}; INDEX: {index}")

            path = path_candidates[index]
            neighbor = neighbors[index]
            reasoning_path.append(f"{starting_entity} -> {path} -> {neighbor}")
            starting_entity = neighbor

    with open(os.path.join(output_dir, f"{id}.json"), "w") as f:
        json.dump(
            {"question": question, "reasoning_path": process_str(reasoning_path)}, f
        )

 30%|███       | 3/10 [00:21<00:45,  6.52s/it]

Error code: 400 - {'error': {'message': "This model's maximum context length is 16385 tokens. However, your messages resulted in 28012 tokens. Please reduce the length of the messages.", 'type': 'invalid_request_error', 'param': 'messages', 'code': 'context_length_exceeded'}}
Failed to get response for query: what country is the grand bahama island in


 40%|████      | 4/10 [00:22<00:26,  4.40s/it]

Error code: 400 - {'error': {'message': "This model's maximum context length is 16385 tokens. However, your messages resulted in 28555 tokens. Please reduce the length of the messages.", 'type': 'invalid_request_error', 'param': 'messages', 'code': 'context_length_exceeded'}}
Failed to get response for query: what kind of money to take to bahamas


100%|██████████| 10/10 [01:12<00:00,  7.29s/it]

Error code: 400 - {'error': {'message': "This model's maximum context length is 16385 tokens. However, your messages resulted in 49249 tokens. Please reduce the length of the messages.", 'type': 'invalid_request_error', 'param': 'messages', 'code': 'context_length_exceeded'}}
Failed to get response for query: which countries border the us



