## Multitool Creation using Bedrock ConverseAPI Function Calling

#### In this notebook you will create two different tools called bar_chart and query_athena. Bar_chart tool will create a matplot bar chart for the query results coming from Amazon Athena database via query_athena tool.

### Prerequisites
1. Create an S3 bucket and run create_athena_catalog.py to create a table in Athena
2. Add 'AmazonBedrockFullAccess' policy into the sagemaker notebook


In [4]:
!pip3 install -qU boto3


In [5]:
!pip3 install matplotlib

In [6]:
!pip3 install PyAthena

In [14]:
import boto3
import json, sys
from datetime import datetime
from pyathena import connect
import pandas as pd

print('Running boto3 version:', boto3.__version__)

In [15]:
modelId = 'anthropic.claude-3-sonnet-20240229-v1:0'
#modelId = 'anthropic.claude-3-haiku-20240307-v1:0'
#modelId = 'cohere.command-r-plus-v1:0'
#modelId = 'cohere.command-r-v1:0'
#modelId = 'mistral.mistral-large-2402-v1:0'
print(f'Using modelId: {modelId}')

region = 'us-east-1'
print('Using region: ', region)

bedrock = boto3.client(
    service_name = 'bedrock-runtime',
    region_name = region,
    )

bucket = 'bedrockconverseapibucket'

In [16]:
import matplotlib.pyplot as plt


class ToolsList:
    #Define our bar_chart tool function...
    def bar_chart(self, title, x_values, y_values, x_label, y_label):
      
        fig, ax = plt.subplots(figsize=(10, 6))  
        
        ax.bar(x_values, y_values, color='skyblue')
    
        # Adding title and labels
        ax.set_title(title)
        ax.set_xlabel(x_label)
        ax.set_ylabel(y_label)
        
        plt.show() 
        
        output_name=f"{title}.png"
        fig.savefig(output_name)
        
        
        result = f'Your bar chart named {title} is saved'

        # Return the figure object
        return result
        

    def query_athena(self, query):
        print(f"{datetime.now().strftime('%H:%M:%S')} - Got tool query: {query}\n")
        try:
            cursor = connect(s3_staging_dir=f"s3://{bucket}/athena/",
                                region_name="us-east-1").cursor()
            cursor.execute(query)
            df = pd.DataFrame(cursor.fetchall()).to_string(index=False)
            print(f"{datetime.now().strftime('%H:%M:%S')} - Tool result: {df}\n")
        except Exception as e:
            print(f"{datetime.now().strftime('%H:%M:%S')} - Error: {e}")
            raise
        return df
       
        


In [17]:
#Define the configuration for our tool...
toolConfig = {'tools': [],
'toolChoice': {
    'auto': {},
    #'any': {},
    #'tool': {
    #    'name': 'get_weather'
    #}
    }
}

toolConfig['tools'].append({
        'toolSpec': {
            'name': 'bar_chart',
            'description': 'create a bar chart.',
            'inputSchema': {
                'json': {
                    'type': 'object',
                    'properties': {
                        'title': {
                            'type': 'string',
                            'description': 'title of the bar chart'
                        },
                        
                         'x_values': {
                            'type': 'array',
                            'description': 'x axis values of the bar chart',
                            'items': {
                                "type": "string"
                            }
                        },
                        
                         'y_values': {
                            'type': 'array',
                            'description': 'y axis values of the bar chart',
                            'items': {
                                'type': 'string'
                            }
                        },
                        
                         'x_label': {
                            'type': 'string',
                            'description': 'x axis label of the bar chart'
                        },
                        
                          'y_label': {
                            'type': 'string',
                            'description': 'y axis label of the bar chart'
                        }
                        
                    },
                    'required': ['title','x_values','y_values','x_label','y_label']
                }
            }
        }
    })



### Defining tool schema
toolConfig['tools'].append({
        'toolSpec': {
            'name': 'query_athena',
            'description': 'Query the Acme Bank Athena catalog.',
            'inputSchema': {
                'json': {
                    'type': 'object',
                    'properties': {
                        'query': {'type': 'string', 'description': 'SQL query to run against the Athena catalog'}
                    },
                    'required': ['query']
                }
            }
        }
    })



In [18]:
#Function for caling the Bedrock Converse API...
def converse_with_tools(messages, system='', toolConfig=toolConfig):
    response = bedrock.converse(
        modelId=modelId,
        system=system,
        messages=messages,
        toolConfig=toolConfig
    )
    return response

In [19]:
#Function for orchestrating the conversation flow...
def converse(prompt, system=''):
    #Add the initial prompt:
    messages = []
    messages.append(
        {
            "role": "user",
            "content": [
                {
                    "text": prompt
                }
            ]
        }
    )
    print(f"\n{datetime.now().strftime('%H:%M:%S')} - Initial prompt:\n{json.dumps(messages, indent=2)}")

    #Invoke the model the first time:
    output = converse_with_tools(messages, system)
    print(f"\n{datetime.now().strftime('%H:%M:%S')} - Output so far:\n{json.dumps(output['output'], indent=2, ensure_ascii=False)}")

    #Add the intermediate output to the prompt:
    messages.append(output['output']['message'])

    function_calling = next((c['toolUse'] for c in output['output']['message']['content'] if 'toolUse' in c), None)

    #Check if function calling is triggered:
    if function_calling:
        #Get the tool name and arguments:
        tool_name = function_calling['name']
        tool_args = function_calling['input'] or {}
        
        #Run the tool:
        print(f"\n{datetime.now().strftime('%H:%M:%S')} - Running ({tool_name}) tool...")
        tool_response = getattr(ToolsList(), tool_name)(**tool_args) or ""
        if tool_response:
            tool_status = 'success'
        else:
            tool_status = 'error'

        #Add the tool result to the prompt:
        messages.append(
            {
                "role": "user",
                "content": [
                    {
                        'toolResult': {
                            'toolUseId':function_calling['toolUseId'],
                            'content': [
                                {
                                    "text": tool_response
                                }
                            ],
                            'status': tool_status
                        }
                    }
                ]
            }
        )
        #print(f"\n{datetime.now().strftime('%H:%M:%S')} - Messages so far:\n{json.dumps(messages, indent=2)}")

        #Invoke the model one more time:
        output = converse_with_tools(messages, system)
        print(f"\n{datetime.now().strftime('%H:%M:%S')} - Final output:\n{json.dumps(output['output'], indent=2, ensure_ascii=False)}\n")
    return


In [20]:
prompts = [
   
    "return me the name of the table in acme_bank database",
    "return me 3 user_names which are different from each other without duplicates, and their corresponding amounts from the transactions table within acme_bank database, print the user names amd the amounts as array list such as usernames:[x,y,z] and values:[x,y,z] ",
]



for prompt in prompts:
    converse(
        system = [{"text": "You're provided with a tool that can plot a bar chart with the given title 'bar_chart' and another tool to create and run SQL queries agains Athena data catalog named 'query_athena'; \
            only use the tool if required;\ You can use multiple tools at once or you can call the tool multiple times in the same response if required; \ Don't make reference to the tools in your final answer."}],
        prompt = prompt
)
    
    


In [21]:
    
prompt = "Create the bar chart, title is acme_bank_chart, x values are [Leonardo DiCaprio, Daniel Day-Lewis,Brad Pitt],y values are [750.0,1600.0,2000.0], x label is usernames and y label is amounts."
converse(
        system = [{"text": "You're provided with a tool that can plot a bar chart with the given title 'bar_chart' and another tool to create and run SQL queries agains Athena data catalog named 'query_athena'; \
            only use the tool if required;\ You can use multiple tools at once or you can call the tool multiple times in the same response if required; \ Don't make reference to the tools in your final answer."}],
        prompt = prompt
)
    