In [1]:
import random


def miller_rabin(num):
    # Returns True if num is a prime number.

    s = num - 1
    t = 0
    while s % 2 == 0:
        # keep halving s while it is even (and use t
        # to count how many times we halve s)
        s = s // 2
        t += 1

    for trials in range(5): # try to falsify num's primality 5 times
        a = random.randrange(2, num - 1)
        v = pow(a, s, num)
        if v != 1: # this test does not apply if v is 1.
            i = 0
            while v != (num - 1):
                if i == t - 1:
                    return False
                else:
                    i = i + 1
                    v = (v ** 2) % num
    return True


def is_prime(num):
    if (num < 2):
        return False # 0, 1, and negative numbers are not prime
    
    lowPrimes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997]

    if num in lowPrimes:
        return True

    # See if any of the low prime numbers can divide num
    for prime in lowPrimes:
        if (num % prime == 0):
            return False

    # If all else fails, call rabinMiller() to determine if num is a prime.
    return miller_rabin(num)

In [2]:
import random
from random import getrandbits
from math import gcd
from Crypto.Cipher import AES
from tqdm.notebook import tqdm

pad = lambda x: x + b'\0' * (AES.block_size - len(x) % AES.block_size)
strip = lambda x: x.rstrip(b'\0')
hexdump = lambda x: ' '.join(f'{b:02X}' for b in x)

In [3]:
def randwhile(bits, until):
    while True:
        x = getrandbits(bits)
        if until(x): return x
        
def primroots(p):
    coprime = {x for x in range(1, p) if gcd(x, p) == 1}
    return [x for x in tqdm(range(1, p)) if coprime == {pow(x, y, p) for y in range(1, p)}]

# Uzgodnienie liczby pierwszej n i jej pierwiastka pierwotnego g
n = randwhile(10, is_prime)
g = random.choice(primroots(n))
print(f"n = {n}, g = {g}")

  0%|          | 0/460 [00:00<?, ?it/s]

n = 461, g = 306


In [6]:
# Obliczenie g^x mod n, gdzie x to tajna liczba pierwszej osoby
x = getrandbits(80)
X = pow(g, x, n)
print(f"A: x = {x}, X = {X}")

# Obliczenie g^y mod n, gdzie y to tajna liczba drugiej osoby
y = getrandbits(80)
Y = pow(g, y, n)
print(f"B: y = {y}, Y = {Y}")

# Obliczenie klucza za pomocą znanych wartości
Ak = pow(Y, x, n)
print(f"A: full key = {Ak}")

Bk = pow(X, y, n)
print(f"B: full key = {Bk}")

assert Ak == Bk
k = Ak = Bk

A: x = 356573612564656427238273, X = 27
B: y = 835226331706433831371593, Y = 47
A: full key = 230
B: full key = 230


In [7]:
def encrypt(key, m):
    cipher = AES.new(key.to_bytes(16, byteorder='big'), AES.MODE_ECB)
    return cipher.encrypt(pad(bytes(m, 'utf-8')))

# Szyfrowanie wiadomości
msg = encrypt(k, 'Hello, world!')
print('A: encrypted =', hexdump(msg))

A: encrypted = 71 79 39 8D 22 AF 26 1B 97 44 C9 16 32 C7 A5 5C


In [8]:
def decrypt(key, m):
    cipher = AES.new(key.to_bytes(16, byteorder='big'), AES.MODE_ECB)
    return strip(cipher.decrypt(m)).decode('utf-8')

# Odszyfrowanie wiadomości
res = decrypt(k, msg)
print('B: decrypted =', res)

B: decrypted = Hello, world!
