In [None]:
#| default_exp core

# safepython

In [None]:
#| export
from fastcore.utils import *
from fastcore.xtras import asdict
from inspect import currentframe,Parameter,signature

import json,importlib,linecache,re,inspect,uuid,ast,warnings,collections,time,asyncio,urllib.parse,dataclasses,shlex,urllib
import zlib,unicodedata,binascii,enum,secrets,pickle,contextlib,types,keyword,httpx
import heapq, bisect, html, struct, decimal, fractions, pprint, fnmatch, base64
import random, statistics, difflib, csv, string, textwrap, hashlib, copy, datetime as dt_mod
import xml.etree.ElementTree as ET,ipaddress,colorsys,cmath,traceback,sys
from datetime import datetime
from urllib.parse import quote,unquote,urlencode
from io import StringIO,BytesIO
from collections import Counter,deque

In [None]:
#| export
from fastcore.imports import __llmtools__
from RestrictedPython import utility_builtins, safe_builtins,limited_builtins
from RestrictedPython.transformer import RestrictingNodeTransformer, INSPECT_ATTRIBUTES, copy_locations
from restrictedpython_async import *

In [None]:
#| export
def _find_frame_dict(sentinel:str):
    "Find the globals dict containing sentinel, or calling frame's globals if no sentinel"
    frame = currentframe().f_back.f_back
    if not sentinel: return frame.f_globals
    while frame:
        if sentinel in frame.f_globals: return frame.f_globals
        frame = frame.f_back
    raise ValueError(f"Could not find {sentinel} in any scope")

In [None]:
#| export
__pytools__ = set('pyrun')

def allow(*c):
    for o in c:
        if isinstance(o, dict):
            __pytools__.update({k.__name__ for k in o})
            __pytools__.update({f'{k.__name__}.{m}' for k,v in o.items() for m in v})
        else: __pytools__.add(o)

In [None]:
#| export
all_builtins = safe_builtins | utility_builtins | limited_builtins | async_builtins | dict(
    dict=dict, list=list, set=set, tuple=tuple, frozenset=frozenset,
    __import__=__import__
)

In [None]:
#| export
def _safe_getattr(obj, name):
    val = getattr(obj, name)
    if callable(val):
        keys = [f"{cls.__name__}.{name}" for cls in type(obj).__mro__]
        keys += [f"{cls.__module__}.{cls.__qualname__}.{name}" for cls in type(obj).__mro__ if hasattr(cls, '__module__')]
        obj_name = getattr(obj, '__name__', None)
        if obj_name: keys.append(f"{obj_name}.{name}")
        if not any(k in (__llmtools__|__pytools__) for k in keys): raise AttributeError(f"Cannot access callable: {name}")
    return val

In [None]:
#| export
class _DirectPrint:
    def __init__(self, *a, **kw): pass
    def _call_print(self, *a, **kw): print(*a, **kw)
    def __call__(self, *a, **kw): print(*a, **kw)

In [None]:
#| export
class _Uncallable:
    def __init__(self, o, name):
        functools.update_wrapper(self, o)
        self._o,self._name = o,name
    def __call__(self, *a, **kw): raise PermissionError(f"Calling `{self._name}` is not permitted")
    def __getattr__(self, name): return getattr(self._o, name)
    def __repr__(self): return repr(self._o)

def _callable_ok(k, v, _ok):
    if k.endswith('_') or k in _ok: return True
    mod,qn = getattr(v, '__module__', None), getattr(v, '__qualname__', None)
    return bool(mod and qn and f"{mod}.{qn}" in _ok)

In [None]:
#| export
ALLOWED_DUNDERS = {'__name__', '__module__', '__doc__', '__qualname__', '__file__'}

class SafeTransformer(RestrictingNodeTransformer):
    def visit_Attribute(self, node):
        if node.attr.startswith('_') and node.attr != '_' and node.attr not in ALLOWED_DUNDERS:
            self.error(node, f'"{node.attr}" is an invalid attribute name because it starts with "_".')
        if node.attr.endswith('__roles__'):
            self.error(node, f'"{node.attr}" is an invalid attribute name because it ends with "__roles__".')
        if node.attr in INSPECT_ATTRIBUTES:
            self.error(node, f'"{node.attr}" is a restricted name, that is forbidden to access in RestrictedPython.')
        if isinstance(node.ctx, ast.Load):
            node = self.node_contents_visit(node)
            new_node = ast.Call(func=ast.Name('_getattr_', ast.Load()), args=[node.value, ast.Constant(node.attr)], keywords=[])
            copy_locations(new_node, node)
            return new_node
        elif isinstance(node.ctx, (ast.Store, ast.Del)):
            node = self.node_contents_visit(node)
            new_value = ast.Call(func=ast.Name('_write_', ast.Load()), args=[node.value], keywords=[])
            copy_locations(new_value, node.value)
            node.value = new_value
            return node
        else: raise NotImplementedError(f"Unknown ctx type: {type(node.ctx)}")

In [None]:
#| export
async def _run_python(code:str, g=None):
    _ok = __llmtools__|__pytools__
    tools = {k:(v if not callable(v) or _callable_ok(k,v,_ok) else _Uncallable(v,k))
        for k,v in g.items() if not k.startswith('_')}
    def unpack(a,*args): return list(a)
    rg = dict(__builtins__=all_builtins, _getattr_=_safe_getattr,
              _getitem_=lambda o,k: o[k], _getiter_=iter, _print_=_DirectPrint, _print=_DirectPrint(),
              _unpack_sequence_=unpack, _iter_unpack_sequence_=unpack,
              enumerate=enumerate, sorted=sorted, reversed=reversed, max=max, min=min, **tools)
    loc,errs = {},[]
    sout, serr = StringIO(), StringIO()
    async def run(src, is_exec=True):
        try:
            comp = compile_restricted(src, '<tool>', 'exec' if is_exec else 'eval', policy=SafeTransformer)
            res = eval(comp, rg, loc)
            if inspect.iscoroutine(res): res = await res
            return res
        except SyntaxError as e: errs.append(f'SyntaxError: {e}')
        except NameError as e: errs.append(f'`{e.name}` is not available in this sandbox; ask the user to add it to the available tools')
    def _export(): g.update({k:v for k,v in loc.items() if k.endswith('_') and not k.startswith('_')})
    def _result(res=None):
        _export()
        d = {}
        if (out := sout.getvalue()): d['stdout'] = out
        if (err := serr.getvalue()): d['stderr'] = err
        if errs: d['errors'] = '\n'.join(errs)
        if res is not None: d['result'] = res
        return d or None
    tree = ast.parse(code)
    with contextlib.redirect_stdout(sout), contextlib.redirect_stderr(serr), warnings.catch_warnings():
        warnings.filterwarnings('ignore', category=SyntaxWarning)
        if tree.body and isinstance(tree.body[-1], ast.Expr):
            last = tree.body.pop()
            if tree.body:
                await run(ast.unparse(ast.Module(tree.body, [])))
                if errs: return _result()
            res = await run(ast.unparse(ast.Expression(last.value)), False)
            return _result(res)
        await run(code)
        return _result()

In [None]:
#| export
class RunPython:
    def __init__(self, g=None, sentinel=None):
        if not g: g = _find_frame_dict(sentinel)
        self.g = g

    @property
    def __doc__(self):
        tools = ', '.join(sorted(__llmtools__|__pytools__))
        return f"""Execute restricted Python with access to LLM tools, returning last expression.
            `import` works in the usual way. All non-callable globals and non-callable attrs are usable.
            Callable globals are also usable if their name ends with `_` (but not `_`-prefixed).
            - This is an easy way for users to expose extra functions: `def my_helper_(...)`
            Callable object attrs are only accessible if `ClassName.method` is registered as a tool.
            Multiline code blocks can be used, including defining functions and variables, for use within the call.
            In addition most builtins are available, plus these symbols: {tools}

            **NB**: If `code` creates symbols that end with `_`, they will be exported by to the calling namespace.
            - This is how you can use symbols that either human or AI can use again later.
            Examples: `len([1,2,3])` (builtin); `add_msg(content="hi")` (tool); `df.shape` (non-callable attr);
            `[x**2 for x in range(5)]` (last expression returned); `sorted(my_dict.items())` (builtin + non-callable attr)"""

    async def __call__(self,
        code:str # Python code to execute, can be multiple lines, include functions, etc
    ): # A dict containing up to 4 keys for non-empty vals: `(stdout=, stderr=, errors=, result=)`
        return await _run_python(code, g=self.g)

In [None]:
pyrun = RunPython()

In [None]:
await pyrun('[]')

{'result': []}

In [None]:
await pyrun("print('tt')")

{'stdout': 'tt\n'}

In [None]:
def f(): warnings.warn('a warning')
allow('f')
await pyrun('print("asdf"); f(); 1+1')

{'stdout': 'asdf\n',
 'result': 2}

In [None]:
#| export
def safe_type(o:object):
    "Same as `type(o)`"
    return type(o)

In [None]:
#| export
_io_meths = ['getvalue', 'read', 'write', 'seek']

In [None]:
#| export
def docs(sym)->str:
    """Get documentation (signature, docstring, + docments if they exist) for `sym`.
    **NB**: This is not an llm tool, so must be run with pyrun(). `sym` must be available in the namespace."""
    return MarkdownRenderer(sym)._repr_markdown_()

In [None]:
#| export
allow({
    re: ['search', 'findall', 'sub', 'match', 'compile', 'split', 'escape', 'fullmatch', 'subn'],
    json: ['loads', 'dumps', 'load'],
    math: ['sqrt', 'floor', 'ceil', 'log', 'log2', 'log10', 'gcd', 'isnan', 'isinf',
        'exp', 'sin', 'cos', 'tan', 'atan2', 'radians', 'degrees', 'factorial', 'comb', 'perm', 'prod', 'isclose',
        'fsum', 'hypot', 'isfinite', 'copysign'],
    collections: ['Counter', 'defaultdict', 'deque', 'namedtuple', 'OrderedDict', 'ChainMap'],
    tuple: ['index', 'count'],
    float: ['is_integer', 'fromhex'],
    Counter: ['most_common'],
    dict: ['keys', 'values', 'items', 'get', 'update', 'pop', 'setdefault', 'copy'],
    list: ['append', 'copy', 'extend', 'index', 'insert', 'pop', 'remove', 'reverse', 'sort', 'count'],
    set: ['add', 'discard', 'intersection', 'union', 'difference', 'update',
        'symmetric_difference', 'issubset', 'issuperset', 'copy', 'pop', 'remove'],
    str: ['split', 'join', 'replace', 'strip', 'lstrip', 'rstrip', 'startswith', 'endswith', 'lower', 'upper',
        'find', 'count', 'format', 'isdigit', 'isalpha', 'title', 'encode', 'splitlines', 'removeprefix', 'removesuffix',
        'zfill', 'center', 'ljust', 'rjust', 'maketrans', 'translate', 'casefold', 'partition', 'rpartition'],
    bytes: ['decode', 'fromhex', 'hex'],
    int: ['to_bytes', 'from_bytes', 'bit_length'],
    Path: ['read_text', 'glob', 'iterdir', 'exists', 'read_bytes', 'is_file', 'is_dir', 'stat', 'resolve',
        'with_suffix', 'with_name', 'relative_to', 'match', 'joinpath'],
    asyncio: ['gather'], copy: ['deepcopy'], httpx: ['get', 'options'],
    itertools: ['chain', 'islice', 'groupby', 'product', 'permutations', 'combinations', 'accumulate', 'starmap', 'zip_longest',
        'pairwise', 'takewhile', 'dropwhile', 'filterfalse', 'compress', 'count', 'repeat', 'cycle', 'tee', 'batched'],
    functools: ['reduce', 'partial', 'lru_cache', 'cache', 'wraps', 'cmp_to_key', 'total_ordering'],
    textwrap: ['dedent', 'indent', 'wrap', 'shorten', 'fill'],
    datetime: ['now', 'fromisoformat', 'strftime', 'strptime', 'isoformat'],
    dt_mod: ['timedelta', 'date', 'time', 'timezone'],
    operator: ['itemgetter', 'attrgetter', 'add', 'mul', 'sub', 'truediv', 'neg', 'contains',
        'getitem', 'mod', 'eq', 'ne', 'lt', 'gt', 'or_', 'and_', 'not_', 'pow', 'floordiv', 'xor'],
    frozenset: ['intersection', 'union', 'difference', 'symmetric_difference', 'issubset', 'issuperset', 'copy'],
    StringIO: _io_meths, BytesIO: _io_meths,
    }, 'urlencode', 'quote', 'unquote', 'string', 'safe_type', 'docs'
)

In [None]:
#| export
allow({
    os.path: ['join', 'basename', 'dirname', 'splitext', 'exists', 'isfile', 'isdir', 'abspath',
        'relpath', 'expanduser', 'normpath'],
    base64: ['b64encode', 'b64decode', 'urlsafe_b64encode', 'urlsafe_b64decode'],
    hashlib: ['md5', 'sha256'],
    random: ['choice', 'randint', 'sample', 'shuffle', 'uniform', 'random'],
    statistics: ['mean', 'median', 'stdev'],
    difflib: ['unified_diff', 'ndiff'],
    csv: ['reader', 'DictReader'],
    heapq: ['nlargest', 'nsmallest', 'heappush', 'heappop'],
    bisect: ['bisect_left', 'bisect_right', 'insort'],
    html: ['escape', 'unescape'],
    struct: ['pack', 'unpack'],
    fnmatch: ['fnmatch', 'filter'],
    time: ['time', 'perf_counter'],
    urllib.parse: ['urlparse', 'parse_qs', 'parse_qsl', 'urlunparse', 'urljoin', 'quote_plus', 'unquote_plus'],
    dataclasses: ['dataclass', 'field', 'asdict', 'fields', 'replace', 'is_dataclass'],
    shlex: ['split', 'quote'],
    zlib: ['compress', 'decompress', 'crc32'],
    unicodedata: ['name', 'lookup', 'category', 'normalize'],
    binascii: ['hexlify', 'unhexlify'],
    enum: ['Enum', 'IntEnum'],
    secrets: ['token_hex', 'token_urlsafe'],
    deque: ['appendleft', 'popleft', 'rotate', 'extendleft'],
    ast: ['literal_eval', 'parse', 'dump', 'walk', 'unparse'],
    pickle: ['loads', 'dumps'],
    contextlib: ['suppress', 'contextmanager'],
    inspect: ['getsource', 'getsourcefile', 'getsourcelines', 'getmodule', 'getdoc', 'getmembers',
        'signature', 'isclass', 'isfunction', 'ismethod', 'ismodule', 'getfile'],
    keyword: ['iskeyword', 'kwlist'],
    ET: ['fromstring', 'tostring'],
    ET.Element: ['findall', 'find', 'get', 'iter'],
    ipaddress: ['ip_address', 'ip_network'],
    colorsys: ['rgb_to_hsv', 'hsv_to_rgb', 'rgb_to_hls'],
    cmath: ['phase', 'polar', 'rect', 'sqrt'],
    decimal: ['Decimal'], fractions: ['Fraction'],
    uuid: ['uuid4'], pprint: ['pformat'], types: ['SimpleNamespace'],
    traceback: ['format_exc'], sys: ['getsizeof'], warnings: ['warn'],
})

In [None]:
await pyrun('''
a = {"b":1}
list(a.items())
''')

{'result': [('b', 1)]}

In [None]:
await pyrun('Path().exists()')

{'result': True}

In [None]:
await pyrun("os.path.join('/foo', 'bar', 'baz.py')")

{'result': '/foo/bar/baz.py'}

In [None]:
await pyrun('a_=3')
a_

3

In [None]:
await pyrun('''aa_='33' ''')
await pyrun('''len(aa_) ''')

{'result': 2}

In [None]:
def g(): ...

In [None]:
await pyrun('inspect.getsource(g)')

{'result': 'def g(): ...\n'}

In [None]:
try: await pyrun('g()')
except PermissionError: print("Correct exception raised")
else: raise Exception("No exception")

Correct exception raised


In [None]:
await pyrun('re.compile("a")')

{'result': re.compile(r'a', re.UNICODE)}

In [None]:
from re import compile

In [None]:
await pyrun('compile("a")')

{'result': re.compile(r'a', re.UNICODE)}

In [None]:
await pyrun('''
dict(a=safe_type(1))
''')

{'result': {'a': int}}

In [None]:
await pyrun("""
async def agen():
    for x in [1,2]: yield x
res = []
async for x in agen(): res.append(x)
res
""")

{'result': [1, 2]}

In [None]:
await pyrun('''
import asyncio
async def fetch(n): return n * 10
print(string.ascii_letters)
await asyncio.gather(fetch(1), fetch(2), fetch(3))
''')

{'stdout': 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ\n',
 'result': [10, 20, 30]}

In [None]:
import numpy as np

In [None]:
allow('numpy.array', 'numpy.ndarray.sum')
await pyrun('import numpy as np; np.array([1,2,3]).sum()')

{'result': 6}

## export -

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()