## Multitool Creation using Amazon Bedrock ConverseAPI Function Calling

In this notebook we will create a fictionary database called acme bank which has transaction data for different user names and their corresponding transfer amounts. After the database is being created we will create two different tools called **bar_chart** and **query_athena**. Using function calling in Bedrock Converse API, our LLM will be able to take the user input and create the query for Amazon Athena table by calling query_athena_tool. Afterwards user can ask the LLM to create a bar plot from the values returned by the query by calling the bar_chart tool. 

### Import the Required Modules

In [None]:
!pip3 install -qU boto3

In [None]:
!pip3 install matplotlib

In [None]:
!pip3 install PyAthena

In [None]:
import boto3
import json, sys
from datetime import datetime
from pyathena import connect
import pandas as pd

print('Running boto3 version:', boto3.__version__)

### Attach the necessary IAM permissions

By clicking your sagemaker notebook you can find the attached IAM role. Copy and paste your IAM role to role_name and run the cell below to give access to Amazon Bedrock.

In [None]:
# Create an IAM client
iam = boto3.client('iam')

# Define the role name and policy ARN
role_name = 'INSERT YOUR SAGEMAKER IAM ROLE'
policy_arn = 'arn:aws:iam::aws:policy/AmazonBedrockFullAccess'  
# Attach the policy to the role
try:
    response = iam.attach_role_policy(
        RoleName=role_name,
        PolicyArn=policy_arn
    )
    print(f"Policy {policy_arn} attached to role {role_name} successfully.")
except iam.exceptions.NoSuchEntityException:
    print(f"Role {role_name} or policy {policy_arn} does not exist.")
except Exception as e:
    print(f"Error attaching policy: {e}")

### Create the S3 Bucket for Amzaon Athena Query Results

We will save the outout query results from Amazon Athena to the the S3 bucket that we create below. Please make sure you create your S3 bucket before you create your Athena table.


In [None]:
# Create an S3 client
s3 = boto3.client('s3')

# Define the bucket name
bucket_name = 'BUCKET_NAME'

# Create the S3 bucket
response = s3.create_bucket(Bucket=bucket_name)

if response['ResponseMetadata']['HTTPStatusCode'] ==200:
    print(f'Bucket {bucket_name} created successfully!')
else:
    print(f"Failed to create the bucket Error: {response.status_code}")

### Creating Amazon Athena Catalog

we will create our acme_bank database in athena. Run the python code below and make sure you replace the bucket name in the code with your bucket name that is being created in the previous step!


In [None]:
!python create_athena_catalog.py

### Amazon Bedrock Configurations

In [None]:
modelId = 'anthropic.claude-3-sonnet-20240229-v1:0'
#modelId = 'anthropic.claude-3-haiku-20240307-v1:0'
#modelId = 'cohere.command-r-plus-v1:0'
#modelId = 'cohere.command-r-v1:0'
#modelId = 'mistral.mistral-large-2402-v1:0'
print(f'Using modelId: {modelId}')

region = 'us-east-1'
print('Using region: ', region)

bedrock = boto3.client(
    service_name = 'bedrock-runtime',
    region_name = region,
    )


### Defining our Tools

Bar_Chart tool will create a bar plot using matplot library for the given user names and their corresponding transaction amounts. Query_Athena tool wil query the Athena table and return the database name and values inside the rows/columns of the table.

In [None]:
import matplotlib.pyplot as plt


class ToolsList:
    #Define our bar_chart tool function...
    def bar_chart(self, title, x_values, y_values, x_label, y_label):
      
        fig, ax = plt.subplots(figsize=(10, 6))  
        ax.bar(x_values, y_values, color='skyblue')
        ax.set_title(title)
        ax.set_xlabel(x_label)
        ax.set_ylabel(y_label)
        
        plt.show() 
        
        output_name=f"{title}.png"
        fig.savefig(output_name)
        
        result = f'Your bar chart named {title} is saved'

        # Returns the figure object
        return result
        

    def query_athena(self, query):
        
        print(f"{datetime.now().strftime('%H:%M:%S')} - Got tool query: {query}\n")
        
        try:
            cursor = connect(s3_staging_dir=f"s3://{bucket_name}/athena/",
                                region_name=region).cursor()
            cursor.execute(query)
            df = pd.DataFrame(cursor.fetchall()).to_string(index=False)
            print(f"{datetime.now().strftime('%H:%M:%S')} - Tool result: {df}\n")
            
        except Exception as e:
            
            print(f"{datetime.now().strftime('%H:%M:%S')} - Error: {e}")
            raise
            
        return df
       
        


In [None]:
#Define the configuration for our tool...
toolConfig = {'tools': [],
'toolChoice': {
    'auto': {},
    #'any': {},
    #'tool': {
    #    'name': 'get_weather'
    #}
    }
}

toolConfig['tools'].append({
        'toolSpec': {
            'name': 'bar_chart',
            'description': 'create a bar chart.',
            'inputSchema': {
                'json': {
                    'type': 'object',
                    'properties': {
                        'title': {
                            'type': 'string',
                            'description': 'title of the bar chart'
                        },
                        
                         'x_values': {
                            'type': 'array',
                            'description': 'x axis values of the bar chart',
                            'items': {
                                "type": "string"
                            }
                        },
                        
                         'y_values': {
                            'type': 'array',
                            'description': 'y axis values of the bar chart',
                            'items': {
                                'type': 'string'
                            }
                        },
                        
                         'x_label': {
                            'type': 'string',
                            'description': 'x axis label of the bar chart'
                        },
                        
                          'y_label': {
                            'type': 'string',
                            'description': 'y axis label of the bar chart'
                        }
                        
                    },
                    'required': ['title','x_values','y_values','x_label','y_label']
                }
            }
        }
    })



### Defining tool schema
toolConfig['tools'].append({
        'toolSpec': {
            'name': 'query_athena',
            'description': 'Query the Acme Bank Athena catalog.',
            'inputSchema': {
                'json': {
                    'type': 'object',
                    'properties': {
                        'query': {'type': 'string', 'description': 'SQL query to run against the Athena catalog'}
                    },
                    'required': ['query']
                }
            }
        }
    })



In [None]:
#Function for caling the Bedrock Converse API...
def converse_with_tools(messages, system='', toolConfig=toolConfig):
    response = bedrock.converse(
        modelId=modelId,
        system=system,
        messages=messages,
        toolConfig=toolConfig
    )
    return response

In [None]:
#Function for orchestrating the conversation flow...
def converse(prompt, system=''):
    #Add the initial prompt:
    messages = []
    messages.append(
        {
            "role": "user",
            "content": [
                {
                    "text": prompt
                }
            ]
        }
    )
    print(f"\n{datetime.now().strftime('%H:%M:%S')} - Initial prompt:\n{json.dumps(messages, indent=2)}")

    #Invoke the model the first time:
    output = converse_with_tools(messages, system)
    print(f"\n{datetime.now().strftime('%H:%M:%S')} - Output so far:\n{json.dumps(output['output'], indent=2, ensure_ascii=False)}")

    #Add the intermediate output to the prompt:
    messages.append(output['output']['message'])

    function_calling = next((c['toolUse'] for c in output['output']['message']['content'] if 'toolUse' in c), None)

    #Check if function calling is triggered:
    if function_calling:
        #Get the tool name and arguments:
        tool_name = function_calling['name']
        tool_args = function_calling['input'] or {}
        
        #Run the tool:
        print(f"\n{datetime.now().strftime('%H:%M:%S')} - Running ({tool_name}) tool...")
        tool_response = getattr(ToolsList(), tool_name)(**tool_args) or ""
        if tool_response:
            tool_status = 'success'
        else:
            tool_status = 'error'

        #Add the tool result to the prompt:
        messages.append(
            {
                "role": "user",
                "content": [
                    {
                        'toolResult': {
                            'toolUseId':function_calling['toolUseId'],
                            'content': [
                                {
                                    "text": tool_response
                                }
                            ],
                            'status': tool_status
                        }
                    }
                ]
            }
        )
        #print(f"\n{datetime.now().strftime('%H:%M:%S')} - Messages so far:\n{json.dumps(messages, indent=2)}")

        #Invoke the model one more time:
        output = converse_with_tools(messages, system)
        print(f"\n{datetime.now().strftime('%H:%M:%S')} - Final output:\n{json.dumps(output['output'], indent=2, ensure_ascii=False)}\n")
    return


In [None]:
prompts = [
   
    "return me the name of the table in acme_bank database",
    "return me 3 user_names which are different from each other without duplicates, and their corresponding amounts from the transactions table within acme_bank database, print the user names amd the amounts as array list such as usernames:[x,y,z] and values:[x,y,z] ",
]



for prompt in prompts:
    converse(
        system = [{"text": "You're provided with a tool that can plot a bar chart with the given title 'bar_chart' and another tool to create and run SQL queries agains Athena data catalog named 'query_athena'; \
            only use the tool if required;\ You can use multiple tools at once or you can call the tool multiple times in the same response if required; \ Don't make reference to the tools in your final answer."}],
        prompt = prompt
)
    
    


In [None]:
    
prompt = "Create the bar chart, title is acme_bank_chart, x values are [Leonardo DiCaprio, Daniel Day-Lewis,Brad Pitt],y values are [750.0,1600.0,2000.0], x label is usernames and y label is amounts."
converse(
        system = [{"text": "You're provided with a tool that can plot a bar chart with the given title 'bar_chart' and another tool to create and run SQL queries agains Athena data catalog named 'query_athena'; \
            only use the tool if required;\ You can use multiple tools at once or you can call the tool multiple times in the same response if required; \ Don't make reference to the tools in your final answer."}],
        prompt = prompt
)
    

In the successfull execution you should be able to see an output text similar to the example below from the LLM:

*"message": {"role": "assistant", "content": [
    {"text": "The query selects 3 distinct user_names and their corresponding amounts from the transactions table in the acme_ bank database. \n\nusernames: ['Brad Pitt', 'Tilda Swinton', 'Christian Bale'] \nvalues: [1800.0, 550.0,1300.0]"}}*


Please keep in mind that the LLM might return different user names than the example one!

### Conclusion

In this notebook we learned how to use function calling for multiple tools within Bedrock. With the defined tools, Large Language Model from Amazon Bedrock is able to call the required functions according to the user input and create the final response.

In our scenario, user asks information about tables in a bank database and LLM creates the required SQL query from the natural language by calling the query_athena function. Moreover, if the user requests a bar plot creation, LLM calls bar_plot function in order to create a bar plot for the given query results.







### Clean Up(Optional)

From the AWS Console

* Go to S3 -> Buckets -> Select your bucket and press delete
* Go to AWS Glue -> Databases -> ACME_Bank -> delete