In [1]:
"""
This is a RAG that:
 - gets mocked up data from the APSViz DB
 - uses a mini-LLM for the NLM.
 - uses
"""

import os
import psycopg2
import json
import warnings
import faiss

from dotenv import load_dotenv
from rich.console import Console
from rich import pretty

from sentence_transformers import SentenceTransformer
from transformers import RagSequenceForGeneration, BartTokenizer

# load the secret DB credentials
load_dotenv()

# init Rich formatting
pretty.install()

# create a console with the dark theme
console = Console() # theme=dark

# suppress warnings
warnings.filterwarnings('ignore')

# set to "cuda" if you have GPU and torch.cuda.is_available()
device = "cpu"


  from .autonotebook import tqdm as notebook_tqdm





In [2]:
def run_query(sql_query):
    """
    runs a query against the APSViz DB.

    Note this notebook expects localhost to be connected to a postgres DB.
    :param sql_query:
    :return:
    """
    # Database connection parameters
    connection = psycopg2.connect(dbname="apsviz", user=os.getenv("PG_USER"), password=os.getenv("PG_PASS"), host="localhost", port="5432")

    # init the return
    ret_val = None

    with connection.cursor() as cursor:
        try:
            # Create a cursor object
            cursor = connection.cursor()

            # Execute an SQL query
            cursor.execute(sql_query)

            # Fetch and print results
            ret_val = cursor.fetchall()

        except Exception as e:
            print("An error occurred:", e)

        finally:
            # Close the cursor and connection
            cursor.close()
            connection.close()

    return ret_val[0][0]

In [3]:
# create the SQL and get mocked up data from the DB.
# note we are converting number values to feet, and we are creating a random value for the current water height.
sql = """
            SELECT json_agg(row_to_json(t))
            FROM (
                SELECT name, station_id, abbrev, lon, lat,
                CASE WHEN nos_minor IS NOT NULL THEN (nos_minor * 3.28084) ELSE 0 END AS nos_minor,
                CASE WHEN nos_moderate IS NOT NULL THEN (nos_moderate * 3.28084) ELSE 0 END AS nos_moderate,
                CASE WHEN nos_major IS NOT NULL THEN (nos_major * 3.28084) ELSE 0 END AS nos_major,
                CASE WHEN nws_minor IS NOT NULL THEN (nws_minor * 3.28084) ELSE 0 END AS nws_minor,
                CASE WHEN nws_moderate IS NOT NULL THEN (nws_moderate * 3.28084) ELSE 0 END AS nws_moderate,
                CASE WHEN nws_major IS NOT NULL THEN (nws_major * 3.28084) ELSE 0 END AS nws_major,
                (RANDOM() * 5.0 + 1.0) AS current_height
                FROM noaa_station_levels
                ORDER BY name
            ) t ;
        """

# get the station data
data = run_query(sql)

In [4]:
def get_flood_stage(values):
    """
    Gets the flood stage based on the station data

    Note "current_height" is a random number (1 to 5) already generated in the data
    """

    # init the return value
    ret_val = 'no flooding'

    if ((values['nos_major'] and values['nos_major'] - values['current_height'] < 0) or (
            values['nws_major'] and values['nws_major'] - values['current_height'] < 0)):
        ret_val = 'major flooding'
    elif ((values['nos_moderate'] and values['nos_moderate'] - values['current_height'] < 0) or (
            values['nws_moderate'] and values['nws_moderate'] - values['current_height'] < 0)):
        ret_val = 'moderate flooding'
    elif ((values['nos_minor'] and values['nos_minor'] - values['current_height'] < 0) or (
            values['nws_minor'] and values['nws_minor'] - values['current_height'] < 0)):
        ret_val = 'minor flooding'

    print('\nStation:', values['name'], 'current_height:', values['current_height'])
    print('values[nos_major]', values['nos_major'] - values['current_height'])
    print('values[nos_moderate]', values['nos_moderate'] - values['current_height'])
    print('values[nos_minor]', values['nos_minor'] - values['current_height'])
    print('values[nws_major]', values['nws_major'] - values['current_height'])
    print('values[nws_moderate]', values['nws_moderate'] - values['current_height'])
    print('values[nws_minor]', values['nws_minor'] - values['current_height'])
    print('ret_val', ret_val)

    return ret_val

In [5]:
"""
get the data in an acceptable format
"""

docs = []
metadata = []

# for each record returned from the DB
for d in data:
    # get the flood stage label for this station
    flood_stage = get_flood_stage(d)

    # create some data tags
    tags = [str(d['station_id']), d['name'], flood_stage]

    # get the data in an understandable format
    text = f"{d['name']}: {flood_stage} (tags: {','.join(tags)})"

    # save the
    docs.append(text)
    metadata.append(d)


Station: Annapolis current_height: 1.1960354083574178
values[nos_major] 3.4151452013986794
values[nos_moderate] 2.184830201398679
values[nos_minor] 1.2145817867645339
values[nws_major] 3.0850606892035572
values[nws_moderate] 1.3846253233498995
values[nws_minor] 0.6844460550572165
ret_val no flooding

Station: Apalachicola current_height: 4.678886838187268
values[nos_major] -0.06770622843117025
values[nos_moderate] -1.2980212284311698
values[nos_minor] -2.2682696430653166
values[nws_major] 2.0228290154712694
values[nws_moderate] 0.022316820349317368
values[nws_minor] -1.778144155260439
ret_val major flooding

Station: Aransas, Aransas Pass current_height: 1.5726516681617477
values[nos_major] -1.5726516681617477
values[nos_moderate] -1.5726516681617477
values[nos_minor] -1.5726516681617477
values[nws_major] 2.4483778440333746
values[nws_moderate] 1.448121746472399
values[nws_minor] 0.4478656489114239
ret_val no flooding

Station: Aransas Wildlife Refuge current_height: 3.986426605832157

In [6]:
"""
embed that data and create an embedding index
"""

# load the transformer
embedder = SentenceTransformer("all-MiniLM-L6-v2")

# encode the data
encoded = embedder.encode(docs, convert_to_numpy=True)

# get the extents of the data
dim = encoded.shape[1]

# create an index
index = faiss.IndexFlatL2(dim)

# add the embeddings to the index
index.add(encoded)

In [7]:
def retrieve_with_scores(query, top_k=2):
    """
    gets the answer data with the scores

    :param query:
    :param top_k:
    :return:
    """
    # encode the query
    q_emb = embedder.encode([query], convert_to_numpy=True)  # (1, dim)

    # get the scores and answer/data indexes
    scores, idx = index.search(q_emb, top_k)  # returned L2 dists

    # convert L2 -> similarity by negative distance (simple)
    similarity_score = -scores[0].astype(float)  # length top_k

    # sort the response records in order of the scores
    similarity_score = sorted(similarity_score, reverse=False)

    # init the results
    contexts = []

    # init counter
    j = 0

    # loop through the answers
    for i in idx[0]:
        # put away the result
        contexts.append({"text": docs[i], 'score': similarity_score[j], "metadata": metadata[i]})

        # increment the counter
        j += 1

    # sort the results by score
    contexts = sorted(contexts, key=lambda s: s["score"], reverse=False)

    # get the embeddings of the answers
    embeddings = encoded[idx[0]]

    return contexts, similarity_score, embeddings

In [8]:
# get the results for a prompt
results, score, retrieved_embeddings = retrieve_with_scores("Show me places that have minor flooding.", 5)

#console.print('results:', json.dumps(results), 'score:', score)  # , 'embeddings:', retrieved_embeddings

In [9]:
# init facebook RAG encoders
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq").to(device)

gen_tok = BartTokenizer.from_pretrained("facebook/bart-large")

Some weights of the model checkpoint at facebook/rag-sequence-nq were not used when initializing RagSequenceForGeneration: ['rag.question_encoder.question_encoder.bert_model.pooler.dense.bias', 'rag.question_encoder.question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing RagSequenceForGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RagSequenceForGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
def concat_generate(query: str, contexts):
    """
    dumps out the data when there has been an exception

    :param query:
    :param contexts:
    :return:
    """
    # create the model with the same generator weights
    gen_model = model.generator.to(device)

    # Build the prompt: query + contexts joined as it keeps provenance
    full_prompt = query + " " + " ".join([f"[DOC{i}] {c['text']}" for i, c in enumerate(contexts, start=1)])

    # tokenize and get the input tokens
    inputs = gen_tok(full_prompt, return_tensors="pt", truncation=True, padding=True).to(device)

    # generate the output
    out = gen_model.generate(input_ids=inputs["input_ids"],
                             attention_mask=inputs["attention_mask"],
                             max_length=64,
                             num_beams=2)

    # get the answer
    answer = gen_tok.batch_decode(out, skip_special_tokens=True)[0]

    # return the results
    return json.dumps({"answer": answer, "sources": contexts}, indent=4)

In [11]:
def rag_answer(prompt, top_k=2):
    """
    gets the answer

    :param prompt:
    :param top_k:
    :return:
    """
    # get the results
    retrieved_texts, sims, retrieved_embs = retrieve_with_scores(prompt, top_k)

    # get a list of the text outputs
    texts = [r["text"] for r in retrieved_texts]

    # get the length of the results
    k = len(texts)

    # make sure we have results
    if k == 0:
        return {"query": prompt, "answer": "(no answers)"}

    # tell the model how many docs per query
    model.config.n_docs = k

    # get the RAG results
    ret_val = concat_generate(prompt, retrieved_texts)

    # return the output results
    return ret_val

In [12]:
prompts = [
    ['Show me places that have major flooding.', 5],

    # ['Show me places that have moderate flooding.', 5],
    # ['Show me places that have minor flooding.', 5],
    # ['Show me places that have no flooding.', 5],
    # ['What is going on in eastport?', 5]
]

# output the results
for x in prompts:
    console.print(f'-- Prompt: {x[0]} --')
    result = rag_answer(x[0], x[1])

    console.print('Result:', result)
    console.print('-- Complete --')