In [2]:
#RAG
from PyPDF2 import PdfReader
from transformers import AutoTokenizer, AutoModel 
import torch 
import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
import sentence_transformers
# SQL imports 
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

import warnings 
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
#functions 
def get_vector_store(filename):
    filename = filename
    pdf_reader = PdfReader(filename)
    
    text = ""
    for page in pdf_reader.pages:
        text += page.extract_text()
        
    pdf_docs = []
    text_splitter = RecursiveCharacterTextSplitter(chunk_size = 5000, chunk_overlap = 200)

    for idx, page in enumerate(pdf_reader.pages):
        if len(text) > 0:
            pdf_docs.extend(
                text_splitter.create_documents(
                    texts = [text],
                    metadatas = [{'filename': filename, 'page': idx+1}]
                )
            )
            
    embedding = HuggingFaceEmbeddings(model_name = "sentence-transformers/all-MiniLM-L6-v2")
    return FAISS.from_documents(pdf_docs, embedding)

def get_schema(_):
    schema = db.get_table_info()
    return schema


def run_query(query):
    return db.run(query)

def sql_answer(user_question):
    # add ss search 
    question = user_question
    ss_result = v_db.similarity_search(question)
    top_ss_docs = ss_result[0:1]
    context = " ----- ".join([ss_result.page_content for ss_result in top_ss_docs])  
    return full_chain.invoke({"context": context,"question": user_question}).content

# RAG

In [4]:
v_db = get_vector_store('data_dict.pdf')

# SQL Agent 

In [5]:
# setup database and schema
sqlite_uri = 'sqlite:///./snyth.db' 
db = SQLDatabase.from_uri(sqlite_uri)

In [6]:
#setup model 
key = 'sk-proj-WVAYZM6xMzLkvMYc3p2pT3BlbkFJnlTUh1OzdtEVL1TT7Aq8'
llm = ChatOpenAI(openai_api_key=key)

In [7]:
from langchain_core.prompts import ChatPromptTemplate

template = """Based on the table schema and context below, write a SQL query that would answer the user's question:
{schema}
{context}
Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)

In [8]:
sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)


In [24]:
# user_question = 'how many assets are there in the registry?'
# sql_chain.invoke({"context": context, "question": user_question})

In [10]:
template = """Based on the table schema below, question, sql query, and sql response, and context write a natural language response:
{schema}
{context}
Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)

In [11]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response=lambda vars: run_query(vars["query"]),
    )
    | prompt_response
    | llm
)

In [22]:
user_question = 'what is the most likly vulnerablity source?'

In [23]:
sql_answer(user_question)

SELECT vulnerability_source, COUNT(vulnerability_source) AS source_count
FROM vulnerability
GROUP BY vulnerability_source
ORDER BY source_count DESC
LIMIT 1;


"The most likely vulnerability source based on the data is 'Infra', with a total of 1526 vulnerabilities originating from this source."

In [8]:
user_question = 'how many assets are there in the registry?'
sql_chain.invoke({"question": user_question})

'SELECT COUNT(DISTINCT asset_identifier) AS total_assets\nFROM registry;'

In [9]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)


In [10]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response=lambda vars: run_query(vars["query"]),
    )
    | prompt_response
    | model
)

In [11]:
user_question = 'how many assets are there in the registry?'
message_content = full_chain.invoke({"question": user_question}).content
message_content

'There are 999 assets in the registry table.'