In [7]:
import hashlib
import os
import random
import unittest

from key import fe, ECPubKey, P_256, P_256_G, P_256_ORDER


def forward_map(u):
    """Forward mapping function

    Parameters:
        u (of type fe) : any field element
    Returns:
        fe, fe : affine X and Y coordinates of a point on the p-256 curve
    """
    alpha = - (u ** 2)
    V = alpha * alpha + alpha
    X2 = -fe(P_256.b) / fe(P_256.a)  * (fe(1) + fe(1) / V)
    X3 = alpha * X2
    h2 = X2 * X2 * X2 + fe(P_256.a) * X2 + fe(P_256.b)
    h3 = X3 * X3 * X3 + fe(P_256.a) * X3 + fe(P_256.b)
    if h2.is_square():
        x = X2
        y = h2 ** ((P_256.p+1) // 4)
    else:
        x = X3
        y = h3 ** ((P_256.p+1) // 4)        
    
    if y.is_odd() == u.is_odd():
        return x, y
    else:
        return x, -y

def reverse_map(x, y, i):
    """Reverse mapping function

    Parameters:
        fe, fe : X and Y coordinates of a point on the secp256k1 curve
        i      : integer in range [0,3]
    Returns:
        u (of type fe) : such that forward_map(u) = (x,y), or None.

        - There can be up to 4 such inverses, and i selects which formula to use.
        - Each i can independently from other i values return a value or None.
        - All non-None values returned across all 4 i values are guaranteed to be distinct.
        - Together they will cover all inverses of (x,y) under forward_map.
    """
    delta1 = fe(1) - fe(4) * fe(P_256.b) / (fe(P_256.a) * x + fe(P_256.b))
    delta2 = (fe(P_256.a) * x + fe(1)) ** 2 - fe(4)*(fe(P_256.a) * x + fe(P_256.b))
    v0 = None
    v1 = None 
    if delta1.is_square():
        v0 = (fe(1) - delta1.sqrt()) / fe(2)
        v1 = (fe(1) + delta1.sqrt()) / fe(2)
    v2 = None
    v3 = None
    if delta2.is_square():
        v2 = ((fe(P_256.a) * x + fe(1)) - delta2.sqrt()) / fe(2)
        v3 = ((fe(P_256.a) * x + fe(1)) + delta2.sqrt()) / fe(2)
    if i==0:
        if (not v0 or not v0.is_square()):
            return None
        else:
            u = v0.sqrt()
    if i==1:
        if (not v1 or not v1.is_square()):
            return None
        else:
            u = v1.sqrt()
    if i==2:
        if (not v2 or not v2.is_square()):
            return None
        else:
            u = v2.sqrt()
    if i==3:
        if (not v3 or not v3.is_square()):
            return None
        else:
            u = v3.sqrt()
    if y.is_odd() == u.is_odd():
        u = u
    else:
        u = -u
    
    if i==2 or i==3:
        alpha = - (u ** 2)
        V = alpha * alpha + alpha
        X2 = -fe(P_256.b) / fe(P_256.a)  * (fe(1) + fe(1) / V)
        h2 = X2 * X2 * X2 + fe(P_256.a) * X2 + fe(P_256.b)
        if h2.is_square():
            return None
        
    return u

def reverse_map_list(x, y):
    delta1 = fe(1) - fe(4) * fe(P_256.b) / (fe(P_256.a) * x + fe(P_256.b))
    delta2 = (fe(P_256.a) * x + fe(1)) ** 2 - fe(4)*(fe(P_256.a) * x + fe(P_256.b))
    v0 = None
    v1 = None 
    if delta1.is_square():
        v0 = (fe(1) - delta1.sqrt()) / fe(2)
        v1 = (fe(1) + delta1.sqrt()) / fe(2)
    v2 = None
    v3 = None
    if delta2.is_square():
        v2 = ((fe(P_256.a) * x + fe(1)) - delta2.sqrt()) / fe(2)
        v3 = ((fe(P_256.a) * x + fe(1)) + delta2.sqrt()) / fe(2)
    
    list = []
    if v0 and v0.is_square():
        u = v0.sqrt()
        if y.is_odd() == u.is_odd():
            u = u
        else:
            u = -u
        list.append(u)
    if v1 and v1.is_square():
        u = v1.sqrt()
        if y.is_odd() == u.is_odd():
            u = u
        else:
            u = -u
        list.append(u)
    if v2 and v2.is_square():
        u = v2.sqrt()
        if y.is_odd() == u.is_odd():
            u = u
        else:
            u = -u
        alpha = - (u ** 2)
        V = alpha * alpha + alpha
        X2 = -fe(P_256.b) / fe(P_256.a)  * (fe(1) + fe(1) / V)
        h2 = X2 * X2 * X2 + fe(P_256.a) * X2 + fe(P_256.b)
        if not h2.is_square():
            list.append(u)
    if v3 and v3.is_square():
        u = v3.sqrt()
        if y.is_odd() == u.is_odd():
            u = u
        else:
            u = -u
        alpha = - u ** 2
        V = alpha * alpha + alpha
        X2 = -fe(P_256.b) / fe(P_256.a)  * (fe(1) + fe(1) / V)
        h2 = X2 * X2 * X2 + fe(P_256.a) * X2 + fe(P_256.b)
        if not h2.is_square():
            list.append(u)
            
    return list
    
def encode(P, randombytes):
    # P -> u, v; forward_map(u)+forward_map(v) = P; 
    count = 0
    while True:
        # Random field element u and random number j is extracted from
        # SHA256("secp256k1_ellsq_encode\x00" + uint32{count} + rnd32 + X + byte{Y & 1})
        m = hashlib.sha256()
        m.update(b"SWU_P-256_encode\x00")
        m.update(count.to_bytes(4, 'little'))
        m.update(randombytes)
        m.update(P[0].to_bytes(32, byteorder='big'))
        m.update((P[1] & 1).to_bytes(1, 'big'))
        hash = m.digest()
        u = fe(int.from_bytes(hash, 'big'))
        count += 1
        if count == 1:
            branch_hash = hash
            continue

        ge = forward_map(u)
        # convert ge to jacobian form for EC operations
        ge = (ge[0].val, ge[1].val, 1)
        T = P_256.negate(ge)
        Q = P_256.add(P_256.affine(P), P_256.affine(T))
        Q = P_256.affine(Q)
        if P_256.is_infinity(Q):
            Q = T
        j = (branch_hash[(count-2) >> 2] >> (((count-2) & 3) << 1)) & 3  # 0~3 randomness
        x, y, z = Q
        v = reverse_map(fe(x), fe(y), j)
        if v is not None:
            x1, y1 = forward_map(u)
            x2, y2 = forward_map(v)
            Sum = P_256.add((x1.val, y1.val, 1),(x2.val, y2.val, 1))
            Sum = P_256.affine(Sum)
            if (P[0] == Sum[0] and P[1] == Sum[1]):
                return u, v

def decode(u, v):
    # u, v -> P
    ge1 = forward_map(u)
    ge2 = forward_map(v)
    # convert ge1 and ge2 to jacobian form for EC operations
    T = ge1[0].val, ge1[1].val, 1
    S = ge2[0].val, ge2[1].val, 1
    P = P_256.add(T, S)
    if P_256.is_infinity(P):
        P = T
    P = P_256.affine(P)
    return P

import secrets

def generate_random_bytes(num_bytes):
    return secrets.token_bytes(num_bytes)

def encode_bytes(x, P, LT):
    k = secrets.randbelow((P * (2 ** LT) - x) // P)
    return x + k * P

def decode_bytes(enc, P, LT):
    return enc % P

LP = P_256.p.bit_length()
LT = 32
print(LP)

256
(12634124199229297121735045049902685327494650747802041101543265687219455220023, 67523986399000868244724969570677257032788302873594047388871657288506283485449, 1)
(12634124199229297121735045049902685327494650747802041101543265687219455220023, 67523986399000868244724969570677257032788302873594047388871657288506283485449, 1)


In [None]:
# m = secrets.randbelow(P_256_ORDER)
# A = P_256.affine(P_256.mul([(P_256_G, m)]))
# ge = (A[0], A[1], A[2])
# ge = P_256.affine(ge)
# print(ge)
# random_bytes = generate_random_bytes(32)
# u, v = encode(ge, random_bytes)
# P1 = decode(u, v)
# print(P1)

In [8]:
# m = secrets.randbelow(P_256_ORDER)
# A = P_256.affine(P_256.mul([(P_256_G, m)]))
# ge = (A[0], A[1], A[2])
# random_bytes = generate_random_bytes(32)

# u, v = encode(ge, random_bytes)
# u_ = encode_bytes(u.val, P_256.p, LT)
# v_ = encode_bytes(v.val, P_256.p, LT)
# print(ge)
# u = decode_bytes(u_, P_256.p, LT)
# v = decode_bytes(v_, P_256.p, LT)
# x, y, _ = decode(fe(u), fe(v))

# print(x,y)

(52860767161461039356405934751706078651960061079464608163674770690446641311140, 86327694312228102062771838221269177718629684539738179013090808699983711922820, 1)
52860767161461039356405934751706078651960061079464608163674770690446641311140 86327694312228102062771838221269177718629684539738179013090808699983711922820


In [None]:
All = [] 
for i in range(200000):
    m = secrets.randbelow(P_256_ORDER)
    A = P_256.affine(P_256.mul([(P_256_G, m)]))
    ge = (A[0], A[1], A[2])
    random_bytes = generate_random_bytes(32)
    u, v = encode(ge, random_bytes)
    u_ = encode_bytes(u.val, P_256.p, LT)
    v_ = encode_bytes(v.val, P_256.p, LT)
    All.append(u_.to_bytes(36, 'big'))
    All.append(v_.to_bytes(36, 'big'))
with open('output_file', 'wb') as file:
    for bit_string in All:
        file.write(bit_string)