In [94]:
import sys
sys.path.append('pizza/utils')
from trees import TopSemanticTree, ExpressSemanticTree
from semantic_matchers import is_unordered_exact_match, is_semantics_only_unordered_exact_match, \
                              is_semantics_only_unordered_exact_match_post_ER, is_unordered_exact_match_post_ER, is_semantics_only_unordered_exact_match_post_ER_top_top
from entity_resolution import PizzaSkillEntityResolver

In [95]:
top_pred_string_same = "(ORDER good afternoon i'm in the mood for (PIZZAORDER (NUMBER a ) (SIZE medium ) pizza i'd love (TOPPING pineapple ) on it and i love (STYLE thin crust ) please do not put any (NOT (TOPPING ham ) ) ) on there i don't like ham on pizza )"
exr_string = " (ORDER (PIZZAORDER (SIZE medium ) (TOPPING love ) ) )"

In [96]:
resolver = PizzaSkillEntityResolver()
is_semantics_only_unordered_exact_match_post_ER_top_top(top_pred_string_same, exr_string, resolver)

0.4166666666666667

In [97]:
import csv
import json
import os

def json_to_csv(json_path, output_folder=None):
    """
    Reads a JSON file and creates two CSV files: test_set.csv and correct.csv.

    Parameters:
        json_path (str): Path to the input JSON file.
        output_folder (str): Path to the folder where CSV files will be saved.
                            If None, files are saved in the same folder as the JSON file.
    """
    # Determine the output folder
    if output_folder is None:
        output_folder = os.path.dirname(json_path)

    # Prepare paths for output files
    test_set_path = os.path.join(output_folder, 'test_set.csv')
    correct_path = os.path.join(output_folder, 'correct.csv')

    try:
        # Read JSON file
        with open(json_path, 'r', encoding='utf-8') as f:
            data = [json.loads(line) for line in f]

        # Write test_set.csv
        with open(test_set_path, 'w', encoding='utf-8', newline='') as test_file:
            writer = csv.writer(test_file)
            writer.writerow(['id', 'order'])
            for idx, entry in enumerate(data):
                writer.writerow([idx, entry['test.SRC']])

        # Write correct.csv
        with open(correct_path, 'w', encoding='utf-8', newline='') as correct_file:
            writer = csv.writer(correct_file)
            writer.writerow(['id', 'output'])
            for idx, entry in enumerate(data):
                writer.writerow([idx, entry['test.TOP']])

        print(f"CSV files created: \nTest Set: {test_set_path}\nCorrect: {correct_path}")

    except Exception as e:
        print(f"An error occurred: {str(e)}")
        
json_to_csv('../dataset2/PIZZA_dev.json', output_folder='../mimic_competition')

CSV files created: 
Test Set: ../mimic_competition\test_set.csv
Correct: ../mimic_competition\correct.csv


In [98]:
import csv

def evaluate_accuracy(test_set_path, correct_path):
    """
    Compares sentences from two CSV files using a semantic matching function and calculates accuracy.

    Parameters:
        test_set_path (str): Path to the test_set.csv file.
        correct_path (str): Path to the correct.csv file.

    Prints:
        ACC: Average score across all comparisons.
        Exact match: Percentage of exact matches.
    """
    resolver = PizzaSkillEntityResolver()
    try:
        with open(test_set_path, 'r', encoding='utf-8') as test_file, open(correct_path, 'r', encoding='utf-8') as correct_file:
            test_reader = csv.DictReader(test_file)
            correct_reader = csv.DictReader(correct_file)

            scores = []
            exact_matches = 0

            for test_row, correct_row in zip(test_reader, correct_reader):
                result = test_row['output']
                top = correct_row['output']

                # Call the semantic matching function
                score = is_semantics_only_unordered_exact_match_post_ER_top_top(result, top, resolver)
                scores.append(score)

                if score == 1:
                    exact_matches += 1

            acc = sum(scores) / len(scores)
            exact_match = exact_matches / len(scores)

            print(f"ACC: {acc:.4f}")
            print(f"Exact match: {exact_match:.4f}")

    except Exception as e:
        print(f"An error occurred: {str(e)}")

evaluate_accuracy('../mimic_competition/1.csv', '../mimic_competition/correct.csv')

ACC: 0.7155
Exact match: 0.1606
