Skip to content

Commit

Permalink
feat(Add langchain module):
Browse files Browse the repository at this point in the history
  • Loading branch information
arekusandr committed Apr 12, 2024
1 parent 834c7e9 commit 753c9e1
Show file tree
Hide file tree
Showing 6 changed files with 1,099 additions and 81 deletions.
4 changes: 2 additions & 2 deletions last_layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .core import RiskModel, scan_llm, scan_prompt
from .core import RiskModel, Threat, scan_llm, scan_prompt

__all__ = ["scan_llm", "scan_prompt", "RiskModel"]
__all__ = ["scan_llm", "scan_prompt", "RiskModel", "Threat"]
3 changes: 3 additions & 0 deletions last_layer/langchain_module/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .llm import LastLayerSecurity

__all__ = ["LastLayerSecurity"]
43 changes: 43 additions & 0 deletions last_layer/langchain_module/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import logging
from typing import Any, Callable, List

import last_layer

from langchain_core.language_models.llms import BaseLLM

logger = logging.getLogger(__name__)


def default_handler(text, risk: last_layer.RiskModel):
if risk.passed:
return
logger.warning(f"Security risk: {risk} detected in text: {text}")


class LastLayerSecurity(BaseLLM):
llm: BaseLLM
handle_prompt_risk: Callable[str, last_layer.RiskModel] = default_handler
handle_response_risk: Callable[str, last_layer.RiskModel] = default_handler
ignore_opts: list[last_layer.Threat] = []

@property
def _llm_type(self) -> str:
return "LastLayerSecurity"

def _generate(
self,
prompts: List[str],
**kwargs: Any,
) -> Any:
"""Run the LLM on the given prompts."""

for prompt in prompts:
risk = last_layer.scan_prompt(prompt)
self.handle_prompt_risk(prompt, risk)
result = self.llm._generate(prompts, **kwargs)

for top_gen in result.generations:
for gen in top_gen:
risk = last_layer.scan_llm(gen.text)
self.handle_response_risk(gen.text, risk)
return result
19 changes: 19 additions & 0 deletions last_layer/langchain_module/test_lc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import unittest
from .llm import LastLayerSecurity


class TestScanPrompt(unittest.TestCase):
@unittest.skip("Not implemented")
def test_integration(self):
# The line `from langchain_openai import OpenAI` is importing the `OpenAI` class from the
# `langchain_openai` module. This allows the code to use the `OpenAI` class and its
# functionalities within the current module or script.
from langchain_contrib.llms.testing import FakeLLM

secure_llm = LastLayerSecurity(
llm=FakeLLM(verbose=True, sequenced_responses=["One", "Two", "Three"])
)
response = secure_llm.invoke(
"Summarize this message: my name is Bob Dylan. My SSN is 123-45-6789."
)
print(f"{response=}")
Loading

0 comments on commit 753c9e1

Please sign in to comment.