In [6]:
import json
import logging
import os
import random
import time
import pickle
from collections import defaultdict
from typing import List, Dict, Any, Optional, Set, Tuple, DefaultDict
from tqdm import tqdm

# Import rdflib for parsing the RDF graph
import rdflib
from rdflib import Graph, URIRef, Literal, BNode
from rdflib.namespace import RDF, RDFS



# --- Utility Functions (from original src.utils) ---
# We include these directly to make the script self-contained
# and remove external dependencies.


def is_inv_rel(rel: str) -> bool:
    """Check if a relation is an inverse relation."""
    return rel.endswith("#R")


def get_inv_rel(rel: str) -> str:
    """Get the inverse of a relation, or vice-versa."""
    if is_inv_rel(rel):
        return rel[:-2]  # Remove '#R'
    return f"{rel}#R"


def get_readable_class(cls: str, schema: Optional[Dict[str, Any]] = None) -> str:
    """Get a readable name for a class."""
    if schema and cls in schema["classes"] and "description" in schema["classes"][cls]:
        return schema["classes"][cls]["description"]
    return cls.split(".")[-1]


def get_readable_relation(rel: str, schema: Optional[Dict[str, Any]] = None) -> str:
    """Get a readable name for a relation."""
    if (
        schema
        and rel in schema["relations"]
        and "description" in schema["relations"][rel]
    ):
        return schema["relations"][rel]["description"]
    return rel.split(".")[-1]


def get_reverse_relation(rel: str, schema: Dict[str, Any]) -> Optional[str]:
    """Get the reverse relation from the schema."""
    return schema["relations"].get(rel, {}).get("reverse")


def get_reverse_readable_relation(rel: str, schema: Dict[str, Any]) -> Optional[str]:
    """Get the readable name of the reverse relation."""
    rev_rel = get_reverse_relation(rel, schema)
    if rev_rel and rev_rel in schema["relations"]:
        return schema["relations"][rev_rel].get("description")
    return None


def get_nodes_by_class(
    nodes: List[Dict[str, Any]], cls: str, except_nid: Optional[List[int]] = None
) -> List[Dict[str, Any]]:
    """Get all nodes of a specific class, with optional exceptions."""
    if except_nid is None:
        except_nid = []
    return [n for n in nodes if n["class"] == cls and n["nid"] not in except_nid]


def get_non_literals(
    nodes: List[Dict[str, Any]], except_nid: Optional[Set[int]] = None
) -> List[Dict[str, Any]]:
    """Get all nodes that are not literals."""
    if except_nid is None:
        except_nid = set()
    return [
        n
        for n in nodes
        if n["nid"] not in except_nid and not n["class"].startswith("type.")
    ]


def legal_class(cls: str) -> bool:
    """Check if a class is a legal starting point (not a literal)."""
    return not cls.startswith("type.")


def legal_relation(rel: str) -> bool:
    """Placeholder for relation filtering logic, if any."""
    # You can add logic here to filter out specific relations
    return True


def graph_query_to_sexpr(*args, **kwargs) -> str:
    """
    Placeholder for the s-expression conversion function.
    In a real scenario, you would copy this function's code here.
    For this example, we'll return a placeholder string.
    """
    # In your actual use, you would copy the full function definition for
    # graph_query_to_sexpr from src.utils.parser
    logging.warning("Using placeholder function for graph_query_to_sexpr")
    return "(PlaceholderSExpression)"


def graph_query_to_sparql(*args, **kwargs) -> str:
    """
    Placeholder for the SPARQL conversion function.
    In a real scenario, you would copy this function's code here.
    For this example, we'll return a placeholder string.
    """
    # In your actual use, you would copy the full function definition for
    # graph_query_to_sparql from src.utils.parser
    logging.warning("Using placeholder function for graph_query_to_sparql")
    return "SELECT ?x WHERE { ?x ?y ?z . } # (Placeholder SPARQL)"


# --- End of Utility Functions ---


In [7]:

# Hard-coded literal_map (as it was an external dependency)
# This maps schema types to their full XSD/RDF URIs
literal_map: Dict[str, str] = {
    "type.string": "http://www.w3.org/2001/XMLSchema#string",
    "type.text": "http://www.w3.org/2001/XMLSchema#string",
    "type.datetime": "http://www.w3.org/2001/XMLSchema#dateTime",
    "type.integer": "http://www.w3.org/2001/XMLSchema#int",
    "type.int": "http://www.w3.org/2001/XMLSchema#int",
    "type.float": "http://www.w3.org/2001/XMLSchema#float",
    "type.boolean": "http://www.w3.org/2001/XMLSchema#boolean",
}


# Setup basic logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)


class Explorer:
    """
    Loads a pre-built RDF graph and its JSON schema to perform
    random schema-guided traversals (walks).
    """

    def __init__(self, kg_name: str):
        self.kg_name: str = kg_name
        self.schema: Optional[Dict[str, Any]] = None
        self.schema_dr: Dict[str, Tuple[str, str]] = {}
        self.classes: Set[str] = set()

        # In-memory representation of the graph and schema
        self.out_relations_cls: DefaultDict[str, set] = defaultdict(set)
        self.in_relations_cls: DefaultDict[str, set] = defaultdict(set)
        self.cls_2_entid: DefaultDict[str, set] = defaultdict(set)
        self.entid_2_cls_ent: Dict[str, Dict[str, Any]] = {}
        self.literals_by_cls_rel: DefaultDict[Tuple[str, str], set] = defaultdict(set)

    def load_graph_and_schema(
        self,
        schema_fpath: str,
        rdf_fpath: str,
        processed_fpath: Optional[str] = None,
        use_cache: bool = True,
    ):
        """
        Loads the JSON schema and the RDF graph file.
        It builds the in-memory representation needed for exploration.

        Args:
            schema_fpath: Path to the JSON schema file.
            rdf_fpath: Path to the RDF graph file (e.g., .nt, .ttl, .rdf).
            processed_fpath: Path to a .pkl file for caching the processed data.
            use_cache: If True, try to load from processed_fpath if it exists.
        """

        if use_cache and processed_fpath and os.path.exists(processed_fpath):
            logger.info(f"Loading cached processed data from {processed_fpath}")
            with open(processed_fpath, "rb") as f:
                processed = pickle.load(f)
                self.schema = processed["schema"]
                self.schema_dr = processed["schema_dr"]
                self.classes = processed["classes"]
                self.out_relations_cls = processed["out_relations_cls"]
                self.in_relations_cls = processed["in_relations_cls"]
                self.cls_2_entid = processed["cls_2_entid"]
                self.entid_2_cls_ent = processed["entid_2_cls_ent"]
                self.literals_by_cls_rel = processed["literals_by_cls_rel"]
            return

        logger.info(f"Processing schema from {schema_fpath}")

        # 1. Load Schema
        with open(schema_fpath, "r") as f:
            self.schema = json.load(f)

        if not self.schema:
            raise ValueError("Schema could not be loaded or is empty.")

        self.classes = set(self.schema.get("classes", {}).keys())

        for rel, rel_obj in self.schema.get("relations", {}).items():
            domain = rel_obj["domain"]
            range_ = rel_obj["range"]

            self.schema_dr[rel] = (domain, range_)
            self.out_relations_cls[domain].add(rel)
            self.in_relations_cls[range_].add(rel)

        logger.info(f"Loading RDF graph from {rdf_fpath}...")

        # 2. Load RDF Graph
        g = Graph()
        try:
            g.parse(rdf_fpath)
        except Exception as e:
            logger.error(f"Failed to parse RDF file {rdf_fpath}: {e}")
            raise

        logger.info(f"Graph loaded with {len(g)} triples. Indexing entities...")

        # 3. Build in-memory indexes from the graph

        # Get RDFS.label, fall back to a common alt
        label_prop = RDFS.label

        # Index Entities and their labels
        for cls in tqdm(self.classes, desc="Indexing entities by class"):
            if cls.startswith("type."):  # Skip literal types
                continue

            try:
                cls_uri = URIRef(cls)
                for ent_uri in g.subjects(RDF.type, cls_uri):
                    if not isinstance(ent_uri, URIRef):
                        continue  # Skip blank nodes

                    ent_id_str = str(ent_uri)
                    self.cls_2_entid[cls].add(ent_id_str)

                    # Get label
                    label_lit = g.value(ent_uri, label_prop)
                    label_str = (
                        str(label_lit) if label_lit else ent_id_str.split("/")[-1]
                    )

                    self.entid_2_cls_ent[ent_id_str] = {"class": cls, "name": label_str}
            except Exception as e:
                logger.warning(f"Error indexing class {cls}: {e}")

        # Index Literals by (domain_class, relation)
        for rel, (domain, range_) in tqdm(
            self.schema_dr.items(), desc="Indexing literals"
        ):
            if not range_.startswith("type."):  # Skip non-literal ranges
                continue

            try:
                domain_uri = URIRef(domain)
                rel_uri = URIRef(rel)

                # Find all subjects of the domain type
                for s_uri in g.subjects(RDF.type, domain_uri):
                    # For each subject, get the literal objects for this relation
                    for o_lit in g.objects(s_uri, rel_uri):
                        if isinstance(o_lit, Literal):
                            self.literals_by_cls_rel[(domain, rel)].add(str(o_lit))
            except Exception as e:
                logger.warning(f"Error indexing literals for relation {rel}: {e}")

        logger.info("Finished processing graph and schema.")

        # 4. Save to cache if path provided
        if use_cache and processed_fpath:
            logger.info(f"Saving processed data to cache at {processed_fpath}")
            try:
                with open(processed_fpath, "wb") as f:
                    pickle.dump(
                        {
                            "schema": self.schema,
                            "schema_dr": self.schema_dr,
                            "classes": self.classes,
                            "out_relations_cls": self.out_relations_cls,
                            "in_relations_cls": self.in_relations_cls,
                            "cls_2_entid": self.cls_2_entid,
                            "entid_2_cls_ent": self.entid_2_cls_ent,
                            "literals_by_cls_rel": self.literals_by_cls_rel,
                        },
                        f,
                        protocol=pickle.HIGHEST_PROTOCOL,
                    )
            except Exception as e:
                logger.warning(f"Failed to write to cache file {processed_fpath}: {e}")

    def explore(
        self,
        n_walks: int,
        edge_lengths: List[int],
        max_retries_per_iter: int = 5,
        always_ground_classes: bool = False,
        always_ground_literals: bool = False,
        sexpr_type_constraint: bool = True,
        n_per_pattern: int = 1,
        use_functions: Optional[List[str]] = None,
        max_skip: Optional[int] = None,
        verbose: bool = False,
        max_retries: Optional[int] = None,
        out_dir: Optional[str] = None,
        run_id: str = str(int(time.time())),
    ) -> List[Dict[str, Any]]:
        """
        Generates a set of random graph queries (walks) based on the loaded schema.
        This version does NOT execute the queries.
        """
        edge_dup, walk_dup = 0, 0
        iter = 0
        walked_sexpr: DefaultDict[str, int] = defaultdict(int)
        walked_sexpr_all: Set[str] = set()
        walks: List[Dict[str, Any]] = []
        fn_counts: DefaultDict[str, int] = defaultdict(int)
        edge_counts: DefaultDict[int, int] = defaultdict(int)
        node_counts: DefaultDict[int, int] = defaultdict(int)
        retained_iter: List[int] = []

        # Create output directory
        res_dir_fpath: Optional[str] = None
        if out_dir is not None:
            res_dir_parts = ["walks", self.kg_name, "-".join(map(str, edge_lengths))]
            if use_functions:
                res_dir_parts.append("-".join(use_functions))
            res_dir_parts.append(run_id)

            res_dir_fpath = os.path.join(out_dir, "_".join(res_dir_parts))
            os.makedirs(res_dir_fpath, exist_ok=True)
            logger.info(f"Saving results to {res_dir_fpath}")

        pbar = tqdm(total=n_walks, desc="Exploring")
        start_time = time.time()
        retries = 0
        retries_n_walks = 0
        retries_per_program: List[int] = []
        reached_retry_limit = False

        while len(walks) < n_walks and (max_skip is None or iter < n_walks + max_skip):
            # Retry loop count-out
            if len(walks) == retries_n_walks:
                retries += 1
                if max_retries is not None and retries > max_retries:
                    reached_retry_limit = True
                    break
            else:
                retries_per_program.append(retries)
                retries = 0
                retries_n_walks = len(walks)

            if verbose:
                logger.info(f"--- Iter {iter} ---")

            iter += 1
            sampled_n_edges = random.choice(edge_lengths)

            res = self.generate_graph_query(
                n_edges=sampled_n_edges,
                max_retries=max_retries_per_iter,
                always_ground_literals=always_ground_literals,
                always_ground_classes=always_ground_classes,
                use_functions=use_functions,
                verbose=verbose,
            )
            if res is None:
                edge_dup += 1
                if verbose:
                    logger.info(
                        f"Skipping (generate_graph_query failed) (count={edge_dup})"
                    )
                continue

            gq, gq_fn, gq_n_groundings = res

            sexpr_anon_noid = graph_query_to_sexpr(
                gq,
                type_constraint=sexpr_type_constraint,
                readable=True,
                readable_type="anon_noid",
            )
            sexpr_machine = graph_query_to_sexpr(
                gq, type_constraint=sexpr_type_constraint
            )

            if (
                walked_sexpr[sexpr_anon_noid] >= n_per_pattern
                or sexpr_machine in walked_sexpr_all
            ):
                walk_dup += 1
                if verbose:
                    logger.info(f"Already walked. Skipping (count={walk_dup})")
                continue

            # This is where SPARQL execution (filter_empty, prune_redundant, save_answers)
            # was. It has been removed.

            walked_sexpr[sexpr_anon_noid] += 1
            walked_sexpr_all.add(sexpr_machine)

            # Generate all representations
            sexpr_anon = graph_query_to_sexpr(
                gq,
                type_constraint=sexpr_type_constraint,
                readable=True,
                readable_type="anon",
            )
            sexpr_anon_rev = graph_query_to_sexpr(
                gq,
                type_constraint=sexpr_type_constraint,
                readable=True,
                readable_type="anon",
                use_reverse_relations=True,
            )
            sexpr_machine_rev = graph_query_to_sexpr(
                gq,
                type_constraint=sexpr_type_constraint,
                readable=False,
                use_reverse_relations=True,
            )
            sexpr_label = graph_query_to_sexpr(
                gq,
                type_constraint=sexpr_type_constraint,
                readable=True,
                readable_type="label",
            )
            sexpr_label_rev = graph_query_to_sexpr(
                gq,
                type_constraint=sexpr_type_constraint,
                readable=True,
                readable_type="label",
                use_reverse_relations=True,
            )
            sparql_query = graph_query_to_sparql(gq)  # Removed header addition

            walks.append(
                {
                    "qid": len(walks),
                    "function": gq_fn,
                    "num_node": len(gq["nodes"]),
                    "num_edge": len(gq["edges"]),
                    "graph_query": gq,
                    "s_expression_anon": sexpr_anon,
                    "s_expression_anon-rev": sexpr_anon_rev,
                    "s_expression_label": sexpr_label,
                    "s_expression_label-rev": sexpr_label_rev,
                    "s_expression_machine": sexpr_machine,
                    "s_expression_machine-rev": sexpr_machine_rev,
                    "sparql_query": sparql_query,
                    # "answer" field is removed
                }
            )
            retained_iter.append(iter)

            fn_counts[gq_fn] += 1
            edge_counts[len(gq["edges"])] += 1
            node_counts[len(gq["nodes"])] += 1

            if verbose:
                logger.info(json.dumps(walks[-1], indent=2))

            pbar.update(1)

        pbar.close()

        if max_retries is not None and reached_retry_limit:
            logger.info(f"Stopping exploration: reached retry limit ({max_retries})")
        elif n_walks is not None and max_skip is not None and len(walks) < n_walks:
            logger.info(
                f"Stopping exploration: reached maximum attempts ({n_walks + max_skip})"
            )

        end_time = time.time()
        seen_patterns = {k: v for k, v in dict(walked_sexpr).items() if v > 0}

        stats = {
            "n_iters": iter,
            "n_programs": len(walks),
            "time_taken": end_time - start_time,
            "retries": {
                "avg": sum(retries_per_program) / len(retries_per_program)
                if retries_per_program
                else 0,
                "max": max(retries_per_program) if retries_per_program else 0,
            },
            "patterns": {
                "total": len(seen_patterns),
                "min_programs": min(seen_patterns.values()) if seen_patterns else 0,
                "max_programs": max(seen_patterns.values()) if seen_patterns else 0,
                "avg_programs": sum(seen_patterns.values()) / len(seen_patterns)
                if seen_patterns
                else 0,
            },
            "n_skipped_seen_node-rel_pair": edge_dup,
            "n_skipped_seen_pattern": walk_dup,
            "n_skipped_empty_ans": 0,  # This is 0 because we removed the check
            "n_programs_per_fn": dict(fn_counts),
            "n_programs_per_node_count": dict(node_counts),
            "n_programs_per_edge_count": dict(edge_counts),
            "n_programs_per_pattern": seen_patterns,
            "retain_iter_count": retained_iter,
        }

        if res_dir_fpath:
            out_fpath = os.path.join(res_dir_fpath, "results.json")
            with open(out_fpath, "w") as fh:
                json.dump(walks, fh, indent=2)
            logger.info(f"Saved walks to {out_fpath}")

            stats_fpath = os.path.join(res_dir_fpath, "stats.json")
            with open(stats_fpath, "w") as fh:
                json.dump(stats, fh, indent=2)
            logger.info(f"Saved exploration statistics to {stats_fpath}")

        return walks

    def generate_graph_query(
        self,
        n_edges: int,
        max_retries: int = 3,
        always_ground_literals: bool = True,
        always_ground_classes: bool = False,
        verbose: bool = False,
        use_functions: Optional[List[str]] = None,
        ground_attempts_max: int = 5,
    ) -> Optional[Tuple[Dict[str, List], str, int]]:
        """
        Generates a single random graph query.
        This function is largely unchanged, as it depends on the in-memory
        schema representation, which is now populated by load_graph_and_schema.
        """
        if not self.schema or not self.classes:
            logger.error("Schema not loaded. Call load_graph_and_schema() first.")
            return None

        graph: Dict[str, List] = {"nodes": [], "edges": []}
        class_2_nid: DefaultDict[str, Set[int]] = defaultdict(set)
        sampled_node_rel: Set[Tuple[int, str]] = set()
        nodes_2_ground: Set[int] = set()
        ungrounded_terminal_node: Set[int] = set()
        n_nodes_grounded = 0

        # Sample an initial question node from all non-literal classes
        legal_classes = [c for c in self.classes if legal_class(c)]
        if not legal_classes:
            logger.error("No legal (non-literal) classes found in schema.")
            return None

        q_class = random.choice(legal_classes)

        q_node: Dict[str, Any] = {
            "nid": 0,
            "node_type": "class",
            "id": q_class,
            "class": q_class,
            "readable_name": get_readable_class(q_class, schema=self.schema),
            "question_node": 1,
            "function": "none",
        }
        graph["nodes"].append(q_node)
        class_2_nid[q_node["class"]].add(q_node["nid"])

        n_attempts = 0
        n_legal_relation_attempts = 0
        while (
            len(graph["edges"]) < n_edges
            and n_attempts < max_retries
            and n_legal_relation_attempts < max_retries
        ):
            next_edge: Dict[str, Any] = {}
            # Sample a node from the set of ungrounded and non-literal nodes
            non_literal_nodes = get_non_literals(
                graph["nodes"], except_nid=nodes_2_ground
            )
            if not non_literal_nodes:
                if verbose:
                    logger.info(
                        "No more non-literal nodes to expand. Stopping walk early."
                    )
                break  # Can't expand anymore

            node_2_expand = random.choice(non_literal_nodes)

            # Sample adjacent node relation
            possible_rels = list(self.out_relations_cls[node_2_expand["class"]]) + [
                f"{r}#R" for r in self.in_relations_cls[node_2_expand["class"]]
            ]

            if not possible_rels:
                n_attempts += 1
                if verbose:
                    logger.info(
                        f"Node {node_2_expand['nid']} has no relations. Retrying ({n_attempts})"
                    )
                continue

            next_rel = random.choice(possible_rels)

            if not legal_relation(next_rel):
                n_legal_relation_attempts += 1
                if verbose:
                    logger.info(
                        f"Illegal relation sampled. Retrying ({n_legal_relation_attempts})"
                    )
                continue
            n_legal_relation_attempts = 0

            # Re-try if the sampled (node, relation) pair has been sampled before
            if (node_2_expand["nid"], next_rel) in sampled_node_rel:
                n_attempts += 1
                if verbose:
                    logger.info(f"Already seen. Retrying ({n_attempts})")
                continue
            sampled_node_rel.add((node_2_expand["nid"], next_rel))
            sampled_node_rel.add((node_2_expand["nid"], get_inv_rel(next_rel)))
            n_attempts = 0

            ungrounded_terminal_node.discard(node_2_expand["nid"])

            if is_inv_rel(next_rel):
                rel_name = get_inv_rel(next_rel)
                rel_domain, rel_range = (
                    self.schema["relations"][rel_name]["range"],
                    self.schema["relations"][rel_name]["domain"],
                )
            else:
                rel_name = next_rel
                rel_domain, rel_range = (
                    self.schema["relations"][rel_name]["domain"],
                    self.schema["relations"][rel_name]["range"],
                )

            # Select next node
            cand_nodes_from_existing = get_nodes_by_class(
                graph["nodes"], cls=rel_range, except_nid=[node_2_expand["nid"]]
            )
            new_node: Dict[str, Any] = {
                "nid": len(graph["nodes"]),
                "node_type": "class",
                "readable_name": get_readable_class(rel_range, schema=self.schema),
                "question_node": 0,
                "function": "none",
                "id": rel_range,
                "class": rel_range,
            }

            if not cand_nodes_from_existing:
                add_new_node = True
                next_node = new_node
            else:
                if add_new_node := random.choice([True, False]):
                    next_node = new_node
                else:
                    next_node = random.choice(cand_nodes_from_existing)

            if is_inv_rel(next_rel):
                next_edge.update(
                    {
                        "start": next_node["nid"],
                        "end": node_2_expand["nid"],
                        "relation": rel_name,
                        "readable_name": get_readable_relation(
                            rel_name, schema=self.schema
                        ),
                        "reverse_relation": get_reverse_relation(
                            rel_name, schema=self.schema
                        ),
                        "reverse_readable_name": get_reverse_readable_relation(
                            rel_name, schema=self.schema
                        ),
                    }
                )
            else:
                next_edge.update(
                    {
                        "start": node_2_expand["nid"],
                        "end": next_node["nid"],
                        "relation": rel_name,
                        "readable_name": get_readable_relation(
                            rel_name, schema=self.schema
                        ),
                        "reverse_relation": get_reverse_relation(
                            rel_name, schema=self.schema
                        ),
                        "reverse_readable_name": get_reverse_readable_relation(
                            rel_name, schema=self.schema
                        ),
                    }
                )

            if add_new_node:
                if next_node["class"].startswith("type."):  # Check for literal
                    if always_ground_literals:
                        next_node["node_type"] = "literal"
                        grounding_cands = self.literals_by_cls_rel[
                            (node_2_expand["class"], next_edge["relation"])
                        ]
                        if not grounding_cands:
                            if verbose:
                                logger.info(
                                    f"No grounding values found for class `{node_2_expand['class']}` + relation `{next_edge['relation']}`. Skipping."
                                )
                            return None
                        next_node["readable_name"] = random.choice(
                            list(grounding_cands)
                        )
                        next_node["id"] = (
                            f'"{next_node["readable_name"]}"^^{literal_map[next_node["class"]]}'
                        )
                        n_nodes_grounded += 1
                    else:
                        next_node["grounding_helper"] = {
                            "cls": node_2_expand["class"],
                            "rel": next_edge["relation"],
                        }
                        if random.choice([True, False]):
                            nodes_2_ground.add(next_node["nid"])
                else:  # non-literal class
                    if always_ground_classes or random.choice([True, False]):
                        if len(self.cls_2_entid[next_node["class"]]) > 0:
                            nodes_2_ground.add(next_node["nid"])

                graph["nodes"].append(next_node)
                class_2_nid[next_node["class"]].add(next_node["nid"])
                if (
                    next_node["node_type"] == "class"
                    and next_node["nid"] not in nodes_2_ground
                ):
                    ungrounded_terminal_node.add(next_node["nid"])

            graph["edges"].append(next_edge)
            sampled_node_rel.add((next_node["nid"], next_rel))
            sampled_node_rel.add((next_node["nid"], get_inv_rel(next_rel)))

        if len(graph["edges"]) == 0 and n_edges > 0:
            if verbose:
                logger.info("Could not sample any edges.")
            return None

        if len(graph["edges"]) < n_edges and n_attempts >= max_retries:
            if verbose:
                logger.info(
                    "Reached maximum attempts in trying to sample new node-edge pairs"
                )
            return None

        # Mark ungrounded terminal nodes for grounding
        nodes_2_ground.update(ungrounded_terminal_node)

        # Add a function to the query
        function_cands = ["none"]
        if not q_class.startswith("type."):
            function_cands.append("count")

        numerical_nodes_not_q = class_2_nid["type.integer"].union(
            class_2_nid["type.float"]
        ).union(class_2_nid["type.datetime"]) - {q_node["nid"]}

        if len(numerical_nodes_not_q) > 0:
            function_cands.extend(["argmax", "argmin", ">", "<", ">=", "<="])
        else:
            function_cands += ["none"] * 2  # reduce likelihood of 'count'

        sampled_fn = random.choice(function_cands)

        if use_functions is not None and sampled_fn not in use_functions:
            return None

        if sampled_fn == "count":
            graph["nodes"][0]["function"] = "count"
        elif sampled_fn in ["argmin", "argmax"]:
            sampled_node_id = random.choice(list(numerical_nodes_not_q))
            graph["nodes"][sampled_node_id].update(
                {
                    "node_type": "literal",
                    "id": '"0"^^http://www.w3.org/2001/XMLSchema#int',
                    "readable_name": "0",
                    "function": sampled_fn,
                }
            )
            nodes_2_ground -= {sampled_node_id}
            n_nodes_grounded += 1
        elif sampled_fn in [">", "<", ">=", "<="]:
            sampled_node_id = random.choice(list(numerical_nodes_not_q))
            graph["nodes"][sampled_node_id]["function"] = sampled_fn
            if graph["nodes"][sampled_node_id]["node_type"] == "class":
                nodes_2_ground.add(sampled_node_id)

        # Try to ground something if nothing is grounded so far
        if n_nodes_grounded == 0:
            ground_attempts = 0
            while len(nodes_2_ground) == 0 and ground_attempts < ground_attempts_max:
                sampled_node = random.choice(graph["nodes"][1:])  # Don't ground q_node
                if sampled_node["node_type"] == "class":
                    nodes_2_ground.add(sampled_node["nid"])
                ground_attempts += 1

        if (
            n_nodes_grounded == 0
            and len(nodes_2_ground) == 0
            and len(graph["nodes"]) > 1
        ):
            # As a last resort, if we still have no groundings,
            # try to ground *any* non-question node
            non_q_nodes = [
                n for n in graph["nodes"] if n["nid"] != 0 and n["node_type"] == "class"
            ]
            if non_q_nodes:
                nodes_2_ground.add(random.choice(non_q_nodes)["nid"])

        # Ground nodes
        grounded_ents: Set[str] = set()
        for nid in nodes_2_ground:
            node = graph["nodes"][nid]
            node_cls = node["class"]

            if node_cls.startswith("type."):  # literal class
                g_helper = node.get("grounding_helper")
                if not g_helper:
                    if verbose:
                        logger.info(
                            f"Node {nid} marked for grounding but has no grounding_helper. Skipping."
                        )
                    continue

                grounding_cands = list(
                    self.literals_by_cls_rel[(g_helper["cls"], g_helper["rel"])]
                )
                if not grounding_cands:
                    if verbose:
                        logger.warning(
                            f"No grounding values found for `{g_helper['cls']}->{g_helper['rel']}`. Cannot ground."
                        )
                    return None  # This walk is invalid

                grounded_lit = random.choice(grounding_cands)
                node.update(
                    {
                        "node_type": "literal",
                        "id": f'"{grounded_lit}"^^{literal_map[node_cls]}',
                        "readable_name": grounded_lit,
                    }
                )
            else:  # entity class
                if not self.cls_2_entid[node_cls]:
                    if verbose:
                        logger.info(
                            f"No entities found for class {node_cls}. Skipping."
                        )
                    return None

                grounding_cands = list(self.cls_2_entid[node_cls] - grounded_ents)

                if not grounding_cands:
                    # All entities for this class are already used. Use one again.
                    grounding_cands = list(self.cls_2_entid[node_cls])
                    if not grounding_cands:
                        if verbose:
                            logger.info(
                                f"No grounding values found for class `{node_cls}`. Skipping"
                            )
                        return None

                grounded_ent = random.choice(grounding_cands)
                grounded_ents.add(grounded_ent)
                node.update(
                    {
                        "node_type": "entity",
                        "id": grounded_ent,
                        "readable_name": self.entid_2_cls_ent[grounded_ent]["name"],
                    }
                )
            n_nodes_grounded += 1

        if n_nodes_grounded == 0 and len(graph["nodes"]) > 1:
            if verbose:
                logger.info("Failed to ground any node. Skipping.")
            return None

        return graph, sampled_fn, n_nodes_grounded

In [8]:
workflow_explorer = Explorer(kg_name="workflow")

In [9]:
workflow_explorer.load_graph_and_schema(
    schema_fpath="data/graphs/workflow/schema.json",
    rdf_fpath="data/graphs/workflow/chatbs_sample.ttl",
    processed_fpath="data/graphs/workflow/processed_workflow.pkl",
    use_cache=True,
)

2025-11-05 14:09:19,653 - INFO - Processing schema from data/graphs/workflow/schema.json
2025-11-05 14:09:19,654 - INFO - Loading RDF graph from data/graphs/workflow/chatbs_sample.ttl...
2025-11-05 14:09:19,673 - INFO - Graph loaded with 1053 triples. Indexing entities...
Indexing entities by class: 100%|██████████| 15/15 [00:00<00:00, 52254.62it/s]
Indexing literals: 100%|██████████| 23/23 [00:00<00:00, 249273.88it/s]
2025-11-05 14:09:20,386 - INFO - Finished processing graph and schema.
2025-11-05 14:09:20,386 - INFO - Saving processed data to cache at data/graphs/workflow/processed_workflow.pkl


In [10]:
workflow_explorer.schema_dr

{'provone.hasSubProgram': ('provone.Program', 'provone.Program'),
 'provone.controls': ('provone.Controller', 'provone.Program'),
 'provone.hasInPort': ('provone.Program', 'provone.Port'),
 'provone.hasOutPort': ('provone.Program', 'provone.Port'),
 'provone.hasDefaultParam': ('provone.Port', 'prov.Entity'),
 'provone.connectsTo': ('provone.Port', 'provone.Channel'),
 'provone.programWasDerivedFrom': ('provone.Program', 'provone.Program'),
 'prov.used': ('provone.Execution', 'prov.Entity'),
 'prov.wasGeneratedBy': ('prov.Entity', 'provone.Execution'),
 'prov.wasAssociatedWith': ('provone.Execution', 'provone.User'),
 'prov.wasInformedBy': ('provone.Execution', 'provone.Execution'),
 'provone.wasPartOf': ('provone.Execution', 'provone.Execution'),
 'prov.qualifiedAssociation': ('provone.Execution', 'prov.Association'),
 'prov.agent': ('prov.Association', 'provone.User'),
 'prov.hadPlan': ('prov.Association', 'provone.Program'),
 'prov.qualifiedUsage': ('provone.Execution', 'prov.Usage')

In [11]:
workflow_explorer.schema

{'classes': {'provone.Program': {'description': 'A computational task that consumes and produces data. Can be atomic or composite.'},
  'provone.Workflow': {'description': 'A distinguished Program, representing a computational experiment in its entirety. Subclass of Program.'},
  'provone.Port': {'description': 'Enables a Program to send or receive Entity items (Data, Visualization, or Document instances).'},
  'provone.Channel': {'description': 'Provides a connection between Ports that are defined for Programs.'},
  'provone.Controller': {'description': 'Specifies a Program that controls other Programs under a particular model of computation.'},
  'provone.Execution': {'description': 'Represents the execution of a Program. If the Program is a Workflow, this is the trace.'},
  'provone.User': {'description': 'The person(s) responsible for the execution of an Execution. Subclass of prov:Agent.'},
  'prov.Entity': {'description': 'A physical, digital, conceptual, or other kind of thing. 

In [12]:
workflow_explorer.out_relations_cls

defaultdict(set,
            {'provone.Program': {'provone.hasInPort',
              'provone.hasOutPort',
              'provone.hasSubProgram',
              'provone.programWasDerivedFrom'},
             'provone.Controller': {'provone.controls'},
             'provone.Port': {'provone.connectsTo', 'provone.hasDefaultParam'},
             'provone.Execution': {'prov.qualifiedAssociation',
              'prov.qualifiedGeneration',
              'prov.qualifiedUsage',
              'prov.used',
              'prov.wasAssociatedWith',
              'prov.wasInformedBy',
              'provone.wasPartOf'},
             'prov.Entity': {'prov.wasGeneratedBy',
              'provone.dataWasDerivedFrom'},
             'prov.Association': {'prov.agent', 'prov.hadPlan'},
             'prov.Usage': {'provone.usageHadEntity',
              'provone.usageHadInPort'},
             'prov.Generation': {'provone.generationHadEntity',
              'provone.generationHadOutPort'},
             'prov.

In [13]:
workflow_explorer.explore(
    n_walks=10,
    edge_lengths=[2, 3],
    max_retries_per_iter=5,
    always_ground_classes=False,
    always_ground_literals=True,
    sexpr_type_constraint=True,
    n_per_pattern=1,
    use_functions=None,
    max_skip=50,
    verbose=True,
    max_retries=100,
    out_dir="exploration_results",
    run_id="test_run_001",
)

2025-11-05 14:13:56,013 - INFO - Saving results to exploration_results/walks_workflow_2-3_test_run_001
Exploring:   0%|          | 0/10 [00:00<?, ?it/s]2025-11-05 14:13:56,015 - INFO - --- Iter 0 ---
2025-11-05 14:13:56,016 - INFO - Node 0 has no relations. Retrying (1)
2025-11-05 14:13:56,016 - INFO - Node 0 has no relations. Retrying (2)
2025-11-05 14:13:56,017 - INFO - Node 0 has no relations. Retrying (3)
2025-11-05 14:13:56,017 - INFO - Node 0 has no relations. Retrying (4)
2025-11-05 14:13:56,017 - INFO - Node 0 has no relations. Retrying (5)
2025-11-05 14:13:56,018 - INFO - Could not sample any edges.
2025-11-05 14:13:56,018 - INFO - Skipping (generate_graph_query failed) (count=1)
2025-11-05 14:13:56,018 - INFO - --- Iter 1 ---
2025-11-05 14:13:56,019 - INFO - Already seen. Retrying (1)
2025-11-05 14:13:56,019 - INFO - No entities found for class provone.Port. Skipping.
2025-11-05 14:13:56,020 - INFO - Skipping (generate_graph_query failed) (count=2)
2025-11-05 14:13:56,020 - I

[]