# Install packages

https://wiki.postgresql.org/wiki/Using_psycopg2_with_PostgreSQL#Fetch_Records_using_a_Server-Side_Cursor

In [2]:
pip install "psycopg[binary]" openai aiofiles neo4j load_dotenv

Collecting load_dotenv
  Downloading load_dotenv-0.1.0-py3-none-any.whl.metadata (1.9 kB)
Collecting python-dotenv (from load_dotenv)
  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)
Downloading load_dotenv-0.1.0-py3-none-any.whl (7.2 kB)
Downloading python_dotenv-1.0.1-py3-none-any.whl (19 kB)
Installing collected packages: python-dotenv, load_dotenv

Successfully installed load_dotenv-0.1.0 python-dotenv-1.0.1


In [3]:
import load_dotenv
import pandas as pd
import psycopg                      # SQL query package
import os                           # get env variables
from urllib.request import urlopen  # package for HTTP connections
import time

from openai import OpenAI
import json
import csv
from typing import List, Dict, Any, Optional

from neo4j import GraphDatabase

In [None]:
load_dotenv()

In [None]:
DB_HOST = load_dotenv("DB_HOST")
DB_NAME = load_dotenv("DB_NAME")
DB_USER = load_dotenv("DB_USER")
DB_PASS = load_dotenv("DB_PASS")

OPENAI_KEY = load_dotenv("OPENAI_KEY")
OPENAI_MODEL = load_dotenv("OPENAI_MODEL")

NEO4J_URI = load_dotenv("NEO4J_URI")
NEO4J_USER = load_dotenv("NEO4J_USER")
NEO4J_PASS = load_dotenv("NEO4J_PASS")

In [None]:
client = OpenAI(api_key=OPENAI_KEY)

# Connect to Postgres
creates `list_of_ids` = a list of document IDs

In [None]:
conn_string = f"host='{DB_HOST}' dbname = '{DB_NAME}' user='{DB_USER}' password='{DB_PASS}'"

In [6]:
conn = psycopg.connect(conn_string)

In [7]:
# HERE IS THE IMPORTANT PART, by specifying a name for the cursor
# psycopg2 creates a server-side cursor, which prevents all of the records from being downloaded at once from the server.
cursor = conn.cursor('server_cursor')

In [None]:
query = "SELECT id FROM ucsf_opioid.table_name"

In [None]:
cursor.execute(query)
records = cursor.fetchall()
print(len(records))

652


In [10]:
conn.close()

In [11]:
records[0][0]

'htcf0232'

In [None]:
# iterate through each record and save data into dictionary
list_of_ids = []

for i in records:
    list_of_ids.append(i[0])
list_of_ids

# Connect to Solr
creates `emails_list` = list of dictionaries 
* {"email_id": email_id, "email_text": email_body}

outputs `email_bodies_list.csv` of all ids and their email body texts

In [None]:
# add the email body from Solr to the metadata
emails_list = []
for i in list_of_ids:
    id_and_body = {}
    record_id = i

    id_and_body["email_id"] = record_id

    solr_url = f'{record_id}'
    connection = urlopen(solr_url)

    response_text = connection.read().decode('utf-8')
    data = json.loads(response_text)
    email_body = data['response']['docs'][0]['ocr_text'][0]
    
    id_and_body["email_text"] = email_body
    emails_list.append(id_and_body)

In [None]:
emails_list

In [None]:
# email_bodies_list.csv has list of dictionaries {email_id: email_body}
with open("email_bodies_list.csv", 'w', newline="", encoding="utf-8") as f:
    csv_writer = csv.writer(f)
    for i in emails_list:
        csv_writer.writerow([i])

# Call Batch API and OpenAI API
to format bodies into the schema by:

defining schema and system_instructions

checkpoint handling for in case the system crashes halfway through

outputs `checkpoints.jsonl` and `output_list` within python

## schema
`schema_json`

In [56]:
schema = {
  "@context": {
    "@vocab": "https://schema.org/",
    "email": "https://schema.org/EmailMessage",
    "person": "https://schema.org/Person",
    "org": "https://schema.org/Organization",
    "document": "https://schema.org/DigitalDocument",
    "topicEntity": "https://schema.org/Thing",
    "url": "https://schema.org/URL",
    "gpe": "https://schema.org/Place",
    "drug": "https://schema.org/Drug"
  },

  # Thread-level wrapper
  "@type": "case:Legislation",
  "semantic_type": "Legal Communication Record",
  "identifier": "",
  "legalStatus": "",
  "dateFiled": "",
  "language": [],   # could be "en" for English or a different language
  "confidentialityNotice": "",

  # MUST be the full raw CSV cell, unchanged except for JSON escaping
#   "raw_thread_text": "",

  # The thread: multiple messages in reverse-chronological order
  "hasPart": [
    {
      "@type": "email:EmailMessage",
      "semantic_type": "Email Communication",

      "identifier": "",
      "subject": "",      # substring of raw_thread_text (if present)
      "dateSent": "",
      "importance": "",

      "threadIndex": 0,   # 0 = most recent message
      "inReplyTo": "",    # optional

      "sender": {
        "@type": "person:Person",
        "semantic_type": "Person",
        "name": "",       # substring of raw_thread_text
        "email": "",      # substring of raw_thread_text
        "affiliation": {
          "@type": "org:Organization",
          "semantic_type": "ORG",
          "name": "",     # substring if present, else empty
          "role": "",
          "parentOrganization": {
            "@type": "org:Organization",
            "semantic_type": "ORG",
            "name": "",
            "role": ""
          }
        }
      },

      "recipient": [
        {
          "@type": "person:Person",
          "semantic_type": "Person",
          "name": "",     # substring
          "email": "",    # substring if present
          "affiliation": {
            "@type": "org:Organization",
            "semantic_type": "ORG",
            "name": "",
            "role": ""
          }
        }
      ],

      # MUST be a literal substring of raw_thread_text
      "body": "",

      # MENTIONS: text grounded, labels inferred
      "mentions": [
        {
          "@type": "topicEntity",
          "semantic_type": "",   # e.g. "Legal Case", "Product Brand", "Business Operation", "Financial Document", "Drug Name", "GPE"
          "role": "",            # e.g. "Geographic Destination"
          # MUST be substring of the user input text
          "name": "",
          # MUST be substring if it appears in the text; otherwise leave ""
          "identifier": ""
        }
      ],

      # ATTACHMENTS: names/desc grounded, format may be inferred
      "attachments": [
        {
          "@type": "document:DigitalDocument",
          "semantic_type": "",   # e.g. "Spreadsheet Document", "Presentation Document", "Policy Document"

          # MUST be substring of raw_thread_text
          "name": "",            # e.g. "TLAIS costings 8.09.04.xls"

          # fileFormat may be inferred (application/pdf, etc.)
          "fileFormat": "",

          # MUST be substring if the description text appears in the email;
          # otherwise either leave "" or omit the field.
          "description": ""
        }
      ],

      "forwardedMessage": None,
      "mentionsEmail": [],
      "structuredArgument": [],
      "complianceContext": {
          "@type": "CreativeWork",
          "semantic_type": "",
          "name": "",
          "keywords": [],
          "about": []
      }
    }
  ]
}

schema_json = json.dumps(schema, indent=2)

## continue

In [57]:
system_instructions = f"""
- The user input is a single email thread (one CSV row).
- Split the thread into individual email messages.
- For each message, create one object in `hasPart` with @type "email:EmailMessage".
- Order `hasPart` in reverse-chronological order so index 0 is the most recent message.

GROUNDING RULES (VERY IMPORTANT):
- `body` MUST be a literal substring of the user input text. Do not paraphrase or summarize.
- `subject`, sender/recipient `name` and `email`, attachment `name`, and any description text
  MUST be literal substrings of the user input text where they appear.
- For each `mentions` item:
  - `name` MUST be a literal substring of the user input text.
  - `identifier` MUST be a literal substring of the user input text if present; otherwise leave "".
  - `semantic_type` and `role` are categorical labels and MAY be inferred.
- For each `attachments` item:
  - `name` MUST be a literal substring of the user input text.
  - `description` MUST be a literal substring if such text exists; otherwise use "" or omit the field.
  - `fileFormat` MAY be inferred (e.g., "application/vnd.ms-excel" for .xls) and
    does not need to match any literal span.
- Do NOT introduce any text in `body`, `name`, `description`, `subject`,
  or other free-text fields that is not a substring of the user input text.

- You may assign or infer categorical labels such as:
  - mentions.semantic_type: "Legal Case", "Product Brand", "Business Operation",
    "Financial Document", "Drug Name", "GPE", etc.
  - mentions.role: e.g., "Geographic Destination".
  - attachments.semantic_type: "Spreadsheet Document", "Presentation Document",
    "Policy Document", etc.
  - attachments.fileFormat: e.g., "application/vnd.ms-excel", "application/pdf", etc.
  - complianceContext.semantic_type: e.g., "Regulatory and Legal Framework".
  - complianceContext.name: e.g., "Export Control and Restricted Destination Compliance".

- Do NOT summarize the thread; your job is to structure it, not rewrite it.

- Output strictly must be valid JSON only (no markdown or commentary).
- Follow this JSON structure exactly: {schema_json}
"""

In [58]:
def chunk(lst, size):
    for i in range(0, len(lst), size):
        yield lst[i:i+size]

In [None]:
def build_batch_jsonl(
    emails: List[Dict[str, str]],
    system_instructions: str,
    jsonl_path: str,
    model: str = OPENAI_MODEL,  # or any chat model with JSON mode
) -> None:
    """
    Build a JSONL file where each line is a Batch request for /v1/chat/completions,
    keyed by email_id via custom_id.
    """
    with open(jsonl_path, "w", encoding="utf-8") as f:
        for rec in emails:
            email_id = rec["email_id"]
            email_text = rec["email_text"]

            task = {
                "custom_id": str(email_id),
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": model,
                    "response_format": {"type": "json_object"},
                    "temperature": 0.0,
                    "max_tokens": 6000,
                    "messages": [
                        {
                            "role": "system",
                            "content": system_instructions,
                        },
                        {
                            "role": "user",
                            "content": email_text,
                        },
                    ],
                },
            }

            f.write(json.dumps(task, ensure_ascii=False) + "\n")

In [60]:
def submit_batch(jsonl_path: str) -> str:
    """
    Upload the batch JSONL file and create a Batch job.
    Returns the batch_id.
    """
    # 1. Upload file
    with open(jsonl_path, "rb") as f:
        batch_file = client.files.create(
            file=f,
            purpose="batch",
        )

    # 2. Create batch job
    batch = client.batches.create(
        input_file_id=batch_file.id,
        endpoint="/v1/chat/completions",
        completion_window="24h",  # currently the only allowed window
        metadata={"job_type": "ocr_email_structuring"},
    )

    print("Created batch:", batch.id)
    return batch.id


In [61]:
def wait_for_batch(batch_id: str, poll_interval: int = 60, max_polls: int = 120):
    """
    Polls the batch until it reaches a terminal state or until max_polls is exceeded.
    Returns the last seen batch object.
    """
    polls = 0
    last_completed = None

    while True:
        batch = client.batches.retrieve(batch_id)
        polls += 1

        print(
            f"Batch {batch_id} status: {batch.status} | "
            f"completed={batch.request_counts.completed} / total={batch.request_counts.total}"
        )

        # Track progress; if you want, detect "stuck" here
        if last_completed is None or batch.request_counts.completed != last_completed:
            last_completed = batch.request_counts.completed

        if batch.status in ("completed", "failed", "expired", "cancelled"):
            return batch

        if polls >= max_polls:
            print(f"Reached max_polls={max_polls}, stopping wait loop.")
            return batch

        time.sleep(poll_interval)


In [62]:
def download_and_parse_results(
    batch,
    output_jsonl_path: Optional[str] = None,
) -> Dict[str, Any]:
    """
    Given a completed batch object, download its output JSONL and return:
        { custom_id (email_id): parsed_json }

    If output_jsonl_path is given, append each as:
        {"email_id": "<id>", "output": <parsed_json>}
    to that JSONL file.
    """
    if not batch.output_file_id:
        raise RuntimeError(f"Batch {batch.id} has no output_file_id; status={batch.status}")

    file_content = client.files.content(batch.output_file_id).text

    results_by_email_id: Dict[str, Any] = {}

    out_f = open(output_jsonl_path, "a", encoding="utf-8") if output_jsonl_path else None

    try:
        for line in file_content.splitlines():
            if not line.strip():
                continue

            record = json.loads(line)
            custom_id = record.get("custom_id")

            # 1) Batch-level error for this task
            if record.get("error") is not None:
                parsed_json = {"_error": record["error"]}
                results_by_email_id[custom_id] = parsed_json
                if out_f is not None:
                    out_f.write(json.dumps({"email_id": custom_id, "output": parsed_json},
                                           ensure_ascii=False) + "\n")
                continue

            response = record.get("response")
            if response is None:
                # Unexpected shape
                parsed_json = {"_raw_record": record}
                results_by_email_id[custom_id] = parsed_json
                if out_f is not None:
                    out_f.write(json.dumps({"email_id": custom_id, "output": parsed_json},
                                           ensure_ascii=False) + "\n")
                continue

            body = response.get("body", {})
            choices = body.get("choices") or []
            if not choices:
                parsed_json = {"_raw_body": body}
                results_by_email_id[custom_id] = parsed_json
                if out_f is not None:
                    out_f.write(json.dumps({"email_id": custom_id, "output": parsed_json},
                                           ensure_ascii=False) + "\n")
                continue

            choice0 = choices[0]
            message = choice0.get("message", {})
            finish_reason = choice0.get("finish_reason")
            content = message.get("content")

            # If we hit the max_tokens limit, flag it explicitly
            if finish_reason == "length":
                parsed_json = {
                    "_error": "truncated_output_max_tokens",
                    "_raw_content": content,
                }
                results_by_email_id[custom_id] = parsed_json
                if out_f is not None:
                    out_f.write(json.dumps({"email_id": custom_id, "output": parsed_json},
                                           ensure_ascii=False) + "\n")
                continue

            # 2) Normalize content into a string
            if isinstance(content, list):
                parts = []
                for block in content:
                    if isinstance(block, dict):
                        if "text" in block and isinstance(block["text"], str):
                            parts.append(block["text"])
                        elif "output_text" in block and isinstance(block["output_text"], dict):
                            t = block["output_text"].get("text")
                            if isinstance(t, str):
                                parts.append(t)
                content_str = "".join(parts).strip()
            elif isinstance(content, str):
                content_str = content.strip()
            else:
                parsed_json = {"_raw_content": content}
                results_by_email_id[custom_id] = parsed_json
                if out_f is not None:
                    out_f.write(json.dumps({"email_id": custom_id, "output": parsed_json},
                                           ensure_ascii=False) + "\n")
                continue

            # 3) Parse JSON content
            try:
                parsed_json = json.loads(content_str)
            except json.JSONDecodeError:
                print("FAILED TO PARSE JSON FOR", custom_id)
                print("Content snippet:", repr(content_str[:300]))
                parsed_json = {"_raw_content": content_str}

            results_by_email_id[custom_id] = parsed_json

            if out_f is not None:
                out_f.write(json.dumps({"email_id": custom_id, "output": parsed_json},
                                       ensure_ascii=False) + "\n")

    finally:
        if out_f is not None:
            out_f.close()

    return results_by_email_id

In [None]:
# jsonl_path = "OpenAI_API_Output.jsonl"

In [None]:
batch_ids = []
for i, email_chunk in enumerate(chunk(emails_list, 300)):
    jsonl_path = f"batch_input_{i:04d}.jsonl"
    build_batch_jsonl(email_chunk, system_instructions, jsonl_path)
    batch_id = submit_batch(jsonl_path)
    batch_ids.append(batch_id)
    final_batch = wait_for_batch(batch_id, poll_interval=60)
    

In [None]:
combined_jsonl = "OpenAI_API_Output.jsonl"

for batch_id in batch_ids:
    batch = client.batches.retrieve(batch_id)
    if batch.status != "completed":
        print(f"Skipping batch {batch_id} with status={batch.status}")
        continue

    download_and_parse_results(batch, output_jsonl_path=combined_jsonl)

print(f"Combined output written to {combined_jsonl}")


In [None]:
combined_jsonl = "OpenAI_API_Output.jsonl"

# Start with a fresh/truncated file (optional)
# open(combined_jsonl, "w", encoding="utf-8").close()

# List batches from the API (tweak limit as needed)
batch_list = client.batches.list(limit=100)

In [None]:
batch_list.data

In [None]:
for batch in batch_list.data:
    # Filter to your specific job type, and only completed ones
    if getattr(batch, "metadata", None) and batch.metadata.get("job_type") == "ocr_email_structuring":
        if batch.status != "completed":
            print(f"Skipping batch {batch.id} with status={batch.status}")
            continue

        print(f"Collecting from batch {batch.id}")
        download_and_parse_results(batch, output_jsonl_path=combined_jsonl)

print(f"Combined output written to {combined_jsonl}")

In [None]:
# Build JSONL batch file
build_batch_jsonl(
    emails=emails_list,
    system_instructions=system_instructions,
    jsonl_path=jsonl_path,
    model=OPENAI_MODEL,   # or your preferred model
)

In [None]:
# Submit batch
batch_id = submit_batch(jsonl_path)

In [None]:
# Wait for completion
final_batch = wait_for_batch(batch_id, poll_interval=30)

if final_batch.status != "completed":
    print("Batch did not complete successfully.")
    print("Status:", final_batch.status)
    print("Errors:", getattr(final_batch, "errors", None))
    print("Request counts:", final_batch.request_counts)

    if final_batch.output_file_id:
        print("Attempting to inspect output file for per-request errors...")
        out_content = client.files.content(final_batch.output_file_id).text
        for i, line in enumerate(out_content.splitlines()):
            print(line)
            if i >= 20:  # show first 20 lines max
                break
    else:
        print("No output_file_id present.")

    # Only raise after logging everything:
    raise RuntimeError(f"Batch failed/ended with status={final_batch.status}")


In [None]:
# Download and parse structured results
structured_by_email_id = download_and_parse_results(final_batch)

In [None]:
# Example: inspect one
for email_id, structured in list(structured_by_email_id.items())[:3]:
    print(f"\n=== {email_id} ===")
    print(json.dumps(structured, indent=2, ensure_ascii=False))


# The code below this cell is depreciated.
Please refer to the `Part 3` notebook for the code that connects to Neo4j and to `graph_queries.py` for the Cypher queries to build the knowledge graph.

# Connect to Neo4j
https://browser.neo4j.io/



In [None]:
# REMEMBER TO CLOSE CONNECTION
uri = NEO4J_URI
username = NEO4J_USER
password = NEO4J_PASS
driver = GraphDatabase.driver(uri, auth=(username, password))

In [None]:
# test query
with driver.session() as session:
    result = session.run("RETURN 1 AS n")
    print(result.single())

<Record n=1>


In [None]:
# email subject lines and drug name of any email that mentions drugs
with driver.session() as session:
    cypher_query = session.run('MATCH (e:Email)-[:MENTIONS]->(t:TopicEntity {semantic_type: "Drug Name"}) RETURN e.subject, t.name;')
    print(type(cypher_query))
    cypher_query = cypher_query.data()   # converts from Result to list
for i in cypher_query:
    print(i)

## add constraints
`setup_constraints` with parameters (uri, user, password)

In [None]:
def setup_constraints(uri, user, password):
    driver = GraphDatabase.driver(uri, auth=(user, password))

    constraint_statements = [
        # Core entities
        """
        CREATE CONSTRAINT case_identifier IF NOT EXISTS
        FOR (c:Case)
        REQUIRE c.identifier IS UNIQUE
        """,
        """
        CREATE CONSTRAINT email_identifier IF NOT EXISTS
        FOR (e:Email)
        REQUIRE e.identifier IS UNIQUE
        """,
        """
        CREATE CONSTRAINT person_key IF NOT EXISTS
        FOR (p:Person)
        REQUIRE p.key IS UNIQUE
        """,
        """
        CREATE CONSTRAINT org_name IF NOT EXISTS
        FOR (o:Organization)
        REQUIRE o.name IS UNIQUE
        """,
        """
        CREATE CONSTRAINT document_name IF NOT EXISTS
        FOR (d:Document)
        REQUIRE d.name IS UNIQUE
        """,
        """
        CREATE CONSTRAINT place_name IF NOT EXISTS
        FOR (pl:Place)
        REQUIRE pl.name IS UNIQUE
        """,
        """
        CREATE CONSTRAINT topicentity_name IF NOT EXISTS
        FOR (t:TopicEntity)
        REQUIRE t.name IS UNIQUE
        """,
        """
        CREATE CONSTRAINT crossrefemail_cid IF NOT EXISTS
        FOR (x:CrossRefEmail)
        REQUIRE x.cid IS UNIQUE
        """,

        # Enriched-content entities
        """
        CREATE CONSTRAINT decision_text IF NOT EXISTS
        FOR (d:Decision)
        REQUIRE d.text IS UNIQUE
        """,
        """
        CREATE CONSTRAINT concern_text IF NOT EXISTS
        FOR (c:Concern)
        REQUIRE c.text IS UNIQUE
        """,

        # FinancialMention – there are two flavors, so we use two constraints:
        # one for simple text mentions, one for (description, figure, currency)
        """
        CREATE CONSTRAINT financialmention_text IF NOT EXISTS
        FOR (f:FinancialMention)
        REQUIRE f.text IS UNIQUE
        """,
        """
        CREATE CONSTRAINT financialmention_desc_fig_cur IF NOT EXISTS
        FOR (f:FinancialMention)
        REQUIRE (f.description, f.figure, f.currency) IS UNIQUE
        """
    ]

    with driver.session() as session:
        for stmt in constraint_statements:
            session.run(stmt)

    driver.close()


## cypher queries
creates function `import_jsonl_to_neo4j`
with parameters(json_path, uri, user, password, log_every)

In [None]:
from neo4j import GraphDatabase
import json
from typing import Any, Dict, List, Union


def ensure_list(x: Union[None, Dict[str, Any], List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
    if x is None:
        return []
    if isinstance(x, list):
        return x
    return [x]


# ----------------- Upsert helpers ----------------- #

def upsert_case(tx, case_obj: Dict[str, Any]):
    case_id = case_obj.get("identifier")
    if not case_id:
        return

    tx.run(
        """
        MERGE (c:Case {identifier: $identifier})
        SET
          c.semantic_type = $semantic_type,
          c.legalStatus = $legalStatus,
          c.dateFiled = $dateFiled,
          c.confidentialityNotice = $confidentialityNotice,
          c.language = $language
        """,
        identifier=case_id,
        semantic_type=case_obj.get("semantic_type"),
        legalStatus=case_obj.get("legalStatus"),
        dateFiled=case_obj.get("dateFiled"),
        confidentialityNotice=case_obj.get("confidentialityNotice"),
        language=case_obj.get("language"),
    )

    # Case-level mentions
    for mention in case_obj.get("mentions") or []:
        if isinstance(mention, dict):
            upsert_case_mention(tx, case_id, mention)

    # hasPart emails
    for email_obj in ensure_list(case_obj.get("hasPart")):
        if isinstance(email_obj, dict):
            upsert_email_recursive(tx, case_id, email_obj, parent_email_id=None)


def upsert_case_mention(tx, case_id: str, mention: Dict[str, Any]):
    m_type = mention.get("@type")
    name = mention.get("name")
    if not name:
        return

    sem = mention.get("semantic_type")
    identifier = mention.get("identifier")

    if m_type == "gpe":
        label = "Place"
    else:
        label = "TopicEntity"

    # Node
    tx.run(
        f"""
        MERGE (m:{label} {{name: $name}})
        SET
          m.semantic_type = $semantic_type,
          m.identifier = $identifier
        """,
        name=name,
        semantic_type=sem,
        identifier=identifier,
    )

    # Relationship
    tx.run(
        f"""
        MATCH (c:Case {{identifier: $case_id}})
        MATCH (m:{label} {{name: $name}})
        MERGE (c)-[:CASE_MENTIONS]->(m)
        """,
        case_id=case_id,
        name=name,
    )


def upsert_person(tx, person: Dict[str, Any]) -> str:
    if not person:
        return None

    name = person.get("name") or "Unknown"
    email_addr = person.get("email")
    sem = person.get("semantic_type")
    key = email_addr or name

    tx.run(
        """
        MERGE (p:Person {key: $key})
        SET
          p.name = $name,
          p.email = $email,
          p.semantic_type = $semantic_type
        """,
        key=key,
        name=name,
        email=email_addr,
        semantic_type=sem,
    )

    aff = person.get("affiliation")
    if isinstance(aff, dict):
        upsert_org_for_person(tx, key, aff)

    return key


def upsert_org_for_person(tx, person_key: str, org: Dict[str, Any]):
    name = org.get("name")
    if not name:
        return

    role = org.get("role")
    sem = org.get("semantic_type")

    # Org node
    tx.run(
        """
        MERGE (o:Organization {name: $name})
        SET
          o.semantic_type = $semantic_type,
          o.role = $role
        """,
        name=name,
        semantic_type=sem,
        role=role,
    )

    # Person -> Org
    tx.run(
        """
        MATCH (p:Person {key: $person_key})
        MATCH (o:Organization {name: $name})
        MERGE (p)-[:AFFILIATED_WITH]->(o)
        """,
        person_key=person_key,
        name=name,
    )

    parent = org.get("parentOrganization")
    if isinstance(parent, dict) and parent.get("name"):
        pname = parent.get("name")
        psem = parent.get("semantic_type")
        prole = parent.get("role")

        # Parent org
        tx.run(
            """
            MERGE (po:Organization {name: $pname})
            SET
              po.semantic_type = $p_sem,
              po.role = $p_role
            """,
            pname=pname,
            p_sem=psem,
            p_role=prole,
        )

        # Org -> Parent
        tx.run(
            """
            MATCH (o:Organization {name: $name})
            MATCH (po:Organization {name: $pname})
            MERGE (o)-[:SUBSIDIARY_OF]->(po)
            """,
            name=name,
            pname=pname,
        )


def upsert_mention_for_email(tx, email_id: str, mention: Dict[str, Any]):
    m_type = mention.get("@type")
    name = mention.get("name")
    if not name:
        return

    sem = mention.get("semantic_type")
    identifier = mention.get("identifier")
    role = mention.get("role")

    if m_type == "gpe":
        label = "Place"
        rel_type = "EMAIL_MENTIONS_PLACE"
    else:
        label = "TopicEntity"
        rel_type = "EMAIL_MENTIONS_TOPIC"

    # Node
    tx.run(
        f"""
        MERGE (m:{label} {{name: $name}})
        SET
          m.semantic_type = $semantic_type,
          m.identifier = $identifier,
          m.role = $role
        """,
        name=name,
        semantic_type=sem,
        identifier=identifier,
        role=role,
    )

    # Relationship
    tx.run(
        f"""
        MATCH (e:Email {{identifier: $email_id}})
        MATCH (m:{label} {{name: $name}})
        MERGE (e)-[:{rel_type}]->(m)
        """,
        email_id=email_id,
        name=name,
    )


def upsert_attachment(tx, email_id: str, case_id: str, attachment: Dict[str, Any]):
    name = attachment.get("name")
    if not name:
        return

    sem = attachment.get("semantic_type")
    file_format = attachment.get("fileFormat")
    desc = attachment.get("description")

    # Document node
    tx.run(
        """
        MERGE (d:Document {name: $name})
        SET
          d.semantic_type = $semantic_type,
          d.fileFormat = $fileFormat,
          d.description = $description
        """,
        name=name,
        semantic_type=sem,
        fileFormat=file_format,
        description=desc,
    )

    # Email–Document
    tx.run(
        """
        MATCH (e:Email {identifier: $email_id})
        MATCH (d:Document {name: $name})
        MERGE (e)-[:HAS_ATTACHMENT]->(d)
        """,
        email_id=email_id,
        name=name,
    )

    # Case–Document
    if case_id:
        tx.run(
            """
            MATCH (c:Case {identifier: $case_id})
            MATCH (d:Document {name: $name})
            MERGE (c)-[:CASE_HAS_DOCUMENT]->(d)
            """,
            case_id=case_id,
            name=name,
        )


def upsert_email_recursive(tx, case_id: str, email_obj: Dict[str, Any], parent_email_id: str = None):
    if not email_obj:
        return

    email_id = email_obj.get("identifier")
    if not email_id:
        email_id = f"{email_obj.get('subject', 'Unknown')}|{email_obj.get('dateSent', '')}"

    # Email node
    tx.run(
        """
        MERGE (e:Email {identifier: $identifier})
        SET
          e.semantic_type = $semantic_type,
          e.subject = $subject,
          e.dateSent = $dateSent,
          e.importance = $importance,
          e.body = $body
        """,
        identifier=email_id,
        semantic_type=email_obj.get("semantic_type"),
        subject=email_obj.get("subject"),
        dateSent=email_obj.get("dateSent"),
        importance=email_obj.get("importance"),
        body=email_obj.get("body"),
    )

    # Case–Email
    if case_id:
        tx.run(
            """
            MATCH (c:Case {identifier: $case_id})
            MATCH (e:Email {identifier: $email_id})
            MERGE (c)-[:HAS_EMAIL]->(e)
            """,
            case_id=case_id,
            email_id=email_id,
        )

    # Parent email relationship (for forwarded messages)
    if parent_email_id:
        tx.run(
            """
            MATCH (parent:Email {identifier: $parent_id})
            MATCH (child:Email {identifier: $email_id})
            MERGE (parent)-[:FORWARDED_MESSAGE]->(child)
            """,
            parent_id=parent_email_id,
            email_id=email_id,
        )

    # Sender
    sender = email_obj.get("sender")
    if isinstance(sender, dict):
        sender_key = upsert_person(tx, sender)
        if sender_key:
            tx.run(
                """
                MATCH (e:Email {identifier: $email_id})
                MATCH (p:Person {key: $sender_key})
                MERGE (p)-[:SENT]->(e)
                """,
                email_id=email_id,
                sender_key=sender_key,
            )

    # Recipients
    for rcpt in email_obj.get("recipient") or []:
        if isinstance(rcpt, dict):
            rcpt_key = upsert_person(tx, rcpt)
            if rcpt_key:
                tx.run(
                    """
                    MATCH (e:Email {identifier: $email_id})
                    MATCH (p:Person {key: $rcpt_key})
                    MERGE (e)-[:SENT_TO]->(p)
                    """,
                    email_id=email_id,
                    rcpt_key=rcpt_key,
                )

    # Mentions
    for mention in email_obj.get("mentions") or []:
        if isinstance(mention, dict):
            upsert_mention_for_email(tx, email_id, mention)

    # Attachments
    for att in email_obj.get("attachments") or []:
        if isinstance(att, dict):
            upsert_attachment(tx, email_id, case_id, att)

    # Forwarded / nested
    fwd = email_obj.get("forwardedMessage")
    if isinstance(fwd, dict):
        upsert_email_recursive(tx, case_id, fwd, parent_email_id=email_id)

    # mentionsEmail
    for me in email_obj.get("mentionsEmail") or []:
        if isinstance(me, dict):
            ref_id = me.get("identifier")
            if ref_id:
                tx.run(
                    """
                    MERGE (ref:Email {identifier: $ref_id})
                    """,
                    ref_id=ref_id,
                )
                tx.run(
                    """
                    MATCH (e:Email {identifier: $email_id})
                    MATCH (ref:Email {identifier: $ref_id})
                    MERGE (e)-[:MENTIONS_EMAIL]->(ref)
                    """,
                    email_id=email_id,
                    ref_id=ref_id,
                )


# ----------------- Main import with logging & error handling ----------------- #

def import_jsonl_to_neo4j(
    jsonl_path: str,
    uri: str,
    user: str,
    password: str,
    log_every: int = 25,
):
    """
    Import JSONL case/email schemas into Neo4j with:
      - progress logging every `log_every` lines
      - per-line try/except so a bad record doesn't kill the whole run
    """
    driver = GraphDatabase.driver(uri, auth=(user, password))

    total_lines = 0
    success_cases = 0
    skipped_lines = 0
    failed_cases = 0

    with driver.session() as session:
        with open(jsonl_path, "r", encoding="utf-8") as f:
            start_time = time.time()
            for line_no, line in enumerate(f, start=1):
                total_lines += 1
                line = line.strip()
                if not line:
                    skipped_lines += 1
                    continue

                # Progress log
                if line_no % log_every == 0:
                    print(f"[INFO] Processing line {line_no}... (success={success_cases}, failed={failed_cases}, skipped={skipped_lines})")
                    print('\t took', time.time() - start_time, 'seconds')

                try:
                    wrapper = json.loads(line)
                except json.JSONDecodeError as e:
                    print(f"[WARN] Skipping line {line_no}: invalid JSON wrapper ({e})")
                    skipped_lines += 1
                    continue

                output_raw = wrapper.get("output")
                if not output_raw:
                    print(f"[WARN] Skipping line {line_no}: no 'output' field")
                    skipped_lines += 1
                    continue

                try:
                    case_obj = json.loads(output_raw)
                except json.JSONDecodeError:
                    if isinstance(output_raw, dict):
                        case_obj = output_raw
                    else:
                        print(f"[WARN] Skipping line {line_no}: invalid 'output' JSON")
                        skipped_lines += 1
                        continue

                case_id = case_obj.get("identifier")

                # Wrap the write in try/except so a single bad case doesn't kill everything
                try:
                    def work(tx):
                        upsert_case(tx, case_obj)

                    session.execute_write(work)
                    success_cases += 1

                except Exception as e:
                    failed_cases += 1
                    print(f"[ERROR] Failed to import case on line {line_no} (case_id={case_id!r}): {type(e).__name__}: {e}")

    driver.close()

    print("\n=== Import summary ===")
    print(f"Total lines read:     {total_lines}")
    print(f"Successful cases:     {success_cases}")
    print(f"Failed cases:         {failed_cases}")
    print(f"Skipped lines:        {skipped_lines}")
    print('Runtime (s):          ', time.time() - start_time)


## continue

In [None]:
setup_constraints(uri=NEO4J_URI, user=NEO4J_USER, password=NEO4J_PASS)

In [None]:
jsonl_path = 'enriched_output.jsonl'

In [None]:
import_jsonl_to_neo4j(
    jsonl_path=jsonl_path,
    uri=NEO4J_URI,
    user="neo4j",
    password=NEO4J_PASS,
    log_every=50 
)

[INFO] Processing line 50... (success=49, failed=0, skipped=0)
[INFO] Processing line 100... (success=99, failed=0, skipped=0)
[INFO] Processing line 150... (success=149, failed=0, skipped=0)
[INFO] Processing line 200... (success=199, failed=0, skipped=0)
[INFO] Processing line 250... (success=249, failed=0, skipped=0)
[INFO] Processing line 300... (success=299, failed=0, skipped=0)
[INFO] Processing line 350... (success=349, failed=0, skipped=0)
[INFO] Processing line 400... (success=399, failed=0, skipped=0)
[INFO] Processing line 450... (success=449, failed=0, skipped=0)
[INFO] Processing line 500... (success=499, failed=0, skipped=0)
[INFO] Processing line 550... (success=549, failed=0, skipped=0)
[INFO] Processing line 600... (success=599, failed=0, skipped=0)
[INFO] Processing line 650... (success=649, failed=0, skipped=0)

=== Import summary ===
Total lines read:     652
Successful cases:     652
Failed cases:         0
Skipped lines:        0


In [None]:
# delete entire graph
with driver.session() as session:
    cypher_query = session.run("MATCH (n) DETACH DELETE n;")

In [47]:
driver.close()