In [16]:
import math
import random


def GCD(a,b):
    while b:
        a,b = b, a%b
    return abs(a)
def random_nbit_number(k):

    smallest_possible = 2**(k-1)
    largest_possible = (2**k)-1
    return random.randint(smallest_possible,largest_possible)

#Returns true if result is probably prime
def fermat_test(x, trials):
    for i in range(trials):
        a = random.randrange(2,x-1)
        d = GCD(a,x)
        if(d!=1):
            return False
        else:
            a_power = pow(a,x-1,x)
            if a_power!=1:
                return False
    return True


#k - number of bits for the numbers
def RandEllCurve(k):
    p = random_nbit_number(k)
    while(p%4!=3 or fermat_test(p,2000)==False):
        p = random_nbit_number(k)
    A = random.randint(0,p-1)
    B = random.randint(0,p-1)
    while((4*pow(A,3)+27*pow(B,2))%p==0):
        A = random.randint(0,p-1)
        B = random.randint(0,p-1)
    return (p,A,B)

#1 If square, -1 if not, 0 if a is congruent to 0
def is_square(a,p):
    val = pow(a,(p-1)//2,p)
    #val = (a**((p-1)/2))%p
    if(val==-1%p):
        return -1
    return val


def factor_powers_of_2(a):
    Q=a
    S = 0
    while(Q%2==0):
        Q=Q//2
        S+=1
    return(Q,S)



def square_root(a,p):
    if not is_square(a,p)==1:
        raise ValueError("Not a square, cannot calculate root")
    Q,S = factor_powers_of_2(p-1)
    for i in range(2,p):
        if is_square(i,p)==-1:
            z=i
            break
    M=S
    c = pow(z,Q,p)
    t = pow(a,Q,p)
    R = pow(a,(Q+1)//2,p)
    while True:
        if t==0:
            return 0
        if t==1:
            return R
        for temp in range(1,M):
            if(pow(t,pow(2,temp,p),p)==1):
                i = temp
                break
        b = pow(c,pow(2,M-i-1),p)
        M=i
        c = pow(b,2,p)
        t = (t*pow(b,2,p))%p
        R = (R*b)%p

def RandomPoint(A,B,p):
    x = random.randint(0,p-1)
    right_side = pow(x,3,p)+A*x+B
    while(not is_square(right_side,p)==1):
        x = random.randint(0,p-1)
        right_side = pow(x,3,p)+A*x+B
        right_side = right_side%p
    coin_toss = random.randint(0, 1) #For choosing positive or negative
    #Calculate square root of right_side
    if(coin_toss):
        y = square_root(right_side,p)
    else:
        y = -1*(square_root(right_side,p))
    y%=p
    return (x,y)

def IsPtOnEll(A,B,p,x,y):
    val = pow(x,3,p)+A*x+B
    val = val%p
    y_squared = pow(y,2,p)
    if(val==y_squared):
        return True
    return False

def extendedGCD(a,b):
    r,r1=a,b
    s,s1=1,0
    t,t1=0,1
    while r1!=0:
        q,r2=r//r1,r % r1
        r,s,t,r1,s1,t1=r1,s1,t1,r2,s-s1*q,t-t1*q
    d=r
    return d,s,t

def multiplicative_inverse(a,m):
    d,inv,_=extendedGCD(a,m)
    if d==1:
        if m==1:
            return 1 #for compatibility
        return inv%m
    else:
        raise ValueError('Numbers '+str(a)+' and '+str(m)+' are not coprime.')




def OppPt(x,y,p):
    y_temp = -1*y
    y_temp%=p
    return(x,y_temp)


def PtSum(A,B,p,P,Q):
    if(P==0):
        return Q
    if(Q==0):
        return P
    (x1,y1) = P
    (x2,y2) = Q

    x1%=p
    x2%=p
    y1%=p
    y2%=p
    opposite_point = OppPt(x1,y1,p)
    if(opposite_point[0]==x2 and opposite_point[1] == y2):
        return 0
    if(x1==x2 and y1==y2):
        upper = (3*pow(x1,2,p))+A
        upper%=p
        lower = 2*y1
        lower%=p
        inv = multiplicative_inverse(lower,p)
        lam = (upper*inv)%p
    else:
        upper = (y2-y1)%p
        lower = (x2-x1)%p
        inv = multiplicative_inverse(lower,p)
        lam = (upper*inv)%p
    x3 = pow(lam,2,p)-x1-x2
    x3%=p
    y3 = lam*(x1-x3)-y1
    y3%=p
    return(x3,y3)

In [17]:
def MultPoint(E,P,n):
    (A,B,p) = E
    if n ==0 or P==0:
        return 0
    elif n==1:
        return P
    elif (n < 0):
        return MultPoint(E, OppPt(P[0], P[1], p), n*-1)
    elif n%2==1:
        P2= MultPoint(E,P,n-1)
        if P2==0:
            return P
        return PtSum(E[0],E[1],E[2],P,P2)
    else:
        point_doubled = PtSum(E[0],E[1],E[2],P,P)
        return MultPoint(E,point_doubled,n//2)


In [18]:
def prime_factorization(n):
    i =2
    prime_factors = []
    while (i*i<=n):
        if n%i==0:
            prime_factors.append(i)
            n//=i
        else:
            i+=1
    if(n!=1):
        prime_factors.append(n)
    return prime_factors

In [19]:
def point_order(P,E):
    (A2, B2, p2) = E
    baby_steps_list=[]
    s = math.ceil(math.sqrt(math.sqrt(p2)) + 1000)
    for i2 in range(s):
        baby_steps_list.append(MultPoint(E, P, i2 + 1))
    last_positive_index = len(baby_steps_list)-1
    inverse_elements = []
    for x in baby_steps_list:
        if x==0:
            inverse_elements.append(0)
        else:
            inverse_elements.append(OppPt(x[0], x[1], p2))
    baby_steps_list.extend(inverse_elements)
    Q = MultPoint(E,P,(2*s+1))
    R = MultPoint(E, P, (p2 + 1))
    t = (2 * math.sqrt(p2)) / (2 * s + 1)
    lower_bound = math.floor(-1*t)
    upper_bound = math.ceil(t)
    i2 = lower_bound
    while(i2 <= upper_bound):
        iQ = MultPoint(E, Q, i2)
        temp_value = PtSum(A2, B2, p2, R, iQ)
        if (temp_value in baby_steps_list) or temp_value==0:
            break
        i2+=1
    if(i2>upper_bound):
        print("This shouldn't happen")
        print("P: ",P)
        print("E: ",E)
        raise RuntimeError
    if temp_value==0:
        j=0
    else:
        index = baby_steps_list.index(temp_value)
        j = index+1
        if(index>last_positive_index):
            j=(index-last_positive_index)*-1
    m = p2 + 1 + (2 * s + 1) * i2 - j
    return m

In [20]:
k=1000
def EncodeMessage(A,B,p,M):
    mk = M*k
    for x in range(mk,mk+k):
        right_side = pow(x,3,p)+A*x+B
        right_side%=p
        if(is_square(right_side,p)==1):
            found_x = x
            break
    y = square_root(right_side,p)
    return(found_x,y)

def DecodeMessage(A,B,p,P):
    return math.floor(P[0]/k)

# Key generation

In [21]:
def ElgamalPubPrivateKey(k):
    (p2, A2, B2) = RandEllCurve(40)
    E = (A2, B2, p2)
    Q = RandomPoint(A2, B2, p2)
    order = point_order(Q,E)
    smallest_possible = 2**(k-1)
    while(order<smallest_possible):
        Q = RandomPoint(A2, B2, p2)
        order = point_order(Q,E)
    x = random.randint(1,order-1)
    P = MultPoint(E,Q,x)
    public_key = (A2, B2, p2, Q, P)
    private_key = (A2, B2, p2, Q, P, x)
    return (public_key,private_key)

In [22]:
for x in range(50):
    print(ElgamalPubPrivateKey(40))

((968324707235, 484592667571, 1095738162031, (783759444836, 229261341135), (139796387005, 121753217196)), (968324707235, 484592667571, 1095738162031, (783759444836, 229261341135), (139796387005, 121753217196), 956406733901))
((610514252972, 258142696494, 864142004191, (2444275568, 219953846312), (150815222859, 634446536532)), (610514252972, 258142696494, 864142004191, (2444275568, 219953846312), (150815222859, 634446536532), 4181426027))
((239355470995, 116833239769, 767869031179, (199623428128, 119319224106), (520550353157, 236431810463)), (239355470995, 116833239769, 767869031179, (199623428128, 119319224106), (520550353157, 236431810463), 369723125286))
((301554716128, 72197296764, 623046813551, (119807514576, 66737009671), (261561826645, 307887124300)), (301554716128, 72197296764, 623046813551, (119807514576, 66737009671), (261561826645, 307887124300), 467527352407))
((328268813095, 711046525821, 850701031991, (152290197539, 115989412181), (407163179732, 506617127200)), (3282688130

# Encryption / Decryption functions

In [23]:
def ElgamalEncryption(pubKey,M):
    (A,B,p,Q,P) = pubKey
    E = (A,B,p)
    k_temp = random.randint(1,point_order(Q,E)-1)
    encodedMessage = EncodeMessage(A,B,p,M)
    C1 = MultPoint(E,Q,k_temp)
    kP = MultPoint(E,P,k_temp)
    C2 = PtSum(A,B,p,encodedMessage,kP)
    return (C1,C2)

def DecryptElgamal(privKey,cryptogram):
    (A,B,p,Q,P,x) = privKey
    E = (A,B,p)
    (C1,C2) = cryptogram
    xC1 = MultPoint(E,C1,x)
    inv= OppPt(xC1[0],xC1[1],p)
    decrypted_point = PtSum(A,B,p,C2,inv)
    decoded_message = DecodeMessage(A,B,p,decrypted_point)
    return  decoded_message

# Testing the encryption

In [24]:
(public, private) = ElgamalPubPrivateKey(30)

to_encrypt = 67
encrypted = ElgamalEncryption(public,to_encrypt)
print(f"Plaintext: {to_encrypt}")
print(f"Encrypted: {encrypted}")
decrypted = DecryptElgamal(private,encrypted)
print(f"Decrypted: {decrypted}")

Plaintext: 67
Encrypted: ((242610568303, 243762322854), (438009626790, 498607101354))
Decrypted: 67


In [25]:
for i in range(100):
    m = i+1
    (public, private) = ElgamalPubPrivateKey(30)
    encrypted = ElgamalEncryption(public,m)
    decrypted = DecryptElgamal(private,encrypted)
    print(f"m: {m},encrypted: {encrypted},decrypted: {decrypted}")
    assert decrypted==m

m: 1,encrypted: ((40486836020, 26894854117), (240360903787, 163125310898)),decrypted: 1
m: 2,encrypted: ((365514587460, 10281392387), (690298569882, 809507466080)),decrypted: 2
m: 3,encrypted: ((684228668636, 14244961661), (280811480650, 578580201419)),decrypted: 3
m: 4,encrypted: ((205235516560, 940020073841), (530625062714, 337499010764)),decrypted: 4
m: 5,encrypted: ((653933024504, 230459251642), (697732734553, 750125669188)),decrypted: 5
m: 6,encrypted: ((248660701564, 48271383880), (484245220429, 52891394147)),decrypted: 6
m: 7,encrypted: ((218415484098, 586394945994), (232476503962, 526089920110)),decrypted: 7
m: 8,encrypted: ((378710439951, 113115974623), (93224974023, 645806238257)),decrypted: 8
m: 9,encrypted: ((538920392041, 380530281088), (761290668328, 21176441830)),decrypted: 9
m: 10,encrypted: ((665934543050, 568864167673), (551627648297, 355059331245)),decrypted: 10
m: 11,encrypted: ((523413787813, 688960485916), (837109456403, 624104439631)),decrypted: 11
m: 12,encrypte