In [None]:
from typing import List, Optional

from stickler.structured_object_evaluator import StructuredModel, ComparableField
from stickler.comparators.levenshtein import LevenshteinComparator
from stickler.comparators.exact import ExactComparator
from stickler.comparators.numeric import NumericComparator
from stickler.comparators.fuzzy import FuzzyComparator

class Sender(StructuredModel):
    """Sender information for shipment."""
    company: str = ComparableField(
        comparator=LevenshteinComparator(),
        threshold=0.8,
        weight=1.0,
        description="Sender's company name"
    )
    address: Optional[str] = ComparableField(
        comparator=LevenshteinComparator(),
        threshold=0.7,
        weight=0.8,
        description="Street address of the recipient as it appears",
        default=None
    )
    city: Optional[str] = ComparableField(
        comparator=LevenshteinComparator(),
        threshold=0.7,
        weight=0.8,
        description="City of the recipient",
        default=None
    )


class Recipient(StructuredModel):
    """Recipient information for shipment."""
    company: str = ComparableField(
        comparator=LevenshteinComparator(),
        threshold=0.8,
        weight=1.0,
        description="Recipient's company name as it appears"
    )
    address: Optional[str] = ComparableField(
        comparator=LevenshteinComparator(),
        threshold=0.7,
        weight=0.8,
        description="Street address of the recipient as it appears",
        default=None
    )
    zipCode: Optional[str] = ComparableField(
        comparator=NumericComparator(),
        threshold=1.0,
        weight=0.8,
        description="Zipcode of the recipient's address",
        default=None
    )


class AdditionalCharge(StructuredModel):
    """Additional charge applied to shipment."""
    description: str = ComparableField(
        comparator=LevenshteinComparator(),
        threshold=0.8,
        weight=1.0,
        description="Description of the additional charge as it appears"
    )
    amount: Optional[float] = ComparableField(
        comparator=NumericComparator(),
        threshold=0.95,
        weight=1.0,
        description="Amount of the additional charge",
        default=None
    )


class MostRecentShipment(StructuredModel):
    """Shipment with the most recent date information."""
    sender: Optional[Sender] = ComparableField(
        threshold=0.8,
        weight=1.0,
        description="Sender information",
        default=None
    )
    recipient: Optional[Recipient] = ComparableField(
        threshold=0.8,
        weight=1.0,
        description="Recipient information",
        default=None
    )
    shipmentNumber: str = ComparableField(
        comparator=ExactComparator(),
        threshold=1.0,
        weight=2.0,
        description="Shipment number, used as ID",
        default=None
    )
    additionalCharges: Optional[List[AdditionalCharge]] = ComparableField(
        weight=0.8,
        description="Details of additional charges applied to the shipment",
        default=None
    )


class ShippingInvoice(StructuredModel):
    """Main shipping invoice structure."""
    mostRecentShipment: Optional[MostRecentShipment] = ComparableField(
        threshold=0.8,
        weight=3.0,
        description="Shipment with the most recent date (closest to today's date), always the last item in the list",
        default=None
    )


In [None]:
gt_json = {

    "mostRecentShipment": {
        "sender": {
            "company": "Schumm, Cronin And Grady"
        },
        "recipient": {
            "address": "49418 Renner Key",
            "company": "Stiedemann - Hermann"
        },
        "shipmentNumber": "8186386200",
        "additionalCharges": [{
                "amount": 8.62,
                "description": "Change Of Address Fee"
            },
            {
                "amount": 5.78,
                "description": "Priority Service Charge"
            }
        ]
    }
}

pred_json = {

    "mostRecentShipment": {
        "sender": {
            "company": "Schumm, Cronin And Graddy"  # Typo: Graddy vs Grady
        },
        "recipient": {
            "address": "49418 Renner Key",
            "company": "Stiedemann - Hermann"
        },
        "shipmentNumber": "8186386208", # Last digit typo (it should end with 0 instead of 8) 
        "additionalCharges": [{
                "amount": 8.62,
                "description": "Change Of Address Fees" # Slight workding change
            }       
        # One missing entry in additional Charges                                    
        ]
    }
}

In [6]:
gt_invoice = ShippingInvoice(**gt_json)
pred_invoice = ShippingInvoice(**pred_json)

In [9]:
results = gt_invoice.compare_with(pred_invoice, include_confusion_matrix=True)

In [None]:
# Overall results show that the ground truth ShippingInvoice object does not match the predicted ShippingInvoice object
results['confusion_matrix']['overall']

{'tp': 0,
 'fa': 0,
 'fd': 1,
 'fp': 1,
 'tn': 0,
 'fn': 0,
 'similarity_score': 0.0,
 'all_fields_matched': False,
 'derived': {'cm_precision': 0.0,
  'cm_recall': 0.0,
  'cm_f1': 0.0,
  'cm_accuracy': 0.0}}

In [None]:
# Aggregate results consider the child nodes' performance for a parent node' performance
results['confusion_matrix']['aggregate']

{'tp': 5,
 'fa': 0,
 'fd': 1,
 'fp': 1,
 'tn': 3,
 'fn': 2,
 'derived': {'cm_precision': 0.8333333333333334,
  'cm_recall': 0.7142857142857143,
  'cm_f1': 0.7692307692307692,
  'cm_accuracy': 0.7272727272727273}}

In [None]:
# Field level evaluation shows detailed performance of a field
results['confusion_matrix']['fields']['mostRecentShipment']['fields']['additionalCharges']

{'overall': {'tp': 1,
  'fa': 0,
  'fd': 0,
  'fp': 0,
  'tn': 0,
  'fn': 1,
  'derived': {'cm_precision': 1.0,
   'cm_recall': 0.5,
   'cm_f1': 0.6666666666666666,
   'cm_accuracy': 0.5}},
 'fields': {'description': {'overall': {'tp': 1,
    'fa': 0,
    'fd': 0,
    'fp': 0,
    'tn': 0,
    'fn': 1,
    'derived': {'cm_precision': 1.0,
     'cm_recall': 0.5,
     'cm_f1': 0.6666666666666666,
     'cm_accuracy': 0.5}},
   'aggregate': {'tp': 1,
    'fa': 0,
    'fd': 0,
    'fp': 0,
    'tn': 0,
    'fn': 1,
    'derived': {'cm_precision': 1.0,
     'cm_recall': 0.5,
     'cm_f1': 0.6666666666666666,
     'cm_accuracy': 0.5}}},
  'amount': {'overall': {'tp': 1,
    'fa': 0,
    'fd': 0,
    'fp': 0,
    'tn': 0,
    'fn': 1,
    'derived': {'cm_precision': 1.0,
     'cm_recall': 0.5,
     'cm_f1': 0.6666666666666666,
     'cm_accuracy': 0.5}},
   'aggregate': {'tp': 1,
    'fa': 0,
    'fd': 0,
    'fp': 0,
    'tn': 0,
    'fn': 1,
    'derived': {'cm_precision': 1.0,
     'cm_recal

In [24]:
gt_additional_charges = getattr(getattr(gt_invoice, "mostRecentShipment"), "additionalCharges")
pred_additional_charges = getattr(getattr(pred_invoice, "mostRecentShipment"), "additionalCharges")

In [26]:
gt_additional_charges

[AdditionalCharge(extra_fields={}, description='Change Of Address Fee', amount=8.62),
 AdditionalCharge(extra_fields={}, description='Priority Service Charge', amount=5.78)]

In [28]:
pred_additional_charges

[AdditionalCharge(extra_fields={}, description='Change Of Address Fee', amount=8.62)]

In [30]:
from stickler.structured_object_evaluator.models.hungarian_helper import HungarianHelper
hungarian_helper = HungarianHelper()
hungarian_info = hungarian_helper.get_complete_matching_info(gt_additional_charges, pred_additional_charges)
matched_pairs = hungarian_info["matched_pairs"]
matched_pairs

[(0, 0, np.float64(1.0))]