In [1]:
"""
This is a RAG that:
 - gets mocked up data from the APSViz DB
 - uses a mini-LLM for the NLM.
 - uses
"""

import os
import psycopg2
import json
import warnings
import faiss
import torch

from dotenv import load_dotenv
from rich.console import Console
from rich import pretty

from sentence_transformers import SentenceTransformer
from transformers import RagSequenceForGeneration, DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer, BartTokenizer

# load the secret DB credentials
load_dotenv()

# init Rich formatting
pretty.install()

# create a console with the dark theme
console = Console() # theme=dark

# suppress warnings
warnings.filterwarnings('ignore')

# set to "cuda" if you have GPU and torch.cuda.is_available()
device = "cpu"


  from .autonotebook import tqdm as notebook_tqdm





In [2]:
def run_query(query):
    """
    runs a query against the APSViz DB.

    Note this notebook expects localhost to be connected to a postgres DB.
    :param query:
    :return:
    """
    # Database connection parameters
    connection = psycopg2.connect(dbname="apsviz", user=os.getenv("PG_USER"), password=os.getenv("PG_USER"), host="localhost", port="5432")

    results = None

    with connection.cursor() as cursor:
        try:
            # Create a cursor object
            cursor = connection.cursor()

            # Execute an SQL query
            cursor.execute(query)

            # Fetch and print results
            results = cursor.fetchall()

        except Exception as e:
            print("An error occurred:", e)

        finally:
            # Close the cursor and connection
            cursor.close()
            connection.close()

    return results[0][0]

In [3]:
# create the SQL and get mocked up data from the DB.
# note we are converting number values to feet.
query = """
            SELECT json_agg(row_to_json(t))
            FROM (
                SELECT name, station_id, abbrev, lon, lat,
                CASE WHEN nos_minor IS NOT NULL THEN (nos_minor * 3.28084) ELSE NULL END AS nos_minor,
                CASE WHEN nos_moderate IS NOT NULL THEN (nos_moderate * 3.28084) ELSE NULL END AS nos_moderate,
                CASE WHEN nos_major IS NOT NULL THEN (nos_major * 3.28084) ELSE NULL END AS nos_major,
                CASE WHEN nws_minor IS NOT NULL THEN (nws_minor * 3.28084) ELSE NULL END AS nws_minor,
                CASE WHEN nws_moderate IS NOT NULL THEN (nws_moderate * 3.28084) ELSE NULL END AS nws_moderate,
                CASE WHEN nws_major IS NOT NULL THEN (nws_major * 3.28084) ELSE NULL END AS nws_major,
                FLOOR(random() * 5 + 1)::INT AS current_level
                FROM noaa_station_levels
                ORDER BY name
            ) t ;
        """
# get the station data
data = run_query(query)

In [4]:
def get_flood_stage(values):
    """
    Gets the flood stage based on the station data

    Note "current_data" is a random number (1 to 5) already generated in the data
    """
    if ((values['nos_major'] and values['nos_major'] - values['current_level'] < 0) or (
            values['nws_major'] and values['nws_major'] - values['current_level'] < 0)):
        return 'major flooding'
    elif ((values['nos_moderate'] and values['nos_moderate'] - values['current_level'] < 0) or (
            values['nws_moderate'] and values['nws_moderate'] - values['current_level'] < 0)):
        return 'moderate flooding'
    elif ((values['nos_minor'] and values['nos_minor'] - values['current_level'] < 0) or (
            values['nws_minor'] and values['nws_minor'] - values['current_level'] < 0)):
        return 'minor flooding'
    else:
        return 'no flooding'

In [5]:
"""
get the data in an acceptable format
"""

docs = []
metadata = []

# for each record returned from the DB
for d in data:
    # get the flood stage label for this station
    flood_stage = get_flood_stage(d)

    # create some data tags
    tags = [str(d['station_id']), d['name'], flood_stage]

    # get the data in a understandable format
    text = f"{d['name']}: {flood_stage} (tags: {','.join(tags)})"

    # save the
    docs.append(text)
    metadata.append(d)

In [6]:
"""
embed that data and create an embedding index
"""

# load the transformer
embedder = SentenceTransformer("all-MiniLM-L6-v2")

# encode the data
doc_embeddings = embedder.encode(docs, convert_to_numpy=True)

# get the extents of the data
dim = doc_embeddings.shape[1]

# create an index
index = faiss.IndexFlatL2(dim)

# add the embeddings to the index
index.add(doc_embeddings)

In [7]:
def retrieve_with_scores(query, top_k=2):
    """
    gets the answer data with the scores

    :param query:
    :param top_k:
    :return:
    """
    # encode the query
    q_emb = embedder.encode([query], convert_to_numpy=True)  # (1, dim)

    # get the scores and answer/data indexes
    scores, idx = index.search(q_emb, top_k)  # returned L2 dists

    # convert L2 -> similarity by negative distance (simple)
    similarity = -scores[0].astype(float)  # length top_k

    similarity = sorted(similarity, reverse=False)

    # init the results
    results = []

    j = 0

    # loop through the answers
    for i in idx[0]:
        # put away the result
        results.append({"text": docs[i], 'score': similarity[j], "metadata": metadata[i]})
        j += 1

    # sort the results by score
    results = sorted(results, key=lambda x: x["score"], reverse=False)

    # get the embeddings of the answers
    retrieved_embeddings = doc_embeddings[idx[0]]

    return results, similarity, retrieved_embeddings

In [8]:
results, similarity, retrieved_embeddings = retrieve_with_scores("Show me places that have minor flooding.", 5)

console.print('results:', json.dumps(results), 'similarity:', similarity)  # , 'embeddings:', retrieved_embeddings

In [9]:
# init facebook RAG encoders
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq").to(device)
q_tok = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
ctx_tok = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
gen_tok = BartTokenizer.from_pretrained("facebook/bart-large")

Some weights of the model checkpoint at facebook/rag-sequence-nq were not used when initializing RagSequenceForGeneration: ['rag.question_encoder.question_encoder.bert_model.pooler.dense.bias', 'rag.question_encoder.question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing RagSequenceForGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RagSequenceForGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class

In [10]:
def fallback_concat_generate(query: str, contexts):
    """
    dumps out the data when there has been an exception

    :param query:
    :param contexts:
    :return:
    """
    # create the model with the same generator weights
    gen_model = model.generator.to(device)

    # Build prompt: query + contexts joined (keeps provenance)
    full_prompt = query + " " + " ".join([f"[DOC{i}] {c['text']}" for i, c in enumerate(contexts, start=1)])

    # get the input tokens
    inputs = gen_tok(full_prompt, return_tensors="pt", truncation=True, padding=True).to(device)

    # generate the output
    out = gen_model.generate(input_ids=inputs["input_ids"],
                             attention_mask=inputs["attention_mask"],
                             max_length=64,
                             num_beams=2)

    # get the answer
    answer = gen_tok.batch_decode(out, skip_special_tokens=True)[0]

    console.print('Query:', query)

    # return the results
    return json.dumps({"answer": answer, "sources": contexts, "note": "fallback_concat_used"}, indent=4)

In [11]:
def rag_answer(query, top_k=2):
    """
    gets the answer

    :param query:
    :param top_k:
    :return:
    """
    # get the results
    retrieved_texts, sims, retrieved_embs = retrieve_with_scores(query, top_k=top_k)

    # get a list of the text outputs
    texts = [r["text"] for r in retrieved_texts]
    scores = [r["score"] for r in retrieved_texts]

    # get the length of the results
    k = len(texts)

    # make sure we have results
    if k == 0:
        return {"query": query, "answer": "(no answers)"}

    # tell the model how many docs per query
    model.config.n_docs = k

    # tokenize the question (DPRQuestion tokenizer)
    q_inputs = q_tok(query, return_tensors="pt", truncation=True, padding=True).to(device)

    # tokenize the contexts (DPRContext tokenizer)
    ctx_inputs = ctx_tok(texts, return_tensors="pt", truncation=True, padding=True).to(device)

    # build doc_scores tensor: shape (batch_size, n_docs) => here batch_size=1
    # use similarities computed from FAISS (converted to float)
    doc_scores = torch.tensor(scores, dtype=torch.float32, device=device).unsqueeze(0)  # (1, k)

    # debug prints to understand the results
    print("Debug shapes for: ", query)
    print(" q input_ids:", q_inputs["input_ids"].shape, "dtype:", q_inputs["input_ids"].dtype)
    print(" q attention_mask:", q_inputs["attention_mask"].shape)
    print(" ctx input_ids:", ctx_inputs["input_ids"].shape, "dtype:", ctx_inputs["input_ids"].dtype)
    print(" ctx attention_mask:", ctx_inputs["attention_mask"].shape)
    print(" model.config.n_docs:", model.config.n_docs)
    print(" doc_scores:", doc_scores.shape, "values:", doc_scores.tolist())

    # final safety checks
    assert "input_ids" in q_inputs and "input_ids" in ctx_inputs, "missing tokenized ids"
    assert q_inputs["input_ids"].dtype == torch.long and ctx_inputs["input_ids"].dtype == torch.long
    assert doc_scores.shape[1] == ctx_inputs["input_ids"].shape[0], f"doc_scores dim mismatch: {doc_scores.shape[1]} vs ctx count {ctx_inputs['input_ids'].shape[0]}"

    try:
        # get the RAG results
        out_ids = model.generate(input_ids=q_inputs["input_ids"],
                                 attention_mask=q_inputs["attention_mask"],
                                 context_input_ids=ctx_inputs["input_ids"],
                                 context_attention_mask=ctx_inputs["attention_mask"],
                                 doc_scores=doc_scores, n_docs=k, max_length=64, num_beams=2)

        # grab the answer
        answer = gen_tok.batch_decode(out_ids, skip_special_tokens=True)[0]

        console.print("query:", query)

        return {"query": query, "answer": answer, "sources": retrieved_texts}

    except AssertionError as e:
        console.print("RAG generate assertion ERROR:", e)

        # fall back to concatenation method (guaranteed to work)
        return fallback_concat_generate(query, retrieved_texts)
    except Exception as e:
        console.print("RAG generate other ERROR:", type(e).__name__, str(e))

        # fall back to concatenation method (guaranteed to work)
        return fallback_concat_generate(query, retrieved_texts)

In [12]:
prompts = [
    ['Show me places that have major flooding.', 5],
    ['Show me places that have moderate flooding.', 5],
    ['Show me places that have minor flooding.', 5],
    ['Show me places that have no flooding.', 5],
    ['What is going on in eastport?', 5]]

# console.print(rag_answer(prompts[0][0], prompts[0][1]))

for x in prompts:
    ret_val = rag_answer(x[0], x[1])
    console.print(ret_val)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Debug shapes for:  Show me places that have major flooding.
 q input_ids: torch.Size([1, 10]) dtype: torch.int64
 q attention_mask: torch.Size([1, 10])
 ctx input_ids: torch.Size([5, 25]) dtype: torch.int64
 ctx attention_mask: torch.Size([5, 25])
 model.config.n_docs: 5
 doc_scores: torch.Size([1, 5]) values: [[-0.783881425857544, -0.7603001594543457, -0.7555378079414368, -0.74961256980896, -0.6202335357666016]]


Debug shapes for:  Show me places that have moderate flooding.
 q input_ids: torch.Size([1, 10]) dtype: torch.int64
 q attention_mask: torch.Size([1, 10])
 ctx input_ids: torch.Size([5, 21]) dtype: torch.int64
 ctx attention_mask: torch.Size([5, 21])
 model.config.n_docs: 5
 doc_scores: torch.Size([1, 5]) values: [[-0.7174880504608154, -0.7146773934364319, -0.6849405765533447, -0.6782433986663818, -0.6591987013816833]]


Debug shapes for:  Show me places that have minor flooding.
 q input_ids: torch.Size([1, 10]) dtype: torch.int64
 q attention_mask: torch.Size([1, 10])
 ctx input_ids: torch.Size([5, 25]) dtype: torch.int64
 ctx attention_mask: torch.Size([5, 25])
 model.config.n_docs: 5
 doc_scores: torch.Size([1, 5]) values: [[-0.7437201738357544, -0.7253307104110718, -0.7077674865722656, -0.6955010294914246, -0.638798713684082]]


Debug shapes for:  Show me places that have no flooding.
 q input_ids: torch.Size([1, 10]) dtype: torch.int64
 q attention_mask: torch.Size([1, 10])
 ctx input_ids: torch.Size([5, 25]) dtype: torch.int64
 ctx attention_mask: torch.Size([5, 25])
 model.config.n_docs: 5
 doc_scores: torch.Size([1, 5]) values: [[-0.7934206128120422, -0.7930344343185425, -0.7926424741744995, -0.7643238306045532, -0.7381272912025452]]


Debug shapes for:  What is going on in eastport?
 q input_ids: torch.Size([1, 10]) dtype: torch.int64
 q attention_mask: torch.Size([1, 10])
 ctx input_ids: torch.Size([5, 35]) dtype: torch.int64
 ctx attention_mask: torch.Size([5, 35])
 model.config.n_docs: 5
 doc_scores: torch.Size([1, 5]) values: [[-1.111785650253296, -1.0407010316848755, -0.9967638850212097, -0.9609330892562866, -0.6696454882621765]]
