Skip to content

Commit

Permalink
[commands] Inject the internal variables for bot.say & co explicitly.
Browse files Browse the repository at this point in the history
This is to catch cases where it wouldn't fail to find it when
inspecting the stack to catch these stack variables.
  • Loading branch information
Rapptz committed Jan 9, 2016
1 parent ad800e2 commit 0a07fc0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 19 deletions.
29 changes: 16 additions & 13 deletions discord/ext/commands/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@
from .context import Context
from .errors import CommandNotFound

def _get_variable(name):
stack = inspect.stack()
try:
for frames in stack:
current_locals = frames[0].f_locals
if name in current_locals:
return current_locals[name]
finally:
del stack

def when_mentioned(bot, msg):
"""A callable that implements a command prefix equivalent
to being mentioned, e.g. ``@bot ``."""
Expand Down Expand Up @@ -71,13 +81,6 @@ def __init__(self, command_prefix, **options):

# internal helpers

def _get_variable(self, name):
stack = inspect.stack()
for frames in stack:
current_locals = frames[0].f_locals
if name in current_locals:
return current_locals[name]

def _get_prefix(self, message):
prefix = self.command_prefix
if callable(prefix):
Expand Down Expand Up @@ -122,7 +125,7 @@ def say(self, content):
content : str
The content to pass to :class:`Client.send_message`
"""
destination = self._get_variable('_internal_channel')
destination = _get_variable('_internal_channel')
result = yield from self.send_message(destination, content)
return result

Expand All @@ -141,7 +144,7 @@ def whisper(self, content):
content : str
The content to pass to :class:`Client.send_message`
"""
destination = self._get_variable('_internal_author')
destination = _get_variable('_internal_author')
result = yield from self.send_message(destination, content)
return result

Expand All @@ -161,8 +164,8 @@ def reply(self, content):
content : str
The content to pass to :class:`Client.send_message`
"""
author = self._get_variable('_internal_author')
destination = self._get_variable('_internal_channel')
author = _get_variable('_internal_author')
destination = _get_variable('_internal_channel')
fmt = '{0.mention}, {1}'.format(author, str(content))
result = yield from self.send_message(destination, fmt)
return result
Expand All @@ -184,7 +187,7 @@ def upload(self, fp, name=None):
name
The second parameter to pass to :meth:`Client.send_file`
"""
destination = self._get_variable('_internal_channel')
destination = _get_variable('_internal_channel')
result = yield from self.send_file(destination, fp, name)
return result

Expand All @@ -202,7 +205,7 @@ def type(self):
---------
The :meth:`Client.send_typing` function.
"""
destination = self._get_variable('_internal_channel')
destination = _get_variable('_internal_channel')
yield from self.send_typing(destination)

# listener registration
Expand Down
26 changes: 20 additions & 6 deletions discord/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,25 @@
import inspect
import re
import discord
from functools import partial
import functools

from .errors import *
from .view import quoted_word

__all__ = [ 'Command', 'Group', 'GroupMixin', 'command', 'group',
'has_role', 'has_permissions', 'has_any_role', 'check' ]

def inject_context(ctx, coro):
@functools.wraps(coro)
@asyncio.coroutine
def wrapped(*args, **kwargs):
_internal_channel = ctx.message.channel
_internal_author = ctx.message.author

ret = yield from coro(*args, **kwargs)
return ret
return wrapped

def _convert_to_bool(argument):
lowered = argument.lower()
if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
Expand Down Expand Up @@ -103,10 +114,11 @@ def handle_local_error(self, error, ctx):
except AttributeError:
return

injected = inject_context(ctx, coro)
if self.instance is not None:
discord.utils.create_task(coro(self.instance, error, ctx), loop=ctx.bot.loop)
discord.utils.create_task(injected(self.instance, error, ctx), loop=ctx.bot.loop)
else:
discord.utils.create_task(coro(error, ctx), loop=ctx.bot.loop)
discord.utils.create_task(injected(error, ctx), loop=ctx.bot.loop)

def _receive_item(self, message, argument, regex, receiver, generator):
match = re.match(regex, argument)
Expand Down Expand Up @@ -263,7 +275,8 @@ def invoke(self, ctx):
return

if self._parse_arguments(ctx):
yield from self.callback(*ctx.args, **ctx.kwargs)
injected = inject_context(ctx, self.callback)
yield from injected(*ctx.args, **ctx.kwargs)

def error(self, coro):
"""A decorator that registers a coroutine as a local error handler.
Expand Down Expand Up @@ -425,7 +438,8 @@ def invoke(self, ctx):
if trigger in self.commands:
ctx.invoked_subcommand = self.commands[trigger]

yield from self.callback(*ctx.args, **ctx.kwargs)
injected = inject_context(ctx, self.callback)
yield from injected(*ctx.args, **ctx.kwargs)

if ctx.invoked_subcommand:
ctx.invoked_with = trigger
Expand Down Expand Up @@ -616,7 +630,7 @@ def predicate(ctx):
if ch.is_private:
return False

getter = partial(discord.utils.get, msg.author.roles)
getter = functools.partial(discord.utils.get, msg.author.roles)
return any(getter(name=name) is not None for name in names)
return check(predicate)

Expand Down

0 comments on commit 0a07fc0

Please sign in to comment.