# Entity alignment experiment

Problem: The entities extracted from LLM can be messy, we need to align it into some canonical form to make sure they are referring to the same thing.

Steps:

1. Extract entities
2. Project to semantic space
3. Use distance metric to defines it canonical form, semi-manually

In [None]:
import os
import json
import requests
import tenacity
import logging
from dotenv import load_dotenv
from concurrent.futures import ThreadPoolExecutor

logging.basicConfig(level=logging.INFO)

load_dotenv()
CHTC_LLM_HOST = os.getenv('CHTC_LLM_HOST', "")
CHTC_LLM_PORT = os.getenv('CHTC_LLM_PORT', "")
CHTC_LLM_API_KEY = os.getenv('CHTC_LLM_API_KEY', "")
CHTC_LLM_API_URL = f"http://{CHTC_LLM_HOST}:{CHTC_LLM_PORT}"


BATCH_SIZE = 8

Check if LLM endpoint is running

In [None]:
requests.get(CHTC_LLM_API_URL).json()

Basic LLM extraction function

In [None]:
@tenacity.retry(wait=tenacity.wait_fixed(5), stop=tenacity.stop_after_attempt(5))
def extract(text: str) -> dict:
    """Extract entities from text."""

    auth_headers = {"Api-Key": CHTC_LLM_API_KEY}
    payload = {
        "model": "my_model",
        "messages": [
            {
                "role": "system",
                "content": "Try to extract all locations, stratigraphic names, lithologies from the text provided. Reply in JSON format with the following structure: {\"locations:\": \"\", \"stratigraphic_names\": \"\", \"lithologies\": \"\"}. If you can't find any of the requested information, leave the corresponding field empty.Do not include any other information in the response."
            },
            {
                "role": "user",
                "content": text
            }
        ],
    }
    response = requests.post(f"{CHTC_LLM_API_URL}/v1/chat/completions", headers=auth_headers, json=payload)
    response.raise_for_status()
    content = json.loads(response.json()['_content'])
    return json.loads(content["choices"][0]['message']['content'])


In [None]:
sample_text = "In the rolling hills of the Devonshire region, geologists have identified a fascinating stratigraphic layer known as the Devonian Slate. This stratum, rich in history and significance, dates back to the Devonian period, showcasing a deep, grayish-black coloration that speaks to its volcanic ash origin. The lithology of the Devonian Slate is particularly noteworthy; it exhibits a fine-grained texture with a smooth to slightly fissile surface, indicating a high degree of compaction and low-grade metamorphism over millions of years. Embedded within this slate, one can find intermittent layers of quartz and small fossils, hinting at the dynamic environmental conditions that prevailed during its formation. The study of such layers not only enriches our understanding of Earth's geological past but also provides valuable insights into the processes that have shaped the planet's crust over aeons."

extract(sample_text)

Make a pipeline to process the criticalMAAS data

In [None]:
import weaviate

WEAVIATE_APIKEY = os.getenv("WEAVIATE_APIKEY")
WEAVIATE_URL = os.getenv("WEAVIATE_URL")

client = weaviate.Client(WEAVIATE_URL, weaviate.AuthApiKey(api_key=WEAVIATE_APIKEY))

def get_batch_with_cursor(class_properties: list[str], class_name:str="Paragraph", batch_size:int=BATCH_SIZE, offset:int | None=None):

    if "topic_list" not in class_properties:
        class_properties.append("topic_list")
    query = (
        client.query.get(class_name, class_properties)
        .with_additional(["id"])
        .with_where({"path": "topic_list", "operator": "ContainsAny", "valueText": ["criticalmaas"]})
    )

    if offset is not None:
        query = query.with_offset(offset)

    return query.with_limit(batch_size).do()

Make a local DB to store results

In [None]:
def insert_case(hashed_text:str, paper_id:str, locations:str, stratigraphic_names:str, lithologies:str) -> None:
    with sqlite3.connect("entities.db") as conn:
        cur = conn.cursor()
        cur.execute('''
        INSERT INTO entities (hashed_text, paper_id, locations, stratigraphic_names, lithologies)
        VALUES (?, ?, ?, ?, ?)
        ''', (hashed_text, paper_id, locations, stratigraphic_names, lithologies))
        conn.commit()

In [None]:
def in_db(hashed_text: str) -> bool:
    with sqlite3.connect("entities.db") as conn:
        cur = conn.cursor()
        cur.execute("SELECT hashed_text FROM entities WHERE hashed_text=?", (hashed_text,))
        result = cur.fetchone()
    return result is not None

In [None]:
def process_case(paragraph: dict) -> None:

    # Check if the paragraph has already been processed
    if in_db(paragraph["hashed_text"]):
        logging.info(f"Paragraph {paragraph['hashed_text']} already processed.")
        return
    
    try:
        entities = extract(paragraph["text_content"])
    except tenacity.RetryError:
        logging.error(f"Failed to extract entities for paragraph {paragraph['hashed_text']}.")
        return
    entities["hashed_text"] = paragraph["hashed_text"]
    entities["paper_id"] = paragraph["paper_id"]

    logging.info(f"Extracted entities for paragraph {paragraph['hashed_text']}: {entities}")
    insert_case(**entities)
    logging.info(f"Inserted entities for paragraph {paragraph['hashed_text']} into the database.")

