In [1]:
from math import log2, inf
from math import comb as binom
from scipy.stats import bernoulli
import numpy as np
import matplotlib.pyplot as plt

#insall from https://github.com/Crypto-TII/syndrome_decoding_estimator
from sd_estimator.estimator import stern_complexity

In [None]:
def less_than_x(n,x,p):
    return sum(binom(n,i)*(1-p)^(n-i)*p^i for i in range(x))

In [None]:
def less_than_x_asym(n,x,p1,p2):
    s=0
    for i in range(x):
        ws_i= binom(n//2,i)*(1-p1)^(n//2-i)*p1^i 
        ws_j= sum(binom(n//2,o)*(1-p2)^(n//2-o)*p2^o for o in range(x-i))
        s+=ws_i*ws_j
    return s

### General functions

In [None]:
#compute list of binary vectors of length n and weight k
def GospersHack(k,n):
    L_vecs=[]
    setb = (1 << k) - 1;
    limit = (1 << n);
    while setb < limit:
        
        L_vecs.append(setb)
        c = setb & - setb;
        r = setb + c;
        setb = int(((r ^^ setb) >> 2) / c) | int(r);
    return L_vecs

#compute all binary vectors of length n with weight at most k
def all_gosps(k,n):
    L=[0]
    for i in range(1,k+1):
        L+=GospersHack(i,n)
    return L

#Hamming weight
def weight(x):
    w=0
    while(x):
        x&=(x-1)
        w+=1
    return w

#plots two lists Ls=[[xi,yi]], Ll=[[vi,wi]]
def plot_lists(Ls,Ll):
    Lp_short_x=[i[0] for i in Ls]
    Lp_short_y=[i[1] for i in Ls]
    Lp_long_x=[i[0] for i in Ll]
    Lp_long_y=[i[1] for i in Ll]
    plt.scatter(Lp_short_x,Lp_short_y)
    plt.scatter(Lp_long_x,Lp_long_y)
    
def avg_lists(L):
    Lpx=[i[0] for i in L]
    Lpy=[i[1] for i in L]

    avg=[]
    sumi=count=0
    for i in range(len(Lpx)):
        sumi+=Lpy[i]
        count+=1
        if i==len(Lpx)-1 or Lpx[i]!=Lpx[i+1]:
            avg.append([Lpx[i],sumi/count])
            count=sumi=0
    return avg

### Functions for error experiments in standard format

In [None]:
#generates erroneous key in the standard format
def gen_cb_instance_long_rep(k,w,p1,p2):
    bits=int(ceil(log2(k)))
    x=bernoulli(p1)
    y=bernoulli(p2)
    
    L=zero_vector(GF(2),2*k)
    k2_range=[i for i in range(2*k)]
    shuffle(k2_range)
    for i in range(w):
        L[k2_range[i]]=1
    
    
    L_f=zero_vector(GF(2),2*k)
    for i in range(2*k):
        L_f[i]=L[i]
        if L_f[i]:
            L_f[i]+=x.rvs(1)[0]
        else:
            L_f[i]+=y.rvs(1)[0]
            
    return L,L_f

In [None]:
#outputs the number of missing one entries in sf which are present in s
def missing_ones_long_rep(s,sf):
    count=0
    for i in range(len(s)):
        if s[i] and not(sf[i]):
            count+=1
    return count

#iterations of prange ISD on code with n=2*k with c candidates fixed in the information set 
def comp(k,c,wp):
    return log2(binom(2*k-c,wp)/binom(k-c,wp))

In [None]:
#expected number of iterations of Prange ISD to recover BIKE secret key from erroneous key with error probabilities p1 and p2
def complexity_long_rep(k,w,p1,p2):
    s,sf=gen_cb_instance_long_rep(k,w,p1,p2)
    wp,c=missing_ones_long_rep(s,sf),sf.hamming_weight()
    return comp(k,c,wp)#,c,wp

### Functions for Compact Format

In [None]:
#Generate erroneous BIKE secret key in compact format
def gen_cb_instance(k,w,p1,p2):
    
    bits=int(ceil(log2(k)))
    x=bernoulli(p1)
    y=bernoulli(p2)
    
    L=[randint(0,k) for _ in range(w)]
    
    L_f=[]
    for i in L:
        val=i
        for j in range(bits):
            mask=1<<j
            if val & mask:
                val^^=(x.rvs(1)[0]<<j)
            else:
                val^^=(y.rvs(1)[0]<<j)
        L_f.append(val)
    return L,L_f

In [None]:
#number iterations of Prange-ISD with two blocks of different weight
def time_2_blocks(k,c,w,wp):
    t=[inf]
    for p in range(w-wp,min(c+1,k-wp-1)):
        tmp=log2(binom(c,w-wp))+log2(binom(2*k-c,wp))-log2(binom(p,w-wp))-log2(binom(k-p,wp))
        if tmp<t[0]:
            t=[tmp,p]
    return t

In [None]:
#number iterations of Prange-ISD with two blocks of different weight
def time_2_blocks(k,c,w,wp):
    t=[inf]
    for p in range(w-wp,min(c+1,k-wp-1)):
        tmp=log2(binom(c,w-wp))+log2(binom(2*k-c,wp))-log2(binom(p,w-wp))-log2(binom(k-p,wp))
        if tmp<t[0]:
            t=[tmp,p]
    return t[0]

In [None]:
#generate list of candidates L based on erroneous key indices Lf. 
#Consider all x for L, with x = y + e, where y in L and weight(e)<maxw. 
#Include only those x in L for which there exists an y in Lf such that it holds that Pr[x in L | y in Lf]>threshold.
def gen_candidates(maxw,threshold,bits,L_f,k):
    s=0
    colls=0
    Lvecs=all_gosps(maxw,bits)
    L_candidates=set([])
    mask=(1<<bits)-1
    
    for l in range(len(L_f)):
        len_bef=len(L_candidates)
        i=L_f[l]
        for j in Lvecs:
            n00=n01=n10=n11=0
            j^^=i
            n11= weight(j&i)
            n00= weight((j^^mask)& (i^^mask))
            n10=weight(j & (i^^mask))
            n01=weight((j^^mask) & i)

            score=(1-p2)^n00*p2^n01*p1^n10*(1-p1)^n11
            if score>threshold and j<k:
                if l> len(L_f)/2:
                    if j+k in L_candidates:
                        colls+=1
                    L_candidates.add(j+k)
                else:
                    if j in L_candidates:
                        colls+=1
                    L_candidates.add(j)
    return L_candidates

#determine number of missing one-indices of the BIKE secret key (L) in the list of candidates Lcand
def missing(Lcand,L,k):
    miss=0
    for i in range(len(L)):
        z= L[i] if i<len(L)/2 else L[i]+k
        if z not in Lcand:
            miss+=1
    return miss

#recursive function to determine the expected size of the union of (i+1) random sets of size L0 containing elements from {1,...,k}
def f_L(i,k,L0):
    if i==0:
        return L0
    return (f_L(i-1,k,L0)*(1-L0/k)+L0).n()

#approximation of the threshold required to obtain a list of candidates of size about k*0.95 from gen_candidates
def find_probability(k,w,L_f,p1,p2):
    bits=int(ceil(log2(k)))
    avg_size=inf
    prob=0.00001
    factor=k/(1<<bits).n()
    while(avg_size>k*0.95):
        prob+=0.00001
        avg_weight=int(ceil((sum(weight(L_f[i]) for i in range(len(L_f)))/len(L_f)).n()))

        for i in range(bits-avg_weight):
            ws=p1^i*(1-p1)^(bits-avg_weight-i)
            if ws <prob or i==bits-avg_weight-1:
                wi=i
                break

        avg_elements=sum(binom(bits-avg_weight,i) for i in range(min(wi,6)))
        avg_size=f_L(w//2-1,k,int(factor*avg_elements))*2
    return prob,wi

In [None]:
# Determine complexity for key recovery in compact format for given (or generated) erroneous BIKE key
def complexity_short_rep(k,w,p1,p2,threshold,keys=0):
    if keys==0:
        L,L_f=gen_cb_instance(k,w,p1,p2)
    else:
        L,L_f=keys
    bits=int(ceil(log2(k)))
    Lcand=gen_candidates(maxw,threshold,bits,L_f,k)
    c,wp=(len(Lcand),missing(Lcand,L,k))
    tmp=time_2_blocks(k,c,w,wp)
    return tmp,c,wp

In [None]:
#generates erroneous BIKE compact key and finds threshold leading to best attack complexity
def complexity_short_rep_increase_prob(k,w,p1,p2,prob_start=0, inc=0):
    L,L_f=gen_cb_instance(k,w,p1,p2)
    iprob,maxw=find_probability(k,w,L_f,p1,p2)
    prob=max(iprob-0.0005,iprob/21)
    if p1>0.2:
        prob /=16
    if prob_start!=0:
        prob=prob_start
    t=inf
    vals=[]
    first=1
    eqc=0
    while(1):
        tmp,c,wp=complexity_short_rep(k,w,p1,p2,prob,keys=[L,L_f])
        #print(tmp,c,wp,prob)
        if tmp<t:
            t=tmp
            vals=[c,wp,prob]
        elif tmp==t:
            eqc+=1
        elif c<k and tmp-t>1:
            break
        if inc==0:
            prob+=0.00001
        else:
            prob+=inc
        if (eqc>3 or t<0.001) and c<k:
            break
        
    return t#,vals

### Experiment on complexity for symmetric error probabilities
Note that the error experiments on the compact format are very slow, due to the candidate generation, better run them overnight

In [None]:
#choose parameter set to use
k,w=(12323, 142)
#k,w=(40973, 274)

In [None]:
#used approximation of logaithm of polynomial factors of Prange's algorithm
poly_factors =2.8*log2(k)+1

In [None]:
#Perform experiment for symmetric error, compute 10 data points per format (compact / standard)
L_short_sym=[]
L_long_sym=[]

p1=0.025
while p1<0.301:
    pstart=0
    for i in range(10):
        tmp_short=complexity_short_rep_increase_prob(k,w,p1,p1,prob_start=pstart)
        tmp_long =complexity_long_rep(k,w,p1,p1)
        L_short_sym.append([p1,tmp_short+poly_factors])
        L_long_sym.append([p1,tmp_long+poly_factors])
    print(p1,tmp_long)
    p1+=0.025
    

In [None]:
plot_lists(L_short_sym,L_long_sym)

In [None]:
A,B=avg_lists(L_short_sym),avg_lists(L_long_sym)
plot_lists(A,B)

### Experiment on complexity for asymmetric error probabilities

In [None]:
#Perform experiment for symmetric error, compute 10 data points per format (compact / standard)
L_short_asym=[]
L_long_asym=[]

p1=0.025
while p1<0.301:
    pstart=0
    for i in range(1):
        tmp_short=complexity_short_rep_increase_prob(k,w,p1,0.001,prob_start=pstart)
        tmp_long =complexity_long_rep(k,w,p1,0.001)
        L_short_asym.append([p1,tmp_short+poly_factors])
        L_long_asym.append([p1,tmp_long+poly_factors])
    print(p1,tmp_long)
    p1+=0.025
    

In [None]:
plot_lists(L_short_asym,L_long_asym)

In [None]:
A,B=avg_lists(L_short_sym),avg_lists(L_long_sym)
plot_lists(A,B)

### Erasure experiments compact representation

In [None]:
#Generate partially erased BIKE private key in compact format
def simulate_erasure(k,w,p):
    x=bernoulli(p)
    bits=int(ceil(log2(k)))
    Lerasures=[]
    for i in range(w):
        cur=[]
        for j in range(bits):
            if x.rvs(1)[0]:
                cur.append(j)
        Lerasures.append(cur)
    return Lerasures

#Genrate list of all candidates, based on known key coordinates (L) and erasure positions (Lerasures)
def gen_erasure_candidates(L,Lerasures,k):
    mask=(1<<int(ceil(log2(k))))-1
    cand=set([])
    c=0
    for i in Lerasures:
        key_index=L[c]
        for j in range(2**len(i)):
            for o in range(len(i)):
                current_bit=1<<i[o]
                neg = current_bit^^mask
                key_index&=neg
                if (j>>o)&1:
                    key_index ^^= current_bit
            if len(i)!=0:
                if c<w/2:
                    cand.add(key_index)
                else:
                    cand.add(key_index+k)
        c+=1
    return cand

#bit complexity to recover partially erased BIKE secret key
def complexity_erasure(cand,L,Lerasures,k,w):
    np=len(cand)
    wp=w-sum(1 for i in Lerasures if len(i)==0)
    if np<k:
        return log2(k)*2.8
    return stern_complexity(np,np-k,wp)["time"]+log2(np)

In [None]:
Lcomplexity=[]
p=0.4
while p<=0.54:
    for _ in range(100):
        Lerasures=simulate_erasure(k,w,p)
        L,Lf=gen_cb_instance(k,w,p,p)
        cand=gen_erasure_candidates(L,Lerasures,k)
        Lcomplexity.append([p,complexity_erasure(cand,L,Lerasures,k,w)])
    print(p)
    p+=0.005
    


In [None]:
Lcomplexity_short_1= [i for i in Lcomplexity if i[1]!=inf]

In [None]:
Lcomplexity_short_avg=avg_lists(Lcomplexity_short_1)
plot_lists(Lcomplexity_short_avg,[])

### Erasure experiments standard format

In [None]:
#Generate partially-erased BIKE key in the standard format
def simulate_erasure_long_rep(k,w,p):
    x=bernoulli(p)
    bits=int(ceil(log2(k)))
    Lerasures=[]
    for i in range(2*k):
        if x.rvs(1)[0]:
            Lerasures.append(i)
    return Lerasures

#Complexity to recover partially-erased BIKE key in standard foramt
def complexity_erasure_long(L,Lerasures,k,w):
    np=len(Lerasures)
    wp=sum(1 for i in Lerasures if L[i]==1)
    if np<k:
        return 2.8*log2(k)
    return stern_complexity(np,np-k,wp)["time"]+log2(np)

In [None]:
#Generate 10 datapoints for bit complexity of recovery of partially erased BIKE key in standard format
L,Lf=gen_cb_instance_long_rep(k,w,p,p)
Lcomplexity_std=[]
p=0.35
while p<=1:
    for _ in range(10):
        Lerasures=simulate_erasure_long_rep(k,w,p)
        Lcomplexity_std.append([p,complexity_erasure_long(L,Lerasures,k,w)])
    print(p)
    p+=0.01
    
list_plot(Lcomplexity_std)

In [None]:
Lcomplexity_std=[i for i in Lcomplexity_std if i[1]!=inf]

In [None]:
Lcomplexity_std_avg = avg_lists(Lcomplexity_std)
plot_lists(Lcomplexity_std_avg,Lcomplexity_short_avg)

### 