<a href="https://colab.research.google.com/github/Saanvi-U/Information-Network-and-Security-Lab/blob/master/secure_key_management_system.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from cryptography.hazmat.primitives.asymmetric import rsa, padding, dh
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.padding import PKCS7
import os
import base64
import uuid
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)

# Key Management System (KMS) Class
class SecureKeyManagementSystem:
    def __init__(self):
        self.symmetric_keys = {}  # Store AES keys (key_id: (key, salt))
        self.asymmetric_keys = {}  # Store RSA key pairs (user_id: (private_key, public_key))

    def generate_aes_key(self, key_id):
        salt = os.urandom(16)  # Salt for key derivation
        password = os.urandom(16)  # Random password-like input
        kdf = PBKDF2HMAC(
            algorithm=hashes.SHA256(),
            length=32,
            salt=salt,
            iterations=100000
        )
        key = kdf.derive(password)  # Derive a 256-bit AES key
        self.symmetric_keys[key_id] = (key, salt, password)  # Store key, salt, and password
        logging.info(f"AES key generated for key_id: {key_id}")
        return base64.b64encode(key).decode()

    def encrypt_with_aes(self, key_id, plaintext):
        key_data = self.symmetric_keys.get(key_id)
        if not key_data:
            raise ValueError("Key not found for the given key_id.")

        key, salt, _ = key_data
        iv = os.urandom(16)
        cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
        encryptor = cipher.encryptor()
        padder = PKCS7(algorithms.AES.block_size).padder()
        padded_data = padder.update(plaintext.encode()) + padder.finalize()
        ciphertext = encryptor.update(padded_data) + encryptor.finalize()
        logging.info(f"Data encrypted using AES for key_id: {key_id}")
        return base64.b64encode(iv + ciphertext).decode()

    def decrypt_with_aes(self, key_id, encrypted_data):
        key_data = self.symmetric_keys.get(key_id)
        if not key_data:
            raise ValueError("Key not found for the given key_id.")

        key, salt, _ = key_data
        encrypted_data = base64.b64decode(encrypted_data)
        iv, ciphertext = encrypted_data[:16], encrypted_data[16:]
        cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
        decryptor = cipher.decryptor()
        decrypted_padded = decryptor.update(ciphertext) + decryptor.finalize()
        unpadder = PKCS7(algorithms.AES.block_size).unpadder()
        plaintext = unpadder.update(decrypted_padded) + unpadder.finalize()
        logging.info(f"Data decrypted using AES for key_id: {key_id}")
        return plaintext.decode()

    def generate_rsa_key_pair(self, user_id):
        private_key = rsa.generate_private_key(
            public_exponent=65537,
            key_size=2048
        )
        public_key = private_key.public_key()
        self.asymmetric_keys[user_id] = (private_key, public_key)
        logging.info(f"RSA key pair generated for user_id: {user_id}")
        return public_key

    def encrypt_with_rsa(self, user_id, plaintext):
        _, public_key = self.asymmetric_keys.get(user_id, (None, None))
        if not public_key:
            raise ValueError("Public key not found for the given user_id.")

        encrypted = public_key.encrypt(
            plaintext.encode(),
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=None
            )
        )
        logging.info(f"Data encrypted using RSA for user_id: {user_id}")
        return base64.b64encode(encrypted).decode()

    def decrypt_with_rsa(self, user_id, encrypted_data):
        private_key, _ = self.asymmetric_keys.get(user_id, (None, None))
        if not private_key:
            raise ValueError("Private key not found for the given user_id.")

        encrypted_data = base64.b64decode(encrypted_data)
        decrypted = private_key.decrypt(
            encrypted_data,
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=None
            )
        )
        logging.info(f"Data decrypted using RSA for user_id: {user_id}")
        return decrypted.decode()

    def generate_diffie_hellman_key(self):
        parameters = dh.generate_parameters(generator=2, key_size=2048)
        private_key = parameters.generate_private_key()
        public_key = private_key.public_key()
        serialized_public_key = public_key.public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )
        logging.info("Diffie-Hellman keys generated.")
        return private_key, serialized_public_key

    def key_revocation(self, key_id):
        if key_id in self.symmetric_keys:
            del self.symmetric_keys[key_id]
        elif key_id in self.asymmetric_keys:
            del self.asymmetric_keys[key_id]
        else:
            raise ValueError("Key ID not found.")
        logging.info(f"Key with ID {key_id} revoked successfully.")
        return "Key Revoked Successfully"

# ---------------------- Test Cases ----------------------

# Create an instance of the KMS
kms = SecureKeyManagementSystem()

# Test Case 1: Symmetric Key Management
aes_key_id = str(uuid.uuid4())
aes_key = kms.generate_aes_key(aes_key_id)
aes_ciphertext = kms.encrypt_with_aes(aes_key_id, "Sensitive Data")
print("AES Encrypted Data:", aes_ciphertext)  # Check encrypted output
aes_decrypted_text = kms.decrypt_with_aes(aes_key_id, aes_ciphertext)
print("Decrypted AES:", aes_decrypted_text)  # Expected Output: Sensitive Data

# Test Case 2: Asymmetric Key Management
rsa_user = "userRSA"
kms.generate_rsa_key_pair(rsa_user)
rsa_enc_data = kms.encrypt_with_rsa(rsa_user, "Confidential")
print("RSA Encrypted Data:", rsa_enc_data)  # Check encrypted output
rsa_dec_data = kms.decrypt_with_rsa(rsa_user, rsa_enc_data)
print("Decrypted RSA:", rsa_dec_data)  # Expected Output: Confidential


# Test Case 3: Diffie-Hellman Key Exchange
dh_private, dh_public = kms.generate_diffie_hellman_key()
print("Serialized DH Public Key:", dh_public.decode())

# Test Case 4: Key Revocation Test
revocation_result = kms.key_revocation(aes_key_id)
print("Revocation Result:", revocation_result)  # Expected: Key Revoked Successfully

# Additional test: Attempt to decrypt AES data after revocation
try:
    kms.decrypt_with_aes(aes_key_id, aes_ciphertext)
except Exception as e:
    print("Expected error after key revocation:", e)


AES Encrypted Data: 5Y3swX71aELJ/8D77KN1ns+2uc3skrsOAZ01FMfAM0g=
Decrypted AES: Sensitive Data
RSA Encrypted Data: ZCvaMdABZ9DlLTUSKm4x7MkHHL06n4Hs/28hyRIX6yQgHGATg+y+FsAUI0bPFNpgYxGrBYGks4mRnH5SAfU/jlbkKtkkL8heDbPENJWg7u0cnJnzXEfTeJ/ROorPTxeL+pkk46rEevvLtSVB35HvBsQ/OvckwjFLcpvzDZL8Wox/WHb4WKlUd9oNDpmLtVbkYiW57gkX0ZC0NdwbXnre4A8SC45V7V0H29GPpbecacsVP27pA05GKuylp0hPYyeZH7HxIwHSO7bux8fusFzAAJb++DD7LxPa1hPuIzplqdVkll9suSB1R5sSUfV8RR+Orr5LSpYz//1qlMHP41GyZw==
Decrypted RSA: Confidential
Serialized DH Public Key: -----BEGIN PUBLIC KEY-----
MIICJTCCARcGCSqGSIb3DQEDATCCAQgCggEBAP1zwZk+6fifedkm38lNVuZ884yT
mtpN8d+Z53+XQzBRRhbwi0bWlZ0STjQ6m2w3z8Z12fMiY5x2ShTopX8oGtxfrPa+
wHT0yW+XaEqOWHtyQcNSqXsKxZaV6Jaa2HML7mfur1QDfpIaUqWox/yjgbme03h+
fs77DYZqgoEtFbk3soacQp7SUw0E2aKkZSq1LyQfVFNmKEsdvkZoFTHMxJ0pfzIb
8jid5+5V3gRQLLGaD9H8nQYlyKzSUUfOK80Wzz4NBW7azc/sQ6BB9a2j1vn9bWVa
Iqh+O3XHHo6J0jrm/Zz9A5mkn10j+bR5a/iUCiTnz6VfpiwCIM51b8AjpVcCAQID
ggEGAAKCAQEAkTzf/ANcuvzMP4kSrHGWl6dy9e/rmRlh0EpiVWzQRPJMtsmJYMd7
Jzzk