# Notebook to obtain the third part of the context-enhanced datasets

In this notebook, we use ChatGPT (through the provided wrapper) to add context and relevant keywords to each datapoint of our datasets. These datasets are useful both to finetune the QA model, and the keyword retriever model. For scalability reason, the prompting has been divided in three notebooks, this is the third notebook of the three.

In [1]:
%pip install artifacts/gpt_wrapper-0.0.8-py3-none-any.whl
%pip install tiktoken
%pip install datasets
%pip install wikipedia-api

Processing ./artifacts/gpt_wrapper-0.0.8-py3-none-any.whl
gpt-wrapper is already installed with the same version as the provided wheel. Use --force-reinstall to force an installation of the wheel.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
import json
import gpt_wrapper
from gpt_wrapper.chat import Chat
import wikipediaapi
import random
import time
from IPython.display import clear_output
import re

In [19]:
gpt_wrapper.api_key = "API-KEY-REMOVED-FOR-PRIVACY"

In [4]:
# Load JSON file
def load_json(file):
    with open(file, 'r') as f:
        data = json.load(f)
    return data

In [5]:
def extract_keywords_from_response(response):
    # Define regular expressions to match the desired patterns
    pattern_en = r'^EN\.\s*\[?(.+)\]?$'
    pattern_fr = r'^FR\.\s*\[?(.+)\]?$'

    # Check if the response matches the EN pattern
    match_en = re.match(pattern_en, response)
    if match_en:
        return "EN", match_en.group(1)  # Return "EN" and the extracted keyword

    # Check if the response matches the FR pattern
    match_fr = re.match(pattern_fr, response)
    if match_fr:
        return "FR", match_fr.group(1)  # Return "FR" and the extracted keyword

    # If the response doesn't match any pattern, return None
    return None, None

In [6]:
debug = False

In [7]:
def flatten_and_concatenate(nested_list):
    # If the input is a string, return it
    if isinstance(nested_list, str):
        return nested_list

    # If the input is a list, apply the function to each element and concatenate the results
    if isinstance(nested_list, list):
        return ' '.join(flatten_and_concatenate(element) for element in nested_list)

    # If the input is neither a string nor a list, return an empty string
    return ''

In [20]:
# Create an empty list to store the datapoints
QandA = []
start_index = 0
success_count = 0

#solutions = load_json('DATASETS/gen_dataset_modern_nlm_ai.json')
#solutions = solutions + load_json('DATASETS/gen_dataset_modern_nlm_class_dataset.json')
#solutions = solutions + load_json('DATASETS/gen_dataset_modern_nlm_eli5_questions_answers.json')
#solutions = solutions + load_json('DATASETS/gen_dataset_modern_nlm_hh_questions_answers.json')
#solutions = solutions + load_json('DATASETS/gen_dataset_modern_nlm_sciq.json')
#solutions = solutions + load_json('DATASETS/gen_dataset_modern_nlm_stack_exchange.json')
solutions = load_json('DATASETS/gen_dataset_modern_nlm_synthetic_questions_answers.json')
solutions = solutions + load_json('DATASETS/gen_dataset_modern_nlm_texas_SFT.json')
solutions = solutions + load_json('DATASETS/gen_modern_nlm_stack_exchange.json')

try:
    # Load the previous list of datapoints (useful to resume the process if it was interrupted)
    with open('Definitions_3.json', 'r') as f:
        QandA = json.load(f)
        success_count = len(QandA)

    # Find the highest ID among the existing datapoints
    start_index = max([int(item['ID']) for item in QandA]) + 1 if QandA else 0

    if start_index == 0:
        print("Error: No datapoints found in the existing datasets.")
    else:
        print("Starting index:", start_index)

except Exception as e:
    print(f"Failed to load previous data: {e}")
    
data_len = len(solutions)
start_time = time.time()
elapsed_times = []
token_costs = []

for index_q, datapoint in enumerate(solutions[start_index:], start=start_index):

    iteration_start_time = time.time()
    finished = False
    remaining_trials = 4

    if 'question' in datapoint:
        question = datapoint['question']
        if 'choices' in datapoint and datapoint['choices'] is not None:
            question += flatten_and_concatenate(datapoint['choices'])

        if question is not None:
            while remaining_trials > 0 and finished is False:
                remaining_trials -= 1
                chat = Chat.create("Key_topics_" + str(random.randint(0, 1000000)))
                query = "Given the following question, which concept or definition, if looked up on Wikipedia, would be most likely to help you answer it? Return a single concept, ideally corresponding to a Wikipedia page title. If the question is in English, answer in the following format: EN. [CONCEPT], if the question is in French, answer in the following format: FR. [CONCEPT]\n\n" + question
                A = chat.ask(content=query)
                used_after = Chat.budget()['usage']
                language, keyword = extract_keywords_from_response(A.content)
                if keyword is not None:
                    if debug:
                        print("Keyword: ", keyword)
                else:
                    if debug:
                        print(f"Response was not in the expected format")

                if keyword is not None:
                    # Create a Wikipedia API client
                    wiki_wiki = wikipediaapi.Wikipedia(language)

                    # Retrieve and print the main definition for each keyword
                    page = wiki_wiki.page(keyword)
                    if page.exists():
                        if "may refer to" in page.text or "plusieurs concepts" in page.text or "dans les articles suivants" in page.text or "Suivant le contexte, le terme" in page.text:
                            if debug:
                                print(f"Skipping disambiguation page for '{keyword}'")
                        else:
                            if debug:
                                print(f"Main definition for '{keyword}':")
                                print(page.summary)

                            datapoint = {
                                "Language": language,
                                "Question": question,
                                "Keyword": keyword,
                                "Wiki_Summary": page.summary,
                                "ID": index_q
                            }
                            if debug:
                                print(datapoint)

                            QandA.append(datapoint)
                            finished = True
                            success_count += 1
                    else:
                        if debug:
                            print(f"No webpage found for '{keyword}'")

    # Save the list of datapoints as a JSON file
    with open('Definitions_3.json', 'w') as f:
        json.dump(QandA, f)

    iteration_end_time = time.time()

    elapsed_time = iteration_end_time - iteration_start_time
    elapsed_times.append(elapsed_time)

    average_time_per_datapoint = sum(elapsed_times) / len(elapsed_times)
    remaining_datapoints = data_len - index_q
    estimated_time_remaining = remaining_datapoints * average_time_per_datapoint

    clear_output()
    print("########################################################")
    print("Processed datapoint", index_q, "of", (data_len - 1), "(", round(index_q/(data_len-1)*100, 2), "%)")
    print("Success count: ", success_count)
    print("########################################################")

    print("ETA (minutes): ", round(estimated_time_remaining/60, 2))
    print("ETA (hours): ", round(estimated_time_remaining/3600, 2))
    tokens_remaining = Chat.budget()['limit'] - used_after
    print("tokens remaining: ", tokens_remaining)

########################################################
Processed datapoint 31932 of 32110 ( 99.45 %)
Success count:  25747
########################################################
ETA (minutes):  7.71
ETA (hours):  0.13
tokens remaining:  -697


APIException: Team has reached its budget