Skip to content

Conversation

@jfeser
Copy link
Contributor

@jfeser jfeser commented Feb 12, 2025

Adds defdata, defterm implementations for iterables. These are aimed at supporting comprehension expressions rather than arbitrary iterables.

>>> x = defop(list[int], name='x')
>>> gen = ((i, j) for i in x() for j in x() if i > 0)
>>> print(str(defterm(gen)))
filter_map(('int()', 'int!1()'), gt(int(), 0), (_BaseOperation(<function _.<locals>.func at 0x11e5bc5e0>, name=int, freshening=[1164120249]), _BaseOperation(<function _.<locals>.func at 0x11e5bc700>, name=int, freshening=[1685202287])), ('x()', 'x()'))

@eb8680
Copy link
Contributor

eb8680 commented Feb 13, 2025

I think we might want to narrow the scope to generator expressions (i.e. things of type generator/collections.abc.Generator), which are more introspectable. Analyzing generator expressions involves some gross Python internal stuff but after some initial exploration it seems ChatGPT should be able to write most of that for us.

@eb8680
Copy link
Contributor

eb8680 commented Feb 13, 2025

There are a number of more full-featured Python bytecode decompilers but the more of the language they try to cover the less portable across versions they are. Here's the implementation ChatGPT suggested for a nested generator expression decompiler (a much simpler problem), which I haven't tested but seems plausible:

import dis
import types

def extract_expression_from_bytecode(code_obj):
    """
    Reconstructs the generator expression from bytecode, including:
    - Function calls
    - Operations (e.g., +, -, *, /)
    - If-condition filters
    """
    instructions = list(dis.get_instructions(code_obj))
    expr_parts = []
    filter_parts = []
    in_condition = False  # Tracks if we're inside an 'if' condition

    for instr in instructions:
        if instr.opname in {"LOAD_FAST", "LOAD_NAME"}:
            expr_parts.append(instr.argval)  # Variable names
        elif instr.opname == "LOAD_CONST":
            expr_parts.append(repr(instr.argval))  # Constants
        elif instr.opname == "CALL_FUNCTION":
            # Function calls (e.g., f(y))
            arg_count = instr.arg
            args = ", ".join(expr_parts[-arg_count:])
            func_name = expr_parts[-arg_count - 1]
            expr_parts = expr_parts[:-arg_count - 1]  # Remove used parts
            expr_parts.append(f"{func_name}({args})")
        elif instr.opname.startswith("BINARY_"):
            # Handle binary operations
            right = expr_parts.pop()
            left = expr_parts.pop()
            op_symbol = instr.opname.replace("BINARY_", "").lower()
            expr_parts.append(f"({left} {op_symbol} {right})")
        elif instr.opname in {"POP_JUMP_IF_FALSE", "POP_JUMP_FORWARD_IF_FALSE"}:
            # Start of a conditional filter (if-clause)
            in_condition = True
        elif in_condition:
            # Capture the condition (used after "if")
            if instr.opname in {"LOAD_FAST", "LOAD_NAME", "LOAD_CONST"}:
                filter_parts.append(instr.argval)
            elif instr.opname.startswith("BINARY_"):
                # Handle operations in conditions
                right = filter_parts.pop()
                left = filter_parts.pop()
                op_symbol = instr.opname.replace("BINARY_", "").lower()
                filter_parts.append(f"({left} {op_symbol} {right})")
            elif instr.opname == "COMPARE_OP":
                # Handle comparisons (e.g., x % 2 == 0)
                right = filter_parts.pop()
                left = filter_parts.pop()
                op = instr.argval
                filter_parts.append(f"({left} {op} {right})")
            elif instr.opname == "JUMP_FORWARD":
                # End of the condition
                in_condition = False
    
    # Final expression
    expr = " ".join(expr_parts) if expr_parts else "<unknown_expr>"
    filter_expr = " and ".join(filter_parts) if filter_parts else ""

    return f"{expr} if {filter_expr}" if filter_expr else expr

def reconstruct_generator_expression(gen, depth=0):
    """Recursively reconstructs generator expressions, including function calls and conditions."""
    indent = "  " * depth  # Indentation for nesting
    code = gen.gi_code

    # Extract the generator expression
    expr = extract_expression_from_bytecode(code)

    # Get variable names used in iteration
    iter_vars = ", ".join(code.co_varnames)

    # Display the reconstructed generator expression
    reconstructed = f"{indent}({expr} for {iter_vars})"
    print(reconstructed)

    # Recursively check for nested generators inside co_consts
    for const in code.co_consts:
        if isinstance(const, types.CodeType):  # If it's another generator code object
            print(f"{indent}↳ Nested Generator:")
            reconstruct_generator_expression(type((lambda: (yield))()), depth + 1)  # Dummy generator to extract structure

# Example usage with an if-condition:
def f(y): return y

xs = (x for x in range(10) if x % 2 == 0)  # Generator with if-condition
nested_xs = ((f(y) for y in ys if y > 2) for ys in ((y_ + i for y_ in range(k) if y_ != 1) for k in range(1, 4)))

print("Reconstructed Generators:")
reconstruct_generator_expression(xs)
print("\nReconstructed Nested Generators:")
reconstruct_generator_expression(nested_xs)

Expected output it generated:

(x for x in range(10) if x % 2 == 0)
# Reconstructed Generators:
# (x for x in range if (x % 2 == 0))

((f(y) for y in ys if y > 2) 
 for ys in ((y_ + i for y_ in range(k) if y_ != 1) for k in range(1, 4)))
# Reconstructed Nested Generators:
# (f(y) for y in ys if (y > 2))
# ↳ Nested Generator:
#   (y_ + i for y_ in range if (y_ != 1))

@eb8680
Copy link
Contributor

eb8680 commented Feb 13, 2025

We don't want a decompiler, exactly, but we need to do similar introspection to get hold of the generator expression components and recursively defterm them.

class _ComplexTerm(_NumberTerm, Term[numbers.Complex]):
def __bool__(self) -> bool:
raise ValueError("Cannot convert term to bool")
return bool_(self)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overloading __bool__ on terms may not be precise enough to implement the behavior we want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It definitely isn't as precise as we would like. Bool conversion inside terms won't behave correctly.

return _TensorTerm(op, *args, **kwargs)


@defterm.register(torch.Tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iterable is so broad that we'd probably need to implement many overrides like this. Generator is narrower by comparison but still covers the syntax we want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, agreed.

@jfeser
Copy link
Contributor Author

jfeser commented Feb 13, 2025

The current code isn't totally unreasonable on nested generators, but it doesn't preserve the nesting. For example:

>>> gen = ((i, j) for i in x() for j in x() if (i > 0))
>>> gen3 = ((j, k) for j in gen for k in gen)
>>> str(defterm(gen3))
"filter_map((('int()', 'int!1()'), ('int()', 'int!2()')), gt(int(), 0), (_BaseOperation(<function _.<locals>.func at 0x1362c2950>, name=int, freshening=[663389772]), _BaseOperation(<function _.<locals>.func at 0x1362c1b40>, name=int, freshening=[3360915926]), _BaseOperation(<function _.<locals>.func at 0x1362c1090>, name=int, freshening=[1660047089])), ('x()', 'x()', 'x()'))"

@jfeser
Copy link
Contributor Author

jfeser commented Feb 13, 2025

Here's the output from the gpt code:

>>> reconstruct_generator_expression(gen)
(.0(.0) i 0 i j None for .0, i, j)
>>> reconstruct_generator_expression(gen3)
(.0 j k None for .0, j, k)

The approach might work, but we'll need reasonably full-featured expression reconstruction.

@jfeser
Copy link
Contributor Author

jfeser commented Feb 19, 2025

This version is more precise, but requires wrappers for the iterables and guard:

For example:

x = defop(list[int], name='x')

gen = (i for i in _gen(x()))

gen1 = (i + 1 for i in _gen(x()))

gen4 = ((i, j) for i in _gen(x()) for j in _gen([1,2,3]) if _guard(i > 0))

gen2 = (j for i in _gen(x()) for j in _gen(y(i)))

gen3 = ((j, k) for j in _gen(gen) for k in _gen(gen))

Unfortunately, I don't think the _gen wrappers can be eliminated in general, even with bytecode manipulation. The _gen wrapper's purpose is to ensure that we can capture the inner iterables, not their iterators, because iterators are not introspectable. However, iterables are converted to iterators when a generator expression is evaluated, so even if we modify the bytecode of a generator, we will not be able to see these iterables unless we capture them before the generator expression is evaluated.

It is probably possible to remove the _guard wrapper with bytecode rewriting. However, the guard expression itself needs to be written using logical (&, ~, |) operations because the boolean operations are hardcoded to only work on booleans. Bytecode rewriting could help here, but we would need to reconstruct the boolean operations from the resulting short-circuit control flow.

@eb8680
Copy link
Contributor

eb8680 commented Jun 25, 2025

Closing in favor of #288

@eb8680 eb8680 closed this Jun 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants