In [1]:
def find_allowable_combinations(tree, correct, assignments, x_counter=0):
    p, q, r, s = assignments

    f, *args = tree
    op = f


    # A helper to union results (for alternatives) while threading the counter.
    def union_results(results):
        union_list = []
        max_counter = x_counter  # starting counter
        for res, cnt in results:
            union_list.extend(res)
            max_counter = max(max_counter, cnt)
        return union_list, max_counter

    # A helper to combine two lists of constraint dictionaries.
    def combine(list1, list2):
        combined = []
        for d1 in list1:
            for d2 in list2:
                merged = d1.copy()
                conflict = False
                for key, value in d2.items():
                    if key in merged and merged[key] != value:
                        conflict = True
                        break
                    merged[key] = value
                if not conflict:
                    combined.append(merged)
        return combined

    # Process operators.
    if op == 'not':
        res, new_counter = find_allowable_combinations(args[0], 1 - correct, assignments, x_counter)
        return res, new_counter

    elif op == 'and':
        if correct == 1:
            left, counter_left = find_allowable_combinations(args[0], 1, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 1, assignments, counter_left)
            return combine(left, right), counter_right
        else:
            branch1_left, counter1 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 1, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 0, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 0, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter

    elif op == 'or':
        if correct == 1:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 1, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 0, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 1, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter
        else:
            left, counter_left = find_allowable_combinations(args[0], 0, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 0, assignments, counter_left)
            return combine(left, right), counter_right

    elif op == 'conditional':
        if correct == 1:
            branch1_left, counter1 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 1, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter
        else:
            left, counter_left = find_allowable_combinations(args[0], 1, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 0, assignments, counter_left)
            return combine(left, right), counter_right

    elif op == 'negated_conditional':
        if correct == 1:
            left, counter_left = find_allowable_combinations(args[0], 1, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 0, assignments, counter_left)
            return combine(left, right), counter_right
        else:
            branch1_left, counter1 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 1, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter

    elif op == 'biconditional':
        if correct == 1:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 1, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 0, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2)
            ])
            return union_list, final_counter
        else:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2)
            ])
            return union_list, final_counter

    elif op == 'negated_biconditional':
        if correct == 1:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2)
            ])
            return union_list, final_counter
        else:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 1, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 0, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2)
            ])
            return union_list, final_counter

    elif op == 'nand':
        if correct == 1:
            branch1_left, counter1 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 0, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter
        else:
            left, counter_left = find_allowable_combinations(args[0], 1, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 1, assignments, counter_left)
            return combine(left, right), counter_right

    elif op == 'nor':
        if correct == 1:
            left, counter_left = find_allowable_combinations(args[0], 0, assignments, x_counter)
            right, counter_right = find_allowable_combinations(args[1], 0, assignments, counter_left)
            return combine(left, right), counter_right
        else:
            branch1_left, counter1 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch1_right, counter1 = find_allowable_combinations(args[1], 0, assignments, counter1)
            poss1 = combine(branch1_left, branch1_right)

            branch2_left, counter2 = find_allowable_combinations(args[0], 0, assignments, x_counter)
            branch2_right, counter2 = find_allowable_combinations(args[1], 1, assignments, counter2)
            poss2 = combine(branch2_left, branch2_right)

            branch3_left, counter3 = find_allowable_combinations(args[0], 1, assignments, x_counter)
            branch3_right, counter3 = find_allowable_combinations(args[1], 1, assignments, counter3)
            poss3 = combine(branch3_left, branch3_right)

            union_list, final_counter = union_results([
                (poss1, counter1), (poss2, counter2), (poss3, counter3)
            ])
            return union_list, final_counter

    elif op in ('p', 'q', 'r', 's'):
        # Check if the predetermined value agrees with the expected
        val = {'p': p, 'q': q, 'r': r, 's': s}[op]
        return ([] if val != correct else [{}]), x_counter

    elif op == 'X':
        # For an unknown "X" node, assign a unique variable name.
        new_var = f"X_{x_counter}"
        return ([{new_var: correct}]), x_counter + 1





In [49]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utilities import binary_to_bitlist  # assumed to exist

# Assume find_allowable_combinations is available:
# from some_module import find_allowable_combinations

# -------------------------------------------------------------------
# 1. Define the network (adjusted to output a vector)
# -------------------------------------------------------------------
class Net(nn.Module):
    def __init__(self, input_size, output_size):
        """
        input_size: number of truth table inputs (here, 4 for P, Q, R, S)
        output_size: target vector length determined by non-terminals (e.g. 2)
        """
        super(Net, self).__init__()
        self.ln1 = nn.Linear(input_size, 16)
        self.ln2 = nn.Linear(16, 16)
        self.ln3 = nn.Linear(16, output_size)

    def forward(self, x):
        # Scale inputs from [0,1] to [-1, 1]
        x = 2 * x - 1
        if not isinstance(x, torch.Tensor):
            x = torch.Tensor(x)
        x = F.relu(self.ln1(x))
        x = F.relu(self.ln2(x))
        x = self.ln3(x)
        return torch.sigmoid(x)  # outputs in [0,1]

# -------------------------------------------------------------------
# 2. Helper function to round predictions (using 0.5 threshold)
# -------------------------------------------------------------------
def round_prediction(pred, threshold=0.5):
    return (pred > threshold).float()

# -------------------------------------------------------------------
# 3. Compute target based on the current truth table row
#    using the network's initial prediction.
# -------------------------------------------------------------------
def compute_target(input_row, nn_first_prediction):
    """
    Given a truth table row and the NN's first prediction,
    use find_allowable_combinations to choose the closest candidate target.
    """
    assignments = tuple(input_row)
    correct = 0

    candidate_list, non_terminal_count = find_allowable_combinations(tree, correct, assignments)

    # Convert each candidate dictionary into a vector
    candidate_vectors = [
        torch.tensor([candidate.get(f"X_{i}", 0) for i in range(non_terminal_count)], dtype=torch.float)
        for candidate in candidate_list
    ]

    candidate_tensor = torch.stack(candidate_vectors)  # shape: (num_candidates, num_X)
    nn_first_prediction = nn_first_prediction.view(1, -1)  # shape: (1, num_X)

    distances = torch.norm(candidate_tensor - nn_first_prediction, dim=1)
    best_idx = torch.argmin(distances).item()
    best_candidate = candidate_vectors[best_idx]

    return best_candidate


# -------------------------------------------------------------------
# 4. Train on the truth table sequentially, one row at a time.
# -------------------------------------------------------------------
def train_on_truth_table(nn_model, truth_table, max_iterations_per_row=1000):
    """
    For each truth table row, train the network until its rounded output
    matches the target computed by the candidate function.
    
    Returns:
      A list of iteration counts for each truth table row.
    """
    optimizer = torch.optim.Adam(nn_model.parameters())
    iterations_per_row = []
    
    for i, row in enumerate(truth_table):
        row_input = np.array([row], dtype=np.float32) 
        
        # Get the neural network's first prediction on this row.
        # Use torch.no_grad() so that this prediction isn't tracked by autograd.
        with torch.no_grad():
            initial_pred = nn_model(torch.tensor(row_input))
        
        # Squeeze the prediction to shape (output_size,)
        initial_pred_vector = initial_pred.squeeze(0)
        
        # Compute the target using the updated candidate selection (closest to initial prediction).
        target = compute_target(row, initial_pred_vector)
        
        training_iter = 0
        while training_iter < max_iterations_per_row:
            pred = nn_model(torch.tensor(row_input))
            pred_vector = pred.squeeze(0)
            pred_binary = round_prediction(pred_vector)
            
            # Check if the rounded prediction equals the target.
            if torch.equal(pred_binary, target):
                print(f"Row {i+1} correct after {training_iter} iterations: pred {pred_binary.numpy()} vs. target {target.numpy()}")
                break  # move on to next row
            else:
                loss = F.binary_cross_entropy(pred, target.unsqueeze(0))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                training_iter += 1
                
        if training_iter == max_iterations_per_row:
            print(f"Row {i+1} did not converge in {max_iterations_per_row} iterations.")
        iterations_per_row.append(training_iter)
        
    return iterations_per_row

# -------------------------------------------------------------------
# 5. Main: Prepare truth table, initialize network, and start training.
# -------------------------------------------------------------------
if __name__ == "__main__":
    # Fixed truth table for variables P, Q, R, S (16 rows).
    truth_table = np.array([
        [1, 1, 1, 1],
        [1, 1, 1, 0],
        [1, 1, 0, 1],
        [1, 1, 0, 0],
        [1, 0, 1, 1],
        [1, 0, 1, 0],
        [1, 0, 0, 1],
        [1, 0, 0, 0],
        [0, 1, 1, 1],
        [0, 1, 1, 0],
        [0, 1, 0, 1],
        [0, 1, 0, 0],
        [0, 0, 1, 1],
        [0, 0, 1, 0],
        [0, 0, 0, 1],
        [0, 0, 0, 0]
    ])

    input_size = 4  # P, Q, R, S

    # Detect the number of non-terminals (Xs) using the first row.
    first_row = truth_table[0]
    tree = ['and', ['or', ['not', 'X'], ['and', 'X', 'X']], ['conditional', ['or', 'r', 's'], ['and', 'X', 'q']]]
    
    # Fake call to find number of Xs using your function
    _, non_terminal_count = find_allowable_combinations(tree, correct=0, assignments=tuple(first_row))

    # Create dummy prediction of correct size
    dummy_prediction = torch.zeros(non_terminal_count)

    # Pass prediction to get correctly sized target
    dummy_target = compute_target(first_row, dummy_prediction)

    output_size = dummy_target.shape[0]

    # Initialize your model
    nn_model = Net(input_size, output_size)

    # Train on full truth table
    iteration_counts = train_on_truth_table(nn_model, truth_table, max_iterations_per_row=1000)

    # Report
    print("Iterations per truth table row:")
    for i, cnt in enumerate(iteration_counts):
        print(f"Row {i+1}: {cnt}")
    total_iterations = sum(iteration_counts)
    print(f"Total iterations for full truth table: {total_iterations}")



Row 1 correct after 29 iterations: pred [0. 1. 1. 0.] vs. target [0. 1. 1. 0.]
Row 2 correct after 0 iterations: pred [0. 1. 1. 0.] vs. target [0. 1. 1. 0.]
Row 3 correct after 9 iterations: pred [0. 1. 1. 0.] vs. target [0. 1. 1. 0.]
Row 4 correct after 66 iterations: pred [1. 0. 1. 0.] vs. target [1. 0. 1. 0.]
Row 5 correct after 0 iterations: pred [0. 0. 1. 0.] vs. target [0. 0. 1. 0.]
Row 6 correct after 0 iterations: pred [1. 0. 1. 0.] vs. target [1. 0. 1. 0.]
Row 7 correct after 0 iterations: pred [0. 0. 1. 0.] vs. target [0. 0. 1. 0.]
Row 8 correct after 0 iterations: pred [1. 0. 1. 0.] vs. target [1. 0. 1. 0.]
Row 9 correct after 0 iterations: pred [0. 0. 1. 0.] vs. target [0. 0. 1. 0.]
Row 10 correct after 0 iterations: pred [0. 0. 1. 0.] vs. target [0. 0. 1. 0.]
Row 11 correct after 0 iterations: pred [0. 0. 1. 0.] vs. target [0. 0. 1. 0.]
Row 12 correct after 7 iterations: pred [1. 0. 1. 0.] vs. target [1. 0. 1. 0.]
Row 13 correct after 0 iterations: pred [1. 0. 1. 0.] vs. t

Training the neural network...
Row 1 correct after 46 iterations: pred [1. 0.] vs. target [1. 0.]
Row 2 correct after 14 iterations: pred [1. 0.] vs. target [1. 0.]
Row 3 correct after 8 iterations: pred [1. 0.] vs. target [1. 0.]
Row 4 correct after 8 iterations: pred [1. 0.] vs. target [1. 0.]
Row 5 correct after 0 iterations: pred [1. 0.] vs. target [1. 0.]
Row 6 correct after 0 iterations: pred [1. 0.] vs. target [1. 0.]
Row 7 correct after 0 iterations: pred [1. 0.] vs. target [1. 0.]
Row 8 correct after 0 iterations: pred [1. 0.] vs. target [1. 0.]
Row 9 correct after 0 iterations: pred [1. 0.] vs. target [1. 0.]
Row 10 correct after 0 iterations: pred [1. 0.] vs. target [1. 0.]
Row 11 correct after 3 iterations: pred [1. 0.] vs. target [1. 0.]
Row 12 correct after 8 iterations: pred [1. 0.] vs. target [1. 0.]
Row 13 correct after 0 iterations: pred [1. 0.] vs. target [1. 0.]
Row 14 correct after 0 iterations: pred [1. 0.] vs. target [1. 0.]
Row 15 correct after 0 iterations: pre