<a href="https://colab.research.google.com/github/DorShabat/Cryptology-Project/blob/main/combined.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pip install and Imports

In [1]:
%%capture
!pip install cryptography # for derive shared key from shared secret.
!pip install gmpy2 # for precision calculations.

## imports:

In [2]:
import numpy as np
import sympy
import os
import random
import hashlib
from sympy.ntheory.generate import randprime
import gmpy2

# for derive shared key from shared secret:
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend

# The Elliptic Curve Diffie-Hellman (ECDH)

## Define the Elliptic Curve

### helping functions:

In [3]:
def point_addition(P1, P2, P, a):
    if P1 == (None, None):
        return P2
    if P2 == (None, None):
        return P1
    if P1 == P2:
        return point_doubling(P1, P, a)

    x1, y1 = P1
    x2, y2 = P2

    if x1 == x2 and (y1 + y2) % P == 0:
        return (None, None)

    m = (y2 - y1) * pow(x2 - x1, -1, P) % P
    x3 = (m * m - x1 - x2) % P
    y3 = (m * (x1 - x3) - y1) % P

    return (x3, y3)

def point_doubling(P1, P, a):
    if P1 == (None, None):
        return (None, None)

    x1, y1 = P1

    m = (3 * x1 * x1 + a) * pow(2 * y1, -1, P) % P
    x3 = (m * m - 2 * x1) % P
    y3 = (m * (x1 - x3) - y1) % P

    return (x3, y3)


def multiply_point_by_scalar_mod_P(point, scalar, P, a):
    """
    Multiplies a point on an elliptic curve by a scalar using the double-and-add algorithm.

    Parameters:
    point (tuple): The point on the elliptic curve (x, y).
    scalar (int): The scalar to multiply the point by.
    P (int): The prime order of the finite field.
    a (int): The coefficient 'a' in the elliptic curve equation y^2 = x^3 + ax + b.

    Returns:
    tuple: The resulting point after multiplication.
    """

    result = (None, None)
    addend = point

    while scalar:
        if scalar & 1:
            result = point_addition(result, addend, P, a)
        addend = point_doubling(addend, P, a)
        scalar >>= 1

    return result

### parametes defining:

In [4]:
# well known parameters:
P = 6277101735386680763835789423207666416083908700390324961279
a = 6277101735386680763835789423207666416083908700390324961276
b = 2455155546008943817740293915197451784769108058161191238065
Gx = 602046282375688656758213480587526111916698976636884684818
Gy = 174050332293622031404857552280219410364023488927386650641
G = (Gx, Gy)

In [5]:
# small numbers example:
a = 2
b = 3
P = 97 #prime number
G = (3, 6)

## Alice

In [6]:
# private key:
private_key_a = 621
# public key:
A = multiply_point_by_scalar_mod_P(point=G, scalar=private_key_a, P=P, a=a)
print(A)

(3, 6)


## Bob

In [7]:
# private key : b
private_key_b = 77
# public key : B = bG
B = multiply_point_by_scalar_mod_P(point=G, scalar=private_key_b, P=P, a=a)
print(B)

(80, 10)


## Extchanging public keys...

`S_A = private_key_a * B`

`S_B = private_key_b * A`

Shared secret: `S_A == S_b`

## Alice

In [8]:
S_A = multiply_point_by_scalar_mod_P(point=B, scalar=private_key_a, P=P, a=a)
print(S_A)

(80, 10)


## Bob

In [9]:
S_B = multiply_point_by_scalar_mod_P(point=A, scalar=private_key_b, P=P, a=a)
print(S_B)

(80, 10)


In [10]:
if S_A == S_B:
  shared_secret = S_A
  print("Alice and Bob now shared a secret point.")
else:
  print("Alice and Bob shared different secrets.")

Alice and Bob now shared a secret point.


## Derive a shared key from the shared secret:

In [11]:
# shared secret (an integer): GxGy
shared_secret_int = shared_secret[0] * (10 ** len(str(shared_secret[0]))) + shared_secret[1]

# Calculate the number of bytes needed to represent the integer
num_bytes = (shared_secret_int.bit_length() + 7) // 8

# Convert the integer to a byte array
shared_secret_bytes = shared_secret_int.to_bytes(num_bytes, byteorder='big')

# Derive a key using HKDF
# Choose the desired output length of the key
output_key_length = 16 # 16-byte = 128-bit

# Create HKDF instance
hkdf = HKDF(
    algorithm=hashes.SHA256(),
    length=output_key_length,
    salt=None,
    info=b'handshake data',
    backend=default_backend()
)

# Perform key derivation
shared_key = hkdf.derive(shared_secret_bytes)

print("Derived Key:", shared_key.hex())


Derived Key: 55323c3cd820695175f5b961b1c77116


# Alice - SM4 + Schnorr

## SM4 functions

### SM4 CONSTANTS

In [19]:
SBOX = [
    0xd6, 0x90, 0xe9, 0xfe, 0xcc, 0xe1, 0x3d, 0xb7, 0x16, 0xb6, 0x14, 0xc2, 0x28, 0xfb, 0x2c, 0x05,
    0x2b, 0x67, 0x9a, 0x76, 0x2a, 0xbe, 0x04, 0xc3, 0xaa, 0x44, 0x13, 0x26, 0x49, 0x86, 0x06, 0x99,
    0x9c, 0x42, 0x50, 0xf4, 0x91, 0xef, 0x98, 0x7a, 0x33, 0x54, 0x0b, 0x43, 0xed, 0xcf, 0xac, 0x62,
    0xe4, 0xb3, 0x1c, 0xa9, 0xc9, 0x08, 0xe8, 0x95, 0x80, 0xdf, 0x94, 0xfa, 0x75, 0x8f, 0x3f, 0xa6,
    0x47, 0x07, 0xa7, 0xfc, 0xf3, 0x73, 0x17, 0xba, 0x83, 0x59, 0x3c, 0x19, 0xe6, 0x85, 0x4f, 0xa8,
    0x68, 0x6b, 0x81, 0xb2, 0x71, 0x64, 0xda, 0x8b, 0xf8, 0xeb, 0x0f, 0x4b, 0x70, 0x56, 0x9d, 0x35,
    0x1e, 0x24, 0x0e, 0x5e, 0x63, 0x58, 0xd1, 0xa2, 0x25, 0x22, 0x7c, 0x3b, 0x01, 0x21, 0x78, 0x87,
    0xd4, 0x00, 0x46, 0x57, 0x9f, 0xd3, 0x27, 0x52, 0x4c, 0x36, 0x02, 0xe7, 0xa0, 0xc4, 0xc8, 0x9e,
    0xea, 0xbf, 0x8a, 0xd2, 0x40, 0xc7, 0x38, 0xb5, 0xa3, 0xf7, 0xf2, 0xce, 0xf9, 0x61, 0x15, 0xa1,
    0xe0, 0xae, 0x5d, 0xa4, 0x9b, 0x34, 0x1a, 0x55, 0xad, 0x93, 0x32, 0x30, 0xf5, 0x8c, 0xb1, 0xe3,
    0x1d, 0xf6, 0xe2, 0x2e, 0x82, 0x66, 0xca, 0x60, 0xc0, 0x29, 0x23, 0xab, 0x0d, 0x53, 0x4e, 0x6f,
    0xd5, 0xdb, 0x37, 0x45, 0xde, 0xfd, 0x8e, 0x2f, 0x03, 0xff, 0x6a, 0x72, 0x6d, 0x6c, 0x5b, 0x51,
    0x8d, 0x1b, 0xaf, 0x92, 0xbb, 0xdd, 0xbc, 0x7f, 0x11, 0xd9, 0x5c, 0x41, 0x1f, 0x10, 0x5a, 0xd8,
    0x0a, 0xc1, 0x31, 0x88, 0xa5, 0xcd, 0x7b, 0xbd, 0x2d, 0x74, 0xd0, 0x12, 0xb8, 0xe5, 0xb4, 0xb0,
    0x89, 0x69, 0x97, 0x4a, 0x0c, 0x96, 0x77, 0x7e, 0x65, 0xb9, 0xf1, 0x09, 0xc5, 0x6e, 0xc6, 0x84,
    0x18, 0xf0, 0x7d, 0xec, 0x3a, 0xdc, 0x4d, 0x20, 0x79, 0xee, 0x5f, 0x3e, 0xd7, 0xcb, 0x39, 0x48
]

FK = [0xa3b1bac6, 0x56aa3350, 0x677d9197, 0xb27022dc]

CK = [
    0x00070e15, 0x1c232a31, 0x383f464d, 0x545b6269,
    0x70777e85, 0x8c939aa1, 0xa8afb6bd, 0xc4cbd2d9,
    0xe0e7eef5, 0xfc030a11, 0x181f262d, 0x343b4249,
    0x50575e65, 0x6c737a81, 0x888f969d, 0xa4abb2b9,
    0xc0c7ced5, 0xdce3eaf1, 0xf8ff060d, 0x141b2229,
    0x30373e45, 0x4c535a61, 0x686f767d, 0x848b9299,
    0xa0a7aeb5, 0xbcc3cad1, 0xd8dfe6ed, 0xf4fb0209,
    0x10171e25, 0x2c333a41, 0x484f565d, 0x646b7279
]

### SM4 helper Functions

functions for both Key Expantion, Decryption - Encryption.

In [20]:
def rotl(x, n):
    # Rotate left: shift x to the left by n bits and wrap around the overflowed bits to the right
    return ((x << n) & 0xFFFFFFFF) | (x >> (32 - n))

def tau(a):
    # Non-linear substitution using the S-Box. Each byte of the input is substituted using the S-Box
    return [
        SBOX[(a >> 24) & 0xFF],  # Substitution for the highest byte
        SBOX[(a >> 16) & 0xFF],  # Substitution for the second byte
        SBOX[(a >> 8) & 0xFF],   # Substitution for the third byte
        SBOX[a & 0xFF]           # Substitution for the lowest byte
    ]

### Key Expantion

In [21]:
def l_key(b):
    # Linear transformation for key expansion: rotate and XOR operations
    return b ^ rotl(b, 13) ^ rotl(b, 23)

def t_key(a):
    # Apply tau and then l_key to the input for key expansion
    a = tau(a)
    result = l_key((a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3])
    return result

def key_expansion(key):
    # Key expansion to generate round keys from the original key
    MK = [(key[i * 4] << 24) | (key[i * 4 + 1] << 16) | (key[i * 4 + 2] << 8) | key[i * 4 + 3] for i in range(4)]
    K = [MK[i] ^ FK[i] for i in range(4)]
    rk = []
    for i in range(32):
        K.append(K[i] ^ t_key(K[i + 1] ^ K[i + 2] ^ K[i + 3] ^ CK[i]))
        rk.append(K[-1])
    return rk

### Decryption - Encryption

In [22]:
def l_b(b):
    # Linear transformation: rotate and XOR operations to achieve diffusion
    return b ^ rotl(b, 2) ^ rotl(b, 10) ^ rotl(b, 18) ^ rotl(b, 24)

def t(a):
    # Apply tau and then l_b to the input
    a = tau(a)
    result = l_b((a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3])
    return result

def sm4_encrypt_block(plaintext, rk):
    # Encrypt a single block of plaintext using the round keys
    X = [(plaintext[i * 4] << 24) | (plaintext[i * 4 + 1] << 16) | (plaintext[i * 4 + 2] << 8) | plaintext[i * 4 + 3] for i in range(4)]
    for i in range(32):
        temp = t(X[i + 1] ^ X[i + 2] ^ X[i + 3] ^ rk[i])
        X.append(X[i] ^ temp)
    ciphertext = [(X[35 - i] >> (24 - j * 8)) & 0xFF for i in range(4) for j in range(4)]
    return ciphertext

def sm4_decrypt_block(ciphertext, rk):
    # Decrypt a single block of ciphertext using the round keys
    return sm4_encrypt_block(ciphertext, rk[::-1])

### Split text into blocks

In [23]:
# Function to split text into 128-bit blocks and pad if necessary
# can split into how many bit user wants, block_size = 8 -> 64 bit , 16 -> 128 bit ..
def split_into_blocks(text, block_size=16):
    # Ensure text length is a multiple of block_size
    padding_length = block_size - (len(text) % block_size)
    text += bytes([padding_length] * padding_length)
    # Split text into blocks
    return [text[i:i+block_size] for i in range(0, len(text), block_size)]

### Functions to call Encrypt, Decrypt, key using sm4 with ECB mode

In [24]:
# Function to encrypt text using ECB mode
def sm4_encrypt_ecb(text, key):
    rk = key_expansion(key)
    blocks = split_into_blocks(text)
    ciphertext = []
    for block in blocks:
        ciphertext_block = sm4_encrypt_block(block, rk)
        ciphertext.extend(ciphertext_block)
    return ciphertext

# Function to decrypt text using ECB mode
def sm4_decrypt_ecb(ciphertext, key):
    rk = key_expansion(key)
    blocks = [ciphertext[i:i+16] for i in range(0, len(ciphertext), 16)]
    plaintext = []
    for block in blocks:
        plaintext_block = sm4_decrypt_block(block, rk)
        plaintext.extend(plaintext_block)
    # Remove padding
    padding_length = plaintext[-1]
    return plaintext[:-padding_length]

## Alice encrypt text

In [25]:
key = shared_key #using the key that alice got
plaintext_msg = b'Alice`s bank account details.'

ciphertext = sm4_encrypt_ecb(plaintext_msg, key)
decrypted_plaintext = sm4_decrypt_ecb(ciphertext, key)

print("Plaintext:", plaintext_msg.decode())
print("Plaintext:", ' '.join([f'{x:02X}' for x in plaintext_msg]))
print("Ciphertext:", ' '.join([f'{x:02X}' for x in ciphertext]))

#alice wants to check the algorithem:
print("Decrypted Plaintext:", ' '.join([f'{x:02X}' for x in decrypted_plaintext]))
print("Decrypted Plaintext:", bytes(decrypted_plaintext).decode())

Plaintext: Alice`s bank account details.
Plaintext: 41 6C 69 63 65 60 73 20 62 61 6E 6B 20 61 63 63 6F 75 6E 74 20 64 65 74 61 69 6C 73 2E
Ciphertext: 83 58 56 83 B1 F9 4C 95 B5 F1 C5 E7 D7 04 64 D3 2C CF 6A 3D EE 67 6E 9F DE B9 D2 03 05 B8 72 01
Decrypted Plaintext: 41 6C 69 63 65 60 73 20 62 61 6E 6B 20 61 63 63 6F 75 6E 74 20 64 65 74 61 69 6C 73 2E
Decrypted Plaintext: Alice`s bank account details.


# Schnorr signature

## functions for generating the mathematical parameters:

In [26]:
def get_factor_of_prime_minus_one(prime): # get the factor of prime - 1
    if not sympy.isprime(prime):
        raise ValueError("The input number is not a prime number.")
    p_minus_1 = prime - 1
    factors = sympy.factorint(p_minus_1)
    return max(factors.keys())


def mod_exp(base, exp, mod): # modular exponentiation
    base = gmpy2.mpz(base)
    exp = gmpy2.mpz(exp)
    mod = gmpy2.mpz(mod)
    return gmpy2.powmod(base, exp, mod)


def find_A(Q, P): # find the non-trivial solution to A^Q = 1 mod P
    if Q == 0:
        raise ValueError("Q must be non-zero.")
    if P <= 1:
        raise ValueError("P must be greater than 1.")
    A = gmpy2.mpz(2)
    while A < P:
        if mod_exp(A, Q, P) == 1:
            return A
        A += 1
    raise ValueError("No non-trivial solution found.")


def mod_inverse(A, Q): # find the modular inverse of A under modulo Q
    A = gmpy2.mpz(A)
    Q = gmpy2.mpz(Q)
    g, x, y = extended_gcd(A, Q)
    if g != 1:
        raise ValueError(f"Modular inverse does not exist for A={A} and Q={Q}")
    else:
        return x % Q


def extended_gcd(a, b): # Extended Euclidean Algorithm
    if a == 0:
        return b, 0, 1
    else:
        g, x1, y1 = extended_gcd(b % a, a)
        x = y1 - (b // a) * x1
        y = x1
        return g, x, y


def calculate_V(A, s, P): # Calculate V as A^(-s) mod P
    A = gmpy2.mpz(A)
    s = gmpy2.mpz(s)
    P = gmpy2.mpz(P)
    V = gmpy2.powmod(A, -s, P)
    return V


def hash_function(data): # Hash function
    if isinstance(data, str):
        data = data.encode('utf-8')
    sha256 = hashlib.sha256()
    sha256.update(data)
    return sha256.hexdigest()

## Networking:

## Alice Sign and send

***Sign:***

`M = message to be signed`

global agreed:
```
P = a prime number     # typically 1024-bit
Q = a factor of P-1    # typically 160-bit
A = a^Q === 1 mod P
```
private key:
`s = random 0 < s < Q`

public verification key:
`V = A^(-s) mod P`

for signing:
```
r = random 0 < r < Q
x = A^r mod P
e = Hash(M||x)
y = (r+se)modQ
```

Send: `ciphertext | Signature(e , y) `


### generate parameters:

In [27]:
M = plaintext_msg

# P = randprime(2**1023, 2**1024 - 1)
P = randprime(2**32, 2**33 - 1)                # typicalyy 1024-bit number
Q = get_factor_of_prime_minus_one(P)           # typically 160-bit number
A = find_A(Q, P)
s = random.randint(1, Q-1)
V = calculate_V(A, s, P)
r = random.randint(1, Q-1)
x = mod_exp(A, r, P)
e = hash_function(str(M) + str(x))
y = (r + s*int(e, 16)) % Q

### parameters printing:

In [28]:
print(f'P = {P}')
print(f'Q = {Q}')
print(f'A = {A}')
print(f's = {s}')
print(f'V = {V}')
print(f'r = {r}')
print(f'x = {x}')
print(f'e = {e}')
print(f'y = {y}')

P = 5297840599
Q = 2367221
A = 427
s = 1142827
V = 3787623612
r = 374489
x = 1162997539
e = 0f97e606f10b8ccc6c4d4d4d988137b39612dc0f9381eba0a51742b6b8eec3c5
y = 1454441


# Bob


*** Verification:***
  
Received:
`(ciphertext,  e,  y)`

Known publicly:
`(A,  P,  V)`

compute:
* `x' = (A^y * V^e) mod P`
* `e' = H( M || x' )`

### print what recived and known:

In [29]:
print("   Recived form alice:\n")
print(f'ciphertext = **CIPHERD**')
print(f'e = {e}')
print(f'y = {y}')
print("\n\n  known public parameters:\n")
print(f'A = {A}')
print(f'P = {P}')
print(f'V = {V}')

   Recived form alice:

ciphertext = **CIPHERD**
e = 0f97e606f10b8ccc6c4d4d4d988137b39612dc0f9381eba0a51742b6b8eec3c5
y = 1454441


  known public parameters:

A = 427
P = 5297840599
V = 3787623612


## Decrypt msg:

In [30]:
decrypted_plaintext = sm4_decrypt_ecb(ciphertext, key)
print("Decrypted Plaintext:", bytes(decrypted_plaintext).decode())

Decrypted Plaintext: Alice`s bank account details.


### compute:

In [31]:
Ay = mod_exp(A, y, P)
Ve = mod_exp(V, int(e, 16), P)
AyVe = (Ay * Ve) % P
x_tag = AyVe

print(f"Alice`s x value is {x} and Bob`s x' value is {x_tag}")
if(x == x_tag):
  print("so, (x == x') = True")
else:
  print("so, (x != x') = False")

Alice`s x value is 1162997539 and Bob`s x' value is 1162997539
so, (x == x') = True


Technically, Bob should check if his computed x' is equal to Alice's x.

`(x' == x)` But he doen't know Alice`s x value beacuse it private.

So, Bob will Compute the e` by the hash function value and compare it with Alice's e value.

`e' = H( M || x' )`

In [32]:
decrypted_plaintext_bytes = bytes(decrypted_plaintext)
e_tag = hash_function(str(decrypted_plaintext_bytes) + str(x_tag))

print(f"Alice`s e = {e}")
print(f" Bob`s e' = {e_tag}\n")

if(e == e_tag):
  print("so, (e = e')")
  print("and it means the message M was really signed by Alice.")
else:
  print("so, (e != e')")
  print("and it means the message M was not signed by Alice or it was alerted by someone in the way.")


Alice`s e = 0f97e606f10b8ccc6c4d4d4d988137b39612dc0f9381eba0a51742b6b8eec3c5
 Bob`s e' = 0f97e606f10b8ccc6c4d4d4d988137b39612dc0f9381eba0a51742b6b8eec3c5

so, (e = e')
and it means the message M was really signed by Alice.
