# Test OpenAI function call

In [9]:
import os
import openai
import re
from dotenv import load_dotenv
import json

load_dotenv()

openai.api_key = os.getenv("OPENAI_API_KEY")

COMPLETION_MODEL = "gpt-3.5-turbo-0613"


QUESTION = (
    "What's the result of 22 plus 5 in decimal added to the hexadecimal number A?"
)
messages = [
    {"role": "user", "content": QUESTION},
]


def add_decimal_values(arguments):
    value1 = int(re.search(r'"value1": (\d+)', str(arguments)).group(1))
    value2 = int(re.search(r'"value2": (\d+)', str(arguments)).group(1))

    result = value1 + value2
    print(f"{value1} + {value2} = {result} (decimal)")

    return value1 + value2


def add_hexadecimal_values(arguments):
    value1 = re.search(r'"value1": "(\w+)"', str(arguments)).group(1)
    value2 = re.search(r'"value2": "(\w+)"', str(arguments)).group(1)

    decimal1 = int(value1, 16)
    decimal2 = int(value2, 16)

    result = hex(decimal1 + decimal2)[2:]
    print(f"{value1} + {value2} = {result} (hex)")
    return result


def get_completion(messages):
    response = openai.ChatCompletion.create(
        model=COMPLETION_MODEL,
        messages=messages,
        functions=[
            {
                "name": "add_decimal_values",
                "description": "Add two decimal values",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "value1": {
                            "type": "integer",
                            "description": "The first decimal value to add. For example, 5",
                        },
                        "value2": {
                            "type": "integer",
                            "description": "The second decimal value to add. For example, 10",
                        },
                    },
                    "required": ["value1", "value2"],
                },
            },
            {
                "name": "add_hexadecimal_values",
                "description": "Add two hexadecimal values",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "value1": {
                            "type": "string",
                            "description": "The first hexadecimal value to add. For example, 5",
                        },
                        "value2": {
                            "type": "string",
                            "description": "The second hexadecimal value to add. For example, A",
                        },
                    },
                    "required": ["value1", "value2"],
                },
            },
        ],
        temperature=0,
    )

    return response


while True:
    print(json.dumps(messages, indent=4))
    print("-" * 50)
    response = get_completion(messages)
    print(response)
    print("-" * 50)

    if response.choices[0]["finish_reason"] == "stop":
        print(response.choices[0]["message"]["content"])
        break

    elif response.choices[0]["finish_reason"] == "function_call":
        fn_name = response.choices[0].message["function_call"].name
        arguments = response.choices[0].message["function_call"].arguments

        function = locals()[fn_name]
        result = function(arguments)
        print("-" * 100)

        messages.append(
            {
                "role": "assistant",
                "content": None,
                "function_call": {
                    "name": fn_name,
                    "arguments": arguments,
                },
            }
        )

        messages.append(
            {
                "role": "function", 
                "name": fn_name, 
                "content": f'{{"result": {str(result)} }}'}
        )


[
    {
        "role": "user",
        "content": "What's the result of 22 plus 5 in decimal added to the hexadecimal number A?"
    }
]
--------------------------------------------------
{
  "id": "chatcmpl-7UW4JF5YKVVZDqjbGO12O9Ct8LUH6",
  "object": "chat.completion",
  "created": 1687507467,
  "model": "gpt-3.5-turbo-0613",
  "choices": [
    {
      "index": 0,
      "message": {
        "role": "assistant",
        "content": null,
        "function_call": {
          "name": "add_decimal_values",
          "arguments": "{\n  \"value1\": 22,\n  \"value2\": 5\n}"
        }
      },
      "finish_reason": "function_call"
    }
  ],
  "usage": {
    "prompt_tokens": 154,
    "completion_tokens": 25,
    "total_tokens": 179
  }
}
--------------------------------------------------
22 + 5 = 27 (decimal)
----------------------------------------------------------------------------------------------------
[
    {
        "role": "user",
        "content": "What's the result of 22 plus 5 i

# Test Pydantic + Function Call

In [12]:
import openai
import os
from openai_function_call import openai_function
from dotenv import load_dotenv

load_dotenv()

openai.api_key = os.getenv("OPENAI_API_KEY")

@openai_function
def sum(a:int, b:int) -> int:
    """Sum description adds a + b"""
    return a + b

completion = openai.ChatCompletion.create(
        model="gpt-3.5-turbo-0613",
        temperature=0,
        functions=[sum.openai_schema],
        messages=[
            {
                "role": "system",
                "content": "You must use the `sum` function instead of adding yourself.",
            },
            {
                "role": "user",
                "content": "What is 6+3 use the `sum` function",
            },
        ],
    )

result = sum.from_response(completion)
print(result)  # 9

9


In [75]:
import openai
import os
from pydantic import Field, BaseModel
from openai_function_call import OpenAISchema
from dotenv import load_dotenv
import json

load_dotenv()
os.environ["http_proxy"] = "http://10.10.10.10:17890"
os.environ["https_proxy"] = "http://10.10.10.10:17890"

openai.api_key = os.getenv("OPENAI_API_KEY")

class UserDetails(OpenAISchema):
    """User Details"""
    name: str = Field(..., description="User's name")
    age: int = Field(..., description="User's age")

completion = openai.ChatCompletion.create(
    model="gpt-3.5-turbo-0613",
    functions=[UserDetails.openai_schema],
    messages=[
        {
            "role": "system", 
            "content": "I'm going to ask for user details. Use UserDetails to parse this data."
        },
        {
            "role": "user", 
            "content": "My name is John Doe and I'm 30 years old."
        },
    ],
)

user_details = UserDetails.from_response(completion)
print(user_details)  # UserDetails(name="John Doe", age=30)
print(dict(user_details))
print(json.dumps(dict(user_details), indent=4))
print("-" * 50)
print(UserDetails.schema_json(indent=4))

name='John Doe' age=30
{'name': 'John Doe', 'age': 30}
{
    "name": "John Doe",
    "age": 30
}
--------------------------------------------------
{
    "title": "UserDetails",
    "description": "User Details",
    "type": "object",
    "properties": {
        "name": {
            "description": "User's name",
            "type": "string"
        },
        "age": {
            "description": "User's age",
            "type": "integer"
        }
    },
    "required": [
        "name",
        "age"
    ]
}


# Test Code_String Generation

In [77]:
field_list = ["section name", "location of the samples and sections", "GPS location", 
              "associated fossils", "lithology", "number of species and genera found"]
attribute_set_string = ""
for idx, field in enumerate(field_list):
        new_attribute = f"""{field.replace(" ", "_")}: str = Field(..., description="{field}")"""
        attribute_set_string = attribute_set_string + "\n" + new_attribute
print(attribute_set_string)
print("-" * 100)

# name: str = Field(..., description="User's name")
# age: int = Field(..., description="User's age")

attr1 = "name"
attr2 = "age"
code_string = f"""{attr1}: str = Field(..., description="User's name")
{attr2}: int = Field(..., description="User's age")
"""
print(code_string)


section_name: str = Field(..., description="section name")
location_of_the_samples_and_sections: str = Field(..., description="location of the samples and sections")
GPS_location: str = Field(..., description="GPS location")
associated_fossils: str = Field(..., description="associated fossils")
lithology: str = Field(..., description="lithology")
number_of_species_and_genera_found: str = Field(..., description="number of species and genera found")
----------------------------------------------------------------------------------------------------
name: str = Field(..., description="User's name")
age: int = Field(..., description="User's age")



# Test Extraction w/o Function Call

In [3]:
""" LLMs for DeepShovel: 结构化数据抽取 """
import os
import openai
import textract
import tiktoken
from dotenv import load_dotenv
import os
import regex as re
from tqdm import tqdm
import pandas as pd
import csv
import random
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from PyPDF2 import PdfReader
import ast
import json


# Split a text into smaller chunks of size n, preferably ending at the end of a sentence
def create_chunks(text, n, tokenizer):
    tokens = tokenizer.encode(text)
    """Yield successive n-sized chunks from text."""
    i = 0
    while i < len(tokens):
        # Find the nearest end of sentence within a range of 0.5 * n and 1.5 * n tokens
        j = min(i + int(1.5 * n), len(tokens))
        while j > i + int(0.5 * n):
            # Decode the tokens and check for full stop or newline
            chunk = tokenizer.decode(tokens[i:j])
            if chunk.endswith(".") or chunk.endswith("\n"):
                break
            j -= 1
        # If no end of sentence found, use n tokens as the chunk size
        if j == i + int(0.5 * n):
            j = min(i + n, len(tokens))
        yield tokens[i:j]
        i = j


# 使用gpt-3.5-turbo抽取数据，加入异常处理机制
def extract_chunk(document, template_prompt):
    for i in range(3):  # Retry the API call up to 3 times
        try:
            prompt=template_prompt.replace('<document>', document)
            response = openai.ChatCompletion.create(
                model='gpt-3.5-turbo', 
                messages=[
                    {"role": "user", "content": prompt}
                ],
                temperature=0,
                max_tokens=1500,
                top_p=1,
                frequency_penalty=0,
                presence_penalty=0
            )
            return "1. " + response['choices'][0]['message']['content']
        except openai.error.RateLimitError:  # If rate limit is exceeded
            wait_time = (2 ** i) + random.random()  # Exponential backoff with jitter
            logging.warning(f"Rate limit exceeded. Retrying after {wait_time} seconds.")
            time.sleep(wait_time)  # Wait before retrying
        except Exception as e:  # If any other error occurs
            logging.error(f"API call failed: {str(e)}")
            return None  # Return None for failure
    logging.error("Failed to call OpenAI API after multiple retries due to rate limiting.")
    return None  # Return None for failure


# 调用API并使用重试机制处理rate limit error和其他异常
def get_completion(prompt, model="gpt-3.5-turbo"):
    for i in range(3):  # Retry the API call up to 3 times
        try:
            messages = [{"role": "user", "content": prompt}]
            response = openai.ChatCompletion.create(
                model=model,
                messages=messages,
                temperature=0,
            )
            return response.choices[0].message["content"]
        except openai.error.RateLimitError:  # If rate limit is exceeded
            wait_time = (2 ** i) + random.random()  # Exponential backoff with jitter
            logging.warning(f"Rate limit exceeded. Retrying after {wait_time} seconds.")
            time.sleep(wait_time)  # Wait before retrying
        except Exception as e:  # If any other error occurs
            logging.error(f"API call failed: {str(e)}")
            return None  # Return None for failure
    logging.error("Failed to call OpenAI API after multiple retries due to rate limiting.")
    return None  # Return None for failure


# 处理过程信息写入log文件
def log_to_file(log_file, message):
    try:
        with open(log_file, 'a') as file:
            file.write(message + '\n')
    except Exception as e:
        logging.error(f'Failed to log to file {log_file}: {str(e)}')
        raise


# 使用PdfReader读取pdf文献，手动加入Page Number信息
def read_pdf(filepath):
    """Takes a filepath to a PDF and returns a string of the PDF's contents"""
    # creating a pdf reader object
    reader = PdfReader(filepath)
    pdf_text = ""
    page_number = 0
    for page in reader.pages:
        page_number += 1
        pdf_text += page.extract_text() + f"\nPage Number: {page_number}"
    return pdf_text


# 传入pdf路径和带抽取的属性列表，返回抽取的结构化数据
def data_extraction(pdf_path, field_list):
    # 1. 执行pdf解析和切片
    pdf_text = read_pdf(pdf_path)
    clean_text = pdf_text.replace("  ", " ").replace("\n", "; ").replace(';',' ')
    tokenizer = tiktoken.get_encoding("cl100k_base")
    chunks = create_chunks(clean_text, 1000, tokenizer)
    text_chunks = [tokenizer.decode(chunk) for chunk in chunks]

    # 2. 关键信息抽取：多线程对text_chunks处理，抽取关键信息
    question_format = "0. What is the value of the 'title' attribute"
    # 适应性地生成抽取问题，并集成到抽取提示extract_prompt中去
    for idx, field in enumerate(field_list):
        new_question = str(idx+1) + ". What is the value of the '" + field + "' attribute"
        question_format = question_format + "\n" + new_question
    document = '<document>'
    # 关键信息抽取prompt
    extract_prompt=f'''Extract key pieces of information from this regulation document.
If a particular piece of information is not present, output \"Not specified\".
When you extract a key piece of information, include the closest page number.
---
Use the following format:
{question_format}
---
Document: \"\"\"{document}\"\"\"\n
0. What is the value of the 'title' attribute: Origin of Lower Carboniferous cherts in southern Guizhou, South China (Page 1)
1.'''
    # 多线程对text_chunks处理，抽取关键信息
    results = []
    log_file = 'log_geo_extract.txt'
    # log_to_file(log_file, f'Number of chunks: {len(text_chunks)}')
    with ThreadPoolExecutor() as executor:
        # 多线程处理
        futures = {executor.submit(extract_chunk, chunk, extract_prompt): chunk for chunk in text_chunks}
        for future in tqdm(as_completed(futures), total=len(futures), desc='Processing chunks'):
            # 收集完成的线程处理好的结果
            response = future.result()
            if response is None:
                # log_to_file(log_file, f'Failed to process chunk {futures[future]}')
                pass
            else:
                # 汇总关键信息抽取的结果
                results.append(response)
                # log_to_file(log_file, f'Successfully processed chunk!')
    # 进一步整理关键信息抽取结果，便于下一步格式化转换
    groups = [r.split('\n') for r in results]
    groups = [y for x in groups for y in x]
    groups = sorted(groups)
    groups = [x for x in groups if "Not specified" not in x and "__" not in x]
    zipped = groups
    # 移除太长的结果 (保留len(r) <= 180)
    zipped = [r for r in zipped if len(r) <= 180]

    # 3. 数据格式转换：根据抽取的关键信息，转换生成JSON样式的结果
    zipped_example = ["1. What is the value of the 'section name' attribute: The end-Triassic extinction event (ETE) (Page 1)", "1. What is the value of the 'section name' attribute: Katsuyama section (Page 2)", "1. What is the value of the 'section name' attribute: The Inuyama area (Page 1)", "2. What is the value of the 'location of the samples and sections' attribute: Katsuyama section, Inuyama, Japan (Page 1)", "2. What is the value of the 'location of the samples and sections' attribute: Inuyama area, central Japan (Page 2)", "2. What is the value of the 'location of the samples and sections' attribute: Rock samples from TJ-3 to TJ + 4 (3 beds above TJ + 1) continuously (Page 2)", "3. What is the value of the 'GPS location' attribute: N 35◦25.367′, E 136◦58.261 (Page 2)", "4. What is the value of the 'associated fossils' attribute: Sea surface-dwelling radiolaria (Page 1)", "4. What is the value of the 'associated fossils' attribute: Radiolarian fossils (Page 1)", "4. What is the value of the 'associated fossils' attribute: Radiolarian fossils (Page 3)", "5. What is the value of the 'lithology' attribute: Bedded chert (Page 1)", "5. What is the value of the 'lithology' attribute: Bedded chert and siliciclastic rocks (Page 2)", "5. What is the value of the 'lithology' attribute: Siliceous mudstone, bedded chert sequence, and siliciclastic rocks (Page 1)"]
    zipped_str_example = str(zipped_example)[1:][:-1]
    field_list_example = ["section name", "location of the samples and sections", "GPS location", 
                          "associated fossils", "lithology", "number of species and genera found"]
    zipped_str = str(zipped)[1:][:-1]
    # 数据格式转换prompt
    transform_prompt = f'''You will read a paragraph, summarise it in JSON format according to keywords and remove duplicate values.
---
Here is an example: 

PARAGRAPH
{zipped_str_example}
KEYWORDS
{field_list_example}
OUTPUT
{{
    "section name": [
        "The end-Triassic extinction event (ETE)",
        "Katsuyama section",
        "The Inuyama area"
    ],
    "location of the samples and sections": [
        "Katsuyama section, Inuyama, Japan",
        "Inuyama area, central Japan",
        "Rock samples from TJ-3 to TJ + 4 (3 beds above TJ + 1) continuously"
    ],
    "GPS location": [
        "N 35◦25.367′, E 136◦58.261"
    ],
    "associated fossils": [
        "Sea surface-dwelling radiolaria",
        "Radiolarian fossils"
    ],
    "lithology": [
        "Bedded chert",
        "Bedded chert and siliciclastic rocks",
        "Siliceous mudstone, bedded chert sequence, and siliciclastic rocks"
    ],
    "number of species and genera found": []
}}
---
Here is the paragragh you need to process, summarise it in JSON format according to keywords and remove duplicate values: 

PARAGRAPH
{zipped_str}
KEYWORDS
{field_list}
OUTPUT

'''

    response = get_completion(transform_prompt)
    res_json = ast.literal_eval(response)

    return res_json


if __name__ == '__main__':
    # 环境初始化，用户上传OpenAI API key
    load_dotenv()
    os.environ["http_proxy"] = "http://10.10.1.3:10000"
    os.environ["https_proxy"] = "http://10.10.1.3:10000"
    # Load your API key from an environment variable or secret management service
    openai.api_key = os.getenv("OPENAI_API_KEY")
    os.environ['OPENAI_API_KEY'] = openai.api_key

    # 用户上传pdf，输入需要抽取的属性列表
    pdf_path = "data/radiolarian/000.pdf"
    field_list = ["section name", "location of the samples and sections", "GPS location", 
                  "associated fossils", "lithology", "number of species and genera found"]
    
    # LLMs结构化数据抽取
    res_json = data_extraction(pdf_path, field_list)
    with open('results/result.json', 'w', newline='\n') as file:
        json.dump(res_json, file, indent=4)

    print(json.dumps(res_json, indent=4))


Processing chunks: 100%|██████████| 16/16 [00:08<00:00,  1.81it/s]


{
    "section name": [
        "Abstract",
        "Kunga Island section"
    ],
    "location of the samples and sections": [
        "Kunga Island, Queen Charlotte Islands (QCI), and Inuyama",
        "Queen Charlotte Islands, B.C. (Canada) and Inuyama",
        "Queen Charlotte Islands, Canada",
        "Queen Charlotte Islands, Canada; Inuyama, Japan",
        "Queen Charlotte Islands, Canada; Inuyama, Kuzuu and Ikuno, Japan; New Zealand; Montenegro"
    ],
    "GPS location": [],
    "associated fossils": [],
    "lithology": [],
    "number of species and genera found": [
        "Nearly 20 genera and over 130 Rhaetian species disappeared at the end of the Triassic"
    ]
}


# Test Extraction w/ Function Call

In [4]:
""" LLMs for DeepShovel: 结构化数据抽取 """
import os
import openai
import textract
import tiktoken
from dotenv import load_dotenv
import os
import regex as re
from tqdm import tqdm
import pandas as pd
import csv
import random
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from PyPDF2 import PdfReader
import ast
import json
from pydantic import Field, BaseModel
from openai_function_call import OpenAISchema


# Split a text into smaller chunks of size n, preferably ending at the end of a sentence
def create_chunks(text, n, tokenizer):
    tokens = tokenizer.encode(text)
    """Yield successive n-sized chunks from text."""
    i = 0
    while i < len(tokens):
        # Find the nearest end of sentence within a range of 0.5 * n and 1.5 * n tokens
        j = min(i + int(1.5 * n), len(tokens))
        while j > i + int(0.5 * n):
            # Decode the tokens and check for full stop or newline
            chunk = tokenizer.decode(tokens[i:j])
            if chunk.endswith(".") or chunk.endswith("\n"):
                break
            j -= 1
        # If no end of sentence found, use n tokens as the chunk size
        if j == i + int(0.5 * n):
            j = min(i + n, len(tokens))
        yield tokens[i:j]
        i = j


# 使用gpt-3.5-turbo抽取数据，加入异常处理机制
def extract_chunk(document, template_prompt):
    for i in range(3):  # Retry the API call up to 3 times
        try:
            prompt=template_prompt.replace('<document>', document)
            response = openai.ChatCompletion.create(
                model='gpt-3.5-turbo', 
                messages=[
                    {"role": "user", "content": prompt}
                ],
                temperature=0,
                max_tokens=1500,
                top_p=1,
                frequency_penalty=0,
                presence_penalty=0
            )
            return "1. " + response['choices'][0]['message']['content']
        except openai.error.RateLimitError:  # If rate limit is exceeded
            wait_time = (2 ** i) + random.random()  # Exponential backoff with jitter
            logging.warning(f"Rate limit exceeded. Retrying after {wait_time} seconds.")
            time.sleep(wait_time)  # Wait before retrying
        except Exception as e:  # If any other error occurs
            logging.error(f"API call failed: {str(e)}")
            return None  # Return None for failure
    logging.error("Failed to call OpenAI API after multiple retries due to rate limiting.")
    return None  # Return None for failure


# 调用API并使用重试机制处理rate limit error和其他异常
def get_completion(prompt, model="gpt-3.5-turbo"):
    for i in range(3):  # Retry the API call up to 3 times
        try:
            messages = [{"role": "user", "content": prompt}]
            response = openai.ChatCompletion.create(
                model=model,
                messages=messages,
                temperature=0,
            )
            return response.choices[0].message["content"]
        except openai.error.RateLimitError:  # If rate limit is exceeded
            wait_time = (2 ** i) + random.random()  # Exponential backoff with jitter
            logging.warning(f"Rate limit exceeded. Retrying after {wait_time} seconds.")
            time.sleep(wait_time)  # Wait before retrying
        except Exception as e:  # If any other error occurs
            logging.error(f"API call failed: {str(e)}")
            return None  # Return None for failure
    logging.error("Failed to call OpenAI API after multiple retries due to rate limiting.")
    return None  # Return None for failure

# 调用API并使用重试机制处理rate limit error和其他异常，调用function call功能
def get_completion_function_call(prompt, attribute_set_string):
    class AttributeDict(OpenAISchema):
        """Attributes of user input"""
        exec(attribute_set_string)

    for i in range(3):  # Retry the API call up to 3 times
        try:
            messages = [{"role": "system", "content": "Use AttributeDict to parse this data."}, 
                        {"role": "user", "content": prompt}]
            response = openai.ChatCompletion.create(
                model="gpt-3.5-turbo-0613",
                functions = [AttributeDict.openai_schema],
                messages=messages,
                temperature=0,
            )
            return response
        except openai.error.RateLimitError:  # If rate limit is exceeded
            wait_time = (2 ** i) + random.random()  # Exponential backoff with jitter
            logging.warning(f"Rate limit exceeded. Retrying after {wait_time} seconds.")
            time.sleep(wait_time)  # Wait before retrying
        except Exception as e:  # If any other error occurs
            logging.error(f"API call failed: {str(e)}")
            return None  # Return None for failure
    logging.error("Failed to call OpenAI API after multiple retries due to rate limiting.")
    return None  # Return None for failure


# 处理过程信息写入log文件
def log_to_file(log_file, message):
    try:
        with open(log_file, 'a') as file:
            file.write(message + '\n')
    except Exception as e:
        logging.error(f'Failed to log to file {log_file}: {str(e)}')
        raise


# 使用PdfReader读取pdf文献，手动加入Page Number信息
def read_pdf(filepath):
    """Takes a filepath to a PDF and returns a string of the PDF's contents"""
    # creating a pdf reader object
    reader = PdfReader(filepath)
    pdf_text = ""
    page_number = 0
    for page in reader.pages:
        page_number += 1
        pdf_text += page.extract_text() + f"\nPage Number: {page_number}"
    return pdf_text


# 传入pdf路径和带抽取的属性列表，返回抽取的结构化数据
def data_extraction(pdf_path, field_list):
    # 1. 执行pdf解析和切片
    pdf_text = read_pdf(pdf_path)
    clean_text = pdf_text.replace("  ", " ").replace("\n", "; ").replace(';',' ')
    tokenizer = tiktoken.get_encoding("cl100k_base")
    chunks = create_chunks(clean_text, 1000, tokenizer)
    text_chunks = [tokenizer.decode(chunk) for chunk in chunks]

    # 2. 关键信息抽取：多线程对text_chunks处理，抽取关键信息
    question_format = "0. What is the value of the 'title' attribute"
    # 适应性地生成抽取问题，并集成到抽取提示extract_prompt中去
    for idx, field in enumerate(field_list):
        new_question = str(idx+1) + ". What is the value of the '" + field + "' attribute"
        question_format = question_format + "\n" + new_question
    document = '<document>'
    # 关键信息抽取prompt
    extract_prompt=f'''Extract key pieces of information from this regulation document.
If a particular piece of information is not present, output \"Not specified\".
When you extract a key piece of information, include the closest page number.
---
Use the following format:
{question_format}
---
Document: \"\"\"{document}\"\"\"\n
0. What is the value of the 'title' attribute: Origin of Lower Carboniferous cherts in southern Guizhou, South China (Page 1)
1.'''
    # 多线程对text_chunks处理，抽取关键信息
    results = []
    log_file = 'log_geo_extract.txt'
    # log_to_file(log_file, f'Number of chunks: {len(text_chunks)}')
    with ThreadPoolExecutor() as executor:
        # 多线程处理
        futures = {executor.submit(extract_chunk, chunk, extract_prompt): chunk for chunk in text_chunks}
        for future in tqdm(as_completed(futures), total=len(futures), desc='Processing chunks'):
            # 收集完成的线程处理好的结果
            response = future.result()
            if response is None:
                # log_to_file(log_file, f'Failed to process chunk {futures[future]}')
                pass
            else:
                # 汇总关键信息抽取的结果
                results.append(response)
                # log_to_file(log_file, f'Successfully processed chunk!')
    # 进一步整理关键信息抽取结果，便于下一步格式化转换
    groups = [r.split('\n') for r in results]
    groups = [y for x in groups for y in x]
    groups = sorted(groups)
    groups = [x for x in groups if "Not specified" not in x and "__" not in x]
    zipped = groups
    # 移除太长的结果 (保留len(r) <= 180)
    zipped = [r for r in zipped if len(r) <= 180]

    # 3. 数据格式转换：根据抽取的关键信息，转换生成JSON样式的结果
    zipped_str = str(zipped)[1:][:-1]
    # 数据格式转换prompt
    transform_prompt = f'''I'm going to ask for attributes. Use AttributeDict to parse this data."
---
PARAGRAPH
{zipped_str}
'''

    # 适应性地生成attribute_set_string, 在AttributeDict类中exec()生成可执行代码
    attribute_set_string = ""
    for idx, field in enumerate(field_list):
        new_attribute = f"""{field.replace(" ", "_")}: str = Field(..., description="{field}")"""
        attribute_set_string = attribute_set_string + "\n" + new_attribute

    class AttributeDict(OpenAISchema):
        """Attributes of user input"""
        exec(attribute_set_string)

    response = get_completion_function_call(transform_prompt, attribute_set_string)
    arrtibutes = AttributeDict.from_response(response)
    res_json = dict(arrtibutes)

    return res_json


if __name__ == '__main__':
    # 环境初始化，用户上传OpenAI API key
    load_dotenv()
    os.environ["http_proxy"] = "http://10.10.1.3:10000"
    os.environ["https_proxy"] = "http://10.10.1.3:10000"
    # os.environ["http_proxy"] = "http://10.10.10.10:17890"
    # os.environ["https_proxy"] = "http://10.10.10.10:17890"
    # Load your API key from an environment variable or secret management service
    openai.api_key = os.getenv("OPENAI_API_KEY")
    os.environ['OPENAI_API_KEY'] = openai.api_key

    # 用户上传pdf，输入需要抽取的属性列表
    pdf_path = "data/radiolarian/466.pdf"
    field_list = ["section name", "location of the samples and sections", "GPS location", 
                  "associated fossils", "lithology", "number of species and genera found"]

    # LLMs结构化数据抽取
    res_json = data_extraction(pdf_path, field_list)
    with open('results/result.json', 'w', newline='\n') as file:
        json.dump(res_json, file, indent=4)

    print(json.dumps(res_json, indent=4))


Processing chunks: 100%|██████████| 12/12 [00:16<00:00,  1.35s/it]


{
    "section_name": "5.3 Buryella tetradica -Bekoma campechensis interval zone (Page 7)",
    "location_of_the_samples_and_sections": "Greater Indian passive continental margin (Page 21966)",
    "GPS_location": "",
    "associated_fossils": "Bekoma campechensis, Buryella tetradica, B. pentadica, Clathrocycloma (?) parcum, C. aff. catherinea (Page 7)",
    "lithology": "burgundy and gray laminated siliceous shale and siliceous rocks (Page 5)",
    "number_of_species_and_genera_found": "54 species of 30 radiolarian genera (Page 5)"
}


# Test sentence similarity prompt

In [6]:
import openai
from dotenv import load_dotenv
import random
import logging
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from PyPDF2 import PdfReader


# 调用API并使用重试机制处理rate limit error和其他异常
def get_completion(prompt, model="gpt-3.5-turbo"):
    for i in range(3):  # Retry the API call up to 3 times
        try:
            messages = [{"role": "user", "content": prompt}]
            response = openai.ChatCompletion.create(
                model=model,
                messages=messages,
                temperature=0,
            )
            return response.choices[0].message["content"]
        except openai.error.RateLimitError:  # If rate limit is exceeded
            wait_time = (2 ** i) + random.random()  # Exponential backoff with jitter
            logging.warning(f"Rate limit exceeded. Retrying after {wait_time} seconds.")
            time.sleep(wait_time)  # Wait before retrying
        except Exception as e:  # If any other error occurs
            logging.error(f"API call failed: {str(e)}")
            return None  # Return None for failure
    logging.error("Failed to call OpenAI API after multiple retries due to rate limiting.")
    return None  # Return None for failure


sentence1 = "The quick brown fox jumps over the lazy dog."
sentence2 = "The quick yellow fox jumps over the lazy cat."
prompt = f'''Please decide whether the following two sentences are semantically related.
Your output can only be relevant or irrelevant.
---
SENTENCE 1: {sentence1}
SENTENCE 2: {sentence2}
OUTPUT
'''

response = get_completion(prompt)
print(response)

Relevant


# Test revChatGPT.V3 chatbot

In [20]:
import os
import openai
from dotenv import load_dotenv
from revChatGPT.V3 import Chatbot


load_dotenv()
os.environ["http_proxy"] = "http://10.10.1.3:10000"
os.environ["https_proxy"] = "http://10.10.1.3:10000"
openai.api_key = os.getenv("OPENAI_API_KEY")
os.environ['OPENAI_API_KEY'] = openai.api_key

# basic example
# chatbot = Chatbot(api_key=openai.api_key)
# chatbot.ask("Hello world")

# streaming example
chatbot = Chatbot(api_key=openai.api_key)
for data in chatbot.ask_stream("Hello world"):
    print(data, end="", flush=True)

Hello! How can I assist you today?

# Test single paper extraction

In [1]:
# -*- coding: utf-8 -*-
import os
import sys
import re
from PyPDF2 import PdfReader


def read_pdf(filepath):
    """Takes a filepath to a PDF and returns a string of the PDF's contents"""
    # creating a pdf reader object
    reader = PdfReader(filepath)
    pdf_text = ""
    page_number = 0
    for page in reader.pages:
        page_number += 1
        pdf_text += page.extract_text() + f"\nPage Number: {page_number}"
    return pdf_text

pdf_text = read_pdf("data/Favorable Effects of Tacrolimus Monotherapy on Myasthenia Gravis Patients.pdf")
str_tmp = pdf_text.replace("  ", " ").replace("\n", "; ").replace(';',' ')

sentences_list = []
paper_idx = 0
paper_cnt = 0
sentence_cnt = 0
sentence_ls = str_tmp.split('. ')
for sentence in sentence_ls:
    # 正则过滤字符串，只保留空格、数字、字母、百分数、-、<、>、/、小数点等
    sentence = re.sub(r'[^ \dA-Za-z%<>/\.-]', '', sentence)
    # 如果句子太长，切分处理
    if len(sentence) > 500:
        idx = 0
        while idx < len(sentence):
            tmp_str = sentence[idx: min(idx + 500, len(sentence))] + " end"
            if len(tmp_str) <= 8:
                tmp_str += ' end'
            if tmp_str[-1] != ".":
                tmp_str += '.'
            sentences_list.append(tmp_str)
            sentence_cnt += 1
            idx += 500
    else:
        if len(sentence) <= 8:
            sentence += ' end'
        if sentence[-1] != ".":
            sentence += '.'
        sentences_list.append(sentence)
        sentence_cnt += 1
print("# total paper count: {},  total sentence count: {}".format(paper_cnt, sentence_cnt))

write_file = "../../engineering/mgkg/sentences/sentence_display.txt"
with open(write_file, 'w') as f:
    for item in sentences_list:
        f.write("%s\n" % item)


# extract sentences in title, abstract, and sections from retrieved data indexed between [left_index, right_index)
def extract_sentences():
    sentences_list = []
    paper_idx = 0
    paper_cnt = 0
    sentence_cnt = 0
    for paper in data:
        paper_idx += 1
        paper_cnt += 1
        str_tmp = ""
        if 'title' in paper.keys():
            str_tmp += paper['title'] + "."
        if 'abstractText' in paper.keys():
            str_tmp += " " + paper['abstractText']
        # append title+abstract to sentences_list
        if str_tmp != "":
            # split sentence by ". "
            sentence_ls = str_tmp.split('. ')
            for sentence in sentence_ls:
                # 正则过滤字符串，只保留空格、数字、字母、百分数、-、<、>、/、小数点等
                sentence = re.sub(r'[^ \dA-Za-z%<>/\.-]', '', sentence)
                # 如果句子太长，切分处理
                if len(sentence) > 500:
                    idx = 0
                    while idx < len(sentence):
                        tmp_str = sentence[idx: min(idx + 500, len(sentence))] + " end"
                        if len(tmp_str) <= 8:
                            tmp_str += ' end'
                        if tmp_str[-1] != ".":
                            tmp_str += '.'
                        sentences_list.append(tmp_str)
                        sentence_cnt += 1
                        idx += 500
                else:
                    if len(sentence) <= 8:
                        sentence += ' end'
                    if sentence[-1] != ".":
                        sentence += '.'
                    sentences_list.append(sentence)
                    sentence_cnt += 1
        if 'sections' in paper.keys():
            for section in paper['sections']:
                if "heading" in section.keys():
                    str_tmp = section['heading'] + ". " + section['text']
                else:
                    str_tmp = section['text']
                # append heading+text to sentences_list
                if str_tmp != "":
                    # split sentence by ". "
                    sentence_ls = str_tmp.split('. ')
                    for sentence in sentence_ls:
                        # 正则过滤字符串，只保留空格、数字、字母、百分数、-、<、>、/、小数点等
                        sentence = re.sub(r'[^ \dA-Za-z%<>/\.-]', '', sentence)
                        # 如果句子太长，切分处理
                        if len(sentence) > 500:
                            idx = 0
                            while idx < len(sentence):
                                tmp_str = sentence[idx: min(idx + 500, len(sentence))] + " end"
                                if len(tmp_str) <= 8:
                                    tmp_str += ' end'
                                if tmp_str[-1] != ".":
                                    tmp_str += '.'
                                sentences_list.append(tmp_str)
                                sentence_cnt += 1
                                idx += 500
                        else:
                            if len(sentence) <= 8:
                                sentence += ' end'
                            if sentence[-1] != ".":
                                sentence += '.'
                            sentences_list.append(sentence)
                            sentence_cnt += 1

        if paper_cnt % 1000 == 0:
            print("current paper count: {},  current sentence count: {}".format(paper_cnt, sentence_cnt))

    print("# total paper count: {},  total sentence count: {}".format(paper_cnt, sentence_cnt))
    # pubmed_0508: total paper count: 310,  total sentence count: 35812
    return sentences_list, sentence_cnt
    # paper interval: [1, 500001),  total paper count: 500000,  total sentence count: 111915016


# if __name__ == '__main__':
#     extracted_data, total_data_count = extract_sentences()
#     write_file = "./sentences/sentence_pubmed_0508.txt"
#     with open(write_file, 'w') as f:
#         for item in extracted_data:
#             f.write("%s\n" % item)


# total paper count: 0,  total sentence count: 345


# Test MGKG - SingleThread

In [15]:
import os
import openai
import json
import re
import tiktoken
import logging
import time
import random
import pandas as pd
from dotenv import load_dotenv
from revChatGPT.V3 import Chatbot
from PyPDF2 import PdfReader
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed


load_dotenv()
os.environ["http_proxy"] = "http://10.10.1.3:10000"
os.environ["https_proxy"] = "http://10.10.1.3:10000"
openai.api_key = os.getenv("OPENAI_API_KEY")
os.environ['OPENAI_API_KEY'] = openai.api_key


# 调用API并使用重试机制处理rate limit error和其他异常
def get_completion(prompt, model="gpt-3.5-turbo"):
    for i in range(3):  # Retry the API call up to 3 times
        try:
            messages = [
                {"role": "system", "content": "I want you to act as a entity and relation extractor to help me build a medical knowledge graph from a paragraph."},
                {"role": "user", "content": prompt}
                ]
            response = openai.ChatCompletion.create(
                model=model,
                messages=messages,
                temperature=0,
            )
            return response.choices[0].message["content"]
        except openai.error.RateLimitError:  # If rate limit is exceeded
            wait_time = (2 ** i) + random.random()  # Exponential backoff with jitter
            logging.warning(f"Rate limit exceeded. Retrying after {wait_time} seconds.")
            time.sleep(wait_time)  # Wait before retrying
        except Exception as e:  # If any other error occurs
            logging.error(f"API call failed: {str(e)}")
            return None  # Return None for failure
    logging.error("Failed to call OpenAI API after multiple retries due to rate limiting.")
    return None  # Return None for failure


# 从schema文件中获取本体（ontology - entity_label），20个
def get_entity_labels():
    entity_labels = []

    # 读取excel工作表MGKG_Schema_2023-05-05.xlsx - ontology
    df = pd.read_excel('../mgkg/MGKG_Schema_2023-05-05.xlsx', sheet_name='ontology')
    # 按行迭代数据
    for index, row in df.iterrows():
        # 读取行中的每个单元格
        entity_label = row['schema']
        entity_labels.append(entity_label)

    return entity_labels


# 从schema文件中获取关系（relation），33个
def get_relations():
    relations = []

    # 读取excel工作表MGKG_Schema_2023-05-05.xlsx - relations
    df = pd.read_excel('../mgkg/MGKG_Schema_2023-05-05.xlsx', sheet_name='relations')
    # 按行迭代数据
    for index, row in df.iterrows():
        # 读取行中的每个单元格
        relation_name = row['schema']
        relations.append(relation_name)

    return relations


def triple_extraction(paragraph: str, entity_labels: list, schema_relations: list):
    # system_prompt = "I want you to act as a entity and relation extractor to help me build an academic knowledge graph from several paragraphs."
    # chatbot = Chatbot(api_key=openai.api_key, system_prompt=system_prompt)
    
    prompt1 = f"""
I will give you a paragraph. Extract as many named entities as possible from it. Your answer should only contain a list and nothing else. 
---
Here is an example:

paragraph: 
myasthenia gravis is characterized by muscle weakness. prednisolone is a treatment for myasthenia gravis.

your answer: 
[
    "myasthenia gravis",
    "muscle weakness",
    "prednisolone"
]
---
Here is the paragraph you should process:
{paragraph}
"""

    # entity_list = chatbot.ask(prompt1)
    entity_list = get_completion(prompt1)
    # print(entity_list)
    
    prompt2 = f"""This is the entity list you have just generated:

{entity_list}

Classify every entity in into one of the categories in the following list. You should not classify any entity into a category that in not in the following list.

{entity_labels}

Your result should be a JSON dictionary with entities being the keys and categories being the values. There should be nothing in your answer except the JSON dictionary.
---
Here is an example:

paragraph: 
myasthenia gravis is characterized by muscle weakness. prednisolone is a treatment for myasthenia gravis.

entity list:
[
    "myasthenia gravis",
    "muscle weakness",
    "prednisolone"
]
your answer:
{{
    "myasthenia gravis": "disease",
    "muscle weakness": "symptom",
    "prednisolone": "medication"
}}
"""

    # entity_category_dict = chatbot.ask(prompt2)
    entity_category_dict = get_completion(prompt2)
    # print(entity_category_dict)
    
    prompt3 = f"""
The following is the paragraph:

{paragraph}

The following is the "entity list" you have just generated:

{entity_list}

Extract as many relations as possible from the paragraph. Your result should be a list of triples and nothing else. 
The first and third element in each triple should be in the "entity list" you have generated and the second element should be in the following "relation category list". 
You should not extract any relation that the second element in it is not in the following "relation category list". 
The relation you choose should be precise and diverse. You shouldn't use "treatment" to describe all the relations.

Here is the "relation category list":
{schema_relations}

---
Here is an example:

paragraph: 
myasthenia gravis is characterized by muscle weakness. prednisolone is a treatment for myasthenia gravis.

entity list:
[
    "myasthenia gravis",
    "muscle weakness",
    "prednisolone"
]

your answer:
[
    ["myasthenia gravis", "presented with", "muscle weakness"],
    ["prednisolone", "treatment", "myasthenia gravis"],
]

"""

    # relation_list = chatbot.ask(prompt3)
    relation_list = get_completion(prompt3)
    # print(relation_list)
    
    try:
        p_entity_list = json.loads(entity_list)
        p_entity_category_dict = json.loads(entity_category_dict)
        p_relation_list = json.loads(relation_list)
        # print("# JSON load successful!")
        load_flag = True
        return {
            "entity_list": p_entity_list,
            "entity_category_dict": p_entity_category_dict,
            "relation_list": p_relation_list
        }, load_flag
    except:
        # print("# JSON load failed!")
        load_flag = False
        return {
            "entity_list": entity_list,
            "entity_category_dict": entity_category_dict,
            "relation_list": relation_list
        }, load_flag


entity_labels = get_entity_labels()
schema_relations = get_relations()
print("# entity labels:\n", entity_labels)
print("-" * 100)
print("# schema_relations:\n", schema_relations)
print("-" * 120)


"""
paragraph_dict = {
    "paragraph": paragraph,
    "relations": [
        {
            "head": head,
            "head_label": head_label,
            "relation": relation,
            "tail": tail,
            "tail_label": tail_label
        },
        {
            "head": head,
            "head_label": head_label,
            "relation": relation,
            "tail": tail,
            "tail_label": tail_label
        },
        ...
    ]
}
"""
# paragraph_file = "./data/mgkg_data/paragraph_pubmed_0720_test.txt"
# result_write_file = "./results/mgkg_result/paragraph_pubmed_0720_test_result.json"
# retry_write_file = "./data/mgkg_data/paragraph_pubmed_0720_test_retry.txt"
paragraph_file = "./data/mgkg_data/paragraph_pubmed_0720.txt"
result_write_file = "./results/mgkg_result/paragraph_pubmed_0720_result.json"
result_single_write_file = "./results/mgkg_result/paragraph_pubmed_0720_result_single.json"
retry_write_file = "./data/mgkg_data/paragraph_pubmed_0720_retry.txt"
with open(paragraph_file, "r") as file:
    paragraphs = file.readlines()

retry_paragraphs = []   # JSON load失败的paragraph
paragraph_dict_list = []    # 所有paragraph的三元组抽取结果
success_cnt = 0
fail_cnt = 0

# 将抽取结果paragraph_dict持续性写入result_single_write_file文件
with open(result_single_write_file, "w", newline='\n') as wrt_single_file:
    for paragraph in tqdm(paragraphs, total=len(paragraphs), desc='Processing paragraphs'):
        # print(paragraph.strip())  # 输出每一行内容（去除换行符）
        result, load_flag = triple_extraction(paragraph, entity_labels, schema_relations)
        if load_flag:   # JSON load成功
            success_cnt += 1
            entities = result['entity_list']
            entity_labels = result['entity_category_dict']
            relations = result['relation_list']
            relation_dict_list = []  # 每个paragraph抽取的的所有relation
            for item in relations:
                head = item[0]
                relation = item[1]
                tail = item[2]
                head_lebal = ""
                tail_label = ""
                # print(f'{head}, {relation}, {tail}')
                entity_keys = entity_labels.keys()
                for key in entity_keys:
                    if key in head:
                        head_label = entity_labels[key]
                    if key in tail:
                        tail_label = entity_labels[key]
                relation_dict = {   # paragraph中抽取的一个relation
                    "head": head,
                    "head_label": head_label,
                    "relation": relation,
                    "tail": tail,
                    "tail_label": tail_label
                }
                relation_dict_list.append(relation_dict)
            paragraph_dict = {
                "paragraph": paragraph,
                "relations": relation_dict_list
            }
            paragraph_dict_list.append(paragraph_dict)
        else:   # JSON load失败
            fail_cnt += 1
            retry_paragraphs.append(paragraph)
            paragraph_dict = {
                "paragraph": paragraph,
                "relations": "failure"
            }
            paragraph_dict_list.append(paragraph_dict)
        json.dump(paragraph_dict, wrt_single_file, indent=4)
        wrt_single_file.write('\n')  # 添加换行符
        print("# paragraph_cnt {}, success_cnt {}, fail_cnt {}".format(len(paragraphs), success_cnt, fail_cnt))


print("# paragraph_cnt:", len(paragraphs))
print("# success_cnt:", success_cnt)
print("# fail_cnt:", fail_cnt)

# 将抽取结果paragraph_dict_list一次性写入文件
with open(result_write_file, "w") as json_file:
    json.dump(paragraph_dict_list, json_file, indent=4)

# 将JSON load失败的paragraph一次性写入新文件
with open(retry_write_file, 'w') as retry_file:
    for item in retry_paragraphs:
        retry_file.write("%s\n" % item)


# entity labels:
 ['medication', 'non-medication treatment', 'clinical feature', 'disease', 'symptom', 'sign', 'subgroup', 'scale', 'ancillary test', 'comorbidity', 'clinical effect', 'adverse effect', 'sex', 'age', 'age at onset', 'history of smoking', 'history of alcohol consumption', 'level of education', 'level of income', 'latitude of residence']
----------------------------------------------------------------------------------------------------
# schema_relations:
 ['adverse effect', 'alleviate', 'be superior to', 'biologics', 'caution', 'clinical effect', 'coadministration', 'combine with', 'complication', 'contraindication', 'conventional immunosuppression', 'drug intenration', 'especially alleviate', 'first-line medication', 'incompatibility', 'indication', 'medication for precaution', 'medication for treatment', 'postoperative drug', 'precaution', 'preoperative drug', 'presented with', 'second-line medication', 'steroid sparing', 'subclass', 'supportive treatment', 'surgery',

Processing paragraphs: 100%|██████████| 4/4 [01:51<00:00, 27.85s/it]

# paragraph_cnt: 4
# success_cnt: 4
# fail_cnt: 0





# Test MGKG - MultiThread

In [None]:
import os
import openai
import json
import re
import tiktoken
import logging
import time
import random
import pandas as pd
from dotenv import load_dotenv
from revChatGPT.V3 import Chatbot
from PyPDF2 import PdfReader
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed


load_dotenv()
os.environ["http_proxy"] = "http://10.10.1.3:10001"
os.environ["https_proxy"] = "http://10.10.1.3:10001"
openai.api_key = os.getenv("OPENAI_API_KEY")
os.environ['OPENAI_API_KEY'] = openai.api_key


# 调用API并使用重试机制处理rate limit error和其他异常
def get_completion(prompt, model="gpt-3.5-turbo"):
    for i in range(3):  # Retry the API call up to 3 times
        try:
            messages = [
                {"role": "system", "content": "I want you to act as a entity and relation extractor to help me build a medical knowledge graph from a paragraph."},
                {"role": "user", "content": prompt}
                ]
            response = openai.ChatCompletion.create(
                model=model,
                messages=messages,
                temperature=0,
            )
            return response.choices[0].message["content"]
        except openai.error.RateLimitError:  # If rate limit is exceeded
            wait_time = (2 ** i) + random.random()  # Exponential backoff with jitter
            logging.warning(f"Rate limit exceeded. Retrying after {wait_time} seconds.")
            time.sleep(wait_time)  # Wait before retrying
        except Exception as e:  # If any other error occurs
            logging.error(f"API call failed: {str(e)}")
            return None  # Return None for failure
    logging.error("Failed to call OpenAI API after multiple retries due to rate limiting.")
    return None  # Return None for failure


# 从schema文件中获取本体（ontology - entity_label），20个
def get_entity_labels():
    entity_labels = []

    # 读取excel工作表MGKG_Schema_2023-05-05.xlsx - ontology
    df = pd.read_excel('../mgkg/MGKG_Schema_2023-05-05.xlsx', sheet_name='ontology')
    # 按行迭代数据
    for index, row in df.iterrows():
        # 读取行中的每个单元格
        entity_label = row['schema']
        entity_labels.append(entity_label)

    return entity_labels


# 从schema文件中获取关系（relation），33个
def get_relations():
    relations = []

    # 读取excel工作表MGKG_Schema_2023-05-05.xlsx - relations
    df = pd.read_excel('../mgkg/MGKG_Schema_2023-05-05.xlsx', sheet_name='relations')
    # 按行迭代数据
    for index, row in df.iterrows():
        # 读取行中的每个单元格
        relation_name = row['schema']
        relations.append(relation_name)

    return relations


def triple_extraction(paragraph: str, entity_labels: list, schema_relations: list):
    # system_prompt = "I want you to act as a entity and relation extractor to help me build an academic knowledge graph from several paragraphs."
    # chatbot = Chatbot(api_key=openai.api_key, system_prompt=system_prompt)
    
    prompt1 = f"""
I will give you a paragraph. Extract as many named entities as possible from it. Your answer should only contain a list and nothing else. 
---
Here is an example:

paragraph: 
myasthenia gravis is characterized by muscle weakness. prednisolone is a treatment for myasthenia gravis.

your answer: 
[
    "myasthenia gravis",
    "muscle weakness",
    "prednisolone"
]
---
Here is the paragraph you should process:
{paragraph}
"""

    # entity_list = chatbot.ask(prompt1)
    entity_list = get_completion(prompt1)
    # print(entity_list)
    
    prompt2 = f"""This is the entity list you have just generated:

{entity_list}

Classify every entity in into one of the categories in the following list. You should not classify any entity into a category that in not in the following list.

{entity_labels}

Your result should be a JSON dictionary with entities being the keys and categories being the values. There should be nothing in your answer except the JSON dictionary.
---
Here is an example:

paragraph: 
myasthenia gravis is characterized by muscle weakness. prednisolone is a treatment for myasthenia gravis.

entity list:
[
    "myasthenia gravis",
    "muscle weakness",
    "prednisolone"
]
your answer:
{{
    "myasthenia gravis": "disease",
    "muscle weakness": "symptom",
    "prednisolone": "medication"
}}
"""

    # entity_category_dict = chatbot.ask(prompt2)
    entity_category_dict = get_completion(prompt2)
    # print(entity_category_dict)
    
    prompt3 = f"""
The following is the paragraph:

{paragraph}

The following is the "entity list" you have just generated:

{entity_list}

Extract as many relations as possible from the paragraph. Your result should be a list of triples and nothing else. 
The first and third element in each triple should be in the "entity list" you have generated and the second element should be in the following "relation category list". 
You should not extract any relation that the second element in it is not in the following "relation category list". 
The relation you choose should be precise and diverse. You shouldn't use "treatment" to describe all the relations.

Here is the "relation category list":
{schema_relations}

---
Here is an example:

paragraph: 
myasthenia gravis is characterized by muscle weakness. prednisolone is a treatment for myasthenia gravis.

entity list:
[
    "myasthenia gravis",
    "muscle weakness",
    "prednisolone"
]

your answer:
[
    ["myasthenia gravis", "presented with", "muscle weakness"],
    ["prednisolone", "treatment", "myasthenia gravis"],
]

"""

    # relation_list = chatbot.ask(prompt3)
    relation_list = get_completion(prompt3)
    # print(relation_list)
    
    try:
        p_entity_list = json.loads(entity_list)
        p_entity_category_dict = json.loads(entity_category_dict)
        p_relation_list = json.loads(relation_list)
        # print("# JSON load successful!")
        load_flag = True
        return {
            "entity_list": p_entity_list,
            "entity_category_dict": p_entity_category_dict,
            "relation_list": p_relation_list
        }, load_flag, paragraph
    except:
        # print("# JSON load failed!")
        load_flag = False
        return {
            "entity_list": entity_list,
            "entity_category_dict": entity_category_dict,
            "relation_list": relation_list
        }, load_flag, paragraph


entity_labels = get_entity_labels()
schema_relations = get_relations()
print("# entity labels:\n", entity_labels)
print("-" * 100)
print("# schema_relations:\n", schema_relations)
print("-" * 120)


"""
paragraph_dict = {
    "paragraph": paragraph,
    "relations": [
        {
            "head": head,
            "head_label": head_label,
            "relation": relation,
            "tail": tail,
            "tail_label": tail_label
        },
        {
            "head": head,
            "head_label": head_label,
            "relation": relation,
            "tail": tail,
            "tail_label": tail_label
        },
        ...
    ]
}
"""
# paragraph_file = "./data/mgkg_data/paragraph_pubmed_0720_test.txt"
# result_write_file = "./results/mgkg_result/paragraph_pubmed_0720_test_result.json"
# retry_write_file = "./data/mgkg_data/paragraph_pubmed_0720_test_retry.txt"
paragraph_file = "./data/mgkg_data/paragraph_pubmed_0720.txt"
result_write_file = "./results/mgkg_result/paragraph_pubmed_0720_result.json"
result_single_write_file = "./results/mgkg_result/paragraph_pubmed_0720_result_single.json"
retry_write_file = "./data/mgkg_data/paragraph_pubmed_0720_retry.txt"
with open(paragraph_file, "r") as file:
    paragraphs = file.readlines()

retry_paragraphs = []   # JSON load失败的paragraph
paragraph_dict_list = []    # 所有paragraph的三元组抽取结果
success_cnt = 0
fail_cnt = 0

# 将抽取结果paragraph_dict持续性写入result_single_write_file文件
with open(result_single_write_file, "w", newline='\n') as wrt_single_file:
    with ThreadPoolExecutor(max_workers=2) as executor:
        futures = {executor.submit(triple_extraction, paragraph, entity_labels, schema_relations): paragraph for paragraph in paragraphs}
        for future in tqdm(as_completed(futures), total=len(futures), desc='Processing paragraphs'):
            # 收集完成的线程处理好的结果
            result, load_flag, paragraph = future.result()
            if load_flag:   # JSON load成功
                success_cnt += 1
                entities = result['entity_list']
                entity_labels = result['entity_category_dict']
                relations = result['relation_list']
                relation_dict_list = []  # 每个paragraph抽取的的所有relation
                for item in relations:
                    head = item[0]
                    relation = item[1]
                    tail = item[2]
                    head_lebal = ""
                    tail_label = ""
                    # print(f'{head}, {relation}, {tail}')
                    entity_keys = entity_labels.keys()
                    for key in entity_keys:
                        if key in head:
                            head_label = entity_labels[key]
                        if key in tail:
                            tail_label = entity_labels[key]
                    relation_dict = {   # paragraph中抽取的一个relation
                        "head": head,
                        "head_label": head_label,
                        "relation": relation,
                        "tail": tail,
                        "tail_label": tail_label
                    }
                    relation_dict_list.append(relation_dict)
                paragraph_dict = {
                    "paragraph": paragraph,
                    "relations": relation_dict_list
                }
                paragraph_dict_list.append(paragraph_dict)
            else:   # JSON load失败
                fail_cnt += 1
                retry_paragraphs.append(paragraph)
                paragraph_dict = {
                    "paragraph": paragraph,
                    "relations": "failure"
                }
                paragraph_dict_list.append(paragraph_dict)
            json.dump(paragraph_dict, wrt_single_file, indent=4)
            wrt_single_file.write('\n')  # 添加换行符
            print("# paragraph_cnt {}, success_cnt {}, fail_cnt {}".format(len(paragraphs), success_cnt, fail_cnt))


print("# paragraph_cnt:", len(paragraphs))
print("# success_cnt:", success_cnt)
print("# fail_cnt:", fail_cnt)

# 将抽取结果paragraph_dict_list一次性写入文件
with open(result_write_file, "w") as json_file:
    json.dump(paragraph_dict_list, json_file, indent=4)

# 将JSON load失败的paragraph一次性写入新文件
with open(retry_write_file, 'w') as retry_file:
    for item in retry_paragraphs:
        retry_file.write("%s\n" % item)


# Test MGKG - Process

In [36]:
import json
from tqdm import tqdm


file_path = './results/mgkg_result/paragraph_pubmed_0720_result_single.json'
file_path_2 = './results/mgkg_result/paragraph_pubmed_0720_result_single_reverse.json'

with open(file_path, 'r') as f:
    data = json.load(f)
print("result:", type(data), len(data))

with open(file_path_2, 'r') as f_2:
    data_2 = json.load(f_2)
print("result (reverse):", type(data_2), len(data_2))


paragraph_file = "./data/mgkg_data/paragraph_pubmed_0720.txt"
result_write_file = "./results/mgkg_result/paragraph_pubmed_0720_result.json"
retry_write_file = "./data/mgkg_data/paragraph_pubmed_0720_retry.txt"

with open(paragraph_file, "r") as file:
    paragraphs = file.readlines()
paragraph_dict_list = []    # 所有paragraph的三元组抽取结果
retry_paragraphs = []   # JSON load失败的paragraph
relation_cnt = 0    # 抽取出的三元组数量
for paragraph in tqdm(paragraphs):
    find_flag = False
    # 先在正向的结果中找
    for item in data:
        if paragraph == item['paragraph'] and item['relations'] != "failure":
            find_flag = True
            paragraph_dict = item
            relation_cnt += len(paragraph_dict['relations'])
            paragraph_dict_list.append(paragraph_dict)
            break
    # 如果正向的结果中没有找到，再在反向的结果中找
    if not find_flag:
        for item in data_2:
            if paragraph == item['paragraph'] and item['relations'] != "failure":
                find_flag = True
                paragraph_dict = item
                relation_cnt += len(paragraph_dict['relations'])
                paragraph_dict_list.append(paragraph_dict)
                break
    # 如果正向和反向的结果中都没有找到，将paragraph加入retry_paragraphs
    if not find_flag:
        retry_paragraphs.append(paragraph)

print("relation_cnt:", relation_cnt, "| paragraph_cnt:", len(paragraph_dict_list))
print("result (merged):", type(paragraph_dict_list),len(paragraph_dict_list))
print("result (retry):", type(retry_paragraphs), len(retry_paragraphs))

# 将抽取结果paragraph_dict_list一次性写入文件
with open(result_write_file, "w") as json_file:
    json.dump(paragraph_dict_list, json_file, indent=4)

# 将JSON load失败的paragraph一次性写入新文件
with open(retry_write_file, 'w') as retry_file:
    for item in retry_paragraphs:
        retry_file.write("%s" % item)


result: <class 'list'> 1293
result (reverse): <class 'list'> 1318


100%|██████████| 1839/1839 [00:00<00:00, 6139.53it/s]


relation_cnt: 34944 | paragraph_cnt: 1495
result (merged): <class 'list'> 1495
result (retry): <class 'list'> 344
