In [1]:
import os
import re
import json
from amadeus import Client, ResponseError
from langchain.chat_models import ChatOpenAI
from langchain.agents import initialize_agent, AgentType
from langchain.memory import ConversationBufferMemory
from langchain.tools import Tool
from langchain.schema import HumanMessage
from langchain_google_genai import ChatGoogleGenerativeAI
import warnings
warnings.filterwarnings("ignore")
from dotenv import load_dotenv


load_dotenv()


GEMINI_API_KEY = os.getenv("GOOGLE_API_KEY")
AMADEUS_API_KEY = os.getenv("AMADEUS_API_KEY")
AMADEUS_API_SECRET = os.getenv("AMADEUS_API_SECRET")

amadeus = Client(
    client_id=os.environ["AMADEUS_API_KEY"],
    client_secret=os.environ["AMADEUS_API_SECRET"]
)

llm = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash",
)

def format_duration(iso_duration):
    match = re.match(r'PT(?:(\d+)H)?(?:(\d+)M)?', iso_duration)
    hours = match.group(1) if match.group(1) else "0"
    minutes = match.group(2) if match.group(2) else "0"
    return f"{int(hours)} hours {int(minutes)} minutes"

# Function to get full airline names from codes using Gemini
def get_airline_full_name(airline_code):
    prompt = f"Please provide the full name for the airline only '{airline_code}'."
    response = llm([HumanMessage(content=prompt)])
    return response.content.strip() if response else airline_code

# Tool: Flight Search
def fetch_flights(origin, destination, departure_date, return_date=None, max_price=None, airline_name =None):
    try:
        # Set a high default max_price if not provided
        max_price = max_price if max_price else 20000
        params = {
            "originLocationCode": origin,
            "destinationLocationCode": destination,
            "departureDate": departure_date,
            "adults": 1,
            "maxPrice": max_price
        }

        if return_date:
            params["returnDate"] = return_date

        # Fetch flights from Amadeus API
        response = amadeus.shopping.flight_offers_search.get(**params)
        flights = response.data

        if flights:
            result = []
            for flight in flights[:5]:  # Limit to top 5 results
                if float(flight['price']['total']) <= max_price:
                    # Outbound flight details
                    segments = flight['itineraries'][0]['segments']
                    airline_code = segments[0]['carrierCode']
                    airline = get_airline_full_name(airline_code)  # Get full airline name
                    # Only add flights that match the specified airline, if provided
                    if airline_name and airline and airline.lower() not in airline_name.lower():
                        continue
                    departure_time = segments[0]['departure']['at']
                    arrival_time = segments[-1]['arrival']['at']
                    flight_duration = format_duration(flight['itineraries'][0]['duration'])

                    # Only include return details if a return date is provided
                    if return_date and len(flight['itineraries']) > 1:
                        return_segments = flight['itineraries'][1]['segments']
                        return_departure_time = return_segments[0]['departure']['at']
                        return_arrival_time = return_segments[-1]['arrival']['at']
                        return_duration = format_duration(flight['itineraries'][1]['duration'])
                        return_info = (
                            f"\nReturn Departure: {return_departure_time}\n"
                            f"Return Arrival: {return_arrival_time}\n"
                            f"Return Duration: {return_duration}\n"
                        )
                    else:
                        return_info = ""

                    # Append both outbound and return information (if available) to results
                    result.append(
                        f"Airline: {airline}\nPrice: ${flight['price']['total']}\n"
                        f"Departure: {departure_time}\nArrival: {arrival_time}\n"
                        f"Duration: {flight_duration}{return_info}"
                        "\n----------------------------------------"
                    )
            return "\n\n".join(result) if result else "No flights found within the budget."
        return "No flights found."
    except ResponseError as error:
        return f"An error occurred: {error.response.result}"

# Define the flight search tool
flight_search_tool = Tool(
    name="Flight Search",
    func=lambda input_str: fetch_flights(
        **json.loads(input_str)
    ),
    description="Find flights based on origin, destination, departure date, and return date."
)

# Initialize memory and agent with the flight search tool
memory = ConversationBufferMemory()
agent = initialize_agent(
    tools=[flight_search_tool],
    llm=llm,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    memory=memory,
    verbose=True
)

def parse_flight_details(query):
    prompt = (
        f"Extract the following details from this flight request:\n"
        f"- Origin location (IATA code if possible)\n"
        f"- Destination location (IATA code if possible)\n"
        f"- Departure date (MM-DD-YYYY)\n"
        f"- Return date (MM-DD-YYYY, if provided)\n\n"
        f"Request: '{query}'\n\n"
        f"Respond ONLY in this format:\n"
        f"Origin: <origin>\n"
        f"Destination: <destination>\n"
        f"Departure Date: <MM-DD-YYYY>\n"
        f"Return Date: <MM-DD-YYYY or None>\n"
        f"No extra text."
    )

    message = HumanMessage(content=prompt)
    response = llm([message])
    response_text = response.content.strip()

    # Regex to extract fields
    origin_match = re.search(r"Origin:\s*(.+)", response_text)
    dest_match = re.search(r"Destination:\s*(.+)", response_text)
    dep_match = re.search(r"Departure Date:\s*(.+)", response_text)
    ret_match = re.search(r"Return Date:\s*(.+)", response_text)

    origin = origin_match.group(1).strip() if origin_match else None
    destination = dest_match.group(1).strip() if dest_match else None
    departure_date = dep_match.group(1).strip() if dep_match else None
    return_date = ret_match.group(1).strip() if ret_match else None

    if not all([origin, destination, departure_date]):
        raise ValueError(f"Incomplete flight details extracted from the query: {response_text}")

    return origin, destination, departure_date, return_date


# Main function
def get_flight_recommendations():
    user_query = input("Enter prompt: ")

    try:
        origin, destination, departure_date, return_date = parse_flight_details(user_query)
    except ValueError as e:
        print(f"Error parsing flight details: {e}")
        return

    flight_details = {
        "origin": origin,
        "destination": destination,
        "departure_date": departure_date,
        "return_date": return_date
    }
    response = agent.run(input=json.dumps(flight_details)) # Pass the dictionary as a JSON string under the 'input' key
    print(response)
get_flight_recommendations()


For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  return _bootstrap._gcd_import(name[level:], package, level)
* 'allow_population_by_field_name' has been renamed to 'populate_by_name'




[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: Flight Search
Action Input: {"origin": "BOM", "destination": "DXB", "departure_date": "08-19-2025", "return_date": "08-30-2025"}[0m
Observation: [36;1m[1;3mAn error occurred: {'errors': [{'status': 400, 'code': 477, 'title': 'INVALID FORMAT', 'detail': 'departureDate format is YYYY-MM-DD', 'source': {'pointer': 'departureDate', 'example': '2030-12-31'}}, {'status': 400, 'code': 477, 'title': 'INVALID FORMAT', 'detail': 'returnDate format is YYYY-MM-DD', 'source': {'pointer': 'returnDate', 'example': '2030-12-31'}}]}[0m
Thought:[32;1m[1;3mAction: Flight Search
Action Input: {"origin": "BOM", "destination": "DXB", "departure_date": "2025-08-19", "return_date": "2025-08-30"}[0m
Observation: [36;1m[1;3mAirline: The full name for the airline with the IATA code 'WY' is **Oman Air**.
Price: $271.38
Departure: 2025-08-19T14:50:00
Arrival: 2025-08-19T21:25:00
Duration: 8 hours 5 minutes
Return Departure: 2025-08-30T06