In [None]:
import os
import shutil
import textwrap
import warnings

import torch
from dotenv import load_dotenv
from langchain_community.vectorstores import Neo4jVector
from transformers import AutoModel, AutoTokenizer, pipeline

warnings.filterwarnings("ignore")

In [None]:
load_dotenv(".env", override=True)
NEO4J_URI = os.getenv("NEO4J_URI")
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
NEO4J_DATABASE = os.getenv("NEO4J_DATABASE") or "neo4j"

VECTOR_INDEX_NAME = "texts_from_records"
VECTOR_NODE_LABEL = "recordWithText"
VECTOR_SOURCE_PROPERTY = ["text"]
VECTOR_EMBEDDING_PROPERTY = "textEmbedding"

columns = shutil.get_terminal_size().columns if 'shutil' in globals() else 80

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

embedder_model_name = #embedder name
embedder_tokenizer = AutoTokenizer.from_pretrained(embedder_model_name)
embedder_model = AutoModel.from_pretrained(embedder_model_name).to(device).eval()

In [None]:
class MyEmbedding:
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.query_instruction = "Represent this sentence for searching relevant passages: "

    def embed_query(self, text):
        return self._embed([self.query_instruction + text])[0]

    def embed_documents(self, texts):
        return self._embed(texts)

    def _embed(self, texts):
        encoded_input = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt",
        ).to(self.device)

        with torch.no_grad():
            model_output = self.model(**encoded_input)

        embeddings = model_output.last_hidden_state[:, 0]
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        return embeddings.to(device).numpy().tolist()


my_embedding = MyEmbedding(embedder_model, embedder_tokenizer, device)

In [None]:
request = "What criticism did Rep. John J. 'Jack' Flynt, Jr. have regarding the use of the Labor, Health, Education, and Welfare Subcommittee?"
# request = "What criticism did Rep. John E. Fogarty (D-RI) have regarding the use of the Labor, Health, Education, and Welfare Subcommittee?"
# request = "Show me the documents about Labor, Health, Education, and Welfare Subcommittee"
# request = "What color is pineapple?"
# request = "Show me the documents about pineapples"
# request = "Richard F. Fenno Interview Notes"

In [None]:
retrieval_query = """
MATCH (node:recordWithText)
CALL (node) {
    MATCH window = (:recordWithText)-[:Next*0..1]->(node)-[:Next*0..1]->(:recordWithText)
    WITH window
    ORDER BY length(window) DESC
    LIMIT 1
    RETURN window AS longestWindow
}
WITH node, score, longestWindow
WITH nodes(longestWindow) AS chunkList, node, score
UNWIND chunkList AS chunkRows
WITH collect(chunkRows.text) AS textList, node, score
MATCH (firstNode)
WHERE firstNode.naId = node.naId AND firstNode.chunkSeqId = 0
OPTIONAL MATCH path = (root)-[:Includes*0..]->(firstNode)
WHERE NOT (root)<-[:Includes]-()
OPTIONAL MATCH (relatedNode)-[
    :broaderTerm
    |:contributor
    |:creator
    |:subject
    |:donor
    |:narrowerTerm
    |:organizationalReference
    |:relatedTerm
    |:jurisdiction
    |:organizationName
    |:personalReference
]->(firstNode)
WITH
    textList, node, score,
    COLLECT(DISTINCT relatedNode {.authorityType, .heading, .source}) AS relatedAuthorities,
    firstNode,
    path
WITH
    textList, node, score, relatedAuthorities,
    CASE firstNode.recordType
        WHEN 'description' THEN 
            [n IN reverse(nodes(path)) | n {
                .recordType,
                .levelOfDescription,
                .title,
                .logicalDate_coverageStartDate,
                .logicalDate_coverageEndDate,
                .source
            }]
        WHEN 'authority' THEN
            [n IN nodes(path) | n {
                .recordType,
                .authorityType,
                .heading,
                .source
            }]
    END AS pathNodes
RETURN
    apoc.text.join(textList, "\n") AS text,
    score,
    {
        path_nodes: pathNodes, 
        score: score,
        related_authorities: relatedAuthorities
    } AS metadata
"""

neo4j_vector_store = Neo4jVector.from_existing_graph(
    embedding=my_embedding,
    url=NEO4J_URI,
    username=NEO4J_USERNAME,
    password=NEO4J_PASSWORD,
    index_name=VECTOR_INDEX_NAME,
    node_label=VECTOR_NODE_LABEL,
    text_node_properties=VECTOR_SOURCE_PROPERTY,
    embedding_node_property=VECTOR_EMBEDDING_PROPERTY,
    retrieval_query=retrieval_query,
)

retriever = neo4j_vector_store.as_retriever(search_kwargs={"k": 5})

docs = retriever.get_relevant_documents(request)

for doc in docs:
    for node in doc.metadata["path_nodes"]:
        if node["recordType"] == "authority":
            node["title"] = node["heading"]
            node["logicalDate_coverageStartDate"] = "N/A"
            node["logicalDate_coverageEndDate"] = "N/A"

for doc in docs:
    print(textwrap.fill(
        f"Title: {doc.metadata["path_nodes"][0]["title"]} Text: {doc.page_content}", width=columns))
    print("=" * columns, end="\n\n")

In [None]:
docs