# Breaking the Merkle-Hellman knapsack cryptosystem with the low density attack

In [None]:
from tqdm import tqdm
import time

## The Merkle-Hellman Cryptosystem

In [None]:
def generate_keypair(n):
    skseq = []
    for i in range(n):
        a = randint(((2**i)-1)*(2**n), (2**i)*(2**n))
        skseq.append(a)
    m = randint(2**(2*n+1)+1, 2**(2*n+2)-1)
    w_prime = randint(2, m-2)
    w = w_prime / gcd(w_prime, m)
    pkseq = [Mod(w*a, m).lift() for a in skseq]
    return ((skseq, w, m), pkseq)

In [None]:
def generate_msg(n):
    return [[True, False][randint(0,1)] for _ in range(n)]

In [None]:
def solve_easy_knapsack(skseq, s):
    msg = []
    for x in skseq[::-1]:
        if x <= s:
            s -= x
            msg = [True] + msg
        else:
            msg = [False] + msg
    return msg

In [None]:
def encode(msg, pk):
    return sum(map(lambda y: y[1], filter(lambda x: x[0], zip(msg, pk))))

In [None]:
def decode(c, sk):
    skseq, w, m = sk
    w_prime = inverse_mod(w, m)
    s = Mod(w_prime * c, m).lift()
    print(s)
    msg = solve_easy_knapsack(skseq, s)
    return msg

In [None]:
def msg_to_int(msg):
    return eval('0b'+''.join(map(str, map(int, msg))))

## Break tests

In [None]:
def break_test(n):
    sk, pk = generate_keypair(n)
    msg = generate_msg(n)
    msg_int = msg_to_int(msg)
    c = encode(msg, pk)
    TL = identity_matrix(n)
    TR = column_matrix([pk])
    BL = zero_matrix(1, n)
    BR = matrix([[-c]])
    M = block_matrix([[TL, TR], [BL, BR]])
    SV = M.LLL()
    msg_guess = None
    for v in SV:
        if all([x == 1 or x == 0 for x in v]):
            msg_guess = int(''.join(list(map(str, map(int, v)))[:-1]), 2)
    return msg_guess == msg_int

In [None]:
def break_test2(n):
    sk, pk = generate_keypair(n)
    msg = generate_msg(n)
    msg_int = msg_to_int(msg)
    c = encode(msg, pk)
    lmbd = ceil(0.5*sqrt(n))
    TL = identity_matrix(n)
    TR = column_matrix([[lmbd*x for x in pk]])
    BL = ones_matrix(1, n) * (1/2)
    BR = matrix([[lmbd*c]])
    M = block_matrix([[TL, TR], [BL, BR]])
    SV = M.LLL()
    msg_guess = None
    for v in SV:
        v = v[:-1]
        v = [1 - (item + (1/2)) for item in v]
        if all([x == 1 or x == 0 for x in v]):
            msg_guess = int(''.join(list(map(str, map(int, v)))), 2)
    return msg_guess == msg_int

In [None]:
break_test2(15)

In [None]:
R = dict()
for n in tqdm(range(10, 50)):  # 197, 203
    R[n] = 0
    ts = 0
    for i in range(100):
        res = break_test2(n)
        if res:
            R[n] += 1

In [None]:
for n in range(10, 50):
    print(n, R[n])

## A single break test

In [None]:
sk, pk = generate_keypair(n)
pk

In [None]:
sk

In [None]:
msg = generate_msg(n)
msg_to_int(msg)

In [None]:
c = encode(msg, pk)
c

In [None]:
decmsg = decode(c, sk)
msg_to_int(decmsg)

In [None]:
TL = identity_matrix(n)
TL

In [None]:
TR = column_matrix([pk])

In [None]:
BL = zero_matrix(1, n)


In [None]:
BR = matrix([[-c]])
BR

In [None]:
M = block_matrix([[TL, TR], [BL, BR]])
M

In [None]:
SV = M.LLL()
print(SV)
for v in SV:
    if all([x == 1 or x == 0 for x in v]):
        print(v)
        msg_orig = int(''.join(list(map(str, map(int, v)))[:-1]), 2)
        print(msg_orig)