|
| 1 | +# Copyright (c) Microsoft. All rights reserved. |
| 2 | + |
| 3 | + |
| 4 | +import asyncio |
| 5 | +from collections.abc import Awaitable, Callable |
| 6 | +from typing import TYPE_CHECKING, Any |
| 7 | + |
| 8 | +from samples.concepts.memory.azure_ai_search_hotel_samples.data_model import ( |
| 9 | + HotelSampleClass, |
| 10 | + custom_index, |
| 11 | + load_records, |
| 12 | +) |
| 13 | +from semantic_kernel.agents import ChatCompletionAgent |
| 14 | +from semantic_kernel.agents.agent import AgentThread |
| 15 | +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior |
| 16 | +from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAITextEmbedding |
| 17 | +from semantic_kernel.connectors.memory import AzureAISearchCollection |
| 18 | +from semantic_kernel.filters import FilterTypes, FunctionInvocationContext |
| 19 | +from semantic_kernel.functions import KernelParameterMetadata |
| 20 | +from semantic_kernel.functions.kernel_plugin import KernelPlugin |
| 21 | +from semantic_kernel.kernel_types import OptionalOneOrList |
| 22 | + |
| 23 | +if TYPE_CHECKING: |
| 24 | + from semantic_kernel.functions import KernelParameterMetadata |
| 25 | + |
| 26 | + |
| 27 | +""" |
| 28 | +This sample builds on the previous one, but can be run independently. |
| 29 | +It uses the data model defined in step_0_data_model.py, and with that creates a collection |
| 30 | +and creates two kernel functions from those that are then made available to a LLM. |
| 31 | +The first function is a search function that allows you to search for hotels, optionally filtering for a city. |
| 32 | +The second function is a details function that allows you to get details about a hotel. |
| 33 | +""" |
| 34 | + |
| 35 | + |
| 36 | +# Create an Azure AI Search collection. |
| 37 | +collection = AzureAISearchCollection[str, HotelSampleClass]( |
| 38 | + data_model_type=HotelSampleClass, embedding_generator=OpenAITextEmbedding() |
| 39 | +) |
| 40 | +# load the records |
| 41 | +records = load_records() |
| 42 | +# get the set of cities |
| 43 | +cities: set[str] = set() |
| 44 | +for record in records: |
| 45 | + if record.Address.Country == "USA" and record.Address.City: |
| 46 | + cities.add(record.Address.City) |
| 47 | + |
| 48 | + |
| 49 | +# Before we create the plugin, we want to create a function that will help the plugin work the way we want it to. |
| 50 | +# This function allows us to create the plugin with a parameter called `city` that |
| 51 | +# then get's put into a filter for address/city. |
| 52 | +# This function has to adhere to the `DynamicFilterFunction` signature. |
| 53 | +# which consists of 2 named arguments, `filter`, and `parameters`. |
| 54 | +# and kwargs. |
| 55 | +# It returns the updated filter. |
| 56 | +# The default version that is used when not supplying this, reads the parameters and if there is |
| 57 | +# a parameter that is not `query`, `top`, or 'skip`, and it can find a value for it, either in the kwargs |
| 58 | +# or the default value specified in the parameter, it will add a filter to the options. |
| 59 | +# In this case, we are adding a filter to the options to filter by the city, but since the technical name |
| 60 | +# of that field in the index is `address/city`, want to do this manually. |
| 61 | +# this can also be used to replace a complex technical name in your index with a friendly name towards the LLM. |
| 62 | +def filter_update( |
| 63 | + filter: OptionalOneOrList[Callable | str] | None = None, |
| 64 | + parameters: list["KernelParameterMetadata"] | None = None, |
| 65 | + **kwargs: Any, |
| 66 | +) -> OptionalOneOrList[Callable | str] | None: |
| 67 | + if "city" in kwargs: |
| 68 | + city = kwargs["city"] |
| 69 | + if city not in cities: |
| 70 | + raise ValueError(f"City '{city}' is not in the list of cities: {', '.join(cities)}") |
| 71 | + # we need the actual value and not a named param, otherwise the parser will not be able to find it. |
| 72 | + new_filter = f"lambda x: x.Address.City == '{city}'" |
| 73 | + if filter is None: |
| 74 | + filter = new_filter |
| 75 | + elif isinstance(filter, list): |
| 76 | + filter.append(new_filter) |
| 77 | + else: |
| 78 | + filter = [filter, new_filter] |
| 79 | + return filter |
| 80 | + |
| 81 | + |
| 82 | +# Next we create the Agent, with two functions. |
| 83 | +travel_agent = ChatCompletionAgent( |
| 84 | + name="TravelAgent", |
| 85 | + description="A travel agent that helps you find a hotel.", |
| 86 | + service=OpenAIChatCompletion(), |
| 87 | + instructions="""You are a travel agent. Your name is Mosscap and |
| 88 | +you have one goal: help people find a hotel. |
| 89 | +Your full name, should you need to know it, is |
| 90 | +Splendid Speckled Mosscap. You communicate |
| 91 | +effectively, but you tend to answer with long |
| 92 | +flowery prose. You always make sure to include the |
| 93 | +hotel_id in your answers so that the user can |
| 94 | +use it to get more information.""", |
| 95 | + function_choice_behavior=FunctionChoiceBehavior.Auto(), |
| 96 | + plugins=[ |
| 97 | + KernelPlugin( |
| 98 | + name="azure_ai_search", |
| 99 | + description="A plugin that allows you to search for hotels in Azure AI Search.", |
| 100 | + functions=[ |
| 101 | + collection.create_search_function( |
| 102 | + # this create search method uses the `search` method of the text search object. |
| 103 | + # remember that the text_search object for this sample is based on |
| 104 | + # the text_search method of the Azure AI Search. |
| 105 | + # but it can also be used with the other vector search methods. |
| 106 | + # This method's description, name and parameters are what will be serialized as part of the tool |
| 107 | + # call functionality of the LLM. |
| 108 | + # And crafting these should be part of the prompt design process. |
| 109 | + # The default parameters are `query`, `top`, and `skip`, but you specify your own. |
| 110 | + # The default parameters match the parameters of the VectorSearchOptions class. |
| 111 | + description="A hotel search engine, allows searching for hotels in specific cities, " |
| 112 | + "you do not have to specify that you are searching for hotels, for all, use `*`.", |
| 113 | + search_type="keyword_hybrid", |
| 114 | + # Next to the dynamic filters based on parameters, I can specify options that are always used. |
| 115 | + # this can include the `top` and `skip` parameters, but also filters that are always applied. |
| 116 | + # In this case, I am filtering by country, so only hotels in the USA are returned. |
| 117 | + filter=lambda x: x.Address.Country == "USA", |
| 118 | + parameters=[ |
| 119 | + KernelParameterMetadata( |
| 120 | + name="query", |
| 121 | + description="What to search for.", |
| 122 | + type="str", |
| 123 | + is_required=True, |
| 124 | + type_object=str, |
| 125 | + ), |
| 126 | + KernelParameterMetadata( |
| 127 | + name="city", |
| 128 | + description="The city that you want to search for a hotel " |
| 129 | + f"in, values are: {', '.join(cities)}", |
| 130 | + type="str", |
| 131 | + type_object=str, |
| 132 | + ), |
| 133 | + KernelParameterMetadata( |
| 134 | + name="top", |
| 135 | + description="Number of results to return.", |
| 136 | + type="int", |
| 137 | + default_value=5, |
| 138 | + type_object=int, |
| 139 | + ), |
| 140 | + ], |
| 141 | + # and here the above created function is passed in. |
| 142 | + filter_update_function=filter_update, |
| 143 | + # finally, we specify the `string_mapper` function that is used to convert the record to a string. |
| 144 | + # This is used to make sure the relevant information from the record is passed to the LLM. |
| 145 | + string_mapper=lambda x: f"(hotel_id :{x.record.HotelId}) {x.record.HotelName} (rating {x.record.Rating}) - {x.record.Description}. Address: {x.record.Address.StreetAddress}, {x.record.Address.City}, {x.record.Address.StateProvince}, {x.record.Address.Country}. Number of room types: {len(x.record.Rooms)}. Last renovated: {x.record.LastRenovationDate}.", # noqa: E501 |
| 146 | + ), |
| 147 | + collection.create_search_function( |
| 148 | + # This second function is a more detailed one, that uses a `HotelId` to get details about a hotel. |
| 149 | + # we set the top to 1, so that only 1 record is returned. |
| 150 | + function_name="get_details", |
| 151 | + description="Get details about a hotel, by ID, use the generic search function to get the ID.", |
| 152 | + top=1, |
| 153 | + parameters=[ |
| 154 | + KernelParameterMetadata( |
| 155 | + name="HotelId", |
| 156 | + description="The hotel ID to get details for.", |
| 157 | + type="str", |
| 158 | + is_required=True, |
| 159 | + type_object=str, |
| 160 | + ), |
| 161 | + ], |
| 162 | + ), |
| 163 | + ], |
| 164 | + ) |
| 165 | + ], |
| 166 | +) |
| 167 | + |
| 168 | + |
| 169 | +# This filter will log all calls to the Azure AI Search plugin. |
| 170 | +# This allows us to see what parameters are being passed to the plugin. |
| 171 | +# And this gives us a way to debug the search experience and if necessary tweak the parameters and descriptions. |
| 172 | +@travel_agent.kernel.filter(filter_type=FilterTypes.FUNCTION_INVOCATION) |
| 173 | +async def log_search_filter( |
| 174 | + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] |
| 175 | +): |
| 176 | + print(f"Calling Azure AI Search ({context.function.name}) with arguments:") |
| 177 | + for arg in context.arguments: |
| 178 | + if arg in ("chat_history"): |
| 179 | + continue |
| 180 | + print(f' {arg}: "{context.arguments[arg]}"') |
| 181 | + await next(context) |
| 182 | + |
| 183 | + |
| 184 | +async def chat(): |
| 185 | + # Create the Azure AI Search collection |
| 186 | + async with collection: |
| 187 | + # Check if the collection exists. |
| 188 | + if not await collection.does_collection_exist(): |
| 189 | + await collection.create_collection(index=custom_index) |
| 190 | + if not await collection.get(top=1): |
| 191 | + await collection.upsert(records) |
| 192 | + thread: AgentThread | None = None |
| 193 | + while True: |
| 194 | + try: |
| 195 | + user_input = input("User:> ") |
| 196 | + except KeyboardInterrupt: |
| 197 | + print("\n\nExiting chat...") |
| 198 | + break |
| 199 | + except EOFError: |
| 200 | + print("\n\nExiting chat...") |
| 201 | + break |
| 202 | + |
| 203 | + if user_input == "exit": |
| 204 | + print("\n\nExiting chat...") |
| 205 | + break |
| 206 | + |
| 207 | + result = await travel_agent.get_response(messages=user_input, thread=thread) |
| 208 | + print(f"Agent: {result.content}") |
| 209 | + thread = result.thread |
| 210 | + |
| 211 | + delete_collection = input("Do you want to delete the collection? (y/n): ") |
| 212 | + if delete_collection.lower() == "y": |
| 213 | + await collection.delete_collection() |
| 214 | + print("Collection deleted.") |
| 215 | + else: |
| 216 | + print("Collection not deleted.") |
| 217 | + |
| 218 | + |
| 219 | +async def main(): |
| 220 | + print( |
| 221 | + "Welcome to the chat bot!\ |
| 222 | + \n Type 'exit' to exit.\ |
| 223 | + \n Try to find a hotel to your liking!" |
| 224 | + ) |
| 225 | + await chat() |
| 226 | + |
| 227 | + |
| 228 | +if __name__ == "__main__": |
| 229 | + import asyncio |
| 230 | + |
| 231 | + asyncio.run(main()) |
0 commit comments