Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Entity Recognizer #26

Merged
merged 10 commits into from
Sep 18, 2023
Merged
1,481 changes: 1,023 additions & 458 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ readme = "README.md"

[tool.poetry.dependencies]
python = ">=3.8,<3.9.7"
wikipedia-api = "^0.5.8"
geopy = "^2.3.0"
poetry-dotenv-plugin = "^0.1.0"
openai = "^0.27.2"
Expand All @@ -19,7 +18,11 @@ python-tsp = "^0.3.1"
streamlit = "^1.22.0"
folium = "^0.14.0"
streamlit-folium = "^0.11.1"
spacy = "^3.5.3"
wikipedia-api = "^0.6.0"

[tool.poetry.dependencies.en_core_web_md]
url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.6.0/en_core_web_md-3.6.0.tar.gz"

[tool.poetry.group.dev.dependencies]
black = "^23.1.0"
Expand Down
6 changes: 3 additions & 3 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ exclude_lines =
# Don't complain if tests don't hit defensive assertion code:
raise NotImplementedError

omit = src/gptravel/main.py

ignore_errors = True

show_missing = True
show_missing = True

omit = __init__.py, main.py
70 changes: 42 additions & 28 deletions src/gptravel/core/services/engine/classifier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import requests
from dotenv import load_dotenv
Expand All @@ -26,7 +26,7 @@ def multi_label(self, multi_label: bool) -> None:
@abstractmethod
def predict(
self, input_text_list: List[str], label_classes: List[str]
) -> Dict[str, Dict[str, float]]:
) -> Optional[Dict[str, Dict[str, float]]]:
pass


Expand All @@ -39,28 +39,23 @@ def _query(self, payload: Dict[str, Any], api_url: str) -> Dict[str, Any]:
headers = {"Authorization": f"Bearer {self._api_token}"}
logger.debug("HuggingFace API fetching response: start")
response = requests.post(
api_url, headers=headers, json=payload, timeout=20
api_url, headers=headers, json=payload, timeout=50
).json()
logger.debug("HuggingFace API fetching response: complete")
if isinstance(response, dict):
logger.error(
"Hugging Face classifier: error in retrieving API response from %s",
api_url,
)
logger.error("API respone: %s", response)
raise HuggingFaceError
return response

def _query_on_different_apis(self, payload: Dict[str, Any]) -> Dict[str, Any]:
api_urls = [
"https://api-inference.huggingface.co/models/facebook/bart-large-mnli",
"https://api-inference.huggingface.co/models/MoritzLaurer/mDeBERTa-v3-base-mnli-xnli",
"https://api-inference.huggingface.co/models/joeddav/xlm-roberta-large-xnli",
]
for api_url in api_urls:
response = self._query(payload=payload, api_url=api_url)
if isinstance(response, list):
logger.debug("Using response from API url: %s", api_url)
return response
return response

def predict(
def _predict(
self,
input_text_list: List[str],
label_classes: List[str],
api_url: str,
) -> Dict[str, Dict[str, float]]:
payload = {
"inputs": input_text_list,
Expand All @@ -69,15 +64,34 @@ def predict(
"multi_label": self._multi_label,
},
}
try:
response = self._query_on_different_apis(payload=payload)
return {
item["sequence"]: {
label: float(value)
for label, value in zip(item["labels"], item["scores"])
}
for item in response
response = self._query(payload=payload, api_url=api_url)
return {
item["sequence"]: {
label: float(value)
for label, value in zip(item["labels"], item["scores"])
}
except Exception as exc:
logger.error("Hugging Face classifier: error in retrieving API response")
raise HuggingFaceError from exc
for item in response
}

def predict(
self, input_text_list: List[str], label_classes: List[str]
) -> Optional[Dict[str, Dict[str, float]]]:
api_url_list = [
"https://api-inference.huggingface.co/models/facebook/bart-large-mnli",
"https://api-inference.huggingface.co/models/MoritzLaurer/mDeBERTa-v3-base-mnli-xnli",
"https://api-inference.huggingface.co/models/joeddav/xlm-roberta-large-xnli",
]
for api_url in api_url_list:
try:
output = self._predict(
input_text_list=input_text_list,
label_classes=label_classes,
api_url=api_url,
)
logger.debug("Using response from API url: %s", api_url)
return output
except HuggingFaceError:
pass
except requests.exceptions.ReadTimeout:
pass
return None
44 changes: 44 additions & 0 deletions src/gptravel/core/services/engine/entity_recognizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Dict, Optional

import spacy
from spacy.language import Language

from gptravel.core.io.loggerconfig import logger

RECOGNIZED_ENTITIES_CACHE = {}


class EntityRecognizer:
def __init__(self, trained_pipeline: str = "en_core_web_md") -> None:
self._nlp = None
try:
self._nlp = spacy.load(
trained_pipeline,
disable=[
"tok2vec",
"tagger",
"parser",
"attribute_ruler",
"lemmatizer",
],
)
except OSError:
logger.warning("%s trained pipeline is not available", trained_pipeline)

@property
def nlp(self) -> Optional[Language]:
return self._nlp

def explain(self, code_class: str) -> Optional[str]:
return spacy.explain(code_class)

def recognize(self, input_string: str) -> Optional[Dict[str, str]]:
if self._nlp is not None:
if input_string not in RECOGNIZED_ENTITIES_CACHE:
RECOGNIZED_ENTITIES_CACHE[input_string] = self._nlp(input_string).ents
if RECOGNIZED_ENTITIES_CACHE[input_string]:
return {
ent.text: ent.label_
for ent in RECOGNIZED_ENTITIES_CACHE[input_string]
}
return None
17 changes: 17 additions & 0 deletions src/gptravel/core/services/engine/wikipedia.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Optional

from wikipediaapi import Wikipedia, WikipediaPage


class WikipediaEngine:
def __init__(self) -> None:
self._wiki_wiki = Wikipedia("Gptravel project", "en")

def _page(self, title_page: str) -> WikipediaPage:
return self._wiki_wiki.page(title=title_page)

def url(self, title_page: str) -> Optional[str]:
try:
return self._page(title_page).fullurl
except KeyError:
return None
16 changes: 11 additions & 5 deletions src/gptravel/core/services/geocoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,21 @@ def location_coordinates(self, location_name: str) -> Dict[str, Optional[float]]
return {"lat": fetched_location.latitude, "lon": fetched_location.longitude}
return {"lat": None, "lon": None}

def location_distance(self, location_name_1: str, location_name_2: str) -> float:
def location_distance(
self, location_name_1: str, location_name_2: str
) -> Optional[float]:
if location_name_1.lower() == location_name_2.lower():
return 0.0
location1_coords = self.location_coordinates(location_name_1)
location2_coords = self.location_coordinates(location_name_2)
return GRC(
(location1_coords["lat"], location1_coords["lon"]),
(location2_coords["lat"], location2_coords["lon"]),
).km
if (location1_coords["lat"] is not None) & (
location2_coords["lat"] is not None
):
return GRC(
(location1_coords["lat"], location1_coords["lon"]),
(location2_coords["lat"], location2_coords["lon"]),
).km
return None

def is_location_country_city_state(self, location_name: str) -> bool:
location_type = self._location_type(location_name)
Expand Down
5 changes: 5 additions & 0 deletions src/gptravel/core/services/score_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from gptravel.core.services.engine.classifier import TextClassifier
from gptravel.core.services.engine.entity_recognizer import EntityRecognizer
from gptravel.core.services.geocoder import GeoCoder
from gptravel.core.services.scorer import (
ActivitiesDiversityScorer,
ActivityPlacesScorer,
CitiesCountryScorer,
DayGenerationScorer,
OptimizedItineraryScorer,
Expand All @@ -14,6 +16,9 @@ class ScorerOrchestrator:
def __init__(self, geocoder: GeoCoder, text_classifier: TextClassifier) -> None:
self._scorers = [
ActivitiesDiversityScorer(text_classifier),
ActivityPlacesScorer(
geolocator=geocoder, entity_recognizer=EntityRecognizer()
),
DayGenerationScorer(),
CitiesCountryScorer(geolocator=geocoder),
OptimizedItineraryScorer(geolocator=geocoder),
Expand Down
66 changes: 65 additions & 1 deletion src/gptravel/core/services/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from gptravel.core.io.loggerconfig import logger
from gptravel.core.services.config import ACTIVITIES_LABELS
from gptravel.core.services.engine.classifier import TextClassifier
from gptravel.core.services.engine.entity_recognizer import EntityRecognizer
from gptravel.core.services.engine.exception import HuggingFaceError
from gptravel.core.services.engine.tsp_solver import TSPSolver
from gptravel.core.services.geocoder import GeoCoder
Expand Down Expand Up @@ -37,7 +38,7 @@ def score_weight_key(self) -> str:
return self._score_weight_key

@property
def score_map(self) -> Dict[str, Dict[str, Union[float, int]]]:
def score_map(self) -> Dict[str, Dict[str, Any]]:
return self._score_map

@property
Expand Down Expand Up @@ -96,6 +97,8 @@ def score(
labeled_activities = self._classifier.predict(
input_text_list=activities_list, label_classes=self._activities_labels
)
if labeled_activities is None:
raise HuggingFaceError
aggregated_scores = {
key: sum(item[key] for item in labeled_activities.values())
for key in self._activities_labels
Expand Down Expand Up @@ -314,3 +317,64 @@ def score(
logger.debug("CitiesCountryScorer: End")
else:
logger.debug("CitiesCountryScorer: End -- No Computation needed")


class ActivityPlacesScorer(ScoreService):
def __init__(
self,
geolocator: GeoCoder,
entity_recognizer: EntityRecognizer,
score_weight: float = 1.0,
) -> None:
service_name = "Activity Places"
super().__init__(service_name, score_weight)
self._geolocator = geolocator
self._er = entity_recognizer

def score(
self, travel_plan: TravelPlanJSON, travel_plan_scores: TravelPlanScore
) -> None:
logger.debug("ActivityPlacesScorer: Start")
logger.debug("Start recognizing entities in travel activities")
recognized_places = {
city: [
self._er.recognize(activity)
for activity in travel_plan.get_travel_activities_from_city(city)
if self._er.recognize(activity) is not None
]
for city in travel_plan.travel_cities
}
logger.debug("Recognized entities: %s", recognized_places)
# check if those places exists in the corresponding city
logger.debug("Start geolocalization of founded entities")
existing_places = 0.0
total_entities = 0.0
for city in recognized_places.keys():
for entity in recognized_places[city]:
if entity is not None:
total_entities += 1
entity_name = list(entity.keys())[0] + ", " + city
entity_country = self._geolocator.country_from_location_name(
entity_name
)
if entity_country:
if (
entity_country
== self._geolocator.country_from_location_name(city)
):
logger.debug("Teh entity %s exists", entity_name)
existing_places += 1
else:
logger.warn("The entity %s does not exist", entity_name)
entity_scores = existing_places / total_entities
logger.debug("ActivityPlacesScorer: score value = %f", entity_scores)
logger.debug("ActivityPlacesScorer: weight value = %f", self._score_weight)
travel_plan_scores.add_score(
score_type=self._service_name,
score_dict={
travel_plan_scores.score_value_key: entity_scores,
travel_plan_scores.score_weight_key: self._score_weight,
"recognized_entities": recognized_places,
},
)
logger.debug("ActivityPlacesScorer: End")
4 changes: 4 additions & 0 deletions src/gptravel/core/travel_planner/travel_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from gptravel.core.travel_planner.prompt import Prompt
from gptravel.core.utils.general import (
extract_inner_list_for_given_key,
extract_inner_lists_from_json,
extract_keys_by_depth_from_json,
)
Expand Down Expand Up @@ -58,6 +59,9 @@ def get_key_values_by_name(self, key_name: str) -> List[Any]:
except KeyError:
return []

def get_travel_activities_from_city(self, city_name: str) -> List[Any]:
return extract_inner_list_for_given_key(self._travel_plan, key=city_name)

@property
def travel_cities(self) -> List[str]:
return self.get_key_values_by_name("city")
Expand Down
32 changes: 28 additions & 4 deletions src/gptravel/core/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,34 @@ def _extract_keys(json_obj: Dict[Any, Any], curr_depth: int) -> None:


def extract_inner_lists_from_json(json_obj: Dict[Any, Any]) -> List[Any]:
activities = []
inner_list = []
if isinstance(json_obj, dict):
for value in json_obj.values():
activities.extend(extract_inner_lists_from_json(value))
inner_list.extend(extract_inner_lists_from_json(value))
elif isinstance(json_obj, list):
activities.extend(json_obj)
return activities
inner_list.extend(json_obj)
return inner_list


def extract_inner_list_for_given_key(json_obj: Dict[Any, Any], key: Any) -> List[Any]:
"""
Extracts all lists associated with the given key from a dictionary of undefined depth.

Args:
dictionary (dict): The input dictionary.
key: The key to search for.

Returns:
list: The extracted lists if the key is found, otherwise an empty list.
"""
lists = (
[json_obj[key]] if key in json_obj and isinstance(json_obj[key], list) else []
)

nested_lists = [
extract_inner_list_for_given_key(value, key)
for value in json_obj.values()
if isinstance(value, (dict))
]

return [item for sublist in lists + nested_lists for item in sublist]
Loading