In [1]:
import sqlite3
from utilities import binary_to_bitlist
import random
import networkx as nx
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


conn = sqlite3.connect("db_numprop-4_nestlim-100.db")
cursor = conn.cursor()

cursor.execute("SELECT * FROM data")
all_rows = cursor.fetchall()
column_names = [desc[0] for desc in cursor.description]



In [2]:
class BitCycler:
    def __init__(self, db_row, total_bits=16):
        """
        Initialize from a single DB row, extracting the 'category' bits and setting up cycling.
        """
        category_int = int(db_row[9])  # index 9 is the 'category' column
        self.bits = [int(bit) for bit in f'{category_int:0{total_bits}b}']
        self.index = 0
        self.total_bits = total_bits

    def get_next(self):
        """
        Return the next bit in the cycle, looping back to the beginning when done.
        """
        bit = self.bits[self.index]
        self.index = (self.index + 1) % self.total_bits
        return bit

    def reset(self):
        """
        Reset the cycle to the beginning.
        """
        self.index = 0


bitcycler = BitCycler(all_rows[0])



In [3]:
# Next correct value
current_bit_gen = None

def bit_generator(bits):
    """
    Yield each bit from a list of bits.
    """
    for bit in bits:
        yield bit

def set_category(category_int, total_bits=16):
    """
    Given a category integer, convert it into a bit list (of length total_bits)
    and initialize the global bit generator.
    """
    global current_bit_gen
    bits = [int(bit) for bit in f'{category_int:0{total_bits}b}']
    current_bit_gen = bit_generator(bits)

def get_next_correct():
    """
    Retrieve the next bit from the current category bit generator.
    Raises an error if the generator is exhausted.
    """
    global current_bit_gen
    try:
        return next(current_bit_gen)
    except StopIteration:
        raise ValueError("Category bit stream exhausted. Please set a new category.")


In [4]:
def find_allowable_combinations(tree, correct, assignments, x_counter=0):
    p, q, r, s = assignments

    f, *args = tree
    op = f


    # A helper to union results (for alternatives) while threading the counter.
    def union_results(results):
        union_list = []
        max_counter = x_counter  # starting counter
        for res, cnt in results:
            union_list.extend(res)
            max_counter = max(max_counter, cnt)
        return union_list, max_counter

    # A helper to combine two lists of constraint dictionaries.
    def combine(list1, list2):
        combined = []
        for d1 in list1:
            for d2 in list2:
                merged = d1.copy()
                conflict = False
                for key, value in d2.items():
                    if key in merged and merged[key] != value:
                        conflict = True
                        break
                    merged[key] = value
                if not conflict:
                    combined.append(merged)
        return combined

    # Process operators.
    if op == 'not':
        res, new_counter = find_allowable_combinations(args[0], 1 - correct, assignments, x_counter)
        return res, new_counter

    elif op == 'and':
        if correct == 1:
            left, counter_left = find_allowable_combinations(args[0], 1, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 1, assignments, counter_left)
            return combine(left, right), counter_right
        else:
            branch1_left, counter1 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 1, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 0, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 0, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter

    elif op == 'or':
        if correct == 1:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 1, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 0, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 1, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter
        else:
            left, counter_left = find_allowable_combinations(args[0], 0, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 0, assignments, counter_left)
            return combine(left, right), counter_right

    elif op == 'conditional':
        if correct == 1:
            branch1_left, counter1 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 1, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter
        else:
            left, counter_left = find_allowable_combinations(args[0], 1, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 0, assignments, counter_left)
            return combine(left, right), counter_right

    elif op == 'negated_conditional':
        if correct == 1:
            left, counter_left = find_allowable_combinations(args[0], 1, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 0, assignments, counter_left)
            return combine(left, right), counter_right
        else:
            branch1_left, counter1 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 1, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter

    elif op == 'biconditional':
        if correct == 1:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 1, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 0, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2)
            ])
            return union_list, final_counter
        else:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2)
            ])
            return union_list, final_counter

    elif op == 'negated_biconditional':
        if correct == 1:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2)
            ])
            return union_list, final_counter
        else:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 1, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 0, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2)
            ])
            return union_list, final_counter

    elif op == 'nand':
        if correct == 1:
            branch1_left, counter1 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 0, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter
        else:
            left, counter_left = find_allowable_combinations(args[0], 1, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 1, assignments, counter_left)
            return combine(left, right), counter_right

    elif op == 'nor':
        if correct == 1:
            left, counter_left = find_allowable_combinations(args[0], 0, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 0, assignments, counter_left)
            return combine(left, right), counter_right
        else:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 1, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter

    elif op in ('p', 'q', 'r', 's'):
        # Check if the predetermined value agrees with the expected
        val = {'p': p, 'q': q, 'r': r, 's': s}[op]
        return ([] if val != correct else [{}]), x_counter

    elif op == 'X':
        # For an unknown "X" node, assign a unique variable name.
        new_var = f"X_{x_counter}"
        return ([{new_var: correct}]), x_counter + 1





In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def compute_target(input_row, nn_first_prediction, tree_candidate):
    """
    Given a truth table row, the neural network's first prediction, and a 
    candidate tree formula (tree_candidate), use find_allowable_combinations 
    to choose the closest candidate target vector.
    """
    assignments = tuple(input_row)
    correct = bitcycler.get_next()

    # Run find_allowable_combinations using the tree candidate.
    candidate_list, non_terminal_count = find_allowable_combinations(tree_candidate, correct, assignments)

    if not candidate_list:
        raise ValueError("No candidate combinations found for tree candidate: " + str(tree_candidate))

    # Convert each candidate dictionary into a vector.
    candidate_vectors = [
        torch.tensor(
            [candidate.get(f"X_{i}", 0) for i in range(non_terminal_count)], dtype=torch.float
        )
        for candidate in candidate_list
    ]
    candidate_tensor = torch.stack(candidate_vectors)
    
    # Reshape nn_first_prediction to be a batch of size 1.
    nn_pred = nn_first_prediction.view(1, -1)
    
    # Compute Euclidean distances.
    distances = torch.norm(candidate_tensor - nn_pred, dim=1)
    best_idx = torch.argmin(distances).item()
    best_candidate = candidate_vectors[best_idx]

    return best_candidate


class Net(nn.Module):
    def __init__(self, input_size, output_size):
        """
        input_size: number of truth table inputs (here, 4 for P, Q, R, S)
        output_size: target vector length determined by non-terminals (e.g. 2)
        """
        super(Net, self).__init__()
        self.ln1 = nn.Linear(input_size, 16)
        self.ln2 = nn.Linear(16, 16)
        self.ln3 = nn.Linear(16, output_size)

    def forward(self, x):
        # Scale inputs from [0,1] to [-1, 1]
        x = 2 * x - 1
        if not isinstance(x, torch.Tensor):
            x = torch.Tensor(x)
        x = F.relu(self.ln1(x))
        x = F.relu(self.ln2(x))
        x = self.ln3(x)
        return torch.sigmoid(x)  # outputs in [0,1]


def round_prediction(pred, threshold=0.5):
    return (pred > threshold).float()


def binary_to_bitlist(n, total):
    return [int(a) for a in f'{n:0{total}b}']


def train_on_truth_table(nn_model, truth_table, tree_candidate, max_iterations_per_row=1000):
    """
    For each truth table row, train the network until its rounded output
    matches the target computed for the given tree_candidate.
    
    Args:
        nn_model: The neural network model.
        truth_table: Truth table as a numpy array.
        tree_candidate: A candidate tree structure to be used in compute_target.
        max_iterations_per_row: Maximum number of training iterations per truth table row.

    Returns:
        iterations_per_row: A list of iteration counts for each truth table row.
    """
    optimizer = torch.optim.Adam(nn_model.parameters())
    iterations_per_row = []
    
    for i, row in enumerate(truth_table):
        row_input = np.array([row], dtype=np.float32) 
        
        with torch.no_grad():
            initial_pred = nn_model(torch.tensor(row_input))
        
        # Squeeze the prediction to shape (output_size,)
        initial_pred_vector = initial_pred.squeeze(0)
        
        # Compute the target for this row using the candidate tree.
        target = compute_target(row, initial_pred_vector, tree_candidate)
        
        training_iter = 0
        while training_iter < max_iterations_per_row:
            pred = nn_model(torch.tensor(row_input))
            pred_vector = pred.squeeze(0)
            pred_binary = round_prediction(pred_vector)
            
            # Check if the rounded prediction equals the target.
            if torch.equal(pred_binary, target):
                print(f"Row {i+1} correct after {training_iter} iterations: pred {pred_binary.numpy()} vs. target {target.numpy()}")
                break  # move to next row
            else:
                loss = F.binary_cross_entropy(pred, target.unsqueeze(0))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                training_iter += 1
                
        if training_iter == max_iterations_per_row:
            print(f"Row {i+1} did not converge in {max_iterations_per_row} iterations.")
        iterations_per_row.append(training_iter)
        
    return iterations_per_row


def evaluate_candidate_options(current_options, truth_table, nn_model, max_iterations_per_row=1000):
    """
    For each candidate tree option in the current_options dictionary, run the
    full training process on the truth table and record the total number of 
    training iterations across all rows.
    
    Args:
        current_options: A dictionary where keys are option indices and values 
                         are candidate tree formulas.
        truth_table: The full truth table (numpy array).
        nn_model: The neural network model.
        max_iterations_per_row: Maximum training iterations per truth table row.
        
    Returns:
        A dictionary mapping each candidate index to the total iteration count 
        (i.e. the total derivations it took) over the entire truth table.
    """
    candidate_iterations = {}
    
    # It might be beneficial to reinitialize or clone the model for each candidate option
    # if you want separate runs. Here, we assume nn_model is reinitialized externally
    # before each evaluation or that evaluation order doesn't matter.
    for idx, tree_candidate in current_options.items():
        print(f"\nEvaluating candidate option {idx} with tree: {tree_candidate}")
        
        # If you require an independent training run per candidate, you might need to
        # create a fresh instance of the model. For simplicity, we use the same model.
        iteration_counts = train_on_truth_table(nn_model, truth_table, tree_candidate, max_iterations_per_row)
        total_iterations = sum(iteration_counts)
        candidate_iterations[idx] = total_iterations
        print(f"Candidate {idx} total iterations: {total_iterations}")
        
    return candidate_iterations


# Example usage in your main pipeline:

if __name__ == "__main__":
    # Define your truth table (example with 16 rows for 4 variables)
    truth_table = np.array([
        [1, 1, 1, 1],
        [1, 1, 1, 0],
        [1, 1, 0, 1],
        [1, 1, 0, 0],
        [1, 0, 1, 1],
        [1, 0, 1, 0],
        [1, 0, 0, 1],
        [1, 0, 0, 0],
        [0, 1, 1, 1],
        [0, 1, 1, 0],
        [0, 1, 0, 1],
        [0, 1, 0, 0],
        [0, 0, 1, 1],
        [0, 0, 1, 0],
        [0, 0, 0, 1],
        [0, 0, 0, 0]
    ])
    
    input_size = 4

    # Needs to be replaced
    current_options = {
        0: ['or', ['and', 'q', 'X'], ['not', 'X']],
        1: ['and', ['or', 'X', 'q'], ['not', 'X']]
    }
    
    first_row = truth_table[0]
    dummy_tree = current_options[0]
    _, non_terminal_count = find_allowable_combinations(dummy_tree, correct=0, assignments=tuple(first_row))
    
    # Create dummy prediction of the proper size and get a dummy target
    dummy_prediction = torch.zeros(non_terminal_count)
    dummy_target = compute_target(first_row, dummy_prediction, dummy_tree)
    output_size = dummy_target.shape[0]
    
    nn_model = Net(input_size, output_size)
    
    # Evaluate each candidate option over the truth table.
    candidate_iteration_results = evaluate_candidate_options(current_options, truth_table, nn_model, max_iterations_per_row=1000)
   
    print("\nEvaluation results (candidate index : total iterations):")
    for idx, total_iters in candidate_iteration_results.items():
        print(f"Option {idx}: {total_iters} iterations")


Evaluating candidate option 0 with tree: ['or', ['and', 'q', 'X'], ['not', 'X']]
Row 1 correct after 0 iterations: pred [0. 0.] vs. target [0. 0.]
Row 2 correct after 0 iterations: pred [1. 0.] vs. target [1. 0.]
Row 3 correct after 0 iterations: pred [0. 0.] vs. target [0. 0.]
Row 4 correct after 0 iterations: pred [0. 0.] vs. target [0. 0.]
Row 5 correct after 3 iterations: pred [0. 0.] vs. target [0. 0.]
Row 6 correct after 0 iterations: pred [0. 0.] vs. target [0. 0.]
Row 7 correct after 0 iterations: pred [0. 0.] vs. target [0. 0.]
Row 8 correct after 8 iterations: pred [0. 1.] vs. target [0. 1.]
Row 9 correct after 13 iterations: pred [0. 1.] vs. target [0. 1.]
Row 10 correct after 0 iterations: pred [0. 1.] vs. target [0. 1.]
Row 11 correct after 2 iterations: pred [0. 1.] vs. target [0. 1.]
Row 12 correct after 0 iterations: pred [0. 1.] vs. target [0. 1.]
Row 13 correct after 0 iterations: pred [0. 1.] vs. target [0. 1.]
Row 14 correct after 0 iterations: pred [0. 1.] vs. tar

In [23]:
def extract_grammar_from_data_row(row, columns):
    """
    Build grammar from a row in the 'data' table.
    Each rule returns a nested tuple.
    """
    operator_names = {
        "A": "and",
        "O": "or",
        "C": "conditional",
        "NC": "not_conditional",
        "B": "biconditional",
        "X": "not_biconditional",
        "NA": "not_and",
        "NOR": "not_or",
        "N": "not",
    }

    grammar = {}
    for op, name in operator_names.items():
        if op in columns and row[columns.index(op)] == 1:
            if op == "N":
                grammar[op] = lambda X, name=name: (name, X)
            else:
                grammar[op] = lambda X, name=name: (name, X, X)
    return grammar



def expand_all_X(expr, grammar):
    """
    Recursively finds the leftmost 'X' in a nested tuple structure and replaces it
    with each possible grammar rule or terminal symbol.
    """
    if expr == "X":
        # Base case: single 'X' to replace
        expansions = []

        for rule in grammar.values():
            expansions.append(rule("X"))

        for terminal in ["p", "q", "r", "s"]:
            expansions.append(terminal)

        return expansions

    elif isinstance(expr, tuple):
        # Recursive case: traverse the structure to find the leftmost 'X'
        for i, sub in enumerate(expr):
            sub_expansions = expand_all_X(sub, grammar)
            if sub_expansions:
                # Replace the first expandable part and break
                results = []
                for new_sub in sub_expansions:
                    new_expr = list(expr)
                    new_expr[i] = new_sub
                    results.append(tuple(new_expr))
                return results
    return []


def run_derivation_for_row(row_idx, row, columns):
    """
    Expands from the starting expression 'X' using grammar derived from the row.
    This version performs a single iteration and returns the current expression
    along with the options dictionary. An external function can use the options dict
    to select the next node.

    Returns:
        current (str): The starting expression (or new node if already set).
        current_options (dict): Dictionary of expansion options indexed by integers.
    """
    print(f"Using row {row_idx}: {row}")
    grammar = extract_grammar_from_data_row(row, columns)
    
    # Start with the initial expression 'X'
    current = "X"

    # Get all possible expansions for the current expression
    expansions = expand_all_X(current, grammar)
    if not expansions:
        print("No expansions available.")
        return current, {}

    # Build an options dictionary mapping indices to expansion expressions
    current_options = {i: exp for i, exp in enumerate(expansions)}
    
    print(f"\nCurrent expression: {current}")
    print(f"Options dict:\n  {current_options}")

    # Return the current expression and options dictionary to be processed externally
    return current, current_options

# Plotting function for testing.
def plot_derivation_tree(G, title="Derivation Tree"):
    plt.figure(figsize=(10, 6))
    pos = nx.spring_layout(G, seed=42)
    nx.draw(G, pos, with_labels=True, node_size=1500, node_color="lightyellow", font_size=10, arrows=True)
    plt.title(title)
    plt.show()

def main_from_loaded_data(all_rows, column_names, row_index=0):
    if row_index >= len(all_rows):
        print(f"Row {row_index} out of range.")
        return
    row = all_rows[row_index]
    G = run_derivation_for_row(row_index, row, column_names)
    # plot_derivation_tree(G, f"Derivation from row {row_index}") plotting


if __name__ == "__main__":
    # Change the DB path and row_index to test other configurations
    main_from_loaded_data(all_rows, column_names, row_index=1000000)
    

Using row 1000000: (0, 0, 0, 1, 1, 0, 0, 1, 0, 34996, 5, 'B(q,B(C(r,s),C(p,B(q,r))))')

Current expression: X
Options dict:
  {0: ('conditional', 'X', 'X'), 1: ('biconditional', 'X', 'X'), 2: ('not_or', 'X', 'X'), 3: 'p', 4: 'q', 5: 'r', 6: 's'}
