In [3]:
from guardrails import Guard
import google.generativeai as genai
import json
from typing import Optional

## Loading the API key for Gemini

In [4]:
def load_api_key():
    with open('config.json') as config_file:
        config = json.load(config_file)
        return config['apiKey']

In [5]:
# configuring the gemini model with the api key
genai.configure(api_key=load_api_key())

# Configuring the Gemini Model

In [6]:
# configurations for the model
gen_config = {
    "temperature": 1,
    "top_p": 0.95,
    "top_k": 64,
    "max_output_tokens": 8192,
    "response_mime_type": "text/plain",
}

In [7]:
# selecting the AI model for the response
model = genai.GenerativeModel(
    model_name="gemini-1.5-flash-latest",
    generation_config=gen_config,
)

In [8]:
# for now there is no history for the chat
chat_session = model.start_chat(
    history=[
    ]
)

# Defining a custom LLM

In [9]:
# This is the function template provided by Guardrails AI to define a custom LLM. Simply call your LLM with a given input and return the string output.

def my_llm_api(
    prompt: Optional[str] = None,
    instruction: Optional[str] = None,
    msg_history: Optional[list[dict]] = None,
    **kwargs
) -> str:
    """Custom LLM API wrapper.

    At least one of prompt, instruction or msg_history should be provided.

    Args:
        prompt (str): The prompt to be passed to the LLM API
        instruction (str): The instruction to be passed to the LLM API
        msg_history (list[dict]): The message history to be passed to the LLM API
        **kwargs: Any additional arguments to be passed to the LLM API

    Returns:
        str: The output of the LLM API
    """

    # Call your LLM API here
    llm_output = chat_session.send_message(prompt).text

    return llm_output

# Using the RAIL spec file to generate response

In [25]:
# Here we import the RAIL spec file and use those configurations to create our guardrail
guard = Guard.from_rail('./guardrail.xml')
_, validated_output, *rest = guard(
    my_llm_api,
)

In [27]:
print(validated_output)

{'fees': [{'index': 1, 'name': 'monthly fee', 'explanation': 'A fee charged for maintaining an account.', 'value': 0.01}, {'index': 2, 'name': 'overdraft fee', 'explanation': 'A fee charged for withdrawing more money than is available.', 'value': 0.02}, {'index': 3, 'name': 'atm fee', 'explanation': 'A fee charged for using an ATM outside of the network.', 'value': 0.03}], 'interest_rates': 'Interest rates vary depending on the type of account or loan.'}
