In [1]:
import duckdb
from duckdb import ConstraintException
from duckdb import BinderException
from google import genai
from google.genai import types
from google.genai.errors import ClientError
from google.genai.errors import ServerError
from datetime import datetime
import concurrent.futures
import json
import time
import re
from tqdm import tqdm

In [2]:
def get_api_key() -> str:
    """Gets the users Google Gemini api key from the config file

    Args:
        None

    Returns:
        The Google Gemini api key of the user
    """
    with open("../config.json", "r") as config_file:
        config = json.load(config_file)
    return config.get("gemini_api_key")

def write_log(msg: str, logfile: str):
    """Writes a message to the log file.

    Args:
        msg: The message to write to the log file
        logfile: The name of the log file

    Returns:
        None
    """
    file_path = f"../logs/{logfile}"
    with open(file_path, "a") as log_file:
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        log_file.write(f"{timestamp}\n{msg}\n\n")

In [3]:
client = genai.Client(api_key=get_api_key())

# Connect to the database
con = duckdb.connect(database='patent_database', read_only=False)

In [4]:
def create_label_table(reset: bool = False):
    """Creates a table in the database to store the labels

    Args:
        reset: Whether to reset the table if it already exists
    Returns:
        None
    """
    if reset:
        con.execute("DROP TABLE IF EXISTS labels")

    con.execute("""
        CREATE TABLE IF NOT EXISTS labels (
            han_id INTEGER,
            firm_id INTEGER REFERENCES firm_names(firm_id),
            label INTEGER
        )
    """)

def insert_label(han_id: int, firm_id: int, label: int):
    """Inserts a label into the label table

    Args:
        han_id: The id of the han record
        firm_id: The id of the firm record
        label: The label of the record
    Returns:
        None
    """
    try:
        con.execute(f"""
            INSERT INTO labels
            VALUES ({han_id}, {firm_id}, {label})
        """)
    except ConstraintException as e:
        # Entry already in Database
        pass
    except BinderException as e:
        # None column (faulty gemini response)
        msg = f"""
            Faulty Response for:
            han_id: {han_id}
            firm_id: {firm_id}
            label: {label}
            """
        write_log(msg, "insert_exception_log.txt")

In [5]:

def do_api_call(prompt, input_data, client, types) -> dict:
    """ Calls the Gemini API with the given prompt and input data

    Args:
        prompt: The prompt to be sent to the API
        input_data: The data to be sent to the API
        client: The Google Gemini client
        types: The types module from the Google Gemini API
    Returns:
        The response from the API
    """
    response = None
    successful_api_call = False
    i = 0

    while not successful_api_call:
        try:
            response = client.models.generate_content(
                model='gemini-2.0-flash-thinking-exp-01-21',
                contents=prompt,
                config=types.GenerateContentConfig(
                    temperature=0.0
                )
            )
            successful_api_call = True
        except (ClientError, ServerError) as e:
            i += 1
            if i == 5:
                error = f"""Failed to call the gemini api 5 times
                        Error: {e}
                        Input data: {input_data}"""
                write_log(error, "api_call_error.txt")
                return {}  # Return an empty dict to avoid None
            else:
                time.sleep(30)  # Sleep before retrying
                continue

    # If we exit the loop normally, we got a successful API call
    return response

def call_gemini_api_with_timeout(prompt, input_data, client, types, timeout_seconds=300) -> dict:
    """ Calls the Gemini API with the given prompt and input data, with a timeout

    Args:
        prompt: The prompt to be sent to the API
        input_data: The data to be sent to the API
        client: The Google Gemini client
        types: The types module from the Google Gemini API
        timeout_seconds: The timeout for the API call
    Returns:
        The response from the API
    """
    with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
        future = executor.submit(do_api_call, prompt, input_data, client, types)
        try:
            # Wait up to timeout_seconds for do_api_call to complete
            # print("Calling API")
            response = future.result(timeout=timeout_seconds)
            # print(f"Got response: {response}")
            return response
        except concurrent.futures.TimeoutError:
            write_log(f"Timed out after {timeout_seconds} seconds for input {input_data};\n returning empty dict.", "api_call_error.txt")
            return {}

def call_gemini_api(input_data: str) -> dict:
    """Calls the Google Gemini API to determine if the name is a match to the han_name, person_name and psn_name

    Args:
        input_data: The data to be sent to the API
    Returns:
        The response from the API
    """

    prompt = f""" Your task is to determine if a given company name ('name') matches any of the provided company names from the PATSTAT database ('han_name', 'person_name', 'psn_name'). You must be very thorough in your analysis. Assume that the provided names are accurate and free of spelling errors. Focus on identifying exact or near-exact matches, considering only common and accepted abbreviations. Do NOT consider minor variations or potential spelling mistakes as valid matches.

    Input Data (JSON):
    {input_data}

    Output (JSON):
    {{\n\"firm_id\": \"{{firm_id}}\",\n    \"han_id\": \"{{han_id}}\",\n    \"label\": \"{{label}}\" <--- The value of 'label' MUST be either '0' or '1'. '0' indicates no match, and '1' indicates a match.\n}}",

    "description": "This prompt instructs the model to perform company name matching, comparing a given name against PATSTAT names and outputting a JSON object with the firm_id, han_id, and a label indicating a match (1) or no match (0). The model is instructed to be thorough.
    """
    response = call_gemini_api_with_timeout(prompt, input_data, client, types)

    try:
        # Preprocess the model response
        text = response.text
        pattern = r"json\s*(\{.*?\})\s*"
        match = re.search(pattern, text, re.DOTALL)
        if match:
            json_str = match.group(1).strip()
            data = json.loads(json_str)
            return data
        else:
            # print("Error: Gemini API did not return valid JSON.")
            write_log(f"Error: Gemini API did not return valid JSON.\n\n{response.text}", "label_training_api_call_log.txt")
    except json.JSONDecodeError:
        # print("Error: Gemini API did not return valid JSON.")
        write_log(f"Error: Gemini API did not return valid JSON.\n\n{response.text}", "label_training_api_call_log.txt")


def process_gemini_response(response: dict, row: dict):
    """Processes the response from the Gemini API

    Args:
        response: The response from the Gemini API
        row: The row of data that was sent to the API
    Returns:
        None
    """
    han_id = int(response.get("han_id"))
    firm_id = int(response.get("firm_id"))
    # Check if the han_id and firm_id match the input data
    if han_id != row.get("han_id") or firm_id != row.get("firm_id"):
        write_log(f"Error: han_id or firm_id do not match the input data.\n\n{response}", "label_training_api_call_log.txt")
        return
    label = response.get("label")
    # print(f"han_id: {han_id}, firm_id: {firm_id}, label: {label}")
    insert_label(han_id, firm_id, label)

In [6]:
def process_data():
    """Processes the data in the database

    Args:
        None
    Returns:
        None
    """
    sql = """
    SELECT DISTINCT firm_id, han_id, similarity, name, han_name, person_name, psn_name FROM patstat_firm_match
    JOIN firm_names USING(firm_id)
    JOIN patstat_data USING(han_id)
    WHERE similarity >= 0.9
    """
    data = con.execute(sql).fetchdf()

    for _, row in tqdm(data.iterrows(), total=len(data), desc="Processing rows"):
        # If the name jaro-winkler similarity is >= .99, we assume it is a match
        han_id = row["han_id"]
        firm_id = row["firm_id"]
        if con.execute(f"SELECT * FROM labels WHERE han_id = {han_id} AND firm_id = {firm_id}").fetchdf().shape[0] > 0:
            write_log(f"Entry han_id = {han_id} AND firm_id = {firm_id} already in DB", "entry_log.txt")
            # Entry already in Database
            continue
        elif row['similarity'] >= 0.99:
            # We assume it is a match
            insert_label(han_id, firm_id, 1)
        else:
            # We call the gemini api to determine if it is a match
            successful_answer = False
            i = 0
            while not successful_answer:
                try:
                    response = call_gemini_api(row.to_json())
                    process_gemini_response(response, row)
                    successful_answer = True
                except (TypeError, AttributeError, UnicodeEncodeError) as e:
                    i += 1
                    if i == 5:
                        error = f"""Failed to call the gemini api 5 times\n
                            Error: {e}\n
                            Input data: {row}"""
                        write_log(error, "api_call_error.txt")
                        break
                    else:
                        continue
            

In [7]:
if __name__ == "__main__":
    create_label_table()
    try:
        process_data()
    except UnicodeEncodeError as e:
        msg = f"Failed to write log. Error:\n {e}"

Processing rows:  87%|████████▋ | 26756/30668 [2:28:14<21:40,  3.01it/s]   


In [None]:
sql = """
    SELECT DISTINCT han_id, firm_id, similarity, label, name, han_name, person_name, psn_name FROM labels
    JOIN firm_names USING(firm_id)
    JOIN patstat_data USING(han_id)
    JOIN patstat_firm_match USING(han_id, firm_id)
    WHERE label = 0
    AND similarity > 0.95
"""

con.execute(sql).fetchdf()

In [None]:
# Verify, that there are no han_ids with multiple firm_ids and vice versa
sql = """
    SELECT han_id, COUNT(DISTINCT firm_id) AS distinct_firm_ids
    FROM labels
    WHERE label = 1
    GROUP BY han_id
    HAVING COUNT(DISTINCT firm_id) > 1;
"""

con.execute(sql).fetchdf()


In [None]:
sql = """
    SELECT 
    SELECT firm_id, COUNT(DISTINCT han_id) AS distinct_han_ids
    FROM labels
    WHERE label = 1
    GROUP BY firm_id
    HAVING COUNT(DISTINCT han_id) > 1;
"""

con.execute(sql).fetchdf()

In [None]:
sql = """
        SELECT DISTINCT
        l.han_id,
        l.firm_id,
        pm.similarity,
        l.label,
        f.name,
        p.han_name,
        p.person_name,
        p.psn_name
    FROM labels l
    JOIN firm_names f USING (firm_id)
    JOIN patstat_data p USING (han_id)
    JOIN patstat_firm_match pm USING (han_id, firm_id)
    WHERE l.firm_id IN (
        SELECT firm_id
        FROM labels
        WHERE label = 1
        GROUP BY firm_id
        HAVING COUNT(DISTINCT han_id) > 1
    );
"""

con.execute(sql).fetchdf()

In [None]:
con.execute("SELECT count(DISTINCT firm_id) FROM labels").fetchdf()

In [None]:
sql = """
    SELECT DISTINCT han_id, firm_id, label, name, han_name, person_name, psn_name FROM labels
    JOIN firm_names USING(firm_id)
    JOIN patstat_data USING(han_id)
    WHERE firm_id = 12622
"""

con.execute(sql).fetchdf()

In [None]:
sql = """
    SELECT DISTINCT han_id, firm_id, similarity, label, name, han_name, person_name, psn_name FROM labels
    JOIN firm_names USING(firm_id)
    JOIN patstat_data USING(han_id)
    JOIN patstat_firm_match USING(han_id, firm_id)
    WHERE label = 1
    AND similarity < 0.92
"""

con.execute(sql).fetchdf()

In [None]:
con.execute("select * from patstat_firm_match where similarity < 0.91").fetchdf()

In [None]:
sql = """
SELECT DISTINCT firm_id, han_id, similarity, name, han_name, person_name, psn_name FROM patstat_firm_match
JOIN firm_names USING(firm_id)
JOIN patstat_data USING(han_id)
WHERE similarity >= 0.99
"""

con.execute(sql).fetchdf()