<a href="https://colab.research.google.com/github/KaifAhmad1/code-test/blob/main/KG_Enhanced_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Problem Statement

#### Task
Develop a co-pilot for threat researchers, security analysts, and professionals that addresses the limitations of current AI solutions like ChatGPT and Perplexity.

#### Current Challenges
1. **Generic Data**: Existing AI solutions provide generic information that lacks specificity.
2. **Context Understanding**: These solutions fail to understand and maintain context.
3. **Limited Information**: The data sources are often limited and not comprehensive.
4. **Single Source Dependency**: Relying on a single source of information reduces reliability and accuracy.
5. **Inadequate AI Models**: Current models do not meet the specialized needs of cybersecurity professionals.

#### Requirement
Create a chatbot capable of collecting and curating data from multiple sources, starting with search engines, and expanding to website crawling and Twitter scraping.

#### Features Required

##### User Interface (UI)
- Chat UI with file upload capabilities.
- Options to save and select prompts.
- Configuration settings for connectors with enable/disable toggles.
- Interface for configuring knowledge and variables (similar to Dify.ai).

##### Technical Specifications
- **No Hallucinations**: Ensure the chatbot provides accurate and reliable information.
- **RAG (Retrieval-Augmented Generation)**: Use RAG to determine which connectors to use based on user inputs.
- **Query Chunking and Distribution**: Optimize the process of breaking down queries and distributing them across different sources.
- **Data Curation Steps**:
  1. Collect links from approximately 50 sources.
  2. Aggregate data from websites and Twitter.
  3. Curate data using a knowledge graph to find relationships and generate responses.
- **Chatbot Capabilities**: Answer queries such as:
  - "List all details on {{BFSI}} security incidents in {{India}}."
  - "List all ransomware attacks targeting the healthcare industry in {{last 7 days/last 3 months/last week/last month}}."
  - "Provide recent incidents related to Lockbit Ransomware gang / BlackBasta Ransomware."

#### Goal
Develop a data collector that integrates multiple specific sources to enrich the knowledge base, enabling the model to better understand context and deliver accurate results. The solution should be modular, allowing customization and configuration of sources.

#### Summary
The goal is to build an advanced, modular chatbot for cybersecurity professionals that overcomes the limitations of existing AI solutions by integrating multiple data sources and ensuring context-aware, accurate responses. The chatbot will utilize state-of-the-art techniques like RAG and knowledge graphs to provide comprehensive, curated information from diverse sources.


#### Installation and Setup

In [None]:
!pip uninstall -yq torch torchvision pandas
!pip install -q torch==2.3.1 torchvision==0.18.1 pandas==2.0.3
!pip install -qU langchain langchain-community faiss-cpu kuzu pyvis
!pip install -qU sentence-transformers networkx pydantic
!pip install -qU langchain-groq apify_client langgraph python-dotenv

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m779.1/779.1 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m43.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.3/12.3 MB[0m [31m59.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.3/21.3 MB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m990.3/990.3 kB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m67.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m27.0/27.0 MB[0m [31m54.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

**Imports**

In [None]:
import os
import logging
from typing import List, Dict, Any

import networkx as nx
from pyvis.network import Network
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain
from langchain.schema import SystemMessage, HumanMessage
from langchain_groq import ChatGroq
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.tools import BaseTool
from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain_core.messages import BaseMessage
from langchain.output_parsers import PydanticOutputParser

from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolExecutor

from apify_client import ApifyClient
from pydantic import BaseModel, Field

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

**Embedding and LLM Initialization**

In [None]:
# Initialize HuggingFace embeddings
model_name = "BAAI/bge-small-en"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
embeddings = HuggingFaceBgeEmbeddings(
    model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
)

# Initialize Llama-3.1 from Meta using Groq LPU Inference
llm = ChatGroq(
    temperature=0,
    model="llama-3.1-70b-versatile",
    api_key="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"
)

system = "You are a helpful assistant."
human = "{text}"
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)])

chain = prompt | llm

**Knowledge Graph Implementation**

In [None]:
class KnowledgeGraph:
    def __init__(self):
        self.graph = nx.Graph()

    def add_entity(self, entity: str, entity_type: str):
        self.graph.add_node(entity, type=entity_type)

    def add_relation(self, entity1: str, entity2: str, relation: str):
        self.graph.add_edge(entity1, entity2, relation=relation)

    def get_related_entities(self, entity: str) -> List[Dict[str, str]]:
        related = []
        for neighbor in self.graph.neighbors(entity):
            edge_data = self.graph.get_edge_data(entity, neighbor)
            related.append({
                "entity": neighbor,
                "relation": edge_data["relation"]
            })
        return related

    def visualize(self, output_file: str = "knowledge_graph.html"):
        net = Network(notebook=True, width="100%", height="500px")
        for node, node_data in self.graph.nodes(data=True):
            net.add_node(node, label=node, title=f"Type: {node_data['type']}")
        for edge in self.graph.edges(data=True):
            net.add_edge(edge[0], edge[1], title=edge[2]['relation'])
        net.show(output_file)

# Initialize knowledge graph
kg = KnowledgeGraph()

**Data Collection Functions**

In [None]:
apify_client = ApifyClient("YOUR_APIFY_API_TOKEN")

def scrape_websites(urls: List[str]) -> List[str]:
    """Scrape content from given websites using Apify."""
    logger.info(f"Scraping {len(urls)} websites...")
    run_input = {
        "startUrls": [{"url": url} for url in urls],
        "maxCrawlPages": 10,
        "maxCrawlDepth": 1,
    }
    try:
        run = apify_client.actor("apify/website-content-crawler").call(run_input=run_input)
        dataset_items = apify_client.dataset(run["defaultDatasetId"]).list_items().items
        scraped_content = [item.get('text', '') for item in dataset_items if 'text' in item]
        logger.info(f"Successfully scraped {len(scraped_content)} pages.")
        return scraped_content
    except Exception as e:
        logger.error(f"Error scraping websites: {str(e)}")
        return []

def fetch_scraped_tweets(query: str, max_tweets: int = 100) -> List[Dict[str, Any]]:
    """Fetch tweets related to cybersecurity using Apify."""
    logger.info(f"Fetching tweets for query: {query}")
    actor_input = {
        "queries": [query],
        "maxTweets": max_tweets
    }
    try:
        run = apify_client.actor("apidojo/tweet-scraper").call(run_input=actor_input)
        dataset_id = run["defaultDatasetId"]
        items = apify_client.dataset(dataset_id).list_items().items
        logger.info(f"Fetched {len(items)} tweets.")
        return items
    except Exception as e:
        logger.error(f"Error fetching tweets: {str(e)}")
        return []

# Cybersecurity-specific websites
websites = [
    "https://www.cisa.gov/uscert/ncas/alerts",
    "https://www.virustotal.com/gui/home/upload",
    "https://attack.mitre.org/",
    "https://www.darkreading.com/",
    "https://threatpost.com/",
]

# Scrape websites
scraped_content = scrape_websites(websites)

# Fetch tweets
tweets = fetch_scraped_tweets("#cybersecurity")
tweet_content = [tweet.get('full_text', '') for tweet in tweets]

# Combine scraped content and tweets
all_content = scraped_content + tweet_content

**Vector Store and Retriever Setup Functions**

In [None]:
def create_vectorstore(texts: List[str]) -> FAISS:
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    documents = text_splitter.create_documents(texts)
    return FAISS.from_documents(documents, embeddings)

def setup_retriever(vectorstore: FAISS) -> ContextualCompressionRetriever:
    base_retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 4})
    compressor = LLMChainExtractor.from_llm(llm)
    return ContextualCompressionRetriever(base_compressor=compressor, base_retriever=base_retriever)

# Create vector store and retriever
vectorstore = create_vectorstore(all_content)
retriever = setup_retriever(vectorstore)

**Pydantic Models for Structured Output**

In [None]:
class ThreatAnalysis(BaseModel):
    threat_type: str = Field(description="Type of cybersecurity threat")
    severity: str = Field(description="Severity level of the threat (Low, Medium, High, Critical)")
    description: str = Field(description="Brief description of the threat")
    potential_impact: str = Field(description="Potential impact on organizations")
    mitigation_steps: List[str] = Field(description="List of steps to mitigate the threat")

class VulnerabilityAssessment(BaseModel):
    vulnerability_name: str = Field(description="Name or identifier of the vulnerability")
    affected_systems: List[str] = Field(description="List of affected systems or software")
    cvss_score: float = Field(description="CVSS score of the vulnerability")
    description: str = Field(description="Brief description of the vulnerability")
    remediation_steps: List[str] = Field(description="List of steps to remediate the vulnerability")

class SecurityRecommendation(BaseModel):
    recommendation: str = Field(description="Security recommendation")
    priority: str = Field(description="Priority level (Low, Medium, High)")
    implementation_difficulty: str = Field(description="Difficulty of implementation (Easy, Moderate, Complex)")
    expected_impact: str = Field(description="Expected impact of implementing the recommendation")

**Specialized Agent Tools**

In [None]:
class ThreatAnalyzerTool(BaseTool):
    name = "Threat Analyzer"
    description = "Analyzes cybersecurity threats and provides detailed information"

    def _run(self, query: str) -> ThreatAnalysis:
        prompt = ChatPromptTemplate.from_messages([
            SystemMessage(content="You are a cybersecurity threat analyst. Provide a detailed analysis of the given threat."),
            HumanMessage(content=query)
        ])
        chain = LLMChain(llm=llm, prompt=prompt, output_parser=PydanticOutputParser(pydantic_object=ThreatAnalysis))
        return chain.run(query)

class VulnerabilityAssessorTool(BaseTool):
    name = "Vulnerability Assessor"
    description = "Assesses cybersecurity vulnerabilities and provides detailed information"

    def _run(self, query: str) -> VulnerabilityAssessment:
        prompt = ChatPromptTemplate.from_messages([
            SystemMessage(content="You are a vulnerability assessment specialist. Provide a detailed assessment of the given vulnerability."),
            HumanMessage(content=query)
        ])
        chain = LLMChain(llm=llm, prompt=prompt, output_parser=PydanticOutputParser(pydantic_object=VulnerabilityAssessment))
        return chain.run(query)

class SecurityAdvisorTool(BaseTool):
    name = "Security Advisor"
    description = "Provides security recommendations based on current threats and vulnerabilities"

    def _run(self, query: str) -> SecurityRecommendation:
        prompt = ChatPromptTemplate.from_messages([
            SystemMessage(content="You are a cybersecurity advisor. Provide a detailed security recommendation based on the given context."),
            HumanMessage(content=query)
        ])
        chain = LLMChain(llm=llm, prompt=prompt, output_parser=PydanticOutputParser(pydantic_object=SecurityRecommendation))
        return chain.run(query)

class KnowledgeGraphQueryTool(BaseTool):
    name = "Knowledge Graph Query"
    description = "Queries the knowledge graph for related information"

    def __init__(self, kg: KnowledgeGraph):
        super().__init__()
        self.kg = kg

    def _run(self, query: str) -> str:
        entities = extract_entities(query)
        results = []
        for entity in entities:
            related = self.kg.get_related_entities(entity)
            results.extend([f"{entity} is related to {r['entity']} via {r['relation']}" for r in related])
        return "\n".join(results)

**Helper Functions**

In [None]:
def extract_entities(text: str) -> List[str]:
    # This is a placeholder. In a real-world scenario, you'd use a named entity recognition model.
    return [word.strip() for word in text.split() if len(word) > 5]

def update_knowledge_graph(kg: KnowledgeGraph, text: str):
    entities = extract_entities(text)
    for i, entity in enumerate(entities):
        kg.add_entity(entity, "Concept")
        if i > 0:
            kg.add_relation(entities[i-1], entity, "related_to")

# Update knowledge graph with initial content
for text in all_content:
    update_knowledge_graph(kg, text)

**LangGraph Nodes**

In [None]:
from langgraph.graph import AgentState

def retriever_node(state: AgentState, query: str) -> AgentState:
    relevant_docs = state["retriever"].get_relevant_documents(query)
    context = "\n".join([doc.page_content for doc in relevant_docs])
    state["messages"].append(HumanMessage(content=f"Context: {context}\n\nQuery: {query}"))
    return state

def knowledge_graph_node(state: AgentState) -> AgentState:
    query = state["messages"][-1].content
    kg_tool = KnowledgeGraphQueryTool(state["kg"])
    kg_info = kg_tool._run(query)
    state["messages"].append(HumanMessage(content=f"Knowledge Graph Information:\n{kg_info}"))
    return state

def threat_analysis_node(state: AgentState) -> AgentState:
    query = state["messages"][-1].content
    threat_tool = ThreatAnalyzerTool()
    analysis = threat_tool._run(query)
    state["messages"].append(HumanMessage(content=f"Threat Analysis:\n{analysis.json()}"))
    return state

def vulnerability_assessment_node(state: AgentState) -> AgentState:
    query = state["messages"][-1].content
    vuln_tool = VulnerabilityAssessorTool()
    assessment = vuln_tool._run(query)
    state["messages"].append(HumanMessage(content=f"Vulnerability Assessment:\n{assessment.json()}"))
    return state

def security_recommendation_node(state: AgentState) -> AgentState:
    query = state["messages"][-1].content
    sec_tool = SecurityAdvisorTool()
    recommendation = sec_tool._run(query)
    state["messages"].append(HumanMessage(content=f"Security Recommendation:\n{recommendation.json()}"))
    return state

def agent_node(state: AgentState) -> AgentState:
    tools = [
        ThreatAnalyzerTool(),
        VulnerabilityAssessorTool(),
        SecurityAdvisorTool(),
        KnowledgeGraphQueryTool(state["kg"])
    ]

    agent = create_openai_functions_agent(llm, tools, """You are a cybersecurity expert assistant.
    Analyze the given information and provide a comprehensive response to the query.""")

    agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

    response = agent_executor.invoke({"input": state["messages"][-1].content})
    state["messages"].append(HumanMessage(content=response["output"]))
    return state

def should_continue(state: AgentState) -> str:
    last_message = state["messages"][-1].content
    if "FINAL RESPONSE:" in last_message:
        return "end"
    return "continue"

**Main Workflow**

In [None]:
workflow = StateGraph(AgentState)

# Define nodes
workflow.add_node("retriever", retriever_node)
workflow.add_node("knowledge_graph", knowledge_graph_node)
workflow.add_node("threat_analysis", threat_analysis_node)
workflow.add_node("vulnerability_assessment", vulnerability_assessment_node)
workflow.add_node("security_recommendation", security_recommendation_node)
workflow.add_node("agent", agent_node)

# Define edges
workflow.add_edge("retriever", "knowledge_graph")
workflow.add_edge("knowledge_graph", "threat_analysis")
workflow.add_edge("threat_analysis", "vulnerability_assessment")
workflow.add_edge("vulnerability_assessment", "security_recommendation")
workflow.add_edge("security_recommendation", "agent")

# Add conditional edges
workflow.add_conditional_edges(
    "agent",
    should_continue,
    {
        "continue": "retriever",
        "end": END
    }
)

# Set entry point
workflow.set_entry_point("retriever")

# Compile the workflow
app = workflow.compile()

**Initialize Agent State and Run Workflow**

In [None]:
# Initialize agent state
initial_state = AgentState(
    messages=[HumanMessage(content="What are the latest cybersecurity threats and vulnerabilities?")],
    kg=kg,
    retriever=retriever
)

# Run the workflow
final_state = app.invoke(initial_state)

# Print the final response
print("Final Response:")
print(final_state["messages"][-1].content)

# Visualize the knowledge graph
kg.visualize("cybersecurity_knowledge_graph.html")

# Optional: Print additional information or analysis
print("\nKnowledge Graph Statistics:")
print(f"Number of entities: {len(kg.graph.nodes)}")
print(f"Number of relationships: {len(kg.graph.edges)}")

print("\nMost connected entities:")
sorted_nodes = sorted(kg.graph.degree, key=lambda x: x[1], reverse=True)[:5]
for node, degree in sorted_nodes:
    print(f"{node}: {degree} connections")

print("\nSample relationships:")
for i, (node1, node2, data) in enumerate(kg.graph.edges(data=True)):
    if i >= 5:  # Print only first 5 relationships
        break
    print(f"{node1} is {data['relation']} {node2}")

# Optional: Save the collected data for future use
import json

with open("collected_data.json", "w") as f:
    json.dump({
        "scraped_content": scraped_content,
        "tweets": tweets
    }, f)

print("\nData collection and analysis complete. Results saved to 'cybersecurity_knowledge_graph.html' and 'collected_data.json'.")