Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 89 additions & 22 deletions nemoguardrails/library/jailbreak_detection/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,25 @@

import logging
import os
from typing import Optional
from time import time
from typing import Dict, Optional

from nemoguardrails.actions import action
from nemoguardrails.context import llm_call_info_var
from nemoguardrails.library.jailbreak_detection.request import (
jailbreak_detection_heuristics_request,
jailbreak_detection_model_request,
jailbreak_nim_request,
)
from nemoguardrails.llm.cache import CacheInterface
from nemoguardrails.llm.cache.utils import (
CacheEntry,
create_normalized_cache_key,
get_from_cache_and_restore_stats,
)
from nemoguardrails.llm.taskmanager import LLMTaskManager
from nemoguardrails.logging.explain import LLMCallInfo
from nemoguardrails.logging.processing_log import processing_log_var

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,6 +99,7 @@ async def jailbreak_detection_heuristics(
async def jailbreak_detection_model(
llm_task_manager: LLMTaskManager,
context: Optional[dict] = None,
model_caches: Optional[Dict[str, CacheInterface]] = None,
) -> bool:
"""Uses a trained classifier to determine if a user input is a jailbreak attempt"""
prompt: str = ""
Expand All @@ -102,6 +113,30 @@ async def jailbreak_detection_model(
if context is not None:
prompt = context.get("user_message", "")

# we do this as a hack to treat this action as an LLM call for tracing
llm_call_info_var.set(LLMCallInfo(task="jailbreak_detection_model"))

cache = model_caches.get("jailbreak_detection") if model_caches else None

if cache:
cache_key = create_normalized_cache_key(prompt)
cache_read_start = time()
cached_result = get_from_cache_and_restore_stats(cache, cache_key)
if cached_result is not None:
cache_read_duration = time() - cache_read_start
llm_call_info = llm_call_info_var.get()
if llm_call_info:
llm_call_info.from_cache = True
llm_call_info.duration = cache_read_duration
llm_call_info.started_at = time() - cache_read_duration
llm_call_info.finished_at = time()

log.debug("Jailbreak detection cache hit")
return cached_result["jailbreak"]

jailbreak_result = None
api_start_time = time()

if not jailbreak_api_url and not nim_base_url:
from nemoguardrails.library.jailbreak_detection.model_based.checks import (
check_jailbreak,
Expand All @@ -114,32 +149,64 @@ async def jailbreak_detection_model(
try:
jailbreak = check_jailbreak(prompt=prompt)
log.info(f"Local model jailbreak detection result: {jailbreak}")
return jailbreak["jailbreak"]
jailbreak_result = jailbreak["jailbreak"]
except RuntimeError as e:
log.error(f"Jailbreak detection model not available: {e}")
return False
jailbreak_result = False
except ImportError as e:
log.error(
f"Failed to import required dependencies for local model. Install scikit-learn and torch, or use NIM-based approach",
exc_info=e,
)
return False

if nim_base_url:
jailbreak = await jailbreak_nim_request(
prompt=prompt,
nim_url=nim_base_url,
nim_auth_token=nim_auth_token,
nim_classification_path=nim_classification_path,
)
elif jailbreak_api_url:
jailbreak = await jailbreak_detection_model_request(
prompt=prompt, api_url=jailbreak_api_url
)

if jailbreak is None:
log.warning("Jailbreak endpoint not set up properly.")
# If no result, assume not a jailbreak
return False
jailbreak_result = False
else:
return jailbreak
if nim_base_url:
jailbreak = await jailbreak_nim_request(
prompt=prompt,
nim_url=nim_base_url,
nim_auth_token=nim_auth_token,
nim_classification_path=nim_classification_path,
)
elif jailbreak_api_url:
jailbreak = await jailbreak_detection_model_request(
prompt=prompt, api_url=jailbreak_api_url
)

if jailbreak is None:
log.warning("Jailbreak endpoint not set up properly.")
jailbreak_result = False
else:
jailbreak_result = jailbreak

api_duration = time() - api_start_time

llm_call_info = llm_call_info_var.get()
if llm_call_info:
llm_call_info.from_cache = False
llm_call_info.duration = api_duration
llm_call_info.started_at = api_start_time
llm_call_info.finished_at = time()

processing_log = processing_log_var.get()
if processing_log is not None:
processing_log.append(
{
"type": "llm_call_info",
"timestamp": time(),
"data": llm_call_info,
}
)

if cache:
from nemoguardrails.llm.cache.utils import extract_llm_metadata_for_cache

cache_key = create_normalized_cache_key(prompt)
cache_entry: CacheEntry = {
"result": {"jailbreak": jailbreak_result},
"llm_stats": None,
"llm_metadata": extract_llm_metadata_for_cache(),
}
cache.put(cache_key, cache_entry)
log.debug("Jailbreak detection result cached")

return jailbreak_result
3 changes: 3 additions & 0 deletions nemoguardrails/llm/cache/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def get_from_cache_and_restore_stats(
if cached_metadata:
restore_llm_metadata_from_cache(cached_metadata)

if cached_metadata:
restore_llm_metadata_from_cache(cached_metadata)

processing_log = processing_log_var.get()
if processing_log is not None:
llm_call_info = llm_call_info_var.get()
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def _init_llms(self):
llms = dict()

for llm_config in self.config.models:
if llm_config.type == "embeddings":
if llm_config.type in ["embeddings", "jailbreak_detection"]:
continue

# If a constructor LLM is provided, skip initializing any 'main' model from config
Expand Down
Loading