In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Create a folder to store models and Chroma DB
base_path = "/content/drive/MyDrive/RAG_project"
os.makedirs(base_path, exist_ok=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
%pip install chromadb



In [None]:
import os
import json
import chromadb
from sentence_transformers import SentenceTransformer

MODEL = "all-mpnet-base-v2"
model = SentenceTransformer(MODEL)

client = chromadb.PersistentClient(path=os.path.join(base_path, "chroma_db"))
collection_name = "security_kb"

# Load attacks JSON
with open("docs.json", "r", encoding="utf-8") as fh:
    attacks = json.load(fh)

texts = []
ids = []
metas = []

for attack in attacks:
    # Combine relevant info into one text for embedding
    mitigation_texts = []
    for m in attack.get("mitigation_recommendations", []):
        mitigation_texts.append(
            f"Recommendation: {m['recommendation']}. "
            f"Implementation: {m['implementation']}. "
            f"Priority: {m['priority']}. "
            f"Source: {m['source']} ({m['source_detail']})"
        )
    text = (
        f"Attack ID: {attack['id']}\n"
        f"Type: {attack['attack_type']}\n"
        f"Description: {attack['description']}\n"
        f"Indicators: {', '.join(attack.get('indicators', []))}\n"
        f"Mitigation Recommendations:\n" + "\n".join(mitigation_texts) + "\n"
        f"Defense Layers: {', '.join(attack.get('defense_layers', []))}\n"
        f"Source: {attack.get('source', '')}"
    )
    texts.append(text)
    ids.append(attack["id"])
    metas.append({"source": attack.get("source", ""), "attack_type": attack.get("attack_type", "")})

# Encode embeddings
embeddings = model.encode(texts, show_progress_bar=True, batch_size=32).tolist()

# Create or get collection
try:
    coll = client.get_collection(collection_name)
except:
    coll = client.create_collection(name=collection_name)

coll.add(ids=ids, documents=texts, metadatas=metas, embeddings=embeddings)

print("Indexed", len(ids))


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Indexed 9


In [None]:
import chromadb
from sentence_transformers import SentenceTransformer
import os

MODEL = "all-mpnet-base-v2"
model = SentenceTransformer(MODEL)

base_path = "/content/drive/MyDrive/RAG_project"
client = chromadb.PersistentClient(path=os.path.join(base_path, "chroma_db"))
collection = client.get_collection("security_kb")

def query_rag_local(user_query, top_k=3):
    query_emb = model.encode([user_query])[0]
    results = collection.query(query_embeddings=[query_emb], n_results=top_k)
    docs_texts = "\n\n".join(results['documents'][0])
    return docs_texts

# Example usage
user_query = "Explain DDoS attacks"
answer = query_rag_local(user_query)
print("Top retrieved documents:\n")
print(answer)


Top retrieved documents:

Attack ID: attack_005
Type: DoS
Description: Déni de service visant à saturer les ressources du système ou du réseau.
Indicators: trafic réseau massif anormal, utilisation CPU à 100%, mémoire saturée
Mitigation Recommendations:
Recommendation: Service anti-DDoS avec scrubbing center. Implementation: Redirection trafic vers centre de nettoyage, filtrage patterns DDoS, anycast DNS. Priority: Critical. Source: NIST SP 800-61 (Guide to DDoS Attacks)
Recommendation: Rate limiting et traffic shaping. Implementation: Limites requêtes par seconde, quotas bande passante, priorisation trafic légitime. Priority: High. Source: MITRE ATT&CK (Mitigation ID: M1040)
Recommendation: Architecture scalable avec auto-scaling. Implementation: Load balancers, répartition de charge, ressources cloud élastiques, CDN. Priority: Medium. Source: CIS Control 12 (Boundary Defense)
Defense Layers: Network, Infrastructure, Availability
Source: MITRE ATT&CK

MITRE ATT&CK — Selected Technique

In [None]:
import os
import chromadb
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# ---------------------------
# Paths & model names
# ---------------------------
base_path = "/content/drive/MyDrive/RAG_project"
kb_path = os.path.join(base_path, "chroma_db")
rag_model_path = os.path.join(base_path, "models")
llm_model_name = "Qwen/Qwen2.5-1.5B-Instruct"
embedding_model_name = "all-mpnet-base-v2"

# ---------------------------
# Load embedding model
# ---------------------------
embed_model = SentenceTransformer(embedding_model_name)

# ---------------------------
# Load or download LLM
# ---------------------------
os.makedirs(rag_model_path, exist_ok=True)

if os.path.exists(os.path.join(rag_model_path, "config.json")):
    print("Loading LLM from local Drive...")
    tokenizer = AutoTokenizer.from_pretrained(rag_model_path)
    llm_model = AutoModelForCausalLM.from_pretrained(rag_model_path, device_map="auto")
else:
    print("Downloading LLM from Hugging Face...")
    tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
    llm_model = AutoModelForCausalLM.from_pretrained(llm_model_name, device_map="auto")
    tokenizer.save_pretrained(rag_model_path)
    llm_model.save_pretrained(rag_model_path)
    print(f"Model saved to {rag_model_path}")

# ---------------------------
# Load Chroma collection
# ---------------------------
client = chromadb.PersistentClient(path=kb_path)
collection_name = "security_kb"

try:
    collection = client.get_collection(collection_name)
except:
    raise ValueError(f"Collection {collection_name} not found in Chroma DB")




Loading LLM from local Drive...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
def query_rag_llm(user_query, top_k=3, max_tokens=300):
    # 1. Embed query
    query_emb = embed_model.encode([user_query], normalize_embeddings=True)[0]

    # 2. Retrieve top-k relevant documents
    results = collection.query(query_embeddings=[query_emb], n_results=top_k)
    retrieved_docs = results['documents'][0]  # List of strings

    # 3. Build context for LLM
    context = "\n\n".join(retrieved_docs)
    prompt = f"Using the following security knowledge, answer the question concisely:\n\nContext:\n{context}\n\nQuestion: {user_query}\nAnswer:"

    # 4. Tokenize and generate answer
    inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device)
    outputs = llm_model.generate(**inputs, max_new_tokens=max_tokens)
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return answer




In [None]:
# ---------------------------
# Example usage
# ---------------------------
query = "What are the recommended defenses against DDoS attacks?"
answer = query_rag_llm(query, top_k=3)
print("LLM Answer:\n")
print(answer)

LLM Answer:

Using the following security knowledge, answer the question concisely:

Context:
Attack ID: attack_005
Type: DoS
Description: Déni de service visant à saturer les ressources du système ou du réseau.
Indicators: trafic réseau massif anormal, utilisation CPU à 100%, mémoire saturée
Mitigation Recommendations:
Recommendation: Service anti-DDoS avec scrubbing center. Implementation: Redirection trafic vers centre de nettoyage, filtrage patterns DDoS, anycast DNS. Priority: Critical. Source: NIST SP 800-61 (Guide to DDoS Attacks)
Recommendation: Rate limiting et traffic shaping. Implementation: Limites requêtes par seconde, quotas bande passante, priorisation trafic légitime. Priority: High. Source: MITRE ATT&CK (Mitigation ID: M1040)
Recommendation: Architecture scalable avec auto-scaling. Implementation: Load balancers, répartition de charge, ressources cloud élastiques, CDN. Priority: Medium. Source: CIS Control 12 (Boundary Defense)
Defense Layers: Network, Infrastructure, 

In [None]:
import torch

# Assume you have a trained PyTorch attack detection model
# and it outputs either a label or a probability distribution

def get_attack_label(input_features, detection_model):
    """
    input_features: preprocessed features from your logs or data
    detection_model: your trained PyTorch model
    """
    detection_model.eval()
    with torch.no_grad():
        inputs = torch.tensor(input_features).float().to(detection_model.device)
        outputs = detection_model(inputs)
        # If multi-class, pick the highest probability
        if outputs.ndim == 2:  # batch x classes
            label_idx = torch.argmax(outputs, dim=1).item()
            return label_idx
        # If single output probability
        return outputs.item()

# ---------------------------
# RAG + LLM function
# ---------------------------
def answer_from_attack(attack_label, top_k=3, max_tokens=300):
    """
    Given a detected attack label, query the RAG knowledge base
    and generate a concise answer with LLM.
    """
    # Map your attack label to a descriptive query for RAG
    label_to_query = {
        0: "Defenses against DoS attacks",
        1: "Defenses against Reconnaissance attacks",
        2: "Defenses against Exploits",
        # ... add mappings for all your attack labels
    }

    user_query = label_to_query.get(attack_label, "Recommended defenses for the detected attack")

    # Query RAG and generate LLM answer
    answer = query_rag_llm(user_query, top_k=top_k, max_tokens=max_tokens)

    # Only return the final answer
    return answer

# ---------------------------
# Example usage
# ---------------------------
# Suppose your detection model detected a DoS attack (label 0)
attack_label = 0  # output of your trained detection model
final_answer = answer_from_attack(attack_label)
print(final_answer)
