In [2]:
# !pip install torch 
# !pip install psycopg2
# !pip install rank_bm25
# !pip install langchain 
# !pip install langchain_community
# !pip install langchain_openai
# !pip install sentencepiece
# !pip install transformers 
# !pip install datasets
# !pip install accelerate 
# %env SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL True
# !pip install -U sklearn
# !pip install spacy
# !python3 --version
# !pip install -U transformers 
# !pip install PyPDF2
# !pip install sentence_transformers
# !pip install --upgrade boto3
# !pip install faiss-cpu

In [3]:
# !pip install faiss-cpu

In [4]:
import pandas as pd 
import numpy as np
import torch
import torch.nn as nn 
from sklearn.model_selection import train_test_split 

import json 
import copy 
import gc 
import os 
import re 
from collections import defaultdict
from pathlib import Path 

from transformers import AutoTokenizer 


from spacy.lang.en import English 
from transformers.tokenization_utils import PreTrainedTokenizerBase 
from transformers.models.deberta_v2 import (
    DebertaV2ForTokenClassification,
    DebertaV2TokenizerFast,
)
from transformers.trainer import Trainer 
from transformers.training_args import TrainingArguments
from transformers.trainer_utils import EvalPrediction 
from transformers.data.data_collator import DataCollatorForTokenClassification
from datasets import (
    Dataset, 
    DatasetDict, 
    concatenate_datasets,
    features
)
from transformers import AutoConfig

from transformers import TextClassificationPipeline

import argparse 
from itertools import chain 
from functools import partial 

from transformers import AutoModelForSequenceClassification, DataCollatorWithPadding

import random 

#SQL Agent imports 
#RAG
from PyPDF2 import PdfReader
from transformers import AutoTokenizer, AutoModel 
import torch 
import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
import sentence_transformers
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain.chains.conversation.memory import ConversationBufferMemory

import warnings 
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from langchain_core.messages import AIMessage, HumanMessage

# Initialize routing agent from saved model 

In [6]:
#Please provide the pre-trained model path
modelpath = '/root/checkpoint-894'
model = AutoModelForSequenceClassification.from_pretrained(modelpath)
tokenizer = AutoTokenizer.from_pretrained(modelpath)

pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer)

Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [7]:
def routing_agent(user_question):
    prediction = pipe(user_question)
    print(prediction)
    if prediction[0]['label'] == 'LABEL_1':
        sql_result = sql_answer(user_question)
        memory.save_context({"input":  sql_result["query"]}, {"output": sql_result["result"]})
        return sql_result
    if prediction[0]['label'] == 'LABEL_0':
        tdr_result = "This is a tdr question"
        answer = create_RetrievalQA_chain(user_question,model,thread_bm25_store,prompt,'bm25',30) 
        memory.save_context({"input":  answer["query"]}, {"output": answer["result"]})
        return answer


In [27]:
#add memory 
memory = ConversationBufferMemory(return_messages=True)

# Add SQL Agent  

In [9]:
import boto3
from langchain_community.chat_models import BedrockChat

from boto3 import client
from botocore.config import Config
def load_model():
    config = Config(read_timeout=1000)

    bedrock_runtime = boto3.client(service_name='bedrock-runtime', 
                          region_name='us-east-1',
                          config=config)

    model_id = "anthropic.claude-3-5-sonnet-20240620-v1:0"
    
    model_kwargs = { 
        "max_tokens": 100000,
        "temperature": 0,
        "top_k": 250,
        "top_p": 1,
        "stop_sequences": ["\n\nHuman"],
        
    }
    
 
    
    model = BedrockChat(
        client=bedrock_runtime,
        model_id=model_id,
        model_kwargs=model_kwargs,
    )

    return model

sql_llm = load_model()

In [10]:
#functions 
def get_vector_store(filename):
    filename = filename
    pdf_reader = PdfReader(filename)
    
    text = ""
    for page in pdf_reader.pages:
        text += page.extract_text()
        
    pdf_docs = []
    text_splitter = RecursiveCharacterTextSplitter(chunk_size = 5000, chunk_overlap = 200)

    for idx, page in enumerate(pdf_reader.pages):
        if len(text) > 0:
            pdf_docs.extend(
                text_splitter.create_documents(
                    texts = [text],
                    metadatas = [{'filename': filename, 'page': idx+1}]
                )
            )
            
    embedding = HuggingFaceEmbeddings(model_name = "sentence-transformers/all-MiniLM-L6-v2")
    return FAISS.from_documents(pdf_docs, embedding)

def get_schema(_):
    schema = db.get_table_info()
    return schema


def run_query(query):
    return db.run(query)

def sql_answer(user_question):
    # add Try to run SQL 5 times 
    error_counter = 0 
    while error_counter < 5:
        try:
            question = user_question
            ss_result = v_db.similarity_search(question)
            top_ss_docs = ss_result[0:1]
            context = " ----- ".join([ss_result.page_content for ss_result in top_ss_docs])  
            result = full_chain.invoke(({"context": context,"history":memory.load_memory_variables({}), "question": user_question})).content
            return {"query" : user_question, 'result': result}
        except:
            error_counter += 1 
    result = "Unable to generate answer based on the question being asked. please try again with a different question."
    return {"query" : user_question, 'result': result}
    

In [11]:
# RAG
v_db = get_vector_store('data_dict.pdf')

In [12]:
# SQL Agent 

In [13]:
# setup database and schema
sqlite_uri = 'sqlite:///./snyth.db' 
db = SQLDatabase.from_uri(sqlite_uri)

In [14]:
#setup model 
key = ''
sql_llm_gpt = ChatOpenAI(openai_api_key=key)

In [15]:
#sql chain 
template = """Based on the table schema and context below, write a SQL query that would answer the user's question. Only return the sql code not any explanation.:
{schema}
{context}
{history} 
Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)

In [16]:
sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | sql_llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [17]:
#full chain 
template = """Based on the table schema below, question, sql query, and sql response, and context write a natural language response. The answer should be concise one or two sentance and should include any nubers from the original query. If the answer involvs a list of output, reutrn the full list. You should not refernce the database or its tables:
{schema}
{context}
{history}
Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)

In [18]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response=lambda vars: run_query(vars["query"]),
    )
    | prompt_response
    | sql_llm
)

In [19]:
user_question = 'what is the most likly vulnerablity source?'
# user_question = 'How many apples are in the dog table?'
sql_answer(user_question)

{'query': 'what is the most likly vulnerablity source?',
 'result': 'The most likely vulnerability source is Infra, with 1,526 occurrences in the database.'}

# Add TDR Agent  

In [22]:
import boto3
from langchain_community.chat_models import BedrockChat

from boto3 import client
from botocore.config import Config
def load_model():
    config = Config(read_timeout=1000)

    bedrock_runtime = boto3.client(service_name='bedrock-runtime', 
                          region_name='us-east-1',
                          config=config)

    model_id = "anthropic.claude-3-5-sonnet-20240620-v1:0"
    
    model_kwargs = { 
        "max_tokens": 100000,
        "temperature": 0,
        "top_k": 250,
        "top_p": 1,
        "stop_sequences": ["\n\nHuman"],
        
    }
    
 
    
    model = BedrockChat(
        client=bedrock_runtime,
        model_id=model_id,
        model_kwargs=model_kwargs,
    )

    return model

llm = load_model()


In [23]:
import boto3
from langchain_community.chat_models import BedrockChat

from boto3 import client
from botocore.config import Config

In [24]:
from langchain import PromptTemplate


def load_prompt():
    with open(f'prompt_template/final_prompt.txt', 'r') as file:
        data = file.read()
    prompt_template = data
    prompt = PromptTemplate(
        template=prompt_template, input_variables=["context","question"]
    )
    return prompt
prompt = load_prompt()

In [25]:
def load_model():
    config = Config(read_timeout=1000)

    bedrock_runtime = boto3.client(service_name='bedrock-runtime', 
                          region_name='us-east-1',
                          config=config)
    model_id = "anthropic.claude-3-5-sonnet-20240620-v1:0"
    
    model_kwargs = { 
        "max_tokens": 100000,
        "temperature": 0,
        "top_k": 250,
        "top_p": 1,
        "stop_sequences": ["\n\nHuman"],
        
    }
    
 
    
    model = BedrockChat(
        client=bedrock_runtime,
        model_id=model_id,
        model_kwargs=model_kwargs,
    )

    return model
model = load_model()

In [26]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import RetrievalQA

def create_RetrievalQA_chain(query,model,store,prompt,store_type,no_of_doc_per_batch):
    print("Connecting to bedrock")
    if store_type == "Vector store":
        retriever = store.as_retriever(
            search_type="mmr", search_kwargs={"k": 30,"include_metadata": True}
        )
    elif store_type == "bm25":
        retriever  = store
        retriever.k = no_of_doc_per_batch
        retriever.get_relevant_documents(query)
        
    chain = RetrievalQA.from_chain_type(
        llm=model,
        chain_type="stuff",
        retriever=retriever,
        return_source_documents=False,
        chain_type_kwargs={"prompt": prompt, 
                          }
    )
    
    result = chain.invoke({"query": query,
                                           
                          })
 
    return result


In [27]:
from langchain.retrievers import BM25Retriever
def generate_thread_sql(collection_id,cve_result_filter):
    collection_ids = collection_id
    params = []
    collection_ids_tuple = tuple(collection_ids)
    params.append(collection_ids_tuple)
    
    
    prefinal_sql = f"""
       SELECT   * 
                 FROM 
                     langchain_pg_embedding lpe 
                 WHERE 
                     lpe.collection_id in (
                         SELECT 
                             lpc.uuid 
                         FROM
                             langchain_pg_collection lpc
                        WHERE
                             lpc.name in  %s
                     )
    """

    for key, value in cve_result_filter.items():
        
        if len(value) == 0:
            pass
        else:
                prefinal_sql = prefinal_sql.strip() + f"""
                             AND
                             lpe.cmetadata ->> '{key}' in %s
                            """
                params.append(tuple(value))
                
    prefinal_sql = prefinal_sql.strip() + """
    order by (lpe.cmetadata->>'_time')::timestamptz desc
    """
    
    
    sql = f"""
            SELECT document FROM (
            {prefinal_sql.strip()}
            ) as subquery  
    """
    params = tuple(params)

    return sql,params



def get_documents_for_thread(collection_id, filter_dict,conn_for_vector_db):
    
    
    sql,params = generate_thread_sql(collection_id,filter_dict,)

    with conn_for_vector_db.cursor() as cur:
        
        cur.execute(sql, params)

        rows = cur.fetchall()
        bm25_documents = [row[0] for row in rows] 
        conn_for_vector_db.commit()
        conn_for_vector_db.close()

    bm25_store = BM25Retriever.from_texts(bm25_documents)
    
    return bm25_store


In [28]:
#collection names created for the final results
collection_id = ["TDR_Analysis"]

In [29]:
onprem_cloud_filter = {
  'host': []
}

In [30]:
import boto3
import json
import logging
import psycopg2
import psycopg2.extras as psycopg2_extras
from psycopg2 import pool
from pprint import pprint
conn_for_vector_db = psycopg2.connect(
    dbname="postgres",
    user="postgres", 
    password= "password1", 
    host="database-2-instance-1.cnzoukwx6fwf.us-east-1.rds.amazonaws.com"
)

In [31]:
thread_bm25_store = get_documents_for_thread(collection_id, onprem_cloud_filter,conn_for_vector_db)

In [32]:
# memory = ConversationBufferMemory(return_messages=True)

In [33]:
# memory.load_memory_variables({})

In [34]:
# question_one = "Analyze security threat analysis reports for the hosts and provide a comprehensive report of the threat that takes into account individual host based reports."

In [35]:
# output_one = routing_agent(question_one)
# output_one

In [34]:
memory.load_memory_variables({})

{'history': []}

In [36]:
question_two = 'what is the most likly vulnerablity source?'

In [37]:
output_two = routing_agent(question_two)
output_two

[{'label': 'LABEL_1', 'score': 0.9998812675476074}]


{'query': 'what is the most likly vulnerablity source?',
 'result': 'The most likely vulnerability source is Infra, with 1,526 occurrences in the database.'}

In [38]:
question_three = "What are the risk levels of those vulnerabilities?" 

In [39]:
output_three = routing_agent(question_three)
output_three

[{'label': 'LABEL_1', 'score': 0.9987605810165405}]


{'query': 'What are the risk levels of those vulnerabilities?',
 'result': 'The vulnerabilities sourced from Infra have varying risk levels, with Very High being the most common (588 occurrences), followed closely by High (576 occurrences). There are also 184 Critical and 178 Emergency level vulnerabilities.'}

In [40]:
question_four = "Who owns the assets with the risk level of Emergency?" 

In [39]:
output_four = routing_agent(question_four)
output_four

[{'label': 'LABEL_1', 'score': 0.9999977350234985}]


{'query': 'Who owns the assets with the risk level of Emergency?',
 'result': 'The assets with a risk level of Emergency are owned by 82 different individuals, including Brittney Fisher, Michael Ashley, Amy Brooks, Donna Dunlap, Aaron Stevens, and many others. The full list of asset owners includes 82 unique names.'}

In [42]:
question_five = "Who are the other owners" 

In [9]:
output_five = routing_agent(question_five)
output_five

In [13]:
memory.load_memory_variables({})

In [40]:
question_one = "Analyze security threat analysis reports for the hosts and provide a comprehensive report of the threat that takes into account individual host based reports."

In [41]:
output_one = routing_agent(question_one)
output_one

[{'label': 'LABEL_0', 'score': 0.9999979734420776}]
Connecting to bedrock


{'query': 'Analyze security threat analysis reports for the hosts and provide a comprehensive report of the threat that takes into account individual host based reports.',
 'result': 'Based on the analysis of security threat reports for multiple hosts, I can provide a comprehensive overview of the potential threat scenario:\n\n1. Threat Summary:\nA sophisticated cyber attack appears to have targeted multiple hosts within the organization\'s network. The attack demonstrates characteristics of an Advanced Persistent Threat (APT), involving multiple stages of compromise, lateral movement, and attempts to establish persistence.\n\n2. Affected Systems:\n- wstp-srpa_013 (172.16.10.8)\n- wstp-srpa015 (172.16.10.12)\n- wstp-srpa017 (no malicious activity detected, but included for completeness)\n\n3. Attack Timeline:\nMay 22, 2024:\n14:12:05 - Initial compromise of wstp-srpa_013\n15:21:28 - Attempted expansion on wstp-srpa_013\n15:23:07 - Malware delivery to wstp-srpa_013\n15:27:03 - Malware e

In [15]:
memory.load_memory_variables({})

NameError: name 'memory' is not defined