Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions agent/src/attack_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""
Argus Agent — attack chain correlation engine
Copyright (c) 2026 Kaushikkumaran

Correlates individual Falco alerts into multi-stage attack chains.
Uses a 30-minute sliding window to group related alerts by namespace/node.

MITRE ATT&CK kill chain stages mapped to Falco rules:
Reconnaissance → port scans, service enumeration
Initial Access → shell spawn, unexpected process
Execution → curl/wget, script execution, memfd
Persistence → cron modification, startup file write
Privilege Escalation → capability grant, sudo, SUID
Lateral Movement → cross-namespace traffic, SSH
Exfiltration → large outbound transfer, DNS tunneling
Defense Evasion → log clearing, binary deletion
"""

import hashlib
import time
from collections import defaultdict
from datetime import datetime, timezone
import structlog

log = structlog.get_logger()

CHAIN_WINDOW_SECONDS = 1800 # 30 minutes
MAX_CHAINS = 50

# MITRE kill chain stage mapping
RULE_TO_STAGE = {
"fileless execution via memfd_create": "Execution",
"shell spawned in container": "Initial Access",
"curl or wget executed in container": "Execution",
"read sensitive file untrusted": "Reconnaissance",
"write below etc": "Persistence",
"clear log activities": "Defense Evasion",
"contact k8s api server from container": "Reconnaissance",
"network tool launched in container": "Reconnaissance",
"modify binary dirs": "Persistence",
"sudo potential privilege escalation": "Privilege Escalation",
"launch privileged container": "Privilege Escalation",
"outbound connection to c2 server": "Exfiltration",
"exfiltration over alternative protocol": "Exfiltration",
"container escape attempt": "Privilege Escalation",
"ptrace attached to process": "Execution",
}

STAGE_ORDER = [
"Reconnaissance",
"Initial Access",
"Execution",
"Privilege Escalation",
"Persistence",
"Defense Evasion",
"Lateral Movement",
"Exfiltration",
]

STAGE_MITRE = {
"Reconnaissance": "TA0043",
"Initial Access": "TA0001",
"Execution": "TA0002",
"Privilege Escalation": "TA0004",
"Persistence": "TA0003",
"Defense Evasion": "TA0005",
"Lateral Movement": "TA0008",
"Exfiltration": "TA0010",
}

# In-memory chain store
attack_chains: list[dict] = []

# Active correlation windows: key → list of alerts
_windows: dict[str, list[dict]] = defaultdict(list)


def _correlation_key(audit_entry: dict) -> str:
"""Group alerts by namespace+node — same attacker likely in same area."""
ns = audit_entry.get("namespace") or "host"
node = audit_entry.get("hostname") or "unknown"
return f"{ns}:{node}"


def _get_stage(rule: str) -> str:
"""Map Falco rule name to MITRE kill chain stage."""
rule_lower = rule.lower()
for pattern, stage in RULE_TO_STAGE.items():
if pattern in rule_lower:
return stage
return "Execution" # Default assumption


def _chain_confidence(stages: list[str]) -> float:
"""
More stages = higher confidence this is a real attack.
Single stage: low confidence.
3+ distinct stages: high confidence.
"""
distinct = len(set(stages))
if distinct == 1:
return 0.35
elif distinct == 2:
return 0.60
elif distinct == 3:
return 0.80
else:
return min(0.95, 0.80 + (distinct - 3) * 0.05)


def _build_chain(key: str, alerts: list[dict]) -> dict:
"""Build a chain object from correlated alerts."""
stages = [a["stage"] for a in alerts]
stage_order_indices = {s: i for i, s in enumerate(STAGE_ORDER)}
sorted_stages = sorted(set(stages), key=lambda s: stage_order_indices.get(s, 99))

first_ts = min(a["ts"] for a in alerts)
last_ts = max(a["ts"] for a in alerts)
duration_seconds = int(last_ts - first_ts)

ns = alerts[0].get("namespace") or "host"
node = alerts[0].get("hostname") or "unknown"
pod = alerts[0].get("pod") or "unknown"

chain_id = hashlib.md5(f"{key}{first_ts}".encode()).hexdigest()[:12]

return {
"id": chain_id,
"created_at": datetime.fromtimestamp(first_ts, tz=timezone.utc).isoformat(),
"last_seen": datetime.fromtimestamp(last_ts, tz=timezone.utc).isoformat(),
"duration_seconds": duration_seconds,
"namespace": ns,
"hostname": node,
"pod": pod,
"alert_count": len(alerts),
"stages_detected": sorted_stages,
"stage_count": len(sorted_stages),
"confidence": _chain_confidence(stages),
"severity": "CRITICAL" if len(sorted_stages) >= 3 else "HIGH" if len(sorted_stages) == 2 else "MED",
"alerts": [
{
"rule": a["rule"],
"stage": a["stage"],
"mitre_tactic": STAGE_MITRE.get(a["stage"], ""),
"ts": a["ts"],
"severity": a.get("severity", "MED"),
}
for a in sorted(alerts, key=lambda x: x["ts"])
],
"mitre_tactics": [STAGE_MITRE.get(s, "") for s in sorted_stages],
}


def correlate_alert(audit_entry: dict) -> dict | None:
"""
Add alert to correlation window.
Returns a chain object if a new chain is detected, None otherwise.
"""
now = time.time()
key = _correlation_key(audit_entry)

alert_entry = {
"rule": audit_entry.get("rule", "unknown"),
"stage": _get_stage(audit_entry.get("rule", "")),
"ts": now,
"severity": audit_entry.get("severity", "MED"),
"namespace": audit_entry.get("namespace"),
"hostname": audit_entry.get("hostname"),
"pod": audit_entry.get("pod"),
}

# Evict expired alerts from window
_windows[key] = [
a for a in _windows[key]
if now - a["ts"] < CHAIN_WINDOW_SECONDS
]

_windows[key].append(alert_entry)

# Need at least 2 alerts to form a chain
if len(_windows[key]) < 2:
return None

stages = [a["stage"] for a in _windows[key]]
distinct_stages = set(stages)

# Only create chain if we see 2+ distinct kill chain stages
if len(distinct_stages) < 2:
return None

chain = _build_chain(key, _windows[key])

# Update or add chain
existing = next((c for c in attack_chains if c["id"] == chain["id"]), None)
if existing:
existing.update(chain)
else:
attack_chains.append(chain)
if len(attack_chains) > MAX_CHAINS:
attack_chains.pop(0)
log.info(
"attack_chain_detected",
chain_id=chain["id"],
stages=chain["stages_detected"],
confidence=chain["confidence"],
namespace=chain["namespace"],
alert_count=chain["alert_count"],
)

return chain
13 changes: 13 additions & 0 deletions agent/src/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,19 @@ async def audit_log(
if len(incident_store) > 500:
incident_store.pop(0)

try:
from attack_chain import correlate_alert
chain = correlate_alert(entry)
if chain:
log.info(
"attack_chain_updated",
chain_id=chain["id"],
stages=chain["stages_detected"],
confidence=chain["confidence"],
)
except Exception as e:
log.warning("attack_chain_error", error=str(e))

log.info(
"argus_audit",
**{k: v for k, v in entry.items() if k not in ("assessment", "blast_radius")}
Expand Down
59 changes: 57 additions & 2 deletions agent/src/enricher.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,56 @@ async def fetch_policy_violations(namespace: str) -> list[dict] | None:
return None


async def fetch_vulnerability_report(namespace: str, pod_name: str) -> dict | None:
"""
Fetch Trivy vulnerability report for the pod's container image.
Reads VulnerabilityReport CRDs from the Kubernetes API.
"""
if not namespace or not pod_name:
return None
try:
k8s = _get_k8s_client()
if not k8s:
return None
custom_api = k8s.CustomObjectsApi()
loop = asyncio.get_event_loop()
reports = await loop.run_in_executor(
None,
lambda: custom_api.list_namespaced_custom_object(
group="aquasecurity.github.io",
version="v1alpha1",
namespace=namespace,
plural="vulnerabilityreports",
)
)
total_critical = 0
total_high = 0
top_cves = []
for report in reports.get("items", []):
if pod_name in report.get("metadata", {}).get("name", ""):
summary = report.get("report", {}).get("summary", {})
total_critical += summary.get("criticalCount", 0)
total_high += summary.get("highCount", 0)
vulns = report.get("report", {}).get("vulnerabilities", [])
for v in vulns[:3]:
if v.get("severity") in ("CRITICAL", "HIGH"):
top_cves.append({
"id": v.get("vulnerabilityID"),
"severity": v.get("severity"),
"package": v.get("resource"),
"fixed_version": v.get("fixedVersion"),
})
return {
"critical_count": total_critical,
"high_count": total_high,
"top_cves": top_cves[:5],
"risk_score": min(100, total_critical * 20 + total_high * 5),
}
except Exception as e:
log.warning("trivy_fetch_failed", pod=pod_name, namespace=namespace, error=str(e))
return None


async def enrich_context(alert_payload: dict) -> dict:
"""
Main enrichment entry point. Runs all data source queries concurrently.
Expand Down Expand Up @@ -310,19 +360,20 @@ async def enrich_context(alert_payload: dict) -> dict:

# Run all queries concurrently with timeout
try:
pod_ctx, logs, flows, violations = await asyncio.wait_for(
pod_ctx, logs, flows, violations, vuln_report = await asyncio.wait_for(
asyncio.gather(
fetch_pod_context(namespace, pod_name),
fetch_recent_logs(namespace, pod_name),
fetch_network_flows(namespace, pod_name),
fetch_policy_violations(namespace),
fetch_vulnerability_report(namespace, pod_name),
return_exceptions=True,
),
timeout=ENRICHMENT_TIMEOUT,
)
except asyncio.TimeoutError:
log.warning("enrichment_timeout", timeout=ENRICHMENT_TIMEOUT)
pod_ctx = logs = flows = violations = None
pod_ctx = logs = flows = violations = vuln_report = None

# Handle exceptions from individual queries (return_exceptions=True)
if isinstance(pod_ctx, Exception):
Expand All @@ -333,6 +384,8 @@ async def enrich_context(alert_payload: dict) -> dict:
flows = None
if isinstance(violations, Exception):
violations = None
if isinstance(vuln_report, Exception):
vuln_report = None

duration_ms = round((asyncio.get_event_loop().time() - start) * 1000)

Expand All @@ -342,6 +395,7 @@ async def enrich_context(alert_payload: dict) -> dict:
("loki", logs),
("hubble", flows),
("kyverno", violations),
("trivy", vuln_report),
]
if val is not None
]
Expand All @@ -359,6 +413,7 @@ async def enrich_context(alert_payload: dict) -> dict:
"logs": logs,
"flows": flows,
"violations": violations,
"vulnerabilities": vuln_report,
"enrichment_duration_ms": duration_ms,
"enrichment_sources": successful_sources,
}
16 changes: 16 additions & 0 deletions agent/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,22 @@ async def get_incident_stats():
"total_all_time": len(incident_store),
}

@app.get("/attack-chains")
async def get_attack_chains():
from attack_chain import attack_chains
return {
"chains": list(reversed(attack_chains[-20:])),
"total": len(attack_chains),
}

@app.get("/attack-chains/{chain_id}")
async def get_attack_chain(chain_id: str):
from attack_chain import attack_chains
chain = next((c for c in attack_chains if c["id"] == chain_id), None)
if not chain:
raise HTTPException(status_code=404, detail="Chain not found")
return chain


@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
Expand Down
13 changes: 13 additions & 0 deletions agent/src/reasoning.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _build_user_prompt(context: dict) -> str:
logs = context.get("logs")
flows = context.get("flows")
violations = context.get("violations")
vulnerabilities = context.get("vulnerabilities")

sections = []

Expand Down Expand Up @@ -221,6 +222,18 @@ def _build_user_prompt(context: dict) -> str:
else:
sections.append("## Active Policy Violations\nUnavailable.")

# Image vulnerability context
if vulnerabilities:
sections.append(f"""## Image Vulnerabilities
Critical CVEs: {vulnerabilities.get("critical_count", 0)}
High CVEs: {vulnerabilities.get("high_count", 0)}
Image risk score: {vulnerabilities.get("risk_score", 0)}/100
Top CVEs: {json.dumps(vulnerabilities.get("top_cves", []))}""")
elif vulnerabilities == {}:
sections.append("## Image Vulnerabilities\nNone reported.")
else:
sections.append("## Image Vulnerabilities\nUnavailable.")

return "\n\n".join(sections)


Expand Down
Loading
Loading