# Create Bedrock Agents for Splunk AI Assistant with Function Definition

In this notebook we will create an Agent for Splunk usecases in Amazon Bedrock using the new capabilities for function definition.

We will create agents and functions in Amazon Bedrock for various use cases like getting Splunk sourcetypes for a given NLP question, getting the schema for the given source type, getting lookups and values for given source types and function to execute the SPL query.


## Prerequisites
Before starting, let's update the botocore and boto3 packages to ensure we have the latest version

In [None]:
!python3 -m pip install --upgrade -q botocore
!python3 -m pip install --upgrade -q boto3
!python3 -m pip install --upgrade -q awscli

Let's now check the boto3 version to ensure the correct version has been installed. Your version should be greater than or equal to 1.34.90.

In [None]:
import boto3
import json
import time
import zipfile
from io import BytesIO
import uuid
import pprint
import logging
print(boto3.__version__)

Let's now create the boto3 clients for the required AWS services

In [2]:
# getting boto3 clients for required AWS services
sts_client = boto3.client('sts')
iam_client = boto3.client('iam')
lambda_client = boto3.client('lambda')
bedrock_agent_client = boto3.client('bedrock-agent')
bedrock_agent_runtime_client = boto3.client('bedrock-agent-runtime')

Next we can set some configuration variables for the agent and for the lambda function being created

In [None]:
session = boto3.session.Session()
region = session.region_name
account_id = sts_client.get_caller_identity()["Account"]
region, account_id

In [4]:
# configuration variables
suffix = f"{region}-{account_id}"
agent_name = "splunk-chatbot-agent"
agent_bedrock_allow_policy_name = f"{agent_name}-ba-{suffix}"
agent_role_name = f'AmazonBedrockExecutionRoleForAgents_{agent_name}'
agent_foundation_model = "anthropic.claude-3-sonnet-20240229-v1:0"
agent_description = "Security Agent for querying Splunk and execute SPL Queries"
agent_instruction = "You are a Splunk chatbot agent, helping security analysts to write efficient Splunk \
SPL queries and execute them to analyze the results. You need to understand the schema of Splunk sourcetypes \
to write efficient SPL queries."
agent_action_group_name = "SplunkQueryAgentGroup"
agent_action_group_description = "Actions for writing efficient SPL queries for AWS data sources and interpret the results"
agent_alias_name = f"{agent_name}-alias"
lambda_function_role = f'{agent_name}-lambda-role-{suffix}'
lambda_function_name = f'{agent_name}-{suffix}'

Replace with secret ARN you created as part of the pre-requisite

In [None]:
secret_arn="arn:aws:secretsmanager:AWS-REGION:AWS-ACCOUNTID:secret:SECRET-NAME" 

## Prepare RAG VectorDB for AWS source types
In this step we will create Embeddings for AWS source types and store it in OpenSearch VectorDB.

In [69]:
# Initialize the clients
opensearchserverless_client = boto3.client('opensearchserverless')
sts_client = boto3.client('sts')
connected_role = sts_client.get_caller_identity()['Arn'].split('/')[-2]

Lets create functions and policies for OpenSearch serverless and create the collections to store embeddings.

In [None]:
def create_vector_collection(collection_name):
    try:
        response = opensearchserverless_client.create_collection(
            name=collection_name,
            description='Vector search collection',
            type='VECTORSEARCH'
        )
        print(f"Vector collection creation initiated: {response['createCollectionDetail']['name']}")
        return response['createCollectionDetail']['id']
    except opensearchserverless_client.exceptions.ConflictException:
        print(f"Collection {collection_name} already exists.")
        collections = opensearchserverless_client.list_collections()['collectionSummaries']
        return next((c['id'] for c in collections if c['name'] == collection_name), None)

def wait_for_collection_creation(collection_id):
    while True:
        response = opensearchserverless_client.batch_get_collection(ids=[collection_id])
        status = response['collectionDetails'][0]['status']
        if status == 'ACTIVE':
            print("Collection is now active.")
            return response['collectionDetails'][0]['collectionEndpoint']
        elif status in ['FAILED', 'DELETED']:
            raise Exception(f"Collection creation failed with status: {status}")
        print("Waiting for collection to become active...")
        time.sleep(60)

def create_access_policy(collection_name):
    policy_name = f"{collection_name}-access-policy"
    account_id = sts_client.get_caller_identity()['Account']
    policy_document = [{
        "Description": f"Access policy for {collection_name}",
        "Rules": [
            {
                "ResourceType": "index",
                "Resource": [f"index/{collection_name}/*"],
                "Permission": [
                    "aoss:CreateIndex",
                    "aoss:DeleteIndex",
                    "aoss:UpdateIndex",
                    "aoss:DescribeIndex",
                    "aoss:ReadDocument",
                    "aoss:WriteDocument"
                ]
            },
            {
                "ResourceType": "collection",
                "Resource": [f"collection/{collection_name}"],
                "Permission": [
                    "aoss:CreateCollectionItems",
                    "aoss:DeleteCollectionItems",
                    "aoss:UpdateCollectionItems"
                ]
            }
        ],
        "Principal": [
            f"arn:aws:iam::{account_id}:role/{connected_role}",
            f"arn:aws:iam::{account_id}:role/{lambda_function_role}"
        ]
    }
    ]

    try:
        response = opensearchserverless_client.create_access_policy(
            name=policy_name,
            type='data',
            policy=json.dumps(policy_document)
        )
        print(f"Access policy created: {response['accessPolicyDetail']['name']}")
    except opensearchserverless_client.exceptions.ConflictException:
        print(f"Access policy {policy_name} already exists.")

def create_encryption_policy(collection_name):
    policy_name = f"{collection_name}-encryption-policy"
    policy_document = {
        "Rules": [
            {
                "ResourceType": "collection",
                "Resource": [f"collection/{collection_name}"]
            }
        ],
        "AWSOwnedKey": True
    }

    try:
        response = opensearchserverless_client.create_security_policy(
            name=policy_name,
            policy=json.dumps(policy_document),
            type='encryption'
        )
        print(f"Encryption policy created: {response['securityPolicyDetail']['name']}")
    except opensearchserverless_client.exceptions.ConflictException:
        print(f"Encryption policy {policy_name} already exists.")

def create_network_policy(collection_name):
    policy_name = f"{collection_name}-network-policy"
    policy_document = [{
        "Rules": [
            {
                "ResourceType": "collection",
                "Resource": [f"collection/{collection_name}"]
            },
            {
                "ResourceType": "dashboard",
                "Resource": [f"collection/{collection_name}"]
            }
        ],
        "AllowFromPublic": True
    }]

    try:
        response = opensearchserverless_client.create_security_policy(
            name=policy_name,
            policy=json.dumps(policy_document),
            type='network'
        )
        print(f"Network policy created: {response['securityPolicyDetail']['name']}")
    except opensearchserverless_client.exceptions.ConflictException:
        print(f"Network policy {policy_name} already exists.")

# Usage
collection_name = 'splunk-vector'

# Create the encryption policy
create_encryption_policy(collection_name)

# Create the network policy
create_network_policy(collection_name)

# Create the vector collection
collection_id = create_vector_collection(collection_name)

# Wait for the collection to become active
endpoint = wait_for_collection_creation(collection_id)

# Create the access policy
create_access_policy(collection_name)

print(f"Vector collection endpoint: {endpoint}")

### Create Embeddings and load into opensearch serverless

In [None]:
!pip install opensearch-py
!pip install requests_aws4auth

Initialize client for bedrock runtime and OpenSearch

In [71]:
import pandas as pd
import boto3
import json
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth

# Initialize clients
bedrock_client = boto3.client('bedrock-runtime', region_name=region)
credentials = boto3.Session().get_credentials()
awsauth = AWS4Auth(credentials.access_key, credentials.secret_key,
                   region, "aoss", session_token=credentials.token)
aoss_host = endpoint.replace("https://", "")
# Initialize OpenSearch client
opensearch_client = OpenSearch(
    hosts=[{'host': aoss_host, 'port': 443}],
    http_auth=awsauth,
    use_ssl=True,
    timeout=300, 
    max_retries=10, 
    retry_on_timeout=True,
    verify_certs=True,
    connection_class=RequestsHttpConnection
)

create the index in OpenSearch and load the sourcetype data

In [112]:
index_name = 'splunk-sourcetypes'
def create_index_if_not_exists():
    if not opensearch_client.indices.exists(index=index_name):
        index_body = {
          'settings': {
              'index': {
                  'knn': True,
              }
          },
          'mappings': {
              'properties': {
                  'my_vector': {
                      'type': 'knn_vector',
                      'dimension': 1024,  # Dimension of the Titan embedding 
                      "method": {
                        "name": "hnsw",
                        "space_type": "l2",  # Or "cosinesimil" for cosine similarity
                        "engine": "faiss",  # Or "faiss"
                        "parameters": {
                            "ef_construction": 128,
                            "m": 24
                        }              
                  }
                },

                  'content': {'type': 'text'}

          }
        }
        }
        opensearch_client.indices.create(index=index_name, body=index_body)
        print(f"Created index: {index_name}")

def generate_embedding(text):
    body = json.dumps({"inputText": text})
    response = bedrock_client.invoke_model(
        body=body,
        modelId='amazon.titan-embed-text-v2:0',
        accept='application/json',
        contentType='application/json'
    )
    response_body = json.loads(response['body'].read())
    return response_body['embedding']

def process_csv_file(file_path):
    create_index_if_not_exists()
    with open(file_path, 'r', encoding='utf-8-sig') as data:
      for line in csv.DictReader(data):
        embedding = generate_embedding(json.dumps(line).replace('\\u00a0', ' '))
        # Prepare document
        doc = {
            'content': json.dumps(line).replace('\\u00a0', ' '),
            'my_vector': embedding
        }
        # Index the document
        print(doc['content'])
        response = opensearch_client.index(index=index_name, body=doc)
        print(response)

In [None]:
# Usage
csv_file_path = 'data/aws-source-types.csv'
process_csv_file(csv_file_path)

Create a function to do an approximate search for a query and verify results

In [None]:
def retrieve_similar_documents(query):
    query_embedding = generate_embedding(query)
    response = opensearch_client.search(
        index=index_name,
        body={
            "query": {
                "knn": {
                    "my_vector": {
                        "vector": query_embedding,
                        "k": 1  # Return the top 3 most relevant documents
                    }
                }
            }
        }
    )
    # return response
    if response['hits']['hits']:
        return [item['_source']['content'] for item in response['hits']['hits']]
    else:
        return "No sourcetypes found"

# Example user query
user_query = "VPC FlowLogs"
relevant_documents = retrieve_similar_documents(user_query)
relevant_documents

## Creating Lambda function
Now that we created a Vector DB for the embeddings for AWS source types, lets now build a lambda functions to query these embeddings as well as other functions for Splunk's use cases

Install the libraries for lambda function

In [None]:
!pip install splunk_sdk --target ./lambda_package/
!pip install opensearch-py --target ./lambda_package/
!pip install requests_aws4auth --target ./lambda_package/

In [None]:
%%writefile ./lambda_package/lambda_function.py
import os
import re
import requests
import json
import sys
from time import sleep
import splunklib.client as splunk_client
import splunklib.results as results
import boto3
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth

session = boto3.session.Session()
region = session.region_name
bedrock_client = boto3.client('bedrock-runtime', region_name=region)
index_name = os.environ['aoss_index']
# Initialize clients
bedrock_client = boto3.client('bedrock-runtime', region_name=region)
credentials = boto3.Session().get_credentials()
awsauth = AWS4Auth(credentials.access_key, credentials.secret_key,
                   region, "aoss", session_token=credentials.token)
endpoint = os.environ['aoss_endpoint']
aoss_host = endpoint.replace("https://", "")
# Initialize OpenSearch client
opensearch_client = OpenSearch(
    hosts=[{'host': aoss_host, 'port': 443}],
    http_auth=awsauth,
    use_ssl=True,
    timeout=300, 
    max_retries=10, 
    retry_on_timeout=True,
    verify_certs=True,
    connection_class=RequestsHttpConnection
)

def generate_embedding(text):    
    body = json.dumps({"inputText": text})
    response = bedrock_client.invoke_model(
        body=body,
        modelId='amazon.titan-embed-text-v2:0',
        accept='application/json',
        contentType='application/json'
    )
    response_body = json.loads(response['body'].read())
    return response_body['embedding']

def search_aws_sourcetypes(awssourcetype: str) ->str:
    query_embedding = generate_embedding(awssourcetype)
    response = opensearch_client.search(
        index=index_name,
        body={
            "query": {
                "knn": {
                    "my_vector": {
                        "vector": query_embedding,
                        "k": 1  # Return the top 3 most relevant documents
                    }
                }
            }
        }
    )
    # return response
    if response['hits']['hits']:
        return [item['_source']['content'] for item in response['hits']['hits']]
    else:
        return "No sourcetypes found"    
def get_splunk_fields(sourcetype: str) ->str:
    """
    Gets Splunk sourcetype as input and returns the list of fields in the sourcetype. Give the input sourcetype as only input string parameter. 
    This tool is useful to know which fields are stored in sourcetype to be used in SPL Queries. 
    """
    # Create a Secrets Manager client
    secrets_client = boto3.client('secretsmanager')
    secret_arn = os.environ['secret_arn']
    response = secrets_client.get_secret_value(SecretId=secret_arn)
    secret = json.loads(response['SecretString'])
    splunkToken = secret['SplunkToken']
    HOST=secret['SplunkHost']
    PORT = 8089
    # Create a Secrets Manager client and retrieve Splunk Secrets
    service = splunk_client.connect(
        host=HOST,
        port=PORT,
        splunkToken=splunkToken)
        # username=USERNAME,
        # password=PASSWORD)

    # Create search Query
    search_query = "search index=main "+ "sourcetype="+sourcetype+" | fieldsummary | fields field"
    # Create the search job

    kwargs_normalsearch = {"exec_mode": "normal", "earliest_time": "-15m"}
    job = service.jobs.create(search_query, **kwargs_normalsearch)
    # Wait for the search to complete
    while True:
        print(job.is_ready())
        while not job.is_ready():
            pass
        stats = {"isDone": job["isDone"],
                 "doneProgress": float(job["doneProgress"])*100,
                  "scanCount": int(job["scanCount"]),
                  "eventCount": int(job["eventCount"]),
                  "resultCount": int(job["resultCount"])}
    
        status = ("\r%(doneProgress)03.1f%%   %(scanCount)d scanned   "
                  "%(eventCount)d matched   %(resultCount)d results") % stats
    
        sys.stdout.write(status)
        sys.stdout.flush()
        if stats["isDone"] == "1":
            sys.stdout.write("\n\nDone!\n\n")
            break
        sleep(2)
    fields = []        
    print(status)
    for result in results.JSONResultsReader(job.results(output_mode='json')):
        fields.append(result['field'])
    # return results.JSONResultsReader(job.results(output_mode='json'))
    # fields = job.results()
    # print(fields)
    # fields = [field["field"] for field in fields]
    return fields

def get_splunk_results(search_query:str) ->str:
    """
    Executes a Splunk search query and returns the results as JSON data. Give the input search query as string variable.
    Dont give any search values within quotes unless there is a space in the values.
    Here is an example of search query:
    search index=main sourcetype=aws:cloudtrail errorCode!=success | stats count by eventSource, eventName, errorCode | sort - count
    """
    # Create a Secrets Manager client and retrieve Splunk Secrets
    secrets_client = boto3.client('secretsmanager')
    secret_arn = os.environ['secret_arn']
    response = secrets_client.get_secret_value(SecretId=secret_arn)
    secret = json.loads(response['SecretString'])
    splunkToken = secret['SplunkToken']
    HOST=secret['SplunkHost']
    PORT = 8089


    # Create a Service instance and log in
    service = splunk_client.connect(
        host=HOST,
        port=PORT,
        splunkToken=splunkToken)
        # username=USERNAME,
        # password=PASSWORD)

    kwargs_normalsearch = {"exec_mode": "normal", "earliest_time": "-24h"}
    job = service.jobs.create(search_query, **kwargs_normalsearch)

    # Wait for the search to complete
    while True:
        while not job.is_ready():
            pass
        stats = {"isDone": job["isDone"],
                 "doneProgress": float(job["doneProgress"])*100,
                  "scanCount": int(job["scanCount"]),
                  "eventCount": int(job["eventCount"]),
                  "resultCount": int(job["resultCount"])}

        status = ("\r%(doneProgress)03.1f%%   %(scanCount)d scanned   "
                  "%(eventCount)d matched   %(resultCount)d results") % stats

        # sys.stdout.write(status)
        sys.stdout.flush()
        if stats["isDone"] == "1":
            # sys.stdout.write("\n\nDone!\n\n")
            break
        sleep(2)
    # results = job.results(output_mode='json_cols')
    # Get the results and return them as a JSON object
    results_reader = results.JSONResultsReader(job.results(output_mode="json"))
    job.cancel()
    return list(results_reader)



def get_splunk_lookups(sourcetype: str) ->str:
    """
    Gets Splunk sourcetype as input and returns the list of lookup values for the sourcetype. Useful when the initial SPL query do not return 
    any results due to a lookup field(s). This function gets all lookup names for a given sourcetype which the agent can use to get the right 
    lookup values for SPL query by calling the agent get_splunk_lookup_values. 
    """
    # Create a Secrets Manager client
    secrets_client = boto3.client('secretsmanager')
    secret_arn = os.environ['secret_arn']
    response = secrets_client.get_secret_value(SecretId=secret_arn)
    secret = json.loads(response['SecretString'])
    splunkToken = secret['SplunkToken']
    HOST=secret['SplunkHost']
    PORT = 8089
    # Create a Secrets Manager client and retrieve Splunk Secrets
    service = splunk_client.connect(
        host=HOST,
        port=PORT,
        splunkToken=splunkToken)


    # Create 1st Query to get all lookup values for the source type.
    search_query = "| rest /servicesNS/-/-/data/props/lookups | search stanza="+sourcetype+" \
    | dedup transform | fields transform"
    # Create the search job

    kwargs_normalsearch = {"exec_mode": "normal", "earliest_time": "-15m"}
    job = service.jobs.create(search_query, **kwargs_normalsearch)
    # Wait for the search to complete
    while True:
        print(job.is_ready())
        while not job.is_ready():
            pass
        stats = {"isDone": job["isDone"],
                 "doneProgress": float(job["doneProgress"])*100,
                  "scanCount": int(job["scanCount"]),
                  "eventCount": int(job["eventCount"]),
                  "resultCount": int(job["resultCount"])}
    
        status = ("\r%(doneProgress)03.1f%%   %(scanCount)d scanned   "
                  "%(eventCount)d matched   %(resultCount)d results") % stats
    
        sys.stdout.write(status)
        sys.stdout.flush()
        if stats["isDone"] == "1":
            sys.stdout.write("\n\nDone!\n\n")
            break
        sleep(2)
    fields = []        
    print(status)
    print(job.results(output_mode='json'))
    for result in results.JSONResultsReader(job.results(output_mode='json')):
        fields.append(result['transform'])
    if not fields:
        fields.append("No lookups found for sourcetype "+sourcetype)
        return fields
    return fields        

def get_splunk_lookup_values(lookup_name: str) ->str:
    """
    Gets Splunk lookup name as input and returns the lookup values. Useful when the initial SPL query do not return 
    any results due to a lookup field(s). This function gets all look up values for a given sourcetype which can be useful to rewrite the SPL queries
    with appropriate lookup values. 
    """
    # Create a Secrets Manager client
    secrets_client = boto3.client('secretsmanager')
    secret_arn = os.environ['secret_arn']
    response = secrets_client.get_secret_value(SecretId=secret_arn)
    secret = json.loads(response['SecretString'])
    splunkToken = secret['SplunkToken']
    HOST=secret['SplunkHost']
    PORT = 8089
    # Create a Secrets Manager client and retrieve Splunk Secrets
    service = splunk_client.connect(
        host=HOST,
        port=PORT,
        splunkToken=splunkToken)


    # Create 1st Query to get all lookup values for the source type.
    search_query = "| inputlookup "+ lookup_name
    # Create the search job

    kwargs_normalsearch = {"exec_mode": "normal", "earliest_time": "-15m"}
    job = service.jobs.create(search_query, **kwargs_normalsearch)
    # Wait for the search to complete
    while True:
        print(job.is_ready())
        while not job.is_ready():
            pass
        stats = {"isDone": job["isDone"],
                 "doneProgress": float(job["doneProgress"])*100,
                  "scanCount": int(job["scanCount"]),
                  "eventCount": int(job["eventCount"]),
                  "resultCount": int(job["resultCount"])}
    
        status = ("\r%(doneProgress)03.1f%%   %(scanCount)d scanned   "
                  "%(eventCount)d matched   %(resultCount)d results") % stats
    
        sys.stdout.write(status)
        sys.stdout.flush()
        if stats["isDone"] == "1":
            sys.stdout.write("\n\nDone!\n\n")
            break
        sleep(2)
    #print(status)
    print(job.results(output_mode="json"))
    # Get the results and return them as a list
    fields = []
    for result in results.JSONResultsReader(job.results(output_mode='json')):
        print(result)
        if re.search('ERROR', str(result)):
            fields.append("No lookup values found for lookup name "+lookup_name)
            return fields
        elif re.search('INFO', str(result)):
            continue
        fields.append(result)
    if not fields:
        fields.append("No lookup values found for lookup name "+lookup_name)
        return fields
    return fields  


def lambda_handler(event, context):
    ip = requests.get('https://checkip.amazonaws.com').text.strip()
    print("Extern IP:",ip)
    print("*********Printing event data *************")
    print(event)
    agent = event['agent']
    actionGroup = event['actionGroup']
    function = event['function']
    parameters = event.get('parameters', [])
    responseBody =  {
        "TEXT": {
            "body": "Error, no function was called"
        }
    }
    
    if function == 'search_aws_sourcetypes':
        aws_sourcetype = None
        for param in parameters:
            if param["name"] == "awssourcetype":
                aws_sourcetype = param["value"]
        if not aws_sourcetype:
            raise Exception("Missing mandatory parameter: awssourcetype")
        aws_sourcetype = search_aws_sourcetypes(aws_sourcetype)
        print(type(json.dumps(aws_sourcetype)))
        responseBody =  {
            'TEXT': {
                "body": json.dumps(aws_sourcetype)
            }
        }
    elif function == 'get_splunk_fields':
        sourcetype = None
        for param in parameters:
            if param["name"] == "sourcetype":
                sourcetype = param["value"]
        if not sourcetype:
            raise Exception("Missing mandatory parameter: sourcetype")      
        sourcetype_fields = get_splunk_fields(sourcetype)
        responseBody =  {
            'TEXT': {
                "body": json.dumps(sourcetype_fields)
            }
        }        
    elif function == 'get_splunk_results':
        search_query = None
        for param in parameters:
            if param["name"] == "search_query":
                search_query = param["value"]
        if not search_query:
            raise Exception("Missing mandatory parameter: search_query")      
        query_results = get_splunk_results(search_query)
        responseBody =  {
            'TEXT': {
                "body": json.dumps(query_results)
            }
        }         

    elif function == 'get_splunk_lookups':
        sourcetype = None
        for param in parameters:
            if param["name"] == "sourcetype":
                sourcetype = param["value"]
        if not sourcetype:
            raise Exception("Missing mandatory parameter: sourcetype")      
        sourcetype_fields = get_splunk_lookups(sourcetype)
        responseBody =  {
            'TEXT': {
                "body": json.dumps(sourcetype_fields)
            }
        }  

    elif function == 'get_splunk_lookup_values':
        lookup_name = None
        for param in parameters:
            if param["name"] == "lookup_name":
                lookup_name = param["value"]
        if not lookup_name:
            raise Exception("Missing mandatory parameter: lookup_name")      
        lookup_values = get_splunk_lookup_values(lookup_name)
        print("lookup values type",type(lookup_values))
        responseBody =  {
            'TEXT': {
                "body": json.dumps(lookup_values)
            }
        } 


    action_response = {
        'actionGroup': actionGroup,
        'function': function,
        'functionResponse': {
            'responseBody': responseBody
        }

    }

    function_response = {'response': action_response, 'messageVersion': event['messageVersion']}
    print("Response: {}".format(function_response))

    return function_response                                              

In [None]:
# Create IAM Role for the Lambda function
try:
    assume_role_policy_document = {
        "Version": "2012-10-17",
        "Statement": [
            {
                "Effect": "Allow",
                "Principal": {
                    "Service": "lambda.amazonaws.com"
                },
                "Action": "sts:AssumeRole"
            }
        ]
    }

    assume_role_policy_document_json = json.dumps(assume_role_policy_document)

    lambda_iam_role = iam_client.create_role(
        RoleName=lambda_function_role,
        AssumeRolePolicyDocument=assume_role_policy_document_json
    )

    # Pause to make sure role is created
    time.sleep(10)
except:
    lambda_iam_role = iam_client.get_role(RoleName=lambda_function_role)

iam_client.attach_role_policy(
    RoleName=lambda_function_role,
    PolicyArn='arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole'
)

### Zip the libraries and lambda function to a zip file

In [76]:
import zipfile
import os

In [152]:
def zip_directory_contents(source_dir, output_filename):
    # Ensure the source directory exists
    if not os.path.exists(source_dir):
        print(f"Error: The directory {source_dir} does not exist.")
        return

    # Create a ZipFile object
    with zipfile.ZipFile(output_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
        # Walk through the directory
        for root, _, files in os.walk(source_dir):
            for file in files:
                # Get the full path of the file
                file_path = os.path.join(root, file)
                
                # Calculate the arc name (file name within the zip)
                arc_name = os.path.relpath(file_path, source_dir)
                #print(arc_name)
                # Add the file to the zip
                zipf.write(file_path, arc_name)

    print(f"Successfully created {output_filename} with the contents of {source_dir}")

# Usage
source_directory = './lambda_package'
output_zip = 'lambda_archive.zip'

zip_directory_contents(source_directory, output_zip)

In [None]:
# Create a S3 bucket to upload the zip file

s3_client = boto3.client('s3')

# Bucket name
bucket_name = 'splunk-bedrock-code-'+account_id

# Check if the bucket exists
try:
    response = s3_client.head_bucket(Bucket=bucket_name)
except ClientError as e:
    error_code = e.response['Error']['Code']
    if error_code == '404':
        # Bucket does not exist, create it
        try:
            s3_client.create_bucket(Bucket=bucket_name)
            print(f"Bucket '{bucket_name}' created successfully.")
        except ClientError as e:
            print(f"Error creating bucket: {e}")
            raise
    else:
        # Other error, re-raise the exception
        raise

# Upload an object to the bucket
#object_key = 'lambda_package'
zip_file_path = 'lambda_archive'+'.zip'
# Upload the file to S3
try:
    s3_client.upload_file(zip_file_path, bucket_name, zip_file_path)
    print(f"File '{zip_file_path}' uploaded to bucket '{bucket_name}' with key '{zip_file_path}'.")
except ClientError as e:
    print(f"Error uploading file: {e}")
    raise

In [96]:
# Create Lambda Function
lambda_function = lambda_client.create_function(
    FunctionName=lambda_function_name,
    Runtime='python3.12',
    Timeout=300,
    Role=lambda_iam_role['Role']['Arn'],
    Code={'S3Bucket': bucket_name,
        'S3Key': zip_file_path},
    Environment={
        'Variables': {
            'secret_arn': secret_arn
        }
        },    
    Handler='lambda_function.lambda_handler'
)


In [None]:
# # Optional step to update lambda function with new code, if you make any code changes. Uncomment this block and run.
# try:
#     response = lambda_client.update_function_code(
#         FunctionName=lambda_function_name,
#         S3Bucket= bucket_name,
#         S3Key= zip_file_path,
#         Publish=True  # Create a new version
#     )
#     print(f"Successfully updated Lambda function: {lambda_function_name}")
#     print(f"New version: {response['Version']}")
# except lambda_client.exceptions.ResourceNotFoundException:
#     print(f"Error: Lambda function '{lambda_function_name}' not found.")
# except lambda_client.exceptions.ResourceConflictException:
#     print(f"Error: The function '{lambda_function_name}' is currently updating or in an inconsistent state.")
# except lambda_client.exceptions.InvalidParameterValueException as e:
#     print(f"Error: Invalid parameter value. {str(e)}")
# except lambda_client.exceptions.CodeStorageExceededException:
#     print("Error: You have exceeded your maximum total code size per account.")
# except Exception as e:
#     print(f"An unexpected error occurred: {str(e)}")

In [None]:
# Extend Lambda role with IAM policy to retrieve secret manager secret 
smgr_policy_document = {
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Action": [
                "secretsmanager:GetSecretValue"
            ],
            "Resource": secret_arn
        }
    ]
}

# Add the inline policy to the existing role
try:
    iam_client.put_role_policy(
        RoleName=lambda_function_role,
        PolicyName='SecretsManagerAccessInlinePolicy',
        PolicyDocument=json.dumps(smgr_policy_document)
    )
    print(f"Successfully added inline policy 'SecretsManagerAccessInlinePolicy' to role {lambda_function_role}")
except iam_client.exceptions.LimitExceededException:
    print("Error: Limit exceeded. The role might have too many inline policies.")
except iam_client.exceptions.NoSuchEntityException:
    print(f"Error: The role '{lambda_function_role}' does not exist.")
except iam_client.exceptions.UnmodifiableEntityException:
    print("Error: Cannot modify the role. It might be a service-linked role.")
except Exception as e:
    print(f"Error adding inline policy to role: {str(e)}")


In [None]:
# Extend Lambda role with IAM policy to invoke Titan Model
titan_policy_document = {
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Action": [
                "bedrock:InvokeModel"
            ],
            "Resource": "arn:aws:bedrock:"+region+"::foundation-model/amazon.titan-embed-text-v2:0"
        }
    ]
}

# Add the inline policy to the existing role
try:
    iam_client.put_role_policy(
        RoleName=lambda_function_role,
        PolicyName='BedRockTitanInlinePolicy',
        PolicyDocument=json.dumps(titan_policy_document)
    )
    print(f"Successfully added inline policy 'SecretsManagerAccessInlinePolicy' to role {lambda_function_role}")
except iam_client.exceptions.LimitExceededException:
    print("Error: Limit exceeded. The role might have too many inline policies.")
except iam_client.exceptions.NoSuchEntityException:
    print(f"Error: The role '{lambda_function_role}' does not exist.")
except iam_client.exceptions.UnmodifiableEntityException:
    print("Error: Cannot modify the role. It might be a service-linked role.")
except Exception as e:
    print(f"Error adding inline policy to role: {str(e)}")

In [None]:
collection_name = 'splunk-vector'
# Extend Lambda role with IAM policy for OpenSearch Access
aoss_policy_document = {
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Action": [
                "aoss:*"
            ],
            "Resource": [
                "arn:aws:aoss:"+region+":"+account_id+":dashboards/default",
                "arn:aws:aoss:"+region+":"+account_id+":collection/*"
            ]
        }
    ]
}

# Add the inline policy to the existing role
try:
    iam_client.put_role_policy(
        RoleName=lambda_function_role,
        PolicyName='OpenSearchPolicy',
        PolicyDocument=json.dumps(aoss_policy_document)
    )
    print(f"Successfully added inline policy 'SecretsManagerAccessInlinePolicy' to role {lambda_function_role}")
except iam_client.exceptions.LimitExceededException:
    print("Error: Limit exceeded. The role might have too many inline policies.")
except iam_client.exceptions.NoSuchEntityException:
    print(f"Error: The role '{lambda_function_role}' does not exist.")
except iam_client.exceptions.UnmodifiableEntityException:
    print("Error: Cannot modify the role. It might be a service-linked role.")
except Exception as e:
    print(f"Error adding inline policy to role: {str(e)}")

## Create Agent
We will now create the agent. To do so, we first need to create the agent policies that allow bedrock model invocation for a specific foundation model and the agent IAM role with the policy associated to it. 

In [103]:

# Create IAM policies for agent
bedrock_agent_bedrock_allow_policy_statement = {
    "Version": "2012-10-17",
    "Statement": [
        {
            "Sid": "AmazonBedrockAgentBedrockFoundationModelPolicy",
            "Effect": "Allow",
            "Action": "bedrock:InvokeModel",
            "Resource": [
                f"arn:aws:bedrock:{region}::foundation-model/{agent_foundation_model}"
            ]
        }
    ]
}

bedrock_policy_json = json.dumps(bedrock_agent_bedrock_allow_policy_statement)

agent_bedrock_policy = iam_client.create_policy(
    PolicyName=agent_bedrock_allow_policy_name,
    PolicyDocument=bedrock_policy_json
)


In [None]:
# Create IAM Role for the agent and attach IAM policies
assume_role_policy_document = {
    "Version": "2012-10-17",
    "Statement": [{
          "Effect": "Allow",
          "Principal": {
            "Service": "bedrock.amazonaws.com"
          },
          "Action": "sts:AssumeRole"
    }]
}

assume_role_policy_document_json = json.dumps(assume_role_policy_document)
agent_role = iam_client.create_role(
    RoleName=agent_role_name,
    AssumeRolePolicyDocument=assume_role_policy_document_json
)

# Pause to make sure role is created
time.sleep(10)
    
iam_client.attach_role_policy(
    RoleName=agent_role_name,
    PolicyArn=agent_bedrock_policy['Policy']['Arn']
)

## Create Agent
We will now create the agent. To do so, we first need to create the agent policies that allow bedrock model invocation for a specific foundation model and the agent IAM role with the policy associated to it. 

In [None]:
response = bedrock_agent_client.create_agent(
    agentName=agent_name,
    agentResourceRoleArn=agent_role['Role']['Arn'],
    description=agent_description,
    idleSessionTTLInSeconds=1800,
    foundationModel=agent_foundation_model,
    instruction=agent_instruction,
)
response

In [None]:
agent_id = response['agent']['agentId']
agent_id

## Create Agent Action Group
We will now create an agent action group that uses the lambda function created earlier. The [`create_agent_action_group`](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-agent/client/create_agent_action_group.html) function provides this functionality. We will use `DRAFT` as the agent version since we haven't yet created an agent version or alias. To inform the agent about the action group capabilities, we provide an action group description.

To define the functions using a function schema, you need to provide the `name`, `description` and `parameters` for each function.

In [214]:
agent_functions = [
    {
        'name': 'search_aws_sourcetypes',
        'description': 'Searches a Vector database and returns right sourcetype for AWS source data',
        'parameters': {
            "awssourcetype": {
                "description": "the source type to be searched for a given AWS data",
                "required": True,
                "type": "string"
            }
        }
    },
    {
        'name': 'get_splunk_fields',
        'description': 'Provides list of fields found in a given source type sourcetype. \
        Useful to understand the schema of Splunk source types',
        'parameters': {
            "sourcetype": {
                "description": "sourcetype for the schema to be returned",
                "required": True,
                "type": "string"
            }
        }
    },
    {
        'name': 'get_splunk_results',
        'description': 'Executes a Splunk search query and returns the results as JSON data. \
        Give the input search query as string variable.Dont give any search values within quotes unless there is a space in the values. \
        Always provide Splunk sourcetype to get the result. If index is not given use main as the index, otherwise use the index name. \ 
        Here is an example SPL Query to query cloudtrail data and get the results grouped by account id, event source and error code: \ 
        search index=main sourcetype=aws:cloudtrail | stats count by recipientAccountId, eventSource, eventName, errorCode',
        'parameters': {
            "search_query": {
                "description": "splunk search query",
                "required": True,
                "type": "string"
            }
        }
    },
    {
        'name': 'get_splunk_lookups',
        'description': 'Gets Splunk sourcetype as input and returns the list of lookup values for the sourcetype. \
        Useful to identify any lookups associated with sourcetype and is required for the SPL query to execute. \
        This function gets all lookup names for a given sourcetype which the agent can use to get the right lookup values for \
        SPL query by calling the agent get_splunk_lookup_values.',
        'parameters': {
            "sourcetype": {
                "description": "source type to get the associated lookups",
                "required": True,
                "type": "string"
            }
        }
    },
    {
        'name': 'get_splunk_lookup_values',
        'description': 'Gets Splunk lookup name as input and returns the lookup values. This function gets all look up values \
        for a given sourcetype which can be useful to rewrite the SPL queries with appropriate lookup values.',
        'parameters': {
            "lookup_name": {
                "description": "lookup_name to get all the lookup values",
                "required": True,
                "type": "string"
            }
        }
    }    
]

In [108]:
# Pause to make sure agent is created
time.sleep(30)
# Now, we can configure and create an action group here:
agent_action_group_response = bedrock_agent_client.create_agent_action_group(
    agentId=agent_id,
    agentVersion='DRAFT',
    actionGroupExecutor={
        'lambda': lambda_function['FunctionArn']
    },
    actionGroupName=agent_action_group_name,
    functionSchema={
        'functions': agent_functions
    },
    description=agent_action_group_description
)

## Allowing Agent to invoke Action Group Lambda
Before using the action group, we need to allow the agent to invoke the lambda function associated with the action group. This is done via resource-based policy. Let's add the resource-based policy to the lambda function created

In [109]:
# Create allow invoke permission on lambda
response = lambda_client.add_permission(
    FunctionName=lambda_function_name,
    StatementId='allow_bedrock',
    Action='lambda:InvokeFunction',
    Principal='bedrock.amazonaws.com',
    SourceArn=f"arn:aws:bedrock:{region}:{account_id}:agent/{agent_id}",
)


## Preparing Agent

Let's create a DRAFT version of the agent that can be used for internal testing.


In [None]:
response = bedrock_agent_client.prepare_agent(
    agentId=agent_id
)
print(response)

In [42]:
# Pause to make sure agent is prepared
time.sleep(30)
# USe the Test Alias for the Agent
agent_alias_id = "TSTALIASID"

## Invoke Agent

Now that we've created the agent, let's use the `bedrock-agent-runtime` client to invoke this agent and perform some tasks.

In [43]:
# setting logger
logging.basicConfig(format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
## create a random id for session initiator id
session_id:str = str(uuid.uuid1())
enable_trace:bool = True
end_session:bool = False
query_text="Can you write and execute SPL to query AWS cloudtrail data.\
I need to see which aws account produces the highest count of non-success error code and by which AWS service and event.\
Give me a table of final results and provide your summary"
# invoke the agent API
agentResponse = bedrock_agent_runtime_client.invoke_agent(
    inputText=query_text,
    agentId=agent_id,
    agentAliasId=agent_alias_id, 
    sessionId=session_id,
    enableTrace=enable_trace, 
    endSession= end_session
)


logger.info(pprint.pprint(agentResponse))

Review the prompt handling by Bedrock and various agents executions..

In [None]:
%%time
event_stream = agentResponse['completion']
try:
    for event in event_stream:        
        if 'chunk' in event:
            
            data = event['chunk']['bytes']
            logger.info(f"Final answer ->\n{data.decode('utf8')}")
            agent_answer = data.decode('utf8')
            end_event_received = True
            # End event indicates that the request finished successfully
        elif 'trace' in event:
            logger.info(json.dumps(event['trace'], indent=2))
        else:
            raise Exception("unexpected event.", event)
except Exception as e:
    raise Exception("unexpected event.", e)

## Clean up (optional)

The next steps are optional and demonstrate how to delete our agent. To delete the agent we need to:

1. Delete opensearch index and collection
2. Update the action group to disable it
3. Delete agent action group
4. Delete agent
5. Delete lambda function
6. Delete the created IAM roles and policies


In [None]:
# Delete index
def delete_index(index_name):
    try:
        # Check if the index exists
        if opensearch_client.indices.exists(index=index_name):
            # Delete the index
            response = opensearch_client.indices.delete(index=index_name)
            if response.get('acknowledged', False):
                print(f"Successfully deleted index: {index_name}")
            else:
                print(f"Failed to delete index: {index_name}")
        else:
            print(f"Index {index_name} does not exist.")
    except Exception as e:
        print(f"An error occurred while deleting the index: {str(e)}")

# Usage
delete_index(index_name)

In [None]:
client = boto3.client('opensearchserverless')

def delete_security_policy(policy_name, policy_type):
    try:
        client.delete_security_policy(
            name=policy_name,
            type=policy_type
        )
        print(f"Deleted {policy_type} security policy: {policy_name}")
    except client.exceptions.ResourceNotFoundException:
        print(f"{policy_type.capitalize()} security policy {policy_name} not found.")
    except Exception as e:
        print(f"Error deleting {policy_type} security policy {policy_name}: {str(e)}")

def delete_access_policy(policy_name):
    try:
        client.delete_access_policy(
            name=policy_name,
            type='data'
        )
        print(f"Deleted access policy: {policy_name}")
    except client.exceptions.ResourceNotFoundException:
        print(f"Access policy {policy_name} not found.")
    except Exception as e:
        print(f"Error deleting access policy {policy_name}: {str(e)}")

def delete_collection(collection_name,collection_id):
    try:
        response = client.delete_collection(
            clientToken=collection_name,
            id=collection_id
        )
        print(f"Deletion initiated for collection: {collection_name}")
        return response['deleteCollectionDetail']['id']
    except client.exceptions.ResourceNotFoundException:
        print(f"Collection {collection_name} not found.")
        return None
    except Exception as e:
        print(f"Error deleting collection {collection_name}: {str(e)}")
        return None

def wait_for_collection_deletion(collection_id):
    if not collection_id:
        return

    print("Waiting for collection deletion to complete...")
    waiter = client.get_waiter('collection_deleted')
    try:
        waiter.wait(
            ids=[collection_id],
            WaiterConfig={
                'Delay': 30,
                'MaxAttempts': 40
            }
        )
        print("Collection deletion completed.")
    except Exception as e:
        print(f"Error waiting for collection deletion: {str(e)}")

# Usage
policy_prefix = f"{collection_name}"  # Assuming policies are named with collection name as prefix

# Delete security policies
delete_security_policy(f"{policy_prefix}-network-policy", 'network')
delete_security_policy(f"{policy_prefix}-encryption-policy", 'encryption')

# Delete access policy
delete_access_policy(f"{policy_prefix}-access-policy")

# Delete collection
collection_id = endpoint.split('/')[2].split('.')[0]
delete_collection_id = delete_collection(collection_name, collection_id)

# Wait for collection deletion to complete
wait_for_collection_deletion(collection_id)


In [None]:
wait_for_collection_deletion(collection_id)


In [None]:
# This is not needed, you can delete agent successfully after deleting alias only
# Additionaly, you need to disable it first
action_group_id = agent_action_group_response['agentActionGroup']['actionGroupId']
action_group_name = agent_action_group_response['agentActionGroup']['actionGroupName']

response = bedrock_agent_client.update_agent_action_group(
    agentId=agent_id,
    agentVersion='DRAFT',
    actionGroupId= action_group_id,
    actionGroupName=action_group_name,
    actionGroupExecutor={
        'lambda': lambda_function['FunctionArn']
    },
    functionSchema={
        'functions': agent_functions
    },
    actionGroupState='DISABLED',
)

action_group_deletion = bedrock_agent_client.delete_agent_action_group(
    agentId=agent_id,
    agentVersion='DRAFT',
    actionGroupId= action_group_id
)

In [None]:
agent_deletion = bedrock_agent_client.delete_agent(
    agentId=agent_id
)

In [None]:
# Delete Lambda function
lambda_client.delete_function(
    FunctionName=lambda_function_name
)

In [None]:
# Delete IAM Roles and policies

for policy in [agent_bedrock_allow_policy_name]:
    iam_client.detach_role_policy(RoleName=agent_role_name, PolicyArn=f'arn:aws:iam::{account_id}:policy/{policy}')
    
iam_client.detach_role_policy(RoleName=lambda_function_role, PolicyArn='arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole')

for role_name in [agent_role_name, lambda_function_role]:
    iam_client.delete_role(
        RoleName=role_name
    )

for policy in [agent_bedrock_policy]:
    iam_client.delete_policy(
        PolicyArn=policy['Policy']['Arn']
)


## Conclusion
We have now experimented with using boto3 SDK and Splunk SDKs to create, invoke and delete a Splunk AI Assistant agent, created using function definitions.

## Take aways
Adapt this notebook to customize and create new agents using function definitions for your application

## Thank You