# Imports

In [None]:
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
import jwt
from datetime import datetime, timezone
import re
import hashlib
from typing import Any
import requests
import json
import pickle
import base64

# Token client class definition (mimics mobile app)

In [None]:
DEV_URL = "http://127.0.0.1:8000"

In [None]:
class TokenClient:

    app_version = "v0.0.5"
    api_url = "/mainrequest/"
    base_url = DEV_URL


    def __init__(self, pin = "1234567890", saved_token = None):

        if not saved_token:
            self.private_key, self.public_key = self._generate_keypair()
            self.private_bytes, self.public_bytes = self._get_key_bytes()
            self.ticket = ""
        else:
            decoded_data = self._restore_saved_token(saved_token)
            if decoded_data:
                self.private_key, self.public_key, self.ticket = decoded_data

        self.fingerprint = self._get_fingerprint()
        self.pin = self._get_pin(pin)

    def _get_url(self):
        return self.base_url + self.api_url

    def _generate_keypair(self):
        private_key =  rsa.generate_private_key(key_size=2048, public_exponent=65537)
        public_key = private_key.public_key()
        return (private_key, public_key)
    
    def _get_key_bytes(self):
        private_bytes = self.private_key.private_bytes(
                encoding=serialization.Encoding.PEM,
                format=serialization.PrivateFormat.PKCS8,
                encryption_algorithm=serialization.NoEncryption(),
            ).decode()
        private_bytes = re.sub(r'(\r\n)|\n', '', private_bytes)

        public_bytes = self.public_key.public_bytes(
                    encoding=serialization.Encoding.PEM,
                    format=serialization.PublicFormat.SubjectPublicKeyInfo,
                ).decode()
        public_bytes = re.sub(r'(\r\n)|\n', '', public_bytes)

        return (private_bytes, public_bytes)
    
    def _get_fingerprint(self):
        return ' '.join(f'{byte:02X}' for byte in hashlib.sha256(self.public_bytes.encode('utf-8')).digest()[:6])
    
    def _get_pin(self, pin):
        return hashlib.sha256(f"{pin}{self.fingerprint}".encode('utf-8')).hexdigest()
    
    def _get_jwt_dict(self):
        jwt_dict = {
            "version": self.app_version,
            "request_time": datetime.now(timezone.utc).isoformat(),
            "public_key": self.public_bytes,
            "pin": self.pin,
            "ticket": self.ticket
        }
        return jwt_dict
    
    def _encode_jwt_token(self, jwt_dict : dict[str, Any]):
        return jwt.encode(jwt_dict, self.private_key, algorithm="RS256")
    
    def _get_request_data(self, jwt_token):
        return {'token': jwt_token}
    
    def _get_ticket_from_response(self, response):
        return json.loads(response)['ticket']

    def send_mainrequest(self):
        jwt_dict = self._get_jwt_dict()
        jwt_token = self._encode_jwt_token(jwt_dict)
        request_data = self._get_request_data(jwt_token)

        response = requests.post(self._get_url(), json=request_data)

        if response.status_code == 200:
            self.ticket = self._get_ticket_from_response(response.content)

        return (response.status_code, response.content)

    def save_token(self):
        pickled_data = pickle.dumps((self.private_key, self.public_key, self.ticket))
        base64_bytes = base64.b64encode(pickled_data)
        return base64_bytes.decode()


    def _restore_saved_token(self, encoded_string : str):
        base64_bytes = encoded_string.encode()
        pickled_data = base64.b64decode(base64_bytes)
        unpickled_data = pickle.loads(pickled_data)
        if (type(unpickled_data) != tuple): return None
        if len(unpickled_data) != 3: return None
        return unpickled_data
    
    def __str__(self):
        return self.fingerprint
    
    def __repr__(self):
        return f"Token({self.fingerprint})"

# Code playground

In [None]:
class MassTokenCreator:

    tokens_dict : dict[TokenClient] = dict()

    def __init__(self, amount):
        for _ in range(amount):
            token = TokenClient()
            print(token.send_mainrequest())
            self.tokens_dict[token.fingerprint] = token
        print(list(self.tokens_dict.keys()))
        
    @classmethod
    def send_requests(self, tokens: list[str] | None = None):
        if tokens:
            for fingerprint in tokens:
                print(fingerprint, self.tokens_dict[fingerprint].send_mainrequest())
        else:
            for token in self.tokens_dict.values():
                print(token.fingerprint, token.send_mainrequest())

    @classmethod
    def clear_tokens(self, tokens: list[str] | None = None):
        if tokens:
            for token in tokens:
                self.tokens_dict.pop(token)
        else:
            self.tokens_dict.clear()


In [None]:
MassTokenCreator.clear_tokens()
MassTokenCreator(1)

In [None]:
MassTokenCreator.send_requests()