Skip to content

Commit

Permalink
[Dev] Allow top-level await in code statements (#3508)
Browse files Browse the repository at this point in the history
* [dev] allow top-level await in code statements

* style

* use staticmethod, cls is unneeded

* add asyncio and aiohttp to env

* fix repl

* add __builtins__ to repl env

* style...

* fix debug with no coro

* add `optimize=0` to eval
  • Loading branch information
Zephyrkul committed Feb 13, 2020
1 parent cc30726 commit 42a2327
Showing 1 changed file with 40 additions and 18 deletions.
58 changes: 40 additions & 18 deletions redbot/core/dev_commands.py
@@ -1,8 +1,11 @@
import ast
import asyncio
import aiohttp
import inspect
import io
import textwrap
import traceback
import types
import re
from contextlib import redirect_stdout
from copy import copy
Expand Down Expand Up @@ -35,6 +38,19 @@ def __init__(self):
self._last_result = None
self.sessions = set()

@staticmethod
def async_compile(source, filename, mode):
return compile(source, filename, mode, flags=ast.PyCF_ALLOW_TOP_LEVEL_AWAIT, optimize=0)

@staticmethod
async def maybe_await(coro):
for i in range(2):
if inspect.isawaitable(coro):
coro = await coro
else:
return coro
return coro

@staticmethod
def cleanup_code(content):
"""Automatically removes code blocks from the code."""
Expand All @@ -53,7 +69,9 @@ def get_syntax_error(e):
"""
if e.text is None:
return box("{0.__class__.__name__}: {0}".format(e), lang="py")
return box("{0.text}{1:>{0.offset}}\n{2}: {0}".format(e, "^", type(e).__name__), lang="py")
return box(
"{0.text}\n{1:>{0.offset}}\n{2}: {0}".format(e, "^", type(e).__name__), lang="py"
)

@staticmethod
def get_pages(msg: str):
Expand All @@ -75,8 +93,8 @@ async def debug(self, ctx, *, code):
If the return value of the code is a coroutine, it will be awaited,
and the result of that will be the bot's response.
Note: Only one statement may be evaluated. Using await, yield or
similar restricted keywords will result in a syntax error. For multiple
Note: Only one statement may be evaluated. Using certain restricted
keywords, e.g. yield, will result in a syntax error. For multiple
lines or asynchronous code, see [p]repl or [p]eval.
Environment Variables:
Expand All @@ -96,6 +114,8 @@ async def debug(self, ctx, *, code):
"author": ctx.author,
"guild": ctx.guild,
"message": ctx.message,
"asyncio": asyncio,
"aiohttp": aiohttp,
"discord": discord,
"commands": commands,
"_": self._last_result,
Expand All @@ -104,17 +124,15 @@ async def debug(self, ctx, *, code):
code = self.cleanup_code(code)

try:
result = eval(code, env)
compiled = self.async_compile(code, "<string>", "eval")
result = await self.maybe_await(eval(compiled, env))
except SyntaxError as e:
await ctx.send(self.get_syntax_error(e))
return
except Exception as e:
await ctx.send(box("{}: {!s}".format(type(e).__name__, e), lang="py"))
return

if inspect.isawaitable(result):
result = await result

self._last_result = result
result = self.sanitize_output(ctx, str(result))

Expand Down Expand Up @@ -149,6 +167,8 @@ async def _eval(self, ctx, *, body: str):
"author": ctx.author,
"guild": ctx.guild,
"message": ctx.message,
"asyncio": asyncio,
"aiohttp": aiohttp,
"discord": discord,
"commands": commands,
"_": self._last_result,
Expand All @@ -160,7 +180,8 @@ async def _eval(self, ctx, *, body: str):
to_compile = "async def func():\n%s" % textwrap.indent(body, " ")

try:
exec(to_compile, env)
compiled = self.async_compile(to_compile, "<string>", "exec")
exec(compiled, env)
except SyntaxError as e:
return await ctx.send(self.get_syntax_error(e))

Expand Down Expand Up @@ -192,9 +213,6 @@ async def repl(self, ctx):
The REPL will only recognise code as messages which start with a
backtick. This includes codeblocks, and as such multiple lines can be
evaluated.
You may not await any code in this REPL unless you define it inside an
async function.
"""
variables = {
"ctx": ctx,
Expand All @@ -203,7 +221,9 @@ async def repl(self, ctx):
"guild": ctx.guild,
"channel": ctx.channel,
"author": ctx.author,
"asyncio": asyncio,
"_": None,
"__builtins__": __builtins__,
}

if ctx.channel.id in self.sessions:
Expand All @@ -225,19 +245,19 @@ async def repl(self, ctx):
self.sessions.remove(ctx.channel.id)
return

executor = exec
executor = None
if cleaned.count("\n") == 0:
# single statement, potentially 'eval'
try:
code = compile(cleaned, "<repl session>", "eval")
code = self.async_compile(cleaned, "<repl session>", "eval")
except SyntaxError:
pass
else:
executor = eval

if executor is exec:
if executor is None:
try:
code = compile(cleaned, "<repl session>", "exec")
code = self.async_compile(cleaned, "<repl session>", "exec")
except SyntaxError as e:
await ctx.send(self.get_syntax_error(e))
continue
Expand All @@ -250,9 +270,11 @@ async def repl(self, ctx):

try:
with redirect_stdout(stdout):
result = executor(code, variables)
if inspect.isawaitable(result):
result = await result
if executor is None:
result = types.FunctionType(code, variables)()
else:
result = executor(code, variables)
result = await self.maybe_await(result)
except:
value = stdout.getvalue()
msg = "{}{}".format(value, traceback.format_exc())
Expand Down

0 comments on commit 42a2327

Please sign in to comment.