In [None]:
# @title Imports
%pip install --quiet google-generativeai\
tabulate

import json
import os
import ast
import google.generativeai as genai
from google.colab import userdata
from google.colab import files
from tabulate import tabulate
from openai import OpenAI


In [None]:
# @title LLM (HuggingFace) API Class

class HF_Qwen_API:
  def __init__(self):
    """
    Configuring and initialising gemini api
    """
    self.client = OpenAI(
        base_url="https://router.huggingface.co/v1",
        api_key=userdata.get("HF_API"),
    )

    self.response = ''

  def get_full_response(self,query,images = []):
    completion = self.client.chat.completions.create(
      #model="deepseek-ai/DeepSeek-V3-0324",
      model="moonshotai/Kimi-K2-Instruct",
      messages=[
          {
              "role": "user",
              "content": f"{query}"
          }
      ],
    )
    self.response = completion.choices[0].message.content
    return self.response


In [None]:
# @title Object Representations
#Representation of each action in the graph
class Action:
  def __init__(self,name, unique_id, tactic_id,technique_id,description,connection):
     self.name = name
     self.unique_id = unique_id
     self.tactic_id = tactic_id
     self.technique_id = technique_id
     self.description = description
     self.assets = []
     self.children = connection
     self.graph_level = -1

  def __eq__(self, other):
        if isinstance(other, Action):
            return self.technique_id == other.technique_id
        return False

  def __str__(self):
        return f"{self.technique_id}"

  def __hash__(self):
      return hash(self.technique_id)

  def addAsset(self, asset):
        self.assets.append(asset)

  def addChild(self, child):
        self.children.append(child)

# Representation of each asset in the graph
class Asset:
  def __init__(self, asset_type, name, unqiue_id, description):
    self.asset_type = asset_type
    self.name = name
    self.unqiue_id = unqiue_id
    self.description = description

  def __str__(self):
    return f"{self.asset_type}: {self.name}"

#Representation of an operator in the graph
class Operator:
  def __init__(self, operator,connection):
    self.operator = operator
    self.children = connection
    self.graph_level = -1

  def addChild(self, child):
    self.children.append(child)

In [None]:
# @title Json Reading
def readTruthJson(data):
  threat_dict = {}
  asset_dict = {}
  operator_dict = {}

  for item in data:
    item_type = item.get("type", "")
    if item_type == "attack-action":
      threat_Action = Action(item.get("name",""),item.get("id",""),item.get("tactic_id",""),item.get("technique_id",""),item.get("description",""),item.get("effect_refs",[]))
      threat_dict[item.get("id","")] = threat_Action
    elif item_type == "attack-operator":
      operator = Operator(item.get("operator",""),item.get("effect_refs",[]))
      operator_dict[item.get("id","")] = operator
    elif item_type == "attack-condition":
      operator = Operator("IF",item.get("on_true_refs",[])+item.get("on_false_refs",[]))
      operator_dict[item.get("id","")] = operator
    elif item_type == "relationship":
      action = threat_dict.get(item["source_ref"], threat_dict.get(item["target_ref"],None))
      asset = asset_dict.get(item["target_ref"], asset_dict.get(item["source_ref"],None))
      if isinstance(action, Action) and isinstance(asset, Asset):
        action.addAsset(asset)
      else:
        operator = operator_dict.get(item["source_ref"],None)
        if isinstance(operator, Operator) and isinstance(action, Action):
          operator.addChild(action.unique_id)
    elif item_type == "attack-flow":
      root_ids = item.get("start_refs")

    else:
      asset = Asset(item.get("type",""),item.get("name",""),item.get("id",""),item.get("description",""))
      if item.get("object_refs", None):
        for ref in item["object_refs"]:
          action = threat_dict.get(ref, None)
          if isinstance(action, Action):
            action.addAsset(asset)
      else:
        asset_dict[item.get("id","")] = asset


  all_objects = threat_dict | operator_dict
  resolve_children(all_objects)

  visited = {}
  for root in root_ids:
      root_action = threat_dict.get(root, operator_dict.get(root))
      assign_graph_levels(root_action, 0, visited)

  cleared_threat_dict= {}

  for key, action in threat_dict.items():
    if action.technique_id != "":
      cleared_threat_dict[key] = action

  return cleared_threat_dict, asset_dict.values()

def readExperimentJson(data):
    threat_dict = {}
    asset_dict = {}
    for item in data:
      label =   item.get("label","")[0]
      action = Action(item.get("action_name",""),label,item.get("tactic_id",""),item.get("technique_id",""),item.get("segment",""),[])
      for index, asset in enumerate(item.get("affected_assets", None)):
        new_asset = Asset(asset,asset,index,asset)
        action.addAsset(new_asset)
        asset_dict[f"{label}_{index}"] = new_asset
      threat_dict[label] = action

      if item.get("prerequisite") == []:
        #if there is no parent, its level 0
        action.graph_level = 0
      else:
        for parents in item.get("prerequisite"):
          parent = threat_dict.get(parents)
          if parent:
            parent.addChild(action.unique_id)
            action.graph_level = parent.graph_level + 1
    return threat_dict, asset_dict.values()

In [None]:
# @title Helper Functions
def resolve_children(all_objects):
    for obj in all_objects.values():
        resolved_children = []
        for child_id in obj.children:
            if child_id in all_objects:
                resolved_children.append(all_objects[child_id])
            else:
                print(f"Warning: child ID {child_id} not found in object list.")
        obj.children = resolved_children

def assign_graph_levels(node, level=0, visited=None, max_depth=100):
    if visited is None:
        visited = {}

    if level > max_depth:
        print(f"Max depth exceeded at node {node}")
        return

    if node in visited and visited[node] >= level:
        return

    visited[node] = level
    node.graph_level = level

    for child in node.children:
        if child is None:
            continue

        is_operator_like = isinstance(child, Operator) or (
            isinstance(child, Action) and not child.technique_id
        )

        next_level = level if is_operator_like else level + 1
        assign_graph_levels(child, next_level, visited, max_depth)


def safe_div(n, d):
    return round(n / d, 3) if d else 0.0

def llm_compare_assets(truth, experiment):
    model = HF_Qwen_API()
    query = f"For each item in the list of experiment results: {experiment}, compare with the full list of truth: {truth}. For each item in the result list, answer with True if it can be matched to an item in the second list, or False if not. Any item in the second list can only be matched once!. Furthermore examine the second list and respond how many objects have not been matched at all.  Give the result as one json list only stating the answer. Dont respond with any other text but the list."
    response = model.get_full_response(query)
    # Clean the response
    cleaned_response = response.strip()
    cleaned_response = cleaned_response.replace("```json\n", "").replace("\n```", "")
    cleaned_response = cleaned_response.replace("true", "True").replace("false", "False")
    return cleaned_response


In [None]:
# @title Evaluation
def evaluateResult(truth_actions_dict, truth_assets, experiment_actions_dict, experiment_assets):
    truth_actions= truth_actions_dict.values()
    experiment_actions = experiment_actions_dict.values()
    truth_set = set(truth_actions)
    experiment_set = set(experiment_actions)

    #Evaluate Actions
    intersection = truth_set & experiment_set
    only_in_truth = truth_set - experiment_set
    only_in_experiment = experiment_set - truth_set
    union = truth_set | experiment_set

    true_positive_actions = len(intersection)
    false_positive_actions = len(only_in_experiment)
    false_negative_actions = len(only_in_truth)
    total_actions = len(union)

    #Evaluate Assets

    asset_experiment_description = [asset.description for asset in experiment_assets if asset.description != ""]
    asset_truth_description = [f"{asset.name}: {asset.description}" for asset in truth_assets if asset.description != ""]

    llm_asset_evaluation  = ast.literal_eval(llm_compare_assets(asset_truth_description,asset_experiment_description))
    asset_unused = llm_asset_evaluation.pop()
    asset_result = llm_asset_evaluation

    true_positive_assets = sum(asset_result)
    false_positive_assets = len(asset_result) - true_positive_assets
    false_negative_assets = asset_unused
    total_assets = true_positive_assets+false_positive_assets+false_negative_assets

    #Evaluate Relationships
    true_positive_relationships = 0
    false_positive_relationships = 0
    false_negative_relationships = 0

    prev_correct_prediction = 0

    for given_index, given_action in enumerate(truth_actions):
      for predicted_index, predicted_action in enumerate(experiment_actions):
        if predicted_action == given_action:
          #Asset Relationship calculation
          for predicted_asset in predicted_action.assets:
            for asset_index, verified_asset in enumerate(experiment_assets):
              if predicted_asset is verified_asset:
                if 0 <= asset_index < len(asset_result) and asset_result[asset_index] == True:
                  true_positive_relationships += 1
                  break
          #Action Relationship calculation
          predicted_graph_level = predicted_action.graph_level
          given_graph_level = given_action.graph_level
          if predicted_graph_level == given_graph_level:
            prev_correct_prediction = predicted_graph_level
            true_positive_relationships += 1
          elif predicted_graph_level > prev_correct_prediction :
            prev_correct_prediction = predicted_graph_level
            true_positive_relationships += 1
          break

    for action in truth_set:
      #If the action is only in the truth, all connections from it are missing
      if action in only_in_truth:
        false_negative_relationships += len(action.children)
        false_negative_relationships += len(action.assets)
    for action in experiment_set:
      #If the action is only in the experiment, all connections are wrong aswell
      if action in only_in_experiment:
        false_positive_relationships += len(action.children)
        false_positive_relationships += len(action.assets)

    total_relationships = true_positive_relationships+false_positive_relationships+false_negative_relationships

    return ({
        "Total Actions" : total_actions,
        "True Positive Actions": true_positive_actions,         # True Prediction
        "False Positive Actions": len(only_in_experiment),  # Existing in Experiment, not in Ground Truth
        "False Negative Actions": false_negative_actions,       # Existing in Ground Truth, not in Experiment

    "Action Accuracy": safe_div(true_positive_actions, total_actions),
    "Action Recall": safe_div(        true_positive_actions,        true_positive_actions + false_negative_actions    ),
    "Action Precision": safe_div(        true_positive_actions,        true_positive_actions + false_positive_actions    ),
    "Action F1 Score": safe_div(        2 * true_positive_actions,        2 * true_positive_actions + false_positive_actions + false_negative_actions    ),
    },
    {
    "Total Assets": total_assets,
    "True Positive Assets": true_positive_assets,          # Correct matches
    "False Positive Assets": false_positive_assets,        # In experiment, not in truth
    "False Negative Assets": false_negative_assets,        # In truth, not in experiment

    "Asset Accuracy": safe_div(true_positive_assets, total_assets),
    "Asset Recall": safe_div(true_positive_assets, true_positive_assets + false_negative_assets),
    "Asset Precision": safe_div(true_positive_assets, true_positive_assets + false_positive_assets),
    "Asset F1 Score": safe_div(        2 * true_positive_assets,        2 * true_positive_assets + false_positive_assets + false_negative_assets    ),
    },
    {
    "Total Relationships": total_relationships,
    "True Positive Relationships": true_positive_relationships,         # Correct matches
    "False Positive Relationships": false_positive_relationships,       # In experiment, not in truth
    "False Negative Relationships": false_negative_relationships,       # In truth, not in experiment

    "Relationship Accuracy": safe_div(true_positive_relationships, total_relationships),
    "Relationship Recall": safe_div(        true_positive_relationships,        true_positive_relationships + false_negative_relationships    ),
    "Relationship Precision": safe_div(        true_positive_relationships,        true_positive_relationships + false_positive_relationships    ),
    "Relationship F1 Score": safe_div(        2 * true_positive_relationships,        2 * true_positive_relationships + false_positive_relationships + false_negative_relationships    ),
    })

In [None]:
print("📥 Upload MITRE Ground Truth")
uploaded_truth = files.upload()
filename_truth = list(uploaded_truth.keys())[0]

with open(filename_truth, 'r') as f:
    data = json.load(f)
    truth_actions, truth_assets = readTruthJson(data["objects"])

print("📥 Upload one or more Experiment Outcome JSONs")
uploaded_experiments = files.upload()

all_results = {}

metrics_sum = {
    "Action": [0] * 8,
    "Asset": [0] * 8,
    "Relationship": [0] * 8
}
file_count = 0

for filename in uploaded_experiments:
    print(f"\n🔍 Comparing with: {filename}")
    with open(filename, 'r') as f:
        try:
            data = json.load(f)
            experiment_actions, experiment_assets = readExperimentJson(data["techniques"])

            action_result, asset_result, relationship_result = evaluateResult(
                truth_actions, truth_assets, experiment_actions, experiment_assets
            )

            all_results[filename] = {
                "action_result": action_result,
                "asset_result": asset_result,
                "relationship_result": relationship_result
            }

            print(f"✅ Comparison complete for {filename}")
            file_count += 1

        except Exception as e:
            print(f"❌ Failed to process {filename}: {e}")

print("\n📊 Summary of all comparisons:")
for name, result in all_results.items():
    print(f"\n📁 {name}")

    rows = []
    for key, label in [("action_result", "Action"), ("asset_result", "Asset"), ("relationship_result", "Relationship")]:
        res = result[key]
        row = [
            label,
            res.get(f"True Positive {label}s", 0),
            res.get(f"False Positive {label}s", 0),
            res.get(f"False Negative {label}s", 0),
            round(res.get(f"{label} Accuracy", 0), 3),
            round(res.get(f"{label} Precision", 0), 3),
            round(res.get(f"{label} Recall", 0), 3),
            round(res.get(f"{label} F1 Score", 0), 3),
        ]
        rows.append(row)

        metrics_sum[label] = [sum_val + row_val for sum_val, row_val in zip(metrics_sum[label], row[1:])]

    headers = ["Type", "TP", "FP", "FN", "Accuracy", "Precision", "Recall", "F1 Score"]
    print(tabulate(rows, headers=headers, tablefmt="grid"))

if file_count > 0:
    avg_rows = []
    for label in ["Action", "Asset", "Relationship"]:
        avg = [round(val / file_count, 3) for val in metrics_sum[label]]
        avg_rows.append([label] + avg)

    print("\n📊 Average Metrics Across All Files:")
    print(tabulate(avg_rows, headers=headers, tablefmt="grid"))


📥 Upload MITRE Ground Truth


Saving SWIFT Heist.json to SWIFT Heist.json
📥 Upload one or more Experiment Outcome JSONs


Saving Swift_Gemini_1.json to Swift_Gemini_1.json
Saving Swift_Gemini_2.json to Swift_Gemini_2.json
Saving Swift_Gemini_3.json to Swift_Gemini_3.json
Saving Swift_Gemini_4.json to Swift_Gemini_4.json
Saving Swift_Gemini_5.json to Swift_Gemini_5.json

🔍 Comparing with: Swift_Gemini_1.json
✅ Comparison complete for Swift_Gemini_1.json

🔍 Comparing with: Swift_Gemini_2.json
✅ Comparison complete for Swift_Gemini_2.json

🔍 Comparing with: Swift_Gemini_3.json
✅ Comparison complete for Swift_Gemini_3.json

🔍 Comparing with: Swift_Gemini_4.json
✅ Comparison complete for Swift_Gemini_4.json

🔍 Comparing with: Swift_Gemini_5.json
✅ Comparison complete for Swift_Gemini_5.json

📊 Summary of all comparisons:

📁 Swift_Gemini_1.json
+--------------+------+------+------+------------+-------------+----------+------------+
| Type         |   TP |   FP |   FN |   Accuracy |   Precision |   Recall |   F1 Score |
| Action       |    2 |    4 |    4 |      0.2   |       0.333 |    0.333 |      0.333 |
+---