# **Evaluation for GenAI-Powered STIX 2.1 Generator**

**Notebook Version:** 9.0  
**Author:** Giulio Triggiani  
**Python Version:** >= 3.8  
**Key Libraries:** `stix2`, `stix2validator`

---

## **Objective**
The objective is to **quantitatively evaluate** the performance of GenAI_STIX_2_1_Generator, an LLM-based tool for the automatic generation of Cyber Threat Intelligence reports in STIX 2.1 format.  

To measure the effectiveness of the generator, this script performs a comparison between a bundle automatically generated by the tool and a reference bundle (‚Äúground truth‚Äù) manually created by a CTI analyst.  

The analysis leverages the advanced features of the official `stix2` library, in particular a **semantic comparison** is made between the various STIX objects (SDO, SCO, and SRO) to identify not only literal matches, but also matches in meaning.  

The data obtained from this comparison is then used to calculate **standard performance metrics** such as Precision, Recall, and F1-Score, providing a clear and objective assessment of the quality of the generated bundle.

## **Workflow Overview**

1.   **Setup**: installs the Python libraries needed for validating and manipulating STIX objects;
2.   **Libraries and Environment**: imports the required modules and mount Google Drive to access the file and set the threshold for comparison;
3.   **Support Functions**: the main functions that perform validation and object extraction;
4.   **Metrics**: the function that calculates metrics and various averages
5.   **Comparison**: performs the entire process: it browses folders, loads bundles, compares them, and prints the results.
6.   **Part 6**: Generation of analysis charts



## **Part 1**: Setup

This block installs the Python libraries needed for validating and manipulating STIX objects.

In [None]:
# Installation of necessary libraries
print("--- Installing dependencies ---")
!pip install stix2[semantic] --quiet
!pip install stix2-validator --quiet
!pip install rapidfuzz --quiet
print("Installation complete.")

## **Part 2**: Importing Libraries and Setting Up the Environment

This block imports the required modules and mounts Google Drive to access the files, you can also set the threshold for comparison.

The comparison threshold can have a value between 0 and 100. The similarity threshold is the minimum score that two objects must achieve to be considered equivalent. It establishes ‚Äúhow similar they must be‚Äù to be counted as a match (a True Positive).
A value between 80 and 95 is recommended.

NOTE: The first time you will be asked for access to Google Drive.

In [None]:
# Import and setup of the environment
import os
import json
import subprocess
from collections import defaultdict
from copy import deepcopy
import re # Assicurati che 'import re' sia disponibile (preferibilmente all'inizio della cella)

# Import from the stix2 library
from stix2.equivalence.object import object_similarity, object_equivalence
from stix2 import parse, MemoryStore

# To connect to Google Drive in Colab
from google.colab import drive

print("--- Google Drive Mount ---")
drive.mount('/content/drive', force_remount=True)
print("Google Drive mounted correctly.")

# --- CONFIGURATION ---
# Set the base directory on Google Drive where the folders are located
BASE_DRIVE_DIR = '/content/drive/MyDrive/Reports_Evaluation'

# Set the similarity threshold for considering two objects equivalent (from 0 to 100)
SIMILARITY_THRESHOLD = 80

# Set the similarity threshold for relationships (da 0 a 100)
RELATIONSHIP_SIMILARITY_THRESHOLD = 80

## **Part 3**: Support Functions
This cell contains the main functions that perform validation and object extraction.

Bundles are validated according to the STIX 2.1 standard.

In [None]:
import os
import json
from collections import defaultdict
from copy import deepcopy

# Import from validator
from stix2validator import validate_file, print_results

# Import from the stix2 library
from stix2.equivalence.object import object_similarity, object_equivalence

def validate_bundle(file_path):
    """
    Validate an STIX 2.1 bundle using the stix2validator library.
    Returns True if the bundle is valid, otherwise False.
    """
    print(f"\nValidating the bundle: {os.path.basename(file_path)}...")

    # Performs validation and obtains an object with the results
    results = validate_file(file_path)

    if results.is_valid:
        print(f"‚úÖ Validation of {os.path.basename(file_path)} success.")
        return True
    else:
        print(f"‚ùå Validation ERROR for {os.path.basename(file_path)}:")
        # Use the print_results function to display errors in a formatted manner.
        print_results(results)
        return False

def extract_and_categorize_objects(bundle):
    """
    Extracts objects from a bundle and categorizes them into SDO/SCO and SRO.
    """
    sdo_sco_list = []
    sro_list = []
    objects = bundle.get("objects", [])

    for obj in objects:
        obj_type = obj.get("type", "")
        if obj_type in ["relationship", "sighting", "sighting-of"]:
            sro_list.append(obj)
        else:
            sdo_sco_list.append(obj)

    return sdo_sco_list, sro_list

## **Part 4**: Function for calculating metrics
This is the main function for calculating metrics and averages.

Three main metrics are calculated:

*   **Precision**
*   **Recall**
*   **F1-Score**

each at three levels of granularity:

*   **Micro averaging**
*   **Macro averaging**
*   **Weighted averaging**

Specifically:
*   **Precision**: This metric answers the question: ‚ÄúOf all the STIX elements that my tool generated, what fraction was actually correct?‚Äù A high accuracy score indicates that the tool is reliable and does not generate much ‚Äúnoise‚Äù or incorrect information. It is a measure of output quality.
*   **Recall**: This metric answers the question: ‚ÄúOf all the correct STIX elements that were present in the report, what fraction did my tool manage to find?‚Äù A high recall score indicates that the tool is comprehensive and does not omit much relevant information. It is a measure of coverage.
*   **F1-Score**:  The F1-Score is the harmonic mean of Precision and Recall. It provides a single balanced score that takes both aspects into account. It is particularly useful when you want balanced performance, i.e., when it is equally important to minimize noise (high precision) and maximize coverage (high recall). The F1-Score heavily penalizes systems that excel in one metric at the expense of the other.
*   **Micro Averaging**: Micro-averaging gives equal weight to each individual classification decision on an object instance. It answers the question: ‚ÄúConsidering all STIX objects extracted from all reports, what percentage of individual instances were handled correctly?‚Äù This score will be dominated by performance on the most numerous classes. If the tool is very good at extracting Indicators and IPv4 Addresses (which are very common), the Micro-F1 score will be high, even if performance on rare but important classes such as Campaign or Threat Actor is poor. In a multi-class context, micro-average Precision is equal to micro-average Recall, and therefore to the micro-average F1-Score and overall accuracy.
*   **Macro Averaging**: Macro-averaging gives equal weight to each class, regardless of its frequency. It answers the question: ‚ÄúOn average, how does my tool perform on different types of CTI concepts?‚Äù This score treats the F1-Score for the Threat-Actor class (which may be based on only a few instances) as equally important as the F1-Score for the Indicator class class (based on hundreds of instances). It is a much more indicative measure of the tool's versatility and its ability to handle even the rarest and most difficult types of entities. If the tool performs poorly on a minority class, the Macro-F1 score will be significantly affected.
*   **Weighted-Averaging**: This method represents a compromise between Micro and Macro. It attempts to take into account class imbalance (like Micro) while calculating metrics for each type (like Macro). In many scenarios, the Weighted-average result will be very close to that of Micro-average, as both end up giving more importance to classes with more instances.

In [None]:
# Functions for Calculating and Printing Metrics
import pandas as pd
import matplotlib.pyplot as plt

def print_bundle_summary_metrics(results_dict, bundle_name):
    """
    Calculate and print only the general metrics for a single bundle.
    *** MODIFIED: Now returns the calculated metrics ***
    """
    total_tp = sum(c['TP'] for c in results_dict.values())
    total_fp = sum(c['FP'] for c in results_dict.values())
    total_fn = sum(c['FN'] for c in results_dict.values())

    precision_micro = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
    recall_micro = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
    f1_micro = 2 * (precision_micro * recall_micro) / (precision_micro + recall_micro) if (precision_micro + recall_micro) > 0 else 0.0

    print(f"\n--- OVERALL METRICS FOR BUNDLE: {bundle_name} ---")
    print(f"  - Overall Precision: {precision_micro:.2f}")
    print(f"  - Overall Recall:    {recall_micro:.2f}")
    print(f"  - Overall F1-Score:  {f1_micro:.2f}")

    # Return the calculated values
    return precision_micro, recall_micro, f1_micro

def calculate_and_print_metrics(total_results_per_object):
    """
    Calculate and print the COMPLETE final report, including metrics by class
    and aggregate averages.
    *** MODIFIED TO RETURN THE METRICS DICTIONARIES FOR PLOTTING ***
    """
    metrics_per_class = {}

    # Calculation of metrics for each individual class of objects
    for obj_type, counts in total_results_per_object.items():
        tp = counts['TP']
        fp = counts['FP']
        fn = counts['FN']

        # Precision
        if (tp + fp) > 0:
            precision = tp / (tp + fp)
        else:
            precision = 0.0

        # Recall
        if (tp + fn) > 0:
            recall = tp / (tp + fn)
        else:
            recall = 0.0

        # F1-Score
        if (precision + recall) > 0:
            f1_score = 2 * (precision * recall) / (precision + recall)
        else:
            f1_score = 0.0

        # Support (number of actual instances for the class in the ground truth)
        support = tp + fn

        metrics_per_class[obj_type] = {
            'Precision': precision,
            'Recall': recall,
            'F1-Score': f1_score,
            'Support': support
        }

    # Calculation of averages
    total_tp = sum(c['TP'] for c in total_results_per_object.values())
    total_fp = sum(c['FP'] for c in total_results_per_object.values())
    total_fn = sum(c['FN'] for c in total_results_per_object.values())
    total_support = sum(m['Support'] for m in metrics_per_class.values())

    # --- Micro Averages ---
    if (total_tp + total_fp) > 0:
        precision_micro = total_tp / (total_tp + total_fp)
    else:
        precision_micro = 0.0
    if (total_tp + total_fn) > 0:
        recall_micro = total_tp / (total_tp + total_fn)
    else:
        recall_micro = 0.0
    if (precision_micro + recall_micro) > 0:
        f1_micro = 2 * (precision_micro * recall_micro) / (precision_micro + recall_micro)
    else:
        f1_micro = 0.0

    # --- Macro Average ---
    num_classes = len(metrics_per_class)
    if num_classes > 0:
        precision_macro = sum(m['Precision'] for m in metrics_per_class.values()) / num_classes
        recall_macro = sum(m['Recall'] for m in metrics_per_class.values()) / num_classes
        f1_macro = sum(m['F1-Score'] for m in metrics_per_class.values()) / num_classes
    else:
        precision_macro, recall_macro, f1_macro = 0.0, 0.0, 0.0

    # --- Weighted Average ---
    if total_support > 0:
        precision_weighted = sum(m['Precision'] * m['Support'] for m in metrics_per_class.values()) / total_support
        recall_weighted = sum(m['Recall'] * m['Support'] for m in metrics_per_class.values()) / total_support
        f1_weighted = sum(m['F1-Score'] * m['Support'] for m in metrics_per_class.values()) / total_support
    else:
        precision_weighted, recall_weighted, f1_weighted = 0.0, 0.0, 0.0

    # *** NEW: Store aggregate metrics in a dictionary ***
    aggregate_metrics = {
        'MICRO AVG': {'Precision': precision_micro, 'Recall': recall_micro, 'F1-Score': f1_micro, 'Support': total_support},
        'MACRO AVG': {'Precision': precision_macro, 'Recall': recall_macro, 'F1-Score': f1_macro, 'Support': total_support},
        'WEIGHTED AVG': {'Precision': precision_weighted, 'Recall': recall_weighted, 'F1-Score': f1_weighted, 'Support': total_support}
    }

    # --- Printing results in a table (unchanged) ---
    print("\n" + "="*70); print("üìà PERFORMANCE METRICS REPORT üìà".center(70)); print("="*70)
    print(f"| {'CLASS':<25} | {'PRECISION':>9} | {'RECALL':>9} | {'F1-SCORE':>9} | {'SUPPORT':>9} |")
    print("|" + "-"*68 + "|")

    for obj_type, metrics in sorted(metrics_per_class.items()):
        p = metrics['Precision']
        r = metrics['Recall']
        f1 = metrics['F1-Score']
        s = metrics['Support']

        # *** QUESTA E' LA RIGA CORRETTA ***
        print(f"| {obj_type:<25} | {p:>8.2f} | {r:>8.2f} | {f1:>8.2f} | {s:>9} |")

    print("|" + "-"*68 + "|")
    print("|" + " "*68 + "|")

    print(f"| {'MICRO AVG':<25} | {precision_micro:>8.2f} | {recall_micro:>8.2f} | {f1_micro:>8.2f} | {total_support:>9} |")
    print(f"| {'MACRO AVG':<25} | {precision_macro:>8.2f} | {recall_macro:>8.2f} | {f1_macro:>8.2f} | {total_support:>9} |")
    print(f"| {'WEIGHTED AVG':<25} | {precision_weighted:>8.2f} | {recall_weighted:>8.2f} | {f1_weighted:>8.2f} | {total_support:>9} |")

    print("="*70)

    # *** NEW: Return the calculated data ***
    return metrics_per_class, aggregate_metrics

## **Part 5**: Main Logic of Comparison
This is the main cell that performs the entire process: it browses folders, loads bundles, compares them, and prints the results.

The objects generated by the tool can belong to three categories:
*   **TP** (**True Positive**): objects correctly identified by the tool that match the ground truth;
*   **FP** (**False Positive**): objects that have been identified by the tool but are not part of the ground truth;
*   **FN** (**False Negative**): objects that have not been identified by the tool but are part of the ground truth.

Furthermore, TP, FP, and FN are also divided according to the type of object to which they belong in order to achieve a further level of granularity.

SDOs/SCOs and SROs are compared in two different ways:
*   SDOs/SCOs: for each object generated by the tool, the script searches for a
semantically equivalent object in the ground truth objects using the function `object_equivalence(gen_obj, exp_obj, threshold=SIMILARITY_THRESHOLD)`, a match indicates that a **TP** has been found (the objects found are removed from their respective lists to avoid double counting). At the end of the comparison, all objects in the list generated by the tool represent an **FP**, while those in the ground truth list represent an **FN**.
*   SRO: In this case, the comparison is made by comparing the **source object**, the **destination object**, and the **type of relationship**. Each is assigned a **weight** (40 and 40 for the two source and destination objects and 20 for the type of relationship). If the similarity threshold is set to 80, the two objects, source and destination, only need to be the same to obtain a similarity. The calculation of FP and FN is similar to the previous one.

For indicator objects containing **YARA rules**, the comparison is customized. Specifically, the text contained in the pattern property is extracted and normalized, and then the two strings are compared.

NOTE: All STIX metadata objects such as `reports` and `marking-definitions` are excluded from metric calculations.

In [None]:
def main_comparison_logic(generated_bundle_path, expert_bundle_path):
    """
    Performs comparison using weighted logic for relationships.
    Excludes indicators with non-STIX patterns, metadata objects and their associated relationships.
    """
    METADATA_OBJECTS_TO_EXCLUDE = ['report', 'marking-definition']

    # STEP 1: Validate the bundles before proceeding
    if not validate_bundle(generated_bundle_path) or not validate_bundle(expert_bundle_path):
        return {}  # <- Backslash rimosso
    with open(generated_bundle_path, 'r') as f: generated_bundle_json = json.load(f) # <- Backslash rimosso
    with open(expert_bundle_path, 'r') as f: expert_bundle_json = json.load(f) # <- Backslash rimosso

    # STEP 2: Division of objects into lists
    generated_sdo_sco, generated_sro = extract_and_categorize_objects(generated_bundle_json)
    expert_sdo_sco, expert_sro = extract_and_categorize_objects(expert_bundle_json)

    # STEP 3: Identify metadata objects and their IDs
    metadata_objects_gen = [obj for obj in generated_sdo_sco if obj.get('type') in METADATA_OBJECTS_TO_EXCLUDE]
    metadata_object_ids_gen = {obj['id'] for obj in metadata_objects_gen}

    metadata_objects_exp = [obj for obj in expert_sdo_sco if obj.get('type') in METADATA_OBJECTS_TO_EXCLUDE]
    metadata_object_ids_exp = {obj['id'] for obj in metadata_objects_exp}

    # STEP 4: Isolation of special indicators (MODIFIED - NO LONGER EXCLUDES BY TYPE)
    # We keep the structure but make the filters ineffective for pattern_type
    special_indicators_gen = []
    special_indicator_ids_gen = set()

    special_indicators_exp = []
    special_indicator_ids_exp = set()

    print(f"\\n‚ÑπÔ∏è  Nessun indicatore escluso in base al pattern_type.")

    # STEP 5: Create a unique set of all IDs to exclude
    all_excluded_ids_gen = metadata_object_ids_gen.union(special_indicator_ids_gen)
    all_excluded_ids_exp = metadata_object_ids_exp.union(special_indicator_ids_exp)

    # STEP 6: Filtering all objects and relationships to be compared
    sdo_sco_to_compare = [obj for obj in generated_sdo_sco if obj not in metadata_objects_gen and obj not in special_indicators_gen]
    sdo_sco_to_compare_exp = [obj for obj in expert_sdo_sco if obj not in metadata_objects_exp and obj not in special_indicators_exp]

    # STEP 7: Excludes relationships that connect to ANY excluded object
    sro_to_compare = [rel for rel in generated_sro if rel.get('source_ref') not in all_excluded_ids_gen and rel.get('target_ref') not in all_excluded_ids_gen]
    sro_to_compare_exp = [rel for rel in expert_sro if rel.get('source_ref') not in all_excluded_ids_exp and rel.get('target_ref') not in all_excluded_ids_exp]

    # STEP 8: Initialization of the dictionary for TP, FP, and FN
    results = defaultdict(lambda: {'TP': 0, 'FP': 0, 'FN': 0});
    id_map = {}

    unmatched_generated_sdo_sco = deepcopy(sdo_sco_to_compare)
    search_pool_expert_sdo_sco = deepcopy(sdo_sco_to_compare_exp)
    unmatched_generated_sro = deepcopy(sro_to_compare)
    search_pool_expert_sro = deepcopy(sro_to_compare_exp)

    print(f"\n--- Start comparison with similarity threshold >= {SIMILARITY_THRESHOLD} ---")

    # STEP 9: Comparison of all eligible SDOs/SCOs (REVISED LOGIC with Custom Indicator Comparison)
    print("Comparing allowed SDO/SCO...")

    # Copies of the original lists (filtered) from which we will remove the elements
    # These will ultimately be used to calculate FP and FN
    unmatched_generated_sdo_sco = deepcopy(sdo_sco_to_compare)
    search_pool_expert_sdo_sco = deepcopy(sdo_sco_to_compare_exp)

    # Set of IDs of objects already matched to avoid double counting
    matched_gen_ids = set()
    matched_exp_ids = set()

    # Helper function to normalise the content of YARA brackets
    def normalize_braces_content(text):
        def replacer(match):
            content = match.group(1) # <- Backslash rimosso
            # Remove special characters and convert to lowercase
            cleaned_content = re.sub(r'[^a-z0-9\\s]', '', content.lower()) # <- Backslash rimosso
            # Compress multiple spaces
            normalized_content = ' '.join(cleaned_content.split()) # <- Backslash rimosso
            return '{' + normalized_content + '}' # <- Backslash rimosso
        # Apply the replacement to all occurrences of {...}
        return re.sub(r'\\{(.*?)\\}', replacer, text, flags=re.DOTALL) # <- Backslash rimosso

    for gen_obj in sdo_sco_to_compare:
        if gen_obj['id'] in matched_gen_ids:
            continue

        best_match_expert_obj = None

        for exp_obj in sdo_sco_to_compare_exp:
            if exp_obj['id'] in matched_exp_ids:
                continue

            if gen_obj.get('type') == exp_obj.get('type'):
                is_equivalent = False

                # SPECIFIC COMPARISON LOGIC FOR INDICATORS
                if gen_obj.get('type') == 'indicator':
                    pattern_gen = gen_obj.get('pattern', '')
                    pattern_exp = exp_obj.get('pattern', '')
                    pattern_type_gen = gen_obj.get('pattern_type', 'stix')
                    pattern_type_exp = exp_obj.get('pattern_type', 'stix')

                    if pattern_type_gen == pattern_type_exp:
                        if pattern_type_gen == 'stix':
                            norm_gen = pattern_gen.strip().lower()
                            norm_exp = pattern_exp.strip().lower()
                            if norm_gen == norm_exp:
                                is_equivalent = True
                        elif pattern_type_gen == 'yara':
                            # Customised normalisation for YARA
                            # 1. Minuscule
                            norm_gen = pattern_gen.lower()
                            norm_exp = pattern_exp.lower()
                            # 2. Remove comments //...
                            norm_gen = re.sub(r'//.*', '', norm_gen)
                            norm_exp = re.sub(r'//.*', '', norm_exp)
                            # 3. Normalise whitespace -> single space
                            norm_gen = ' '.join(norm_gen.split())
                            norm_exp = ' '.join(norm_exp.split())

                            norm_gen_braces = normalize_braces_content(norm_gen)
                            norm_exp_braces = normalize_braces_content(norm_exp)
                            if norm_gen_braces == norm_exp_braces:
                                is_equivalent = True

                        # Default handling for other pattern_type
                        else:
                             norm_gen = pattern_gen.strip().lower()
                             norm_exp = pattern_exp.strip().lower()
                             if norm_gen == norm_exp:
                                is_equivalent = True

                else:
                    try:
                        if object_equivalence(gen_obj, exp_obj, threshold=SIMILARITY_THRESHOLD):
                            is_equivalent = True
                    except Exception as e:
                         pass

                # If you find a valid match, record the IDs and exit the internal loop.
                if is_equivalent:
                    obj_type = gen_obj['type']
                    results[obj_type]['TP'] += 1
                    id_map[gen_obj['id']] = exp_obj['id']
                    matched_gen_ids.add(gen_obj['id'])
                    matched_exp_ids.add(exp_obj['id'])
                    best_match_expert_obj = exp_obj
                    break

    # After the cycles, the 'unmatched_generated_sdo_sco' and 'search_pool_expert_sdo_sco' lists
    # still contain all the initial objects. Now we recalculate them correctly.
    final_unmatched_gen = [obj for obj in sdo_sco_to_compare if obj['id'] not in matched_gen_ids]
    final_unmatched_exp = [obj for obj in sdo_sco_to_compare_exp if obj['id'] not in matched_exp_ids]

    # Update the lists that will be used later (for SRO and final FP/FN calculation)
    # Overwrite the original lists defined in STEP 8 with the newly calculated ones
    unmatched_generated_sdo_sco = final_unmatched_gen
    search_pool_expert_sdo_sco = final_unmatched_exp

    # ======================= STEP 10: NEW LOGIC OF COMPARISON FOR RELATIONSHIPS =======================
    print("Esecuzione del confronto PONDERATO per SRO consentito...")

    # Let's define the weights for each component of the relationship
    weights = {'source_ref': 40, 'target_ref': 40, 'relationship_type': 20}

    for gen_rel in sro_to_compare:
        best_match_expert_rel = None
        highest_score = -1

        for exp_rel in search_pool_expert_sro:
            current_score = 0
            # Compare ‚Äòsource_ref‚Äô using id_map
            if id_map.get(gen_rel.get('source_ref')) == exp_rel.get('source_ref'):
                current_score += weights['source_ref']

            # Compare ‚Äòtarget_ref‚Äô using id_map
            if id_map.get(gen_rel.get('target_ref')) == exp_rel.get('target_ref'):
                current_score += weights['target_ref']

            # Compare 'tipo_relazione'
            if gen_rel.get('relationship_type') == exp_rel.get('relationship_type'):
                current_score += weights['relationship_type']

            if current_score > highest_score:
                highest_score = current_score
                best_match_expert_rel = exp_rel

        # Check whether the best score exceeds the threshold
        if highest_score >= RELATIONSHIP_SIMILARITY_THRESHOLD:
            results['relationship']['TP'] += 1
            # Removes matched items so they don't need to be checked again
            unmatched_generated_sro.remove(gen_rel)
            if best_match_expert_rel in search_pool_expert_sro:
                 search_pool_expert_sro.remove(best_match_expert_rel)

    # ================================= END OF NEW LOGIC ==================================

    # STEP 11: Calculation of FP and FN
    print("Step 3: Final calculation of FP and FN...")
    for unmatched_obj in unmatched_generated_sdo_sco + unmatched_generated_sro:
        results[unmatched_obj['type']]['FP'] += 1
    for unmatched_obj in search_pool_expert_sdo_sco + search_pool_expert_sro:
        results[unmatched_obj['type']]['FN'] += 1

    print("\\n--- Calculation Results (TP, FP, FN) ---");
    print(json.dumps(dict(results), indent=2))
    return dict(results)

# SCRIPT EXECUTION

# Initialize the variables before starting the loop.
total_results_per_object = defaultdict(lambda: {'TP': 0, 'FP': 0, 'FN': 0})
grand_totals = {'TP': 0, 'FP': 0, 'FN': 0}

# *** NEW: Initialize list to store metrics for each bundle ***
all_bundle_metrics = []


if not os.path.exists(BASE_DRIVE_DIR):
    print(f"‚ùå ERROR: The specified directory does not exist: {BASE_DRIVE_DIR}")
else:
    for dirname in sorted(os.listdir(BASE_DRIVE_DIR)):
        dirpath = os.path.join(BASE_DRIVE_DIR, dirname)
        if os.path.isdir(dirpath):
            print(f"\\n{'='*15} Analyzing the folder: {dirname} {'='*15}")
            ground_truth_path, predicted_path = None, None
            try:
                for filename in os.listdir(dirpath):
                    if filename.endswith("_gt.json"):
                        ground_truth_path = os.path.join(dirpath, filename)
                    elif filename.endswith("_pred.json"):
                        predicted_path = os.path.join(dirpath, filename)
            except FileNotFoundError:
                print(f"Unable to access the folder '{dirname}'.")
                continue

            if ground_truth_path and predicted_path:
                print(f"Ground Truth found: {os.path.basename(ground_truth_path)}")
                print(f"Found Expected Bundle: {os.path.basename(predicted_path)}")

                # Call the function and save the result of the single bundle
                bundle_results = main_comparison_logic(predicted_path, ground_truth_path)

                if bundle_results:
                    # *** NEW: Calculate, print, AND capture metrics for the SINGLE bundle ***
                    p_micro, r_micro, f1_micro = print_bundle_summary_metrics(bundle_results, dirname)

                    # *** NEW: Add captured metrics to our list ***
                    all_bundle_metrics.append({
                        'bundle': dirname,
                        'Precision': p_micro,
                        'Recall': r_micro,
                        'F1-Score': f1_micro
                    })

                    # Update totals by object type
                    for obj_type, counts in bundle_results.items():
                        total_results_per_object[obj_type]['TP'] += counts['TP']
                        total_results_per_object[obj_type]['FP'] += counts['FP']
                        total_results_per_object[obj_type]['FN'] += counts['FN']
            else:
                print(f"Bundle pair not found.")

# Calculate the grand total by adding up the totals for each type of item.
for obj_type, counts in total_results_per_object.items():
    grand_totals['TP'] += counts['TP']
    grand_totals['FP'] += counts['FP']
    grand_totals['FN'] += counts['FN']

# --- Printing Final Results ---

# Print detailed totals by object type (raw data)
print("\\n" + "="*50)
print("üìä TOTAL RESULTS PER STIX OBJECT TYPE (RAW DATA) üìä")
print("="*50)
print(json.dumps(total_results_per_object, indent=2))

# Print the total (raw data)
print("\\n" + "="*50)
print("üèÜ GRAND TOTAL (ALL OBJECTS) (RAW DATA) üèÜ")
print("="*50)
print(json.dumps(grand_totals, indent=2))

# Call the function to calculate, print, and return metrics.
metrics_per_class, aggregate_metrics = calculate_and_print_metrics(total_results_per_object)

print("\n\n‚úÖ Dati metrici catturati e pronti per la visualizzazione.")

## **Part 6**: Generation of analysis charts

Creating graphs for statistical analysis:



*   **Aggregate metrics chart:** Shows the overall performance of the model according to the three different methods of calculating the average.
*   **Performance chart by STIX object type:** It shows in detail how the model performs on each STIX object class, allowing you to identify strengths and weaknesses.
*   **Distribution analysis chart:** It shows how unbalanced the dataset is: there are many instances of some classes and very few of others.
*   **Error analysis chart:** This chart uses raw data to show the composition of ‚Äúsuccesses‚Äù and ‚Äúerrors‚Äù for each class.
*   **Performance variability chart by bundle:** Analyze how consistent the generator is by analyzing the metrics of each individual bundle, showing the variability between the different types of reports.


Save the results to an output folder.


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import os
import json

print("--- Avvio Creazione Grafici e Salvataggio Dati ---")

# --- 1. IMPOSTAZIONE AMBIENTE ---

# Definisci la cartella di output (verr√† creata se non esiste)
OUTPUT_DIR = os.path.join(BASE_DRIVE_DIR, "Risultati_Grafici_Valutazione")
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Tutti i file verranno salvati in: {OUTPUT_DIR}")

# --- 2. PREPARAZIONE DATI ---

# Converte i dizionari delle metriche in DataFrames Pandas
# Dati grezzi TP/FP/FN
df_raw_counts = pd.DataFrame.from_dict(total_results_per_object, orient='index').fillna(0).astype(int)
df_raw_counts = df_raw_counts.reindex(columns=['TP', 'FP', 'FN'], fill_value=0) # Assicura colonne

# Dati metriche per classe
df_metrics_class = pd.DataFrame.from_dict(metrics_per_class, orient='index')

# Dati metriche aggregate
df_metrics_agg = pd.DataFrame.from_dict(aggregate_metrics, orient='index')

# *** NEW: Dati metriche per singolo bundle ***
df_bundle_metrics = pd.DataFrame(all_bundle_metrics)


# --- 3. CREAZIONE E SALVATAGGIO GRAFICI ---

# --- GRAFICO 1: Metriche Aggregate (Micro, Macro, Weighted) ---
try:
    plt.figure(figsize=(10, 6))
    df_metrics_agg[['Precision', 'Recall', 'F1-Score']].plot(kind='bar', rot=0)
    plt.title('Performance Aggregate (Micro, Macro, Weighted)')
    plt.ylabel('Punteggio')
    plt.xlabel('Tipo di Media')
    plt.ylim(0, 1.0)
    plt.legend(loc='lower right')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plot_path_1 = os.path.join(OUTPUT_DIR, "1_grafico_metriche_aggregate.png")
    plt.savefig(plot_path_1)
    plt.show()
    print(f"‚úÖ Grafico 1 salvato: 1_grafico_metriche_aggregate.png")
except Exception as e:
    print(f"‚ùå Errore creazione Grafico 1: {e}")

# --- GRAFICO 2: F1-Score per Tipo di Oggetto (solo con Support > 0) ---
try:
    df_metrics_class_filtered = df_metrics_class[df_metrics_class['Support'] > 0].sort_values('F1-Score', ascending=False)
    plt.figure(figsize=(15, 8))
    df_metrics_class_filtered[['Precision', 'Recall', 'F1-Score']].plot(kind='bar')
    plt.title('Performance per Tipo di Oggetto STIX (con Support > 0)')
    plt.ylabel('Punteggio')
    plt.xlabel('Tipo Oggetto STIX')
    plt.ylim(0, 1.0)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.legend(loc='upper right')
    plt.tight_layout()
    plot_path_2 = os.path.join(OUTPUT_DIR, "2_grafico_performance_per_classe.png")
    plt.savefig(plot_path_2)
    plt.show()
    print(f"‚úÖ Grafico 2 salvato: 2_grafico_performance_per_classe.png")
except Exception as e:
    print(f"‚ùå Errore creazione Grafico 2: {e}")

# --- GRAFICO 3: Distribuzione del Supporto (Conteggio istanze Ground Truth) ---
try:
    df_support = df_metrics_class[df_metrics_class['Support'] > 0]['Support'].sort_values(ascending=False)
    plt.figure(figsize=(15, 8))
    df_support.plot(kind='bar', color='skyblue')
    plt.title('Distribuzione del Supporto per Classe (N. Istanze nel Ground Truth)')
    plt.ylabel('Numero di Istanze (Support)')
    plt.xlabel('Tipo Oggetto STIX')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plot_path_3 = os.path.join(OUTPUT_DIR, "3_grafico_distribuzione_supporto.png")
    plt.savefig(plot_path_3)
    plt.show()
    print(f"‚úÖ Grafico 3 salvato: 3_grafico_distribuzione_supporto.png")
except Exception as e:
    print(f"‚ùå Errore creazione Grafico 3: {e}")

# --- GRAFICO 4: Analisi Errori (TP, FP, FN) per Classe ---
try:
    # Filtra solo classi con almeno un TP, FP, o FN per evitare grafici vuoti
    df_raw_filtered = df_raw_counts[(df_raw_counts['TP'] > 0) | (df_raw_counts['FP'] > 0) | (df_raw_counts['FN'] > 0)]
    df_raw_filtered = df_raw_filtered.sort_values(['FN', 'FP'], ascending=[False, False])

    plt.figure(figsize=(15, 10))
    df_raw_filtered.plot(kind='bar', stacked=True, figsize=(15, 8), colormap='viridis')
    plt.title('Analisi Errori: True Positives (TP), False Positives (FP), False Negatives (FN)')
    plt.ylabel('Conteggio Oggetti')
    plt.xlabel('Tipo Oggetto STIX')
    plt.legend(loc='upper right')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plot_path_4 = os.path.join(OUTPUT_DIR, "4_grafico_analisi_errori_tp_fp_fn.png")
    plt.savefig(plot_path_4)
    plt.show()
    print(f"‚úÖ Grafico 4 salvato: 4_grafico_analisi_errori_tp_fp_fn.png")
except Exception as e:
    print(f"‚ùå Errore creazione Grafico 4: {e}")

# --- *** NEW: GRAFICO 5: Variabilit√† Performance per Bundle *** ---
try:
    if not df_bundle_metrics.empty:
        df_bundle_metrics_sorted = df_bundle_metrics.sort_values('F1-Score', ascending=False)
        plt.figure(figsize=(15, 8))
        df_bundle_metrics_sorted.plot(kind='bar', x='bundle', y='F1-Score', color='coral')
        plt.title('Variabilit√† Performance (F1-Score) per Bundle')
        plt.ylabel('F1-Score (Micro Avg)')
        plt.xlabel('Bundle')
        plt.ylim(0, 1.0)
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.tight_layout()
        plot_path_5 = os.path.join(OUTPUT_DIR, "5_grafico_variabilita_per_bundle.png")
        plt.savefig(plot_path_5)
        plt.show()
        print(f"‚úÖ Grafico 5 salvato: 5_grafico_variabilita_per_bundle.png")
    else:
        print("‚ÑπÔ∏è Grafico 5 saltato: nessun dato sui bundle individuali √® stato catturato.")
except Exception as e:
    print(f"‚ùå Errore creazione Grafico 5: {e}")


# --- 4. SALVATAGGIO DATI GREZZI ---

try:
    # Salva i CSV
    df_metrics_class.to_csv(os.path.join(OUTPUT_DIR, "report_metriche_per_classe.csv"))
    df_metrics_agg.to_csv(os.path.join(OUTPUT_DIR, "report_metriche_aggregate.csv"))
    df_raw_counts.to_csv(os.path.join(OUTPUT_DIR, "report_conteggi_grezzi_tp_fp_fn.csv"))
    # *** NEW: Salva CSV metriche bundle ***
    df_bundle_metrics.to_csv(os.path.join(OUTPUT_DIR, "report_metriche_per_bundle.csv"), index=False)

    # Salva i JSON originali
    with open(os.path.join(OUTPUT_DIR, "dati_totali_per_oggetto.json"), 'w') as f:
        json.dump(total_results_per_object, f, indent=2)

    with open(os.path.join(OUTPUT_DIR, "dati_grand_total.json"), 'w') as f:
        json.dump(grand_totals, f, indent=2)

    print(f"‚úÖ Tutti i file di dati (CSV e JSON) sono stati salvati.")

except Exception as e:
    print(f"‚ùå Errore durante il salvataggio dei file di dati: {e}")

print("--- Processo completato. ---")