Skip to content

Commit

Permalink
Add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
aluminiumgeek committed Mar 23, 2016
1 parent 2f3323c commit 3cae804
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import json
from threading import Thread
from typing import Any, Callable, List, Optional, Union

import requests
from requests.exceptions import ConnectionError, ReadTimeout
Expand All @@ -22,16 +23,16 @@ class TemporaryStore(object):
def __init__(self):
self.store = {}

def get(self, key, default):
def get(self, key: str, default: Any) -> Any:
return self.store.get(key, default)

def set(self, key, value):
def set(self, key: str, value: Any) -> None:
self.store[key] = value


class Executor(object):

def __init__(self, callback):
def __init__(self, callback: Callable[[asyncio.Task, str], None]):
self.callback = callback

# Run new event loop in another thread
Expand All @@ -45,13 +46,13 @@ def init_loop(self):
asyncio.set_event_loop(self.loop)
self.loop.run_forever()

def call(self, module, *args, **kwargs):
def call(self, module: Callable[..., Optional[str]], *args, **kwargs) -> asyncio.Task:
chat_id = kwargs.get('chat_id')
task = asyncio.run_coroutine_threadsafe(self.run(module, *args, **kwargs), self.loop)
task.add_done_callback(functools.partial(self.callback, chat_id=chat_id))
return task

async def run(self, module, *args, **kwargs):
async def run(self, module: Callable[..., Optional[str]], *args, **kwargs):
if asyncio.iscoroutinefunction(module):
return await module(*args, **kwargs)
else:
Expand Down Expand Up @@ -188,7 +189,7 @@ def start(self):
tasks += list(process)
logging.debug('Running tasks: {}'.format(tasks))

def process(self, update):
def process(self, update: dict) -> Union[asyncio.Task, List[asyncio.Task], None]:
"""
Process an update
"""
Expand Down Expand Up @@ -248,7 +249,7 @@ def process(self, update):
if obj is not None:
return [self.executor.call(self.commands[command_type][cmd], self, obj, update, chat_id=self.chat_id) for cmd in self.commands[command_type]]

def pre_send(self, chat_id=None, action='typing'):
def pre_send(self, chat_id: Optional[str]=None, action: str='typing') -> None:
"""
Pre send hook. Send 'typing...' or another chat action
"""
Expand All @@ -259,7 +260,7 @@ def pre_send(self, chat_id=None, action='typing'):
}
self.call('sendChatAction', 'POST', data=data)

def send(self, chat_id=None, text=None, data={}):
def send(self, chat_id: Optional[str]=None, text: Optional[str]=None, data: dict={}) -> None:
"""
Send message to a chat
"""
Expand All @@ -268,7 +269,7 @@ def send(self, chat_id=None, text=None, data={}):
logging.debug('Sending: {}'.format(data))
self.call('sendMessage', 'POST', data=data)

def get_updates(self):
def get_updates(self) -> List[dict]:
"""
Get updates from telegram
"""
Expand All @@ -284,7 +285,7 @@ def get_updates(self):
return updates
return []

def call(self, method_name, http_method, **kwargs):
def call(self, method_name: str, http_method: str, **kwargs):
"""
Call a Telegram API method
"""
Expand All @@ -311,8 +312,8 @@ def call(self, method_name, http_method, **kwargs):
logging.error(resp.content)
return []

def _is_owner(self, update):
def _is_owner(self, update: dict) -> bool:
return update.get('message', {}).get('from', {}).get('username', '') == self.settings.owner

def _get_chat_id(self, update):
def _get_chat_id(self, update: dict) -> Optional[str]:
return update.get('message', {}).get('chat', {}).get('id', None)

0 comments on commit 3cae804

Please sign in to comment.