### ElGamal signature scheme
The ElGamal signature scheme was described in 1985 and uses the discrete logarithm problem as the "trapdoor", utilizing the fact that in modular arithmetic exponentiation is easy but its inverse (logarithms) are difficult to calculate.

This particular scheme is rarely used in practice - DSA is a modified variant that's used more often, though even that was superseded by variants based on elliptic curves. Moreover, with the increasing power of quantum computers, [Dilithium](https://www.ibm.com/docs/en/zos/2.5.0?topic=cryptography-crystals-dilithium-digital-signature-algorithm) was chosen as the preferred post-quantum signature scheme. Nevertheless, it's a good starting point to learn how digital signatures (may) work.


### Disclaimer: this code is not safe for real-world use

The implementations are for educational purposes

They are slow and surely vulnerable to various attacks

First, we'll need a couple functions for generating prime numbers and getting modular inverses. I'll copy these from [a previous project](https://github.com/AdrianKlessa/ecc_elgamal/blob/main/ecc%20elgamal.ipynb).

In [None]:
import math
import random

def GCD(a,b):
    while b:
        a,b = b, a%b
    return abs(a)

# Used to make sure the numbers we work with are sufficiently large
def random_nbit_number(k):

    smallest_possible = 2**(k-1)
    largest_possible = (2**k)-1
    return random.randint(smallest_possible,largest_possible)

# Returns true if result is probably prime
# There is always at least one prime between k and 2k (Bertrand's postulate)
def fermat_test(x, trials):
    for i in range(trials):
        a = random.randrange(2,x-1)
        d = GCD(a,x)
        if d!=1:
            return False
        else:
            a_power = pow(a,x-1,x)
            if a_power!=1:
                return False
    return True

def extendedGCD(a,b):
    r,r1=a,b
    s,s1=1,0
    t,t1=0,1
    while r1!=0:
        q,r2=r//r1,r % r1
        r,s,t,r1,s1,t1=r1,s1,t1,r2,s-s1*q,t-t1*q
    d=r
    return d,s,t

# Multiplicative inverse of a, modulo m
def multiplicative_inverse(a,m):
    d,inv,_=extendedGCD(a,m)
    if d==1:
        if m==1:
            return 1 #for compatibility
        return inv%m
    else:
        raise ValueError('Numbers '+str(a)+' and '+str(m)+' are not coprime.')


We need to find a prime number p and a generator for the group Z_mod(p)

We'll find these by using safe primes

In [None]:
N = 512 # Signatures in this scheme are notoriously large
q = random_nbit_number(N)
while fermat_test(q, 32)==False or fermat_test((2*q)+1, 32)==False:
    q = random_nbit_number(N)
p=(2*q)+1

The above cell took 30s on a relatively strong CPU, and we should actually be using N=2048 but for such large numbers this implementation was too slow.

This is one of the advantages of DSA, because it uses an explicit formula for the generator (and has smaller signatures).

In [None]:
p

In [None]:
q

The order of a group modulo prime p is p-1. By definition, the generator will have the same order (p-1).

We can use the fact that the order of an element in a group divides the order of the group to quickly check if an element is the generator.

Since we're using safe primes, the order of the group is p-1=2q (because p=2q+1). Because q is also prime, the only possible orders for an element are 2, q and 2q.

In [None]:
def find_generator_for_safe_prime(prime_p: int)->int:
    g = random.randint(2,p-1)
    while pow(g,2, prime_p)==1 or pow(g,q, prime_p)==1:
        g = random.randint(2,p-1)
    assert pow(g,(p-1), prime_p)==1
    return g

g = find_generator_for_safe_prime(p)

In [None]:
g

In [None]:
pow(g,(p-1), p)

### Key generation

In [None]:
def generate_keys(prime_p: int, generator: int)->tuple:
    x = random.randint(1,prime_p-2)
    y = pow(generator,x,prime_p)
    return x,y

In [None]:
x, y = generate_keys(p,g)

print(f"Private key: {x}")
print(f"Public key: {y}")

We'll be using SHA-256 as the hash function

### Signature generation & validation

In [None]:
from hashlib import sha256

N_bytes_count = N//8

def get_shortened_hash(message: bytes):
    message_hash = sha256(message).digest()[:N_bytes_count]
    return int.from_bytes(message_hash, 'big')

def sign(message: bytes, generator: int, prime_p: int, private_key: int)->tuple:
    k = random.randint(2, prime_p-2)
    r = pow(generator,k,prime_p)

    message_hash = get_shortened_hash(message)
    s = (message_hash-private_key*r)
    s = s*multiplicative_inverse(k,prime_p-1) % (p-1)
    return r,s

def verify_signature(message:bytes, signature:tuple, public_key: int, prime_p: int, generator:int)->bool:
    r,s = signature
    if not 0 < r < prime_p:
        print("Invalid r")
        return False
    if not 0< s < prime_p-1:
        print("Invalid s")
        return False
    message_hash = get_shortened_hash(message)
    left = pow(generator,message_hash,prime_p)
    right = pow(public_key,r, prime_p)*pow(r,s,prime_p)
    right = right % prime_p
    return left == right


In [None]:
signature = sign(b"Let's see if this works", g, p, x)

In [None]:
signature

In [None]:
verify_signature(b"Let's see if this works", signature, y, p, g)

In [None]:
verify_signature(b"Let's see if this works with a modified message", signature, y, p, g)

In [None]:
verify_signature(b"Invalid_signature", (3, 17), y, p, g)