In [1]:
from sage.groups.generic import bsgs
from collections import namedtuple
from Crypto.Util.number import inverse, bytes_to_long
import hashlib
import random

from math import ceil, sqrt

Point = namedtuple("Point", "x y")
O = 'INFINITY'

def is_on_curve(P):
    if P == O:
        return True
    else:
        return (P.y**2 - (P.x**3 + a*P.x + b)) % p == 0 and 0 <= P.x < p and 0 <= P.y < p

def point_inverse(P):
    if P == O:
        return P
    return Point(P.x, -P.y % p)

def point_addition(P, Q):
    if P == O:
        return Q
    elif Q == O:
        return P
    elif Q == point_inverse(P):
        return O
    else:
        if P == Q:
            s = (3*P.x**2 + a)*inverse(2*P.y, p) % p
        else:
            s = (Q.y - P.y) * inverse((Q.x - P.x), p) % p
    Rx = (s**2 - P.x - Q.x) % p
    Ry = (s*(P.x - Rx) - P.y) % p
    R = Point(Rx, Ry)
    assert is_on_curve(R)
    return R

def point_multiply(P, d):
    bits = bin(d)[2:]
    Q = O
    for bit in bits:
        Q = point_addition(Q, Q)
        if bit == '1':
            Q = point_addition(Q, P)
    assert is_on_curve(Q)
    return Q

a = 9605275265879631008726467740646537125692167794341640822702313056611938432994
b = 7839838607707494463758049830515369383778931948114955676985180993569200375480

p = 9631668579539701602760432524602953084395033948174466686285759025897298205383
gx = 5664314881801362353989790109530444623032842167510027140490832957430741393367
gy = 3735011281298930501441332016708219762942193860515094934964869027614672869355
G = Point(gx, gy)

A = Point(x=3829488417236560785272607696709023677752676859512573328792921651640651429215, y=7947434117984861166834877190207950006170738405923358235762824894524937052000)
B = Point(x=9587224500151531060103223864145463144550060225196219072827570145340119297428, y=2527809441042103520997737454058469252175392602635610992457770946515371529908)
enc = "1536c5b019bd24ddf9fc50de28828f727190ff121b709a6c63c4f823ec31780ad30d219f07a8c419c7afcdce900b6e89b37b18b6daede22e5445eb98f3ca2e40"

In [2]:
# Calculate a, b
a = (((A.y^2-A.x^3)-(B.y^2-B.x^3)) / (A.x-B.x)) % p
print(a)
b = (A.y^2-A.x^3-a*A.x)%p
print(b)

9605275265879631008726467740646537125692167794341640822702313056611938432994
7839838607707494463758049830515369383778931948114955676985180993569200375480


In [3]:
type(O)

<class 'str'>

In [4]:
# EllipticCurve
# E = EllipticCurve(GF(p), [a,b])
fs = factor(p-1)
print(fs)
n = p-1

2 * 2329468847 * 2414146711 * 2484441769 * 2546315801 * 2988745687 * 3048801089 * 3618313243 * 4105685383


In [5]:
def my_bsgs(a, b, n):
    m = ceil(sqrt(n))
    
    table = {}
    
    prevPoint = O
    # Compute Table of J*Gi
    for j in range(0, m):
        table[prevPoint] = j
        prevPoint = point_addition(prevPoint, a)
            
    gama = b
    constant_m_alpha = point_inverse(point_multiply(a, m))

    for i in range(0, m):
        # Search
        if gama in table:
            print("found: " + str(i*m+table[gama]))
            return i*m+table[gama]
#         tmp = point_addition(b, point_inverse(gama))
        gama = point_addition(gama, constant_m_alpha)
        
    print("Solution Not Found")
    return None
    

In [6]:
xis = []

for fac in fs:
    gi = point_multiply(G, n//fac[0])
    pi = point_multiply(A, n//fac[0])
    
#     di = bsgs(a=gi, b=pi, bounds=(0, fac[0]), operation='other', identity=O, inverse=point_inverse, op=point_addition)
    di = my_bsgs(gi, pi, fac[0])
#     di = leo_bsgs(gi, pi, fac[0])
    xis.append(di)

found: 1
found: 1109856587
found: 1432808300
found: 2086362528
found: 1831216898
found: 2399145070
found: 3043431419
found: 1439500374
found: 3507738586


In [7]:
print(xis)
print([fac[0] for fac in fs])
da = crt(xis, [fac[0] for fac in fs])
print(da)

[1, 1109856587, 1432808300, 2086362528, 1831216898, 2399145070, 3043431419, 1439500374, 3507738586]
[2, 2329468847, 2414146711, 2484441769, 2546315801, 2988745687, 3048801089, 3618313243, 4105685383]
6348321811382313695959733425907835828187085177602967426242446915729082292743


In [8]:
encode_flag = bytes.fromhex(enc)
k = point_multiply(B, da).x
k = hashlib.sha512(str(k).encode('ascii')).digest()
flag = bytes(ci ^^ ki for ci, ki in zip(encode_flag, k))
print('flag =', flag)

flag = b'FLAG{adbffefdb46a99fad0042dd3c10fdc414fadd25c}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
