# NUS-NCS Innovation Challenge (Final Round): GPTTransit Chatbot Application

**Team GPT (Green Path Traffic)**

This notebook is divided into 4 parts.

Overview of Content:
1) Setup & Define OpenAI / HuggingFace Models
2) Constucting Tools with Functions for Data Processing & Making API Calls
3) Chat Model, OpenAI Tool Agent & Memory
4) Gradio Chatbot User Interface

## Part 1: Setup & Define OpenAI / HuggingFace Models

In [1]:
# !pip install -U langchain-community
# !pip install -U huggingface_hub
# !pip install -U transformers
# !pip install -U text_generation
# !pip install -U openai
# !pip install -U tiktoken
# !pip install -U langchain-openai
# !pip install -U langchainhub    

# !pip install langchain
# !pip install langchain-community
# !pip install huggingface_hub
# !pip install transformers
# !pip install text_generation
# !pip install openai
# !pip install tiktoken
# !pip install langchain-openai
# !pip install langchainhub    
# !pip install gradio  ## Guide to install Gradio: https://www.gradio.app/main/guides/installing-gradio-in-a-virtual-environment

In [2]:
# Import packages
import requests
import re
import pandas as pd
import time
from datetime import datetime, timedelta
from collections import Counter
from math import radians, sin, cos, sqrt, atan2

import openai
import langchain
from langchain_community.llms import HuggingFaceEndpoint
from langchain_openai import ChatOpenAI 
from langchain.chains import LLMChain  
from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain.agents import tool, Tool, load_tools, create_openai_tools_agent, AgentExecutor 
from langchain.tools import StructuredTool
from langchain.memory import ChatMessageHistory, ConversationSummaryMemory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables import RunnablePassthrough
from langchain_core.messages import HumanMessage, AIMessage

import gradio as gr

In [None]:
import os
from dotenv import load_dotenv, find_dotenv

# Initialise API keys: OpenAI, HuggingFace, LTA DataMall, OneMap
_ = load_dotenv(find_dotenv()) 

# read API keys from .env file
openai.api_key  = os.environ["OPENAI_API_KEY"]  
huggingfacehub_api_token = os.environ["HUGGINGFACEHUB_API_TOKEN"] 

LTA_API_KEY = os.environ["LTA_API_KEY"] 
ONEMAP_API_KEY = os.environ["ONEMAP_API_KEY"] 

In [None]:
# Define language models
# OpenAI's GPT-4 Chat model
chatllm = ChatOpenAI(model_name="gpt-4-turbo", temperature=0) 

# MistalAI's Mixtral-8x7B Instruct model (from HuggingFace) - used for summarisation of chat_memory
llm = HuggingFaceEndpoint(repo_id='mistralai/Mixtral-8x7B-Instruct-v0.1', temperature=0.01, huggingfacehub_api_token=huggingfacehub_api_token)

# Set temperature = 0 to maximise deterministic & predictable outputs

## Part 2: Constucting Tools with Functions for Data Processing & Making API Calls 

Tools are interfaces that the agent can use to call functions, facilitating the building of action-taking systems and gathering of contextual information through external API calls. <br>Agents involve the use of a language model as a reasoning engine to determine a sequence of actions to take (defined in Part 3). 

### (I) Get Train / Walking Route

- Description : Get train / walking route
- Function : get_public_transport_route_concise(input)
- Input: *string* "start_station,end_station"

In [5]:
# Define function to get train route

@tool
def get_public_transport_route_concise(station):
    """Use this to find the journey time, cost and list of train stations from one MRT train station to another MRT train station. 
    Also, use this to find walking routes and number of steps between MRT train stations.
    You can also use this to find all MRT stations in between 2 stations, and check if you will pass by a certain station.
    To use the tool, you must provide two parameters in the following format: 'start_station,end_station'. Do not include 'mrt' in the input."""
    try:
        # Read the CSV file into a DataFrame
        df = pd.read_csv('mrtlrt_gps.csv')
        start_station, end_station = str(station).split(",")
        start_station = start_station.upper().replace(" ", "") # updated so that all spaces will become blank
        end_station = end_station.upper().replace(" ", "") # updated so that all spaces will become blank

        # Get starting station details
        station_start = df[df['station_name'] == start_station]
        if station_start.empty:
            start_not_found = f"Starting station '{start_station}' not found, do ensure that the spelling is correct."
            print(start_not_found)
            return start_not_found
        lat_start = station_start['lat'].values[0]
        lng_start = station_start['lng'].values[0]

        # Get destination station details
        station_end = df[df['station_name'] == end_station]
        if station_end.empty:
            dest_not_found = f"Destination station '{end_station}' not found, do ensure that the spelling is correct."
            print(dest_not_found)
            return dest_not_found
        lat_end = station_end['lat'].values[0]
        lng_end = station_end['lng'].values[0]
        code_end = station_end['station_code'].values[0] # get destination station code for walk as last leg
        name_end = station_end['full_name'].values[0] # get destination station name for walk as last leg

        # Define parameters for API request
        routeType = "pt"
        date = datetime.today().strftime("%m-%d-%Y")
        hour = '{:02d}'.format(datetime.now().hour)
        minute = '{:02d}'.format(datetime.now().minute)
        mode = "RAIL"
        maxWalkDistance = 100
        numItineraries = 3

        # Construct URL for API request
        url = f"https://www.onemap.gov.sg/api/public/routingsvc/route?start={lat_start}%2C{lng_start}&end={lat_end}%2C{lng_end}&routeType={routeType}&date={date}&time={hour}%3A{minute}%3A00&mode={mode}&maxWalkDistance={maxWalkDistance}&numItineraries={numItineraries}"

        # Define headers for API request (Access code valid until 23 Mar)
        headers = {
            "Authorization": ONEMAP_API_KEY
        }

        # Make API request
        response = requests.request("GET", url, headers=headers)
        
        # Handle case where API request fails
        if not response.ok:
            response_fail = "The API request failed, please try again later."
            print(response_fail)
            return response_fail
        
        # Extract route information from API response
        routes_str = ""
        api_response = response.json()
        route_count = 1
        api_response_itineraries = api_response['plan']['itineraries']
        walk_legs = 0

        for itinerary in api_response_itineraries:
            for leg in itinerary['legs']:
                if leg['mode'] == 'WALK' and leg['from']['name'] == 'Origin' and leg['to']['name'] == 'Destination':
                    walk_legs += 1

        routes_count = len(api_response_itineraries) - walk_legs
        routes_str += f"There are {routes_count} possible travel route(s).\n"

        for itinerary in api_response_itineraries: # pull out all the itineraries or routes from the api
            route_string = ""
            duration = itinerary['duration'] / 60  # calculate overall duration of the route
            fare = itinerary['fare']  # extract total fare of the route
            prev_station_name = None
            first_leg = itinerary['legs'][0]  # Extracting the first leg
            last_leg = itinerary['legs'][-1]  # Extracting the last leg
            transit_distance = 0

            if first_leg['mode'] == 'WALK' and \
                first_leg['from']['name'] == 'Origin' and \
                first_leg['to']['name'].replace(' MRT STATION', '').replace(' ', '') != start_station and \
                first_leg['to']['name'] != 'Destination':

                route_string += f"Walk {round(first_leg['distance'],0)} metres or {round(first_leg['distance']/0.75,0)} steps to {first_leg['to']['stopCode']} {first_leg['to']['name'].replace(' MRT STATION', '')} "

            for leg_index, leg in enumerate(itinerary['legs']):  # within each itinerary, look for each leg of the route
                
                if leg_index > 0 and leg_index < len(itinerary['legs']) - 1 and \
                    leg['mode'] == 'WALK' and \
                    leg['from']['name'] != leg['to']['name']: # to include "walk leg from one station to another" in the route string
                    route_string += f" then walk {round(leg['distance'],0)} metres or {round(leg['distance']/0.75,0)} steps"  

                if leg_index > 0 and leg_index < len(itinerary['legs']) - 1 and \
                    leg['mode'] == 'WALK' and \
                    leg['from']['name'] == leg['to']['name']: # to include "walk leg for transit" 
                    transit_distance = leg['distance']  

                if leg_index > 0 and leg_index < len(itinerary['legs']) - 1 and \
                    leg['mode'] == 'SUBWAY' and \
                    leg['from']['name'] == leg['to']['name']: # to include "walk leg for transit cross platform only" 
                    transit_distance = 0
                    
                if leg['mode'] == 'SUBWAY':  # if the leg is a subway route

                    if route_string.startswith("Walk"):  # Check if route_string starts with "Walk"
                        route_string += "then take train from "
                    elif route_string:
                        route_string += " to "
                    current_station_name = leg['from']['name'].replace(' MRT STATION', '') # remove excess words

                    if current_station_name == prev_station_name and transit_distance !=0:
                        route_string += f"transit by walking {round(transit_distance,0)} metres or {round(transit_distance/0.75,0)} steps to "  # include the walking to transit station
                    elif current_station_name == prev_station_name and transit_distance ==0:
                        route_string += f"transit by crossing the platform (10 meters, 13 steps) to "  # include the walking to transit station
                                        
                    route_string += f"{leg['from']['stopCode']} {current_station_name}"
                    prev_station_name = leg['to']['name'].replace(' MRT STATION', '') # remove excess words
                    route_string += f" to {leg['to']['stopCode']} {prev_station_name}"

            # Check if the last leg is a 'WALK' mode and meets the conditions
            if last_leg['mode'] == 'WALK' and \
                last_leg['to']['name'] == 'Destination' and \
                last_leg['from']['name'].replace(' MRT STATION', '').replace(' ', '') != end_station and \
                last_leg['from']['name'] != 'Origin':
                if route_string:  # If route_string is not empty, add "Walk to end station"
                    routes_str += f"Route {route_count}: {route_string} then walk {round(last_leg['distance'],0)} metres or {round(last_leg['distance']/0.75,0)} steps to {code_end} {name_end} with an estimated duration of {round(duration, 0)} minutes and cost ${fare}.\n"
            elif route_string:  # Check if route_string is not empty before adding to routes_str
                routes_str += f"Route {route_count}: {route_string} with an estimated duration of {round(duration, 0)} minutes and cost ${fare}.\n"
            
            # Increment the route count at the end of the loop
            route_count += 1
                                  
        return routes_str

    except Exception as e:
        exception_msg = f"An error occurred: {str(e)}"
        print(exception_msg)
        return exception_msg
        

### (II) Get GPS

- Description : Get GPS coordinates searching for input location
- Function : get_GPS(input_location)
- Input: *string* any postal code, address or road name

In [6]:
# Define function to get GPS coordinates

@tool
def getGPS(query):
    """
    Use this to get the latitude and longitude coordinates. You must provide a location or postal code or address or road name.
    """
    url = f"https://www.onemap.gov.sg/api/common/elastic/search?searchVal={query}&returnGeom=Y&getAddrDetails=Y&pageNum=1"
    headers = {"Content-Type": "application/json"}
    response = requests.request("GET", url, headers=headers)
    if response.status_code == 200:
        api_output = response.json()
        if api_output["found"] > 0:
            result = api_output["results"][0]  # Extract the first result
            latitude = float(result["LATITUDE"])
            longitude = float(result["LONGITUDE"])
            address = result["ADDRESS"]
            output_msg = f"Latitude: {latitude}\nLongitude: {longitude}\nAddress: {address}"
            return output_msg
        else:
            output_msg = "No results were found. Ensure that the input is a valid location, postal code, address or road name."
            return output_msg
    else:
        error_code = str(response.status_code)
        error_msg = f"Error: {error_code}. Please try again later."
        return error_msg

### (III) Weather Forecast

- Description: Get the nationwide weather forecast for the next 2 and 24 hours.
- Function: get_2h_24h_weather_forecast(passthrough)
- Input: *not required* Takes in any string as input (needed for Langchain agents to run).

In [7]:
# Function to get the nationwide aggregated weather forecast for the next 2 hours
def get_2h_weather_forecast():
    try:
        url = "https://api.data.gov.sg/v1/environment/2-hour-weather-forecast"
        response = requests.get(url)

        # Check if the request was successful (status code 200)
        if response.status_code == 200:
            # Parse JSON response
            data = response.json()

            # Extract forecasts
            forecasts = data["items"][0]["forecasts"]

            # Count the occurrences of each forecast
            forecast_counter = Counter([forecast["forecast"] for forecast in forecasts])

            # Get the most common forecast
            forecast_2hrs = forecast_counter.most_common(1)[0][0]
            return forecast_2hrs
        else:
            return "Error: Unable to fetch 2-hour weather forecast data from the API."
    except Exception as e:
        return f"Error: {e}"

# Function to get the nationwide weather forecast for the next 24 hours
def get_24h_weather_forecast():
    try:
        url = "https://api.data.gov.sg/v1/environment/24-hour-weather-forecast"
        response24 = requests.get(url)

        # Check if the request was successful (status code 200)
        if response24.status_code == 200:
            # Parse JSON response
            data24 = response24.json()

            # Extract forecast from the "general" key
            forecast_24hrs = data24["items"][0]["general"]["forecast"]
            return forecast_24hrs
        else:
            return "Error: Unable to fetch 24-hour weather forecast data from the API."
    except Exception as e:
        return f"Error: {e}"


# Function to get the nationwide 2 hour & 24 hour weather forecast
# Takes in any string input - required for agents to work, but will not be used by the function.
@tool
def get_2h_24h_weather_forecast(passthrough):
    """
    Use this to get the weather forecast for the next 2 and 24 hours. Provide any string as input.
    """
    forecast_2h = get_2h_weather_forecast()
    forecast_24h = get_24h_weather_forecast()

    # Check for errors. If yes, highlight that the forecast is not available.
    if "Error" in forecast_2h:
        forecast_2h = f"Not available due to '{forecast_2h}'"
    if "Error" in forecast_24h:
        forecast_24h = f"Not available due to '{forecast_24h}'"
        
    # Construct weather forecast message 
    forecast_message = f"2-Hour weather forecast: {forecast_2h}.\n24-Hour weather forecast: {forecast_24h}."

    return(forecast_message)

### (IV) Supporting Functions for LTA DataMall Tools

Define common functions that will be used as part of other larger functions/tools. Will not be used as tools by the Agent. 

In [8]:
def cleanTimePrompt(prompt):
    # Segregate 'start_stn, end_stn; time'
    parts = prompt.split(';')
    text = parts[0].strip()
    
    # Check time info
    if (len(parts)) > 1 and parts[1].strip() != "":
        date_time = parts[1].strip()
    else:
        date_time = ","
    return text, date_time

def splitDateTime(prompt):
    if ',' in prompt:
        date_str, time_str = prompt.split(',', 1)
        # If both empty : get current date & time 
        if not date_str.strip():
            date = datetime.now().strftime("%d-%m-%Y")
        else:
            date = date_str.strip()
        
        if not time_str.strip():
            time = datetime.now().strftime("%H:%M")
        else:
            # Check if time is in 'HHMM' format
            time_input = time_str.strip()
            if len(time_input) == 4 and time_input.isdigit():
                time = time_input[:2] + ":" + time_input[2:]
            else:
                time = time_input
    else:
        if len(prompt) == 0:
            date = datetime.now().strftime("%d-%m-%Y")
            time = datetime.now().strftime("%H:%M")
        else:
            if len(prompt) <= 5:
                date = datetime.now().strftime("%d-%m-%Y")
                if not prompt.strip():
                    time = datetime.now().strftime("%H:%M")
                else:
                    # Check if time is in 'HHMM' format
                    time_input = prompt.strip()
                    if len(time_input) == 4 and time_input.isdigit():
                        time = time_input[:2] + ":" + time_input[2:]
                    else:
                        time = time_input
            else:
                date = prompt
                time = datetime.now().strftime("%H:%M")
    
    # Check weekday/weekend
    if datetime.strptime(date, "%d-%m-%Y").weekday() < 5:
        day = 'WEEKDAY'
    else:
        day = 'WEEKENDS/HOLIDAY'

    time = datetime.strptime(time, "%H:%M")
    return day, time

def cleanPrompt(prompt):
    # Segregate into multiple routes
    routes = []
    prompt_routes_pattern = r'Route .*?\.'
    prompt_single_pattern = r'MRT station .*?\.'
    prompt_routes = re.findall(prompt_routes_pattern, prompt)
    prompt_single = re.findall(prompt_single_pattern, prompt)
    
    if prompt_single is None:
        for indiv_route in prompt_routes:
            routes.append(indiv_route)
    else:
        routes.append(prompt_single)
    
    # Loop to get info of all stations 
    stn_df = pd.DataFrame(columns=['stn_codes', 'stn_lines'])
    for i in range(len(routes)):
        # Get all stn_codes from prompt
        route_stn = pd.DataFrame()
        pattern = r'\b(?:[A-Z]{2}\d{2}|[A-Z]{2}\d{1})\b'
        codes = re.findall(pattern, prompt)
        lines = [station[:2] + 'L' for station in codes]
        route_stn = pd.DataFrame({
            'stn_codes': codes,
            'stn_lines': lines
        })
        stn_df = pd.concat([stn_df, route_stn], ignore_index=True)
        if i == 0:
            origin = stn_df.loc[0, 'stn_codes']
            destination = stn_df.loc[len(stn_df)-1, 'stn_codes']
            
    stn_df = stn_df.drop_duplicates(subset='stn_codes')
    return stn_df, origin, destination

def cleanHighCrowdPrompt(prompt, data_df):
    # Segregate into multiple routes
    highcrowd = []
    highcrowd_pattern = r'CROWD LEVEL HIGH .*?\.'
    highcrowd_prompt = re.findall(highcrowd_pattern, prompt)
    
    if highcrowd_prompt is None:
        return None
    else:
        highcrowd.append(highcrowd_prompt)
    
    # Loop to get info of all stations 
    stn_df = pd.DataFrame(columns=['stn_codes'])
    pattern = r'\b(?:[A-Z]{2}\d{2}|[A-Z]{2}\d{1})\b'
    codes = re.findall(pattern, prompt)
    highcrowd_stn = pd.DataFrame({
        'stn_codes': codes,
    })
    data_df = pd.concat([data_df, highcrowd_stn], ignore_index=True)
            
    data_df = data_df.drop_duplicates(subset='stn_codes')
    return data_df

def cleanAlertPrompt(prompt):
    # Segregate into multiple routes
    stn_df = pd.DataFrame()
    pattern = r'\b(?:[A-Z]{2}\d{2}|[A-Z]{2}\d{1})\b'
    codes = re.findall(pattern, prompt)
    lines = [station[:2] + 'L' for station in codes]
    route_stn = pd.DataFrame({
        'stn_codes': codes,
        'stn_lines': lines
    })
    stn_df = pd.concat([stn_df, route_stn], ignore_index=True)
    stn_df = stn_df.drop_duplicates(subset='stn_codes')
    return stn_df

def getCrowdRequest(url_type, stn_line):
    payload={}
    headers = {
      'AccountKey': LTA_API_KEY
    }
    url = f"http://datamall2.mytransport.sg/ltaodataservice/PCD{url_type}?TrainLine={stn_line}"
    response = requests.request("GET", url, headers=headers, data=payload)
    
    # Check for response
    if response.status_code == 200:
        return response.json()
    else:
        # print("Failed to fetch data, please try again in a short while.")
        return "Error: Failed to fetch data, please try again in a short while."

def getDataRequest(url_type):
    # URLs
    vol_by_stn_url = "http://datamall2.mytransport.sg/ltaodataservice/PV/Train"
    vol_to_fro_url = "http://datamall2.mytransport.sg/ltaodataservice/PV/ODTrain?$skip=500"
    alert_url = "http://datamall2.mytransport.sg/ltaodataservice/TrainServiceAlerts"

    if url_type == "vol_by_stn":
        url = vol_by_stn_url
    elif url_type == 'vol_to_fro':
        url = vol_to_fro_url
    elif url_type == 'alert':
        url = alert_url
    
    payload={}
    headers = {
      'AccountKey': LTA_API_KEY
    }
    response = requests.request("GET", url, headers=headers, data=payload)
    
    # Check for response
    if response.status_code == 200:
        return response.json()
    else:
        print("Failed to fetch data, please try again in a short while.")
        return None

def getLinkCSV(url_type):
    # Pull Info
    response = getDataRequest(url_type)
    if response is not None:
        data_href = pd.DataFrame(response['value'])
        link = data_href['Link'].iloc[0]
        link_response = requests.get(link)
        with open('data.zip', 'wb') as f:
            f.write(link_response.content)
        with zipfile.ZipFile('data.zip', 'r') as zip_ref:
            zip_ref.extractall()
        match = re.search(r'/([\w_]+)\.zip', link)
        if match:
            filename = match.group(1)
        csv_data = pd.read_csv(filename + '.csv')
        csv_df = cleanCSV(csv_data, None)
        return csv_df
    else:
        return None

def cleanCSV(df, where):
    if where is None:
        col_name = 'PT_CODE'
    else:
        col_name = where + '_PT_CODE'
    modified_rows = []
    for index, row in df.iterrows():
        if '/' in row['PT_CODE']:
            codes = row['PT_CODE'].split('/')

            for code in codes:
                new_row = row.copy()
                new_row['PT_CODE'] = code
                modified_rows.append(new_row)
        else:
            modified_rows.append(row)

    modified_df = pd.DataFrame(modified_rows)
    modified_df.rename(columns={'PT_CODE': col_name}, inplace=True)
    return modified_df

def getTime(time_delta):
    now = datetime.now() + timedelta(minutes=time_delta)
    # Round to nearest 30 mins
    rounded_minutes = (now.minute // 30) * 30
    if rounded_minutes == 0:
        now = now.replace(minute=0)
    else:
        now = now.replace(minute=rounded_minutes)
    return now

def getStnName(df, where):
    input_csv = pd.read_csv('mrtlrt_gps.csv')
    if where is None:
        col_name = 'Station'
    else:
        col_name = where + '_Station'
    for index, row in df.iterrows():
        match = input_csv[input_csv['station_code'] == row['Station']]
        if not match.empty:
            stn_code_name = row['Station'] + ' ' + match['full_name'].values[0]
            df.at[index, 'Station'] = stn_code_name
    df.rename(columns={'Station': col_name}, inplace=True)
    return df

def cleanForecastCrowd(jsonObj):
    # Clean json data
    platform_data = []
    for entry in jsonObj['value']:
        date = entry['Date']
        for station in entry['Stations']:
            station_name = station['Station']
            for interval in station['Interval']:
                interval_data = {
                    'Date': date,
                    'Station': station_name,
                    'Start': interval['Start'],
                    'CrowdLevel': interval['CrowdLevel']
                }
                platform_data.append(interval_data)
    df = pd.DataFrame(platform_data)
    df['Start'] = pd.to_datetime(df['Start']).dt.time
    df.drop(columns=["Date"], inplace=True)
    return df

def cleanRealTimeCrowd(jsonObj):
    # Clean json data
    df = pd.DataFrame(jsonObj['value'])
    df['StartTime'] = pd.to_datetime(df['StartTime']).dt.time
    df.drop(columns=["EndTime"], inplace=True)
    return df

def cleanTimeCrowd(data_df, prompt_df, date_time):
    # Filter for station
    crowd = []
    day, time = splitDateTime(date_time)
    for index, row in data_df.iterrows():
        if row['Station'] in prompt_df['stn_codes'].values:
            if (time-timedelta(minutes=30)).time() <= row['Start'] <= (time+timedelta(minutes=30)).time():
                crowd.append(row)
    
    # Clean output
    crowd_df = pd.DataFrame(crowd)
    crowd_df.reset_index(drop=True, inplace=True)
    crowd_df.drop(columns=["Start"], inplace=True)
    crowd_df.drop_duplicates(subset=['Station', 'CrowdLevel'], inplace=True)
    check_crowd_df = pd.DataFrame(columns=['Station', 'CrowdLevel'])
    for station in crowd_df['Station'].unique():
        station_df = crowd_df[crowd_df['Station'] == station]
        crowd_levels = []

        # Arrange 'CrowdLevel' values by the order 'l', 'm', 'h'
        for level in ['l', 'm', 'h']:
            # Append values for current 'CrowdLevel' if present
            if level in station_df['CrowdLevel'].values:
                crowd_levels.append(level)
        current_crowd = pd.DataFrame({'Station': [station],
                                      'CrowdLevel': [crowd_levels]})
        check_crowd_df = pd.concat([check_crowd_df, current_crowd])
    check_crowd_df['CrowdLevel'] = check_crowd_df['CrowdLevel'].apply(replaceCrowdLevels)
    check_crowd_df.reset_index(drop=True, inplace=True)
    getStnName(check_crowd_df, None)
    crowd_string = summarizeCrowd(check_crowd_df)   
    return crowd_string

def replaceCrowdLevels(levels):
    if levels == ['l']:
        return 'CROWD LEVEL LOW'
    elif levels == ['m']:
        return 'CROWD LEVEL MODERATE'
    elif levels == ['h']:
        return 'CROWD LEVEL HIGH'
    elif levels == ['l', 'm']:
        return 'CROWD LEVEL LOW TO MODERATE'
    elif levels == ['m', 'h']:
        return 'CROWD LEVEL MODERATE TO HIGH'
    elif levels == ['l', 'm', 'h']:
        return 'CROWD LEVEL LOW TO HIGH'
    else:
        return levels

def cleanCrowd(data_df, prompt_df):
    # Filter for station & time
    crowd = []
    for index, row in data_df.iterrows():
        # Check if the station code from data_df is in prompt_df's 'stn_codes' column
        if row['Station'] in prompt_df['stn_codes'].values:
            # Append the row from data_df to matched_data
            crowd.append(row)
    if not crowd:
        return None
    else:
        # Clean output
        crowd_df = pd.DataFrame(crowd)
        crowd_df.reset_index(drop=True, inplace=True)
        crowd_df['CrowdLevel'] = crowd_df['CrowdLevel'].replace({\
                                                                 'l': 'CROWD LEVEL LOW',
                                                                 'm': 'CROWD LEVEL MODERATE',
                                                                 'h': 'CROWD LEVEL HIGH'})
        crowd_df.drop(columns=["StartTime"], inplace=True)
        getStnName(crowd_df, None)
        crowd_string = summarizeCrowd(crowd_df)   
        return crowd_string

def cleanVolume(data_df, prompt_df, date_time):
    # Clean date
    day, time = splitDateTime(date_time)
    time_hr = time.hour
    
    selected_list = []
    passenger_upper_threshold = 95000   # threshold between "medium-to-high"
    passenger_lower_threshold = 15000    # threshold between "low-to-medium"
    
    # Filtering through DataFrame
    filtered_df = pd.DataFrame()
    for stn_code in prompt_df['stn_codes'].unique():
        filtered_df = pd.concat([filtered_df, data_df[data_df['PT_CODE'] == stn_code]])
    for index, row in filtered_df.iterrows():
        if row['DAY_TYPE'] == day \
        and row['TIME_PER_HOUR'] == time_hr:
            selected_list.append(row)
    
    if not selected_list:
        return None
    else:
        # Check crowd volume
        selected_list_df = pd.DataFrame(selected_list)
        selected_list_df = selected_list_df.reset_index(drop=True)
        passenger_volume = selected_list_df.loc[0, 'TOTAL_TAP_IN_VOLUME'] + selected_list_df.loc[0, 'TOTAL_TAP_OUT_VOLUME']
        selected_list_df.insert(1, 'CrowdLevel', 
                                'CROWD LEVEL HIGH' if passenger_volume > passenger_upper_threshold 
                                else ('CROWD LEVEL LOW' if passenger_volume < passenger_lower_threshold 
                                else 'CROWD LEVEL MODERATE'))
        selected_list_df = selected_list_df.loc[:, ['PT_CODE', 'CrowdLevel']]
        selected_list_df.rename(columns={'PT_CODE': 'Station'}, inplace=True)
        getStnName(selected_list_df, None)
        selected_list_string = summarizeCrowd(selected_list_df)
        return selected_list_string

def cleanToFroVolume(data_df, start_stn, end_stn, date_time):
    # Clean date
    day, time = splitDateTime(date_time)
    time_hr = time.hour
    
    selected_list = []
    passenger_upper_threshold = 57   # 75% 
    passenger_lower_threshold = 4    # 25%
    
    # Filtering through DataFrame
    data_df = data_df[data_df['ORIGIN_PT_CODE'].str.contains(start_stn)]
    data_df = data_df[data_df['DESTINATION_PT_CODE'].str.contains(end_stn)]
    data_df.rename(columns={'ORIGIN_PT_CODE': 'PT_CODE'}, inplace=True)
    data_df = cleanCSV(data_df, 'ORIGIN')
    data_df.rename(columns={'DESTINATION_PT_CODE': 'PT_CODE'}, inplace=True)
    data_df = cleanCSV(data_df, 'DESTINATION')
    data_df = data_df[data_df['ORIGIN_PT_CODE'] == start_stn]
    data_df = data_df[data_df['DESTINATION_PT_CODE'] == end_stn]
    
    for index, row in data_df.iterrows():
        if row['DAY_TYPE'] == day \
        and row['TIME_PER_HOUR'] == time_hr:
            selected_list.append(row)
            break
    
    # Check crowd volume 
    selected_list_df = pd.DataFrame(selected_list).loc[:, ['ORIGIN_PT_CODE', 'DESTINATION_PT_CODE', 'TOTAL_TRIPS']]
    selected_list_df = selected_list_df.reset_index(drop=True)
    passenger_volume = selected_list_df.loc[0, 'TOTAL_TRIPS']
    selected_list_df.insert(1, 'CrowdLevel', 
                            'CROWD LEVEL HIGH' if passenger_volume > passenger_upper_threshold 
                            else ('CROWD LEVEL LOW' if passenger_volume < passenger_lower_threshold 
                            else 'CROWD LEVEL MODERATE'))
    selected_list_df = selected_list_df.loc[:, ['CrowdLevel', 'ORIGIN_PT_CODE', 'DESTINATION_PT_CODE']]
    selected_list_df.rename(columns={'ORIGIN_PT_CODE': 'Station'}, inplace=True)
    getStnName(selected_list_df, 'ORIGIN')
    selected_list_df.rename(columns={'DESTINATION_PT_CODE': 'Station'}, inplace=True)
    getStnName(selected_list_df, 'DESTINATION')
    selected_list_df.insert(1, 'from', 'from')
    selected_list_df.insert(3, 'to', 'to')
    selected_list_df.insert(5, '.', '.')
    return selected_list_df.to_string(index=False, header=False)

def getServiceStatus(data):
    status = data["value"]["Status"]
    affected_segments = data["value"]["AffectedSegments"]
    messages = data["value"]["Message"]
    
    # Extracting affected segments information
    lines = [segment["Line"] for segment in affected_segments]
    directions = [segment["Direction"] for segment in affected_segments]
    stations = [segment["Stations"] for segment in affected_segments]
    free_public_bus = [segment["FreePublicBus"] for segment in affected_segments]
    free_mrt_shuttle = [segment["FreeMRTShuttle"] for segment in affected_segments]
    mrt_shuttle_direction = [segment["MRTShuttleDirection"] for segment in affected_segments]
    
    # Extracting message information
    content = [msg["Content"] for msg in messages]
    created_date = [msg["CreatedDate"] for msg in messages]
    
    # Create DataFrame
    df = pd.DataFrame({
        "Status": status,
        "Line": lines,
        "Direction": directions,
        "Stations": stations,
        "FreePublicBus": free_public_bus,
        "FreeMRTShuttle": free_mrt_shuttle,
        "MRTShuttleDirection": mrt_shuttle_direction,
        "MessageContent": content,
        "CreatedDate": created_date
    })
    return df

def validateStnAlert(data_df, stn_code):
    if len(stn_code) >= 3 and stn_code[2] == '0':
        stn_code = stn_code[:2] + stn_code[3:]
    stn_line = stn_code[:2] + 'L'
    for index, row in data_df.iterrows():
        if stn_code in row['Stations']:
            service_message = data_df.loc[index, 'MessageContent']
            # Clean service_message
            split_strings = re.split(r'(\d{4}hrs :)', service_message)
            combined_strings = []
            current_string = ''
            for string in split_strings:
                if string.endswith('hrs :'):
                    if current_string:
                        combined_strings.append(current_string.strip())
                    current_string = string
                else:
                    current_string += string
            if current_string:
                combined_strings.append(current_string.strip())
            for string in combined_strings:
                if stn_line in string:
                    status = string
                    break
                break
            break
        else:
            status = "No train service issues at selected stations."
    return status

def summarizeCrowd(df):
    grouped = df.groupby('CrowdLevel')['Station'].apply(lambda x: ', '.join(x)).reset_index()

    # Construct the summary string
    summary = ''
    for index, row in grouped.iterrows():
        summary += f"CROWD LEVEL {row['CrowdLevel'].split()[-1]} at these stations: {row['Station']}.\n "

    # Remove trailing comma and space
    summary = summary[:-2]
    return summary

def summarizeAlerts(df):
    grouped = df.groupby('Status')['Station'].apply(lambda x: ', '.join(x)).reset_index()
    
    # Drop no issues
    grouped = grouped[grouped['Status'] != 'No train service issues at selected stations.']
    
    if len(grouped)>0:
        # Construct the summary string
        summary = ''
        for index, row in grouped.iterrows():
            summary += f"{row['Status']} at these stations: {row['Station']}.\n"

        # Remove trailing comma and space
        summary = summary[:-2]
        return summary
    else:
        return ('No train service issues at selected stations.')
    
def summarizeNearestTaxi(df):
    df = df.reset_index(drop=True)
    df = df.loc[:, ['Name', 'Distance']]
    df['Distance'] = df['Distance'].round(1).astype(str)
    summary = ''
    for index, row in df.iterrows():
        summary += f"{row['Name']} at {row['Distance']}m"
        if index != len(df) - 1:
            summary += "; "
        else:
            summary += "."
    return summary

def haversine(lat1, lon1, lat2, lon2):
    # Convert latitude and longitude from degrees to radians
    lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])

    # Haversine formula
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = sin(dlat / 2)**2 + cos(lat1) * cos(lat2) * sin(dlon / 2)**2
    c = 2 * atan2(sqrt(a), sqrt(1 - a))
    distance = 6371 * c  # Radius of Earth in kilometers

    return distance * 1000

def getStnCoordinates(stn):
    stn_df = pd.read_csv('mrtlrt_gps.csv')
    # Get MRT station coordinates
    for i in range(len(stn_df)):
        if stn_df.loc[i, 'station_code'] == stn:
            stn_lat = stn_df.loc[i, 'lat']
            stn_lon = stn_df.loc[i, 'lng']
            stn_name = stn + " " + stn_df.loc[i, 'full_name']
    return stn_lat, stn_lon, stn_name

def calcTaxiDist(place_lat, place_lon, taxi_df):
    dist = []
    # Calculate distance between places
    for i in range(len(taxi_df)):
        taxi_lat = taxi_df.loc[i, 'Latitude']
        taxi_lon = taxi_df.loc[i, 'Longitude']
        dist.append(haversine(place_lat, place_lon, taxi_lat, taxi_lon))
    taxi_df['Distance'] = pd.Series(dist)
    return taxi_df

def cleanCrowdTime(data_df, prompt_df, date_time):
    # Filter for station
    crowd = []
    day, time = splitDateTime(date_time)
    
    # Check if user input an actual time (+- 1.5hrs), else 7am - 10pm
    current_time = datetime.strptime((datetime.now().strftime("%H:%M")), "%H:%M")
    if time != current_time:
        start_check_time = (time-timedelta(minutes=90)).time()
        if start_check_time <= datetime.strptime('07:00', '%H:%M').time():
            start_check_time = datetime.strptime('07:00', '%H:%M').time()
        end_check_time = (time+timedelta(minutes=30)).time()
        if end_check_time >= datetime.strptime('22:00', '%H:%M').time():
            end_check_time = datetime.strptime('22:00', '%H:%M').time()
    else:
        start_check_time = datetime.strptime('07:00', '%H:%M').time()
        end_check_time = datetime.strptime('22:00', '%H:%M').time()
    
    # Find station
    for index, row in data_df.iterrows():
        if row['Station'] in prompt_df['stn_codes'].values:
            if start_check_time <= row['Start'] <= end_check_time:
                crowd.append(row)
    
    # Clean output
    crowd_df = pd.DataFrame(crowd)
    crowd_df['Start'] = crowd_df['Start'].apply(lambda x: x.strftime('%H:%M'))
    crowd_df.reset_index(drop=True, inplace=True)
    
    all_time_df = pd.DataFrame()
    for station in crowd_df['Station'].unique():
        station_df = crowd_df[crowd_df['Station'] == station]
        current_stn_df = station_df.groupby(['Station', 'CrowdLevel'])['Start'].apply(list).reset_index()
        # Arrange order 'l-m-h'
        current_stn_df['CrowdLevel'] = pd.Categorical(current_stn_df['CrowdLevel'], categories=['l', 'm', 'h'], ordered=True)
        current_stn_df = current_stn_df.sort_values('CrowdLevel')
        all_time_df = pd.concat([all_time_df, current_stn_df])
    all_time_df['CrowdLevel'] = all_time_df['CrowdLevel'].replace({\
                                                                   'l': 'LOW',
                                                                   'm': 'MODERATE',
                                                                   'h': 'HIGH'})
    all_time_df.reset_index(drop=True, inplace=True)
    getStnName(all_time_df, None)
    time_string = summarizeTime(all_time_df)
    return time_string

def cleanForecastVolume(data_df, prompt_df, date_time):
    # Clean date
    day, time = splitDateTime(date_time)
    data_df.rename(columns={'PT_CODE': 'Station', 'TIME_PER_HOUR': 'Start'}, inplace=True)
    
    # Check if user input an actual time (+- 2hrs), else 7am - 10pm
    current_time = datetime.strptime((datetime.now().strftime("%H:%M")), "%H:%M")
    if time != current_time:
        start_check_time = time.hour - 2
        if start_check_time <= 7:
            start_check_time = 7
        end_check_time = time.hour + 2
        if end_check_time >= 22:
            end_check_time = 22
    else:
        start_check_time = 7
        end_check_time = 22 
    
    selected_list = []
    passenger_upper_threshold = 95000   # threshold 'med-to-high'
    passenger_lower_threshold = 15000    # threshold 'low-to-med'
    
    # Find station
    for index, row in data_df.iterrows():
        if row['Station'] in prompt_df['stn_codes'].values:
            if start_check_time <= row['Start'] <= end_check_time:
                selected_list.append(row)
    
    # Check crowd volume
    selected_list_df = pd.DataFrame(selected_list)
    selected_list_df = selected_list_df.reset_index(drop=True)
    selected_list_df['Passenger_Volume'] = selected_list_df['TOTAL_TAP_IN_VOLUME'] + selected_list_df['TOTAL_TAP_OUT_VOLUME']
    for i in range(len(selected_list_df)):
        passenger_volume = selected_list_df.loc[i, 'Passenger_Volume']
        if passenger_volume > passenger_upper_threshold:
            selected_list_df.loc[i, 'CrowdVolume'] = 'HIGH'
        elif passenger_volume < passenger_lower_threshold:
            selected_list_df.loc[i, 'CrowdVolume'] = 'LOW'
        else:
            selected_list_df.loc[i, 'CrowdVolume'] = 'MODERATE'
    selected_list_df = selected_list_df.loc[:, ['Station', 'CrowdVolume', 'Passenger_Volume', 'Start', 'DAY_TYPE']]
    all_volume_df = pd.DataFrame()
    for station in selected_list_df['Station'].unique():
        station_df = selected_list_df[selected_list_df['Station'] == station]
        current_stn_df = station_df.groupby(['Station', 'CrowdVolume', 'DAY_TYPE'])['Start'].apply(list).reset_index()
        current_stn_df['Start'] = current_stn_df['Start'].apply(lambda hours: [f"{hour:02d}:00" for hour in hours])
        current_stn_df['Start'] = current_stn_df['Start'].apply(lambda times: sorted(times, key=lambda x: int(x.split(':')[0])))
        
        # Arrange order 'LOW-MODERATE-HIGH'
        current_stn_df['CrowdVolume'] = pd.Categorical(current_stn_df['CrowdVolume'], categories=['LOW', 'MODERATE', 'HIGH'], ordered=True)
        current_stn_df = current_stn_df.sort_values('CrowdVolume')
        all_volume_df = pd.concat([all_volume_df, current_stn_df])
    all_volume_df.reset_index(drop=True, inplace=True)
    getStnName(all_volume_df, None)
    time_string = summarizeVolumeTime(all_volume_df)
    return time_string

def summarizeTime(df):
    # Construct the summary string
    summary = ''
    for index, row in df.iterrows():
        summary += f"CROWD LEVEL at {row['Station']} is {row['CrowdLevel'].split()[-1]} at these timings: {row['Start']}.\n "

    # Remove trailing comma and space
    summary = summary[:-2]
    return summary

def summarizeVolumeTime(df):
    grouped_df = df.groupby(['DAY_TYPE', 'Station', 'CrowdVolume'], observed=False)['Start'].apply(list).reset_index()
    grouped = grouped_df.dropna()
    
    summary = ''
    # Construct the summary string
    for day in grouped['DAY_TYPE'].unique():
        summary += f"On {day} : "
        for station in grouped['Station'].unique():
            summary += f"CROWD VOLUME at {station} is "
            count = 0
            for volume in grouped['CrowdVolume'].unique():
                timings = grouped[(grouped['DAY_TYPE'] == day) & 
                                  (grouped['Station'] == station) & 
                                  (grouped['CrowdVolume'] == volume)]['Start'].tolist()
                timings_str = str(timings).strip('[]')
                
                if len(timings_str) > 0:
                    if count > 0:
                        summary += ", "
                    summary += f"{volume} at these timings: {timings_str}"
                if volume == grouped['CrowdVolume'].unique()[-1]:
                    summary += ".\n "
                count += 1

    # Remove trailing comma and space
    summary = summary[:-2]
    return summary

# Check in a given date is a weekday or weekend
def check_weekday_or_weekend(date_time_string):
    # Convert date/time string to datetime object
    date_time_obj = datetime.strptime(date_time_string, '%d-%m-%Y,%H:%M') #datetime.strptime(date_time_string, '%d-%m-%Y, %H%M')
    
    # Extract the weekday number (0: Monday, 1: Tuesday, ..., 6: Sunday)
    weekday_num = date_time_obj.weekday()
    
    # Check if it's a weekday (Monday to Friday)
    if weekday_num < 5:
        return f"{date_time_obj.strftime('%d-%m-%Y')} is a WEEKDAY.\n"
    else:
        return f"{date_time_obj.strftime('%d-%m-%Y')} is a WEEKEND.\n"

# Summarise nearest taxi stands & craft the google maps link
def summarizeNearestTaxiLink(df):
    df = df.reset_index(drop=True)
    df = df.loc[:, ['Latitude', 'Longitude', 'Name', 'Distance']]
    df['Distance'] = df['Distance'].round(1).astype(str)
#     summary = 'Nearest taxi stations are:' + df.to_string(index=False, header=False)
#     summary = summary.strip().replace('\n', ', ')
    summary = ''
    for index, row in df.iterrows():
        summary += f"{row['Name']} at {row['Distance']}m (Link: https://www.google.com/maps?q={row['Latitude']},{row['Longitude']})"
        if index != len(df) - 1:
            summary += " ; "
        else:
            summary += " ."
    return summary

# Obtain route based on a given start & end station
def get_public_transport_route(station):
    try:
        # Read the CSV file into a DataFrame
        df = pd.read_csv('mrtlrt_gps.csv')
        
        # Check for cases where there is only 1 train station given
        if ',' not in station:
            station_cleaned = str(station).upper().replace(" ", "")   # remove any spaces
            station_details = df[df['station_name'] == station_cleaned]
            if station_details.empty:
                stn_not_found = f"Error: Station '{station}' not found, do ensure that the spelling is correct."
                return stn_not_found
            # If there is only 1 MRT station, return the station code & station name on that station without calling the route API
            station_code = str(station_details['station_code'].values[0])
            station_name = str(station_details['full_name'].values[0])
            single_station = f"MRT station: {station_code} {station_name}."
            return single_station

        # For cases where there are 2 stations, split them into start_station & end_station based on the comma delimiter
        start_station, end_station = str(station).split(",")
        start_station = start_station.upper().replace(" ", "") # updated so that all spaces will become blank
        end_station = end_station.upper().replace(" ", "") # updated so that all spaces will become blank

        # Get starting station details
        station_start = df[df['station_name'] == start_station]
        if station_start.empty:
            start_not_found = f"Error: Starting station '{start_station}' not found, do ensure that the spelling is correct."
            return start_not_found
        lat_start = station_start['lat'].values[0]
        lng_start = station_start['lng'].values[0]

        # Get destination station details
        station_end = df[df['station_name'] == end_station]
        if station_end.empty:
            dest_not_found = f"Error: Destination station '{end_station}' not found, do ensure that the spelling is correct."
            return dest_not_found
        lat_end = station_end['lat'].values[0]
        lng_end = station_end['lng'].values[0]
        code_end = station_end['station_code'].values[0] # get destination station code for walk as last leg
        name_end = station_end['full_name'].values[0] # get destination station name for walk as last leg

        # Define parameters for API request
        routeType = "pt"
        date = datetime.today().strftime("%m-%d-%Y")
        hour = '{:02d}'.format(datetime.now().hour)
        minute = '{:02d}'.format(datetime.now().minute)
        mode = "RAIL"
        maxWalkDistance = 100
        numItineraries = 3

        # Construct URL for API request
        url = f"https://www.onemap.gov.sg/api/public/routingsvc/route?start={lat_start}%2C{lng_start}&end={lat_end}%2C{lng_end}&routeType={routeType}&date={date}&time={hour}%3A{minute}%3A00&mode={mode}&maxWalkDistance={maxWalkDistance}&numItineraries={numItineraries}"

        # Define headers for API request (Access code valid until 21 Mar)
        headers = {
            "Authorization": ONEMAP_API_KEY
        }

        # Make API request
        response = requests.request("GET", url, headers=headers)
        
        # Handle case where API request fails
        if not response.ok:
            response_fail = "Error: The API request failed, please try again later."
            return response_fail
        
        # Extract route information from API response
        routes_str = ""
        api_response = response.json()
        route_count = 1
        api_response_itineraries = api_response['plan']['itineraries']

        walk_legs = 0
        for itinerary in api_response_itineraries:
            for leg in itinerary['legs']:
                if leg['mode'] == 'WALK' and leg['from']['name'] == 'Origin' and leg['to']['name'] == 'Destination':
                    walk_legs += 1

        routes_count = len(api_response_itineraries) - walk_legs
        routes_str += f"There are {routes_count} possible travel route(s).\n"

        for itinerary in api_response_itineraries: # pull out all the itineraries or routes from the api
            route_string = ""
            duration = itinerary['duration'] / 60  # calculate overall duration of the route
            fare = itinerary['fare']  # extract total fare of the route
            prev_station_name = None
            first_leg = itinerary['legs'][0]  # Extracting the first leg
            last_leg = itinerary['legs'][-1]  # Extracting the last leg

            if first_leg['mode'] == 'WALK' and \
                first_leg['from']['name'] == 'Origin' and \
                first_leg['to']['name'].replace(' MRT STATION', '').replace(' ', '') != start_station and \
                first_leg['to']['name'] != 'Destination':

                route_string += f"Walk {round(first_leg['distance'],0)} metres or {round(first_leg['distance']/0.75,0)} steps to {first_leg['to']['stopCode']} {first_leg['to']['name'].replace(' MRT STATION', '')} "
            
            # for leg in itinerary['legs']:  # within each itinerary, look for each leg of the route
            for leg_index, leg in enumerate(itinerary['legs']):  # within each itinerary, look for each leg of the route
                if leg_index > 0 and leg_index < len(itinerary['legs']) - 1 and \
                    leg['mode'] == 'WALK' and \
                    leg['from']['name'] != leg['to']['name']: # to include "walk leg from one station to another" in the route string
                    route_string += f" then walk {round(leg['distance'],0)} metres or {round(leg['distance']/0.75,0)} steps"
                
                if leg['mode'] == 'SUBWAY':  # if the leg is a subway route

                    if route_string.startswith("Walk"):  # Check if route_string starts with "Walk"
                        route_string += "then take train from "
                    elif route_string:
                        route_string += " to "
                    current_station_name = leg['from']['name'].replace(' MRT STATION', '') # remove excess words

                    if current_station_name == prev_station_name:
                        route_string += "transit to "  # include the transit station
                    route_string += f"{leg['from']['stopCode']} {current_station_name}"
                    prev_station_name = leg['to']['name'].replace(' MRT STATION', '') # remove excess words

                    for stop in leg['intermediateStops']:
                        route_string += f" to {stop['stopCode']} {stop['name'].replace(' MRT STATION', '')}" # remove excess words
                    route_string += f" to {leg['to']['stopCode']} {prev_station_name}"

            # Check if the last leg is a 'WALK' mode and meets the conditions
            if last_leg['mode'] == 'WALK' and \
                last_leg['to']['name'] == 'Destination' and \
                last_leg['from']['name'].replace(' MRT STATION', '').replace(' ', '') != end_station and \
                last_leg['from']['name'] != 'Origin':
                if route_string:  # If route_string is not empty, add "Walk to end station"
                    routes_str += f"Route {route_count}: {route_string} then walk {round(last_leg['distance'],0)} metres or {round(last_leg['distance']/0.75,0)} steps to {code_end} {name_end} with an estimated duration of {round(duration, 0)} minutes and cost ${fare}.\n"
            elif route_string:  # Check if route_string is not empty before adding to routes_str
                routes_str += f"Route {route_count}: {route_string} with an estimated duration of {round(duration, 0)} minutes and cost ${fare}.\n"
            
            # Increment the route count at the end of the loop
            route_count += 1

        return routes_str

    except Exception as e:
        exception_msg = f"Error: {str(e)}"
        return exception_msg

### (V) Train Service Alerts
- Description : Check if passing route's traing stations are affected by a train service breakdown.
- Function: checkTrainAlert(input_prompt)
- Input: *string* "station_name" OR "start_station,end_station"

In [9]:
# Check for any existing train alerts
@tool
def checkTrainAlert(input_prompt):
    """
    Use this to check for any disruptions or unavailability of train services, such as affected line and stations.
    To use the tool, you must provide one or two parameters in the following format: 'station_name' OR 'start_station,end_station'. Do not include 'mrt' in the input.
    """
    # Clean prompt
    text_input_prompt, datetime_input_prompt = cleanTimePrompt(input_prompt)
    prompt = get_public_transport_route(text_input_prompt)
    
    # if prompt is not None:
    if "Error" not in prompt:
        prompt_stn_df, origin, destination = cleanPrompt(prompt)

        # Get data
        url_type = 'alert'
        response = getDataRequest(url_type)
    #     response = test_a   # Simulate situation
        status = []

        if response is not None:
            # Initialize DataFrame
            service_df = pd.DataFrame({
                'Status': [response['value']['Status']],
                'AffectedSegments': [response['value']['AffectedSegments']],
                'Message': [response['value']['Message']]
            })

            # Check service status
            if service_df.loc[0, 'Status'] == 1:
                status = "There are no real-time train service issues at the selected stations."
            else:
                for code in (prompt_stn_df['stn_codes']):
                    message = getServiceStatus(response)
                    status.append(validateStnAlert(message, code))

            # Summarize Alerts
            status_df = pd.DataFrame({
                'Station': prompt_stn_df['stn_codes'],
                'Status': status
            })
            status_df = getStnName(status_df, None)
            status_string = summarizeAlerts(status_df)
            return status_string
        else:
            return "Error: The API call was unsuccessful. Please try again later."
    else:
        return "Error: No routes available. Please try again later."


### (VI) 3 Nearest Taxi Stands from Searched Location
- Description : Check for the nearest 3 taxi stands at the selected location.
- Function: checkNearestTaxiStands(input_location)
- Input: *string* Name of location or postal code or address or road name


In [10]:
# Support Function to obtain lat, long, address 
def get_GPS_support(query):
    url = f"https://www.onemap.gov.sg/api/common/elastic/search?searchVal={query}&returnGeom=Y&getAddrDetails=Y&pageNum=1"
    headers = {"Content-Type": "application/json"}
    response = requests.request("GET", url, headers=headers)
    if response.status_code == 200:
        api_output = response.json()
        if api_output["found"] > 0:
            result = api_output["results"][0]  # Extract the first result
            latitude = float(result["LATITUDE"])
            longitude = float(result["LONGITUDE"])
            address = result["ADDRESS"]
            return latitude, longitude, address
        else:
            # print("Error: No results were found. Ensure that the input is a valid location, postal code, address or road name.")
            return None, None, None
    else:
        print("Error:", response.status_code)
        return None, None, None

# Check for the nearest taxi stands based on a given location or postal code or address or road name
@tool
def checkNearestTaxiStands(input_location):
    """Use this to find the location of and distance to the nearest taxi stands. 
    You must provide a valid location or postal code (6-digit integer) or address or road name.
    """
    # Check Crowd Levels
    lat, lon, address = get_GPS_support(input_location)
    
    if lat is not None:
        # Calculate distance
        taxi_df = pd.read_csv('taxi_stands_Monthly.csv')
        taxi_df = calcTaxiDist(lat, lon, taxi_df)

        # Summarize info
        nearest_taxis = summarizeNearestTaxiLink(taxi_df.nsmallest(3, 'Distance'))
        summary = 'Nearest taxi stands at ' + input_location + ' are : ' + nearest_taxis
        return summary
    else:
        return "Error: No results were found. Ensure that the input is a valid location, postal code, address or road name."


### (VII) RealTime Crowd Levels
- Description : Check current crowd levels at each train station along the route.
- Function : checkRealTimeCrowd(input_prompt)
- Input: *string* "station_name" OR "start_station,end_station"


In [11]:
# Check real-time platform crowdedness level for MRT/LRT stations
@tool
def checkRealTimeCrowd(input_prompt):
    """Use this to find the real time and current MRT train platform crowdedness level.
    To use the tool, you must provide one or two train station parameters in the following format: 'station_name' OR 'start_station,end_station'. Example: 'buona vista' for 1 station OR 'buona vista, kent ridge' for between 2 stations. Do not include 'mrt' in the input.
    """
    # Clean prompt
    text_input_prompt, datetime_input_prompt = cleanTimePrompt(input_prompt)
    prompt = get_public_transport_route(text_input_prompt)
    
    # if prompt is not None:
    if "Error" not in prompt:
        prompt_stn_df, origin, destination = cleanPrompt(prompt)
    
        # Get data
        url_type = 'RealTime'
        realtime_df = pd.DataFrame()
        for lines in (prompt_stn_df['stn_lines'].unique()):
            response = getCrowdRequest(url_type, lines)
            if response is not None:
                json_df = cleanRealTimeCrowd(response)
                realtime_df = pd.concat([realtime_df, json_df])
            else:
                return None
        realtime_crowd = cleanCrowd(realtime_df, prompt_stn_df)
        if realtime_crowd is not None:
            return realtime_crowd
        else:
            print("No live crowd data available. Please try again later.")
            return None
    else:
        print("No routes available. Please try again later.")
        return None


### (VIII) Forecast Crowd Levels
- Description : Summarize forecasted crowd levels (30 minutes before and after specified time) at each train station along the route.
- Function : checkForecastCrowd(input_prompt)
- Input: *string* "station_name" OR "start_station,end_station" OR "start_station,end_station; date DDMMYYYY, time HHMM"


In [12]:
# Check forecasted platform crowdedness level for the MRT/LRT stations

@tool
def checkForecastVolume(input_prompt):
    """Use this to find the forecasted or predicted future MRT train platform crowdedness level.
    To use the tool, you must provide one or two train station parameters ('station_name' OR 'start_station,end_station'), AND specify the date (DD-MM-YYYY) AND/OR time (HH:MM) in the following format: \ 
    Format | Example \
    'station_name;DD-MM-YYYY,HH:MM' | 'buona vista;02-02-2024,18:00' OR \ 
    'station_name;DD-MM-YYYY' | 'buona vista;02-02-2024' OR \
    'station_name;HH:MM' | 'buona vista;18:00' OR \
    'start_station,end_station;DD-MM-YYYY,HH:MM' | 'somerset,novena;01-01-2024,06:00' OR \
    'start_station,end_station;DD-MM-YYYY' | 'somerset,novena;01-01-2024' OR \
    'start_station,end_station;HH:MM' | 'somerset,novena;06:00'
    """ 
    # Clean prompt
    text_input_prompt, datetime_input_prompt = cleanTimePrompt(input_prompt)
    prompt = get_public_transport_route(text_input_prompt)
    
    # if prompt is not None:
    if "Error" not in prompt:
        prompt_stn_df, origin, destination = cleanPrompt(prompt)
        # prompt_stn_df = prompt_stn_df.iloc[[0, -1]]
        prompt_stn_df = pd.DataFrame({
            'stn_codes' : [origin, destination],
            'stn_lines' : [(origin[:2] + 'L'), (destination[:2] + 'L')]
        })

        # Get data
        url_type = 'vol_by_stn'
    #     data_df = getLinkCSV(url_type)
        data_df = cleanCSV(pd.read_csv('transport_node_train_202402.csv'), None)

        # Check crowd volume
        crowd_volume = cleanForecastVolume(data_df, prompt_stn_df, datetime_input_prompt)
        if crowd_volume is not None:
            # check if date is a weekday/weekend
            weekday_weekend_string = check_weekday_or_weekend(datetime_input_prompt)
            crowd_volume = weekday_weekend_string + crowd_volume
            return crowd_volume
        else:
            print("No crowd volume data available. Please try again later.")
            return None
    else:
        print("No routes available. Please try again later.")
        return None

### (IX) Nearest Restaurants & Attractions
- Description : Obtain recommendations for nearby restaurants or attractions
- Function : checkNearestAttractions(input_location)
- Input: *string* Name of location or postal code or address or road name

In [13]:
# Obtain recommendations for nearby popular local restaurants or attractions

@tool
def checkNearestAttractions(input_location):
    """Use this to get recommendations for nearby restaurants or attractions.
    You must provide a valid location or postal code or address or road name.
    """
    # Check Crowd Levels
    lat, lon, address = get_GPS_support(input_location)
    
    if lat is not None:
        # Summarize info
        summary = f'If you are at {address}, try some local delights at http://www.google.com/maps/search/Restaurant/@{lat},{lon},16z/data=!3m1!4b1?entry=ttu and visit attractions at http://www.google.com/maps/search/Things+to+do/@{lat},{lon},16z/data=!3m1!4b1?entry=ttu .'
        return summary
    else:
        return "Error: No results were found. Ensure that the input is a valid location, postal code, address or road name."


### **Defining the set of Tools available to the Agent**

In [22]:
# Create custom tools from a given function

# Define tool to get the best train or walking route  
get_train_route = [Tool(
    name="Train-Journey-Route",
    func=get_public_transport_route_concise,
    description="""
    Use this to find the journey time, cost and list of train stations from one MRT train station to another MRT train station.
    Also, use this to find walking routes and number of steps between MRT train stations.
    You can also use this to find all MRT stations in between 2 stations, and check if you will pass by a certain station.
    To use the tool, you must provide two parameters in the following format: 'start_station,end_station'. Do not include 'mrt' in the input."""
    )]

# Define tool to get GPS coordinates based on a given location
get_gps = [Tool(
    name="Find-GPS-Coordinates",
    func=getGPS,
    description="Use this to get the latitude and longitude coordinates. You must provide a location or postal code or address or road name."
    )]

# Define tool to get the weather forecast for the next 2 and 24 hours
get_weather_forecast = [Tool(
    name="Weather-Forecast",
    func=get_2h_24h_weather_forecast,
    description="Use this to get the weather forecast for the next 2 and 24 hours. Use this to check if it is conducive for walking. Provide any string as input."
    )]

# Define tool to check for any active train service alerts
check_train_alert = [Tool(
    name="Get-Real-Time-Train-Service-Alerts",
    func=checkTrainAlert,
    description="""
    Use this to check for any CURRENT real time disruptions or unavailability of train services, such as affected lines and stations.
    To use the tool, you must provide one or two parameters in the following format: 'station_name' OR 'start_station,end_station'. 
    Example: 'buona vista' for 1 station OR 'buona vista,kent ridge' for between 2 stations. Do not include 'mrt' in the input.
    Do not use this tool if the user is not travelling today.
    """
    )]

# Define tool to identify the 3 nearest taxi stands based on a given location
nearest_taxi_stands = [Tool(
    name="Nearest-Taxi-Stands",
    func=checkNearestTaxiStands,
    description="""
    Use this to find the location of and distance to the nearest taxi stands. 
    You must provide a valid location or postal code or address or road name.
    """
    )]

# Define tool to find the real-time MRT train platform crowdedness level (low, medium, high) at a given MRT station or between 2 stations
get_realtime_crowd = [Tool(
    name="Get-RealTime-Train-Platform-Crowd",
    func=checkRealTimeCrowd,
    description="""
    Use this to find the real time and current MRT train platform crowdedness level.
    To use the tool, you must provide one or two train station parameters in the following format: 'station_name' OR 'start_station,end_station'. 
    Example: 'buona vista' for 1 station OR 'buona vista,kent ridge' for between 2 stations. Do not include 'mrt' in the input.
    """
    )]

# Define tool to find the forecasted MRT train platform crowdedness level (low, medium, high) at a given MRT station or between 2 stations, at a specified date and/or time
get_forecast_crowd = [Tool(
    name="Get-Forecast-Train-Platform-Crowd",
    func=checkForecastVolume,
    description="""Use this to find the forecasted or predicted future MRT train platform crowdedness level.
    To use the tool, you must provide one or two train station parameters ('station_name' OR 'start_station,end_station'), AND specify the date (DD-MM-YYYY) AND/OR time (HH:MM) in the following format: \ 
    Format | Example \
    'station_name;DD-MM-YYYY,HH:MM' | 'buona vista;02-02-2024,18:00' OR \
    'station_name;HH:MM' | 'buona vista;18:00' OR \
    'start_station,end_station;DD-MM-YYYY,HH:MM' | 'somerset,novena;01-01-2024,06:00' OR \
    'start_station,end_station;HH:MM' | 'somerset,novena;06:00'
    """
    )]


# Define tool to obtain recommendations for nearby restaurants or attractions
nearest_attractions = [Tool(
    name="Get-Nearest-Attractions",
    func=checkNearestAttractions,
    description="""Use this to get recommendations for nearby restaurants or attractions.
    You must provide a valid location or postal code or address or road name.
    """
    )]

# Load built-in calculator tool
math_tools = load_tools(["llm-math"], llm=chatllm)

# Initialise tools with the list of custom & built-in tools 
tools = math_tools + get_train_route + get_gps + get_weather_forecast + check_train_alert + nearest_taxi_stands +  get_forecast_crowd + get_realtime_crowd + nearest_attractions 

## Part 3: Chat Model, OpenAI Tool Agent & Memory


In [23]:
# Define the chat LLM system prompt
system_prompt = """You are a highly adept customer service assistant specialized in public transport services.

You must always provide accurate and up-to-date information, pulling from the latest source files and tools available to you. Carefully analyze the provided output from these tools before formulating your responses. If uncertain about an answer, honestly state that you are not sure.

In your responses:
- Emphasize the benefits of public transport for environmental sustainability, such as its role in reducing carbon emissions when relevant.
- Suggest walking routes as an alternative when it is physically safe and the weather permits, but only if suitable paths are available. Provide the number of steps and highlight their health benefits when walking.
- Always check the weather conditions before proposing a walking route.

Tailor your recommendations meticulously based on:
1) The user's stated preferences,
2) Personal circumstances like health, safety, or age considerations,
    - Always suggest a route with the least walking if the user is elderly or sick
    - Always encourage a longer walking route if the user is young and if the weather is good
3) Current external conditions such as weather and operational status of transport services.

For instance, propose a scenic walking route for a healthy, young individual when conditions are favorable. Alternatively, recommend a direct and less strenuous public transport route for someone elderly or less mobile, especially in poor weather conditions or during service disruptions.

When multiple travel options exist, present only those that align closely with the user's specific situation. Each suggestion should include a concise rationale to inform the user why it is the most appropriate choice.

Structure your responses clearly to enhance readability and utility. Use the following guidelines:
- **Tables**: Whenever you present multiple options or alternative routes, organize the information in a table format.
- **Titles**: Use bold font to represent the start of each section. Do not use markdown headers "#".
- **Bullet Points**: Employ bullet points for lists, features, or brief points that do not fit neatly into a table format.

Although primarily focused on providing information related to transport options involving mass rapid transit (MRT) and walking routes, you may engage in general conversation about these topics without utilizing tools, should the user desire a more casual interaction.

You must refrain from discussing topics outside the scope of transportation and commuting. Maintain this focused character at all times without deviation."""

# Define the chat template
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{input}"),        
        MessagesPlaceholder(variable_name="agent_scratchpad"),
    ]
)

# Construct the OpenAI Tools agent   
agent = create_openai_tools_agent(chatllm, tools, prompt) 

# Construct the agent executor
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True, max_iterations=5)

# Pass the latest input to the conversation. The RunnableWithMessageHistory class will wrap our chain (agent executor) and append that input variable to the chat history.
chat_memory = ChatMessageHistory()   # Stores the chat history (memory)
conversational_agent_executor = RunnableWithMessageHistory(
    agent_executor,
    lambda session_id: chat_memory,
    input_messages_key="input",
    output_messages_key="output",
    history_messages_key="chat_history",
)

# Create a function that will distil previous interactions into a summary, instead of storing the entire chat history
def summarize_messages(chain_input):
    stored_messages = chat_memory.messages
    if len(stored_messages) == 0:
        return False
    summarization_prompt = ChatPromptTemplate.from_messages(
        [
            MessagesPlaceholder(variable_name="chat_history"),
            (
                "user",
                "Distill the above chat messages into a single summary message. Include many specific details and only contextual information relevant to transportation or commuting. Ignore failure or error messages.",
            ),
        ]
    )
    summarization_chain = summarization_prompt | llm    # use Mistral llm for summarisation
    summary_message = summarization_chain.invoke({"chat_history": stored_messages})   # summarise the chat history
    chat_memory.clear()    # clear the existing chat history & replace it with a summary
    chat_memory.add_message(summary_message)
    return True

# Chain: 
# First step: Summarise chat history, indicate if 'messages_summarized' is True/False 
# Second step: Execute the agent with reference to the summarised chat history
chain_with_summarization = (
    RunnablePassthrough.assign(messages_summarized=summarize_messages)   # boolean True/False
    | conversational_agent_executor
)

# Define a function to take in the user's input query, and produce the output response from the chat LLM
def invoke_chatllm(input_message, session_id="unused"):
    output_message = chain_with_summarization.invoke(
        {"input": input_message},
        {"configurable": {"session_id": session_id}},
        )
    return output_message

## Part 4: Gradio Chatbot User Interface

In [16]:
# Include thumbs up / down button
def print_like_dislike(x: gr.LikeData):
    print(x.index, x.value, x.liked)

# Include textbox
def add_text(history, text):
    history = history + [(text, None)]
    return history, gr.Textbox(value="", interactive=False)

# Function to handle bot responses
def bot(history):
    query = history[-1][0]  # Get the user's query
    
    # Call the chatLLM, pass in the input & gather a response
    response = invoke_chatllm(query)
    response = response['output']
    
    history[-1][1] = response
    yield history

# Get the current directory where the Jupyter Notebook is located
current_directory = os.getcwd()

# Create Gradio interface
with gr.Blocks() as demo:
    chatbot = gr.Chatbot(
        value=[[None, "🚇 Welcome to GPTTransit! 🚶‍♂️\n\nI'm here to help you navigate MRT trains, find walking routes and locate taxi stands effortlessly. Whether you're planning your journey or exploring options, I've got you covered! Let's make commuting a breeze together!"]],
        elem_id="chatbot",
        bubble_full_width=False,
        avatar_images=((os.path.join(current_directory, "user.png")), (os.path.join(current_directory, "LTA_logo.jpg"))),
        height=600,
        scale=1
    )

    with gr.Row():
        txt = gr.Textbox(
            scale=5,
            show_label=False,
            placeholder="Enter text and press enter",
            container=False,
        )
        clear = gr.Button("Clear")

    txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
        bot, chatbot, chatbot, api_name="bot_response"
    )
    txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)

    chatbot.like(print_like_dislike, None, None)
    clear.click(lambda: None, None, chatbot, queue=False)

demo.queue()
demo.launch()   # share=True

Running on local URL:  http://127.0.0.1:7864

To create a public link, set `share=True` in `launch()`.




In [25]:
chat_memory

# Observe the LangSmith output from our video demo



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `Train-Journey-Route` with `toa payoh,outram park`


[0m[33;1m[1;3mThere are 3 possible travel route(s).
Route 1: NS19 TOA PAYOH to NS24 DHOBY GHAUT to transit by walking 135.0 metres or 180.0 steps to NE6 DHOBY GHAUT to NE3 OUTRAM PARK with an estimated duration of 21.0 minutes and cost $1.49.
Route 2: NS19 TOA PAYOH to NS26 RAFFLES PLACE to transit by crossing the platform (10 meters, 13 steps) to EW14 RAFFLES PLACE to EW16 OUTRAM PARK with an estimated duration of 22.0 minutes and cost $1.49.
Route 3: NS19 TOA PAYOH to NS26 RAFFLES PLACE then walk 1974.0 metres or 2632.0 steps to EW16 OUTRAM PARK with an estimated duration of 41.0 minutes and cost $1.49.
[0m[32;1m[1;3m
Invoking: `Get-Forecast-Train-Platform-Crowd` with `toa payoh,outram park;18-04-2023,16:00`


[0m[36;1m[1;3m18-04-2023 is a WEEKDAY.
On WEEKDAY : CROWD VOLUME at EW16 OUTRAM PARK is MODERATE at these timings: '14:00', '15:00', '16:00', HI