<a href="https://colab.research.google.com/github/amranzoabi/Cryptography/blob/main/Crpytology_SM4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

An application for secure payment with server identification using Schnorr signature, encr-decr with SM4, including the elliptic DH key generation
# ** INSTRUCTIONS **
Run all blocks one by one, the last one is the main and includes some interactiveness (inputs required)

SM4 Algorithm

In [None]:
import struct
## CORE ##
S = (
    0xd6, 0x90, 0xe9, 0xfe, 0xcc, 0xe1, 0x3d, 0xb7, 0x16, 0xb6, 0x14, 0xc2, 0x28, 0xfb, 0x2c, 0x05,
    0x2b, 0x67, 0x9a, 0x76, 0x2a, 0xbe, 0x04, 0xc3, 0xaa, 0x44, 0x13, 0x26, 0x49, 0x86, 0x06, 0x99,
    0x9c, 0x42, 0x50, 0xf4, 0x91, 0xef, 0x98, 0x7a, 0x33, 0x54, 0x0b, 0x43, 0xed, 0xcf, 0xac, 0x62,
    0xe4, 0xb3, 0x1c, 0xa9, 0xc9, 0x08, 0xe8, 0x95, 0x80, 0xdf, 0x94, 0xfa, 0x75, 0x8f, 0x3f, 0xa6,
    0x47, 0x07, 0xa7, 0xfc, 0xf3, 0x73, 0x17, 0xba, 0x83, 0x59, 0x3c, 0x19, 0xe6, 0x85, 0x4f, 0xa8,
    0x68, 0x6b, 0x81, 0xb2, 0x71, 0x64, 0xda, 0x8b, 0xf8, 0xeb, 0x0f, 0x4b, 0x70, 0x56, 0x9d, 0x35,
    0x1e, 0x24, 0x0e, 0x5e, 0x63, 0x58, 0xd1, 0xa2, 0x25, 0x22, 0x7c, 0x3b, 0x01, 0x21, 0x78, 0x87,
    0xd4, 0x00, 0x46, 0x57, 0x9f, 0xd3, 0x27, 0x52, 0x4c, 0x36, 0x02, 0xe7, 0xa0, 0xc4, 0xc8, 0x9e,
    0xea, 0xbf, 0x8a, 0xd2, 0x40, 0xc7, 0x38, 0xb5, 0xa3, 0xf7, 0xf2, 0xce, 0xf9, 0x61, 0x15, 0xa1,
    0xe0, 0xae, 0x5d, 0xa4, 0x9b, 0x34, 0x1a, 0x55, 0xad, 0x93, 0x32, 0x30, 0xf5, 0x8c, 0xb1, 0xe3,
    0x1d, 0xf6, 0xe2, 0x2e, 0x82, 0x66, 0xca, 0x60, 0xc0, 0x29, 0x23, 0xab, 0x0d, 0x53, 0x4e, 0x6f,
    0xd5, 0xdb, 0x37, 0x45, 0xde, 0xfd, 0x8e, 0x2f, 0x03, 0xff, 0x6a, 0x72, 0x6d, 0x6c, 0x5b, 0x51,
    0x8d, 0x1b, 0xaf, 0x92, 0xbb, 0xdd, 0xbc, 0x7f, 0x11, 0xd9, 0x5c, 0x41, 0x1f, 0x10, 0x5a, 0xd8,
    0x0a, 0xc1, 0x31, 0x88, 0xa5, 0xcd, 0x7b, 0xbd, 0x2d, 0x74, 0xd0, 0x12, 0xb8, 0xe5, 0xb4, 0xb0,
    0x89, 0x69, 0x97, 0x4a, 0x0c, 0x96, 0x77, 0x7e, 0x65, 0xb9, 0xf1, 0x09, 0xc5, 0x6e, 0xc6, 0x84,
    0x18, 0xf0, 0x7d, 0xec, 0x3a, 0xdc, 0x4d, 0x20, 0x79, 0xee, 0x5f, 0x3e, 0xd7, 0xcb, 0x39, 0x48,
)

FK = (
    0xa3b1bac6, 0x56aa3350, 0x677d9197, 0xb27022dc,
)

CK = (
    0x00070e15, 0x1c232a31, 0x383f464d, 0x545b6269,
    0x70777e85, 0x8c939aa1, 0xa8afb6bd, 0xc4cbd2d9,
    0xe0e7eef5, 0xfc030a11, 0x181f262d, 0x343b4249,
    0x50575e65, 0x6c737a81, 0x888f969d, 0xa4abb2b9,
    0xc0c7ced5, 0xdce3eaf1, 0xf8ff060d, 0x141b2229,
    0x30373e45, 0x4c535a61, 0x686f767d, 0x848b9299,
    0xa0a7aeb5, 0xbcc3cad1, 0xd8dfe6ed, 0xf4fb0209,
    0x10171e25, 0x2c333a41, 0x484f565d, 0x646b7279,
)


def rotate_left(i32, k):
    return i32 << k & 0xffffffff | i32 >> 32 - k


def tau(i32):
    return S[i32 >> 24] << 24 | S[i32 >> 16 & 0xff] << 16 | S[i32 >> 8 & 0xff] << 8 | S[i32 & 0xff]


def linear_substitution_0(i32):
    return i32 ^ rotate_left(i32, 2) ^ rotate_left(i32, 10) ^ rotate_left(i32, 18) ^ rotate_left(i32, 24)


def linear_substitution_1(i32):
    return i32 ^ rotate_left(i32, 13) ^ rotate_left(i32, 23)


def derive_keys(key):
    k0, k1, k2, k3 = struct.unpack(">IIII", key)
    k0 ^= FK[0]
    k1 ^= FK[1]
    k2 ^= FK[2]
    k3 ^= FK[3]

    for ck in CK:
        k0, k1, k2, k3 = k1, k2, k3, k0 ^ linear_substitution_1(tau(k1 ^ k2 ^ k3 ^ ck))
        yield k3


def encode_block(block, derived_keys):
    b0, b1, b2, b3 = block

    for key in derived_keys:
        b0, b1, b2, b3 = b1, b2, b3, b0 ^ linear_substitution_0(tau(b1 ^ b2 ^ b3 ^ key))
    return b3, b2, b1, b0
## CORE END ##

## COMPATIBILITY ##

try:
    iter_range = xrange
except NameError:
    iter_range = range

## COMPATIBILITY END ##

class SM4Key(object):
    """A class for encryption using SM4 Key"""
    def __init__(self, key):
        self.__encryption_key = guard_key(key)
        self.__decryption_key = self.__encryption_key[::-1]
        self.__key = key

    def encrypt(self, message, initial=None, padding=True):
        """Encrypts the message with the key object.

        :param message: {bytes} The message to be encrypted
        :param initial: {union[bytes, NoneType]} The initial value, using CBC Mode when is not None
        :param padding: {any} Uses PKCS5 Padding when TRUTHY
        :return: {bytes} Encrypted bytes
        """
        return handle(message, self.__encryption_key, initial, padding, 1)

    def decrypt(self, message, initial=None, padding=True):
        """Decrypts the encrypted message with the key object.

        :param message: {bytes} The message to be decrypted
        :param initial: {union[bytes, NoneType]} The initial value, using CBC Mode when is not None
        :param padding: {any} Uses PKCS5 Padding when TRUTHY
        :return: {bytes} Decrypted bytes
        """
        return handle(message, self.__decryption_key, initial, padding, 0)

    def __hash__(self):
        return hash((self.__class__, self.__encryption_key))


def guard_key(key):
    if isinstance(key, bytearray):
        key = bytes(key)

    assert isinstance(key, bytes), "The key should be `bytes` or `bytearray`"
    assert len(key) == 16, "The key should be of length 16"

    return tuple(derive_keys(key))


def guard_message(message, padding, encryption):
    assert isinstance(message, bytes), "The message should be bytes"
    length = len(message)
    if encryption and padding:
        return message.ljust(length + 16 >> 4 << 4, chr(16 - (length & 15)).encode())

    assert length & 15 == 0, (
        "The length of the message should be divisible by 16"
        "(or set `padding` to `True` in encryption mode)"
    )
    return message


def guard_initial(initial):
    if initial is None:
        return

    if isinstance(initial, bytearray):
        initial = bytes(initial)

    assert isinstance(initial, bytes), "The initial value should be of type `bytes` or `bytearray`"
    assert len(initial) & 15 == 0, "The initial value should be of length 16"
    return struct.unpack(">IIII", initial)


def handle(message, key, initial, padding, encryption):
    message = guard_message(message, padding, encryption)
    initial = guard_initial(initial)

    blocks = (struct.unpack(">IIII", message[i: i + 16]) for i in iter_range(0, len(message), 16))

    if initial is None:
        # ECB
        encoded_blocks = ecb(blocks, key)
    else:
        # CBC
        encoded_blocks = cbc(blocks, key, initial, encryption)

    ret = b"".join(struct.pack(">IIII", *block) for block in encoded_blocks)
    return (ret[:-ord(ret[-1:])] if not encryption and padding else ret)


def ecb(blocks, key):
    for block in blocks:
        yield encode_block(block, key)


def cbc(blocks, key, initial, encryption):
    if encryption:
        for block in blocks:
            data = tuple(x ^ y for x, y in zip(block, initial))
            initial = encode_block(data, key)
            yield initial
    else:
        for block in blocks:
            data = encode_block(block, key)
            initial, block = block, tuple(x ^ y for x, y in zip(data, initial))
            yield block


try:
    bytes.fromhex
except AttributeError:
    def h2b(byte_string):
        return bytes(bytearray.fromhex(byte_string))
else:
    def h2b(byte_string):
        return bytes.fromhex(byte_string)

Elliptic Curve Diffie Hellman

In [None]:
import random

def add_points(p1, p2):
    x1, y1 = p1
    x2, y2 = p2

    if p1 == (float('inf'), float('inf')):
        return p2
    if p2 == (float('inf'), float('inf')):
        return p1

    if x1 == x2 and y1 == -y2 % p:
        return (float('inf'), float('inf'))

    if x1 == x2 and y1 == y2:
        m = (3 * x1 * x1 + a) * pow(2 * y1, p - 2, p) % p
    else:
        m = (y1 - y2) * pow(x1 - x2, p - 2, p) % p

    x3 = (m * m - x1 - x2) % p
    y3 = (m * (x1 - x3) - y1) % p

    return (x3, y3)

def multiply_point(k, p):
    result = (float('inf'), float('inf'))
    k_binary = bin(k)[2:]

    for bit in k_binary:
        result = add_points(result, result)
        if bit == '1':
            result = add_points(result, p)

    return result

def generate_ecdh_key_pair(p, a, b, G):
    # Generate private key
    private_key = random.randint(1, p - 1)

    # Generate public key
    public_key = multiply_point(private_key, G)

    # Return the key pair
    return private_key, public_key

def derive_shared_secret(private_key, peer_public_key, p, a, b, G):
    shared_secret = multiply_point(private_key, peer_public_key)[0]
    return shared_secret

Schnorr Digital Signature

In [None]:
import hashlib
from hashlib import sha256
from random import randint
'''
# Function to generate a prime number
def generate_prime():
    # Generate a random prime number (simplified for demonstration purposes)
    # In practice, you should use a more robust prime generation algorithm
    prime = 23
    return prime
'''
# Function to compute modular exponentiation (base^exponent mod modulus)
def mod_exp(base, exponent, modulus):
    result = nt > 0:
        if exponent % 2 == 1:
            result = (result * base) % modulus
        base = (base * base) % modulus
        exponent = exponent // 2
    return result

# Function to compute the modular inverse (a^-1 mod modulus)
def mod_inverse(a, modulus):
    # Compute the extended Euclidean algorithm
    # In practice, you should use a more efficient algorithm like the extended Euclidean algorithm
    for x in range(1, modulus):
        if (a * x) % modulus == 1:
            return x
    raise ValueError("Inverse does not exist.")

# Function to compute the Schnorr signature
def schnorr_sign(message, prime, generator, secret_key):
    # Step 1: Generate a random value k
    k = randint(1, prime - 1)

    # Step 2: Compute the commitment value r
    r = mod_exp(generator, k, prime)

    # Step 3: Compute the hash of the message
    hashed_message = int(hashlib.sha256(message).hexdigest(), 16)

    # Step 4: Compute the challenge value e
    e = hashed_message

    # Step 5: Compute the response value s
    s = (k - e * secret_key) % (prime - 1)


    resultP = Packet(r, s, message)
    # Return the signature (r, s)
    return resultP

# Function to verify the Schnorr signature
def schnorr_verify(signature, prime, generator, public_key):
    # Unpack the signature
    r, s = signature.r, signature.s

    # Step 1: Compute the hash of the message
    hashed_message = int(hashlib.sha256(signature.m).hexdigest(), 16)

    # Step 2: Compute the verification equation
    v = (mod_exp(generator, s, prime) * mod_exp(public_key, hashed_message, prime)) % prime

    # Check if the verification equation is equal to r
    return v == r

Utilities

In [None]:
import random
import secrets
import sympy


class Packet:
  def __init__(self, r, s, m):
    self.r = r
    self.s = s
    self.m = m



def generate_random_prime(bits):
    while True:
        num = random.getrandbits(bits)  # Generate a random number with the desired number of bits
        prime = sympy.nextprime(num)  # Find the next prime number starting from the random number
        if prime.bit_length() == bits:
            return prime

def enlarge_key(key):
  return hashlib.sha256(key).digest()[:16]

def split_string_at_character(string, character):
    split_list = string.split(character)
    return split_list[0]

SM4+ECDH+Schnorr attempt

Main

In [None]:
import random
import secrets
import sympy
from random import randint

#--- Schnorr Parameters ---#
sPrime = generate_random_prime(16)
sGenerator = 2  # Generator for the multiplicative group of integers modulo prime
s_SK = randint(1, sPrime - 1)
s_PK = mod_exp(sGenerator, s_SK, sPrime)
#--- Schnorr Parameters ---#

# Generate a random prime number with 256 bits - p for ECDH
p = generate_random_prime(256)

# Generate random coefficients a and b for ECDH
a = secrets.randbelow(p)
b = secrets.randbelow(p)

# Generate random base point G for ECDH
x = secrets.randbelow(p)
y = secrets.randbelow(p)
G = (x, y)

#print global ECDH values
print("[Global] \nGlobal ECDH values:\np = "+str(p)+"\nG = "+str(G)+"\na = "+str(a)+", b="+str(b))

#Generate key pair for Alice
alice_private_key, alice_public_key = generate_ecdh_key_pair(p, a, b, G)

# Generate key pair for Server
server_private_key, server_public_key = generate_ecdh_key_pair(p, a, b, G)

print("[Alice]\nAlice's private ECDH key = " + str(alice_private_key))
print("[Server]\nServer's private ECDH key = " + str(server_private_key))
print("[Alice]\nAlice's public ECDH key = " + str(alice_public_key))
print("[Server]\nServer's public ECDH key = " + str(server_public_key))


#the public key consists of two parts - so we divide them into two packets
server_public_key_bytes_1 = server_public_key[0].to_bytes(32, byteorder='big') #convert server's public key first section to bytes
serverPacket1 = schnorr_sign(server_public_key_bytes_1, sPrime, sGenerator, s_SK) #sign the first packet (first section of server public key)
print("\n\n*** [Server] \tSending packet containing first section of the server's public ECDH key to Alice for identity verification")
server_public_key_bytes_2 = server_public_key[1].to_bytes(32, byteorder='big') #convert server's public key second section to bytes
serverPacket2 = schnorr_sign(server_public_key_bytes_2, sPrime, sGenerator, s_SK) #sign the second packet (second section of server public key
print("*** [Server] \tSending packet containing second section of the server's public ECDH key to Alice for identity verification")
valid1 = schnorr_verify(serverPacket1, sPrime, sGenerator, s_PK) #verifying first packet
valid2 = schnorr_verify(serverPacket2, sPrime, sGenerator, s_PK) #verifying second packet
valid = valid1 & valid2 #verifying both packets
if (valid):
  print("[Alice] \tMessage accepted and continuing to create the shared key!")
  # Alice derives shared secret using Server's public key
  serverPK = int.from_bytes(serverPacket1.m, byteorder='big'), int.from_bytes(serverPacket2.m, byteorder='big') #restoring server public key to original
  #serverPKBytes = serverPacket1.m + serverPacket2.m
  alice_shared_secret = derive_shared_secret(alice_private_key, serverPK, p, a, b, G) #calculating the shared key of Alice
  print("[Alice] \tAlice's shared ECDH Key: "+ str(alice_shared_secret))

  # Server derives shared secret using Alice's public key
  server_shared_secret = derive_shared_secret(server_private_key, alice_public_key, p, a, b, G) #calculating the shared key of the Server
  print("[Server] \tServer's shared ECDH Key: "+ str(server_shared_secret))

  aSM4Key = enlarge_key(alice_shared_secret.to_bytes(32, byteorder='big'))
  print("[Alice] \tAlice's SM4 key derived by enlarging the ECDH key: "+str(aSM4Key))

  sSM4Key = enlarge_key(server_shared_secret.to_bytes(32, byteorder='big'))
  print("[Server] \tServer's SM4 key derived by enlarging the ECDH key: "+str(sSM4Key))

  #initiating the key objects respectively for each key
  aliceSM4K = SM4Key(aSM4Key) # SM4Key object for Alice
  serverSM4K = SM4Key(sSM4Key) # SM4Key object for Server

  print("\n\n*** [Server] \tConnection Established \t***\n")
  #Alice performing payment
  print("** Login Process **")
  print("[Alice] Username: ")
  username = input()
  print("\n[Alice] Password: ")
  password = input()
  loginInfo = str(username) + '|' + str(password)

  plaintext = loginInfo.encode('utf-8')
  message = aliceSM4K.encrypt(plaintext)
  print("\n*** [Alice] \tLogin details submitted and encrypted")
  signedMessageAlice = schnorr_sign(message, sPrime, sGenerator, s_SK)
  print("*** [Alice] \tSending packet containing message (login details) to Server for identity verification\n")
  verifyAliceMessage = schnorr_verify(signedMessageAlice, sPrime, sGenerator, s_PK)
  print("*** [Server] \tMessage received and is awaiting identity verification")
  if(verifyAliceMessage):
    print("*** [Server] \tMessage accepted and continuing to be decrypted")
    decrypted_message = serverSM4K.decrypt(signedMessageAlice.m)
    if(decrypted_message != None):
      print("*** [Server] \tMessage decrypted successfully!")
      messageUsername = split_string_at_character(str(decrypted_message)[2:], '|') #get username from the decrypted message
      print("[Server] \tUser "+messageUsername+" has successfully logged in\n")
      randomOverdue = randint(1, 200)
      print("[Alice] \tYour overdue is "+str(randomOverdue)+"$")
      print("[Alice] \tPlease enter your credit card details (16 digits)")
      creditcard = str(input())
      while(len(creditcard) != 16):
        print("\n[Alice] \tPlease make sure to enter a valid credit card number (16 digits)")
        creditcard = str(input())
      plaintext = creditcard.encode('utf-8')
      message = aliceSM4K.encrypt(plaintext)
      print("\n*** [Alice] \tCredit card details submitted and encrypted")
      signedMessageAlice = schnorr_sign(message, sPrime, sGenerator, s_SK)
      print("*** [Alice] \tSending packet containing message (credit card details) to Server for identity verification\n")
      verifyAliceMessage = schnorr_verify(signedMessageAlice, sPrime, sGenerator, s_PK)
      print("*** [Server] \tMessage received and is awaiting identity verification")
      if(verifyAliceMessage):
        print("*** [Server] \tMessage accepted and continuing to be decrypted")
        decrypted_message = serverSM4K.decrypt(signedMessageAlice.m)
        print("*** [Alice] \tYour payment has been processed successfully, thanks for paying!")
      else:
        print("***[Server] \tMessage verification failed, exiting process.")

  else:
    print("***[Server] \tMessage verification failed, exiting process.")

[Global] 
Global ECDH values:
p = 69366747227242718571457461194154658281591827005175504643902190811844956749403
G = (33270745349940538213116614696381042480078211444154093486988764838224200229687, 19019184430781829816790361688695888992581977487635309372920408104094127526194)
a = 60174489979152822413690937222500194111560043819098114507907347789650321871301, b=52148230455551979386473466305018615787126667917320155018824330009867477212725
[Alice]
Alice's private ECDH key = 58128768962545029653365682222032807679400558470261667630423091086021778058207
[Server]
Server's private ECDH key = 17001629139981626743429398397585524536622455767436965275973987784717752917152
[Alice]
Alice's public ECDH key = (66203894292904067173471092201683828126789966836590921314850287090697013799807, 7425168734497755267219226827219260197534179346651814526611793048955708971854)
[Server]
Server's public ECDH key = (64978222561370300611557244154500617774749401695325470468053863478360764408987, 12972069256165553196841350