<a href="https://colab.research.google.com/github/KaifAhmad1/code-test/blob/main/Cyber_Knowledge_Graph_Creation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Installations**

In [None]:
!pip install -q langchain langchain-community langchain-groq pandas networkx pyvis ampligraph transformers relik

**Web Data Loading**
- This function loads documents from a list of URLs.

- Loading data from websites is the first step in any data processing pipeline. It ensures that we have the raw data needed for further analysis.

In [None]:
from langchain_community.document_loaders import WebBaseLoader

def load_data_from_websites(urls):
    """
    Load documents from a list of URLs.

    Args:
        urls (list): List of URLs to load documents from.

    Returns:
        list: List of loaded documents.
    """
    web_base_loader = WebBaseLoader(urls)
    documents = web_base_loader.load()
    print(f"Loaded {len(documents)} documents.")
    return documents

websites = [
    "https://www.scmagazine.com/home/security-news/",
    "https://thehackernews.com/",
    "https://www.securityweek.com/",
    "https://www.darkreading.com/",
    "https://krebsonsecurity.com/",
    "https://www.bleepingcomputer.com/",
    "https://threatpost.com/",
    "https://www.cyberscoop.com/",
    "https://www.infosecurity-magazine.com/",
    "https://www.zdnet.com/topic/security/",
    "https://www.wired.com/category/security/",
    "https://nakedsecurity.sophos.com/",
    "https://www.cisomag.com/",
    "https://www.cshub.com/",
    "https://www.cybersecuritydive.com/",
    "https://www.cybersecurity-insiders.com/",
    "https://www.csoonline.com/",
    "https://www.securitymagazine.com/topics/2236-cyber-security-news",
    "https://www.helpnetsecurity.com/"
]

documents = load_data_from_websites(websites)

**Split Documents into Chunks**
- This function splits the loaded documents into smaller chunks for easier processing.
- Splitting documents into chunks helps in managing large texts and ensures that each chunk can be processed independently.
- This function is used to break down long documents into manageable pieces for entity detection and triplet extraction.

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

def chunk_data(documents, chunk_size=1000, chunk_overlap=50):
    """
    Split documents into smaller chunks.

    Args:
        documents (list): List of documents to split.
        chunk_size (int): Size of each chunk.
        chunk_overlap (int): Overlap between chunks.

    Returns:
        list: List of document chunks.
    """
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
        is_separator_regex=False,
    )
    chunks = text_splitter.split_documents(documents)
    print(f"Number of chunks created: {len(chunks)}")
    return chunks

chunks = chunk_data(documents)

**Create a DataFrame of Chunks**
- This function converts the list of chunks into a pandas DataFrame.
- Converting chunks into a DataFrame allows for easier data manipulation and analysis using pandas.
- This function is used to create a structured format for the chunks, which will be used in subsequent steps.

In [None]:
import pandas as pd

def create_chunks_dataframe(chunks):
    """
    Convert list of chunks into a pandas DataFrame.

    Args:
        chunks (list): List of document chunks.

    Returns:
        pd.DataFrame: DataFrame containing the chunks.
    """
    data = []
    for chunk in chunks:
        data.append({
            "text": chunk.page_content,
            "source": chunk.metadata.get("source", ""),
            "chunk_id": chunk.metadata.get("chunk_id", "")
        })
    return pd.DataFrame(data)

chunks_df = create_chunks_dataframe(chunks)
chunks_df.head()

**Detect and Classify Entities**
- This function uses a Named Entity Recognition (NER) pipeline to detect entities in the text chunks.
- Detecting entities helps in identifying key concepts and entities in the text, which is crucial for understanding the content.
- This function is used to extract entities from the text chunks and create a DataFrame of detected entities.

In [None]:
from relik import Relik
from relik.inference.data.objects import RelikOutput

def detect_entities_relik(chunks_df):
    """
    Detect and classify entities using ReLiK.

    Args:
        chunks_df (pd.DataFrame): DataFrame containing the chunks.

    Returns:
        pd.DataFrame: DataFrame containing the detected entities.
    """
    relik = Relik.from_pretrained("sapienzanlp/relik-entity-linking-large")

    def extract_concepts(text):
        """
        Extract concepts from text using ReLiK.

        Args:
            text (str): Text to extract concepts from.

        Returns:
            list: List of extracted concepts.
        """
        relik_out: RelikOutput = relik(text)
        concepts = [span.text for span in relik_out.spans]
        return concepts

    def create_concepts_list(chunks_df):
        """
        Create a list of concepts from the chunks DataFrame.

        Args:
            chunks_df (pd.DataFrame): DataFrame containing the chunks.

        Returns:
            list: List of concepts.
        """
        concepts_list = []
        for index, row in chunks_df.iterrows():
            concepts = extract_concepts(row['text'])
            for concept in concepts:
                concepts_list.append({
                    "node_1": concept,
                    "node_2": row['source'],
                    "edge": "contains",
                    "chunk_id": row['chunk_id']
                })
        return concepts_list

    concepts_list = create_concepts_list(chunks_df)
    return pd.DataFrame(concepts_list)

concepts_df = detect_entities_relik(chunks_df)
concepts_df.head()

**Initializing LLM for Knowledge Extraction**
- Initialize the Mistral LLM using Groq for knowledge extraction.

In [None]:
from langchain_groq import ChatGroq
GROQ_API_KEY = "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"

# Initialize the Mistral LLM using Groq
llm = ChatGroq(
    temperature=0,
    model="mixtral-8x7b-32768",
    api_key=GROQ_API_KEY
)

**Extract Triplets using ReLiK**
- Use the ReLiK model for relation extraction.
- Extracting triplets helps in understanding the relationships between entities, which is essential for building a knowledge graph.
- This function is used to extract triplets from the text chunks and create a list of triplets.

In [None]:
def extract_triplets_relik(chunks):
    """
    Extract triplets using ReLiK.

    Args:
        chunks (list): List of document chunks.

    Returns:
        list: List of extracted triplets.
    """
    relik = Relik.from_pretrained("sapienzanlp/relik-relation-extraction-nyt-large")

    def extract_triplets_from_text(text):
        """
        Extract triplets from text using ReLiK.

        Args:
            text (str): Text to extract triplets from.

        Returns:
            list: List of extracted triplets.
        """
        relik_out: RelikOutput = relik(text)
        triplets = [f"({triplet.subject.text}, {triplet.label}, {triplet.object.text})" for triplet in relik_out.triplets]
        return triplets

    all_triplets = []
    for chunk in chunks:
        triplets = extract_triplets_from_text(chunk.page_content)
        all_triplets.extend(triplets)

    return all_triplets

triplets = extract_triplets_relik(chunks)

**Analyze Relationships**
- This function converts the list of triplets into a DataFrame.
- Converting triplets into a DataFrame allows for easier analysis and manipulation of the relationships between entities.
- This function is used to create a structured format for the triplets, which will be used in subsequent steps.

In [None]:
def analyze_relationships(triplets):
    """
    Convert list of triplets into a DataFrame.

    Args:
        triplets (list): List of triplets.

    Returns:
        pd.DataFrame: DataFrame containing the triplets.
    """
    def create_triplets_dataframe(triplets):
        """
        Create a DataFrame from the list of triplets.

        Args:
            triplets (list): List of triplets.

        Returns:
            pd.DataFrame: DataFrame containing the triplets.
        """
        triplet_data = []
        for triplet in triplets:
            subject, predicate, obj = triplet.strip("()").split(", ")
            triplet_data.append({
                "subject": subject.strip(),
                "predicate": predicate.strip(),
                "object": obj.strip()
            })
        return pd.DataFrame(triplet_data)

    triplets_df = create_triplets_dataframe(triplets)
    return triplets_df

triplets_df = analyze_relationships(triplets)
triplets_df.head()

**Calculate Contextual Proximity**
- This function calculates the contextual proximity between nodes by merging the DataFrame with itself and counting the occurrences of node pairs within the same chunk.
- Contextual proximity helps in understanding the relationships between entities that appear together in the same context.
- This function is used to create a DataFrame of contextual proximity relationships between entities.


In [None]:
def calculate_contextual_proximity(df):
    """
    Calculate contextual proximity between nodes.

    Args:
        df (pd.DataFrame): DataFrame containing the nodes.

    Returns:
        pd.DataFrame: DataFrame containing the contextual proximity.
    """
    long_format_df = pd.melt(
        df, id_vars=["chunk_id"], value_vars=["node_1", "node_2"], value_name="node"
    )
    long_format_df.drop(columns=["variable"], inplace=True)

    merged_df = pd.merge(long_format_df, long_format_df, on="chunk_id", suffixes=("_1", "_2"))

    self_loops_index = merged_df[merged_df["node_1"] == merged_df["node_2"]].index
    merged_df = merged_df.drop(index=self_loops_index).reset_index(drop=True)

    grouped_df = (
        merged_df.groupby(["node_1", "node_2"])
        .agg({"chunk_id": [",".join, "count"]})
        .reset_index()
    )
    grouped_df.columns = ["node_1", "node_2", "chunk_id", "count"]

    grouped_df.replace("", np.nan, inplace=True)
    grouped_df.dropna(subset=["node_1", "node_2"], inplace=True)

    grouped_df = grouped_df[grouped_df["count"] != 1]

    grouped_df["edge"] = "contextual proximity"

    return grouped_df

contextual_proximity_df = calculate_contextual_proximity(concepts_df)
contextual_proximity_df.head()

**Merge DataFrames**
- This function merges the concepts DataFrame with the contextual proximity DataFrame and aggregates the data.
- Merging the DataFrames allows for a comprehensive view of the relationships between entities.
- This function is used to create a merged DataFrame that will be used to create the graph.

In [None]:
def merge_dataframes(concepts_df, contextual_proximity_df):
    """
    Merge the concepts DataFrame with the contextual proximity DataFrame.

    Args:
        concepts_df (pd.DataFrame): DataFrame containing the concepts.
        contextual_proximity_df (pd.DataFrame): DataFrame containing the contextual proximity.

    Returns:
        pd.DataFrame: Merged DataFrame.
    """
    merged_df = pd.concat([concepts_df, contextual_proximity_df], axis=0)
    merged_df = (
        merged_df.groupby(["node_1", "node_2"])
        .agg({"chunk_id": ",".join, "edge": ','.join, 'count': 'sum'})
        .reset_index()
    )
    return merged_df

merged_df = merge_dataframes(concepts_df, contextual_proximity_df)

**Create NetworkX Graph**
- This function creates a NetworkX graph from the merged DataFrame, with nodes and edges representing the relationships between entities.
- Creating a graph allows for visualizing and analyzing the relationships between entities.
- This function is used to create a graph that will be used for further analysis and visualization.

In [None]:
import networkx as nx

def create_networkx_graph(merged_df):
    """
    Create a NetworkX graph from the merged DataFrame.

    Args:
        merged_df (pd.DataFrame): Merged DataFrame containing the nodes and edges.

    Returns:
        nx.Graph: NetworkX graph.
    """
    nodes = pd.concat([merged_df['node_1'], merged_df['node_2']], axis=0).unique()
    graph = nx.Graph()

    for node in nodes:
        graph.add_node(str(node))

    for index, row in merged_df.iterrows():
        graph.add_edge(
            str(row["node_1"]),
            str(row["node_2"]),
            title=row["edge"],
            weight=row['count']/4
        )

    return graph

graph = create_networkx_graph(merged_df)

**Calculate Graph Metrics**
- This function calculates various centrality metrics for the graph and sets them as node attributes.
- Centrality metrics help in understanding the importance and influence of nodes in the graph.
- This function is used to enrich the graph with additional metrics that will be used for analysis and visualization.



In [None]:
from networkx.algorithms.centrality import degree_centrality, betweenness_centrality, closeness_centrality, eigenvector_centrality, pagerank

def calculate_graph_metrics(graph):
    """
    Calculate various centrality metrics for the graph.

    Args:
        graph (nx.Graph): NetworkX graph.
    """
    degree_centrality_values = degree_centrality(graph)
    betweenness_centrality_values = betweenness_centrality(graph)
    closeness_centrality_values = closeness_centrality(graph)
    eigenvector_centrality_values = eigenvector_centrality(graph)
    pagerank_values = pagerank(graph)

    nx.set_node_attributes(graph, degree_centrality_values, 'degree_centrality')
    nx.set_node_attributes(graph, betweenness_centrality_values, 'betweenness_centrality')
    nx.set_node_attributes(graph, closeness_centrality_values, 'closeness_centrality')
    nx.set_node_attributes(graph, eigenvector_centrality_values, 'eigenvector_centrality')
    nx.set_node_attributes(graph, pagerank_values, 'pagerank')

calculate_graph_metrics(graph)

**Calculate Communities**
- This function calculates communities in the graph using the Louvain method and assigns colors to each community.
- Identifying communities helps in understanding the structure and organization of the graph.
- This function is used to create a DataFrame of communities that will be used for visualization.

In [None]:
import seaborn as sns
from community import community_louvain

def assign_colors_to_communities(communities):
    """
    Assign colors to communities.

    Args:
        communities (dict): Dictionary containing the communities.

    Returns:
        pd.DataFrame: DataFrame containing the community colors.
    """
    palette = sns.color_palette("hls", len(communities)).as_hex()
    random.shuffle(palette)
    rows = []
    group = 0
    for community in communities:
        color = palette.pop()
        group += 1
        for node in community:
            rows += [{"node": node, "color": color, "group": group}]
    colors_df = pd.DataFrame(rows)
    return colors_df

def calculate_communities(graph):
    """
    Calculate communities in the graph using the Louvain method.

    Args:
        graph (nx.Graph): NetworkX graph.

    Returns:
        dict: Dictionary containing the communities.
    """
    communities_generator = community_louvain.best_partition(graph)
    communities = {}
    for node, community_id in communities_generator.items():
        if community_id not in communities:
            communities[community_id] = []
        communities[community_id].append(node)

    return communities

communities = calculate_communities(graph)
colors_df = assign_colors_to_communities(communities)

**Enhanced Graph Visualization with AmpliGraph**
- This function converts the NetworkX graph to AmpliGraph format, trains a ComplEx model, evaluates the model, and plots the graph using AmpliGraph's visualization tools.
- Using AmpliGraph for visualization provides advanced graph analysis and visualization capabilities.
- This function is used to create an interactive graph visualization that will be saved to a file.

In [None]:
from ampligraph.latent_features import ComplEx
from ampligraph.evaluation import evaluate_performance
from ampligraph.utils import restore_model
from ampligraph.visualization import plot_2D_graph

def create_interactive_graph(graph, output_path):
    """
    Convert the NetworkX graph to AmpliGraph format and visualize it.

    Args:
        graph (nx.Graph): NetworkX graph.
        output_path (str): Path to save the visualization.
    """
    # Convert NetworkX graph to AmpliGraph format
    edges = [(u, v, d['title']) for u, v, d in graph.edges(data=True)]
    nodes = list(graph.nodes)

    # Train a ComplEx model
    model = ComplEx(batches_count=10, seed=0, epochs=20, k=100, eta=5,
                    optimizer='adam', optimizer_params={'lr': 1e-3},
                    loss='multiclass_nll', regularizer=None,
                    regularizer_params={}, verbose=True)

    model.fit(edges)

    # Evaluate the model
    _, _, _, _ = evaluate_performance(edges, model=model,
                                      filter_triples=edges,
                                      use_default_protocol=True,
                                      verbose=True)

    # Plot the graph
    plot_2D_graph(model, nodes, output_path)

# Main Execution for Step 12
output_path = os.path.join(output_directory, "graph.html")
create_interactive_graph(graph, output_path)