Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions litellm/proxy/guardrails/guardrail_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,18 +630,21 @@ async def get_guardrail_info(guardrail_id: str):
from litellm.litellm_core_utils.litellm_logging import _get_masked_values
from litellm.proxy.guardrails.guardrail_registry import IN_MEMORY_GUARDRAIL_HANDLER
from litellm.proxy.proxy_server import prisma_client
from litellm.types.guardrails import GUARDRAIL_DEFINITION_LOCATION

if prisma_client is None:
raise HTTPException(status_code=500, detail="Prisma client not initialized")

try:
guardrail_definition_location: GUARDRAIL_DEFINITION_LOCATION = GUARDRAIL_DEFINITION_LOCATION.DB
result = await GUARDRAIL_REGISTRY.get_guardrail_by_id_from_db(
guardrail_id=guardrail_id, prisma_client=prisma_client
)
if result is None:
result = IN_MEMORY_GUARDRAIL_HANDLER.get_guardrail_by_id(
guardrail_id=guardrail_id
)
guardrail_definition_location: GUARDRAIL_DEFINITION_LOCATION = GUARDRAIL_DEFINITION_LOCATION.CONFIG

if result is None:
raise HTTPException(
Expand Down Expand Up @@ -669,6 +672,7 @@ async def get_guardrail_info(guardrail_id: str):
guardrail_info=dict(result.get("guardrail_info") or {}),
created_at=result.get("created_at"),
updated_at=result.get("updated_at"),
guardrail_definition_location=guardrail_definition_location,
)
except HTTPException as e:
raise e
Expand Down
5 changes: 4 additions & 1 deletion litellm/types/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,9 @@ class GuardrailEventHooks(str, Enum):
class DynamicGuardrailParams(TypedDict):
extra_body: Dict[str, Any]

class GUARDRAIL_DEFINITION_LOCATION(str, Enum):
DB = "db"
CONFIG = "config"

class GuardrailInfoResponse(BaseModel):
guardrail_id: Optional[str] = None
Expand All @@ -597,7 +600,7 @@ class GuardrailInfoResponse(BaseModel):
guardrail_info: Optional[Dict] = None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
guardrail_definition_location: Literal["config", "db"] = "config"
guardrail_definition_location: GUARDRAIL_DEFINITION_LOCATION = GUARDRAIL_DEFINITION_LOCATION.CONFIG

def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand Down
68 changes: 67 additions & 1 deletion tests/test_litellm/proxy/guardrails/test_guardrail_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,4 +929,70 @@ async def test_apply_guardrail_execution_error(mocker):
await apply_guardrail(request=request, user_api_key_dict=mock_user_auth)

# Verify error is properly handled
assert "Bedrock guardrail failed" in str(exc_info.value.message)
assert "Bedrock guardrail failed" in str(exc_info.value.message)

@pytest.mark.asyncio
async def test_get_guardrail_info_endpoint_config_guardrail(mocker):
"""
Test get_guardrail_info endpoint returns proper response when guardrail is found in config.
"""
from litellm.proxy.guardrails.guardrail_endpoints import get_guardrail_info

# Mock prisma_client to not be None (patch at the source where it's imported from)
mock_prisma = mocker.Mock()
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma)

# Mock the GUARDRAIL_REGISTRY to return None from DB (so it checks config)
mock_registry = mocker.Mock()
mock_registry.get_guardrail_by_id_from_db = AsyncMock(return_value=None)
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_registry)

# Mock IN_MEMORY_GUARDRAIL_HANDLER at its source to return config guardrail
mock_in_memory_handler = mocker.Mock()
mock_in_memory_handler.get_guardrail_by_id.return_value = MOCK_CONFIG_GUARDRAIL
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)

# Mock _get_masked_values to return values as-is
mocker.patch(
"litellm.litellm_core_utils.litellm_logging._get_masked_values",
side_effect=lambda x, **kwargs: x
)

# Call endpoint and expect GuardrailInfoResponse
result = await get_guardrail_info(guardrail_id="test-config-guardrail")

# Verify the response is of the correct type
assert isinstance(result, GuardrailInfoResponse)
assert result.guardrail_id == "test-config-guardrail"
assert result.guardrail_name == "Test Config Guardrail"
assert result.guardrail_definition_location == "config"

@pytest.mark.asyncio
async def test_get_guardrail_info_endpoint_db_guardrail(mocker):
"""
Test get_guardrail_info endpoint returns proper response when guardrail is found in DB.
"""
from litellm.proxy.guardrails.guardrail_endpoints import get_guardrail_info

# Mock prisma_client to not be None (patch at the source where it's imported from)
mock_prisma = mocker.Mock()
mocker.patch("litellm.proxy.proxy_server.prisma_client", mock_prisma)

# Mock the GUARDRAIL_REGISTRY to return a guardrail from DB
mock_registry = mocker.Mock()
mock_registry.get_guardrail_by_id_from_db = AsyncMock(return_value=MOCK_DB_GUARDRAIL)
mocker.patch("litellm.proxy.guardrails.guardrail_endpoints.GUARDRAIL_REGISTRY", mock_registry)

# Mock IN_MEMORY_GUARDRAIL_HANDLER to return None
mock_in_memory_handler = mocker.Mock()
mock_in_memory_handler.get_guardrail_by_id.return_value = None
mocker.patch("litellm.proxy.guardrails.guardrail_registry.IN_MEMORY_GUARDRAIL_HANDLER", mock_in_memory_handler)

# Call endpoint and expect GuardrailInfoResponse
result = await get_guardrail_info(guardrail_id="test-db-guardrail")

# Verify the response is of the correct type
assert isinstance(result, GuardrailInfoResponse)
assert result.guardrail_id == "test-db-guardrail"
assert result.guardrail_name == "Test DB Guardrail"
assert result.guardrail_definition_location == "db"
113 changes: 113 additions & 0 deletions ui/litellm-dashboard/src/components/guardrails/guardrail_info.test.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import * as networking from "@/components/networking";
import { fireEvent, render, waitFor } from "@testing-library/react";
import { afterEach, describe, expect, it, vi } from "vitest";
import GuardrailInfoView from "./guardrail_info";

// Mock the networking module
vi.mock("@/components/networking", () => ({
getGuardrailInfo: vi.fn(),
getGuardrailUISettings: vi.fn(),
getGuardrailProviderSpecificParams: vi.fn(),
updateGuardrailCall: vi.fn(),
}));

describe("Guardrail Info", () => {
afterEach(() => {
vi.clearAllMocks();
});

it("should render the guardrail info after loading", async () => {
// Mock the network responses
vi.mocked(networking.getGuardrailInfo).mockResolvedValue({
guardrail_id: "123",
guardrail_name: "Test Guardrail",
litellm_params: {
guardrail: "presidio",
mode: "pre_call",
default_on: true,
},
created_at: "2024-01-01T00:00:00Z",
updated_at: "2024-01-01T00:00:00Z",
guardrail_definition_location: "database",
});

vi.mocked(networking.getGuardrailUISettings).mockResolvedValue({
supported_entities: ["PERSON", "EMAIL"],
supported_actions: ["MASK", "REDACT"],
pii_entity_categories: [],
supported_modes: ["pre_call", "post_call"],
});

vi.mocked(networking.getGuardrailProviderSpecificParams).mockResolvedValue({});

const { getAllByText, getByText } = render(
<GuardrailInfoView guardrailId="123" onClose={() => {}} accessToken="123" isAdmin={true} />,
);

// Wait for the loading to complete and data to be rendered
await waitFor(() => {
// The guardrail name appears in multiple places (title and settings tab)
const elements = getAllByText("Test Guardrail");
expect(elements.length).toBeGreaterThan(0);
});

// Verify other key elements are present
expect(getByText("Back to Guardrails")).toBeInTheDocument();
expect(getByText("Overview")).toBeInTheDocument();
expect(getByText("Settings")).toBeInTheDocument();
});

it("should not render the edit button for config guardrails", async () => {
// Mock the network responses
vi.mocked(networking.getGuardrailInfo).mockResolvedValue({
guardrail_id: "123",
guardrail_name: "Test Guardrail",
litellm_params: {
guardrail: "presidio",
mode: "pre_call",
default_on: true,
},
created_at: "2024-01-01T00:00:00Z",
updated_at: "2024-01-01T00:00:00Z",
guardrail_definition_location: "config",
});

vi.mocked(networking.getGuardrailUISettings).mockResolvedValue({
supported_entities: ["PERSON", "EMAIL"],
supported_actions: ["MASK", "REDACT"],
pii_entity_categories: [],
supported_modes: ["pre_call", "post_call"],
});

vi.mocked(networking.getGuardrailProviderSpecificParams).mockResolvedValue({});

const { getByText, container } = render(
<GuardrailInfoView guardrailId="123" onClose={() => {}} accessToken="123" isAdmin={true} />,
);

await waitFor(() => {
expect(getByText("Settings")).toBeInTheDocument();
});

// Click the Settings tab
fireEvent.click(getByText("Settings"));

// Wait for the Settings panel to render
await waitFor(() => {
expect(getByText("Guardrail Settings")).toBeInTheDocument();
});

// Find the info icon and hover over it
const infoIcon = container.querySelector(".anticon-info-circle");
expect(infoIcon).toBeInTheDocument();

if (infoIcon) {
fireEvent.mouseEnter(infoIcon);

// Wait for the tooltip to appear
await waitFor(() => {
expect(getByText("Guardrail is defined in the config file and cannot be edited.")).toBeInTheDocument();
});
}
});
});
14 changes: 12 additions & 2 deletions ui/litellm-dashboard/src/components/guardrails/guardrail_info.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ import {
TabPanels,
TextInput,
} from "@tremor/react";
import { Button, Form, Input, Select, Divider } from "antd";
import { Button, Form, Input, Select, Divider, Tooltip } from "antd";
import { InfoCircleOutlined } from "@ant-design/icons";
import {
getGuardrailInfo,
updateGuardrailCall,
Expand Down Expand Up @@ -328,6 +329,8 @@ const GuardrailInfoView: React.FC<GuardrailInfoProps> = ({ guardrailId, onClose,
}
};

const isConfigGuardrail = guardrailData.guardrail_definition_location === "config";

return (
<div className="p-4">
<div>
Expand Down Expand Up @@ -434,7 +437,14 @@ const GuardrailInfoView: React.FC<GuardrailInfoProps> = ({ guardrailId, onClose,
<Card>
<div className="flex justify-between items-center mb-4">
<Title>Guardrail Settings</Title>
{!isEditing && <TremorButton onClick={() => setIsEditing(true)}>Edit Settings</TremorButton>}
{isConfigGuardrail && (
<Tooltip title="Guardrail is defined in the config file and cannot be edited.">
<InfoCircleOutlined />
</Tooltip>
)}
{!isEditing && !isConfigGuardrail && (
<TremorButton onClick={() => setIsEditing(true)}>Edit Settings</TremorButton>
)}
</div>

{isEditing ? (
Expand Down
Loading