In [39]:
import mysql.connector
import yaml


In [40]:
creds = yaml.load(open('credentials.yml'), Loader=yaml.FullLoader)

connection = mysql.connector.connect(
  host = creds['TiDB']['host'],
  port = creds['TiDB']['port'],
  user = creds['TiDB']['user'],
  password = creds['TiDB']['password'],
  database = creds['TiDB']['database'],
  ssl_ca = creds['TiDB']['ssl_ca'],
  ssl_verify_cert = True,
  ssl_verify_identity = True
)

In [21]:
from sentence_transformers import SentenceTransformer

embedder = SentenceTransformer("BAAI/bge-m3")
text = "I am learning ML"
embedding = embedder.encode(text).tolist()

In [22]:
len(embedding)

1024

In [23]:
documents = [
    "Cybersecurity refers to the practice of protecting systems, networks, and programs from digital attacks. These attacks are usually aimed at accessing, changing, or destroying sensitive information; extorting money from users; or interrupting normal business processes.",
    "The primary goals of cybersecurity are often summarized by the 'CIA Triad': Confidentiality, Integrity, and Availability. Confidentiality ensures information is accessed only by authorized parties. Integrity ensures data is accurate and trustworthy. Availability guarantees authorized users can access information when needed.",
    "Common types of cyber threats include malware (like viruses, worms, ransomware), phishing (deceptive communications to steal info), SQL injection (malicious code for database access), Denial-of-Service (DoS) attacks (overwhelming systems to disrupt service), and Man-in-the-Middle (MitM) attacks (eavesdropping on communications).",
    "Malware is a broad term for malicious software designed to disrupt, damage, or gain unauthorized access to a computer system. This category includes viruses, worms, Trojan horses, spyware, adware, and ransomware.",
    "Ransomware is a particularly destructive type of malware that encrypts a victim's files, demanding a ransom payment (often in cryptocurrency) for their decryption. It poses a significant threat to individuals and organizations worldwide.",
    "Phishing attacks involve sending fraudulent communications that appear to come from a reputable source, typically email. The goal is to trick individuals into revealing sensitive information like usernames, passwords, and credit card details.",
    "Multi-Factor Authentication (MFA) is a security system that requires more than one method of verification from independent categories of credentials to verify the user's identity for a login or other transaction. Examples include something you know (password), something you have (phone/token), or something you are (fingerprint).",
    "Firewalls act as a barrier between a trusted internal network and untrusted external networks (like the internet). They monitor and control incoming and outgoing network traffic based on predetermined security rules, preventing unauthorized access.",
    "Encryption is the process of converting information or data into a code to prevent unauthorized access. Only authorized parties with the correct decryption key can access the original, readable information.",
    "Intrusion Detection Systems (IDS) and Intrusion Prevention Systems (IPS) monitor network traffic for suspicious activity and known threats. While an IDS simply alerts, an IPS can automatically block or prevent detected attacks.",
    "Vulnerability management involves identifying, evaluating, prioritizing, and remediating security vulnerabilities in systems, software, and networks. Regular scans and updates are crucial for this process.",
    "Incident response is an organized approach to addressing and managing the aftermath of a security breach or cyberattack. The goal is to contain the damage, reduce recovery time and costs, and prevent future incidents.",
    "Security Information and Event Management (SIEM) systems aggregate and analyze security-related data from various sources across an organization's IT infrastructure, providing a centralized view for real-time threat detection and compliance reporting.",
    "Cloud security refers to the set of policies, controls, procedures, and technologies that work together to protect cloud-based systems, data, and infrastructure. It's a shared responsibility model between the cloud provider and the customer.",
    "Endpoint security involves protecting individual end-user devices (like laptops, desktops, and mobile phones) from cyber threats. It's a critical layer of defense, especially with remote workforces.",
    "Social engineering is a manipulation technique that exploits human error to gain private information, access, or valuables. Attackers use psychological manipulation to trick users into divulging confidential information or performing actions that compromise security.",
    "Data privacy regulations, such as GDPR (General Data Protection Regulation) and CCPA (California Consumer Privacy Act), establish rules for how organizations should collect, store, and process personal data, emphasizing user rights and data protection.",
    "Zero Trust is a security model based on the principle of 'never trust, always verify.' It assumes that no user or device, whether inside or outside an organization's network, should be trusted by default, and every access request must be authenticated and authorized.",
    "Security awareness training educates employees about common cyber threats, organizational security policies, and best practices for protecting sensitive information, significantly reducing the risk of human error in security breaches.",
    "Penetration testing (pen testing) is a simulated cyber attack against your computer system to check for exploitable vulnerabilities. It's a proactive security measure to identify weaknesses before real attackers do."
]

In [25]:
import json

def add_document(connection, text):
    cur = connection.cursor()
    embedding = embedder.encode(text).tolist()
    cur.execute("INSERT INTO documents (document, embedding) VALUES (%s, %s)", (text, json.dumps(embedding)))
    connection.commit()
    cur.close()

for doc in documents:
    add_document(connection, doc)

In [31]:
# k is the number of results for the answers
def query_tidb(connection, query_text, k=5):
    # change query text to embedding
    query_embedding = json.dumps(embedder.encode(query_text).tolist())
    cur = connection.cursor()
    sql_query = f"""
        SELECT document, vec_cosine_distance(embedding, '{query_embedding}') AS distance
        FROM documents
        ORDER BY distance
        LIMIT {k};
    """

    cur.execute(sql_query)
    results = cur.fetchall()
    cur.close()

    return [(content, distance) for content, distance in results]

query_text = "Security Awareness"
results = query_tidb(connection, query_text)
results

[('Security awareness training educates employees about common cyber threats, organizational security policies, and best practices for protecting sensitive information, significantly reducing the risk of human error in security breaches.',
  0.37797851026698526),
 ('Intrusion Detection Systems (IDS) and Intrusion Prevention Systems (IPS) monitor network traffic for suspicious activity and known threats. While an IDS simply alerts, an IPS can automatically block or prevent detected attacks.',
  0.4832984517038562),
 ("Security Information and Event Management (SIEM) systems aggregate and analyze security-related data from various sources across an organization's IT infrastructure, providing a centralized view for real-time threat detection and compliance reporting.",
  0.4925887584685955),
 ('Vulnerability management involves identifying, evaluating, prioritizing, and remediating security vulnerabilities in systems, software, and networks. Regular scans and updates are crucial for this 

In [41]:
# After we get the embedding database
# We will create RAG

import ollama

def generate_response(connection, query_text):
    retrieved_docs = query_tidb(connection, query_text)

    context = "\n".join(doc[0] for doc in retrieved_docs)
    # print(context)

    prompt = f"Please answer the question based on the following context:\n{context}\n\nQuestion: {query_text}"

    response = ollama.chat(model="llama3.2", messages=[
        {"role": "system", "content": "You are the security engineer for this company."},
        {"role": "user", "content": prompt}
    ])

    return response["message"]["content"]

query_text = "What is the meaning of Security Awareness?"
generate_response(connection, query_text)

'According to the context provided, Security Awareness refers to educating employees about common cyber threats, organizational security policies, and best practices for protecting sensitive information, significantly reducing the risk of human error in security breaches.'

In [42]:
query_text = "What is the meaning of Pentest?"
generate_response(connection, query_text)

'The term "Pentest" is an abbreviation for Penetration Testing, which I mentioned earlier in our context. In this case, "Pentest" refers to a simulated cyber attack against a computer system conducted by security engineers to identify vulnerabilities and weaknesses, thereby strengthening the overall security posture of the organization.'