<a href="https://colab.research.google.com/github/KaifAhmad1/code-test/blob/main/KG_Enhanced_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Problem Statement

#### Task
Develop a co-pilot for threat researchers, security analysts, and professionals that addresses the limitations of current AI solutions like ChatGPT and Perplexity.

#### Current Challenges
1. **Generic Data**: Existing AI solutions provide generic information that lacks specificity.
2. **Context Understanding**: These solutions fail to understand and maintain context.
3. **Limited Information**: The data sources are often limited and not comprehensive.
4. **Single Source Dependency**: Relying on a single source of information reduces reliability and accuracy.
5. **Inadequate AI Models**: Current models do not meet the specialized needs of cybersecurity professionals.

#### Requirement
Create a chatbot capable of collecting and curating data from multiple sources, starting with search engines, and expanding to website crawling and Twitter scraping.

#### Features Required

##### User Interface (UI)
- Chat UI with file upload capabilities.
- Options to save and select prompts.
- Configuration settings for connectors with enable/disable toggles.
- Interface for configuring knowledge and variables (similar to Dify.ai).

##### Technical Specifications
- **No Hallucinations**: Ensure the chatbot provides accurate and reliable information.
- **RAG (Retrieval-Augmented Generation)**: Use RAG to determine which connectors to use based on user inputs.
- **Query Chunking and Distribution**: Optimize the process of breaking down queries and distributing them across different sources.
- **Data Curation Steps**:
  1. Collect links from approximately 50 sources.
  2. Aggregate data from websites and Twitter.
  3. Curate data using a knowledge graph to find relationships and generate responses.
- **Chatbot Capabilities**: Answer queries such as:
  - "List all details on {{BFSI}} security incidents in {{India}}."
  - "List all ransomware attacks targeting the healthcare industry in {{last 7 days/last 3 months/last week/last month}}."
  - "Provide recent incidents related to Lockbit Ransomware gang / BlackBasta Ransomware."

#### Source Tools

##### Website Crawling and Scraping
- [Firecrawl](https://www.firecrawl.dev/playground)
- [Crawl4AI](https://github.com/unclecode/crawl4ai)
- [Apify](https://apify.com/apify/website-content-crawler)
- [Exa](https://exa.ai/search)

##### Twitter Sources
- [Apify Tweet Scraper](https://apify.com/apidojo/tweet-scraper)
- [Twitter API](https://developer.x.com/en/docs/twitter-api)

##### Development Tools
- [Flowise AI](https://flowiseai.com/)
- [Langgenius Dify](https://github.com/langgenius/dify)

#### Goal
Develop a data collector that integrates multiple specific sources to enrich the knowledge base, enabling the model to better understand context and deliver accurate results. The solution should be modular, allowing customization and configuration of sources.

#### Summary
The goal is to build an advanced, modular chatbot for cybersecurity professionals that overcomes the limitations of existing AI solutions by integrating multiple data sources and ensuring context-aware, accurate responses. The chatbot will utilize state-of-the-art techniques like RAG and knowledge graphs to provide comprehensive, curated information from diverse sources.


In [1]:
# Uninstall conflicting packages
!pip uninstall -yq torch torchvision pandas

# Install specific versions
!pip install -q torch==2.3.1 torchvision==0.18.1 pandas==2.0.3

# Reinstall the other libraries
!pip install -qU langchain langchain-community faiss-cpu kuzu pyvis
!pip install -qU sentence-transformers plotly scikit-learn networkx
!pip install -qU langchain-groq apify_client langgraph python-dotenv

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m779.1/779.1 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m48.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.3/12.3 MB[0m [31m113.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.3/21.3 MB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.4/50.4 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m990.3/990.3 kB[0m [31m52.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m72.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m27.0/27.0 MB[0m [31m48.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import os
import logging
import json
from typing import List, Dict, Any

import kuzu
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate, ChatPromptTemplate
from langchain.chains import LLMChain
import networkx as nx
import plotly.graph_objects as go
import plotly.express as px
from sklearn.manifold import TSNE
import pandas as pd
from apify_client import ApifyClient
from langgraph.graph import Graph, END
from langgraph.prebuilt import ToolInvocation



In [3]:
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [4]:
# Initialize HuggingFace embeddings
model_name = "BAAI/bge-small-en"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
embeddings = HuggingFaceBgeEmbeddings(
    model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
)

# Initialize Llama-3.1 from Meta using Groq LPU Inference
llm = ChatGroq(
    temperature=0,
    model="llama-3.1-70b-versatile",
    api_key="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"
)

system = "You are a helpful assistant."
human = "{text}"
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)])

chain = prompt | llm

  from tqdm.autonotebook import tqdm, trange
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/90.8k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/133M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [5]:
# Initialize Apify client
apify_client = ApifyClient("apify_api_t9YCnrjquQgW4BCNM8yYZrX6Q2a1uF1ImYkB")

In [6]:
# Cybersecurity-specific websites
websites = [
    "https://www.cisa.gov/uscert/ncas/alerts",
    "https://www.virustotal.com/gui/home/upload",
    "https://attack.mitre.org/",
    "https://www.darkreading.com/",
    "https://threatpost.com/",
]

In [7]:
def scrape_websites(urls: List[str]) -> List[str]:
    """
    Scrape content from given websites using Apify.

    Args:
        urls (List[str]): List of URLs to scrape.

    Returns:
        List[str]: List of scraped text content.
    """
    logger.info(f"Scraping {len(urls)} websites...")
    run_input = {
        "startUrls": [{"url": url} for url in urls],
        "maxCrawlPages": 10,
        "maxCrawlDepth": 1,
    }

    try:
        run = apify_client.actor("apify/website-content-crawler").call(run_input=run_input)
        dataset_items = apify_client.dataset(run["defaultDatasetId"]).list_items().items
        scraped_content = [item.get('text', '') for item in dataset_items if 'text' in item]
        logger.info(f"Successfully scraped {len(scraped_content)} pages.")
        return scraped_content
    except Exception as e:
        logger.error(f"Error scraping websites: {str(e)}")
        return []

# Example usage
documents = scrape_websites(websites)
print(documents)

['Cybersecurity Alerts & Advisories | CISALockProblem loading page\nCybersecurity Advisory: In-depth reports covering a specific cybersecurity issue, often including threat actor tactics, techniques, and procedures; indicators of compromise; and mitigations.\nAlert: Concise summaries covering cybersecurity topics, such as mitigations that vendors have published for vulnerabilities in their products.\nICS Advisory: Concise summaries covering industrial control system (ICS) cybersecurity topics, primarily focused on mitigations that ICS vendors have published for vulnerabilities in their products.\nICS Medical Advisory: Concise summaries covering ICS medical cybersecurity topics, primarily focused on mitigations that ICS medical vendors have published for vulnerabilities in their products.\nAnalysis Report: In-depth analysis of a new or evolving cyber threat, including technical details and remediations.', 'VirusTotal - HomereCAPTCHA\nreCAPTCHA \nPrivacyTerms\nprotected by reCAPTCHA\nPri

In [8]:
def fetch_scraped_tweets(query: str, max_tweets: int = 100) -> List[Dict[str, Any]]:
    """
    Fetch tweets related to cybersecurity using Apify.

    Args:
        query (str): Search query for tweets.
        max_tweets (int): Maximum number of tweets to fetch.

    Returns:
        List[Dict[str, Any]]: List of tweet data.
    """
    logger.info(f"Fetching tweets for query: {query}")
    actor_input = {
        "queries": [query],
        "maxTweets": max_tweets
    }

    try:
        run = apify_client.actor("apidojo/tweet-scraper").call(run_input=actor_input)
        dataset_id = run["defaultDatasetId"]
        items = apify_client.dataset(dataset_id).list_items().items
        logger.info(f"Fetched {len(items)} tweets.")
        return items
    except Exception as e:
        logger.error(f"Error fetching tweets: {str(e)}")
        return []

# Example usage
tweets = fetch_scraped_tweets("#cybersecurity")
print(tweets)

[]


In [9]:
# Combine all texts
all_texts = documents + [tweet.get('full_text', '') for tweet in tweets]

# Split texts into chunks
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_text("\n\n".join(all_texts))

# Create a vector store
vectorstore = FAISS.from_texts(texts, embeddings)



In [10]:
# Initialize Kuzu DB
db = kuzu.Database("cybersecurity_knowledge_graph")
conn = kuzu.Connection(db)

In [11]:
def initialize_knowledge_graph():
    """Initialize the knowledge graph schema."""
    try:
        conn.execute("CREATE NODE TABLE Entity (name STRING, type STRING, PRIMARY KEY (name))")
        conn.execute("CREATE REL TABLE Relation (FROM Entity TO Entity, predicate STRING)")
        logger.info("Knowledge graph schema initialized successfully.")
    except Exception as e:
        logger.error(f"Error initializing knowledge graph: {str(e)}")

# Example usage
initialize_knowledge_graph()

In [12]:
def update_knowledge_graph(triplets: List[tuple]):
    """
    Update the knowledge graph with new triplets.

    Args:
        triplets (List[tuple]): List of (subject, predicate, object) triplets.
    """
    for subject, predicate, obj in triplets:
        try:
            conn.execute("INSERT INTO Entity (name, type) VALUES (?, ?) ON CONFLICT DO NOTHING", [subject, "Cybersecurity_Entity"])
            conn.execute("INSERT INTO Entity (name, type) VALUES (?, ?) ON CONFLICT DO NOTHING", [obj, "Cybersecurity_Entity"])
            conn.execute("INSERT INTO Relation VALUES (?, ?, ?)", [subject, obj, predicate])
        except Exception as e:
            logger.error(f"Error updating knowledge graph: {str(e)}")

In [13]:
def extract_triplets(text: str) -> List[tuple]:
    """
    Extract knowledge triplets from text using the LLM.

    Args:
        text (str): Input text to extract triplets from.

    Returns:
        List[tuple]: List of extracted (subject, predicate, object) triplets.
    """
    kg_triple_extract_template = """
    Extract up to 5 cybersecurity-related knowledge triplets from the text below in the form (subject, predicate, object).
    Focus on threats, vulnerabilities, attack techniques, and security measures.
    Text: {text}
    Triplets:
    """
    kg_triple_extract_prompt = PromptTemplate(
        input_variables=["text"],
        template=kg_triple_extract_template,
    )
    kg_triple_extract_chain = LLMChain(llm=llm, prompt=kg_triple_extract_prompt)

    try:
        result = kg_triple_extract_chain.invoke({"text": text})
        triplets = [eval(triplet.strip()) for triplet in result['text'].split('\n') if triplet.strip()]
        return triplets
    except Exception as e:
        logger.error(f"Error extracting triplets: {str(e)}")
        return []

In [14]:
def get_graph_data():
    """Retrieve graph data from the knowledge graph."""
    try:
        nodes_result = conn.execute("MATCH (e:Entity) RETURN e.name")
        edges_result = conn.execute("MATCH (e1:Entity)-[r:Relation]->(e2:Entity) RETURN e1.name, r.predicate, e2.name")

        nodes = [row.getString(0) for row in nodes_result]
        edges = [(row.getString(0), row.getString(1), row.getString(2)) for row in edges_result]

        return nodes, edges
    except Exception as e:
        logger.error(f"Error retrieving graph data: {str(e)}")
        return [], []

In [15]:
def visualize_graph_plotly():
    """Visualize the knowledge graph using Plotly."""
    nodes, edges = get_graph_data()
    G = nx.Graph()

    for node in nodes:
        G.add_node(node)

    for edge in edges:
        G.add_edge(edge[0], edge[2], label=edge[1])

    pos = nx.spring_layout(G)

    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines')

    node_x = [pos[node][0] for node in G.nodes()]
    node_y = [pos[node][1] for node in G.nodes()]

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            colorscale='YlGnBu',
            reversescale=True,
            color=[],
            size=10,
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right'
            ),
            line_width=2))

    node_adjacencies = []
    node_text = []
    for node, adjacencies in G.adjacency():
        node_adjacencies.append(len(adjacencies))
        node_text.append(f'{node}# of connections: {len(adjacencies)}')

    node_trace.marker.color = node_adjacencies
    node_trace.text = node_text

    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title='Knowledge Graph',
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20,l=5,r=5,t=40),
                        annotations=[ dict(
                            text="",
                            showarrow=False,
                            xref="paper", yref="paper",
                            x=0.005, y=-0.002 ) ],
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                    )

    fig.show()

In [16]:
def visualize_embeddings(texts: List[str]):
    """Visualize document embeddings using t-SNE and Plotly."""
    doc_embeddings = [embeddings.embed_query(text) for text in texts]
    tsne = TSNE(n_components=2, random_state=0)
    tsne_results = tsne.fit_transform(doc_embeddings)

    df = pd.DataFrame(tsne_results, columns=['x', 'y'])
    df['text'] = texts

    fig = px.scatter(df, x='x', y='y', hover_data=['text'], title='Document Embeddings Visualization')
    fig.show()

In [17]:
def threat_analyzer(state: Dict[str, Any]) -> Dict[str, Any]:
    query = "Analyze the latest cybersecurity threats and provide a summary."
    response = chain.invoke({"text": query})
    return {"threat_analysis": response}

def vulnerability_assessor(state: Dict[str, Any]) -> Dict[str, Any]:
    query = "Identify and assess critical vulnerabilities in cybersecurity systems."
    response = chain.invoke({"text": query})
    return {"vulnerability_assessment": response}

def security_advisor(state: Dict[str, Any]) -> Dict[str, Any]:
    query = "Provide recommendations for improving cybersecurity based on current threats and vulnerabilities."
    response = chain.invoke({"text": query})
    return {"security_advice": response}

def knowledge_graph_updater(state: Dict[str, Any]) -> Dict[str, Any]:
    threat_analysis = state.get("threat_analysis", "")
    vulnerability_assessment = state.get("vulnerability_assessment", "")
    security_advice = state.get("security_advice", "")

    combined_text = f"{threat_analysis}\n{vulnerability_assessment}\n{security_advice}"
    triplets = extract_triplets(combined_text)
    update_knowledge_graph(triplets)

    return {"graph_update": f"Knowledge graph updated with {len(triplets)} new triplets."}

In [18]:
def create_workflow():
    workflow = Graph()

    workflow.add_node("threat_analyzer", threat_analyzer)
    workflow.add_node("vulnerability_assessor", vulnerability_assessor)
    workflow.add_node("security_advisor", security_advisor)
    workflow.add_node("knowledge_graph_updater", knowledge_graph_updater)

    workflow.add_edge("threat_analyzer", "vulnerability_assessor")
    workflow.add_edge("vulnerability_assessor", "security_advisor")
    workflow.add_edge("security_advisor", "knowledge_graph_updater")

    workflow.set_entry_point("threat_analyzer")

    return workflow.compile()

In [19]:
def run_cybersecurity_workflow():
    """Run the cybersecurity analysis workflow."""
    logger.info("Starting cybersecurity analysis workflow...")
    app = create_workflow()
    for step in app.stream({}, {"recursion_limit": 10}):
        if isinstance(step, ToolInvocation):
            logger.info(f"Running: {step.tool}")
        else:
            logger.info(f"Result: {json.dumps(step, indent=2)}")

    logger.info("Workflow completed. Updating visualizations...")
    visualize_graph_plotly()
    visualize_embeddings(texts)

In [20]:
def query_graph(query: str) -> str:
    """Query the knowledge graph using the LLM."""
    return chain.invoke({"text": query})

In [21]:
# Example queries
questions = [
    "What are the latest threats targeting the healthcare industry?",
    "Can you provide details on recent ransomware attacks?",
    "What are the most critical vulnerabilities discovered in the last month?",
    "How can organizations protect against phishing attacks?",
    "What are the emerging trends in cybersecurity for financial institutions?"
]

print("\nExample queries:")
for query in questions:
    answer = query_graph(query)
    print(f"Query: {query}\nAnswer: {answer}\n")


Example queries:
Query: What are the latest threats targeting the healthcare industry?
Answer: content='The healthcare industry is a prime target for cyber threats due to the sensitive nature of the data it handles. Here are some of the latest threats targeting the healthcare industry:\n\n1. **Ransomware attacks**: Ransomware attacks continue to plague the healthcare industry, with attackers encrypting sensitive data and demanding payment in exchange for the decryption key. Recent examples include the attacks on Universal Health Services (UHS) and the Sky Lakes Medical Center.\n2. **Phishing and Business Email Compromise (BEC)**: Phishing attacks are becoming increasingly sophisticated, with attackers using social engineering tactics to trick healthcare employees into divulging sensitive information or clicking on malicious links. BEC attacks, in particular, target healthcare executives and administrators, attempting to trick them into transferring funds or revealing sensitive informa