In [1]:
import ast
import astunparse
import math
from typing import List, Union

# Helper functions to create AST nodes for assignments and operations.
# These would typically be in a shared 'ast_lib.py' file.

def _make_load_name(name: str) -> ast.Name:
    """Creates an AST Name node for loading a variable."""
    return ast.Name(id=name, ctx=ast.Load())

def _make_store_name(name: str) -> ast.Name:
    """Creates an AST Name node for storing a variable."""
    return ast.Name(id=name, ctx=ast.Store())

def _make_assignment(target_name: str, value_node: ast.expr) -> ast.Assign:
    """Creates an AST assignment node."""
    return ast.Assign(targets=[_make_store_name(target_name)], value=value_node)

def _make_binary_op(left: ast.expr, op: ast.operator, right: ast.expr) -> ast.BinOp:
    """Creates an AST binary operation node."""
    return ast.BinOp(left=left, op=op, right=right)

def _make_or(left: ast.expr, right: ast.expr) -> ast.BinOp:
    """Creates an AST OR operation."""
    return _make_binary_op(left, ast.BitOr(), right)

def _make_and(left: ast.expr, right: ast.expr) -> ast.BinOp:
    """Creates an AST AND operation."""
    return _make_binary_op(left, ast.BitAnd(), right)

def _make_xor(left: ast.expr, right: ast.expr) -> ast.BinOp:
    """Creates an AST XOR operation."""
    return _make_binary_op(left, ast.BitXor(), right)


def _inline_half_adder(a: ast.expr, b: ast.expr, sum_name: str, carry_name: str) -> List[ast.stmt]:
    """
    Generates AST statements for a half adder (inlined).
    
    Args:
        a, b: AST expressions for the two input bits
        sum_name: Variable name for the sum output
        carry_name: Variable name for the carry output
    
    Returns:
        List of AST statements implementing the half adder
    """
    statements = []
    # sum = a XOR b
    sum_expr = _make_xor(a, b)
    statements.append(_make_assignment(sum_name, sum_expr))
    
    # carry = a AND b
    carry_expr = _make_and(a, b)
    statements.append(_make_assignment(carry_name, carry_expr))
    
    return statements


def _inline_full_adder(a: ast.expr, b: ast.expr, c: ast.expr, 
                       sum_name: str, carry_name: str) -> List[ast.stmt]:
    """
    Generates AST statements for a full adder (inlined).
    
    Args:
        a, b, c: AST expressions for the three input bits
        sum_name: Variable name for the sum output
        carry_name: Variable name for the carry output
    
    Returns:
        List of AST statements implementing the full adder
    """
    statements = []
    # sum = a XOR b XOR c
    sum_expr = _make_xor(_make_xor(a, b), c)
    statements.append(_make_assignment(sum_name, sum_expr))
    
    # carry = (a AND b) OR (b AND c) OR (a AND c)
    ab = _make_and(a, b)
    bc = _make_and(b, c)
    ac = _make_and(a, c)
    carry_expr = _make_or(_make_or(ab, bc), ac)
    statements.append(_make_assignment(carry_name, carry_expr))
    
    return statements

def generate_adder_tree_at_least_k(
    input_vars: List[Union[str, ast.Name]], k: int
) -> List[ast.stmt]:
    """
    Generates AST for an adder tree counter to check if at least k inputs are true.

    This version corrects the infinite loop by properly managing bits at each
    binary weight level during the tree construction.
    """
    if k == 0:
        return [_make_assignment("at_least_k", ast.Constant(value=True))]
    if not input_vars or k > len(input_vars):
        return [_make_assignment("at_least_k", ast.Constant(value=False))]

    n = len(input_vars)
    statements = []

    # --- Build the Adder Tree Layer by Layer (Corrected Logic) ---

    # bits_by_weight is a list of lists. index 0 holds all bits of weight 2^0,
    # index 1 holds bits of weight 2^1, and so on.
    bits_by_weight = [[] for _ in range(math.ceil(math.log2(n + 1)) + 1)]
    bits_by_weight[0] = [
        _make_load_name(var) if isinstance(var, str) else var for var in input_vars
    ]

    # Process until every weight level has at most one bit.
    # The 'work_to_do' check is the correct termination condition.
    work_to_do = any(len(bits) > 1 for bits in bits_by_weight)
    layer_num = 0

    while work_to_do:
        # Process each weight level that has more than one bit
        for weight, current_bits in enumerate(bits_by_weight):
            if len(current_bits) <= 1:
                continue

            new_sum_bits = []
            num_full_adders = len(current_bits) // 3
            
            # Process full adders (3 bits -> sum and carry)
            for i in range(num_full_adders):
                sum_name = f"s_{layer_num}_{weight}_{i}"
                carry_name = f"c_{layer_num}_{weight}_{i}"
                
                fa_inputs = current_bits[i*3 : i*3+3]
                fa_stmts = _inline_full_adder(fa_inputs[0], fa_inputs[1], fa_inputs[2], sum_name, carry_name)
                statements.extend(fa_stmts)

                new_sum_bits.append(_make_load_name(sum_name))
                bits_by_weight[weight + 1].append(_make_load_name(carry_name))

            # Process remaining bits
            remaining_bits = current_bits[num_full_adders*3:]
            if len(remaining_bits) == 2:
                # Half adder (2 bits -> sum and carry)
                sum_name = f"s_{layer_num}_{weight}_{num_full_adders}"
                carry_name = f"c_{layer_num}_{weight}_{num_full_adders}"
                
                ha_stmts = _inline_half_adder(remaining_bits[0], remaining_bits[1], sum_name, carry_name)
                statements.extend(ha_stmts)

                new_sum_bits.append(_make_load_name(sum_name))
                bits_by_weight[weight + 1].append(_make_load_name(carry_name))
            elif len(remaining_bits) == 1:
                # Pass through the single remaining bit
                new_sum_bits.append(remaining_bits[0])

            # The bits for the current weight are now the new sum bits
            bits_by_weight[weight] = new_sum_bits
        
        layer_num += 1
        work_to_do = any(len(bits) > 1 for bits in bits_by_weight)

    # --- The rest of the function remains the same ---

    # Flatten the list of lists to get the final count bits
    count_bits = []
    for bits in bits_by_weight:
        if bits:
            count_bits.append(bits[0])
        else:
            # If a weight level has no bits, it's equivalent to a 0
            count_bits.append(ast.Constant(value=False))
    
    # We need at most ceil(log2(n+1)) bits to represent the count
    num_count_bits = math.ceil(math.log2(n + 1))
    
    # Pad the count bits with False if the tree produced fewer bits
    while len(count_bits) < num_count_bits:
        count_bits.append(ast.Constant(value=False))

    # `k` in binary, padded to the same width
    k_binary = format(k, f'0{num_count_bits}b')

    # Initial state (for bits beyond MSB)
    prev_g = ast.Constant(value=False)
    prev_e = ast.Constant(value=True)

    for i in range(num_count_bits - 1, -1, -1):
        count_bit = count_bits[i]
        k_bit_is_one = (k_binary[num_count_bits - 1 - i] == '1')
        
        g_i_name = f"g_{i}"
        e_i_name = f"e_{i}"
        
        if k_bit_is_one:
            e_i_expr = _make_and(prev_e, count_bit)
        else:
            # CORRECTED: Use ast.Invert() for bitwise NOT (~)
            not_count_bit = ast.UnaryOp(op=ast.Invert(), operand=count_bit)
            e_i_expr = _make_and(prev_e, not_count_bit)
        statements.append(_make_assignment(e_i_name, e_i_expr))
        
        if not k_bit_is_one:
            and_term = _make_and(prev_e, count_bit)
            g_i_expr = _make_or(prev_g, and_term)
        else:
            g_i_expr = prev_g
        statements.append(_make_assignment(g_i_name, g_i_expr))
        
        prev_g = _make_load_name(g_i_name)
        prev_e = _make_load_name(e_i_name)

    at_least_k_expr = _make_or(prev_g, prev_e)
    statements.append(_make_assignment("at_least_k", at_least_k_expr))

    return statements

In [None]:
# --- Example Usage ---
# Demo: Generate code to check if at least 2 of 4 inputs are true
inputs = ['a', 'b', 'c', 'd']
k_val = 2

generated_ast_nodes = generate_adder_tree_at_least_k(inputs, k_val)
for statement in generated_ast_nodes:
    print(astunparse.unparse(statement))

# To demonstrate, we wrap the generated AST in a function
# and print the equivalent Python code.

# Create input parameters for the function
func_args = ast.arguments(
    args=[ast.arg(arg=name, annotation=None) for name in inputs],
    defaults=[],
    posonlyargs=[],
    kwonlyargs=[],
    kw_defaults=[]
)

# Create the function definition
func_def = ast.FunctionDef(
    name='check_at_least_2_of_4',
    args=func_args,
    body=generated_ast_nodes + [ast.Return(value=_make_load_name("at_least_k"))],
    decorator_list=[],
)

module = ast.Module(body=[func_def], type_ignores=[])

# Pretty-print the generated Python code
try:
    import astor
    print("--- Generated Python Code ---")
    print(astor.to_source(module))
except ImportError:
    print("Install 'astor' (`pip install astor`) to see the generated code.")

# You can also execute this code to test it
exec(compile(module, filename="<ast>", mode="exec"))




s_0_0_0 = ((a ^ b) ^ c)


c_0_0_0 = (((a & b) | (b & c)) | (a & c))


s_1_0_0 = (s_0_0_0 ^ d)


c_1_0_0 = (s_0_0_0 & d)


s_1_1_0 = (c_0_0_0 ^ c_1_0_0)


c_1_1_0 = (c_0_0_0 & c_1_0_0)


e_2 = (True & (~ c_1_1_0))


g_2 = (False | (True & c_1_1_0))


e_1 = (e_2 & s_1_1_0)


g_1 = g_2


e_0 = (e_1 & (~ s_1_0_0))


g_0 = (g_1 | (e_1 & s_1_0_0))


at_least_k = (g_0 | e_0)

--- Generated Python Code ---
def check_at_least_2_of_4(a, b, c, d):
    s_0_0_0 = a ^ b ^ c
    c_0_0_0 = a & b | b & c | a & c
    s_1_0_0 = s_0_0_0 ^ d
    c_1_0_0 = s_0_0_0 & d
    s_1_1_0 = c_0_0_0 ^ c_1_0_0
    c_1_1_0 = c_0_0_0 & c_1_0_0
    e_2 = True & ~c_1_1_0
    g_2 = False | True & c_1_1_0
    e_1 = e_2 & s_1_1_0
    g_1 = g_2
    e_0 = e_1 & ~s_1_0_0
    g_0 = g_1 | e_1 & s_1_0_0
    at_least_k = g_0 | e_0
    return at_least_k



TypeError: required field "lineno" missing from stmt

In [None]:
def check_at_least_2_of_4(a, b, c, d):
    s_0_0_0 = a ^ b ^ c
    c_0_0_0 = a & b | b & c | a & c
    s_1_0_0 = s_0_0_0 ^ d
    c_1_0_0 = s_0_0_0 & d
    s_1_1_0 = c_0_0_0 ^ c_1_0_0
    c_1_1_0 = c_0_0_0 & c_1_0_0
    e_2 = True & ~c_1_1_0
    g_2 = False | True & c_1_1_0
    e_1 = e_2 & s_1_1_0
    g_1 = g_2
    e_0 = e_1 & ~s_1_0_0
    g_0 = g_1 | e_1 & s_1_0_0
    at_least_k = g_0 | e_0
    return at_least_k

# Test the generated function
for a in [0, 1]:
    for b in [0, 1]:
        for c in [0, 1]:
            for d in [0, 1]:
                print(f"a={a}, b={b}, c={c}, d={d} -> at_least_2? {check_at_least_2_of_4(a, b, c, d)}")





a=0, b=0, c=0, d=0 -> at_least_2? 0
a=0, b=0, c=0, d=1 -> at_least_2? 0
a=0, b=0, c=1, d=0 -> at_least_2? 0
a=0, b=0, c=1, d=1 -> at_least_2? 1
a=0, b=1, c=0, d=0 -> at_least_2? 0
a=0, b=1, c=0, d=1 -> at_least_2? 1
a=0, b=1, c=1, d=0 -> at_least_2? 1
a=0, b=1, c=1, d=1 -> at_least_2? 1
a=1, b=0, c=0, d=0 -> at_least_2? 0
a=1, b=0, c=0, d=1 -> at_least_2? 1
a=1, b=0, c=1, d=0 -> at_least_2? 1
a=1, b=0, c=1, d=1 -> at_least_2? 1
a=1, b=1, c=0, d=0 -> at_least_2? 1
a=1, b=1, c=0, d=1 -> at_least_2? 1
a=1, b=1, c=1, d=0 -> at_least_2? 1
a=1, b=1, c=1, d=1 -> at_least_2? 1
