Skip to content

Commit

Permalink
Merge pull request #16 from RobertoCorti/bugfix/travel_cities_filter
Browse files Browse the repository at this point in the history
  • Loading branch information
stefano-polo committed Jul 12, 2023
2 parents fa5da43 + 092f714 commit 567e3db
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/gptravel/core/services/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ def check(self, travel_plan: TravelPlanJSON) -> bool:
if all_exists:
logger.debug("Check passed")
else:
logger.debug("Check not passed")
logger.warning("Check not passed")
return all_exists
65 changes: 65 additions & 0 deletions src/gptravel/core/services/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from abc import ABC, abstractmethod

from gptravel.core.io.loggerconfig import logger
from gptravel.core.travel_planner.travel_engine import TravelPlanJSON


class Filter(ABC):
def __init__(self, **kwargs):
pass

@abstractmethod
def filter(self, travel_plan: TravelPlanJSON) -> None:
pass


class DeparturePlaceFilter(Filter):
def filter(self, travel_plan: TravelPlanJSON) -> None:
logger.debug("DeparturePlaceFilter: Start")
departure_place = travel_plan.departure_place.lower()
# if the departure place is present in the travel plan then remove it
if departure_place in [city.lower() for city in travel_plan.travel_cities]:
logger.debug("Found {} inside the travel plan".format(departure_place))
day_depth = travel_plan.keys_map["day"]
if day_depth == 0:
days_to_drop = []
for day in travel_plan.travel_plan.keys():
key_to_remove = [
city
for city in travel_plan.travel_plan[day].keys()
if city.lower() == departure_place
]
if key_to_remove:
logger.debug(
"Removed {} from the travel plan for {}".format(
departure_place, day
)
)
del travel_plan.travel_plan[day][key_to_remove[0]]
# if the day container is empty then remove it
if travel_plan.travel_plan[day] == {}:
days_to_drop.append(day)
logger.debug(
"Removed {} completely from the travel plan".format(day)
)
if days_to_drop:
for day_to_delete in days_to_drop:
del travel_plan.travel_plan[day_to_delete]
# fix the order of the
day_first_word = days_to_drop[0].split(" ")[0]
day_keys = list(travel_plan.travel_plan.keys())
n = 1
for old_key in day_keys:
new_key = day_first_word + " " + str(n)
travel_plan.travel_plan[new_key] = travel_plan.travel_plan.pop(
old_key
)
n += 1
else:
key_to_remove = [
city
for city in travel_plan.travel_plan.keys()
if city.lower() == departure_place
][0]
del travel_plan.travel_plan[key_to_remove]
logger.debug("DeparturePlaceFilter: End")
3 changes: 2 additions & 1 deletion src/gptravel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from gptravel.core.services.scorer import TravelPlanScore
from gptravel.core.travel_planner.openai_engine import ChatGPTravelEngine
from gptravel.core.travel_planner.prompt import PromptFactory
#from gptravel.core.travel_planner.travel_engine import TravelPlanJSON

# from gptravel.core.travel_planner.travel_engine import TravelPlanJSON

load_dotenv()

Expand Down
2 changes: 1 addition & 1 deletion src/gptravel/prototype/pages/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from gptravel.core.services.geocoder import GeoCoder

geo_coder = GeoCoder()
geo_coder = GeoCoder()
6 changes: 6 additions & 0 deletions src/gptravel/prototype/pages/travel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from streamlit_folium import st_folium

from gptravel.core.io.loggerconfig import logger
from gptravel.core.services.checker import ExistingDestinationsChecker
from gptravel.core.services.filters import DeparturePlaceFilter
from gptravel.core.services.geocoder import GeoCoder
from gptravel.core.travel_planner import openai_engine
from gptravel.core.travel_planner.prompt import Prompt, PromptFactory
Expand Down Expand Up @@ -148,6 +150,10 @@ def _get_travel_plan(
prompt=prompt, max_tokens=max_number_tokens
)
logger.info("Generating Travel Plan: End")
travel_filter = DeparturePlaceFilter()
travel_filter.filter(travel_plan_json)
checker = ExistingDestinationsChecker(geocoder)
checker.check(travel_plan_json)
travel_plan_dict = travel_plan_json.travel_plan

score_dict = prototype_utils.get_score_map(travel_plan_json)
Expand Down
1 change: 0 additions & 1 deletion src/gptravel/prototype/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from datetime import date
from typing import Dict, List, Tuple, Union

import numpy as np
import openai

from gptravel.core.io.loggerconfig import logger
Expand Down
48 changes: 48 additions & 0 deletions tests/test_gptravel/test_core/test_services/test_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from copy import deepcopy

import pytest

from gptravel.core.services.filters import DeparturePlaceFilter, TravelPlanJSON


@pytest.fixture
def departure_place_filter() -> DeparturePlaceFilter:
return DeparturePlaceFilter()


class TestDeparturePlaceFilter:
def test_on_travel_plan_with_departure(
self,
departure_place_filter: DeparturePlaceFilter,
italian_travel_plan: TravelPlanJSON,
):
before_removal = deepcopy(italian_travel_plan)
departure_place_filter.filter(italian_travel_plan)
assert before_removal.travel_plan != italian_travel_plan.travel_plan
assert (
italian_travel_plan.departure_place not in italian_travel_plan.travel_cities
)
assert italian_travel_plan.get_key_values_by_name("day") == [
"Day 1",
"Day 2",
"Day 3",
]
assert italian_travel_plan.travel_activities == [
"See San Marco",
"Take a ride on gondola",
"Eat an arancina",
"Eat fiorentina",
]

def test_on_travel_plan_with_nodeparture(
self,
departure_place_filter: DeparturePlaceFilter,
travel_plan_single_city_per_day: TravelPlanJSON,
):
before_removal = deepcopy(travel_plan_single_city_per_day)
departure_place_filter.filter(travel_plan_single_city_per_day)
assert (
travel_plan_single_city_per_day.departure_place
not in travel_plan_single_city_per_day.travel_cities
)
assert before_removal.travel_plan == travel_plan_single_city_per_day.travel_plan
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Any, Dict

import pytest

from gptravel.core.travel_planner.prompt import (
PlainTravelPrompt,
PromptFactory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ def chatgpt_token_manager() -> ChatGptTokenManager:
[(10, 100, 757), (0, 0, 383), (10, 5000, 756), (-100000000, 5, 383)],
)
def test_chatgpt_token_manager(
chatgpt_token_manager: ChatGptTokenManager, n_days: int, distance: float, expected: int
chatgpt_token_manager: ChatGptTokenManager,
n_days: int,
distance: float,
expected: int,
):
result = chatgpt_token_manager.get_number_tokens(n_days=n_days, distance=distance)
assert result == pytest.approx(expected, 1)
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pytest

from gptravel.core.travel_planner.travel_engine import TravelPlanJSON


Expand Down
2 changes: 0 additions & 2 deletions tests/test_gptravel/test_core/test_utils/test_general.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import json

import pytest

from gptravel.core.utils.general import (
extract_inner_lists_from_json,
extract_keys_by_depth_from_json,
Expand Down
2 changes: 0 additions & 2 deletions tests/test_gptravel/test_core/test_utils/test_regex_tool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import json
from inspect import cleandoc

import pytest

from gptravel.core.utils.regex_tool import JsonExtractor


Expand Down

0 comments on commit 567e3db

Please sign in to comment.