In [115]:

from langchain_community.document_loaders import TextLoader 
from langchain_text_splitters import RecursiveCharacterTextSplitter 
from langchain_community.vectorstores import FAISS  
from langchain_core.prompts import ChatPromptTemplate  
from langchain_core.output_parsers import StrOutputParser,JsonOutputParser 
from langchain_core.runnables import RunnablePassthrough  
from langchain_ollama import ChatOllama,OllamaEmbeddings  
from langchain_core.chat_history import InMemoryChatMessageHistory  
from langchain_core.runnables import RunnableWithMessageHistory  
from langchain_core.tools import tool  
from langchain_core.messages import AIMessage, HumanMessage  
from langchain_community.retrievers import BM25Retriever


In [116]:
llm = ChatOllama(model="mistral", temperature=0.5)



loader = TextLoader("security_incidents.txt", encoding="utf-8")
docs = loader.load() 

print(f"Loaded {len(docs)} docs.")
print("Sample content:", docs[0].page_content[:200]) 

Loaded 1 docs.
Sample content: Incident #001 | User=johns | Alert=Multiple failed SSH logins | SourceIP=10.1.1.9 | Host=SRV-LNX-01 | OS=Ubuntu 20 | MITRE=T1110 | Severity=High | Resolution=Blocked source IP; Reset password; Enabled


In [117]:
splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50)  
splits = splitter.split_documents(docs) 

print(f"Created {len(splits)} chunks.") 
for i, chunk in enumerate(splits[:2]):
    print(f"Chunk {i+1} (len {len(chunk.page_content)} chars): {chunk.page_content[:100]}...")

Created 40 chunks.
Chunk 1 (len 205 chars): Incident #001 | User=johns | Alert=Multiple failed SSH logins | SourceIP=10.1.1.9 | Host=SRV-LNX-01 ...
Chunk 2 (len 221 chars): Incident #002 | User=markp | Alert=Suspicious PowerShell encoded command detected | Host=WKS-22 | OS...


In [118]:
embeddings=OllamaEmbeddings(model="nomic-embed-text:latest")  
vectorstore = FAISS.from_documents(splits, embeddings)  

retriever = vectorstore.as_retriever(k=4)  

print(f"Index built with {len(splits)} vectors.")

Index built with 40 vectors.


In [119]:
json_parser = JsonOutputParser()

prompt = ChatPromptTemplate.from_template(f""" You are a Security Incident Assistant that
                helps IT analysts to Recommend resolutions, Retrieve similar past incidents,
        Recall analyst-specific preferences across sessions
, Extract important entities (IPs, OS, MITRE tags)
, \tProvide enriched threat analysis based on historical data.
                                          Ensure the responses doesn't consist of the prompt template.
Use the following context to answer the question.
                                          ENsure the response is in JSON format as per the schema: {json_parser.get_format_instructions()}
 : Context: {{context}}\nQuestion:  {{question}}\nAnswer:  """)  


In [121]:
from langchain_core.documents import Document
import rank_bm25
class HybridRetriever:
    def __init__(self, dense_retriever, sparse_retriever, k=4):
        self.dense_retriever = dense_retriever
        self.sparse_retriever = sparse_retriever
        self.k = k

    def invoke(self, query: str) -> list[Document]:
        dense_results = self.dense_retriever.invoke(query)
        sparse_results = self.sparse_retriever.invoke(query)

        
        all_docs = {doc.page_content: doc for doc in dense_results + sparse_results}
        return list(all_docs.values())[:self.k]
    
bm2=BM25Retriever.from_documents(splits)

In [122]:
hybrid_retriever = HybridRetriever(retriever, bm2, k=4)
from langchain_core.runnables import RunnableLambda

hybrid_runnable = RunnableLambda(lambda q: hybrid_retriever.invoke(q))

In [123]:
import redis
import hashlib
import json

# Initialize Redis connection
r = redis.StrictRedis(host='localhost', port=6379, db=0, decode_responses=True)

def cache_key(prompt: str) -> str:
    return hashlib.md5(prompt.encode()).hexdigest()

def get_kb_result(query: str):
    key = cache_key(query)
    cached_result = r.get(key) 
    
    if cached_result:
        print("Cache hit!")
        return json.loads(cached_result)
    else:
        print("Cache miss! Running RAG chain.")
        
        
        result = hybrid_retriever.invoke(query) 
        
        
        r.set(key, json.dumps(result), ex=3600) 
        
        return result


In [134]:
import re
import json
import ipaddress
from langchain_core.runnables import RunnableLambda

# --- Entity extraction utilities ---
_ip_re = re.compile(r"\b(?:\d{1,3}\.){3}\d{1,3}\b")
_mitre_re = re.compile(r"\bT\d{4}\b", re.IGNORECASE)
_severity_re = re.compile(r"\b(Critical|High|Medium|Low)\b", re.IGNORECASE)
_hostname_re = re.compile(r"\b[A-Za-z0-9][-A-Za-z0-9_.]{1,62}\b")  # loose hostname token matcher

# common OS tokens to look for in text
_OS_KEYWORDS = [
    "Windows Server", "Windows", "CentOS", "RedHat", "Red Hat", "Ubuntu",
    "Debian", "Fedora", "macOS", "mac os", "Windows 10", "Windows 11",
    "Windows Server 2019", "Win11", "Win10"
]
_bruteforce_indicators = [
    "failed login", "failed logins", "brute force", "brute-force", "throttled login",
    "password spray", "multiple failed", "account lockout"
]
_powershell_indicators = ["powershell", "encoded command", "Invoke-Expression", "IEX", "-EncodedCommand"]


def _clean_ips(ip_candidates):
    ips = []
    for ip in ip_candidates:
        try:
            # filter out bogus like 999.999.999.999
            ipaddress.ip_address(ip)
            ips.append(ip)
        except Exception:
            continue
    return sorted(set(ips))


def extract_entities(text: str) -> dict:
    """
    Extract IPs, OS mentions, hostnames, MITRE technique codes, Severity from a string.
    Returns standardized dict.
    """
    if not text:
        return {"IPs": [], "OS": [], "Hostnames": [], "MITRE": [], "Severity": []}

    ips = _clean_ips(_ip_re.findall(text))
    mitre = sorted({m.upper() for m in _mitre_re.findall(text)})
    severity = sorted(set([s.capitalize() for s in _severity_re.findall(text)]))

    # OS detection - look for any token occurrences (case-insensitive)
    os_found = []
    lower = text.lower()
    for tok in _OS_KEYWORDS:
        if tok.lower() in lower and tok not in os_found:
            os_found.append(tok)

    # Hostnames: approximate by searching tokens that look like hostnames and not common words/OS
    hostnames = []
    for token in _hostname_re.findall(text):
        t = token.strip()
        if len(t) < 3 or t.lower() in [o.lower() for o in _OS_KEYWORDS]:
            continue
        # crude heuristic: include tokens that have a dash or uppercase-with-digits like WKS-22, SRV01 etc
        if "-" in t or re.search(r"[A-Z]{2,}\d*", t) or re.search(r"\d{2,}", t):
            hostnames.append(t)
    hostnames = sorted(set(hostnames))

    return {"IPs": ips, "OS": os_found, "Hostnames": hostnames, "MITRE": mitre, "Severity": severity}


# --- Threat enrichment tool (local) ---
def threat_enrichment(entities: dict, docs: list) -> dict:
    """
    Produce lightweight enrichment: mark known-malicious IPs (if referenced in past incidents),
    list related incidents for MITRE/hosts, and surface relevant doc snippets.
    """
    ips = entities.get("IPs", [])
    mitres = entities.get("MITRE", [])
    hosts = entities.get("Hostnames", [])

    related_incidents = []
    malicious_ips = []
    snippets = []

    for d in docs:
        pc = d.page_content
        # collect snippets that contain any entity
        if any(x in pc for x in ips + mitres + hosts):
            snippets.append(pc[:400])

        # if IP present in doc, treat it as evidence of maliciousness (local heuristic)
        for ip in ips:
            if ip in pc and ip not in malicious_ips:
                malicious_ips.append(ip)

        # add incidents that mention MITREs or hosts
        if any(m in pc for m in mitres) or any(h in pc for h in hosts):
            related_incidents.append({"source": d.metadata.get("source"), "snippet": pc[:300]})

    return {
        "malicious_IPs": malicious_ips,
        "related_incidents": related_incidents,
        "snippets": snippets[:10],
        "note": "Enrichment is local (based on retrieved KB). For external feeds use IOC feeds."
    }


# --- Threat scoring logic ---
def compute_threat_score(entities: dict, docs: list) -> tuple[int, dict]:
    """
    Custom scoring:
      - MITRE tags: each tag adds 12 (up to cap)
      - Severity text: Critical=35, High=25, Medium=12, Low=5 (take max if multiple)
      - Malicious IP presence: +20
      - Suspicious PowerShell activity: +15
      - Brute-force indicators: +15
    Returns (score 0-100, breakdown)
    """
    score = 0
    breakdown = {}

    # MITRE
    mitres = entities.get("MITRE", [])
    mitre_score = min(40, len(mitres) * 12)
    breakdown["mitre_count"] = len(mitres)
    breakdown["mitre_score"] = mitre_score
    score += mitre_score

    # Severity
    sev_map = {"Critical": 35, "High": 25, "Medium": 12, "Low": 5}
    sevs = entities.get("Severity", [])
    sev_score = 0
    if sevs:
        # pick the highest mapped severity
        sev_score = max(sev_map.get(s, 0) for s in sevs)
    breakdown["severity_list"] = sevs
    breakdown["severity_score"] = sev_score
    score += sev_score

    # Malicious IPs
    enrichment = threat_enrichment(entities, docs)
    malicious_ips = enrichment.get("malicious_IPs", [])
    ip_score = 20 if malicious_ips else 0
    breakdown["malicious_IPs"] = malicious_ips
    breakdown["malicious_ip_score"] = ip_score
    score += ip_score

    # Powershell indicators (search docs + entity contexts)
    combined_text = " ".join([d.page_content for d in docs])
    ps_present = any(tok.lower() in combined_text.lower() for tok in _powershell_indicators)
    ps_score = 15 if ps_present else 0
    breakdown["powershell_present"] = ps_present
    breakdown["powershell_score"] = ps_score
    score += ps_score

    # brute-force indicators
    bf_present = any(kw in combined_text.lower() for kw in _bruteforce_indicators)
    bf_score = 15 if bf_present else 0
    breakdown["bruteforce_present"] = bf_present
    breakdown["bruteforce_score"] = bf_score
    score += bf_score

    # normalize / cap
    final_score = min(100, int(score))
    breakdown["raw_score"] = score
    breakdown["final_score"] = final_score
    return final_score, breakdown


# --- Build enriched context Runnable for the chain ---
def _build_enriched_context(q: str) -> dict:
    # retrieve dense + sparse using the hybrid retriever
    retrieved_docs = hybrid_retriever.invoke(q)
    docs_text = "\n\n---\n\n".join([d.page_content for d in retrieved_docs])

    # extract entities from query + retrieved docs
    entities = extract_entities(q + " " + docs_text)

    # enrichment + scoring
    enrichment = threat_enrichment(entities, retrieved_docs)
    score, breakdown = compute_threat_score(entities, retrieved_docs)

    # compose a string context to feed the prompt
    ctx = {
        "retrieved_snippets": docs_text[:2000],
        "extracted_entities": entities,
        "threat_enrichment": enrichment,
        "threat_score": score,
        "score_breakdown": breakdown
    }
    # Return a mapping that matches the prompt's expected input variables
    return {
        "context": json.dumps(ctx, indent=2),
        "question": q
    }


# override tool_chain to inject enriched context before the prompt runs
tool_chain = (
    RunnableLambda(_build_enriched_context)  # will receive the question string and return enriched context mapping
    | prompt
    | llm
    | json_parser
)


# --- Updated caching-aware runner ---
def get_kb_result(query: str):
    key = cache_key(query)
    cached_result = r.get(key)
    if cached_result:
        print("Cache hit!")
        try:
            return json.loads(cached_result)
        except Exception:
            return cached_result

    print("Cache miss! Running RAG chain (with enrichment & scoring).")
    result = tool_chain.invoke(query)

    # ensure serializable and store as JSON
    try:
        out = result if isinstance(result, (dict, list)) else json.loads(str(result))
    except Exception:
        out = {"raw": str(result)}

    r.set(key, json.dumps(out), ex=3600)
    return out


def get_response_with_cache(query: str):
    return get_kb_result(query)


# run example (keeps the same call pattern you had)
response = get_response_with_cache("user_id='analyst500', query='Suspicious bound traffic detected from multiple hosts in the network'")
print("Response: ", response)

Cache miss! Running RAG chain (with enrichment & scoring).
Response:  {'recommended_resolution': 'Conduct a comprehensive network scan to identify and block suspicious IPs; Alert the network team for further investigation.', 'similar_past_incidents': [{'source': 'security_incidents.txt', 'snippet': 'Incident #034 | User=rakesh | Alert=Outbound traffic to known malicious IP | SourceIP=10.22.3.9 | Host=SRV-DB01 | OS=CentOS 7 | MITRE=T1071 | Severity=High | Resolution=Blocked traffic; Conducted IOC scan.'}, {'source': 'security_incidents.txt', 'snippet': 'Incident #012 | User=kishor | Alert=Detected network port scanning | SourceIP=172.22.9.54 | Host=SRV-WIN-02 | OS=Windows Server 2019 | MITRE=T1046 | Severity=High | Resolution=Blocked IP; Alerted network team.'}], 'entities': {'IPs': ['10.22.3.9', '172.22.9.54'], 'OS': ['Windows Server', 'CentOS', 'Windows'], 'Hostnames': ['SRV-DB01', 'SRV-WIN-02'], 'MITRE': ['T1046', 'T1071'], 'Severity': ['High']}, 'user_id': 'analyst500'}


In [136]:
response = get_response_with_cache("user_id='analyst500', query='Suspicious bound traffic detected from multiple hosts in the network'")
print("Response: ", response)

Cache hit!
Response:  {'recommended_resolution': 'Conduct a comprehensive network scan to identify and block suspicious IPs; Alert the network team for further investigation.', 'similar_past_incidents': [{'source': 'security_incidents.txt', 'snippet': 'Incident #034 | User=rakesh | Alert=Outbound traffic to known malicious IP | SourceIP=10.22.3.9 | Host=SRV-DB01 | OS=CentOS 7 | MITRE=T1071 | Severity=High | Resolution=Blocked traffic; Conducted IOC scan.'}, {'source': 'security_incidents.txt', 'snippet': 'Incident #012 | User=kishor | Alert=Detected network port scanning | SourceIP=172.22.9.54 | Host=SRV-WIN-02 | OS=Windows Server 2019 | MITRE=T1046 | Severity=High | Resolution=Blocked IP; Alerted network team.'}], 'entities': {'IPs': ['10.22.3.9', '172.22.9.54'], 'OS': ['Windows Server', 'CentOS', 'Windows'], 'Hostnames': ['SRV-DB01', 'SRV-WIN-02'], 'MITRE': ['T1046', 'T1071'], 'Severity': ['High']}, 'user_id': 'analyst500'}
