-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #16 from RobertoCorti/bugfix/travel_cities_filter
- Loading branch information
Showing
12 changed files
with
127 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
from gptravel.core.io.loggerconfig import logger | ||
from gptravel.core.travel_planner.travel_engine import TravelPlanJSON | ||
|
||
|
||
class Filter(ABC): | ||
def __init__(self, **kwargs): | ||
pass | ||
|
||
@abstractmethod | ||
def filter(self, travel_plan: TravelPlanJSON) -> None: | ||
pass | ||
|
||
|
||
class DeparturePlaceFilter(Filter): | ||
def filter(self, travel_plan: TravelPlanJSON) -> None: | ||
logger.debug("DeparturePlaceFilter: Start") | ||
departure_place = travel_plan.departure_place.lower() | ||
# if the departure place is present in the travel plan then remove it | ||
if departure_place in [city.lower() for city in travel_plan.travel_cities]: | ||
logger.debug("Found {} inside the travel plan".format(departure_place)) | ||
day_depth = travel_plan.keys_map["day"] | ||
if day_depth == 0: | ||
days_to_drop = [] | ||
for day in travel_plan.travel_plan.keys(): | ||
key_to_remove = [ | ||
city | ||
for city in travel_plan.travel_plan[day].keys() | ||
if city.lower() == departure_place | ||
] | ||
if key_to_remove: | ||
logger.debug( | ||
"Removed {} from the travel plan for {}".format( | ||
departure_place, day | ||
) | ||
) | ||
del travel_plan.travel_plan[day][key_to_remove[0]] | ||
# if the day container is empty then remove it | ||
if travel_plan.travel_plan[day] == {}: | ||
days_to_drop.append(day) | ||
logger.debug( | ||
"Removed {} completely from the travel plan".format(day) | ||
) | ||
if days_to_drop: | ||
for day_to_delete in days_to_drop: | ||
del travel_plan.travel_plan[day_to_delete] | ||
# fix the order of the | ||
day_first_word = days_to_drop[0].split(" ")[0] | ||
day_keys = list(travel_plan.travel_plan.keys()) | ||
n = 1 | ||
for old_key in day_keys: | ||
new_key = day_first_word + " " + str(n) | ||
travel_plan.travel_plan[new_key] = travel_plan.travel_plan.pop( | ||
old_key | ||
) | ||
n += 1 | ||
else: | ||
key_to_remove = [ | ||
city | ||
for city in travel_plan.travel_plan.keys() | ||
if city.lower() == departure_place | ||
][0] | ||
del travel_plan.travel_plan[key_to_remove] | ||
logger.debug("DeparturePlaceFilter: End") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from gptravel.core.services.geocoder import GeoCoder | ||
|
||
geo_coder = GeoCoder() | ||
geo_coder = GeoCoder() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
48 changes: 48 additions & 0 deletions
48
tests/test_gptravel/test_core/test_services/test_filters.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from copy import deepcopy | ||
|
||
import pytest | ||
|
||
from gptravel.core.services.filters import DeparturePlaceFilter, TravelPlanJSON | ||
|
||
|
||
@pytest.fixture | ||
def departure_place_filter() -> DeparturePlaceFilter: | ||
return DeparturePlaceFilter() | ||
|
||
|
||
class TestDeparturePlaceFilter: | ||
def test_on_travel_plan_with_departure( | ||
self, | ||
departure_place_filter: DeparturePlaceFilter, | ||
italian_travel_plan: TravelPlanJSON, | ||
): | ||
before_removal = deepcopy(italian_travel_plan) | ||
departure_place_filter.filter(italian_travel_plan) | ||
assert before_removal.travel_plan != italian_travel_plan.travel_plan | ||
assert ( | ||
italian_travel_plan.departure_place not in italian_travel_plan.travel_cities | ||
) | ||
assert italian_travel_plan.get_key_values_by_name("day") == [ | ||
"Day 1", | ||
"Day 2", | ||
"Day 3", | ||
] | ||
assert italian_travel_plan.travel_activities == [ | ||
"See San Marco", | ||
"Take a ride on gondola", | ||
"Eat an arancina", | ||
"Eat fiorentina", | ||
] | ||
|
||
def test_on_travel_plan_with_nodeparture( | ||
self, | ||
departure_place_filter: DeparturePlaceFilter, | ||
travel_plan_single_city_per_day: TravelPlanJSON, | ||
): | ||
before_removal = deepcopy(travel_plan_single_city_per_day) | ||
departure_place_filter.filter(travel_plan_single_city_per_day) | ||
assert ( | ||
travel_plan_single_city_per_day.departure_place | ||
not in travel_plan_single_city_per_day.travel_cities | ||
) | ||
assert before_removal.travel_plan == travel_plan_single_city_per_day.travel_plan |
2 changes: 0 additions & 2 deletions
2
tests/test_gptravel/test_core/test_travel_planner/test_prompt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 0 additions & 2 deletions
2
tests/test_gptravel/test_core/test_travel_planner/test_travel_engine.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,3 @@ | ||
import pytest | ||
|
||
from gptravel.core.travel_planner.travel_engine import TravelPlanJSON | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,6 @@ | ||
import json | ||
from inspect import cleandoc | ||
|
||
import pytest | ||
|
||
from gptravel.core.utils.regex_tool import JsonExtractor | ||
|
||
|
||
|