In [1]:
import numpy as np
from langchain_core.tools import tool
import pandas as pd

In [2]:
import re
import openai
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate

In [3]:
load_dotenv()

True

In [4]:
class VectorStoreRetriever:
    def __init__(self, docs: list, vectors: list, open_ai_client):
        self.arr = np.array(vectors)
        self.docs = docs
        self.open_ai_client = open_ai_client
    
    @classmethod
    def from_docs(cls, docs, open_ai_client):
        embeddings = open_ai_client.embeddings.create(
            model = "text-embedding-3-small", 
            input = [ doc['page_content'] for doc in docs ]
        )
        vectors = [ emb.embedding for emb in embeddings.data ]

        return cls(docs, vectors, open_ai_client)

    def query(self, query: str, k: int = 5) -> list[dict]:
        embed = self.open_ai_client.embeddings.create(
            model="text-embedding-3-small", input=[query]
        )
        # "@" is just a matrix multiplication in python
        scores = np.array(embed.data[0].embedding) @ self.arr.T
        top_k_idx = np.argpartition(scores, -k)[-k:]
        top_k_idx_sorted = top_k_idx[np.argsort(-scores[top_k_idx])]
        
        return [
            {**self.docs[idx], "similarity": scores[idx]} for idx in top_k_idx_sorted
        ]

document = open('./knowledge-base/swiss_faq.md', 'r')
content = document.read()
docs = [{"page_content": txt} for txt in re.split(r"(?=\n##)", content)]

retriever = VectorStoreRetriever.from_docs(docs, openai.Client())

@tool
def lookup_policy(query: str) -> str:
    """Consult the company policies to check whether certain options are permitted.
    Use this before making any flight changes performing other 'write' events."""
    docs = retriever.query(query, k=2)
    return "\n\n".join([doc["page_content"] for doc in docs])

In [5]:
print(lookup_policy('Should I reconfirm my flight?'))

  print(lookup_policy('Should I reconfirm my flight?'))



## Booking and Cancellation

1. How can I change my booking?
	* The ticket number must start with 724 (SWISS ticket no./plate).
	* The ticket was not paid for by barter or voucher (there are exceptions to voucher payments; if the ticket was paid for in full by voucher, then it may be possible to rebook online under certain circumstances. If it is not possible to rebook online because of the payment method, then you will be informed accordingly during the rebooking process).
	* There must be an active flight booking for your ticket. It is not possible to rebook open tickets or tickets without the corresponding flight segments online at the moment.
	* It is currently only possible to rebook outbound (one-way) tickets or return tickets with single flight routes (point-to-point).
2. Which tickets/bookings cannot be rebooked online currently?
	* Bookings containing flight segments with other airlines
	* Bookings containing reservations, where a ticket has not yet been issued
	* Bookings wi

In [6]:
def get_relevant_docs(query: str) -> str:
    """Get relevant docs for the query asked."""
    docs = retriever.query(query, k=2)
    return [doc["page_content"] for doc in docs]


def answer_query(query: str) -> str:
    """Answer a query about the company's policies."""
    docs = get_relevant_docs(query)
    if (not docs) or (len(docs) == 0):
        return {
            "success": False,
            "content": "No relevant information found." 
        }
    
    # Combine the content of the top documents
    content = "\n\n".join([doc for doc in docs])
    
    response = openai.Client().chat.completions.create(
        model="gpt-4",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": f"Context: {content} \n\n Question: {query}"}
        ]
    )
    
    return {
        "success": True,
        "content": response.choices[0].message.content.strip()
    }

In [7]:
answer_query("Should I reconfirm my flight?")

{'success': True,
 'content': 'No, reconfirmation of SWISS flights is not required.'}

In [20]:
FLIGHT_TABLE = [
    {
        "id": "F101",
        "name": "Swiss F101",
        "source": "Delhi, India",
        "destination": "London, UK (United Kingdom)",
        "duration": "8h 30m"
    },
    {
        "id": "F201",
        "name": "Swiss F-201",
        "source": "Mumbai, India",
        "destination": "San Francisco, USA (United States of America)",
        "duration": "16h 40m"
    }
]


@tool
def get_flight_details_by_name(name: str) -> list[dict]:
    """Get flight details for the input name.

    Args:
        name: The name of the Flight.
    """
    results = []
    for flight_data in FLIGHT_TABLE:
        if name.lower() in flight_data['name'].lower():
            results.append(flight_data)

    return results


@tool
def get_flight_details_by_source(source: str) -> list[dict]:
    """Get flight details for the source location.

    Args:
        source: The name of the Source location.
    """
    results = []
    for flight_data in FLIGHT_TABLE:
        if source.lower() in flight_data['source'].lower():
            results.append(flight_data)

    return results

@tool
def get_flight_details_by_destination(destination: str) -> list[dict]:
    """Get flight details for the destination location.

    Args:
        destination: The name of the Destination location.
    """
    results = []
    for flight_data in FLIGHT_TABLE:
        if destination.lower() in flight_data['destination'].lower():
            results.append(flight_data)

    return results

In [10]:
from datetime import datetime

HOTELS_TABLE = [
    {
        "id": "H101",
        "name": "Hyatt Regency",
        "location": "Delhi, India",
        "booked_till": datetime(2025, 6, 15, 0, 0, 0) # 15th June 2025
    },
    {
        "id": "H102",
        "name": "Taj Palace",
        "location": "Mumbai, India",
        "booked_till": datetime(2025, 6, 20, 0, 0, 0) # 20th June 2025
    },
]


def get_hotel_availability(location: str, check_in_date: str) -> list[dict]:
    """Get hotel availability for the input location and check-in date.

    Args:
        location: The name of the Location.
        check_in_date: The check-in date in YYYY-MM-DD format.
    """
    results = []
    check_in_date = datetime.strptime(check_in_date, "%Y-%m-%d")
    
    for hotel_data in HOTELS_TABLE:
        if location.lower() in hotel_data['location'].lower():
            if hotel_data['booked_till'] < check_in_date:
                results.append(hotel_data)

    return results


In [24]:
def interrogate_agent(query: str):
    """Interrogate the agent with a query."""
    model = ChatOpenAI(model="gpt-4-turbo-preview")
    tools = [
        get_flight_details_by_name, 
        get_flight_details_by_source,
        get_flight_details_by_destination,
        get_hotel_availability,
    ]

    model_with_tools = model.bind_tools(tools)
    result = model_with_tools.invoke(query)
    print(result.tool_calls)

    return result

# response = interrogate_agent("Can you give me the Flight information ending at USA and also about the Flight F101 ?")
response = interrogate_agent("Can you provide me Hotel details in Mumbai for 2025-06-25 ?")

[{'name': 'get_hotel_availability', 'args': {'check_in_date': '2025-06-25', 'location': 'Mumbai'}, 'id': 'call_pKBzMEonryB6fdmlcGwQQPpb', 'type': 'tool_call'}]


In [22]:
def handle_tool_calls(response):
    if (len(response.tool_calls) == 0):
        return []
    
    tool_call_responses = []
    for tool_call in response.tool_calls:
        if tool_call["name"] == "get_flight_details_by_name":
            arguments = tool_call["args"]
            flight_name = arguments['name']
            print(flight_name)
            print(get_flight_details_by_name(flight_name))
            tool_call_responses.extend(get_flight_details_by_name(flight_name))
        elif tool_call["name"] == "get_flight_details_by_source":
            arguments = tool_call["args"]
            source_location = arguments['source']
            tool_call_responses.extend(get_flight_details_by_source(source_location))
        elif tool_call["name"] == "get_flight_details_by_destination":
            arguments = tool_call["args"]
            destination_location = arguments['destination']
            tool_call_responses.extend(get_flight_details_by_destination(destination_location))
        elif tool_call["name"] == "get_hotel_availability":
            arguments = tool_call["args"]
            location = arguments['location']
            check_in_date = arguments['check_in_date']
            tool_call_responses.extend(get_hotel_availability(location, check_in_date))
    
    return tool_call_responses

In [25]:
print(handle_tool_calls(response))

[{'id': 'H102', 'name': 'Taj Palace', 'location': 'Mumbai, India', 'booked_till': datetime.datetime(2025, 6, 20, 0, 0)}]
