In [1]:
import ast
import datetime
import re

import boto3
from envs import env
from jose import JWTError, jwt
import requests
from getpass import getpass

from aws_srp import AWSSRP
from exceptions import TokenVerificationException

from config import config

import uuid

In [15]:
class Garner:
    
    #user_class = UserObj
    #group_class = GroupObj
    
    def __init__(self, username=None, secret_key=None, password=None):
        
        self.user_pool_id = config["aws_user_pools_id"]
        self.user_pool_region = config["aws_cognito_region"]
        self.client_id = config["aws_user_pools_web_client_id"]
        
        self.username = username
        
        self.refresh_token = None
        self.token_type = None
        self.access_token = None
        self.pool_jwk = None
        self.id_token = None
        
        self.client = boto3.client("cognito-idp", region_name=self.user_pool_region)
        
        self.authenticate(password)
        
    def authenticate(self, password):
        """
        Authenticate the user using the SRP protocol
        :param password: The user's passsword
        :return:
        """
        if not password:
            print("Enter Password")
        aws = AWSSRP(
            username=self.username,
            password=(password if password else getpass("Password:")),
            pool_id=self.user_pool_id,
            client_id=self.client_id,
            client=self.client,
        )
        tokens = aws.authenticate_user('Password:')
        print("Authenticated")
        self.verify_token(tokens["AuthenticationResult"]["IdToken"], "id_token", "id")
        self.refresh_token = tokens["AuthenticationResult"]["RefreshToken"]
        self.verify_token(
            tokens["AuthenticationResult"]["AccessToken"], "access_token", "access"
        )
        self.token_type = tokens["AuthenticationResult"]["TokenType"]
        self.access_token = tokens['AuthenticationResult']['AccessToken']
        
    def get_keys(self):
        if self.pool_jwk:
            return self.pool_jwk

        # Check for the dictionary in environment variables.
        pool_jwk_env = env("COGNITO_JWKS", {}, var_type="dict")
        if pool_jwk_env:
            self.pool_jwk = pool_jwk_env
        # If it is not there use the requests library to get it
        else:
            self.pool_jwk = requests.get(
                "https://cognito-idp.{}.amazonaws.com/{}/.well-known/jwks.json".format(
                    self.user_pool_region, self.user_pool_id
                )
            ).json()
        return self.pool_jwk

    def get_key(self, kid):
        keys = self.get_keys().get("keys")
        key = list(filter(lambda x: x.get("kid") == kid, keys))
        return key[0]
    
    def verify_token(self, token, id_name, token_use):
        kid = jwt.get_unverified_header(token).get("kid")
        unverified_claims = jwt.get_unverified_claims(token)
        token_use_verified = unverified_claims.get("token_use") == token_use
        if not token_use_verified:
            raise TokenVerificationException("Your {} token use could not be verified.")
        hmac_key = self.get_key(kid)
        try:
            verified = jwt.decode(
                token,
                hmac_key,
                algorithms=["RS256"],
                audience=unverified_claims.get("aud"),
                issuer=unverified_claims.get("iss"),
            )
        except JWTError:
            raise TokenVerificationException("Your {} token could not be verified.")
        setattr(self, id_name, token)
        return verified
    
    def check_token(self, renew=True):
        """
        Checks the exp attribute of the access_token and either refreshes
        the tokens by calling the renew_access_tokens method or does nothing
        :param renew: bool indicating whether to refresh on expiration
        :return: bool indicating whether access_token has expired
        """
        if not self.access_token:
            raise AttributeError('Access Token Required to Check Token')
        now = datetime.datetime.now()
        dec_access_token = jwt.get_unverified_claims(self.access_token)

        if now > datetime.datetime.fromtimestamp(dec_access_token['exp']):
            expired = True
            if renew:
                self.renew_access_token()
        else:
            expired = False
        return expired
    
    def renew_access_token(self):
        """
        Sets a new access token on the User using the refresh token.
        """
        auth_params = {'REFRESH_TOKEN': self.refresh_token}
        #self._add_secret_hash(auth_params, 'SECRET_HASH')
        refresh_response = self.client.initiate_auth(
            ClientId=self.client_id,
            AuthFlow='REFRESH_TOKEN',
            AuthParameters=auth_params,
        )

        status_code = refresh_response.get(
            'HTTPStatusCode',
            refresh_response['ResponseMetadata']['HTTPStatusCode']
        )
        
        if status_code == 200:
            self.access_token = refresh_response['AuthenticationResult']['AccessToken']
            self.id_token: refresh_response['AuthenticationResult']['IdToken']
            self.token_type: refresh_response['AuthenticationResult']['TokenType']

In [21]:
u = Garner('e80cb5db-788a-4a84-9604-513cb3e8152b', 'secret key')

Enter Password


Password: ·········


Authenticated


In [26]:
u.check_token()
u.access_token

'eyJraWQiOiJnVWZFOW9sN0gzZng1aHlPeGxVXC9nYXZpK3pIcmNpcDhGOHhCOWpsdVwvaTA9IiwiYWxnIjoiUlMyNTYifQ.eyJzdWIiOiJlODBjYjVkYi03ODhhLTRhODQtOTYwNC01MTNjYjNlODE1MmIiLCJjb2duaXRvOmdyb3VwcyI6WyJnYXRoZXJlciIsImFkbWluIl0sImV2ZW50X2lkIjoiM2Q0YzRlNjktZjNlNS00OWQ0LWE5YjQtNzlkOTgxYzc3N2RhIiwidG9rZW5fdXNlIjoiYWNjZXNzIiwic2NvcGUiOiJhd3MuY29nbml0by5zaWduaW4udXNlci5hZG1pbiIsImF1dGhfdGltZSI6MTYwNDY5Mzc5MCwiaXNzIjoiaHR0cHM6XC9cL2NvZ25pdG8taWRwLmV1LXdlc3QtMi5hbWF6b25hd3MuY29tXC9ldS13ZXN0LTJfZXdJYXdwWndiIiwiZXhwIjoxNjA0Njk3NDA2LCJpYXQiOjE2MDQ2OTM4MDYsImp0aSI6ImM0NWRkNmEyLWY1YmQtNDU3Mi05MWJmLWZmNjQxZjI1MjI1OSIsImNsaWVudF9pZCI6IjVlajA2cm4xNG1vZTVvb2RtYzFqdm1uYjBoIiwidXNlcm5hbWUiOiJlODBjYjVkYi03ODhhLTRhODQtOTYwNC01MTNjYjNlODE1MmIifQ.XwAAIqoE79IneDRAB2kmxJKFdUgRxdjGSmG5RTeYwrN5BruzaG8F_SwLvFtTMSD3QJ4M1WT2BjOsd7-OOqRkzLoNn-MJNDm5fECiNQ0DppOhyILaraTISWpugU_MNMVohDVQizIkPDSZlkrvAYvpk54WrG3HC2cPyOmC_Ut71wYNsPxWbiw4TKpMNvOO6zW_iMoc9imtenvxiU2SQche5odiT03t6oXmSCfan98kwudzalBwlYfk57VGRiPVZc_OIctvDHh0x86d3QdFY8XRv7VnSAJya

In [27]:
import requests
import json

APPSYNC_API_ENDPOINT_URL = config["aws_appsync_graphqlEndpoint"]
#APPSYNC_API_KEY = "da2-fakeApiId123456"

headers = {
    #'Content-Type': "application/graphql",
    #'x-api-key': APPSYNC_API_KEY,
    #'Authorization': u.pool_jwk
    'Authorization': str(u.access_token)
}

params = {
    "filter": {"owner" : {"eq": u.username}}
}

def execute_gql(query):
    u.check_token()
    payload_obj = {"query": query, "variables" : params,}
    payload = json.dumps(payload_obj)
    response = requests.request("POST", APPSYNC_API_ENDPOINT_URL, data=payload, headers=headers)
    return response

query = """query ListPools(
    $filter: ModelPoolFilterInput
    $limit: Int
    $nextToken: String
  ) {
    listPools(filter: $filter, limit: $limit, nextToken: $nextToken) {
      items {
        id
        title
        privateKey
      }
    }
  }"""

output = execute_gql(query).json()

#print(output)
print("You manage:")
for item in output['data']['listPools']['items']:
    print("-{}".format(item['title']))

{'data': {'listPools': {'items': [{'id': '6ce54ac7-2761-4916-971c-c6286ffde71e', 'title': 'This pool is a test', 'privateKey': 'd9a49411-b7c2-42bd-bb43-16f4bdaccb40'}, {'id': '4f8237e1-637e-4f65-9d0c-64f211d52abb', 'title': 'kai man', 'privateKey': 'dbfaab17-b8b1-41ad-95ec-1568ac734109'}]}}}
You manage:
-This pool is a test
-kai man


In [16]:
object_name = 'movie.gif'
s3 = boto3.client('s3')

with open(object_name, 'rb') as data:
    loc = 'pooldata211140-dev'
    
    key = 'protected/'
    key += u.username
    key += '/'
    key += str(uuid.uuid4())
    key += '.'
    key += object_name.split('.')[-1]
    
    s3.upload_fileobj(data, loc, key)

NoCredentialsError: Unable to locate credentials

In [19]:
OBJECT_NAME.split('.')[-1]

'png'