In [299]:
from enum import Enum, auto
from pydantic import BaseModel
from typing import Dict, Any, Optional
from typing_extensions import Protocol, runtime_checkable
from result import Result, Ok, Err
import ast


In [300]:
from typing import Callable
import warnings


class UnboundVariableFinder(ast.NodeVisitor):
    _imports: set[str]
    _assigned: set[str]
    _unbound: set[str]
    _target: Optional[str] = None

    def __init__(self):
        self._imports = set()
        self._assigned = set()
        self._unbound = set()

    def visit_Assign(self, node):
        raise Exception("Assignment is not allowed")

    def visit_ClassDef(self, node):
        raise Exception("Class definition is not allowed")

    def visit_NamedExpr(self, node):
        if isinstance(node.target, ast.Name):
            self._target = node.target.id
        self.generic_visit(node)

    def visit_Lambda(self, node):
        for arg in node.args.args:
            self._assigned.add(arg.arg)
        for kwarg in node.args.kwonlyargs:
            self._assigned.add(kwarg.arg)
        if node.args.vararg:
            self._assigned.add(node.args.vararg.arg)
        if node.args.kwarg:
            self._assigned.add(node.args.kwarg.arg)
        self.generic_visit(node)

    def visit_Name(self, node):
        if isinstance(node.ctx, ast.Load) and node.id not in self._assigned:
            self._unbound.add(node.id)
        self.generic_visit(node)

    def visit_GeneratorExp(self, node):
        for gen in node.generators:
            for target in gen.target.elts if isinstance(
                    gen.target, ast.Tuple) else [gen.target]:
                if isinstance(target, ast.Name):
                    self._assigned.add(target.id)
        self.generic_visit(node)

    def visit_Import(self, node):
        for alias in node.names:
            self._imports.add(alias.asname if alias.asname else alias.name)
        self.generic_visit(node)

    def visit_ImportFrom(self, node):
        for alias in node.names:
            self._imports.add(alias.asname if alias.asname else alias.name)
        self.generic_visit(node)

    @property
    def imports(self):
        return self._imports

    @property
    def builtin(self):
        return set(dir(__builtins__))

    @property
    def unbound(self):
        return self._unbound - self._imports - self.builtin

class ImportValidator(ast.NodeVisitor):
    def visit_Import(self, node):
        pass

    def visit_ImportFrom(self, node):
        pass

    def generic_visit(self, node):
        if not isinstance(node, (ast.Module, ast.alias)):
            dump = ast.dump(node, indent=2)
            raise ValueError(f"Invalid import statement {dump}")
        super().generic_visit(node)

class LazyExpr:
    MAGIC_FN_NAME = "lazy_expr"
    _raw: str
    _ast: ast.Module
    _imports: list[str]
    _finder: UnboundVariableFinder

    def __init__(self, raw: str, imports: Optional[list[str]] = None):
        self._raw = raw
        self._imports = imports if imports else []
        preload = "\n".join(self._imports)
        preload_ast = ast.parse(preload)

        import_validator = ImportValidator()
        import_validator.visit(preload_ast)

        raw_ast = ast.parse(raw)
        if not raw_ast.body:
            raise ValueError("Empty AST body")
        last = raw_ast.body[-1]
        if not isinstance(last, ast.Expr):
            raise ValueError("Last statement is not an expression")
        if len(raw_ast.body) > 1:
            warnings.warn(
                "Multiple expressions in the body. Only the last one will be evaluated"
            )

        func = ast.FunctionDef(name=self.MAGIC_FN_NAME,
                               args=ast.arguments(args=[],
                                                  vararg=None,
                                                  kwarg=None,
                                                  kwonlyargs=[],
                                                  kw_defaults=[],
                                                  posonlyargs=[],
                                                  defaults=[]),
                               body=[ast.Return(value=last.value)],
                               decorator_list=[])
        preload_ast.body.append(func)
        self._ast = preload_ast
        ast.fix_missing_locations(self._ast)
        self._finder = UnboundVariableFinder()
        self._finder.visit(self._ast)

    @property
    def unbound(self):
        """
        Returns the set of unbound variables in the function
        """
        return self._finder.unbound

    @property
    def imports(self):
        """
        Returns the set of imported variables in the function
        """
        return self._finder.imports
    
    @property
    def target(self):
        """
        If the expression is a named expression (defined with walrus operator `:=`), returns the name of the variable
        """
        return self._finder._target

    def eval(self, env: Optional[Dict[str, Any]] = None):
        compiled = compile(self._ast, "<string>", "exec")
        _env = {}
        # pylint: disable-next=exec-used
        exec(compiled, _env)
        if env:
            _env.update(env)
        return _env[self.MAGIC_FN_NAME]()

    # https://peps.python.org/pep-3102/
    def __call__(self, *args, env: Optional[Dict[str, Any]] = None, **kwargs):
        """
        Evaluates the LazyExpr as a function
        """
        fn = self.eval(env)
        if not isinstance(fn, Callable):
            raise TypeError(
                "Not a callable function. Actual type {} ({})".format(
                    type(fn), fn))
        return fn(*args, **kwargs)


In [302]:
# imp = ["from datetime import datetime", "from typing import List"]
imp = []

# s = "lambda x: all(isinstance(i, datetime) for i in x) if isinstance(x, list) else False"
s = "(aaaaa := 1)"
# s = """
# 1+1
# """
f = LazyExpr(s, imp)

# print(ast.dump(f._ast, indent=2))
print(f.unbound)
print(f.target)

f.eval(env={"a": 1024})

set()
aaaaa


1