In [None]:
import os
import logging
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes

# Config
backend = default_backend
logging.basicConfig(filename="activities.log",level = logging.INFO)

# Classes
class Keys:
    """
    This class contains code to create private and public keys.
    """
    def private_key_generate():
        """
        This functions contains code to generate a private key.
        Returns an object.
        """
        private_key = rsa.generate_private_key(
            public_exponent=65537,
            key_size=2048,
            backend=default_backend()
        )
        return private_key

    def private_key_serialize(private_key):
        """
        This function contains code to convert the private_key object to a byte string.
        Returns a byte string.
        """
        serialized_private_key = private_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.BestAvailableEncryption(b'mypassword')
        )
        return serialized_private_key

    def public_key_generate(private_key):
        """
        This function contains code to generate the public key using the private key object.
        Returns an object.
        """
        public_key = private_key.public_key()
        return public_key

    def public_key_serialize(public_key):
        """
        This function contains code to convert the public key object to a byte string.
        Returns a byte string.
        """
        serialized_public_key = public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )
        return serialized_public_key

class Signature:
    """
    This class contains code to create and verfy the digital signatures.
    """
    def digital_signature(message,private_key):
        """
        This function contains code to generate the digital signature using the private key.
        Returns a byte string.
        """
        signature = private_key.sign(
            message,
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA256()),
                salt_length=padding.PSS.MAX_LENGTH
            ),
            hashes.SHA256()
        )
        return signature

    def verification(public_key,message,signature):
        """
        This function contains code to verify the digital signature.
        Raises exception if signature is invalid.
        """
        try:
            public_key.verify(
                signature,
                message,
                padding.PSS(
                    mgf=padding.MGF1(hashes.SHA256()),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                hashes.SHA256()
            )
            return True
        except InvalidSignature:
            return False

class Alice:
    """
    This class handles Alice's functionality including key generation,signature creation and verification.
    """
    def generate_keys():
        """
        This function contains code to generate Alice's private and public keys.
        """
        # Private variables
        __private_key = Keys.private_key_generate()
        __alice_private_key = Keys.private_key_serialize(__private_key)
        # Public variables
        public_key = Keys.public_key_generate(__private_key)
        alice_public_key = Keys.public_key_serialize(public_key)
        
        return __private_key,public_key,__alice_private_key,alice_public_key

    def signature(private_key_obj):
        """
        This function contains code to generate Alice's digital signature.
        """
        message = b"This is signed by Alice"
        signature_alice = Signature.digital_signature(message,private_key_obj)
        return signature_alice

    def verify(public_key,message,signature):
        """
        This function contains code to verify Alice's digital signature.
        """
        if(Signature.verification(public_key,message,signature)):
            return "Signature verified"
        else:
            return "Signature not verified"


class Bob:
    """
    This class handles Bob's functionality including key generation,signature creation and verification.
    """
    def generate_keys():
        """
        This function contains code to generate Bob's private and public keys.
        """
        # Private variables
        __private_key = Keys.private_key_generate()
        __bob_private_key = Keys.private_key_serialize(__private_key)
        # Public variables
        public_key = Keys.public_key_generate(__private_key)
        bob_public_key = Keys.public_key_serialize(public_key)

        return __private_key,public_key,__bob_private_key,bob_public_key

    def signature(private_key):
        """
        This function contains code to generate Bob's digital signature.
        """
        message = b"This is signed by Bob"
        signature_bob = Signature.digital_signature(message,private_key)

        return signature_bob

    def verify(public_key,message,signature):
        """
        This function contains code to verify Bob's digital signature.
        """
        if(Signature.verification(public_key,message,signature)):
            return "Signature verified"
        else:
            return "Signature not verified"


class AsymmetricEncryption:
    """
    This class contains code to perform asymmetric encryption. 
    """
    def encrypt(message,public_key):
        """
        This function contains code to encrypt a message with the available public key.
        Returns a byte string.
        """
        ciphertext = public_key.encrypt(
            message,
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=None
            )
        )
        return ciphertext

    def decrypt(ciphertext,private_key):
        """
        This function contains code to decrypt a message with the private key.
        Returns a byte string.
        """
        plaintext = private_key.decrypt(
            ciphertext,
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=None
            )
        )
        return plaintext

class SymmetricKey():
    """
    This class contains code to generate the symmetric key and iv that will 
    be used for secure communicaton.
    """
    def generate_key():
        """
        This function generates and returns the symmetric key.
        """
        key = os.urandom(32)
        return key

    def generate_iv():
        """
        This function generates and returns the iv.
        """
        iv = os.urandom(16)
        return iv

    def send_key(symmetric_key,iv,bob_keys,alice_keys):
        """
        This function contains code to send the symmetric key from Alice to Bob.
        """
        sign_alice = Alice.signature(alice_keys["private_key_obj"])
        encrypted_symmetric_key = AsymmetricEncryption.encrypt(symmetric_key,bob_keys["public_key_obj"])
        encryted_iv = AsymmetricEncryption.encrypt(iv,bob_keys["public_key_obj"])
        return sign_alice,encrypted_symmetric_key,encryted_iv

    def recieve_key(bob_keys,encrypted_symmetric_key,encrypted_iv):
        """
        This fuction contains code for Bob to recieve the symmetric key.
        """
        decrypted_symmetric_key = AsymmetricEncryption.decrypt(encrypted_symmetric_key,bob_keys["private_key_obj"])
        decrypted_iv = AsymmetricEncryption.decrypt(encrypted_iv,bob_keys["private_key_obj"])
        return decrypted_symmetric_key,decrypted_iv

class SymmetricEncryption:
    """
    This class contains code to carry out symmetric encryption.
    """
    def add_padding(string):
        """
        This function adds padding to a message i.e, if the message length is not an integral multiple of 16, trailing spaces
        are added to make it a multiple of 16.
        """
        space = ' '
        if (len(string) % 16) != 0:
            rem = len(string) % 16 
            string += space*(16-rem)        
        return string

    def remove_padding(string):
        """
        This function removes any extra padding that was added.
        """ 
        return " ".join(string.split())

    def encrypt(plaintext:str,key,iv):
        """
        This function contains code to encrypt a message.
        """
        backend = default_backend()
        cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=backend)
        encryptor = cipher.encryptor()
        ct = encryptor.update(plaintext) + encryptor.finalize()
        return ct

    def decrypt(key,ct,iv):
        """
        This function contains code to decrypt a message.
        """
        backend = default_backend()
        cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=backend)
        decryptor = cipher.decryptor()
        decrypted_message = decryptor.update(ct) + decryptor.finalize()
        return decrypted_message.decode()

# Display menu function
def menu():
    """
    This function contains code for creating the main menu.
    """
    print("Welcome to RSA. All your chats are safe here!")
    print("Press 1 to create Alice's keys")
    print("Press 2 to create Bob's keys")
    print("Press 3 to generate symmetric key and iv")
    print("Press 4 to send symmetric key from Alice to Bob")
    print("Press 5 to recieve symmetric key")
    print("Press 6 to send an encrypted message from Alice to Bob")
    print("Print 7 to decrypt a message")
    print("Press 8 to exit")
    
# Main function to execute the program
def main():
    # Initialize variables
    alice_keys = {"private_key_obj":None,"public_key_obj":None,"private_key":None,"public_key":None}
    bob_keys = {"private_key_obj":None,"public_key_obj":None,"private_key":None,"public_key":None}
    symmetric_key = None
    iv = None
    message = None
    ct = None
    recieved_message = None
    encrypted_symmetric_key = encrypted_iv = None
    decrypted_symmetric_key = decrypted_iv = None
    # Call menu function
    menu()
    # I/O
    while True:
        n = int(input("Enter your choice: \n"))
        if n == 1:
            alice_keys["private_key_obj"],alice_keys["public_key_obj"],alice_keys["private_key"],alice_keys["public_key"] = Alice.generate_keys()
            logging.info("Generated Alice's keys")
            print("\nSuccessfully generated Alice's keys\n")
        elif n == 2:
            bob_keys["private_key_obj"],bob_keys["public_key_obj"],bob_keys["private_key"],bob_keys["public_key"] = Bob.generate_keys()
            logging.info("Generated Bob's keys")
            print("\nSuccessfully generated Bob's keys\n")
        elif n == 3:
            symmetric_key = SymmetricKey.generate_key()
            iv = SymmetricKey.generate_iv()
            logging.info("Generated the symmetric key")
            print("\nSuccessfully generated the symmetric_key\n")
        elif n == 4:
            if alice_keys["private_key"] == None or alice_keys["public_key"] == None:
                print("\nPlease generate Alice's keys\n")
            elif bob_keys["private_key"] == None or bob_keys["public_key"] == None:
                print("\nPlease generate Bob's keys\n")
            elif symmetric_key == None or iv == None:
                print("\nPlease generate the symmetric key first\n")
            else:
                alice_sign,encrypted_symmetric_key,encrypted_iv = SymmetricKey.send_key(symmetric_key,iv,bob_keys,alice_keys)
                logging.info("Symmetric key sent from Alice to Bob")
                print("\nSuccessfully sent the symmetric key to Bob\n")
        elif n == 5:
            message_sign = b"This is signed by Alice"
            if Alice.verify(alice_keys["public_key_obj"],message_sign,alice_sign):
                decrypted_symmetric_key,decrypted_iv = SymmetricKey.recieve_key(bob_keys,encrypted_symmetric_key,encrypted_iv)
                logging.info("Alice's signature verified")
                print("\nSuccessfully recieved keys\n")
            else:
                print("\nFailed to verify signature\n")
        elif n == 6:
            if encrypted_symmetric_key == None or decrypted_symmetric_key == None:
                print("\nSymmetric key unavailable!\nPlease try again after generating the symmetric keys\n")
            else:
                message = input("Please enter your messsage")
                padded_string = SymmetricEncryption.add_padding(message)
                byte_string = padded_string.encode()
                ct = SymmetricEncryption.encrypt(byte_string,decrypted_symmetric_key,decrypted_iv)
                if ct == None:
                    print("\nSomething went wrong\n")
                else:
                    logging.info("Message encrypted")
                    print("\nSuccessfully encrypted the message\n")
        elif n == 7:
            if message == None:
                print("\nNo message available\n")
            elif decrypted_symmetric_key == None:
                print("\nPlease generate the symmetric keys first\n")
            else:
                decrypted_message = SymmetricEncryption.decrypt(decrypted_symmetric_key,ct,decrypted_iv)
                recieved_message = SymmetricEncryption.remove_padding(decrypted_message)
                if recieved_message == None:
                    print("\nSomething went wrong\n")
                else:
                    logging.info("Message recieved by Bob")
                    # Write messages sent by Alice to Bob in Bob's file 
                    fo = open("bob_messages.txt","a")
                    fo.write("\n")                    
                    fo.write(message)
                    fo.write("\n")
                    fo.close()
                    print("\nSuccessfully decrypted the message. Congrats!\n")
        elif n==8:
            logging.info("End of session\n")
            break
            
main()