Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Thin REST client for Cisco AI Defense Chat Inspection.
# Uses httpx.AsyncClient and the OpenAPI-defined endpoint/header.
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any

try:
Expand Down Expand Up @@ -37,11 +37,15 @@ class AIDefenseClient:
timeout_s: Timeout in seconds
"""

api_key: str
api_key: str = field(repr=False)
endpoint_url: str
timeout_s: float

_client: httpx.AsyncClient | None = None # type: ignore[name-defined]
_client: httpx.AsyncClient | None = field( # type: ignore[name-defined]
default=None,
repr=False,
compare=False,
)

async def _get_client(self) -> httpx.AsyncClient: # type: ignore[name-defined]
if not AI_DEFENSE_HTTPX_AVAILABLE: # pragma: no cover
Expand Down Expand Up @@ -88,7 +92,7 @@ async def close(self) -> None:
"""Close the HTTP client and release resources."""
await self.aclose()

async def __aenter__(self) -> "AIDefenseClient":
async def __aenter__(self) -> AIDefenseClient:
"""Async context manager entry."""
return self

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

from typing import Any, Literal
from pydantic import Field

from agent_control_evaluators import EvaluatorConfig
from pydantic import Field


class CiscoAIDefenseConfig(EvaluatorConfig):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from typing import Any
import json
import os
from typing import Any

from agent_control_evaluators import (
Evaluator,
Expand All @@ -10,7 +11,7 @@
)
from agent_control_models import EvaluatorResult

from .client import REGION_BASE_URLS, AIDefenseClient, build_endpoint, AI_DEFENSE_HTTPX_AVAILABLE
from .client import AI_DEFENSE_HTTPX_AVAILABLE, REGION_BASE_URLS, AIDefenseClient, build_endpoint
from .config import CiscoAIDefenseConfig


Expand Down Expand Up @@ -46,10 +47,49 @@ def _build_messages(
# Fallback to single

role = "assistant" if payload_field == "output" else "user"
content = "" if data is None else str(data)
content = _coerce_message_content(data, payload_field)
return [{"role": role, "content": content}]


def _coerce_message_content(data: Any, payload_field: str | None) -> str:
if data is None:
return ""

if isinstance(data, dict):
candidate: Any = None
preferred_keys: list[str] = []
if payload_field is not None:
preferred_keys.append(payload_field)
if payload_field == "output":
preferred_keys.extend(["output", "content", "text", "message"])
else:
preferred_keys.extend(["input", "content", "text", "message"])

for key in preferred_keys:
if key in data and data[key] is not None:
candidate = data[key]
break

if candidate is None:
candidate = data
return _stringify_message_content(candidate)

return _stringify_message_content(data)


def _stringify_message_content(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value
if isinstance(value, (int, float, bool)):
return str(value)
try:
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
except TypeError:
return str(value)


@register_evaluator
class CiscoAIDefenseEvaluator(Evaluator[CiscoAIDefenseConfig]):
"""Cisco AI Defense evaluator.
Expand Down Expand Up @@ -79,7 +119,7 @@ def __init__(self, config: CiscoAIDefenseConfig) -> None:
# API key
try:
api_key = _load_api_key(self.config.api_key_env)
except Exception as e: # noqa: BLE001
except RuntimeError as e:
# Fail fast during construction so misconfiguration is caught early.
raise ValueError(str(e)) from e

Expand Down Expand Up @@ -144,27 +184,34 @@ async def evaluate(self, data: Any) -> EvaluatorResult: # noqa: D401

# If no boolean is present, consider it an evaluator error
fallback = self.config.on_error
matched = fallback == "deny"
error_message = "Cisco AI Defense response missing 'is_safe'"
meta2: dict[str, Any] = {"fallback_action": fallback}
if self.config.include_raw_response:
meta2["raw"] = response
return EvaluatorResult(
matched=(fallback == "deny"),
matched=matched,
confidence=0.0,
message="Cisco AI Defense response missing 'is_safe'",
message=error_message,
metadata=meta2,
error=None if matched else error_message,
)
except Exception as e: # noqa: BLE001
fallback = self.config.on_error
matched = fallback == "deny"
# Pydantic model enforces: if error is set, matched must be False.
# Expose details via metadata always; set error field only on fail-open.
error_detail = str(e)
return EvaluatorResult(
matched=matched,
confidence=0.0,
message=f"Cisco AI Defense evaluation error: {e}",
message=f"Cisco AI Defense evaluation error: {error_detail}",
metadata={
"error": str(e),
"error": error_detail,
"error_type": type(e).__name__,
"fallback_action": fallback,
},
error=None if matched else error_detail,
)

async def aclose(self) -> None:
"""Close the underlying Cisco AI Defense client."""
await self._client.aclose()
1 change: 1 addition & 0 deletions evaluators/contrib/cisco/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for the Cisco AI Defense contrib evaluator."""
20 changes: 11 additions & 9 deletions evaluators/contrib/cisco/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import importlib
import sys
from types import SimpleNamespace
from typing import Any, Dict, List
from typing import Any

import pytest

Expand All @@ -13,24 +15,24 @@

@pytest.mark.asyncio
async def test_chat_inspect_happy_path_builds_headers_and_payload(monkeypatch: pytest.MonkeyPatch) -> None:
captured: Dict[str, Any] = {}
captured: dict[str, Any] = {}

class FakeResponse:
def __init__(self, data: Dict[str, Any]):
def __init__(self, data: dict[str, Any]):
self._data = data

def raise_for_status(self) -> None: # no-op
return None

def json(self) -> Dict[str, Any]:
def json(self) -> dict[str, Any]:
return self._data

class FakeAsyncClient:
def __init__(self, *_, **kwargs: Any):
captured["timeout"] = kwargs.get("timeout")
self.is_closed = False

async def post(self, url: str, json: Dict[str, Any], headers: Dict[str, str]):
async def post(self, url: str, json: dict[str, Any], headers: dict[str, str]):
captured["url"] = url
captured["json"] = json
captured["headers"] = headers
Expand Down Expand Up @@ -68,6 +70,7 @@ async def aclose(self) -> None:
assert captured["json"]["messages"][0]["content"] == "hello"
assert captured["json"]["metadata"] == {"trace_id": "t1"}
assert captured["json"]["config"] == {"mode": "strict"}
assert "api_key='k'" not in repr(c)


@pytest.mark.asyncio
Expand Down Expand Up @@ -108,7 +111,7 @@ class FakeResponse:
def raise_for_status(self) -> None:
raise FakeHTTPError("bad status")

def json(self) -> Dict[str, Any]: # never reached
def json(self) -> dict[str, Any]: # never reached
return {}

class FakeAsyncClient:
Expand All @@ -134,7 +137,7 @@ async def aclose(self) -> None:

@pytest.mark.asyncio
async def test_get_client_lifecycle_create_reuse_recreate(monkeypatch: pytest.MonkeyPatch) -> None:
instances: List["FakeAsyncClient"] = []
instances: list[FakeAsyncClient] = []

class FakeAsyncClient:
def __init__(self, *_, **kwargs: Any):
Expand Down Expand Up @@ -198,7 +201,7 @@ def test_build_endpoint_trailing_slash() -> None:
async def test_importerror_path_disables_httpx_and_get_client_raises(monkeypatch: pytest.MonkeyPatch) -> None:
# Ensure a clean reimport of the client module with ImportError for httpx
monkeypatch.setitem(sys.modules, "httpx", None)

class ImportBlocker:
def find_spec(self, fullname, path=None, target=None): # type: ignore[no-untyped-def]
if fullname == "httpx":
Expand All @@ -222,4 +225,3 @@ def find_spec(self, fullname, path=None, target=None): # type: ignore[no-untype
# Best effort: if httpx is available, restore it by deleting None placeholder
if sys.modules.get("httpx") is None:
del sys.modules["httpx"]

23 changes: 19 additions & 4 deletions evaluators/contrib/cisco/tests/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ async def fake(self: AIDefenseClient, **kwargs):
ev = CiscoAIDefenseEvaluator(cfg)
res = await ev.evaluate("text")
assert res.matched is False
assert res.error == "Cisco AI Defense response missing 'is_safe'"
assert res.metadata and res.metadata.get("fallback_action") == "allow"
assert "raw" not in (res.metadata or {})

Expand Down Expand Up @@ -135,10 +136,8 @@ async def capture(self: AIDefenseClient, **_):
assert captured["endpoint_url"] == "https://example.com/custom/chat"


## Removed: internal client reuse test

@pytest.mark.asyncio
async def test_on_error_allow_fail_open_no_error_field(monkeypatch: pytest.MonkeyPatch) -> None:
async def test_on_error_allow_fail_open_sets_error_field(monkeypatch: pytest.MonkeyPatch) -> None:
async def boom(self: AIDefenseClient, **kwargs):
raise RuntimeError("network down")

Expand All @@ -148,7 +147,7 @@ async def boom(self: AIDefenseClient, **kwargs):
ev = CiscoAIDefenseEvaluator(cfg)
res = await ev.evaluate("anything")
assert res.matched is False
assert res.error is None
assert res.error == "network down"
assert res.metadata and res.metadata.get("fallback_action") == "allow"


Expand Down Expand Up @@ -201,6 +200,22 @@ async def capture(self: AIDefenseClient, messages, **_):
assert captured["messages"] == [{"role": "user", "content": "hello world"}]


@pytest.mark.asyncio
async def test_dict_input_prefers_input_field_over_python_repr(monkeypatch: pytest.MonkeyPatch) -> None:
captured = {}

async def capture(self: AIDefenseClient, messages, **_):
captured["messages"] = messages
return {"is_safe": True}

monkeypatch.setattr(AIDefenseClient, "chat_inspect", capture, raising=True)

cfg = CiscoAIDefenseConfig(messages_strategy="single", payload_field="input")
ev = CiscoAIDefenseEvaluator(cfg)
_ = await ev.evaluate({"input": "hello world", "extra": "ignored"})
assert captured["messages"] == [{"role": "user", "content": "hello world"}]


def test_on_error_validation() -> None:
"""on_error must be either 'allow' or 'deny'."""
# Valid values
Expand Down
4 changes: 2 additions & 2 deletions examples/cisco_ai_defense/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
help:
@echo "Cisco AI Defense examples"
@echo " make run - run direct Chat Inspection demo"
@echo " make seed - seed controls/policy on server"
@echo " make register - register agent and persist .agent_id"
@echo " make seed - seed controls on the server and attach them to the agent"
@echo " make register - register the agent by name"
@echo " make decorator-post-run - run decorator POST-only (PII) example"
@echo " make decorator-all-run - run combined pre + post cases"

Expand Down
6 changes: 6 additions & 0 deletions examples/cisco_ai_defense/chat_inspect_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ class InspectOutcome:


class ChatInspectClient:
"""Standalone direct-HTTP client used by the demo.

This example intentionally avoids importing the contrib evaluator package so
the direct API demo can run with only the example environment dependencies.
"""

def __init__(
self,
api_key: str,
Expand Down
17 changes: 13 additions & 4 deletions examples/cisco_ai_defense/setup_ai_defense_controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,19 @@ async def _ensure_control(
except Exception as e: # noqa: BLE001
s = str(e).lower()
if "409" in s or "already" in s:
existing = await controls.list_controls(client, name=name, limit=1)
items = existing.get("controls", [])
if items:
return int(items[0]["id"]) # type: ignore[index]
cursor: int | None = None
while True:
existing = await controls.list_controls(client, name=name, limit=100, cursor=cursor)
items = existing.get("controls", [])
for item in items:
if item.get("name") == name:
return int(item["id"]) # type: ignore[index]

pagination = existing.get("pagination", {})
if not pagination.get("has_more"):
break
next_cursor = pagination.get("next_cursor")
cursor = int(next_cursor) if next_cursor is not None else None
raise


Expand Down
Loading