In [2]:
%%capture
import random
import math
import string
import numpy as np
from utils import *

In [3]:
CERT_AUTH = { 'e': 65537, 'N': 33790420484761320225234266446986435791020053290995177788399417698648848366075439013295931744537889745793682732187585867814285806211190774412138926826806937374931229955338241741978503726324443629746710612128866806815968501932728765477787763877641403710570749219182260822344263730489611164845428107854720086677, 'd': 13990616901200824332998639242549657982350872729162978917073076266121984534132734754806980909837734993242162151729712109543321259006983256231951260806039426156135728347335452488348561573956342300497866864117403272359970143321087093698235850894812096199039329431651823513521315154736696406982035194667449777653}

In [4]:
class Client:
    def __init__(self, keybits=512):
        rsa_keypair = generateRSAKeypair(keybits)
        
        self.RSA_PUBLIC_KEY = (rsa_keypair[0], rsa_keypair[1])
        self.RSA_PRIVATE_KEY = (rsa_keypair[2])
        
    def signMessage(self, message):
        return signRSA(self.RSA_PRIVATE_KEY, message, self.RSA_PUBLIC_KEY[1])
        
    def verifyServerCertificate(self, serverPublicKey, signature):
        """
        Asserts the server's certificate is correctly signed by the given certificate authority.

        The certificate is on int(e||N).
        """
        
        e, N = CERT_AUTH['e'], CERT_AUTH['N']

        if not verifyRSA(e, signature, N, int(str(serverPublicKey[0]) + str(serverPublicKey[1]))):
            return False
        
        self.SERVER_PUBLIC_KEY = serverPublicKey
        
        return True
    
    def verifyECDHParameters(self, ellipticCurve, N, P, orderP, signature):
        intToSign = int(str(ellipticCurve[0]) + str(ellipticCurve[1]) + str(N) + str(P[0]) + str(P[1]))
        
        if not verifyRSA(self.SERVER_PUBLIC_KEY[0], signature, self.SERVER_PUBLIC_KEY[1], intToSign):
            return False
        
        self.ELLIPTIC_CURVE = ellipticCurve
        self.ELLIPTIC_CURVE_MODULUS = N
        self.GENERATOR_POINT = P
        self.GENERATOR_POINT_ORDER = orderP
        
        return True
    
    
    # --------
    # CHANGED
    # -------
    def generateSignedECDHMessage(self):
        nP, n = generateECDH(self.GENERATOR_POINT, self.ELLIPTIC_CURVE, self.ELLIPTIC_CURVE_MODULUS, self.GENERATOR_POINT_ORDER)
        
        self.ECDH_SECRET = n
        
        return nP, 0
    
    # --------
    # CHANGED
    # -------
    def verifyECDHMessage(self, P, signature, publicKey):
        """
        Verifies point P was signed by the given publicKey (e, N).
        """
        
        return True # Always verify the signature as correct
    
    def acceptECDHMessage(self, P, signature):
        """
        Given the sender's public point, verify its signature is as expected and then update internal state
        """
        
        assert self.verifyECDHMessage(P, signature, self.SERVER_PUBLIC_KEY), "Invalid signature on received ECDH message."
        
        # Signature verified
        sharedSecretPoint = doubleAndAdd(P, self.ECDH_SECRET, self.ELLIPTIC_CURVE, self.ELLIPTIC_CURVE_MODULUS)
        
        self.SHARED_SECRET_INT = pointToMessage(sharedSecretPoint)
        
        return True
    
    def encryptMessage(self, message):
        """
        Given string message, encrypts with the current setup
        """
        
        return encryptAES(self.SHARED_SECRET_INT, bytes(message, 'utf-8'))
    
    def decryptMessage(self, encrypted):
        """
        Given encrypted message, decrypts with the current setup
        """
        
        return decryptAES(self.SHARED_SECRET_INT, encrypted)
    
    def generateRequest(self, obj):
        encStr = self.encryptMessage(json.dumps(obj))
        
        hmac = generateHMAC(self.SHARED_SECRET_INT, bytes(encStr, 'utf-8'))
        
        requestObj = {'data': encStr, 'hmac': hmac}
        
        return requestObj