-
Notifications
You must be signed in to change notification settings - Fork 172
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable JWT authentication in addition to old certificate authenticati…
…on (#40)
- Loading branch information
Showing
6 changed files
with
185 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
|
||
from hyper import HTTP20Connection | ||
from hyper.tls import init_context | ||
|
||
import jwt | ||
|
||
# For creating and comparing the time for the JWT token | ||
import time | ||
|
||
|
||
DEFAULT_TOKEN_LIFETIME = 3600 | ||
DEFAULT_TOKEN_ENCRYPTION_ALGORITHM = 'ES256' | ||
|
||
|
||
# Abstract Base class. This should not be instantiated directly. | ||
class Credentials(object): | ||
|
||
def __init__(self, ssl_context=None): | ||
self.__ssl_context = ssl_context | ||
|
||
# Creates a connection with the credentials, if available or necessary. | ||
def create_connection(self, server, port, proto): | ||
# self.__ssl_context may be none, and that's fine. | ||
return HTTP20Connection(server, port, | ||
ssl_context=self.__ssl_context, | ||
force_proto=proto or 'h2') | ||
|
||
def get_authorization_header(self, topic): | ||
return None | ||
|
||
|
||
# Credentials subclass for certificate authentication | ||
class CertificateCredentials(Credentials): | ||
def __init__(self, cert_file): | ||
ssl_context = init_context() | ||
ssl_context.load_cert_chain(cert_file) | ||
super(CertificateCredentials, self).__init__(ssl_context) | ||
|
||
|
||
# Credentials subclass for JWT token based authentication | ||
class TokenCredentials(Credentials): | ||
def __init__(self, auth_key_path, auth_key_id, team_id, | ||
encryption_algorithm=None, token_lifetime=None): | ||
self.__auth_key = self._get_signing_key(auth_key_path) | ||
self.__auth_key_id = auth_key_id | ||
self.__team_id = team_id | ||
self.__encryption_algorithm = DEFAULT_TOKEN_ENCRYPTION_ALGORITHM if \ | ||
encryption_algorithm is None else \ | ||
encryption_algorithm | ||
self.__token_lifetime = DEFAULT_TOKEN_LIFETIME if \ | ||
token_lifetime is None else token_lifetime | ||
|
||
# Dictionary of {topic: (issue time, ascii decoded token)} | ||
self.__topicTokens = {} | ||
|
||
# Use the default constructor because we don't have an SSL context | ||
super(TokenCredentials, self).__init__() | ||
|
||
def get_tokens(self): | ||
return [val[1] for val in self.__topicTokens] | ||
|
||
def get_authorization_header(self, topic): | ||
token = self._get_or_create_topic_token(topic) | ||
return 'bearer %s' % token | ||
|
||
def _isExpiredToken(self, issueDate): | ||
now = time.time() | ||
return now < issueDate + DEFAULT_TOKEN_LIFETIME | ||
|
||
def _get_or_create_topic_token(self, topic): | ||
# dict of topic to issue date and JWT token | ||
tokenPair = self.__topicTokens.get(topic) | ||
if tokenPair is None or self._isExpiredToken(tokenPair[0]): | ||
# Create a new token | ||
issuedAt = time.time() | ||
tokenDict = {'iss': self.__team_id, | ||
'iat': issuedAt | ||
} | ||
headers = {'alg': self.__encryption_algorithm, | ||
'kid': self.__auth_key_id, | ||
} | ||
jwtToken = jwt.encode(tokenDict, self.__auth_key, | ||
algorithm=self.__encryption_algorithm, | ||
headers=headers).decode('ascii') | ||
|
||
self.__topicTokens[topic] = (issuedAt, jwtToken) | ||
return jwtToken | ||
else: | ||
return tokenPair[1] | ||
|
||
def _get_signing_key(self, key_path): | ||
secret = '' | ||
if key_path: | ||
with open(key_path) as f: | ||
secret = f.read() | ||
return secret |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
-----BEGIN EC PARAMETERS----- | ||
BggqhkjOPQMBBw== | ||
-----END EC PARAMETERS----- | ||
-----BEGIN EC PRIVATE KEY----- | ||
MHcCAQEEIM0scFXkVBBc7d8DSL9rfFB1PvET/dWQa9eWfxpgaqaBoAoGCCqGSM49 | ||
AwEHoUQDQgAEnVxQ41VKMN6uSsSYCCdOUhCms+HT2VpUhFGML5SzYGoodKtRD/6J | ||
YI9Rxq1lPGHMwECSaPtPf9kVDCUM6UHvhA== | ||
-----END EC PRIVATE KEY----- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# This only tests the TokenCredentials test case, since the | ||
# CertificateCredentials would be mocked out anyway. | ||
# Namely: | ||
# - timing out of the token | ||
# - creating multiple tokens for different topics | ||
|
||
from unittest import TestCase, main | ||
|
||
import time | ||
|
||
from apns2.credentials import TokenCredentials | ||
|
||
|
||
class TokenCredentialsTestCase(TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.key_path = 'test/eckey.pem' | ||
cls.team_id = '3Z24IP123A' | ||
cls.key_id = '1QBCDJ9RST' | ||
cls.topics = ('com.example.first_app', 'com.example.second_app',) | ||
cls.token_lifetime = 0.5 | ||
|
||
def setUp(self): | ||
# Create an 'ephemeral' token so we can test token timeouts. We | ||
# want a timeout long enough to last the test, but we don't want to | ||
# slow down the tests too much either. | ||
self.normal_creds = TokenCredentials(self.key_path, self.key_id, | ||
self.team_id) | ||
self.lasting_header = self.normal_creds.get_authorization_header( | ||
self.topics[0]) | ||
self.expiring_creds = \ | ||
TokenCredentials(self.key_path, self.key_id, | ||
self.team_id, | ||
token_lifetime=self.token_lifetime) | ||
self.expiring_header = self.expiring_creds.get_authorization_header( | ||
self.topics[0]) | ||
|
||
def test_create_multiple_topics(self): | ||
h1 = self.normal_creds.get_authorization_header(self.topics[0]) | ||
self.assertEqual(len(self.normal_creds.get_tokens()), 1) | ||
h2 = self.normal_creds.get_authorization_header(self.topics[1]) | ||
self.assertNotEqual(h1, h2) | ||
self.assertEqual(len(self.normal_creds.get_tokens()), 2) | ||
|
||
def test_token_expiration(self): | ||
# As long as the token lifetime hasn't elapsed, this should work. To | ||
# be really careful, we should check how much time has elapsed to | ||
# know if it fail. But, either way, we'd have to come up with a good | ||
# lifetime for future tests... | ||
time.sleep(self.token_lifetime) | ||
h3 = self.expiring_creds.get_authorization_header(self.topics[0]) | ||
self.assertNotEqual(self.expiring_header, h3) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |