Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.vscode/
.ruff_cache/
temp_data/
src/gen_paths/playground.py
data/
Expand Down
6 changes: 0 additions & 6 deletions eval_requirements.txt

This file was deleted.

10 changes: 0 additions & 10 deletions infer_requirements.txt

This file was deleted.

5 changes: 1 addition & 4 deletions mango/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import logging
import os

import packaging.version
"""MANGO: A Benchmark for Evaluating Mapping and Navigation Abilities of Large Language Models."""

__version__ = "1.0.0"
33 changes: 16 additions & 17 deletions mango/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,39 @@
evaluate_model_route_finding,
)


def main():
parser = argparse.ArgumentParser(description='CLI for Mango Evaluation')
parser = argparse.ArgumentParser(description="CLI for Mango Evaluation")
parser.add_argument(
'--mode',
choices=['rf', 'df'],
help='Evaluation mode: "route_finding" or "dest_finding"'
"--mode",
choices=["rf", "df"],
help='Evaluation mode: "route_finding" or "dest_finding"',
)
parser.add_argument(
'--rst-dir',
"--rst-dir",
type=str,
help='Directory with result files (e.g., output from the model)'
help="Directory with result files (e.g., output from the model)",
)
parser.add_argument(
'--map-dir',
type=str,
default='./data',
help='Directory with map files'
"--map-dir", type=str, default="./data", help="Directory with map files"
)
parser.add_argument(
'--output-dir',
nargs='?',
default='./output',
help='Directory to save the evaluation summaries (default: ./output)'
"--output-dir",
nargs="?",
default="./output",
help="Directory to save the evaluation summaries (default: ./output)",
)

args = parser.parse_args()
rst_dir_name = os.path.basename(args.rst_dir)
args.output_dir = os.path.join(args.output_dir, rst_dir_name)
if args.mode == 'rf':
if args.mode == "rf":
evaluate_model_route_finding(args.rst_dir, args.map_dir, args.output_dir)
elif args.mode == 'df':
elif args.mode == "df":
evaluate_model_dest_finding(args.rst_dir, args.map_dir, args.output_dir)
else:
parser.print_help()

if __name__ == '__main__':

if __name__ == "__main__":
main()
65 changes: 33 additions & 32 deletions mango/evaluation/map_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import os
from collections import deque
from enum import IntEnum
from typing import Optional

from mango.evaluation.config import KEY_MAPPING
from mango.evaluation.utils import edit_distance, parse_raw_output
from mango.evaluation.utils.map_utils import get_game_info_with_G_eval
import sys
import traceback


class EvalMetric(IntEnum):
Strict = 0
Loose = 1
Expand All @@ -20,7 +20,7 @@ class TaskType(IntEnum):


class MapEvaluator:
def __init__(self, map_dir: str, map_name: str, key_mapping: Optional[dict] = None):
def __init__(self, map_dir: str, map_name: str, key_mapping: dict | None = None):
(
self.G_eval,
self.G,
Expand Down Expand Up @@ -86,9 +86,9 @@ def eval_actions(self, path, src_node, dst_node, metric):
key=lambda v: self.matching_score(node_action, v, metric),
)
path_actions.append(action)

dest_nodes = self.bfs_get_multi_des(src_node, path_actions)

return (
max([self.matching_score(dst_node, node, metric) for node in dest_nodes])
if dest_nodes
Expand All @@ -110,18 +110,15 @@ def eval_full_path(self, path, src_node, dst_node):
location_after = edge[location_after_key].strip()
action = edge[action_key].strip()
if not self.G_eval.has_edge(location_before, location_after):

return 0
if not any(
action == data.get(action_key)
for _, _, data in self.G_eval.edges(location_before, data=True)
if _ == location_after
):

return 0
for i in range(len(path) - 1):
if path[i][location_after_key] != path[i + 1][location_before_key]:

return 0
return 1

Expand All @@ -138,36 +135,38 @@ def eval_route_finding(self, file_path):
"is_easy": 0,
"is_hard": 0,
}
with open(file_path, "r") as f:
with open(file_path) as f:
try:
infer_rst = json.load(f)
except Exception as e:
except Exception:
return_rst["parsing_error"] = 1
print('file parsing error: ',file_path)
print("file parsing error: ", file_path)
return return_rst
try:
sample_id = infer_rst[key_mapping["sample_id"]]
game_name = infer_rst[key_mapping["game_name"]]
_game_name = infer_rst[key_mapping["game_name"]] # noqa: F841
src_node = infer_rst[key_mapping["src_node"]]
dst_node = infer_rst[key_mapping["dst_node"]]
model_cutoff_num = infer_rst[key_mapping["model_cutoff_num"]]
min_step_total_answerable = infer_rst[key_mapping["min_step_total_answerable"]]
answerable = infer_rst[key_mapping["answerable"]]
min_step_total_answerable = infer_rst[
key_mapping["min_step_total_answerable"]
]
_answerable = infer_rst[key_mapping["answerable"]] # noqa: F841
response = infer_rst[key_mapping["response"]]
except Exception as e:
except Exception:
error_info = traceback.format_exc()
print('file parsing error: ',file_path,error_info)
print("file parsing error: ", file_path, error_info)
return_rst["parsing_error"] = 1
return return_rst
parsed_response,error_info = self.parse_llm_raw_output(response)

parsed_response, error_info = self.parse_llm_raw_output(response)
if error_info is not None:
print('file parsing error: ',file_path,error_info)
print("file parsing error: ", file_path, error_info)
return_rst["parsing_error"] = 1
return return_rst
if "[" in response and "]" in response:
if parsed_response is None:
print('file parsing error: ',file_path)
print("file parsing error: ", file_path)
if sample_id not in self.all_pairs.keys():
return_rst["id_error"] = 1
return return_rst
Expand Down Expand Up @@ -224,31 +223,33 @@ def eval_dest_finding(self, file_path):
"is_easy": 0,
"is_hard": 0,
}
with open(file_path, "r") as f:
with open(file_path) as f:
try:
infer_rst = json.load(f)
except Exception as e:
print('file parsing error: ',file_path)
except Exception:
print("file parsing error: ", file_path)
return_rst["parsing_error"] = 1
return return_rst
try:
sample_id = infer_rst[key_mapping["sample_id"]]
game_name = infer_rst[key_mapping["game_name"]]
_game_name = infer_rst[key_mapping["game_name"]] # noqa: F841
src_node = infer_rst[key_mapping["src_node"]]
dst_node = infer_rst[key_mapping["dst_node"]]
_dst_node = infer_rst[key_mapping["dst_node"]] # noqa: F841
model_cutoff_num = infer_rst[key_mapping["model_cutoff_num"]]
min_step_total_answerable = infer_rst[key_mapping["min_step_total_answerable"]]
answerable = infer_rst[key_mapping["answerable"]]
min_step_total_answerable = infer_rst[
key_mapping["min_step_total_answerable"]
]
_answerable = infer_rst[key_mapping["answerable"]] # noqa: F841
response = infer_rst[key_mapping["response"]]
action_list = infer_rst[key_mapping["action_list"]]
except Exception as e:
except Exception:
error_info = traceback.format_exc()
print('file parsing error: ',file_path,error_info)
print("file parsing error: ", file_path, error_info)
return_rst["parsing_error"] = 1
return return_rst
parsed_response,error_info = self.parse_llm_raw_output(response)
parsed_response, error_info = self.parse_llm_raw_output(response)
if error_info is not None:
print('file parsing error: ',file_path,error_info)
print("file parsing error: ", file_path, error_info)
return_rst["parsing_error"] = 1
return return_rst
if sample_id not in self.all2all.keys():
Expand All @@ -273,7 +274,7 @@ def eval_dest_finding(self, file_path):
return_rst["is_easy"] = 1

dst_nodes = self.bfs_get_multi_des(src_node, action_list)
#print(dst_nodes)
# print(dst_nodes)
return_rst["loose_score"] = (
max(
[
Expand Down
33 changes: 22 additions & 11 deletions mango/evaluation/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import os
from multiprocessing import Pool, cpu_count

import fire

from mango.evaluation.map_evaluator import MapEvaluator, TaskType


Expand Down Expand Up @@ -57,7 +55,7 @@ def evaluate_model(rst_dir, map_dir, output_dir, task_type):
for map_name in map_names:
summary_file = os.path.join(output_dir, f"{map_name}_summary.json")
if os.path.exists(summary_file):
with open(summary_file, "r") as f:
with open(summary_file) as f:
all_summaries[map_name] = json.load(f)

easy_all = [v["easy"] for k, v in all_summaries.items()]
Expand Down Expand Up @@ -90,7 +88,7 @@ def get_final_summary(list_of_dict):

sum_up_key = [k for k, v in list_of_dict[0].items() if not k.endswith("average")]
avg_by_total_keys = [
k for k, v in list_of_dict[0].items() if "error" in k and not k in sum_up_key
k for k, v in list_of_dict[0].items() if "error" in k and k not in sum_up_key
]
avg_by_valid_keys = [
k
Expand Down Expand Up @@ -143,11 +141,24 @@ def evaluate_model_route_finding(rst_dir, map_dir, output_dir):


if __name__ == "__main__":
fire.Fire(
{
"evaluate_map_dest_finding": evaluate_map_dest_finding,
"evaluate_model_dest_finding": evaluate_model_dest_finding,
"evaluate_map_route_finding": evaluate_map_route_finding,
"evaluate_model_route_finding": evaluate_model_route_finding,
}
import argparse

parser = argparse.ArgumentParser(description="MANGO evaluation scripts")
parser.add_argument(
"command",
choices=[
"evaluate_map_dest_finding",
"evaluate_model_dest_finding",
"evaluate_map_route_finding",
"evaluate_model_route_finding",
],
)
parser.add_argument("args", nargs="*")
args = parser.parse_args()
func = {
"evaluate_map_dest_finding": evaluate_map_dest_finding,
"evaluate_model_dest_finding": evaluate_model_dest_finding,
"evaluate_map_route_finding": evaluate_map_route_finding,
"evaluate_model_route_finding": evaluate_model_route_finding,
}[args.command]
func(*args.args)
3 changes: 2 additions & 1 deletion mango/evaluation/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from mango.evaluation.utils.utils import edit_distance, parse_raw_output
from mango.evaluation.utils.utils import edit_distance as edit_distance
from mango.evaluation.utils.utils import parse_raw_output as parse_raw_output
12 changes: 6 additions & 6 deletions mango/evaluation/utils/map_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@ def get_game_info(map_dir: str, game_name: str):
locations_path = osp.join(map_dir, game_name, f"{game_name}.locations.json")
walkthrough_path = osp.join(map_dir, game_name, f"{game_name}.walkthrough")
all2all = {}
with open(all2all_path, "r") as f:
with open(all2all_path) as f:
for line in f:
data = json.loads(line)
all2all[data["id"]] = data
all_pairs = {}
with open(all_pairs_path, "r") as f:
with open(all_pairs_path) as f:
for line in f:
data = json.loads(line)
all_pairs[data["id"]] = data
with open(edges_path, "r") as f:
with open(edges_path) as f:
edges = json.load(f)
with open(actions_path, "r") as f:
with open(actions_path) as f:
actions = json.load(f)
with open(locations_path, "r") as f:
with open(locations_path) as f:
locations = json.load(f)
with open(walkthrough_path, "r") as f:
with open(walkthrough_path) as f:
walkthrough = f.read()

G = nx.MultiDiGraph()
Expand Down
Loading
Loading