In [1]:
from rich.console import Console
from rich_theme_manager import ThemeManager
import pathlib
import warnings
"""
This is a simple RAG that:
 - gets mocked up data from the APSViz DB
 - embeds sentence transformer vectors into a qdrant vector DB
 - uses a mini-LLM for NLM.
"""
from dotenv import load_dotenv
import os
import psycopg2

# load the secret DB credentials
load_dotenv()

# load the theme
theme_dir = pathlib.Path("themes")
theme_manager = ThemeManager(theme_dir=theme_dir)
dark = theme_manager.get("dark")

# create a console with the dark theme
console = Console(theme=dark)

# suppress warnings
warnings.filterwarnings('ignore')

In [2]:
def run_query(query):
    """
    runs a query against the APSViz DB.

    Note this notebook expects localhost to be connected to a postgres DB.
    :param query:
    :return:
    """
    # Database connection parameters
    connection = psycopg2.connect(dbname="apsviz", user=os.getenv("PG_USER"), password=os.getenv("PG_USER"), host="localhost", port="5432")

    results = None

    with connection.cursor() as cursor:
        try:
            # Create a cursor object
            cursor = connection.cursor()

            # Execute an SQL query
            cursor.execute(query)

            # Fetch and print results
            results = cursor.fetchall()

        except Exception as e:
            print("An error occurred:", e)

        finally:
            # Close the cursor and connection
            cursor.close()
            connection.close()

    return results[0][0]

In [3]:
# create the SQL and get mocked up data from the DB.
# note we are converting number values to feet.
query = """
            SELECT json_agg(row_to_json(t))
            FROM (
                SELECT station_id, abbrev, name, lon, lat,
                CASE WHEN nos_minor IS NOT NULL THEN (nos_minor * 3.28084) ELSE NULL END AS nos_minor,
                CASE WHEN nos_moderate IS NOT NULL THEN (nos_moderate * 3.28084) ELSE NULL END AS nos_moderate,
                CASE WHEN nos_major IS NOT NULL THEN (nos_major * 3.28084) ELSE NULL END AS nos_major,
                CASE WHEN nws_minor IS NOT NULL THEN (nws_minor * 3.28084) ELSE NULL END AS nws_minor,
                CASE WHEN nws_moderate IS NOT NULL THEN (nws_moderate * 3.28084) ELSE NULL END AS nws_moderate,
                CASE WHEN nws_major IS NOT NULL THEN (nws_major * 3.28084) ELSE NULL END AS nws_major,
                FLOOR(random() * 5 + 1)::INT AS current_level
                FROM noaa_station_levels
                ORDER BY name
            ) t ;
        """
# get the station data
stations = run_query(query)

In [4]:
def get_flood_stage(values):
    """
    Gets the flood stage based on the station data
    Note "current_data" is a random number (1 to 5) already generated in the data
    """
    if ((values['nos_major'] and values['nos_major'] - values['current_level'] < 0) or
        (values['nws_major'] and values['nws_major'] - values['current_level'] < 0)):
        return 'major flooding'
    elif ((values['nos_moderate'] and values['nos_moderate'] - values['current_level'] < 0) or
          (values['nws_moderate'] and values['nws_moderate'] - values['current_level'] < 0)):
        return 'moderate flooding'
    elif ((values['nos_minor'] and values['nos_minor'] - values['current_level'] < 0) or
          (values['nws_minor'] and values['nws_minor'] - values['current_level'] < 0)):
        return 'minor flooding'
    else:
        return 'no flooding'


In [5]:
from qdrant_client import models, QdrantClient
from qdrant_client.models import VectorParams, PointStruct, Filter, FieldCondition, Range
from sentence_transformers import SentenceTransformer

"""
Create and load a vector DB with encoded sentence vectors with the stations data

"""
# get a vector DB and a sentence transformer
qdrant = QdrantClient(":memory:")
encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") #all-MiniLM-L6-v2 paraphrase-multilingual-MiniLM-L12-v2

# Create collection with vector size
qdrant.recreate_collection(collection_name="stations", vectors_config=VectorParams(size=encoder.get_sentence_embedding_dimension(), distance=models.Distance.COSINE))




True

In [6]:
# array for data points
points = []

# create an embedded vector for each station
for item in stations:
    # get the flood stage based on the data
    flood_stage = get_flood_stage(item)

    # use the flooding level and encode it
    encoded = encoder.encode(flood_stage)

    # convert to a encoded vector list
    embedding = encoded.tolist()

    # create the data points
    points.append(
        PointStruct(id=item["station_id"], vector=embedding,
                    payload={
                        "station_id": item["station_id"],
                        "abbrev": item["abbrev"],
                        "name": item["name"],
                        "lon": item["lon"], "lat": item["lat"],
                        "nos_minor": item['nos_minor'],
                        "nos_moderate": item['nos_moderate'],
                        "nos_major": item['nos_major'],
                        "nws_minor": item['nws_minor'],
                        "nws_moderate": item['nws_moderate'],
                        "nws_major": item['nws_major'],
                        "current_level": item["current_level"],
                        "flooded": flood_stage}
        )
    )

In [7]:
# load the data points
qdrant.upsert(collection_name="stations", points=points)

UpdateResult(operation_id=0, status=<UpdateStatus.COMPLETED: 'completed'>)

In [8]:
# show some stats on the qdrant collection
console.print(qdrant.get_collection(collection_name="stations"))

In [9]:
# what do we want to know
user_prompt = "which stations have major flooding"

# turn the query into a vector
query_vector = encoder.encode(user_prompt).tolist()

# get the results
results = qdrant.search(collection_name="stations", query_vector=query_vector, limit=3)

console.print('Q:', user_prompt)

# output the results
for r in results:
    console.print('A: Station', r.payload['name'], 'has', r.payload['flooded'] + '.')
    console.print('score:', round(r.score, 3), '\nfull payload:', r.payload)