From 092f714391b44c97d7c61028b860cbe527ec4083 Mon Sep 17 00:00:00 2001 From: Stefano Polo Date: Mon, 10 Jul 2023 22:11:33 +0200 Subject: [PATCH] implement filter engine that removes the departure place from the generated travel plan --- src/gptravel/core/services/checker.py | 2 +- src/gptravel/core/services/filters.py | 65 +++++++++++++++++++ src/gptravel/main.py | 3 +- src/gptravel/prototype/pages/__init__.py | 2 +- src/gptravel/prototype/pages/travel.py | 6 ++ src/gptravel/prototype/utils.py | 1 - .../test_core/test_services/test_filters.py | 48 ++++++++++++++ .../test_travel_planner/test_prompt.py | 2 - .../test_travel_planner/test_tokenizer.py | 5 +- .../test_travel_planner/test_travel_engine.py | 2 - .../test_core/test_utils/test_general.py | 2 - .../test_core/test_utils/test_regex_tool.py | 2 - 12 files changed, 127 insertions(+), 13 deletions(-) create mode 100644 src/gptravel/core/services/filters.py create mode 100644 tests/test_gptravel/test_core/test_services/test_filters.py diff --git a/src/gptravel/core/services/checker.py b/src/gptravel/core/services/checker.py index c834b45..c9686cd 100644 --- a/src/gptravel/core/services/checker.py +++ b/src/gptravel/core/services/checker.py @@ -29,5 +29,5 @@ def check(self, travel_plan: TravelPlanJSON) -> bool: if all_exists: logger.debug("Check passed") else: - logger.debug("Check not passed") + logger.warning("Check not passed") return all_exists diff --git a/src/gptravel/core/services/filters.py b/src/gptravel/core/services/filters.py new file mode 100644 index 0000000..9bb1ae6 --- /dev/null +++ b/src/gptravel/core/services/filters.py @@ -0,0 +1,65 @@ +from abc import ABC, abstractmethod + +from gptravel.core.io.loggerconfig import logger +from gptravel.core.travel_planner.travel_engine import TravelPlanJSON + + +class Filter(ABC): + def __init__(self, **kwargs): + pass + + @abstractmethod + def filter(self, travel_plan: TravelPlanJSON) -> None: + pass + + +class DeparturePlaceFilter(Filter): + def filter(self, travel_plan: TravelPlanJSON) -> None: + logger.debug("DeparturePlaceFilter: Start") + departure_place = travel_plan.departure_place.lower() + # if the departure place is present in the travel plan then remove it + if departure_place in [city.lower() for city in travel_plan.travel_cities]: + logger.debug("Found {} inside the travel plan".format(departure_place)) + day_depth = travel_plan.keys_map["day"] + if day_depth == 0: + days_to_drop = [] + for day in travel_plan.travel_plan.keys(): + key_to_remove = [ + city + for city in travel_plan.travel_plan[day].keys() + if city.lower() == departure_place + ] + if key_to_remove: + logger.debug( + "Removed {} from the travel plan for {}".format( + departure_place, day + ) + ) + del travel_plan.travel_plan[day][key_to_remove[0]] + # if the day container is empty then remove it + if travel_plan.travel_plan[day] == {}: + days_to_drop.append(day) + logger.debug( + "Removed {} completely from the travel plan".format(day) + ) + if days_to_drop: + for day_to_delete in days_to_drop: + del travel_plan.travel_plan[day_to_delete] + # fix the order of the + day_first_word = days_to_drop[0].split(" ")[0] + day_keys = list(travel_plan.travel_plan.keys()) + n = 1 + for old_key in day_keys: + new_key = day_first_word + " " + str(n) + travel_plan.travel_plan[new_key] = travel_plan.travel_plan.pop( + old_key + ) + n += 1 + else: + key_to_remove = [ + city + for city in travel_plan.travel_plan.keys() + if city.lower() == departure_place + ][0] + del travel_plan.travel_plan[key_to_remove] + logger.debug("DeparturePlaceFilter: End") diff --git a/src/gptravel/main.py b/src/gptravel/main.py index 9218c42..ad74374 100644 --- a/src/gptravel/main.py +++ b/src/gptravel/main.py @@ -12,7 +12,8 @@ from gptravel.core.services.scorer import TravelPlanScore from gptravel.core.travel_planner.openai_engine import ChatGPTravelEngine from gptravel.core.travel_planner.prompt import PromptFactory -#from gptravel.core.travel_planner.travel_engine import TravelPlanJSON + +# from gptravel.core.travel_planner.travel_engine import TravelPlanJSON load_dotenv() diff --git a/src/gptravel/prototype/pages/__init__.py b/src/gptravel/prototype/pages/__init__.py index 2fb977f..7a2494e 100644 --- a/src/gptravel/prototype/pages/__init__.py +++ b/src/gptravel/prototype/pages/__init__.py @@ -1,3 +1,3 @@ from gptravel.core.services.geocoder import GeoCoder -geo_coder = GeoCoder() \ No newline at end of file +geo_coder = GeoCoder() diff --git a/src/gptravel/prototype/pages/travel.py b/src/gptravel/prototype/pages/travel.py index 7462e50..df00304 100644 --- a/src/gptravel/prototype/pages/travel.py +++ b/src/gptravel/prototype/pages/travel.py @@ -9,6 +9,8 @@ from streamlit_folium import st_folium from gptravel.core.io.loggerconfig import logger +from gptravel.core.services.checker import ExistingDestinationsChecker +from gptravel.core.services.filters import DeparturePlaceFilter from gptravel.core.services.geocoder import GeoCoder from gptravel.core.travel_planner import openai_engine from gptravel.core.travel_planner.prompt import Prompt, PromptFactory @@ -148,6 +150,10 @@ def _get_travel_plan( prompt=prompt, max_tokens=max_number_tokens ) logger.info("Generating Travel Plan: End") + travel_filter = DeparturePlaceFilter() + travel_filter.filter(travel_plan_json) + checker = ExistingDestinationsChecker(geocoder) + checker.check(travel_plan_json) travel_plan_dict = travel_plan_json.travel_plan score_dict = prototype_utils.get_score_map(travel_plan_json) diff --git a/src/gptravel/prototype/utils.py b/src/gptravel/prototype/utils.py index 886bab3..1ed5c80 100644 --- a/src/gptravel/prototype/utils.py +++ b/src/gptravel/prototype/utils.py @@ -1,7 +1,6 @@ from datetime import date from typing import Dict, List, Tuple, Union -import numpy as np import openai from gptravel.core.io.loggerconfig import logger diff --git a/tests/test_gptravel/test_core/test_services/test_filters.py b/tests/test_gptravel/test_core/test_services/test_filters.py new file mode 100644 index 0000000..5d11c94 --- /dev/null +++ b/tests/test_gptravel/test_core/test_services/test_filters.py @@ -0,0 +1,48 @@ +from copy import deepcopy + +import pytest + +from gptravel.core.services.filters import DeparturePlaceFilter, TravelPlanJSON + + +@pytest.fixture +def departure_place_filter() -> DeparturePlaceFilter: + return DeparturePlaceFilter() + + +class TestDeparturePlaceFilter: + def test_on_travel_plan_with_departure( + self, + departure_place_filter: DeparturePlaceFilter, + italian_travel_plan: TravelPlanJSON, + ): + before_removal = deepcopy(italian_travel_plan) + departure_place_filter.filter(italian_travel_plan) + assert before_removal.travel_plan != italian_travel_plan.travel_plan + assert ( + italian_travel_plan.departure_place not in italian_travel_plan.travel_cities + ) + assert italian_travel_plan.get_key_values_by_name("day") == [ + "Day 1", + "Day 2", + "Day 3", + ] + assert italian_travel_plan.travel_activities == [ + "See San Marco", + "Take a ride on gondola", + "Eat an arancina", + "Eat fiorentina", + ] + + def test_on_travel_plan_with_nodeparture( + self, + departure_place_filter: DeparturePlaceFilter, + travel_plan_single_city_per_day: TravelPlanJSON, + ): + before_removal = deepcopy(travel_plan_single_city_per_day) + departure_place_filter.filter(travel_plan_single_city_per_day) + assert ( + travel_plan_single_city_per_day.departure_place + not in travel_plan_single_city_per_day.travel_cities + ) + assert before_removal.travel_plan == travel_plan_single_city_per_day.travel_plan diff --git a/tests/test_gptravel/test_core/test_travel_planner/test_prompt.py b/tests/test_gptravel/test_core/test_travel_planner/test_prompt.py index 5141ba2..8b76536 100644 --- a/tests/test_gptravel/test_core/test_travel_planner/test_prompt.py +++ b/tests/test_gptravel/test_core/test_travel_planner/test_prompt.py @@ -1,7 +1,5 @@ from typing import Any, Dict -import pytest - from gptravel.core.travel_planner.prompt import ( PlainTravelPrompt, PromptFactory, diff --git a/tests/test_gptravel/test_core/test_travel_planner/test_tokenizer.py b/tests/test_gptravel/test_core/test_travel_planner/test_tokenizer.py index 2cd6dc2..0057774 100644 --- a/tests/test_gptravel/test_core/test_travel_planner/test_tokenizer.py +++ b/tests/test_gptravel/test_core/test_travel_planner/test_tokenizer.py @@ -13,7 +13,10 @@ def chatgpt_token_manager() -> ChatGptTokenManager: [(10, 100, 757), (0, 0, 383), (10, 5000, 756), (-100000000, 5, 383)], ) def test_chatgpt_token_manager( - chatgpt_token_manager: ChatGptTokenManager, n_days: int, distance: float, expected: int + chatgpt_token_manager: ChatGptTokenManager, + n_days: int, + distance: float, + expected: int, ): result = chatgpt_token_manager.get_number_tokens(n_days=n_days, distance=distance) assert result == pytest.approx(expected, 1) diff --git a/tests/test_gptravel/test_core/test_travel_planner/test_travel_engine.py b/tests/test_gptravel/test_core/test_travel_planner/test_travel_engine.py index fb4c1e0..7bec1bf 100644 --- a/tests/test_gptravel/test_core/test_travel_planner/test_travel_engine.py +++ b/tests/test_gptravel/test_core/test_travel_planner/test_travel_engine.py @@ -1,5 +1,3 @@ -import pytest - from gptravel.core.travel_planner.travel_engine import TravelPlanJSON diff --git a/tests/test_gptravel/test_core/test_utils/test_general.py b/tests/test_gptravel/test_core/test_utils/test_general.py index f6260f7..8199bde 100644 --- a/tests/test_gptravel/test_core/test_utils/test_general.py +++ b/tests/test_gptravel/test_core/test_utils/test_general.py @@ -1,7 +1,5 @@ import json -import pytest - from gptravel.core.utils.general import ( extract_inner_lists_from_json, extract_keys_by_depth_from_json, diff --git a/tests/test_gptravel/test_core/test_utils/test_regex_tool.py b/tests/test_gptravel/test_core/test_utils/test_regex_tool.py index d17b56f..a1ef9df 100644 --- a/tests/test_gptravel/test_core/test_utils/test_regex_tool.py +++ b/tests/test_gptravel/test_core/test_utils/test_regex_tool.py @@ -1,8 +1,6 @@ import json from inspect import cleandoc -import pytest - from gptravel.core.utils.regex_tool import JsonExtractor