In [63]:
from graphviz import Digraph
from IPython.display import display, display_svg, clear_output
import ipywidgets as widgets
import pylab
import re

Full tree

# Bottom-up search

In [50]:
# Create a Graphviz Digraph using bottom-up layout.
dot = Digraph(engine='dot', format='png')

dot.graph_attr.update(
    rankdir="BT", 
    size="300,300!", 
    margin="0.5", 
    overlap="false"
)

# Global counter for unique node IDs.
node_counter = 0
def new_node_id():
    global node_counter
    nid = f"n{node_counter}"
    node_counter += 1
    return nid

# Dictionary mapping expressions (strings) to node IDs.
nodes = {}
derivations = {}

# Base level: terminal expressions.
base_exprs = ["p", "q"]
all_levels = []  # Each element is a list of expressions at that level

base_level = []
for expr in base_exprs:
    nid = new_node_id()
    nodes[expr] = nid
    derivations[nid] = {"expr": expr, "rule": None, "children": []}
    base_level.append(expr)
    dot.node(nid, expr)
all_levels.append(base_level)

# Production functions.
def apply_unary(expr):
    # Unary rule: S -> ¬S.
    new_expr = f"¬{expr}"
    rule = "¬S"
    return new_expr, rule

def apply_binary(expr1, expr2, op):
    # Binary rule: S -> (S op S).
    new_expr = f"({expr1} {op} {expr2})"
    rule = f"(S {op} S)"
    return new_expr, rule

# Build larger expressions (bottom-up) up to max_depth production steps.
max_depth = 1  # Adjust for more levels.
for d in range(1, max_depth + 1):
    new_level_set = {}
    
    # Unary: apply to every expression from all previous levels.
    for level in all_levels:
        for expr in level:
            new_expr, rule = apply_unary(expr)
            if new_expr not in new_level_set:
                new_level_set[new_expr] = (rule, (expr,))
                
    # Binary: combine any two expressions from available levels.
    available = [expr for level in all_levels for expr in level]
    for expr1 in available:
        for expr2 in available:
            for op in ["∧", "→"]:
                new_expr, rule = apply_binary(expr1, expr2, op)
                if new_expr not in new_level_set:
                    new_level_set[new_expr] = (rule, (expr1, expr2))
    
    # Add new level nodes.
    new_level = list(new_level_set.items())
    current_level_exprs = []
    for expr, (rule, children_exprs) in new_level:
        if expr in nodes:
            continue
        nid = new_node_id()
        nodes[expr] = nid
        derivations[nid] = {"expr": expr, "rule": rule, "children": [nodes[child] for child in children_exprs]}
        current_level_exprs.append(expr)
        dot.node(nid, expr)
        for child_expr in children_exprs:
            child_nid = nodes[child_expr]
            dot.edge(child_nid, nid, label=rule)
    all_levels.append(current_level_exprs)

filename = dot.render(filename='imgs/bottom_up_d2.png')

In [51]:
# Create a Graphviz Digraph using bottom-up layout.
dot = Digraph(engine='neato', format='png')
dot.graph_attr.update(rankdir="BT", overlap="false")

# Global counter for unique node IDs.
node_counter = 0
def new_node_id():
    global node_counter
    nid = f"n{node_counter}"
    node_counter += 1
    return nid

# Dictionary mapping expressions (strings) to node IDs.
nodes = {}
derivations = {}

# Base level: terminal expressions.
base_exprs = ["p", "q"]
all_levels = []  # Each element is a list of expressions at that level

base_level = []
for expr in base_exprs:
    nid = new_node_id()
    nodes[expr] = nid
    derivations[nid] = {"expr": expr, "rule": None, "children": []}
    base_level.append(expr)
    dot.node(nid, expr)
all_levels.append(base_level)

# Production functions.
def apply_unary(expr):
    # Unary rule: S -> ¬S.
    new_expr = f"¬{expr}"
    rule = "¬S"
    return new_expr, rule

def apply_binary(expr1, expr2, op):
    # Binary rule: S -> (S op S).
    new_expr = f"({expr1} {op} {expr2})"
    rule = f"(S {op} S)"
    return new_expr, rule

# Build larger expressions (bottom-up) up to max_depth production steps.
max_depth = 2  # Adjust for more levels.
for d in range(1, max_depth + 1):
    new_level_set = {}
    
    # Unary: apply to every expression from all previous levels.
    for level in all_levels:
        for expr in level:
            new_expr, rule = apply_unary(expr)
            if new_expr not in new_level_set:
                new_level_set[new_expr] = (rule, (expr,))
                
    # Binary: combine any two expressions from available levels.
    available = [expr for level in all_levels for expr in level]
    for expr1 in available:
        for expr2 in available:
            for op in ["∧", "→"]:
                new_expr, rule = apply_binary(expr1, expr2, op)
                if new_expr not in new_level_set:
                    new_level_set[new_expr] = (rule, (expr1, expr2))
    
    # Add new level nodes.
    new_level = list(new_level_set.items())
    current_level_exprs = []
    for expr, (rule, children_exprs) in new_level:
        if expr in nodes:
            continue
        nid = new_node_id()
        nodes[expr] = nid
        derivations[nid] = {"expr": expr, "rule": rule, "children": [nodes[child] for child in children_exprs]}
        current_level_exprs.append(expr)
        dot.node(nid, expr)
        for child_expr in children_exprs:
            child_nid = nodes[child_expr]
            dot.edge(child_nid, nid, label=rule)
    all_levels.append(current_level_exprs)

dot

filename = dot.render(filename='imgs/bottom_up_d3.png')

In [52]:
# Create a Graphviz Digraph with bottom-up layout.
dot = Digraph(engine='dot', format='png')
dot.graph_attr.update(rankdir="BT", overlap="false")

# Global counter for unique node IDs.
node_counter = 0
def new_node_id():
    global node_counter
    nid = f"n{node_counter}"
    node_counter += 1
    return nid

# --- Semantic evaluation functions ---
def parse_formula(expr):
    """Parse a formula string into a simple syntax tree.
       The grammar handled is:
         formula -> "p" | "q" 
                  | "¬" formula 
                  | "(" formula "∧" formula ")"
                  | "(" formula "→" formula ")"
    """
    expr = expr.strip()
    if expr in ["p", "q"]:
        return expr
    if expr.startswith("¬"):
        return ("not", parse_formula(expr[1:]))
    if expr.startswith("(") and expr.endswith(")"):
        inside = expr[1:-1].strip()
        depth = 0
        main_op_index = None
        main_op = None
        for i, c in enumerate(inside):
            if c == '(':
                depth += 1
            elif c == ')':
                depth -= 1
            elif depth == 0 and c in ['∧', '→']:
                main_op_index = i
                main_op = c
                break
        if main_op is not None:
            left = inside[:main_op_index].strip()
            right = inside[main_op_index+1:].strip()
            return (main_op, parse_formula(left), parse_formula(right))
    raise ValueError("Could not parse: " + expr)

def evaluate_formula_tree(tree, env):
    """Evaluate the parsed syntax tree given an environment (mapping 'p' and 'q' to booleans)."""
    if isinstance(tree, str):
        return env[tree]
    elif isinstance(tree, tuple):
        op = tree[0]
        if op == "not":
            return not evaluate_formula_tree(tree[1], env)
        elif op == "∧":
            return evaluate_formula_tree(tree[1], env) and evaluate_formula_tree(tree[2], env)
        elif op == "→":
            return (not evaluate_formula_tree(tree[1], env)) or evaluate_formula_tree(tree[2], env)
        else:
            raise ValueError("Unknown operator: " + op)
    raise ValueError("Invalid tree structure")

def truth_table(expr):
    """Return a tuple of truth values for the expression over all assignments for p and q.
       The order is: (p=False,q=False), (False,True), (True,False), (True,True)
    """
    tree = parse_formula(expr)
    table = []
    for p_val in [False, True]:
        for q_val in [False, True]:
            env = {"p": p_val, "q": q_val}
            table.append(evaluate_formula_tree(tree, env))
    return tuple(table)

# --- Bottom-up generation with semantic pruning and display of pruned nodes ---
# Global dictionaries for canonical expressions.
nodes = {}           # Maps canonical expression (string) to node ID.
derivations = {}     # Records production info.
semantic_cache = {}  # Globally maps truth-table (tuple) -> canonical expression.
pruned_cache = {}    # Globally maps truth-table -> first pruned expression.

# Base level: terminal expressions.
base_exprs = ["p", "q"]
all_levels = []  # List of lists, each holding canonical expressions at that level.
base_level = []
for expr in base_exprs:
    sig = truth_table(expr)
    if sig in semantic_cache:
        continue
    semantic_cache[sig] = expr
    nid = new_node_id()
    nodes[expr] = nid
    derivations[nid] = {"expr": expr, "rule": None, "children": []}
    base_level.append(expr)
    dot.node(nid, expr)
all_levels.append(base_level)

# Production functions.
def apply_unary(expr):
    # Unary production: S -> ¬S.
    new_expr = f"¬{expr}"
    rule = "¬S"
    return new_expr, rule

def apply_binary(expr1, expr2, op):
    # Binary production: S -> (S op S).
    new_expr = f"({expr1} {op} {expr2})"
    rule = f"(S {op} S)"
    return new_expr, rule

# Generate larger expressions bottom-up.
max_depth = 3  # Adjust for more levels.
for d in range(1, max_depth + 1):
    new_level_set = {}       # Maps new_expr -> (rule, children tuple) for canonical candidates.
    new_level_semantics = {} # Maps truth table (sig) -> canonical expression in the current level.
    pruned_level_set = {}    # Maps pruned expression -> (rule, children tuple).
    
    # Only use canonical expressions from previous levels.
    available = [expr for level in all_levels for expr in level]
    
    # --- Unary productions ---
    for expr in available:
        new_expr, rule = apply_unary(expr)
        try:
            sig = truth_table(new_expr)
        except Exception:
            continue
        # If an equivalent expression already exists globally or in the current level, record it as pruned.
        if sig in semantic_cache or sig in new_level_semantics:
            if sig not in pruned_cache:
                pruned_cache[sig] = new_expr
                pruned_level_set[new_expr] = (rule, (expr,))
            continue
        new_level_set[new_expr] = (rule, (expr,))
        new_level_semantics[sig] = new_expr
    
    # --- Binary productions ---
    for expr1 in available:
        for expr2 in available:
            for op in ["∧", "→"]:
                new_expr, rule = apply_binary(expr1, expr2, op)
                try:
                    sig = truth_table(new_expr)
                except Exception:
                    continue
                if sig in semantic_cache or sig in new_level_semantics:
                    if sig not in pruned_cache:
                        pruned_cache[sig] = new_expr
                        pruned_level_set[new_expr] = (rule, (expr1, expr2))
                    continue
                new_level_set[new_expr] = (rule, (expr1, expr2))
                new_level_semantics[sig] = new_expr
    
    # --- Add canonical nodes for this level ---
    canonical_exprs = []
    for expr, (rule, children_exprs) in new_level_set.items():
        sig = truth_table(expr)
        semantic_cache[sig] = expr  # Update global canonical cache.
        nid = new_node_id()
        nodes[expr] = nid
        derivations[nid] = {"expr": expr, "rule": rule, "children": [nodes[child] for child in children_exprs]}
        canonical_exprs.append(expr)
        dot.node(nid, expr)
        for child_expr in children_exprs:
            child_nid = nodes[child_expr]
            dot.edge(child_nid, nid, label=rule)
    all_levels.append(canonical_exprs)
    
    # --- Add the first pruned node for each semantic signature in this level ---
    # Note: We do not add pruned nodes to 'all_levels' or update the 'nodes' mapping used for expansion.
    for expr, (rule, children_exprs) in pruned_level_set.items():
        # Draw pruned nodes in red.
        pid = new_node_id()
        # (Do not store pruned node in 'nodes' so that it is not used as parent later.)
        derivations[pid] = {"expr": expr, "rule": rule, "children": [nodes[child] for child in children_exprs]}
        dot.node(pid, expr, style="filled", fillcolor="red")
        for child_expr in children_exprs:
            child_nid = nodes[child_expr]
            dot.edge(child_nid, pid, label=rule)

filename = dot.render(filename='imgs/bottom_up_d3_pruned.png')

In [53]:
# Create a Graphviz Digraph using bottom-up layout.
dot = Digraph(engine='dot', format='png')
dot.graph_attr.update(rankdir="BT", overlap="false")

# Global counter for unique node IDs.
node_counter = 0
def new_node_id():
    global node_counter
    nid = f"n{node_counter}"
    node_counter += 1
    return nid

# --- Semantic evaluation functions ---

def parse_formula(expr):
    """Parse a formula string into a simple syntax tree.
       The grammar is:
         formula -> "p" | "q" 
                  | "¬" formula 
                  | "(" formula "∧" formula ")"
                  | "(" formula "→" formula ")"
    """
    expr = expr.strip()
    if expr in ["p", "q"]:
        return expr
    if expr.startswith("¬"):
        return ("not", parse_formula(expr[1:]))
    if expr.startswith("(") and expr.endswith(")"):
        inside = expr[1:-1].strip()
        depth = 0
        main_op_index = None
        main_op = None
        for i, c in enumerate(inside):
            if c == '(':
                depth += 1
            elif c == ')':
                depth -= 1
            # Look for the main operator at top level.
            elif depth == 0 and c in ['∧', '→']:
                main_op_index = i
                main_op = c
                break
        if main_op is not None:
            left = inside[:main_op_index].strip()
            right = inside[main_op_index+1:].strip()
            return (main_op, parse_formula(left), parse_formula(right))
    raise ValueError("Could not parse: " + expr)

def evaluate_formula_tree(tree, env):
    """Evaluate the parsed syntax tree given an environment (mapping p,q to booleans)."""
    if isinstance(tree, str):
        return env[tree]
    elif isinstance(tree, tuple):
        op = tree[0]
        if op == "not":
            return not evaluate_formula_tree(tree[1], env)
        elif op == "∧":
            return evaluate_formula_tree(tree[1], env) and evaluate_formula_tree(tree[2], env)
        elif op == "→":
            return (not evaluate_formula_tree(tree[1], env)) or evaluate_formula_tree(tree[2], env)
        else:
            raise ValueError("Unknown operator: " + op)
    raise ValueError("Invalid tree structure")

def truth_table(expr):
    """Return a tuple of truth values for the expression over all assignments of p and q.
       The order is: (p=False,q=False), (False,True), (True,False), (True,True)
    """
    tree = parse_formula(expr)
    table = []
    for p_val in [False, True]:
        for q_val in [False, True]:
            env = {"p": p_val, "q": q_val}
            table.append(evaluate_formula_tree(tree, env))
    return tuple(table)

# --- Bottom-up generation with semantic pruning ---

# Dictionaries to store nodes and semantic signatures.
nodes = {}        # Maps expression (string) to node ID.
derivations = {}  # Stores production info.
semantic_cache = {}  # Maps truth table (tuple) to canonical expression.

# Base level: terminal expressions.
base_exprs = ["p", "q"]
all_levels = []  # List of lists: each inner list holds expressions at that level.

base_level = []
for expr in base_exprs:
    sig = truth_table(expr)
    if sig in semantic_cache:
        continue
    semantic_cache[sig] = expr
    nid = new_node_id()
    nodes[expr] = nid
    derivations[nid] = {"expr": expr, "rule": None, "children": []}
    base_level.append(expr)
    dot.node(nid, expr)
all_levels.append(base_level)

# Production functions.
def apply_unary(expr):
    # Unary production: S -> ¬S.
    new_expr = f"¬{expr}"
    rule = "¬S"
    return new_expr, rule

def apply_binary(expr1, expr2, op):
    # Binary production: S -> (S op S).
    new_expr = f"({expr1} {op} {expr2})"
    rule = f"(S {op} S)"
    return new_expr, rule

# Build larger expressions (bottom-up) up to a given number of production steps.
max_depth = 3  # Adjust for more levels.
for d in range(1, max_depth + 1):
    new_level_set = {}
    
    # (1) Apply the unary rule to every expression from all previous levels.
    for level in all_levels:
        for expr in level:
            new_expr, rule = apply_unary(expr)
            try:
                sig = truth_table(new_expr)
            except Exception:
                continue
            if sig in semantic_cache:
                continue
            if new_expr not in new_level_set:
                new_level_set[new_expr] = (rule, (expr,))
    
    # (2) Apply binary rules: combine any two expressions from the available pool.
    available = [expr for level in all_levels for expr in level]
    for expr1 in available:
        for expr2 in available:
            for op in ["∧", "→"]:
                new_expr, rule = apply_binary(expr1, expr2, op)
                try:
                    sig = truth_table(new_expr)
                except Exception:
                    continue
                if sig in semantic_cache:
                    continue
                if new_expr not in new_level_set:
                    new_level_set[new_expr] = (rule, (expr1, expr2))
    
    # Add new level nodes, recording semantic signatures.
    new_level = list(new_level_set.items())  # Items are (expression, (rule, children_exprs)).
    current_level_exprs = []
    for expr, (rule, children_exprs) in new_level:
        try:
            sig = truth_table(expr)
        except Exception:
            continue
        if sig in semantic_cache:
            continue
        semantic_cache[sig] = expr
        nid = new_node_id()
        nodes[expr] = nid
        derivations[nid] = {"expr": expr, "rule": rule, "children": [nodes[child] for child in children_exprs]}
        current_level_exprs.append(expr)
        dot.node(nid, expr)
        # Draw edges from the children (subexpressions) to the new expression.
        for child_expr in children_exprs:
            child_nid = nodes[child_expr]
            dot.edge(child_nid, nid, label=rule)
    all_levels.append(current_level_exprs)

filename = dot.render(filename='imgs/bottom_up_d3_pruned_cleaned.png')

# Top-down search

In [61]:
# Use the 'twopi' engine for a radial layout.
dot = Digraph(engine='neato', format='png')
# Increase the canvas size and add margin to reduce overlap.
dot.graph_attr.update(
    size="30,30!", 
    margin="0.5", 
    overlap="false"
)

# Global counter for unique node IDs.
node_counter = 0
def new_node():
    global node_counter
    node_id = f"n{node_counter}"
    node_counter += 1
    return node_id

def add_derivation(dot, current, derivation, current_depth, max_depth, productions):
    """
    Recursively builds the derivation tree for the grammar.
    
    Parameters:
      - dot: the Graphviz Digraph.
      - current: the id of the current node.
      - derivation: the current derivation string.
      - current_depth: current depth in the tree.
      - max_depth: maximum allowed recursion depth.
      - productions: list of production rules (right-hand sides) to apply.
      
    Each production is applied to the leftmost occurrence of 'S'. Edge labels show the applied rule.
    Terminal nodes (derivations with no 'S') are colored light gray.
    """
    if current_depth >= max_depth or 'S' not in derivation:
        return
    
    for rule in productions:
        new_derivation = derivation.replace("S", rule, 1)
        child = new_node()
        # Color terminal nodes in light gray.
        if "S" not in new_derivation:
            dot.node(child, new_derivation, style="filled", fillcolor="lightgray")
        else:
            dot.node(child, new_derivation)
        dot.edge(current, child, label=rule)
        if 'S' in new_derivation:
            add_derivation(dot, child, new_derivation, current_depth + 1, max_depth, productions)

# Create the root of the tree with the starting symbol S.
start = new_node()
initial_derivation = "S"
dot.node(start, initial_derivation)

# Define the grammar productions:
# S -> p | q | (S ∧ S) | ¬S | (S → S)
productions = [
    "p", 
    "q", 
    "(S ∧ S)", 
    "¬S", 
    "(S → S)"
]

# Set a maximum depth to control the size of the tree.
max_depth = 2

add_derivation(dot, start, initial_derivation, 0, max_depth, productions)

# Render and display the derivation tree.

filename = dot.render(filename='imgs/top_down_d2.png')

In [59]:
# Use the 'twopi' engine for a radial layout.
dot = Digraph(engine='twopi', format='png')
# Increase the canvas size and add margin to reduce overlap.
dot.graph_attr.update(
    # size="30,30!", 
    # margin="0.5", 
    overlap="false"
)

# Global counter for unique node IDs.
node_counter = 0
def new_node():
    global node_counter
    node_id = f"n{node_counter}"
    node_counter += 1
    return node_id

def add_derivation(dot, current, derivation, current_depth, max_depth, productions):
    """
    Recursively builds the derivation tree for the grammar.
    
    Parameters:
      - dot: the Graphviz Digraph.
      - current: the id of the current node.
      - derivation: the current derivation string.
      - current_depth: current depth in the tree.
      - max_depth: maximum allowed recursion depth.
      - productions: list of production rules (right-hand sides) to apply.
      
    Each production is applied to the leftmost occurrence of 'S'. Edge labels show the applied rule.
    Terminal nodes (derivations with no 'S') are colored light gray.
    """
    if current_depth >= max_depth or 'S' not in derivation:
        return
    
    for rule in productions:
        new_derivation = derivation.replace("S", rule, 1)
        child = new_node()
        # Color terminal nodes in light gray.
        if "S" not in new_derivation:
            dot.node(child, new_derivation, style="filled", fillcolor="lightgray")
        else:
            dot.node(child, new_derivation)
        dot.edge(current, child, label=rule)
        if 'S' in new_derivation:
            add_derivation(dot, child, new_derivation, current_depth + 1, max_depth, productions)

# Create the root of the tree with the starting symbol S.
start = new_node()
initial_derivation = "S"
dot.node(start, initial_derivation)

# Define the grammar productions:
# S -> p | q | (S ∧ S) | ¬S | (S → S)
productions = [
    "p", 
    "q", 
    "(S ∧ S)", 
    "¬S", 
    "(S → S)"
]

# Set a maximum depth to control the size of the tree.
max_depth = 3

add_derivation(dot, start, initial_derivation, 0, max_depth, productions)

# Render and display the derivation tree.
filename = dot.render(filename='imgs/top_down_d3.png')

In [62]:
# Use a radial layout with extra spacing.
dot = Digraph(engine='twopi', format='png')
dot.graph_attr.update(overlap="false")

# Global counter for unique node IDs.
node_counter = 0
def new_node():
    global node_counter
    node_id = f"n{node_counter}"
    node_counter += 1
    return node_id

def add_derivation(dot, current, derivation, current_depth, max_depth, productions):
    """
    Recursively builds the derivation tree for the grammar.
    
    For each production, the leftmost occurrence of 'S' in the derivation is replaced 
    by the production. The edge is labeled with that production rule.
    
    Branches are pruned (i.e. not expanded further) when the derivation contains "¬¬S".
    In those cases, the node is colored in red.
    Terminal nodes (with no 'S' left) are colored in light gray.
    """
    # Stop if maximum depth is reached or if there's no S to replace.
    if current_depth >= max_depth or 'S' not in derivation:
        return
    
    for rule in productions:
        # Replace only the leftmost occurrence of S.
        new_derivation = derivation.replace("S", rule, 1)
        child = new_node()
        
        # Mark terminal nodes in light gray.
        if "S" not in new_derivation:
            dot.node(child, new_derivation, style="filled", fillcolor="lightgray")
        # Mark pruned nodes (if "¬¬S" is in the derivation) in red.
        elif "¬¬S" in new_derivation:
            dot.node(child, new_derivation, style="filled", fillcolor="red")
        else:
            dot.node(child, new_derivation)
        
        # Label the edge with the production rule applied.
        dot.edge(current, child, label=rule)
        
        # If the new derivation is pruned, do not expand further.
        if "¬¬S" in new_derivation:
            continue
        # Continue to expand if there's still a nonterminal.
        if 'S' in new_derivation:
            add_derivation(dot, child, new_derivation, current_depth + 1, max_depth, productions)

# Create the root node with the starting symbol S.
start = new_node()
initial_derivation = "S"
dot.node(start, initial_derivation)

# Define the grammar productions:
# S -> p | q | (S ∧ S) | ¬S | (S → S)
productions = [
    "p", 
    "q", 
    "(S ∧ S)", 
    "¬S", 
    "(S → S)"
]

# Set maximum depth to control tree size.
max_depth = 3

add_derivation(dot, start, initial_derivation, 0, max_depth, productions)

# Render and display the derivation tree.
filename = dot.render(filename='imgs/top_down_d3_pruned.png')

# Two types

In [85]:
dot = Digraph(engine='neato', format='png')
dot.graph_attr.update(overlap="false")

# Global counter to generate unique node IDs.
node_counter = 0
def new_node():
    global node_counter
    node_id = f"n{node_counter}"
    node_counter += 1
    return node_id

def add_derivation(dot, current, derivation, current_depth, max_depth, productions):
    """
    Recursively expands the derivation tree.
    
    At each node, the leftmost occurrence of "EXPR" is replaced by one of the productions.
    The edge label shows the rule applied.
    Terminal nodes (those without "EXPR") are colored in light gray.
    """
    if current_depth >= max_depth or "EXPR" not in derivation:
        return
    
    for prod in productions:
        # Replace only the leftmost occurrence of "EXPR" with the production.
        new_derivation = derivation.replace("EXPR", prod, 1)
        child = new_node()
        
        # If the new derivation has no "EXPR", it's terminal.
        if "EXPR" not in new_derivation:
            dot.node(child, new_derivation, style="filled", fillcolor="lightgray")
        else:
            dot.node(child, new_derivation)
        
        # Label the edge with the production rule.
        dot.edge(current, child, label=prod)
        
        add_derivation(dot, child, new_derivation, current_depth+1, max_depth, productions)

# Starting symbol.
start = new_node()
initial = "EXPR"
dot.node(start, initial)

# Define the grammar productions:
# EXPR -> (EXPR > EXPR) | ¬ EXPR | (EXPR \wedge EXPR) | 1 | (EXPR+EXPR)
productions = [
    "(EXPR > EXPR)",
    "¬ EXPR",
    "(EXPR ∧ EXPR)",
    "1",
    "(EXPR+EXPR)",
    "x"
]

max_depth = 2  # Adjust for deeper derivations.
add_derivation(dot, start, initial, 0, max_depth, productions)

filename = dot.render(filename='imgs/top_down_d2_2types.png')

In [83]:
### Type Inference Utilities

# Global counter for type variables.
type_var_counter = 0
def new_type_var():
    global type_var_counter
    tv = f"t{type_var_counter}"
    type_var_counter += 1
    return tv

def occurs(var, t):
    # In our simple setting, types are strings.
    return t == var

def apply_subst(subst, t):
    if t in subst:
        return apply_subst(subst, subst[t])
    return t

def unify(t1, t2, subst):
    t1 = apply_subst(subst, t1)
    t2 = apply_subst(subst, t2)
    if t1 == t2:
        return subst
    if t1.startswith("t"):
        if occurs(t1, t2):
            raise Exception("Occurs check fails")
        subst[t1] = t2
        return subst
    if t2.startswith("t"):
        if occurs(t2, t1):
            raise Exception("Occurs check fails")
        subst[t2] = t1
        return subst
    raise Exception(f"Cannot unify {t1} with {t2}")

### Parser for the Language

# Our language has the productions:
#   EXPR -> (EXPR > EXPR) | ¬EXPR | (EXPR ∧ EXPR) | 1 | (EXPR+EXPR) | x
#
# The literal "EXPR" is treated as a hole.

def parse_expr(s):
    s = s.replace(" ", "")
    if s == "1":
        return "1"
    if s == "x":
        return "x"
    if s == "EXPR":
        return "EXPR"
    if s.startswith("¬"):
        return ("not", parse_expr(s[1:]))
    if s[0] == "(" and s[-1] == ")":
        inner = s[1:-1]
        depth = 0
        main_op_index = None
        main_op = None
        for i, c in enumerate(inner):
            if c == "(":
                depth += 1
            elif c == ")":
                depth -= 1
            elif depth == 0 and c in [">", "∧", "+"]:
                main_op_index = i
                main_op = c
                break
        if main_op is not None:
            left = inner[:main_op_index]
            right = inner[main_op_index+1:]
            return (main_op, parse_expr(left), parse_expr(right))
    raise Exception("Cannot parse: " + s)

### Type Inference (simplified)

def infer(ast, subst):
    if ast == "1":
        return "num", subst
    if ast == "x":
        return "num", subst
    if ast == "EXPR":
        return new_type_var(), subst
    if isinstance(ast, tuple):
        op = ast[0]
        if op == "not":
            t_sub, subst = infer(ast[1], subst)
            subst = unify(t_sub, "bool", subst)
            return "bool", subst
        elif op == ">":
            t1, subst = infer(ast[1], subst)
            t2, subst = infer(ast[2], subst)
            subst = unify(t1, "num", subst)
            subst = unify(t2, "num", subst)
            return "bool", subst
        elif op == "∧":
            t1, subst = infer(ast[1], subst)
            t2, subst = infer(ast[2], subst)
            subst = unify(t1, "bool", subst)
            subst = unify(t2, "bool", subst)
            return "bool", subst
        elif op == "+":
            t1, subst = infer(ast[1], subst)
            t2, subst = infer(ast[2], subst)
            subst = unify(t1, "num", subst)
            subst = unify(t2, "num", subst)
            return "num", subst
        else:
            raise Exception("Unknown operator: " + op)
    raise Exception("Invalid AST")

def type_check(expr_str):
    ast = parse_expr(expr_str)
    subst = {}
    t, subst = infer(ast, subst)
    # Force overall type to be boolean.
    subst = unify(t, "bool", subst)
    return apply_subst(subst, t)

def is_well_typed(expr_str):
    try:
        t = type_check(expr_str)
        return t == "bool"
    except Exception:
        return False

### Top-Down Derivation Tree Generator

dot = Digraph(engine='neato', format='png')
dot.graph_attr.update(overlap="false")

# Global counter for node IDs.
node_counter = 0
def new_node():
    global node_counter
    nid = f"n{node_counter}"
    node_counter += 1
    return nid

def add_derivation(dot, current, derivation, current_depth, max_depth, productions):
    """
    Recursively expands the derivation tree.
    At each node, the leftmost occurrence of "EXPR" is replaced by one of the productions.
    Each edge is labeled with the production applied.
    
    After each derivation, type consistency is checked under the assumption that the overall type is bool.
      - If the node is ill typed, it is drawn in red and is not expanded further.
      - If the node is terminal (contains no "EXPR") and well typed, it is drawn in light gray.
      - Otherwise, it is drawn with default coloring and expanded.
    """
    # If type checking fails, mark node red and do not expand further.
    if not is_well_typed(derivation):
        dot.node(current, derivation, style="filled", fillcolor="red")
        return
    else:
        if "EXPR" not in derivation:
            dot.node(current, derivation, style="filled", fillcolor="lightgray")
        else:
            dot.node(current, derivation)
    
    # Stop if maximum depth reached or no more nonterminals.
    if current_depth >= max_depth or "EXPR" not in derivation:
        return
    
    for prod in productions:
        new_derivation = derivation.replace("EXPR", prod, 1)
        child = new_node()
        dot.edge(current, child, label=prod)
        add_derivation(dot, child, new_derivation, current_depth+1, max_depth, productions)

# Create the root node with the starting symbol.
start = new_node()
initial = "EXPR"
dot.node(start, initial)

# Define the grammar productions.
# (For ∧ we write "\\wedge" in the source and then replace it with "∧".)
productions = [
    "(EXPR>EXPR)",
    "¬EXPR",
    "(EXPR\\wedgeEXPR)",
    "1",
    "(EXPR+EXPR)",
    "x"
]
prods = [p.replace("\\wedge", "∧") for p in productions]

max_depth = 3  # Adjust for deeper or shallower trees.
add_derivation(dot, start, initial, 0, max_depth, prods)

filename = dot.render(filename='imgs/top_down_d3_2types_pruned.png')