In [3]:
import sqlite3
from utilities import binary_to_bitlist

conn = sqlite3.connect("db_numprop-4_nestlim-100.db")
cursor = conn.cursor()

cursor.execute("SELECT * FROM data")
all_rows = cursor.fetchall()
column_names = [desc[0] for desc in cursor.description]



In [5]:
class BitCycler:
    def __init__(self, db_row, total_bits=16):
        """
        Initialize from a single DB row, extracting the 'category' bits and setting up cycling.
        """
        category_int = int(db_row[9])  # index 9 is the 'category' column
        self.bits = [int(bit) for bit in f'{category_int:0{total_bits}b}']
        self.index = 0
        self.total_bits = total_bits

    def get_next(self):
        """
        Return the next bit in the cycle, looping back to the beginning when done.
        """
        bit = self.bits[self.index]
        self.index = (self.index + 1) % self.total_bits
        return bit

    def reset(self):
        """
        Reset the cycle to the beginning.
        """
        self.index = 0


bitcycler = BitCycler(all_rows[0])



In [6]:
# Next correct value
current_bit_gen = None

def bit_generator(bits):
    """
    Yield each bit from a list of bits.
    """
    for bit in bits:
        yield bit

def set_category(category_int, total_bits=16):
    """
    Given a category integer, convert it into a bit list (of length total_bits)
    and initialize the global bit generator.
    """
    global current_bit_gen
    bits = [int(bit) for bit in f'{category_int:0{total_bits}b}']
    current_bit_gen = bit_generator(bits)

def get_next_correct():
    """
    Retrieve the next bit from the current category bit generator.
    Raises an error if the generator is exhausted.
    """
    global current_bit_gen
    try:
        return next(current_bit_gen)
    except StopIteration:
        raise ValueError("Category bit stream exhausted. Please set a new category.")


In [7]:
def find_allowable_combinations(tree, correct, assignments, x_counter=0):
    p, q, r, s = assignments

    f, *args = tree
    op = f


    # A helper to union results (for alternatives) while threading the counter.
    def union_results(results):
        union_list = []
        max_counter = x_counter  # starting counter
        for res, cnt in results:
            union_list.extend(res)
            max_counter = max(max_counter, cnt)
        return union_list, max_counter

    # A helper to combine two lists of constraint dictionaries.
    def combine(list1, list2):
        combined = []
        for d1 in list1:
            for d2 in list2:
                merged = d1.copy()
                conflict = False
                for key, value in d2.items():
                    if key in merged and merged[key] != value:
                        conflict = True
                        break
                    merged[key] = value
                if not conflict:
                    combined.append(merged)
        return combined

    # Process operators.
    if op == 'not':
        res, new_counter = find_allowable_combinations(args[0], 1 - correct, assignments, x_counter)
        return res, new_counter

    elif op == 'and':
        if correct == 1:
            left, counter_left = find_allowable_combinations(args[0], 1, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 1, assignments, counter_left)
            return combine(left, right), counter_right
        else:
            branch1_left, counter1 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 1, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 0, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 0, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter

    elif op == 'or':
        if correct == 1:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 1, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 0, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 1, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter
        else:
            left, counter_left = find_allowable_combinations(args[0], 0, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 0, assignments, counter_left)
            return combine(left, right), counter_right

    elif op == 'conditional':
        if correct == 1:
            branch1_left, counter1 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 1, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter
        else:
            left, counter_left = find_allowable_combinations(args[0], 1, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 0, assignments, counter_left)
            return combine(left, right), counter_right

    elif op == 'negated_conditional':
        if correct == 1:
            left, counter_left = find_allowable_combinations(args[0], 1, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 0, assignments, counter_left)
            return combine(left, right), counter_right
        else:
            branch1_left, counter1 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 1, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter

    elif op == 'biconditional':
        if correct == 1:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 1, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 0, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2)
            ])
            return union_list, final_counter
        else:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2)
            ])
            return union_list, final_counter

    elif op == 'negated_biconditional':
        if correct == 1:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2)
            ])
            return union_list, final_counter
        else:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 1, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 0, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2)
            ])
            return union_list, final_counter

    elif op == 'nand':
        if correct == 1:
            branch1_left, counter1 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 0, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter
        else:
            left, counter_left = find_allowable_combinations(args[0], 1, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 1, assignments, counter_left)
            return combine(left, right), counter_right

    elif op == 'nor':
        if correct == 1:
            left, counter_left = find_allowable_combinations(args[0], 0, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 0, assignments, counter_left)
            return combine(left, right), counter_right
        else:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 1, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter

    elif op in ('p', 'q', 'r', 's'):
        # Check if the predetermined value agrees with the expected
        val = {'p': p, 'q': q, 'r': r, 's': s}[op]
        return ([] if val != correct else [{}]), x_counter

    elif op == 'X':
        # For an unknown "X" node, assign a unique variable name.
        new_var = f"X_{x_counter}"
        return ([{new_var: correct}]), x_counter + 1





In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utilities import binary_to_bitlist  # assumed to exist


class Net(nn.Module):
    def __init__(self, input_size, output_size):
        """
        input_size: number of truth table inputs (here, 4 for P, Q, R, S)
        output_size: target vector length determined by non-terminals (e.g. 2)
        """
        super(Net, self).__init__()
        self.ln1 = nn.Linear(input_size, 16)
        self.ln2 = nn.Linear(16, 16)
        self.ln3 = nn.Linear(16, output_size)

    def forward(self, x):
        # Scale inputs from [0,1] to [-1, 1]
        x = 2 * x - 1
        if not isinstance(x, torch.Tensor):
            x = torch.Tensor(x)
        x = F.relu(self.ln1(x))
        x = F.relu(self.ln2(x))
        x = self.ln3(x)
        return torch.sigmoid(x)  # outputs in [0,1]


def round_prediction(pred, threshold=0.5):
    return (pred > threshold).float()

def binary_to_bitlist(n,total):
    return [int(a) for a in f'{n:0{total}b}']


def compute_target(input_row, nn_first_prediction):
    """
    Given a truth table row and the NN's first prediction,
    use find_allowable_combinations to choose the closest candidate target.
    """
    assignments = tuple(input_row)
    correct = bitcycler.get_next()
    candidate_list, non_terminal_count = find_allowable_combinations(tree, correct, assignments)


    # Convert each candidate dictionary into a vector
    candidate_vectors = [
        torch.tensor([candidate.get(f"X_{i}", 0) for i in range(non_terminal_count)], dtype=torch.float)
        for candidate in candidate_list
    ]

    candidate_tensor = torch.stack(candidate_vectors)  
    nn_first_prediction = nn_first_prediction.view(1, -1)  

    distances = torch.norm(candidate_tensor - nn_first_prediction, dim=1)
    best_idx = torch.argmin(distances).item()
    best_candidate = candidate_vectors[best_idx]

    return best_candidate



def train_on_truth_table(nn_model, truth_table, max_iterations_per_row=1000):
    """
    For each truth table row, train the network until its rounded output
    matches the target computed by the candidate function.
    
    Returns:
      A list of iteration counts for each truth table row.
    """
    optimizer = torch.optim.Adam(nn_model.parameters())
    iterations_per_row = []
    
    for i, row in enumerate(truth_table):
        row_input = np.array([row], dtype=np.float32) 
        

        with torch.no_grad():
            initial_pred = nn_model(torch.tensor(row_input))
        
        # Squeeze the prediction to shape (output_size,)
        initial_pred_vector = initial_pred.squeeze(0)
        
        # Compute the target using the updated candidate selection (closest to initial prediction).
        target = compute_target(row, initial_pred_vector)
        
        training_iter = 0
        while training_iter < max_iterations_per_row:
            pred = nn_model(torch.tensor(row_input))
            pred_vector = pred.squeeze(0)
            pred_binary = round_prediction(pred_vector)
            
            # Check if the rounded prediction equals the target.
            if torch.equal(pred_binary, target):
                print(f"Row {i+1} correct after {training_iter} iterations: pred {pred_binary.numpy()} vs. target {target.numpy()}")
                break  # move on to next row
            else:
                loss = F.binary_cross_entropy(pred, target.unsqueeze(0))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                training_iter += 1
                
        if training_iter == max_iterations_per_row:
            print(f"Row {i+1} did not converge in {max_iterations_per_row} iterations.")
        iterations_per_row.append(training_iter)
        
    return iterations_per_row


if __name__ == "__main__":
    truth_table = np.array([
        [1, 1, 1, 1],
        [1, 1, 1, 0],
        [1, 1, 0, 1],
        [1, 1, 0, 0],
        [1, 0, 1, 1],
        [1, 0, 1, 0],
        [1, 0, 0, 1],
        [1, 0, 0, 0],
        [0, 1, 1, 1],
        [0, 1, 1, 0],
        [0, 1, 0, 1],
        [0, 1, 0, 0],
        [0, 0, 1, 1],
        [0, 0, 1, 0],
        [0, 0, 0, 1],
        [0, 0, 0, 0]
    ])

    input_size = 4 

    # Detect the number of non-terminals (Xs) using the first row.
    first_row = truth_table[0]
    tree = ['or', ['and', 'q', 'X'], ['not', 'X']]
    
    # Fake call to find number of Xs using your function
    _, non_terminal_count = find_allowable_combinations(tree, correct=0, assignments=tuple(first_row))

    # Create dummy prediction of correct size
    dummy_prediction = torch.zeros(non_terminal_count)

    # Pass prediction to get correctly sized target
    dummy_target = compute_target(first_row, dummy_prediction)

    output_size = dummy_target.shape[0]

    # Initialize your model
    nn_model = Net(input_size, output_size)

    # Train on full truth table
    iteration_counts = train_on_truth_table(nn_model, truth_table, max_iterations_per_row=1000)

    # Report
    print("Iterations per truth table row:")
    for i, cnt in enumerate(iteration_counts):
        print(f"Row {i+1}: {cnt}")
    total_iterations = sum(iteration_counts)
    print(f"Total iterations for full truth table: {total_iterations}")



Row 1 correct after 0 iterations: pred [1. 1.] vs. target [1. 1.]
Row 2 correct after 0 iterations: pred [0. 0.] vs. target [0. 0.]
Row 3 correct after 0 iterations: pred [1. 0.] vs. target [1. 0.]
Row 4 correct after 0 iterations: pred [0. 0.] vs. target [0. 0.]
Row 5 correct after 1 iterations: pred [1. 0.] vs. target [1. 0.]
Row 6 correct after 0 iterations: pred [0. 0.] vs. target [0. 0.]
Row 7 correct after 9 iterations: pred [1. 1.] vs. target [1. 1.]
Row 8 correct after 24 iterations: pred [0. 1.] vs. target [0. 1.]
Row 9 correct after 0 iterations: pred [0. 1.] vs. target [0. 1.]
Row 10 correct after 0 iterations: pred [0. 1.] vs. target [0. 1.]
Row 11 correct after 9 iterations: pred [0. 1.] vs. target [0. 1.]
Row 12 correct after 5 iterations: pred [0. 1.] vs. target [0. 1.]
Row 13 correct after 0 iterations: pred [0. 1.] vs. target [0. 1.]
Row 14 correct after 0 iterations: pred [0. 1.] vs. target [0. 1.]
Row 15 correct after 26 iterations: pred [1. 0.] vs. target [1. 0.]
Ro

In [16]:
import sqlite3
import random
import networkx as nx
import matplotlib.pyplot as plt

def extract_grammar_from_data_row(row, columns):
    """
    Build grammar from a row in the 'data' table.
    Each rule returns a nested tuple of the form:
      - Binary: (operator_name, 'X', 'X')
      - Unary:  (operator_name, 'X')
    """
    operator_names = {
        "A": "and",
        "O": "or",
        "C": "conditional",
        "NC": "not_conditional",
        "B": "biconditional",
        "X": "not_biconditional",
        "NA": "not_and",
        "NOR": "not_or",
        "N": "not",
    }

    grammar = {}
    for op, name in operator_names.items():
        if op in columns and row[columns.index(op)] == 1:
            if op == "N":
                grammar[op] = lambda X, name=name: (name, X)
            else:
                grammar[op] = lambda X, name=name: (name, X, X)
    return grammar



def expand_all_X(expr, grammar):
    """
    Recursively finds the leftmost 'X' in a nested tuple structure and replaces it
    with each possible grammar rule or terminal symbol.
    """
    if expr == "X":
        # Base case: single 'X' to replace
        expansions = []

        for rule in grammar.values():
            expansions.append(rule("X"))

        for terminal in ["p", "q", "r", "s"]:
            expansions.append(terminal)

        return expansions

    elif isinstance(expr, tuple):
        # Recursive case: traverse the structure to find the leftmost 'X'
        for i, sub in enumerate(expr):
            sub_expansions = expand_all_X(sub, grammar)
            if sub_expansions:
                # Replace the first expandable part and break
                results = []
                for new_sub in sub_expansions:
                    new_expr = list(expr)
                    new_expr[i] = new_sub
                    results.append(tuple(new_expr))
                return results
    return []


def run_derivation_for_row(row_idx, row, columns):
    """
    Expands from 'X' using grammar derived from the row.
    At each step:
      - current_options is overwritten to contain only indexed expansions.
      - Simulated times are used internally.
      - Prints everything for debugging.
    """
    print(f"Using row {row_idx}: {row}")
    grammar = extract_grammar_from_data_row(row, columns)

    start_expr = "X"
    current = start_expr

    current_options = {}

    while True:
        expansions = expand_all_X(current, grammar)
        if not expansions:
            break

        # Overwrite the options dictionary with numbered entries
        current_options.clear()
        for i, exp in enumerate(expansions):
            current_options[i] = exp

        # Simulated evaluation â€” in final implementation, your network picks from `current_options`
        simulated_times = {i: random.uniform(0, 10) for i in current_options}

        print(f"\nCurrent expression: {current}")
        print(f"Options dict:\n  {current_options}")

        # Simulate network picking the best (lowest simulated time)
        best_index = min(simulated_times, key=simulated_times.get)
        best_exp = current_options[best_index]
        print(f"  Selected: {best_exp}")

        current = best_exp  # Move to the next node

    print(f"\n[Row {row_idx}] Grammar used: {list(grammar.keys())}")
    return current_options


# Plotting function for testing.
def plot_derivation_tree(G, title="Derivation Tree"):
    plt.figure(figsize=(10, 6))
    pos = nx.spring_layout(G, seed=42)
    nx.draw(G, pos, with_labels=True, node_size=1500, node_color="lightyellow", font_size=10, arrows=True)
    plt.title(title)
    plt.show()

def main_from_loaded_data(all_rows, column_names, row_index=0):
    if row_index >= len(all_rows):
        print(f"Row {row_index} out of range.")
        return
    row = all_rows[row_index]
    G = run_derivation_for_row(row_index, row, column_names)
    # plot_derivation_tree(G, f"Derivation from row {row_index}") plotting


if __name__ == "__main__":
    # Change the DB path and row_index to test other configurations
    main_from_loaded_data(all_rows, column_names, row_index=100000)
    

Using row 100000: (1, 0, 1, 0, 0, 0, 0, 0, 0, 28940, 17, 'O(N(O(p,O(q,N(r)))),N(O(N(p),O(N(O(q,N(O(r,s)))),N(O(N(r),N(s)))))))')

Current expression: X
Options dict:
  {0: ('or', 'X', 'X'), 1: ('not', 'X'), 2: 'p', 3: 'q', 4: 'r', 5: 's'}
  Selected: ('or', 'X', 'X')

Current expression: ('or', 'X', 'X')
Options dict:
  {0: ('or', ('or', 'X', 'X'), 'X'), 1: ('or', ('not', 'X'), 'X'), 2: ('or', 'p', 'X'), 3: ('or', 'q', 'X'), 4: ('or', 'r', 'X'), 5: ('or', 's', 'X')}
  Selected: ('or', 'q', 'X')

Current expression: ('or', 'q', 'X')
Options dict:
  {0: ('or', 'q', ('or', 'X', 'X')), 1: ('or', 'q', ('not', 'X')), 2: ('or', 'q', 'p'), 3: ('or', 'q', 'q'), 4: ('or', 'q', 'r'), 5: ('or', 'q', 's')}
  Selected: ('or', 'q', 'r')

[Row 100000] Grammar used: ['O', 'N']
