In [1]:
import os
import re
from datetime import datetime
import warnings

import ipywidgets as widgets
import torch
from dotenv import load_dotenv
from IPython.display import display, HTML
from langchain_community.vectorstores import Neo4jVector
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, BitsAndBytesConfig, pipeline

warnings.filterwarnings("ignore")

In [None]:
load_dotenv(".env", override=True)
NEO4J_URI = os.getenv("NEO4J_URI")
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
NEO4J_DATABASE = os.getenv("NEO4J_DATABASE") or "neo4j"

VECTOR_INDEX_NAME = "texts_from_records"
VECTOR_NODE_LABEL = "recordWithText"
VECTOR_SOURCE_PROPERTY = ["text"]
VECTOR_EMBEDDING_PROPERTY = "textEmbedding"

HF_TOKEN = os.getenv("HF_TOKEN")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_name_embedder = #embedder name
tokenizer_embedder = AutoTokenizer.from_pretrained(model_name_embedder)
model_embedder = AutoModel.from_pretrained(model_name_embedder).to(device).eval()

model_name_llm = #LLM name
tokenizer_llm = AutoTokenizer.from_pretrained(
    model_name_llm, use_auth_token=HF_TOKEN
)
tokenizer_llm.pad_token = tokenizer_llm.eos_token
tokenizer_llm.pad_token_id = tokenizer_llm.eos_token_id

quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_compute_dtype=torch.float16,
)

model_llm = AutoModelForCausalLM.from_pretrained(
    model_name_llm,
    quantization_config=quantization_config,
    use_auth_token=HF_TOKEN,
    device_map="auto",
).eval()

In [None]:
class Embeddings:
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.query_instruction = "Represent this sentence for searching relevant passages: "

    def embed_query(self, text):
        return self._embed([self.query_instruction + text])[0]

    def embed_documents(self, texts):
        return self._embed(texts)

    def _embed(self, texts):
        encoded_input = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt",
        ).to(self.device)

        with torch.no_grad():
            model_output = self.model(**encoded_input)

        embeddings = model_output.last_hidden_state[:, 0]
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        return embeddings.to(device).numpy().tolist()


my_embeddings = Embeddings(model_embedder, tokenizer_embedder, device)

In [None]:
text_generation_pipe = pipeline(
    "text-generation",
    model=model_llm,
    tokenizer=tokenizer_llm,
    max_new_tokens=100,
    do_sample=False,
    repetition_penalty=1.15,
    temperature=None,
)

In [None]:
def retrieve_docs(request):
    retrieval_query = """
    MATCH (node:recordWithText)
    CALL (node) {
        MATCH window = (:recordWithText)-[:Next*0..1]->(node)-[:Next*0..1]->(:recordWithText)
        WITH window
        ORDER BY length(window) DESC
        LIMIT 1
        RETURN window AS longestWindow
    }
    WITH node, score, longestWindow
    WITH nodes(longestWindow) AS chunkList, node, score
    UNWIND chunkList AS chunkRows
    WITH collect(chunkRows.text) AS textList, node, score
    MATCH (firstNode)
    WHERE firstNode.naId = node.naId AND firstNode.chunkSeqId = 0
    OPTIONAL MATCH path = (root)-[:Includes*0..]->(firstNode)
    WHERE NOT (root)<-[:Includes]-()
    OPTIONAL MATCH (relatedNode)-[
        :broaderTerm
        |:contributor
        |:creator
        |:subject
        |:donor
        |:narrowerTerm
        |:organizationalReference
        |:relatedTerm
        |:jurisdiction
        |:organizationName
        |:personalReference
    ]->(firstNode)
    WITH
        textList, node, score,
        COLLECT(DISTINCT relatedNode {.authorityType, .heading, .source}) AS relatedAuthorities,
        firstNode,
        path
    WITH
        textList, node, score, relatedAuthorities,
        CASE firstNode.recordType
            WHEN 'description' THEN 
                [n IN reverse(nodes(path)) | n {
                    .recordType,
                    .levelOfDescription,
                    .title,
                    .logicalDate_coverageStartDate,
                    .logicalDate_coverageEndDate,
                    .source
                }]
            WHEN 'authority' THEN
                [n IN nodes(path) | n {
                    .recordType,
                    .authorityType,
                    .heading,
                    .source
                }]
        END AS pathNodes
    RETURN
        apoc.text.join(textList, "\n") AS text,
        score,
        {
            path_nodes: pathNodes, 
            score: score,
            related_authorities: relatedAuthorities
        } AS metadata
    """

    neo4j_vector_store = Neo4jVector.from_existing_graph(
        embedding=my_embeddings,
        url=NEO4J_URI,
        username=NEO4J_USERNAME,
        password=NEO4J_PASSWORD,
        index_name=VECTOR_INDEX_NAME,
        node_label=VECTOR_NODE_LABEL,
        text_node_properties=VECTOR_SOURCE_PROPERTY,
        embedding_node_property=VECTOR_EMBEDDING_PROPERTY,
        retrieval_query=retrieval_query,
    )

    retriever = neo4j_vector_store.as_retriever(search_kwargs={"k": 5})

    docs = retriever.get_relevant_documents(request)

    for doc in docs:
        for node in doc.metadata["path_nodes"]:
            if node["recordType"] == "authority":
                node["title"] = node["heading"]
                node["logicalDate_coverageStartDate"] = "N/A"
                node["logicalDate_coverageEndDate"] = "N/A"

    return docs

In [None]:
def classify_request_type(request, tokenizer_llm, text_generation_pipe):
    classification_prompt = """
    Classify the type of user request. Reply with only one word:
    - "question": if this is a question that needs an answer;
    - "show_records": if the user is directly asking to display, list, show, etc.
    archival materials/documents/full texts/records/sources.

    If the request does NOT explicitly ask to display, list, show, etc., classify as
    "question".

    Examples:

    Request: Why is the sky blue?
    Output: question

    Request: Show me documents about politics.
    Output: show_records

    Request: List the sources.
    Output: show_records

    Request: Give me all related documents.
    Output: show_records

    Request: Who was president in 1952?
    Output: question
    """

    chat = [
        {"role": "system", "content": classification_prompt},
        {"role": "user", "content": f"Request: {request}"},
    ]

    input_text = tokenizer_llm.apply_chat_template(
        chat, tokenize=False, add_generation_prompt=True
    )

    generated = text_generation_pipe(input_text, max_new_tokens=10)
    full_output = generated[0]["generated_text"]
    response = (
        full_output.split("[/INST]")[-1].strip()
        if "[/INST]" in full_output
        else full_output.strip()
    )

    return "show_records" if "show_records" in response.lower() else "question"

In [None]:
def process_date(metadata):
    if (
        metadata.get("logicalDate_coverageStartDate") != "N/A"
        and metadata.get("logicalDate_coverageEndDate") != "N/A"
    ):
        try:
            start_date = datetime.strptime(
                metadata["logicalDate_coverageStartDate"], "%Y-%m-%d"
            )
            end_date = datetime.strptime(
                metadata["logicalDate_coverageEndDate"], "%Y-%m-%d"
            )
            return f" ({start_date.strftime('%d/%m/%Y')} — {end_date.strftime('%d/%m/%Y')})"
        except (ValueError, KeyError):
            return ""
    return ""


def process_authority_type(metadata):
    if metadata.get("recordType") == "authority":
        return f" ({pretty_names[metadata.get('authorityType')]})"
    else:
        return ""


def generate_answer(request, request_type, docs, tokenizer_llm, text_generation_pipe):
    if not docs:
        return {
            "answer": "I cannot respond to your request based on the available archival materials.",
            "nodes_info": "N/A",
        }

    nodes_info = [
        (
            f"«{doc.metadata['path_nodes'][0].get('title')}»"
            f"{process_date(doc.metadata['path_nodes'][0])}"
            f"{process_authority_type(doc.metadata['path_nodes'][0])}: "
            f"{doc.metadata['path_nodes'][0].get('source')}"
        )
        for doc in docs
    ]
    ancestors_info = [
        (
            [
                (
                    f"{pretty_names[ancestor.get('levelOfDescription')]}\n"
                    f"«{ancestor.get('title')}»{process_date(ancestor)}: "
                    f"{ancestor.get('source')}"
                )
                for ancestor in doc.metadata["path_nodes"][1:]
            ]
            if len(doc.metadata["path_nodes"]) > 1
            else ["N/A"]
        )
        for doc in docs
    ]
    authorities_info = [
        (
            [
                f"{pretty_names[authority.get("authorityType")]}: <a href={authority.get("source")}>{authority.get("heading")}</a>"
                for authority in doc.metadata["related_authorities"]
            ]
            if len(doc.metadata["related_authorities"]) != 0
            else ["N/A"]
        )
        for doc in docs
    ]
    if request_type == "show_records":
        return {
            "answer": "Sure, here are the archival materials related to your request:",
            "nodes_info": nodes_info,
            "ancestors_info": ancestors_info,
            "authorities_info": authorities_info,
            "request_type": request_type,
        }
    elif request_type == "question":
        context = "\n\n".join(
            [
                f"Title: {doc.metadata["path_nodes"][0]["title"]} Text: {doc.page_content}"
                for doc in docs
            ]
        )

        answer_prompt = """
        Answer the question based ONLY on the provided context.
        Use ONLY information and wording from the context.
        Do NOT add information that is not explicitly stated in the context,
        even if it seems logical or obvious.
        Do NOT include extra information that is present in the context but is not
        directly relevant to the question.
        If the context doesn't contain an answer to the question, say "I cannot
        respond to your request based on the available archival materials."
        Answer in maximum three sentences.

        Examples:

        Context: Title: Discussion with Congressman Y Text: Congressman Y talked
        about his role in a particular panel, highlighting that the working groups
        operate with significant autonomy. He mentioned that the responsibilities are
        extremely intensive, hindering participants from adequately participating in
        additional congressional tasks and attending to their constituencies.
        Question: What did Congressman Y say about the panel's workload?
        Answer: Congressman Y said the responsibilities are extremely intensive,
        hindering participants from adequately participating in additional
        congressional tasks and attending to their constituencies.

        Context: Title: Interview with Legislator B Text: The speaker indicated
        a preference for a collaborative dynamic with federal bureaus, involving
        mutual idea-sharing. Conversely, he portrayed a fellow lawmaker, Mr. Q, who
        often employs aggressive rhetoric and harbors suspicion toward these entities.
        Question: How does the Legislator B's method with bureaus contrast with Mr. Q's?
        Answer: Legislator B favors collaboration and mutual idea-sharing, whereas Mr.
        Q employs aggressive rhetoric and shows suspicion.

        Context: Title: Dialogue with Lawmaker C Text: Lawmaker C pointed out that
        an individual with intense personal stakes in a specific issue domain ought not
        to be placed on the task force overseeing it, since they might lack
        impartiality.
        Question: According to Lawmaker C, what kind of individual should avoid
        placement on a task force?
        Answer: An individual with intense personal stakes in a specific issue domain
        should not be placed on the task force overseeing it, as they might lack
        impartiality.

        Context: Title: Study Notes on a Task Force Text: The document describes
        assignment to a particular task force as a secondary role due to its divisive
        nature. It emphasizes that participants should secure an additional, more
        favorable position as well.
        Question: Why is assignment to this task force viewed as a secondary role?
        Answer: It is viewed as a secondary role due to its divisive nature, and the
        document emphasizes that participants should secure an additional, more
        favorable position.

        Context: Title: Meeting with Senator Z Text: Senator Z discussed the
        procedural hurdles in forming a bipartisan committee, noting that scheduling
        conflicts among senior members have caused significant delays. He expressed
        hope that the committee would be operational by the next fiscal quarter.
        Question: What views did Senator Z express about tax reforms?
        Answer: I cannot respond to your request based on the available archival
        materials.
        """

        chat = [
            {"role": "system", "content": answer_prompt},
            {"role": "user", "content": f"Context: {context}\n\nQuestion: {request}"},
        ]

        input_text = tokenizer_llm.apply_chat_template(
            chat, tokenize=False, add_generation_prompt=True
        )

        generated = text_generation_pipe(input_text)
        full_output = generated[0]["generated_text"]
        answer_text = (
            full_output.split("[/INST]")[-1].strip()
            if "[/INST]" in full_output
            else full_output.strip()
        )

        return {
            "answer": answer_text,
            "nodes_info": nodes_info,
            "ancestors_info": ancestors_info,
            "authorities_info": authorities_info,
            "request_type": request_type,
        }


pretty_names = {
    "item": "Item",
    "fileUnit": "File Unit",
    "series": "Series",
    "recordGroup": "Record Group",
    "collection": "Collection",
    "geographicPlaceName": "Geographic Place",
    "organization": "Organization",
    "person": "Person",
    "specificRecordsType": "Specific Records Type",
    "topicalSubject": "Topical Subject",
}

In [None]:
bubble_css = """
<style>
.chat-container {
    width: 100%;
    max-width: 800px;
    margin: 0 auto;
    font-family: Arial, sans-serif;
    display: flex;
    flex-direction: column;
}

.user-bubble {
    background-color: #276dc1;
    color: #fff;
    padding: 10px 15px;
    border-radius: 20px;
    margin: 10px 0;
    max-width: 70%;
    align-self: flex-end;
    position: relative;
    box-shadow: 0 1px 2px rgba(0,0,0,0.1);
}

.user-bubble::before {
    content: "";
    position: absolute;
    bottom: -10px;
    right: 20px;
    border-width: 10px 10px 0 10px;
    border-style: solid;
    border-color: #276dc1 transparent transparent transparent;
}

.bot-bubble {
    background-color: #E5E5EA;
    color: #000;
    padding: 10px 15px;
    border-radius: 20px;
    margin: 10px 0;
    max-width: 70%;
    align-self: flex-start;
    position: relative;
    box-shadow: 0 1px 2px rgba(0,0,0,0.1);
}

.bot-bubble::before {
    content: "";
    position: absolute;
    bottom: -10px;
    left: 20px;
    border-width: 10px 10px 0 10px;
    border-style: solid;
    border-color: #E5E5EA transparent transparent transparent;
}

.chat-message {
    display: flex;
    flex-direction: column;
}

.bot-bubble a {
    color: #0066cc;
    text-decoration: underline;
}

.bot-bubble a:hover {
    color: #004499;
}
</style>
"""


def create_bubble(text, is_user=True):
    bubble_class = "user-bubble" if is_user else "bot-bubble"
    wrapped_text = text.replace("\n", "<br>")
    return f"""
    <div class="chat-message">
        <div class="{bubble_class}">
            {wrapped_text}
        </div>
    </div>
    """


def make_urls_clickable(text):
    url_pattern = r'(https?://[^\s<>"]+|www\.[^\s<>"]+)'

    def replace_url(match):
        url = match.group(0)
        if not url.startswith("http"):
            url = "https://" + url
        return f'<a href="{url}" target="_blank">{url}</a>'

    return re.sub(url_pattern, replace_url, text)


def filter_unique_documents(nodes_info, ancestors_info, authorities_info):
    unique_nodes = []
    unique_ancestors = []
    unique_authorities = []

    for i in range(len(nodes_info)):
        if nodes_info[i] not in unique_nodes:
            unique_nodes.append(nodes_info[i])
            unique_ancestors.append(ancestors_info[i])
            unique_authorities.append(authorities_info[i])
        if len(unique_nodes) == 5:
            break

    return unique_nodes, unique_ancestors, unique_authorities


def format_documents_section(unique_nodes, unique_ancestors, unique_authorities, is_cannot=False):
    parts = []
    
    for i in range(len(unique_nodes)):
        source = unique_nodes[i]
        ancestors = unique_ancestors[i]
        authorities = unique_authorities[i]
        source_with_links = make_urls_clickable(source)

        doc_header = f"Record #{i + 1}:"
        parts.append(
            f'<div style="text-align: center;"><strong>{doc_header}</strong></div><br>'
        )
        parts.append(
            f'<div style="text-align: center;"><cite>{source_with_links}</cite></div><br>'
        )

        if authorities != ["N/A"]:
            parts.append(f"{'; '.join(authorities)}<br><br>")

        if ancestors != ["N/A"]:
            ancestors_header = "Ancestor(s) of this archival material:"
            parts.append(
                f'<div style="text-align: center;">{ancestors_header}</div><br>'
            )

            for j, ancestor in enumerate(ancestors, 1):
                ancestor_with_links = make_urls_clickable(ancestor)
                parts.append(f"{j}. {ancestor_with_links}<br><br>")

        parts.append("<br><hr><br>")

    return "".join(parts)


def format_question_response_chat(answer):
    parts = []
    cannot_respond = (
        "I cannot respond to your request based on the available archival materials."
    )
    is_cannot = answer["answer"] == cannot_respond

    unique_nodes, unique_ancestors, unique_authorities = filter_unique_documents(
        answer["nodes_info"], answer["ancestors_info"], answer["authorities_info"]
    )

    if is_cannot:
        answer_text = answer["answer"].replace("\n", "<br>")
        parts.append(f"<p>{answer_text}</p>")
        parts.append("<br><hr><br>")

        interested_header = (
            "You might be interested in the following archival materials:"
        )
        parts.append(f'<div style="text-align: center;">{interested_header}</div><br>')
    else:
        answer_text = answer["answer"].replace("\n", "<br>")
        parts.append(f"<p><strong>Answer: </strong>{answer_text}</p>")
        parts.append("<br><hr><br>")

        sources_header = (
            "The answer was generated based on the following archival materials:"
        )
        parts.append(f'<div style="text-align: center;">{sources_header}</div><br>')

    parts.append(format_documents_section(unique_nodes, unique_ancestors, unique_authorities, is_cannot))
    
    return "".join(parts)


def format_show_records_response_chat(answer):
    parts = []
    cannot_respond = (
        "I cannot respond to your request based on the available archival materials."
    )
    is_cannot = answer["answer"] == cannot_respond

    unique_nodes, unique_ancestors, unique_authorities = filter_unique_documents(
        answer["nodes_info"], answer["ancestors_info"], answer["authorities_info"]
    )

    if is_cannot:
        answer_text = answer["answer"].replace("\n", "<br>")
        parts.append(f"<p>{answer_text}</p><br>")
        parts.append("<br><hr><br>")

        interested_header = (
            "You might be interested in the following archival materials:"
        )
        parts.append(f'<div style="text-align: center;">{interested_header}</div><br>')
    else:
        answer_text = answer["answer"].replace("\n", "<br>")
        parts.append(f"<p>{answer_text}</p>")
        parts.append("<br><hr><br>")

    parts.append(format_documents_section(unique_nodes, unique_ancestors, unique_authorities, is_cannot))
    
    return "".join(parts)


def setup_chat():
    chat_output = widgets.Output()
    with chat_output:
        display(HTML(bubble_css + '<div class="chat-container"></div>'))

    request_input = widgets.Text(
        value="",
        placeholder="Type your request here...",
        description="Request:",
        disabled=False,
        layout=widgets.Layout(width="80%"),
    )

    send_button = widgets.Button(
        description="Send",
        disabled=False,
        button_style="",
        tooltip="Send request",
        icon="paper-plane",
        style={"button_color": "#1a4480"},
    )

    def on_send_button_clicked(b):
        request = request_input.value.strip()
        if not request:
            return

        user_bubble = create_bubble(request, is_user=True)
        with chat_output:
            display(HTML(user_bubble))

        request_input.value = ""

        docs = retrieve_docs(request)

        request_type = classify_request_type(
            request, tokenizer_llm, text_generation_pipe
        )

        answer = generate_answer(
            request, request_type, docs, tokenizer_llm, text_generation_pipe
        )

        if answer["nodes_info"] == "N/A":
            bot_text = answer["answer"]
        else:
            if answer["request_type"] == "question":
                bot_text = format_question_response_chat(answer)
            elif answer["request_type"] == "show_records":
                bot_text = format_show_records_response_chat(answer)
            else:
                bot_text = "Unknown request type."

        bot_bubble = create_bubble(bot_text, is_user=False)
        with chat_output:
            display(HTML(bot_bubble))

    send_button.on_click(on_send_button_clicked)

    display(chat_output)
    input_box = widgets.HBox([request_input, send_button])
    display(input_box)


setup_chat()