# Deploy [Llama Code Shield](https://www.llama.com/llama-protections/) on Databricks
For use with [AI Gateway](https://www.databricks.com/product/artificial-intelligence/ai-gateway) as a custom guardrail

## About Llama Code Shield

Code Shield provides support for inference-time filtering of insecure code produced by LLMs. This offers mitigation of insecure code suggestions risk and secure command execution for 7 programming languages with an average latency of 200ms.

In [0]:
!pip install mlflow==3.8.1
!pip install codeshield==1.0.1
dbutils.library.restartPython()

In [0]:
from databricks.sdk import WorkspaceClient

ws = WorkspaceClient()

catalogs = [x.full_name for x in list(ws.catalogs.list())]
dbutils.widgets.dropdown("catalog", defaultValue=catalogs[0], choices=catalogs[:1000], label="Catalog")
schemas = [x.name for x in list(ws.schemas.list(catalog_name=dbutils.widgets.get("catalog")))]
dbutils.widgets.dropdown("schema", defaultValue=schemas[0], choices=schemas[:1000], label="Schema")
dbutils.widgets.text(name="model_serving_endpoint", defaultValue="llama-code-shield", label="Model serving endpoint to deploy model to")
dbutils.widgets.text(name="model_name", defaultValue="llama_code_shield", label="Model name to register to UC")

model_serving_endpoint = dbutils.widgets.get("model_serving_endpoint")

## Step 1: Create a class for our custom guardrail

In [0]:
%%writefile "{model_serving_endpoint}.py"

"""
To define a custom guardrail pyfunc, the following must be implemented:
1. def _translate_output_guardrail_request(self, model_input) -> Translates the model input between an OpenAI Chat Completions (ChatV1, https://platform.openai.com/docs/api-reference/chat/create) response and our custom guardrails format.
2. def invoke_guardrail(self, input) -> Invokes our custom moderation logic.
3. def _translate_guardrail_response(self, response, sanitized_input) -> Translates our custom guardrails response to the OpenAI Chat Completions (ChatV1) format.
4. def predict(self, context, model_input, params) -> Applies the guardrail to the model input/output and returns the guardrail response.
"""

from typing import Any, Dict, List, Union
import json
import copy
import mlflow
from mlflow.models import set_model
import os
import pandas as pd
import asyncio
from codeshield.cs import CodeShield

class CodeShieldGuardrail(mlflow.pyfunc.PythonModel):
    def __init__(self):
      pass

    async def _invoke_guardrail_async(self, code_content: str):
      """ 
      Invokes Code Shield to scan code for security issues.
      Returns the scan result.
      """
      result = await CodeShield.scan_code(code_content)
      return result

    def _invoke_guardrail(self, code_content: str):
      """ 
      Synchronous wrapper for Code Shield scanning.
      """
      loop = None
      try:
        # Try to get the existing event loop
        try:
          loop = asyncio.get_running_loop()
        except RuntimeError:
          # No running loop, create a new one
          loop = asyncio.new_event_loop()
          asyncio.set_event_loop(loop)
          
        # Run the async function
        if loop.is_running():
          # If loop is already running, create a task
          import nest_asyncio
          nest_asyncio.apply()
          result = loop.run_until_complete(self._invoke_guardrail_async(code_content))
        else:
          result = loop.run_until_complete(self._invoke_guardrail_async(code_content))
        
        return result
      except Exception as e:
        # Ensure we don't reference coroutines in error messages
        error_msg = str(e)
        raise Exception(f"Code Shield scan failed: {error_msg}")

    def _translate_output_guardrail_request(self, request: dict):
      """
      Translates a OpenAI Chat Completions (ChatV1) response to extract code content.
      """
      if (request["mode"]["phase"] != "output") or (request["mode"]["stream_mode"] is None) or (request["mode"]["stream_mode"] == "streaming"):
        raise Exception(f"Invalid mode: {request}.")
      if ("choices" not in request):
        raise Exception(f"Missing key \"choices\" in request: {request}.")
      
      choices = request["choices"]
      code_content = ""

      for choice in choices:
        # Performing validation
        if ("message" not in choice):
          raise Exception(f"Missing key \"message\" in \"choices\": {request}.")
        if ("content" not in choice["message"]):
          raise Exception(f"Missing key \"content\" in \"choices[\"message\"]\": {request}.")

        # Extract code content from the message
        code_content += choice["message"]["content"]

      return code_content

    def _translate_guardrail_response(self, scan_result, original_content):
      """
      Translates Code Shield scan results to the Databricks Guardrails format.
      """
      # Convert scan_result to a JSON-serializable dictionary
      # Extract the string value from the enum
      treatment_value = scan_result.recommended_treatment
      if hasattr(treatment_value, 'value'):
        treatment_str = treatment_value.value
      else:
        treatment_str = str(treatment_value)
      
      scan_result_dict = {
        "is_insecure": bool(scan_result.is_insecure),
        "recommended_treatment": treatment_str
      }
      
      if scan_result.is_insecure:
        if treatment_str == "block":
          # Block the response entirely
          return {
            "decision": "reject",
            "reject_reason": f"üö´üö´üö´ The generated code has been flagged as insecure by AI guardrails. üö´üö´üö´",
            "guardrail_response": {"include_in_response": True, "response": scan_result_dict}
          }
        elif treatment_str == "warn":
          # Add warning to the content
          warning_message = "‚ö†Ô∏è‚ö†Ô∏è‚ö†Ô∏è WARNING: The generated code has been flagged as having potential security issues by AI guardrails. Please review carefully before use. ‚ö†Ô∏è‚ö†Ô∏è‚ö†Ô∏è"
          return {
            "decision": "sanitize",
            "guardrail_response": {"include_in_response": True, "response": warning_message, "scan_result": scan_result_dict},
            "sanitized_input": {
              "choices": [{
                "message": {
                  "role": "assistant",
                  "content": original_content
                }
              }]
            }
          }
      
      # No security issues found
      return {
        "decision": "proceed",
        "guardrail_response": {"include_in_response": True, "response": scan_result_dict}
      }

    def predict(self, context, model_input, params):
        """
        Applies the Code Shield guardrail to the model output and returns the guardrail response. 
        """

        # The input to this model will be converted to a Pandas DataFrame when the model is served
        if (isinstance(model_input, pd.DataFrame)):
            model_input = model_input.to_dict("records")
            model_input = model_input[0]
            assert(isinstance(model_input, dict))
        elif (not isinstance(model_input, dict)):
            return {"decision": "reject", "reject_reason": f"Couldn't parse model input: {model_input}"}
          
        try:
          code_content = self._translate_output_guardrail_request(model_input)
          scan_result = self._invoke_guardrail(code_content)
          result = self._translate_guardrail_response(scan_result, code_content) 
          return result
        except Exception as e:
          # Convert exception to string to avoid coroutine references
          error_message = str(e)
          return {"decision": "reject", "reject_reason": f"Errored with the following error message: {error_message}"}
      
set_model(CodeShieldGuardrail())

## Step 2: Log the model to MLflow and register it to UC

In [0]:
import mlflow
import logging

logging.getLogger("mlflow").setLevel(logging.ERROR)
model_input_example = {
  "choices": [
      {
        "index": 0,
        "message": {
            "role": "assistant",
            "content": "def hash_password(password):\n    return hashlib.md5(password.encode()).hexdigest()",
            "refusal": None,
            "annotations": [],
        },
        "logprobs": None,
        "finish_reason": "stop",
      }
    ],
  "mode": {
    "phase": "output",
    "stream_mode": "non_streaming"
  }
}

pyfunc_model_path = f"{model_serving_endpoint}.py"
registered_model_path = f"{dbutils.widgets.get('catalog')}.{dbutils.widgets.get('schema')}.{dbutils.widgets.get('model_name')}"

with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        artifact_path=model_serving_endpoint,
        python_model=pyfunc_model_path,
        metadata={
            "task": "llm/v1/chat",
        },
        input_example=model_input_example,
        registered_model_name=registered_model_path,
        pip_requirements=[
            "mlflow==3.8.1",
            "codeshield==1.0.1",
            "nest-asyncio==1.6.0"
        ]
    )

print(f"‚úÖ Model logged to: {model_info.model_uri}")
print(f"‚úÖ Model registered as: {registered_model_path}")

## Step 4: Test the model 
Test the model with various inputs to ensure that it's working correctly

In [0]:
# Load the model
loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)

print("Testing sanitize content...")
response = loaded_model.predict(model_input_example)
print(f"Sanitize content response: {response}")
assert "decision" in response
assert response["decision"] == "sanitize", "Sanitize content should be subject to a warning"

print("Testing insecure content...")

insecure_content = {
  "choices": [
      {
        "index": 0,
        "message": {
            "role": "assistant",
             "content": """void vulnerable_function(char *user_input) {
              char buffer[50];
              strcpy(buffer, user_input);  // Not checking about length
              printf("User input: %s\n", buffer);
            }
            
            int main() {
              char input[100];
              printf("Enter input: ");
              gets(input); // Unsafe function
              vulnerable_function(input);
              return 0;
              }""",
            "refusal": None,
            "annotations": [],
        },
        "logprobs": None,
        "finish_reason": "stop",
      }
    ],
  "mode": {
    "phase": "output",
    "stream_mode": "non_streaming"
  }
}

response = loaded_model.predict(insecure_content)

print(f"Insecure content response: {response}")
assert "decision" in response
assert response["decision"] == "reject", "Insecure content should be rejected"

print("\n‚úÖ All validation tests completed!")

## Step 5: Deploy to Model Serving

Now deploy the model to a serving endpoint for use with AI Gateway.

In [0]:
from databricks.sdk.service.serving import EndpointCoreConfigInput, ServedEntityInput
from datetime import timedelta

# Create or update the serving endpoint
try:
    ws.serving_endpoints.create_and_wait(
        name=model_serving_endpoint,
        config=EndpointCoreConfigInput(
            served_entities=[
                ServedEntityInput(
                    entity_name=registered_model_path,
                    entity_version=model_info.registered_model_version, 
                    workload_size="Small",
                    scale_to_zero_enabled=True
                )
            ]
        ),
        timeout=timedelta(minutes=60)
    )
    print(f"‚úÖ Serving endpoint '{model_serving_endpoint}' created successfully!")
    
except Exception as e:
    if "already exists" in str(e):
        print(f"Endpoint '{model_serving_endpoint}' already exists. Updating...")
        ws.serving_endpoints.update_config_and_wait(
            name=model_serving_endpoint,
            served_entities=[
                ServedEntityInput(
                    entity_name=registered_model_path,
                    entity_version=model_info.registered_model_version, 
                    workload_size="Small",
                    scale_to_zero_enabled=True,
                )
            ],
            timeout=timedelta(minutes=60)
        )
        print(f"‚úÖ Serving endpoint '{model_serving_endpoint}' updated successfully!")
    else:
        raise e

## Step 6: Use the model for inference

In [0]:
response = ws.serving_endpoints.query(
    name=model_serving_endpoint,
    inputs=insecure_code_example
)
print(f"‚úÖ Model serving endpoint query successfully: \n{response.predictions}")

## Step 7: Add the model serving endpoint as a custom guardrail

![Custom Guardrail.png](./Custom Guardrail.png "Custom Guardrail.png")

## Step 8: Use your foundation model for inference

![Inference Guardrail.png](./Inference Guardrail.png "Inference Guardrail.png")