Skip to content

Commit

Permalink
Enable JWT authentication in addition to old certificate authenticati…
Browse files Browse the repository at this point in the history
…on (#40)
  • Loading branch information
Masa authored and Pr0Ger committed May 3, 2017
1 parent 164d146 commit ae835a1
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 12 deletions.
27 changes: 18 additions & 9 deletions apns2/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import logging
from enum import Enum

from hyper import HTTP20Connection
from hyper.tls import init_context

from .errors import ConnectionFailed, exception_class_for_reason

# We don't generally need to know about the Credentials subclasses except to
# keep the old API, where APNsClient took a cert_file
from .credentials import CertificateCredentials


class NotificationPriority(Enum):
Immediate = '10'
Expand All @@ -30,18 +31,22 @@ class APNsClient(object):
DEFAULT_PORT = 443
ALTERNATIVE_PORT = 2197

def __init__(self, cert_file, use_sandbox=False, use_alternative_port=False, proto=None, json_encoder=None, password=None):
ssl_context = init_context()
ssl_context.load_cert_chain(cert_file, password=password)
self._init_connection(use_sandbox, use_alternative_port, ssl_context, proto)
def __init__(self, credentials, use_sandbox=False, use_alternative_port=False, proto=None, json_encoder=None, password=None):
if credentials is None or isinstance(credentials, str):
self.__credentials = CertificateCredentials(credentials)
else:
self.__credentials = credentials
self._init_connection(use_sandbox, use_alternative_port, proto)

self.__json_encoder = json_encoder
self.__max_concurrent_streams = None
self.__previous_server_max_concurrent_streams = None

def _init_connection(self, use_sandbox, use_alternative_port, ssl_context, proto):
def _init_connection(self, use_sandbox, use_alternative_port, proto):
server = self.SANDBOX_SERVER if use_sandbox else self.LIVE_SERVER
port = self.ALTERNATIVE_PORT if use_alternative_port else self.DEFAULT_PORT
self._connection = HTTP20Connection(server, port, ssl_context=ssl_context, force_proto=proto or 'h2')
self._connection = self.__credentials.create_connection(server, port,
proto)

def send_notification(self, token_hex, notification, topic, priority=NotificationPriority.Immediate,
expiration=None, collapse_id=None):
Expand All @@ -62,6 +67,10 @@ def send_notification_async(self, token_hex, notification, topic, priority=Notif
if expiration is not None:
headers['apns-expiration'] = '%d' % expiration

auth_header = self.__credentials.get_authorization_header(topic)
if auth_header is not None:
headers['authorization'] = auth_header

if collapse_id is not None:
headers['apns-collapse-id'] = collapse_id

Expand Down
96 changes: 96 additions & 0 deletions apns2/credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@

from hyper import HTTP20Connection
from hyper.tls import init_context

import jwt

# For creating and comparing the time for the JWT token
import time


DEFAULT_TOKEN_LIFETIME = 3600
DEFAULT_TOKEN_ENCRYPTION_ALGORITHM = 'ES256'


# Abstract Base class. This should not be instantiated directly.
class Credentials(object):

def __init__(self, ssl_context=None):
self.__ssl_context = ssl_context

# Creates a connection with the credentials, if available or necessary.
def create_connection(self, server, port, proto):
# self.__ssl_context may be none, and that's fine.
return HTTP20Connection(server, port,
ssl_context=self.__ssl_context,
force_proto=proto or 'h2')

def get_authorization_header(self, topic):
return None


# Credentials subclass for certificate authentication
class CertificateCredentials(Credentials):
def __init__(self, cert_file):
ssl_context = init_context()
ssl_context.load_cert_chain(cert_file)
super(CertificateCredentials, self).__init__(ssl_context)


# Credentials subclass for JWT token based authentication
class TokenCredentials(Credentials):
def __init__(self, auth_key_path, auth_key_id, team_id,
encryption_algorithm=None, token_lifetime=None):
self.__auth_key = self._get_signing_key(auth_key_path)
self.__auth_key_id = auth_key_id
self.__team_id = team_id
self.__encryption_algorithm = DEFAULT_TOKEN_ENCRYPTION_ALGORITHM if \
encryption_algorithm is None else \
encryption_algorithm
self.__token_lifetime = DEFAULT_TOKEN_LIFETIME if \
token_lifetime is None else token_lifetime

# Dictionary of {topic: (issue time, ascii decoded token)}
self.__topicTokens = {}

# Use the default constructor because we don't have an SSL context
super(TokenCredentials, self).__init__()

def get_tokens(self):
return [val[1] for val in self.__topicTokens]

def get_authorization_header(self, topic):
token = self._get_or_create_topic_token(topic)
return 'bearer %s' % token

def _isExpiredToken(self, issueDate):
now = time.time()
return now < issueDate + DEFAULT_TOKEN_LIFETIME

def _get_or_create_topic_token(self, topic):
# dict of topic to issue date and JWT token
tokenPair = self.__topicTokens.get(topic)
if tokenPair is None or self._isExpiredToken(tokenPair[0]):
# Create a new token
issuedAt = time.time()
tokenDict = {'iss': self.__team_id,
'iat': issuedAt
}
headers = {'alg': self.__encryption_algorithm,
'kid': self.__auth_key_id,
}
jwtToken = jwt.encode(tokenDict, self.__auth_key,
algorithm=self.__encryption_algorithm,
headers=headers).decode('ascii')

self.__topicTokens[topic] = (issuedAt, jwtToken)
return jwtToken
else:
return tokenPair[1]

def _get_signing_key(self, key_path):
secret = ''
if key_path:
with open(key_path) as f:
secret = f.read()
return secret
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from setuptools import setup

dependencies = ['hyper>=0.7']
dependencies = [
'hyper>=0.7',
'PyJWT>=1.4.0',
'cryptography>=1.7.2',
]

try:
# noinspection PyUnresolvedReferences
Expand Down
8 changes: 8 additions & 0 deletions test/eckey.pem
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-----BEGIN EC PARAMETERS-----
BggqhkjOPQMBBw==
-----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIM0scFXkVBBc7d8DSL9rfFB1PvET/dWQa9eWfxpgaqaBoAoGCCqGSM49
AwEHoUQDQgAEnVxQ41VKMN6uSsSYCCdOUhCms+HT2VpUhFGML5SzYGoodKtRD/6J
YI9Rxq1lPGHMwECSaPtPf9kVDCUM6UHvhA==
-----END EC PRIVATE KEY-----
4 changes: 2 additions & 2 deletions test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def setUp(self):
self.mock_results = None
self.next_stream_id = 0

with patch('apns2.client.HTTP20Connection') as mock_connection_constructor, patch('apns2.client.init_context'):
with patch('apns2.credentials.HTTP20Connection') as mock_connection_constructor, patch('apns2.credentials.init_context'):
self.mock_connection = MagicMock()
self.mock_connection.get_response.side_effect = self.mock_get_response
self.mock_connection.request.side_effect = self.mock_request
self.mock_connection._conn.__enter__.return_value = self.mock_connection._conn
self.mock_connection._conn.remote_settings.max_concurrent_streams = 500
mock_connection_constructor.return_value = self.mock_connection
self.client = APNsClient(cert_file=None)
self.client = APNsClient(credentials=None)

@contextlib.contextmanager
def mock_get_response(self, stream_id):
Expand Down
56 changes: 56 additions & 0 deletions test/test_credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# This only tests the TokenCredentials test case, since the
# CertificateCredentials would be mocked out anyway.
# Namely:
# - timing out of the token
# - creating multiple tokens for different topics

from unittest import TestCase, main

import time

from apns2.credentials import TokenCredentials


class TokenCredentialsTestCase(TestCase):
@classmethod
def setUpClass(cls):
cls.key_path = 'test/eckey.pem'
cls.team_id = '3Z24IP123A'
cls.key_id = '1QBCDJ9RST'
cls.topics = ('com.example.first_app', 'com.example.second_app',)
cls.token_lifetime = 0.5

def setUp(self):
# Create an 'ephemeral' token so we can test token timeouts. We
# want a timeout long enough to last the test, but we don't want to
# slow down the tests too much either.
self.normal_creds = TokenCredentials(self.key_path, self.key_id,
self.team_id)
self.lasting_header = self.normal_creds.get_authorization_header(
self.topics[0])
self.expiring_creds = \
TokenCredentials(self.key_path, self.key_id,
self.team_id,
token_lifetime=self.token_lifetime)
self.expiring_header = self.expiring_creds.get_authorization_header(
self.topics[0])

def test_create_multiple_topics(self):
h1 = self.normal_creds.get_authorization_header(self.topics[0])
self.assertEqual(len(self.normal_creds.get_tokens()), 1)
h2 = self.normal_creds.get_authorization_header(self.topics[1])
self.assertNotEqual(h1, h2)
self.assertEqual(len(self.normal_creds.get_tokens()), 2)

def test_token_expiration(self):
# As long as the token lifetime hasn't elapsed, this should work. To
# be really careful, we should check how much time has elapsed to
# know if it fail. But, either way, we'd have to come up with a good
# lifetime for future tests...
time.sleep(self.token_lifetime)
h3 = self.expiring_creds.get_authorization_header(self.topics[0])
self.assertNotEqual(self.expiring_header, h3)


if __name__ == '__main__':
main()

0 comments on commit ae835a1

Please sign in to comment.