In [8]:
import os
import re
import random
import pickle
import xml.etree.ElementTree as ET
from collections import defaultdict, Counter
import multiprocessing
from multiprocessing import cpu_count
import time
import sys

#Configuration
DATASET_FOLDER = "/Users/dhritichandan/Desktop/SFILES_2/Dhriti_Test/new_dataset/graphml_10k_dataset"
OUTPUT_FILE = "pid_models.pkl"
MAX_NEIGHBORS_TO_SAMPLE = 75
NUM_TOP_PATHS_TO_FIND = 5

def get_node_type(node_id):
    """Generalizes a node ID by removing numbers."""
    clean_id = re.sub(r'[-\d]+', '', node_id)
    return clean_id if clean_id else "unknown"

def create_blanket_entry():
    """Helper function to create the default dictionary structure for pickling."""
    return {"parents": Counter(), "children": Counter(), "spouses": Counter()}

def process_single_file(file_path):
    """
    Worker function that parses AND analyzes a single GraphML file.
    Returns the statistical models for just this one file.
    """
    local_transitions = defaultdict(Counter)
    # Fixed: Use named function instead of lambda for pickling support
    local_blanket = defaultdict(create_blanket_entry)
    local_paths = []
    
    # --- 1. Parse Graph ---
    try:
        tree = ET.parse(file_path)
        root = tree.getroot()
        ns = {"graphml": "http://graphml.graphdrawing.org/xmlns"}
        
        # Build adjacency lists for this file locally
        outgoing = defaultdict(list)
        incoming = defaultdict(list)
        nodes = {}

        for node in root.findall(".//graphml:node", ns):
            n_id = node.attrib["id"]
            n_type = get_node_type(n_id)
            nodes[n_id] = n_type
            
        for edge in root.findall(".//graphml:edge", ns):
            source = edge.attrib["source"]
            target = edge.attrib["target"]
            # Only add edges if both nodes exist in this file
            if source in nodes and target in nodes:
                outgoing[source].append(target)
                incoming[target].append(source)

        # --- 2. Analyze: Transitions (Markov Chain) ---
        for node_a, neighbors_of_a in outgoing.items():
            if not neighbors_of_a: continue
            
            # Sample if too many neighbors (rare in single files, but good safety)
            sampled_a = neighbors_of_a
            if len(neighbors_of_a) > MAX_NEIGHBORS_TO_SAMPLE:
                sampled_a = random.sample(neighbors_of_a, MAX_NEIGHBORS_TO_SAMPLE)
            
            for node_b in sampled_a:
                neighbors_of_b = outgoing.get(node_b, [])
                if not neighbors_of_b: continue
                
                sampled_b = neighbors_of_b
                if len(neighbors_of_b) > MAX_NEIGHBORS_TO_SAMPLE:
                    sampled_b = random.sample(neighbors_of_b, MAX_NEIGHBORS_TO_SAMPLE)
                
                context = (nodes[node_a], nodes[node_b])
                for node_c in sampled_b:
                    next_node = nodes[node_c]
                    local_transitions[context][next_node] += 1

        # --- 3. Analyze: Markov Blanket (Validation) ---
        for node_id, node_type in nodes.items():
            # A. Parents
            parents = incoming.get(node_id, [])
            for p in parents:
                local_blanket[node_type]["parents"][nodes[p]] += 1
            
            # B. Children & Spouses
            children = outgoing.get(node_id, [])
            for c in children:
                c_type = nodes[c]
                local_blanket[node_type]["children"][c_type] += 1
                
                # Spouses (parents of my child)
                child_parents = incoming.get(c, [])
                for spouse in child_parents:
                    if spouse != node_id:
                        local_blanket[node_type]["spouses"][nodes[spouse]] += 1

        # --- 4. Find Path Candidates (for 'Start' suggestions) ---
        # Look for Pump sequences (common starting points)
        for node_id, node_type in nodes.items():
            if 'pp' in node_type.lower() or 'pump' in node_type.lower():
                # Trace back 2 steps to find a path leading to this pump
                parents = incoming.get(node_id, [])
                for p in parents:
                    grandparents = incoming.get(p, [])
                    for gp in grandparents:
                        # Found a sequence: GP -> P -> Pump
                        path = (nodes[gp], nodes[p], node_type)
                        local_paths.append(path)

        return (local_transitions, local_blanket, local_paths)

    except Exception:
        # Return empty structures on failure
        return (defaultdict(Counter), defaultdict(create_blanket_entry), [])

def build_and_save_models():
    print(f"--- Starting Optimized Processing ---")
    start_time = time.time()
    
    if not os.path.exists(DATASET_FOLDER):
        print(f"Error: Folder {DATASET_FOLDER} not found.")
        return

    pid_files = [os.path.join(DATASET_FOLDER, f) for f in os.listdir(DATASET_FOLDER) 
                 if f.startswith("pid") and f.endswith(".graphml")]
    
    total_files = len(pid_files)
    print(f"Found {total_files} files. Analyzing in parallel...")

    # Master Aggregators
    final_transitions = defaultdict(Counter)
    # Fixed: Use named function here too
    final_blanket = defaultdict(create_blanket_entry)
    final_paths = Counter()

    # Parallel Execution
    try:
        ctx = multiprocessing.get_context('fork')
        with ctx.Pool(processes=cpu_count()) as pool:
            # We iterate through results AS they finish
            for i, result in enumerate(pool.imap_unordered(process_single_file, pid_files), 1):
                (file_trans, file_blanket, file_paths) = result
                
                # Aggregate Transitions
                for ctx_key, counts in file_trans.items():
                    final_transitions[ctx_key].update(counts)
                
                # Aggregate Blanket Stats
                for n_type, stats in file_blanket.items():
                    final_blanket[n_type]["parents"].update(stats["parents"])
                    final_blanket[n_type]["children"].update(stats["children"])
                    final_blanket[n_type]["spouses"].update(stats["spouses"])
                
                # Aggregate Paths
                final_paths.update(file_paths)

                # Progress Bar
                percent = (i / total_files) * 100
                bar_length = 30
                filled_length = int(bar_length * i // total_files)
                bar = '█' * filled_length + '-' * (bar_length - filled_length)
                sys.stdout.write(f'\rProcessing: |{bar}| {percent:.1f}% ({i}/{total_files})')
                sys.stdout.flush()
                
        print("\nAggregation complete.")

    except Exception as e:
        print(f"\n⚠️ Error during processing: {e}")
        return

    # Extract Top Paths
    top_paths = [list(p) for p, _ in final_paths.most_common(NUM_TOP_PATHS_TO_FIND)]

    # Save to Disk
    print(f"Saving models to {OUTPUT_FILE}...")
    with open(OUTPUT_FILE, "wb") as f:
        pickle.dump({
            "transitions": final_transitions,
            "blanket_model": dict(final_blanket),
            "top_paths": top_paths
        }, f)

    duration = time.time() - start_time
    print(f"--- Done in {duration:.2f} seconds ---")
    print(f"Run 'streamlit run pid_app.py' now.")

if __name__ == "__main__":
    build_and_save_models()

--- Starting Optimized Processing ---
Found 10000 files. Analyzing in parallel...
Processing: |██████████████████████████████| 100.0% (10000/10000)
Aggregation complete.
Saving models to pid_models.pkl...
--- Done in 12.05 seconds ---
Run 'streamlit run pid_app.py' now.
