In [31]:
# @ todo imports
import json
import duckdb
from google import genai
from google.genai import types
from datetime import datetime

In [32]:
def get_api_key() -> str:
    """Gets the users Google Gemini api key from the config file

    Args:
        None

    Returns:
        The Google Gemini api key of the user
    """
    with open("../config.json", "r") as config_file:
        config = json.load(config_file)
    return config.get("gemini_api_key")

def write_log(msg: str, logfile: str):
    """Writes a message to the log file.

    Args:
        msg: The message to write to the log file
        logfile: The name of the log file

    Returns:
        None
    """
    file_path = f"../logs/{logfile}"
    with open(file_path, "a") as log_file:
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        log_file.write(f"{timestamp}\n{msg}\n\n")

In [33]:
client = genai.Client(api_key=get_api_key())

# Connect to the database
con = duckdb.connect(database='patent_database', read_only=False)

In [51]:
def create_label_table(reset: bool = False):
    """Creates a table in the database to store the labels

    Args:
        reset: Whether to reset the table if it already exists
    Returns:
        None
    """
    if reset:
        con.execute("DROP TABLE IF EXISTS labels")

    con.execute("""
        CREATE TABLE IF NOT EXISTS labels (
            han_id INTEGER,
            firm_id INTEGER REFERENCES firm_names(firm_id),
            label INTEGER
        )
    """)

def insert_label(han_id: int, firm_id: int, label: int):
    """Inserts a label into the label table

    Args:
        han_id: The id of the han record
        firm_id: The id of the firm record
        label: The label of the record
    Returns:
        None
    """
    con.execute(f"""
        INSERT INTO labels
        VALUES ({han_id}, {firm_id}, {label})
    """)

In [52]:
def call_gemini_api(input_data: str) -> dict:
    """Calls the Google Gemini API to determine if the name is a match to the han_name, person_name and psn_name

    Args:
        input_data: The data to be sent to the API
    Returns:
        The response from the API
    """

    prompt = f"""
    You are an expert in entity resolution. I will provide you with multiple name variants for a single company record.
    Your task is to determine if the variable "name" matches to the patstat names, which are "han_name", "person_name" and "psn_name".
    The 'Input Data' is provided below in JSON format.

    **Output your answers in the following JSON format.**  The output should be a single JSON object containing  the following key-value pairs.:
    - 'firm_id': must match the 'firm_id' in the input data
    - 'han_id': must match the 'han_id' in the input data
    - 'label': must be either '1' if match, i.e. they refer to the same company or '0' if they do not

    ```json
    {{
      "firm_id":  // The firm_id of the record,
      "han_id":   // The han_id of the record,
      "label":    // "match" or "no_match"
    }}
    ```
    **Input Data:**

    ```json
    {input_data}
    """

    response = client.models.generate_content(
        model='gemini-2.0-flash',
        contents=prompt,
        config=types.GenerateContentConfig(
        temperature=0.0
        )
    )

    try:
        # Preprocess the model response
        lines = response.text.splitlines()
        # Remove lines that start with backticks
        filtered_lines = [line for line in lines if not line.strip().startswith("```")]
        # If the first (or any) line is just 'json', remove that too
        filtered_lines = [line for line in filtered_lines if line.strip().lower() != "json"]

        json_str = "\n".join(filtered_lines)
        return json.loads(json_str)

    except json.JSONDecodeError:
        print("Error: Gemini API did not return valid JSON.")
        write_log(f"Error: Gemini API did not return valid JSON.\n\n{response.text}", "label_training_api_call_log.txt")


def process_gemini_response(response: dict):
    """Processes the response from the Gemini API

    Args:
        response: The response from the Gemini API
    Returns:
        None
    """
    han_id = response.get("han_id")
    firm_id = response.get("firm_id")
    label = response.get("label")
    print(f"han_id: {han_id}, firm_id: {firm_id}, label: {label}")
    insert_label(han_id, firm_id, label)

In [53]:
def process_data():
    """Processes the data in the database

    Args:
        None
    Returns:
        None
    """
    sql = """
    SELECT DISTINCT(firm_id), han_id, similarity, name, han_name, person_name, psn_name FROM patstat_firm_match
    JOIN firm_names USING(firm_id)
    JOIN patstat_data USING(han_id)
    WHERE similarity > 0.87
    """
    data = con.execute(sql).fetchdf()
    for _, row in data.iterrows():
        response = call_gemini_api(row.to_json())
        process_gemini_response(response)

In [None]:
if __name__ == "__main__":
    create_label_table(True)
    process_data()

han_id: 868544, firm_id: 5017, label: 1
han_id: 1384087, firm_id: 6729, label: 1
han_id: 3653281, firm_id: 10848, label: 0
han_id: 1432953, firm_id: 6881, label: 1
han_id: 4536091, firm_id: 3254, label: 0
han_id: 366042, firm_id: 1645, label: 0
han_id: 2163779, firm_id: 9189, label: 0
han_id: 4698488, firm_id: 9563, label: 0
han_id: 1636172, firm_id: 13430, label: 0
han_id: 4653424, firm_id: 12495, label: 1
han_id: 369580, firm_id: 1933, label: 1
han_id: 10961, firm_id: 9575, label: 1
han_id: 3289414, firm_id: 14134, label: 1
han_id: 4398346, firm_id: 3918, label: 0
han_id: 181583047, firm_id: 12260, label: 0
han_id: 555281, firm_id: 6175, label: 0
han_id: 1904196, firm_id: 13308, label: 1
han_id: 3288696, firm_id: 10592, label: 0
han_id: 2409972, firm_id: 11581, label: 1
han_id: 297236, firm_id: 1662, label: 0
han_id: 1631792, firm_id: 7583, label: 0
han_id: 3094619, firm_id: 14538, label: 0
han_id: 1982247, firm_id: 9096, label: 0
han_id: 1514031, firm_id: 13307, label: 1
han_id: 427

In [20]:
sql = """
    SELECT DISTINCT(firm_id), han_id, similarity, name, han_name, person_name, psn_name FROM patstat_firm_match
    JOIN firm_names USING(firm_id)
    JOIN patstat_data USING(han_id)
    WHERE similarity > 0.87
    """
test_data = con.execute(sql).fetchdf()

In [30]:
process_gemini_response(call_gemini_api(test_data.loc[0].to_json()))

han_id: 868544, firm_id: 5017, label: 1


In [26]:
test_data.loc[0].to_json()

'{"firm_id":5017,"han_id":868544,"similarity":1.0,"name":"FMC CORP","han_name":"FMC CORP","person_name":"FMC Corporation","psn_name":"FMC CORPORATION"}'