In [None]:
from typing import Dict, List, Tuple, DefaultDict
import xml.etree.ElementTree as ET
from collections import defaultdict
import os


In [None]:
def parse_pnml(file_path: str) -> Tuple[
    Dict[str, str],                # transitions
    Dict[str, str],                # places
    DefaultDict[str, List[str]],   # arcs
    DefaultDict[str, List[str]],   # place_to_transition
    DefaultDict[str, List[str]],   # transition_to_place
    DefaultDict[str, List[str]]    # base_transition_map
]:
    """Parse a PNML (Petri Net Markup Language) file and extract its structural elements.

    This function reads a PNML file and extracts transitions, places, and arcs, along with
    their relationships and attributes. It handles role names and transition labels, creating
    a comprehensive mapping of the Petri net structure.

    Args:
        file_path (str): Path to the PNML file to be parsed.

    Returns:
        tuple: A 6-tuple containing:
            - transitions (dict): Mapping of transition IDs to their names (with role names if present)
            - places (dict): Mapping of place IDs to their names
            - arcs (defaultdict(list)): Mapping of source IDs to lists of target IDs
            - place_to_transition (defaultdict(list)): Mapping of place IDs to lists of target transition IDs
            - transition_to_place (defaultdict(list)): Mapping of transition IDs to lists of target place IDs
            - base_transition_map (defaultdict(list)): Mapping of base transition names to lists of related transition IDs

    Features:
        - Handles transitions with or without explicit names
        - Manages duplicate transition names by adding suffixes
        - Extracts and incorporates role names from transitionResource elements
        - Preserves the complete network structure through various mappings

    Example:
        >>> file_path = "path/to/process_model.pnml"
        >>> (transitions, places, arcs, 
        ...  place_to_transition, transition_to_place, 
        ...  base_transition_map) = parse_pnml(file_path)
        >>> print(f"Found {len(transitions)} transitions and {len(places)} places")
        Found 10 transitions and 12 places

    Note:
        - Transition names with roles are formatted as "name [rolename: role]"
        - If a transition or place has no name, its ID is used as the name
        - Duplicate transition names are made unique by appending the last 3 characters of their ID
    """
    tree = ET.parse(file_path)
    root = tree.getroot()

    transitions = {}
    places = {}
    arcs = defaultdict(list)
    place_to_transition = defaultdict(list)
    transition_to_place = defaultdict(list)
    base_transition_map = defaultdict(list)

    transition_names = {}

    for transition in root.findall(".//transition"):
        trans_id = transition.get("id")

        name_elem = transition.find(".//name")
        if name_elem is not None:
            text_elem = name_elem.find(".//text")
            if text_elem is not None and text_elem.text:
                name = text_elem.text.strip()
            else:
                name = trans_id
        else:
            name = trans_id

        base_name = name
        if name in transition_names:
            suffix = trans_id[-3:]
            name = f"{name}_{suffix}"

        transition_names[name] = trans_id
        transitions[trans_id] = name
        base_transition_map[base_name].append(trans_id)

    # Collect role names from transitionResource
    transition_roles = {}
    for transition in root.findall(".//transition"):
        trans_id = transition.get("id")
        res_elem = transition.find(".//transitionResource")
        if res_elem is not None:
            role = res_elem.attrib.get("roleName")
            if role:
                transition_roles[trans_id] = role

    # Add role names to transition labels
    for trans_id in transitions:
        role = transition_roles.get(trans_id)
        if role:
            transitions[trans_id] += f" [rolename: {role}]"

    for place in root.findall(".//place"):
        place_id = place.get("id")

        name_elem = place.find(".//name")
        if name_elem is not None:
            text_elem = name_elem.find(".//text")
            if text_elem is not None and text_elem.text:
                place_name = text_elem.text.strip()
            else:
                place_name = place_id
        else:
            place_name = place_id

        places[place_id] = place_name

    for arc in root.findall(".//arc"):
        source = arc.get("source")
        target = arc.get("target")

        arcs[source].append(target)

        if source in places and target in transitions:
            place_to_transition[source].append(target)
        elif source in transitions and target in places:
            transition_to_place[source].append(target)

    return transitions, places, arcs, place_to_transition, transition_to_place, base_transition_map

def find_start_transitions(transitions, places, place_to_transition, transition_to_place):
    start_transitions = []

    for t_id in transitions:
        incoming_places = [
            p for p, t_list in place_to_transition.items()
            if t_id in t_list
        ]
        has_incoming_transition = False

        for p in incoming_places:
            for x, out_places in transition_to_place.items():
                if p in out_places:
                    has_incoming_transition = True
                    break
            if has_incoming_transition:
                break

        if not has_incoming_transition:
            start_transitions.append(t_id)

    return start_transitions

def detect_structure(transitions, places, place_to_transition, transition_to_place, base_transition_map, start_nodes=None):
    structure = []
    visited = set()
    rec_stack = set()

    def get_next_transitions(trans_id):
        next_ts = []
        if trans_id in transition_to_place:
            for pl in transition_to_place[trans_id]:
                next_ts.extend(place_to_transition.get(pl, []))
        return list(set(next_ts))

    def get_incoming_places(trans_id):
        result = []
        for pl, trans_list in place_to_transition.items():
            if trans_id in trans_list:
                result.append(pl)
        return result

    def is_xor_join(trans_id):
        trans_name = transitions[trans_id]
        base_name = trans_name.split(' [')[0].split('_')[0]
        related_trans = base_transition_map.get(base_name, [])
        if len(related_trans) <= 1:
            return False

        incoming_places = get_incoming_places(trans_id)
        if len(incoming_places) <= 1:
            return False

        for place in incoming_places:
            place_targets = place_to_transition[place]
            if not any(t in related_trans for t in place_targets):
                return False

        return True

    def traverse(node, depth=0):
        if node in rec_stack:
            structure.append("    " * depth + f"{transitions[node]} (loop)")
            return

        if node in visited:
            return

        rec_stack.add(node)

        next_nodes = get_next_transitions(node)
        transition_labels = [transitions.get(t, t) for t in next_nodes]

        incoming_places = get_incoming_places(node)
        outgoing_places = transition_to_place.get(node, [])

        if len(incoming_places) > 1:
            if is_xor_join(node):
                place_names = [places.get(pl, pl) for pl in incoming_places]
                structure.append("    " * depth + f"{{ {', '.join(place_names)} }} ->? {transitions[node]} (XOR-join)")
            else:
                place_names = [places.get(pl, pl) for pl in incoming_places]
                structure.append("    " * depth + f"{{ {', '.join(place_names)} }} -> {transitions[node]} (AND-join)")

        if len(outgoing_places) > 1:
            out_trans_set = set()
            for p in outgoing_places:
                out_trans_set.update(place_to_transition.get(p, []))
            structure.append(
                "    " * depth 
                + f"{transitions[node]} ->{{ {', '.join(transitions[t] for t in out_trans_set)} }} (AND-split)"
            )
        elif len(next_nodes) > 1:
            structure.append(
                "    " * depth
                + f"{transitions[node]} ->? {{ {', '.join(transition_labels)} }} (XOR-split)"
            )
        elif len(next_nodes) == 1:
            nxt = next_nodes[0]
            inc_for_nxt = get_incoming_places(nxt)
            if len(inc_for_nxt) > 1:
                pass
            else:
                structure.append("    " * depth + f"{transitions[node]} -> {transitions[nxt]}")
        else:
            structure.append("    " * depth + f"{transitions[node]} -> End")

        for target in next_nodes:
            traverse(target, depth + 1)

        rec_stack.remove(node)
        visited.add(node)

    if not start_nodes:
        discovered_starts = find_start_transitions(transitions, places, place_to_transition, transition_to_place)
        if discovered_starts:
            start_nodes = discovered_starts
        else:
            start_nodes = [next(iter(transitions), None)]

    for s in start_nodes:
        if s is not None:
            traverse(s)
    return "\n".join(structure)



In [None]:
# Example 1: Process a single PNML file
file_path = r".pnml"  # Replace with your PNML file path
base_name = os.path.splitext(os.path.basename(file_path))[0]
output_file = os.path.join(os.path.dirname(file_path), f"{base_name}_output.txt")

# Parse and analyze the PNML file
transitions, places, arcs, place_to_transition, transition_to_place, base_transition_map = parse_pnml(file_path)

# Find and display start transitions
start_trans_list = find_start_transitions(transitions, places, place_to_transition, transition_to_place)
start_trans_output = "Detected start transitions: " + str([transitions[sid] for sid in start_trans_list])

# Generate process structure
process_structure = detect_structure(
    transitions,
    places,
    place_to_transition,
    transition_to_place,
    base_transition_map,
    start_nodes=start_trans_list
)

# Combine outputs
full_output = f"{start_trans_output}\n\n[Process Structure]\n{process_structure}"

# Print to console and save to file
print(full_output)
with open(output_file, "w", encoding="utf-8") as f:
    f.write(full_output)