# Evaluate SDOs extraction accuracy

In [None]:
import json
from rapidfuzz import fuzz
from collections import defaultdict

In [None]:
def load_stix_bundle(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        return json.load(f)
    
output_bundle = load_stix_bundle("../bundles/APT1_merged.json")
ground_truth_bundle = load_stix_bundle("../bundles/APT1_ground-truth.json")

In [None]:
SDO_TYPES = {
    "attack-pattern", "campaign", "course-of-action", "identity", "indicator",
    "intrusion-set", "malware", "observed-data", "report", "threat-actor", "vulnerability", 
    "infrastructure", "sighting", "note", "opinion", "grouping", 
    "incident", "location", "malware-analysis", "tool"
}

def extract_objects(stix_bundle):
    """Extracts all STIX Domain Objects (SDOs) from a STIX bundle."""
    objects = stix_bundle.get("objects", [])
    sdo_objects = [obj for obj in objects if obj.get("type") in SDO_TYPES]
    return sdo_objects


output_sdos = extract_objects(output_bundle)
ground_truth = extract_objects(ground_truth_bundle)

In [None]:
def calculate_sdo_score(output_bundle, ground_truth):
    total = 0
    count = 0
    for sdo in output_bundle:
        for gt in ground_truth:
            name_similarity = fuzz.ratio(sdo.get("name", ""), gt.get("name", ""))
            if  name_similarity > 80:
                type_similarity = fuzz.ratio(sdo.get("type", ""), gt.get("type", ""))
                desc_similarity = fuzz.ratio(sdo.get("description", ""), gt.get("description", ""))
                score = (name_similarity + type_similarity + desc_similarity) / 3

                print(f"\n> {sdo.get("name", "")}")

                print("\nName")
                print(f"Similarity: {name_similarity:.2f}")
                print(f"GT Name: {gt.get("name", "")}")

                print("\nType")
                print(f"Similarity: {type_similarity:.2f}")
                print(f"Model Type: {sdo.get("type", "")}")
                print(f"GT Type: {gt.get("type", "")}")

                print("\nDescription")
                print(f"Similarity: {desc_similarity:.2f}")
                print(f"Model Description:\n{sdo.get("description", "")}")
                print(f"GT Description:\n{gt.get("description", "")}")

                print(f"\nTotal Score: {score:.2f}")
            
                total += score
                count += 1

                break

    normalized_score = (total / len(output_bundle)) / 100

    print(f"\nNormalized Score: {normalized_score:.2f}")
    
    return normalized_score

sdo_score = calculate_sdo_score(output_sdos, ground_truth)